test
This commit is contained in:
parent
e03a8fbeab
commit
a38dbb8df8
|
@ -5,7 +5,7 @@ import logging
|
|||
import hydra
|
||||
from hydra import utils
|
||||
from deepke.attribution_extraction.standard.tools import Serializer
|
||||
from deepke.attribution_extraction.standard.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
|
||||
from deepke.attribution_extraction.standard.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data , _lm_serialize
|
||||
import matplotlib.pyplot as plt
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from deepke.attribution_extraction.standard.utils import load_pkl, load_csv
|
||||
|
@ -18,9 +18,10 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _preprocess_data(data, cfg):
|
||||
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
||||
attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False)
|
||||
atts = _handle_attribute_data(attribute_data)
|
||||
if cfg.model_name != 'lm':
|
||||
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
||||
cfg.vocab_size = vocab.count
|
||||
serializer = Serializer(do_chinese_split=cfg.chinese_split)
|
||||
serial = serializer.serialize
|
||||
|
@ -35,6 +36,8 @@ def _preprocess_data(data, cfg):
|
|||
formats.format(data[0]['sentence'], cfg.chinese_split,
|
||||
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
|
||||
data[0]['entity_index'], data[0]['attribute_value_index']))
|
||||
else:
|
||||
_lm_serialize(data,cfg)
|
||||
return data, atts
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import logging
|
|||
import hydra
|
||||
from hydra import utils
|
||||
from deepke.relation_extraction.standard.tools import Serializer
|
||||
from deepke.relation_extraction.standard.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
||||
from deepke.relation_extraction.standard.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data , _lm_serialize
|
||||
import matplotlib.pyplot as plt
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from deepke.relation_extraction.standard.utils import load_pkl, load_csv
|
||||
|
@ -16,9 +16,12 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _preprocess_data(data, cfg):
|
||||
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
||||
|
||||
relation_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'relation.csv'), verbose=False)
|
||||
rels = _handle_relation_data(relation_data)
|
||||
|
||||
if cfg.model_name != 'lm':
|
||||
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
||||
cfg.vocab_size = vocab.count
|
||||
serializer = Serializer(do_chinese_split=cfg.chinese_split)
|
||||
serial = serializer.serialize
|
||||
|
@ -33,6 +36,9 @@ def _preprocess_data(data, cfg):
|
|||
formats.format(data[0]['sentence'], cfg.chinese_split, cfg.replace_entity_with_type,
|
||||
cfg.replace_entity_with_scope, data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
|
||||
data[0]['head_idx'], data[0]['tail_idx']))
|
||||
else:
|
||||
_lm_serialize(data,cfg)
|
||||
|
||||
return data, rels
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue