This commit is contained in:
tlk-dsg 2021-09-16 14:30:03 +08:00
parent b066534c6d
commit 9cd69ac3e8
191 changed files with 20 additions and 608 deletions

View File

@ -43,7 +43,7 @@ DeepKE 提供了多种知识抽取模型。
1. RE
```
1.REGULAR
1.STANDARD
2.FEW-SHOT
@ -53,12 +53,12 @@ DeepKE 提供了多种知识抽取模型。
2. NER
```
REGULAR
STANDARD
```
3. AE
```
REGULAR
STANDARD
```
@ -76,7 +76,7 @@ DeepKE 提供了多种知识抽取模型。
具体流程请进入详细的README中RE包括了以下三个子功能
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)**
**[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)**
FEW-SHORT
@ -94,7 +94,7 @@ DeepKE 提供了多种知识抽取模型。
具体流程请进入详细的README中
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ner/regular/README.md)**
**[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ner/regular/README.md)**
3. AE
@ -108,7 +108,7 @@ DeepKE 提供了多种知识抽取模型。
具体流程请进入详细的README中:
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ae/regular/README.md)**
**[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ae/regular/README.md)**
@ -127,6 +127,8 @@ Deepke的架构图如下所示
1. 安装后提示 `ModuleNotFoundError: No module named 'past'`,输入命令 `pip install future` 即可解决。
1. 使用语言预训练模型时,在线安装下载模型比较慢,更建议提前下载好,存放到 pretrained 文件夹内。具体存放文件要求见文件夹内的 readme.md。
## 致谢

View File

Can't render this file because it is too large.

View File

Can't render this file because it is too large.

View File

Can't render this file because it is too large.

View File

@ -8,8 +8,8 @@ from deepke.ae_re_tools import Serializer
from deepke.ae_re_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from deepke.ae_re_utils import load_pkl, load_csv
import deepke.ae_re_models as models
from deepke.ae_st_utils import load_pkl, load_csv
import deepke.ae_st_models as models
logger = logging.getLogger(__name__)

View File

@ -11,9 +11,9 @@ from torch.utils.tensorboard import SummaryWriter
# self
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
import deepke.ae_re_models as models
from deepke.ae_re_tools import preprocess , CustomDataset, collate_fn ,train, validate
from deepke.ae_re_utils import manual_seed, load_pkl
import deepke.ae_st_models as models
from deepke.ae_st_tools import preprocess , CustomDataset, collate_fn ,train, validate
from deepke.ae_st_utils import manual_seed, load_pkl
logger = logging.getLogger(__name__)

View File

Can't render this file because it is too large.

View File

@ -4,12 +4,12 @@ import torch
import logging
import hydra
from hydra import utils
from deepke.re_re_tools import Serializer
from deepke.re_re_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
from deepke.re_st_tools import Serializer
from deepke.re_st_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from deepke.re_re_utils import load_pkl, load_csv
import deepke.re_re_models as models
from deepke.re_st_utils import load_pkl, load_csv
import deepke.re_st_models as models
logger = logging.getLogger(__name__)

View File

@ -11,9 +11,9 @@ from torch.utils.tensorboard import SummaryWriter
# self
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
import deepke.re_re_models as models
from deepke.re_re_tools import preprocess , CustomDataset, collate_fn ,train, validate
from deepke.re_re_utils import manual_seed, load_pkl
import deepke.re_st_models as models
from deepke.re_st_tools import preprocess , CustomDataset, collate_fn ,train, validate
from deepke.re_st_utils import manual_seed, load_pkl
logger = logging.getLogger(__name__)

View File

@ -1,142 +0,0 @@
import os
import hydra
import torch
import logging
import torch.nn as nn
from torch import optim
from hydra import utils
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# self
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
import models as models
from tools import preprocess , CustomDataset, collate_fn ,train, validate
from utils import manual_seed, load_pkl
logger = logging.getLogger(__name__)
@hydra.main(config_path="../conf/config.yaml")
def main(cfg):
cwd = utils.get_original_cwd()
cwd = cwd[0:-5]
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
logger.info(f'\n{cfg.pretty()}')
__Model__ = {
'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
}
# device
if cfg.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', cfg.gpu_id)
else:
device = torch.device('cpu')
logger.info(f'device: {device}')
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
if cfg.preprocess:
preprocess(cfg)
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
if cfg.model_name == 'lm':
vocab_size = None
else:
vocab = load_pkl(vocab_path)
vocab_size = vocab.count
cfg.vocab_size = vocab_size
train_dataset = CustomDataset(train_data_path)
valid_dataset = CustomDataset(valid_data_path)
test_dataset = CustomDataset(test_data_path)
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
model = __Model__[cfg.model_name](cfg)
model.to(device)
logger.info(f'\n {model}')
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
criterion = nn.CrossEntropyLoss()
best_f1, best_epoch = -1, 0
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
train_losses, valid_losses = [], []
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
writer = SummaryWriter('tensorboard')
else:
writer = None
logger.info('=' * 10 + ' Start training ' + '=' * 10)
for epoch in range(1, cfg.epoch + 1):
manual_seed(cfg.seed + epoch)
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
scheduler.step(valid_loss)
model_path = model.save(epoch, cfg)
# logger.info(model_path)
train_losses.append(train_loss)
valid_losses.append(valid_loss)
if best_f1 < valid_f1:
best_f1 = valid_f1
best_epoch = epoch
# 使用 valid loss 做 early stopping 的判断标准
if es_loss > valid_loss:
es_loss = valid_loss
es_f1 = valid_f1
es_epoch = epoch
es_patience = 0
es_path = model_path
else:
es_patience += 1
if es_patience >= cfg.early_stopping_patience:
best_es_epoch = es_epoch
best_es_f1 = es_f1
best_es_path = es_path
if cfg.show_plot:
if cfg.plot_utils == 'matplot':
plt.plot(train_losses, 'x-')
plt.plot(valid_losses, '+-')
plt.legend(['train', 'valid'])
plt.title('train/valid comparison loss')
plt.show()
if cfg.plot_utils == 'tensorboard':
for i in range(len(train_losses)):
writer.add_scalars('train/valid_comparison_loss', {
'train': train_losses[i],
'valid': valid_losses[i]
}, i)
writer.close()
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
f'this epoch macro f1: {best_es_f1:0.4f}')
logger.info(f'this model save path: {best_es_path}')
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
f'this epoch macro f1: {best_f1:.4f}')
logger.info('=====end of training====')
logger.info('')
logger.info('=====start test performance====')
validate(-1, model, test_dataloader, criterion, device, cfg)
logger.info('=====ending====')

View File

@ -1,147 +0,0 @@
import os
import sys
import torch
import logging
import hydra
from hydra import utils
from serializer import Serializer
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from utils import load_pkl, load_csv
import models as models
logger = logging.getLogger(__name__)
def _preprocess_data(data, cfg):
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False)
atts = _handle_attribute_data(attribute_data)
cfg.vocab_size = vocab.count
serializer = Serializer(do_chinese_split=cfg.chinese_split)
serial = serializer.serialize
_serialize_sentence(data, serial, cfg)
_convert_tokens_into_index(data, vocab)
_add_pos_seq(data, cfg)
logger.info('start sentence preprocess...')
formats = '\nsentence: {}\nchinese_split: {}\n' \
'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}'
logger.info(
formats.format(data[0]['sentence'], cfg.chinese_split,
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
data[0]['entity_index'], data[0]['attribute_value_index']))
return data, atts
def _get_predict_instance(cfg):
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
flag = flag.strip().lower()
if flag == 'y' or flag == 'yes':
sentence = '张冬梅汉族1968年2月生河南淇县人1988年7月加入中国共产党1989年9月参加工作中央党校经济管理专业毕业中央党校研究生学历'
entity = '张冬梅'
attribute_value = '汉族'
elif flag == 'n' or flag == 'no':
sentence = input('请输入句子:')
entity = input('请输入句中需要预测的实体:')
attribute_value = input('请输入句中需要预测的属性值:')
elif flag == 'exit':
sys.exit(0)
else:
print('please input yes or no, or exit!')
_get_predict_instance(cfg)
instance = dict()
instance['sentence'] = sentence.strip()
instance['entity'] = entity.strip()
instance['attribute_value'] = attribute_value.strip()
instance['entity_offset'] = sentence.find(entity)
instance['attribute_value_offset'] = sentence.find(attribute_value)
return instance
@hydra.main(config_path='../conf/config.yaml')
def main(cfg):
cwd = utils.get_original_cwd()
cwd = cwd[0:-5]
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
print(cfg.pretty())
# get predict instance
instance = _get_predict_instance(cfg)
data = [instance]
# preprocess data
data, rels = _preprocess_data(data, cfg)
# model
__Model__ = {
'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
}
# 最好在 cpu 上预测
cfg.use_gpu = False
if cfg.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', cfg.gpu_id)
else:
device = torch.device('cpu')
logger.info(f'device: {device}')
model = __Model__[cfg.model_name](cfg)
logger.info(f'model name: {cfg.model_name}')
logger.info(f'\n {model}')
model.load(cfg.fp, device=device)
model.to(device)
model.eval()
x = dict()
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
if cfg.model_name != 'lm':
x['entity_pos'], x['attribute_value_pos'] = torch.tensor([data[0]['entity_pos']]), torch.tensor([data[0]['attribute_value_pos']])
if cfg.model_name == 'cnn':
if cfg.use_pcnn:
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
if cfg.model_name == 'gcn':
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
x['adj'] = adj
for key in x.keys():
x[key] = x[key].to(device)
with torch.no_grad():
y_pred = model(x)
y_pred = torch.softmax(y_pred, dim=-1)[0]
prob = y_pred.max().item()
prob_att = list(rels.keys())[y_pred.argmax().item()]
logger.info(f"\"{data[0]['entity']}\"\"{data[0]['attribute_value']}\" 在句中属性为:\"{prob_att}\",置信度为{prob:.2f}")
if cfg.predict_plot:
plt.rcParams["font.family"] = 'Arial Unicode MS'
x = list(rels.keys())
height = list(y_pred.cpu().numpy())
plt.bar(x, height)
for x, y in zip(x, height):
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
plt.xlabel('关系')
plt.ylabel('置信度')
plt.xticks(rotation=315)
plt.show()
if __name__ == '__main__':
main()

Some files were not shown because too many files have changed in this diff Show More