update predict
This commit is contained in:
parent
d3971ef267
commit
7bc3e4effe
35
predict.py
35
predict.py
|
@ -34,7 +34,7 @@ def _preprocess_data(data, cfg):
|
|||
return data, rels
|
||||
|
||||
|
||||
def _get_predict_instance():
|
||||
def _get_predict_instance(cfg):
|
||||
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
|
||||
flag = flag.strip().lower()
|
||||
if flag == 'y' or flag == 'yes':
|
||||
|
@ -46,9 +46,9 @@ def _get_predict_instance():
|
|||
elif flag == 'n' or flag == 'no':
|
||||
sentence = input('请输入句子:')
|
||||
head = input('请输入句中需要预测关系的头实体:')
|
||||
head_type = input('请输入头实体类型:')
|
||||
head_type = input('请输入头实体类型(可以为空,按enter跳过):')
|
||||
tail = input('请输入句中需要预测关系的尾实体:')
|
||||
tail_type = input('请输入尾实体类型:')
|
||||
tail_type = input('请输入尾实体类型(可以为空,按enter跳过):')
|
||||
elif flag == 'exit':
|
||||
sys.exit(0)
|
||||
else:
|
||||
|
@ -58,9 +58,14 @@ def _get_predict_instance():
|
|||
instance = dict()
|
||||
instance['sentence'] = sentence.strip()
|
||||
instance['head'] = head.strip()
|
||||
instance['head_type'] = head_type.strip()
|
||||
instance['tail'] = tail.strip()
|
||||
instance['tail_type'] = tail_type.strip()
|
||||
if head_type.strip() == '' or tail_type.strip() == '':
|
||||
cfg.replace_entity_with_type = False
|
||||
instance['head_type'] = 'None'
|
||||
instance['tail_type'] = 'None'
|
||||
else:
|
||||
instance['head_type'] = head_type.strip()
|
||||
instance['tail_type'] = tail_type.strip()
|
||||
|
||||
return instance
|
||||
|
||||
|
@ -68,15 +73,16 @@ def _get_predict_instance():
|
|||
# 自定义模型存储的路径
|
||||
fp = 'xxx/checkpoints/2019-12-03_17-35-30/cnn_epoch21.pth'
|
||||
|
||||
|
||||
@hydra.main(config_path='conf/config.yaml')
|
||||
def main(cfg):
|
||||
cwd = utils.get_original_cwd()
|
||||
cfg.cwd = cwd
|
||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||
# print(cfg.pretty())
|
||||
print(cfg.pretty())
|
||||
|
||||
# get predict instance
|
||||
instance = _get_predict_instance()
|
||||
instance = _get_predict_instance(cfg)
|
||||
data = [instance]
|
||||
|
||||
# preprocess data
|
||||
|
@ -85,6 +91,11 @@ def main(cfg):
|
|||
# model
|
||||
__Model__ = {
|
||||
'cnn': models.PCNN,
|
||||
'rnn': models.BiLSTM,
|
||||
'transformer': models.Transformer,
|
||||
'gcn': models.GCN,
|
||||
'capsule': models.Capsule,
|
||||
'lm': models.LM,
|
||||
}
|
||||
|
||||
# 最好在 cpu 上预测
|
||||
|
@ -96,18 +107,20 @@ def main(cfg):
|
|||
logger.info(f'device: {device}')
|
||||
|
||||
model = __Model__[cfg.model_name](cfg)
|
||||
logger.info(f'model name: {cfg.model_name}')
|
||||
logger.info(f'\n {model}')
|
||||
model.load(fp, device=device)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
logger.info(f'model name: {cfg.model_name}')
|
||||
logger.info(f'\n {model}')
|
||||
|
||||
x = dict()
|
||||
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
||||
if cfg.model_name != 'lm':
|
||||
x['head_pos'], x['tail_pos'] = torch.tensor([data[0]['head_pos']]), torch.tensor([data[0]['tail_pos']])
|
||||
if cfg.use_pcnn:
|
||||
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
||||
if cfg.model_name == 'cnn':
|
||||
if cfg.use_pcnn:
|
||||
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
||||
|
||||
for key in x.keys():
|
||||
x[key] = x[key].to(device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue