deepke/example/ae/standard/predict.py

153 lines
5.3 KiB
Python
Raw Permalink Normal View History

2021-09-14 16:08:14 +08:00
import os
import sys
import torch
import logging
import hydra
from hydra import utils
2021-09-17 17:49:28 +08:00
from deepke.attribution_extraction.standard.tools import Serializer
2021-09-23 12:40:14 +08:00
from deepke.attribution_extraction.standard.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data , _lm_serialize
2021-09-14 16:08:14 +08:00
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
2021-09-17 17:49:28 +08:00
from deepke.attribution_extraction.standard.utils import load_pkl, load_csv
import deepke.attribution_extraction.standard.models as models
2021-09-14 16:08:14 +08:00
logger = logging.getLogger(__name__)
def _preprocess_data(data, cfg):
attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False)
atts = _handle_attribute_data(attribute_data)
2021-09-23 12:40:14 +08:00
if cfg.model_name != 'lm':
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
cfg.vocab_size = vocab.count
serializer = Serializer(do_chinese_split=cfg.chinese_split)
serial = serializer.serialize
2021-12-12 14:33:09 +08:00
_serialize_sentence(data, serial)
2021-09-23 12:40:14 +08:00
_convert_tokens_into_index(data, vocab)
_add_pos_seq(data, cfg)
logger.info('start sentence preprocess...')
formats = '\nsentence: {}\nchinese_split: {}\n' \
'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}'
logger.info(
formats.format(data[0]['sentence'], cfg.chinese_split,
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
data[0]['entity_index'], data[0]['attribute_value_index']))
else:
_lm_serialize(data,cfg)
2021-09-14 16:08:14 +08:00
return data, atts
def _get_predict_instance(cfg):
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
flag = flag.strip().lower()
if flag == 'y' or flag == 'yes':
sentence = '张冬梅汉族1968年2月生河南淇县人1988年7月加入中国共产党1989年9月参加工作中央党校经济管理专业毕业中央党校研究生学历'
entity = '张冬梅'
attribute_value = '汉族'
elif flag == 'n' or flag == 'no':
sentence = input('请输入句子:')
entity = input('请输入句中需要预测的实体:')
attribute_value = input('请输入句中需要预测的属性值:')
elif flag == 'exit':
sys.exit(0)
else:
print('please input yes or no, or exit!')
_get_predict_instance(cfg)
instance = dict()
instance['sentence'] = sentence.strip()
instance['entity'] = entity.strip()
instance['attribute_value'] = attribute_value.strip()
instance['entity_offset'] = sentence.find(entity)
instance['attribute_value_offset'] = sentence.find(attribute_value)
return instance
2021-09-14 17:11:46 +08:00
@hydra.main(config_path='conf/config.yaml')
2021-09-14 16:08:14 +08:00
def main(cfg):
cwd = utils.get_original_cwd()
2021-09-14 17:11:46 +08:00
# cwd = cwd[0:-5]
2021-09-14 16:08:14 +08:00
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
print(cfg.pretty())
# get predict instance
instance = _get_predict_instance(cfg)
data = [instance]
# preprocess data
data, rels = _preprocess_data(data, cfg)
# model
__Model__ = {
'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
}
# 最好在 cpu 上预测
cfg.use_gpu = False
if cfg.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', cfg.gpu_id)
else:
device = torch.device('cpu')
logger.info(f'device: {device}')
model = __Model__[cfg.model_name](cfg)
logger.info(f'model name: {cfg.model_name}')
logger.info(f'\n {model}')
model.load(cfg.fp, device=device)
model.to(device)
model.eval()
x = dict()
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
if cfg.model_name != 'lm':
x['entity_pos'], x['attribute_value_pos'] = torch.tensor([data[0]['entity_pos']]), torch.tensor([data[0]['attribute_value_pos']])
if cfg.model_name == 'cnn':
if cfg.use_pcnn:
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
if cfg.model_name == 'gcn':
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
x['adj'] = adj
for key in x.keys():
x[key] = x[key].to(device)
with torch.no_grad():
y_pred = model(x)
y_pred = torch.softmax(y_pred, dim=-1)[0]
prob = y_pred.max().item()
prob_att = list(rels.keys())[y_pred.argmax().item()]
logger.info(f"\"{data[0]['entity']}\"\"{data[0]['attribute_value']}\" 在句中属性为:\"{prob_att}\",置信度为{prob:.2f}")
if cfg.predict_plot:
plt.rcParams["font.family"] = 'Arial Unicode MS'
x = list(rels.keys())
height = list(y_pred.cpu().numpy())
plt.bar(x, height)
for x, y in zip(x, height):
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
plt.xlabel('关系')
plt.ylabel('置信度')
plt.xticks(rotation=315)
plt.show()
if __name__ == '__main__':
main()