test
This commit is contained in:
parent
b066534c6d
commit
9cd69ac3e8
14
README.md
14
README.md
|
@ -43,7 +43,7 @@ DeepKE 提供了多种知识抽取模型。
|
|||
1. RE
|
||||
|
||||
```
|
||||
1.REGULAR
|
||||
1.STANDARD
|
||||
|
||||
2.FEW-SHOT
|
||||
|
||||
|
@ -53,12 +53,12 @@ DeepKE 提供了多种知识抽取模型。
|
|||
2. NER
|
||||
|
||||
```
|
||||
REGULAR
|
||||
STANDARD
|
||||
```
|
||||
|
||||
3. AE
|
||||
```
|
||||
REGULAR
|
||||
STANDARD
|
||||
```
|
||||
|
||||
|
||||
|
@ -76,7 +76,7 @@ DeepKE 提供了多种知识抽取模型。
|
|||
|
||||
具体流程请进入详细的README中,RE包括了以下三个子功能
|
||||
|
||||
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)**
|
||||
**[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)**
|
||||
|
||||
FEW-SHORT
|
||||
|
||||
|
@ -94,7 +94,7 @@ DeepKE 提供了多种知识抽取模型。
|
|||
|
||||
具体流程请进入详细的README中:
|
||||
|
||||
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ner/regular/README.md)**
|
||||
**[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ner/regular/README.md)**
|
||||
|
||||
3. AE
|
||||
|
||||
|
@ -108,7 +108,7 @@ DeepKE 提供了多种知识抽取模型。
|
|||
|
||||
具体流程请进入详细的README中:
|
||||
|
||||
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ae/regular/README.md)**
|
||||
**[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ae/regular/README.md)**
|
||||
|
||||
|
||||
|
||||
|
@ -127,6 +127,8 @@ Deepke的架构图如下所示
|
|||
|
||||
1. 安装后提示 `ModuleNotFoundError: No module named 'past'`,输入命令 `pip install future` 即可解决。
|
||||
|
||||
1. 使用语言预训练模型时,在线安装下载模型比较慢,更建议提前下载好,存放到 pretrained 文件夹内。具体存放文件要求见文件夹内的 readme.md。
|
||||
|
||||
## 致谢
|
||||
|
||||
|
||||
|
|
Can't render this file because it is too large.
|
Can't render this file because it is too large.
|
Can't render this file because it is too large.
|
|
@ -8,8 +8,8 @@ from deepke.ae_re_tools import Serializer
|
|||
from deepke.ae_re_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
|
||||
import matplotlib.pyplot as plt
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from deepke.ae_re_utils import load_pkl, load_csv
|
||||
import deepke.ae_re_models as models
|
||||
from deepke.ae_st_utils import load_pkl, load_csv
|
||||
import deepke.ae_st_models as models
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
|
@ -11,9 +11,9 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
# self
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
import deepke.ae_re_models as models
|
||||
from deepke.ae_re_tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||
from deepke.ae_re_utils import manual_seed, load_pkl
|
||||
import deepke.ae_st_models as models
|
||||
from deepke.ae_st_tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||
from deepke.ae_st_utils import manual_seed, load_pkl
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
Can't render this file because it is too large.
|
|
@ -4,12 +4,12 @@ import torch
|
|||
import logging
|
||||
import hydra
|
||||
from hydra import utils
|
||||
from deepke.re_re_tools import Serializer
|
||||
from deepke.re_re_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
||||
from deepke.re_st_tools import Serializer
|
||||
from deepke.re_st_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
||||
import matplotlib.pyplot as plt
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from deepke.re_re_utils import load_pkl, load_csv
|
||||
import deepke.re_re_models as models
|
||||
from deepke.re_st_utils import load_pkl, load_csv
|
||||
import deepke.re_st_models as models
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
|
@ -11,9 +11,9 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
# self
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
import deepke.re_re_models as models
|
||||
from deepke.re_re_tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||
from deepke.re_re_utils import manual_seed, load_pkl
|
||||
import deepke.re_st_models as models
|
||||
from deepke.re_st_tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||
from deepke.re_st_utils import manual_seed, load_pkl
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
|
@ -1,142 +0,0 @@
|
|||
import os
|
||||
import hydra
|
||||
import torch
|
||||
import logging
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
from hydra import utils
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
# self
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
import models as models
|
||||
from tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||
from utils import manual_seed, load_pkl
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@hydra.main(config_path="../conf/config.yaml")
|
||||
def main(cfg):
|
||||
cwd = utils.get_original_cwd()
|
||||
cwd = cwd[0:-5]
|
||||
cfg.cwd = cwd
|
||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||
logger.info(f'\n{cfg.pretty()}')
|
||||
|
||||
__Model__ = {
|
||||
'cnn': models.PCNN,
|
||||
'rnn': models.BiLSTM,
|
||||
'transformer': models.Transformer,
|
||||
'gcn': models.GCN,
|
||||
'capsule': models.Capsule,
|
||||
'lm': models.LM,
|
||||
}
|
||||
|
||||
# device
|
||||
if cfg.use_gpu and torch.cuda.is_available():
|
||||
device = torch.device('cuda', cfg.gpu_id)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
logger.info(f'device: {device}')
|
||||
|
||||
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
|
||||
if cfg.preprocess:
|
||||
preprocess(cfg)
|
||||
|
||||
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
|
||||
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
|
||||
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
|
||||
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
|
||||
|
||||
if cfg.model_name == 'lm':
|
||||
vocab_size = None
|
||||
else:
|
||||
vocab = load_pkl(vocab_path)
|
||||
vocab_size = vocab.count
|
||||
cfg.vocab_size = vocab_size
|
||||
|
||||
train_dataset = CustomDataset(train_data_path)
|
||||
valid_dataset = CustomDataset(valid_data_path)
|
||||
test_dataset = CustomDataset(test_data_path)
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||
|
||||
model = __Model__[cfg.model_name](cfg)
|
||||
model.to(device)
|
||||
logger.info(f'\n {model}')
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
best_f1, best_epoch = -1, 0
|
||||
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
|
||||
train_losses, valid_losses = [], []
|
||||
|
||||
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
|
||||
writer = SummaryWriter('tensorboard')
|
||||
else:
|
||||
writer = None
|
||||
|
||||
logger.info('=' * 10 + ' Start training ' + '=' * 10)
|
||||
|
||||
for epoch in range(1, cfg.epoch + 1):
|
||||
manual_seed(cfg.seed + epoch)
|
||||
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
|
||||
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
|
||||
scheduler.step(valid_loss)
|
||||
model_path = model.save(epoch, cfg)
|
||||
# logger.info(model_path)
|
||||
|
||||
train_losses.append(train_loss)
|
||||
valid_losses.append(valid_loss)
|
||||
if best_f1 < valid_f1:
|
||||
best_f1 = valid_f1
|
||||
best_epoch = epoch
|
||||
# 使用 valid loss 做 early stopping 的判断标准
|
||||
if es_loss > valid_loss:
|
||||
es_loss = valid_loss
|
||||
es_f1 = valid_f1
|
||||
es_epoch = epoch
|
||||
es_patience = 0
|
||||
es_path = model_path
|
||||
else:
|
||||
es_patience += 1
|
||||
if es_patience >= cfg.early_stopping_patience:
|
||||
best_es_epoch = es_epoch
|
||||
best_es_f1 = es_f1
|
||||
best_es_path = es_path
|
||||
|
||||
if cfg.show_plot:
|
||||
if cfg.plot_utils == 'matplot':
|
||||
plt.plot(train_losses, 'x-')
|
||||
plt.plot(valid_losses, '+-')
|
||||
plt.legend(['train', 'valid'])
|
||||
plt.title('train/valid comparison loss')
|
||||
plt.show()
|
||||
|
||||
if cfg.plot_utils == 'tensorboard':
|
||||
for i in range(len(train_losses)):
|
||||
writer.add_scalars('train/valid_comparison_loss', {
|
||||
'train': train_losses[i],
|
||||
'valid': valid_losses[i]
|
||||
}, i)
|
||||
writer.close()
|
||||
|
||||
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
|
||||
f'this epoch macro f1: {best_es_f1:0.4f}')
|
||||
logger.info(f'this model save path: {best_es_path}')
|
||||
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
|
||||
f'this epoch macro f1: {best_f1:.4f}')
|
||||
|
||||
logger.info('=====end of training====')
|
||||
logger.info('')
|
||||
logger.info('=====start test performance====')
|
||||
validate(-1, model, test_dataloader, criterion, device, cfg)
|
||||
logger.info('=====ending====')
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
import os
|
||||
import sys
|
||||
import torch
|
||||
import logging
|
||||
import hydra
|
||||
from hydra import utils
|
||||
from serializer import Serializer
|
||||
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
|
||||
import matplotlib.pyplot as plt
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from utils import load_pkl, load_csv
|
||||
import models as models
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _preprocess_data(data, cfg):
|
||||
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
||||
attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False)
|
||||
atts = _handle_attribute_data(attribute_data)
|
||||
cfg.vocab_size = vocab.count
|
||||
serializer = Serializer(do_chinese_split=cfg.chinese_split)
|
||||
serial = serializer.serialize
|
||||
|
||||
_serialize_sentence(data, serial, cfg)
|
||||
_convert_tokens_into_index(data, vocab)
|
||||
_add_pos_seq(data, cfg)
|
||||
logger.info('start sentence preprocess...')
|
||||
formats = '\nsentence: {}\nchinese_split: {}\n' \
|
||||
'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}'
|
||||
logger.info(
|
||||
formats.format(data[0]['sentence'], cfg.chinese_split,
|
||||
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
|
||||
data[0]['entity_index'], data[0]['attribute_value_index']))
|
||||
return data, atts
|
||||
|
||||
|
||||
def _get_predict_instance(cfg):
|
||||
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
|
||||
flag = flag.strip().lower()
|
||||
if flag == 'y' or flag == 'yes':
|
||||
sentence = '张冬梅,女,汉族,1968年2月生,河南淇县人,1988年7月加入中国共产党,1989年9月参加工作,中央党校经济管理专业毕业,中央党校研究生学历'
|
||||
entity = '张冬梅'
|
||||
attribute_value = '汉族'
|
||||
elif flag == 'n' or flag == 'no':
|
||||
sentence = input('请输入句子:')
|
||||
entity = input('请输入句中需要预测的实体:')
|
||||
attribute_value = input('请输入句中需要预测的属性值:')
|
||||
elif flag == 'exit':
|
||||
sys.exit(0)
|
||||
else:
|
||||
print('please input yes or no, or exit!')
|
||||
_get_predict_instance(cfg)
|
||||
|
||||
|
||||
instance = dict()
|
||||
instance['sentence'] = sentence.strip()
|
||||
instance['entity'] = entity.strip()
|
||||
instance['attribute_value'] = attribute_value.strip()
|
||||
instance['entity_offset'] = sentence.find(entity)
|
||||
instance['attribute_value_offset'] = sentence.find(attribute_value)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_path='../conf/config.yaml')
|
||||
def main(cfg):
|
||||
cwd = utils.get_original_cwd()
|
||||
cwd = cwd[0:-5]
|
||||
cfg.cwd = cwd
|
||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||
print(cfg.pretty())
|
||||
|
||||
# get predict instance
|
||||
instance = _get_predict_instance(cfg)
|
||||
data = [instance]
|
||||
|
||||
# preprocess data
|
||||
data, rels = _preprocess_data(data, cfg)
|
||||
|
||||
# model
|
||||
__Model__ = {
|
||||
'cnn': models.PCNN,
|
||||
'rnn': models.BiLSTM,
|
||||
'transformer': models.Transformer,
|
||||
'gcn': models.GCN,
|
||||
'capsule': models.Capsule,
|
||||
'lm': models.LM,
|
||||
}
|
||||
|
||||
# 最好在 cpu 上预测
|
||||
cfg.use_gpu = False
|
||||
if cfg.use_gpu and torch.cuda.is_available():
|
||||
device = torch.device('cuda', cfg.gpu_id)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
logger.info(f'device: {device}')
|
||||
|
||||
model = __Model__[cfg.model_name](cfg)
|
||||
logger.info(f'model name: {cfg.model_name}')
|
||||
logger.info(f'\n {model}')
|
||||
model.load(cfg.fp, device=device)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
x = dict()
|
||||
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
||||
|
||||
if cfg.model_name != 'lm':
|
||||
x['entity_pos'], x['attribute_value_pos'] = torch.tensor([data[0]['entity_pos']]), torch.tensor([data[0]['attribute_value_pos']])
|
||||
if cfg.model_name == 'cnn':
|
||||
if cfg.use_pcnn:
|
||||
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
||||
if cfg.model_name == 'gcn':
|
||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
|
||||
x['adj'] = adj
|
||||
|
||||
|
||||
for key in x.keys():
|
||||
x[key] = x[key].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
y_pred = model(x)
|
||||
y_pred = torch.softmax(y_pred, dim=-1)[0]
|
||||
prob = y_pred.max().item()
|
||||
prob_att = list(rels.keys())[y_pred.argmax().item()]
|
||||
logger.info(f"\"{data[0]['entity']}\" 和 \"{data[0]['attribute_value']}\" 在句中属性为:\"{prob_att}\",置信度为{prob:.2f}。")
|
||||
|
||||
if cfg.predict_plot:
|
||||
plt.rcParams["font.family"] = 'Arial Unicode MS'
|
||||
x = list(rels.keys())
|
||||
height = list(y_pred.cpu().numpy())
|
||||
plt.bar(x, height)
|
||||
for x, y in zip(x, height):
|
||||
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
|
||||
plt.xlabel('关系')
|
||||
plt.ylabel('置信度')
|
||||
plt.xticks(rotation=315)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue