update predict
This commit is contained in:
parent
d3971ef267
commit
7bc3e4effe
35
predict.py
35
predict.py
|
@ -34,7 +34,7 @@ def _preprocess_data(data, cfg):
|
||||||
return data, rels
|
return data, rels
|
||||||
|
|
||||||
|
|
||||||
def _get_predict_instance():
|
def _get_predict_instance(cfg):
|
||||||
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
|
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
|
||||||
flag = flag.strip().lower()
|
flag = flag.strip().lower()
|
||||||
if flag == 'y' or flag == 'yes':
|
if flag == 'y' or flag == 'yes':
|
||||||
|
@ -46,9 +46,9 @@ def _get_predict_instance():
|
||||||
elif flag == 'n' or flag == 'no':
|
elif flag == 'n' or flag == 'no':
|
||||||
sentence = input('请输入句子:')
|
sentence = input('请输入句子:')
|
||||||
head = input('请输入句中需要预测关系的头实体:')
|
head = input('请输入句中需要预测关系的头实体:')
|
||||||
head_type = input('请输入头实体类型:')
|
head_type = input('请输入头实体类型(可以为空,按enter跳过):')
|
||||||
tail = input('请输入句中需要预测关系的尾实体:')
|
tail = input('请输入句中需要预测关系的尾实体:')
|
||||||
tail_type = input('请输入尾实体类型:')
|
tail_type = input('请输入尾实体类型(可以为空,按enter跳过):')
|
||||||
elif flag == 'exit':
|
elif flag == 'exit':
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
|
@ -58,9 +58,14 @@ def _get_predict_instance():
|
||||||
instance = dict()
|
instance = dict()
|
||||||
instance['sentence'] = sentence.strip()
|
instance['sentence'] = sentence.strip()
|
||||||
instance['head'] = head.strip()
|
instance['head'] = head.strip()
|
||||||
instance['head_type'] = head_type.strip()
|
|
||||||
instance['tail'] = tail.strip()
|
instance['tail'] = tail.strip()
|
||||||
instance['tail_type'] = tail_type.strip()
|
if head_type.strip() == '' or tail_type.strip() == '':
|
||||||
|
cfg.replace_entity_with_type = False
|
||||||
|
instance['head_type'] = 'None'
|
||||||
|
instance['tail_type'] = 'None'
|
||||||
|
else:
|
||||||
|
instance['head_type'] = head_type.strip()
|
||||||
|
instance['tail_type'] = tail_type.strip()
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
@ -68,15 +73,16 @@ def _get_predict_instance():
|
||||||
# 自定义模型存储的路径
|
# 自定义模型存储的路径
|
||||||
fp = 'xxx/checkpoints/2019-12-03_17-35-30/cnn_epoch21.pth'
|
fp = 'xxx/checkpoints/2019-12-03_17-35-30/cnn_epoch21.pth'
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path='conf/config.yaml')
|
@hydra.main(config_path='conf/config.yaml')
|
||||||
def main(cfg):
|
def main(cfg):
|
||||||
cwd = utils.get_original_cwd()
|
cwd = utils.get_original_cwd()
|
||||||
cfg.cwd = cwd
|
cfg.cwd = cwd
|
||||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||||
# print(cfg.pretty())
|
print(cfg.pretty())
|
||||||
|
|
||||||
# get predict instance
|
# get predict instance
|
||||||
instance = _get_predict_instance()
|
instance = _get_predict_instance(cfg)
|
||||||
data = [instance]
|
data = [instance]
|
||||||
|
|
||||||
# preprocess data
|
# preprocess data
|
||||||
|
@ -85,6 +91,11 @@ def main(cfg):
|
||||||
# model
|
# model
|
||||||
__Model__ = {
|
__Model__ = {
|
||||||
'cnn': models.PCNN,
|
'cnn': models.PCNN,
|
||||||
|
'rnn': models.BiLSTM,
|
||||||
|
'transformer': models.Transformer,
|
||||||
|
'gcn': models.GCN,
|
||||||
|
'capsule': models.Capsule,
|
||||||
|
'lm': models.LM,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 最好在 cpu 上预测
|
# 最好在 cpu 上预测
|
||||||
|
@ -96,18 +107,20 @@ def main(cfg):
|
||||||
logger.info(f'device: {device}')
|
logger.info(f'device: {device}')
|
||||||
|
|
||||||
model = __Model__[cfg.model_name](cfg)
|
model = __Model__[cfg.model_name](cfg)
|
||||||
|
logger.info(f'model name: {cfg.model_name}')
|
||||||
|
logger.info(f'\n {model}')
|
||||||
model.load(fp, device=device)
|
model.load(fp, device=device)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
logger.info(f'model name: {cfg.model_name}')
|
|
||||||
logger.info(f'\n {model}')
|
|
||||||
|
|
||||||
x = dict()
|
x = dict()
|
||||||
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
||||||
if cfg.model_name != 'lm':
|
if cfg.model_name != 'lm':
|
||||||
x['head_pos'], x['tail_pos'] = torch.tensor([data[0]['head_pos']]), torch.tensor([data[0]['tail_pos']])
|
x['head_pos'], x['tail_pos'] = torch.tensor([data[0]['head_pos']]), torch.tensor([data[0]['tail_pos']])
|
||||||
if cfg.use_pcnn:
|
if cfg.model_name == 'cnn':
|
||||||
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
if cfg.use_pcnn:
|
||||||
|
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
||||||
|
|
||||||
for key in x.keys():
|
for key in x.keys():
|
||||||
x[key] = x[key].to(device)
|
x[key] = x[key].to(device)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue