This commit is contained in:
tlk-dsg 2021-09-23 12:40:14 +08:00
parent e03a8fbeab
commit a38dbb8df8
2 changed files with 39 additions and 30 deletions

View File

@ -5,7 +5,7 @@ import logging
import hydra
from hydra import utils
from deepke.attribution_extraction.standard.tools import Serializer
from deepke.attribution_extraction.standard.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
from deepke.attribution_extraction.standard.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data , _lm_serialize
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from deepke.attribution_extraction.standard.utils import load_pkl, load_csv
@ -18,23 +18,26 @@ logger = logging.getLogger(__name__)
def _preprocess_data(data, cfg):
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False)
atts = _handle_attribute_data(attribute_data)
cfg.vocab_size = vocab.count
serializer = Serializer(do_chinese_split=cfg.chinese_split)
serial = serializer.serialize
if cfg.model_name != 'lm':
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
cfg.vocab_size = vocab.count
serializer = Serializer(do_chinese_split=cfg.chinese_split)
serial = serializer.serialize
_serialize_sentence(data, serial, cfg)
_convert_tokens_into_index(data, vocab)
_add_pos_seq(data, cfg)
logger.info('start sentence preprocess...')
formats = '\nsentence: {}\nchinese_split: {}\n' \
'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}'
logger.info(
formats.format(data[0]['sentence'], cfg.chinese_split,
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
data[0]['entity_index'], data[0]['attribute_value_index']))
_serialize_sentence(data, serial, cfg)
_convert_tokens_into_index(data, vocab)
_add_pos_seq(data, cfg)
logger.info('start sentence preprocess...')
formats = '\nsentence: {}\nchinese_split: {}\n' \
'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}'
logger.info(
formats.format(data[0]['sentence'], cfg.chinese_split,
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
data[0]['entity_index'], data[0]['attribute_value_index']))
else:
_lm_serialize(data,cfg)
return data, atts

View File

@ -5,7 +5,7 @@ import logging
import hydra
from hydra import utils
from deepke.relation_extraction.standard.tools import Serializer
from deepke.relation_extraction.standard.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
from deepke.relation_extraction.standard.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data , _lm_serialize
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from deepke.relation_extraction.standard.utils import load_pkl, load_csv
@ -16,23 +16,29 @@ logger = logging.getLogger(__name__)
def _preprocess_data(data, cfg):
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
relation_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'relation.csv'), verbose=False)
rels = _handle_relation_data(relation_data)
cfg.vocab_size = vocab.count
serializer = Serializer(do_chinese_split=cfg.chinese_split)
serial = serializer.serialize
_serialize_sentence(data, serial, cfg)
_convert_tokens_into_index(data, vocab)
_add_pos_seq(data, cfg)
logger.info('start sentence preprocess...')
formats = '\nsentence: {}\nchinese_split: {}\nreplace_entity_with_type: {}\nreplace_entity_with_scope: {}\n' \
'tokens: {}\ntoken2idx: {}\nlength: {}\nhead_idx: {}\ntail_idx: {}'
logger.info(
formats.format(data[0]['sentence'], cfg.chinese_split, cfg.replace_entity_with_type,
cfg.replace_entity_with_scope, data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
data[0]['head_idx'], data[0]['tail_idx']))
if cfg.model_name != 'lm':
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
cfg.vocab_size = vocab.count
serializer = Serializer(do_chinese_split=cfg.chinese_split)
serial = serializer.serialize
_serialize_sentence(data, serial, cfg)
_convert_tokens_into_index(data, vocab)
_add_pos_seq(data, cfg)
logger.info('start sentence preprocess...')
formats = '\nsentence: {}\nchinese_split: {}\nreplace_entity_with_type: {}\nreplace_entity_with_scope: {}\n' \
'tokens: {}\ntoken2idx: {}\nlength: {}\nhead_idx: {}\ntail_idx: {}'
logger.info(
formats.format(data[0]['sentence'], cfg.chinese_split, cfg.replace_entity_with_type,
cfg.replace_entity_with_scope, data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
data[0]['head_idx'], data[0]['tail_idx']))
else:
_lm_serialize(data,cfg)
return data, rels