update predict

This commit is contained in:
leo 2019-12-04 16:01:48 +08:00
parent d3971ef267
commit 7bc3e4effe
1 changed files with 24 additions and 11 deletions

View File

@ -34,7 +34,7 @@ def _preprocess_data(data, cfg):
return data, rels
def _get_predict_instance():
def _get_predict_instance(cfg):
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
flag = flag.strip().lower()
if flag == 'y' or flag == 'yes':
@ -46,9 +46,9 @@ def _get_predict_instance():
elif flag == 'n' or flag == 'no':
sentence = input('请输入句子:')
head = input('请输入句中需要预测关系的头实体:')
head_type = input('请输入头实体类型')
head_type = input('请输入头实体类型可以为空按enter跳过')
tail = input('请输入句中需要预测关系的尾实体:')
tail_type = input('请输入尾实体类型')
tail_type = input('请输入尾实体类型可以为空按enter跳过')
elif flag == 'exit':
sys.exit(0)
else:
@ -58,9 +58,14 @@ def _get_predict_instance():
instance = dict()
instance['sentence'] = sentence.strip()
instance['head'] = head.strip()
instance['head_type'] = head_type.strip()
instance['tail'] = tail.strip()
instance['tail_type'] = tail_type.strip()
if head_type.strip() == '' or tail_type.strip() == '':
cfg.replace_entity_with_type = False
instance['head_type'] = 'None'
instance['tail_type'] = 'None'
else:
instance['head_type'] = head_type.strip()
instance['tail_type'] = tail_type.strip()
return instance
@ -68,15 +73,16 @@ def _get_predict_instance():
# 自定义模型存储的路径
fp = 'xxx/checkpoints/2019-12-03_17-35-30/cnn_epoch21.pth'
@hydra.main(config_path='conf/config.yaml')
def main(cfg):
cwd = utils.get_original_cwd()
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
# print(cfg.pretty())
print(cfg.pretty())
# get predict instance
instance = _get_predict_instance()
instance = _get_predict_instance(cfg)
data = [instance]
# preprocess data
@ -85,6 +91,11 @@ def main(cfg):
# model
__Model__ = {
'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
}
# 最好在 cpu 上预测
@ -96,18 +107,20 @@ def main(cfg):
logger.info(f'device: {device}')
model = __Model__[cfg.model_name](cfg)
logger.info(f'model name: {cfg.model_name}')
logger.info(f'\n {model}')
model.load(fp, device=device)
model.to(device)
model.eval()
logger.info(f'model name: {cfg.model_name}')
logger.info(f'\n {model}')
x = dict()
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
if cfg.model_name != 'lm':
x['head_pos'], x['tail_pos'] = torch.tensor([data[0]['head_pos']]), torch.tensor([data[0]['tail_pos']])
if cfg.use_pcnn:
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
if cfg.model_name == 'cnn':
if cfg.use_pcnn:
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
for key in x.keys():
x[key] = x[key].to(device)