test
This commit is contained in:
parent
b066534c6d
commit
9cd69ac3e8
14
README.md
14
README.md
|
@ -43,7 +43,7 @@ DeepKE 提供了多种知识抽取模型。
|
||||||
1. RE
|
1. RE
|
||||||
|
|
||||||
```
|
```
|
||||||
1.REGULAR
|
1.STANDARD
|
||||||
|
|
||||||
2.FEW-SHOT
|
2.FEW-SHOT
|
||||||
|
|
||||||
|
@ -53,12 +53,12 @@ DeepKE 提供了多种知识抽取模型。
|
||||||
2. NER
|
2. NER
|
||||||
|
|
||||||
```
|
```
|
||||||
REGULAR
|
STANDARD
|
||||||
```
|
```
|
||||||
|
|
||||||
3. AE
|
3. AE
|
||||||
```
|
```
|
||||||
REGULAR
|
STANDARD
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -76,7 +76,7 @@ DeepKE 提供了多种知识抽取模型。
|
||||||
|
|
||||||
具体流程请进入详细的README中,RE包括了以下三个子功能
|
具体流程请进入详细的README中,RE包括了以下三个子功能
|
||||||
|
|
||||||
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)**
|
**[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)**
|
||||||
|
|
||||||
FEW-SHORT
|
FEW-SHORT
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ DeepKE 提供了多种知识抽取模型。
|
||||||
|
|
||||||
具体流程请进入详细的README中:
|
具体流程请进入详细的README中:
|
||||||
|
|
||||||
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ner/regular/README.md)**
|
**[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ner/regular/README.md)**
|
||||||
|
|
||||||
3. AE
|
3. AE
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ DeepKE 提供了多种知识抽取模型。
|
||||||
|
|
||||||
具体流程请进入详细的README中:
|
具体流程请进入详细的README中:
|
||||||
|
|
||||||
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ae/regular/README.md)**
|
**[STANDARD](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ae/regular/README.md)**
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,6 +127,8 @@ Deepke的架构图如下所示
|
||||||
|
|
||||||
1. 安装后提示 `ModuleNotFoundError: No module named 'past'`,输入命令 `pip install future` 即可解决。
|
1. 安装后提示 `ModuleNotFoundError: No module named 'past'`,输入命令 `pip install future` 即可解决。
|
||||||
|
|
||||||
|
1. 使用语言预训练模型时,在线安装下载模型比较慢,更建议提前下载好,存放到 pretrained 文件夹内。具体存放文件要求见文件夹内的 readme.md。
|
||||||
|
|
||||||
## 致谢
|
## 致谢
|
||||||
|
|
||||||
|
|
||||||
|
|
Can't render this file because it is too large.
|
Can't render this file because it is too large.
|
Can't render this file because it is too large.
|
|
@ -8,8 +8,8 @@ from deepke.ae_re_tools import Serializer
|
||||||
from deepke.ae_re_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
|
from deepke.ae_re_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
from deepke.ae_re_utils import load_pkl, load_csv
|
from deepke.ae_st_utils import load_pkl, load_csv
|
||||||
import deepke.ae_re_models as models
|
import deepke.ae_st_models as models
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
|
@ -11,9 +11,9 @@ from torch.utils.tensorboard import SummaryWriter
|
||||||
# self
|
# self
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
import deepke.ae_re_models as models
|
import deepke.ae_st_models as models
|
||||||
from deepke.ae_re_tools import preprocess , CustomDataset, collate_fn ,train, validate
|
from deepke.ae_st_tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||||
from deepke.ae_re_utils import manual_seed, load_pkl
|
from deepke.ae_st_utils import manual_seed, load_pkl
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
Can't render this file because it is too large.
|
|
@ -4,12 +4,12 @@ import torch
|
||||||
import logging
|
import logging
|
||||||
import hydra
|
import hydra
|
||||||
from hydra import utils
|
from hydra import utils
|
||||||
from deepke.re_re_tools import Serializer
|
from deepke.re_st_tools import Serializer
|
||||||
from deepke.re_re_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
from deepke.re_st_tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
from deepke.re_re_utils import load_pkl, load_csv
|
from deepke.re_st_utils import load_pkl, load_csv
|
||||||
import deepke.re_re_models as models
|
import deepke.re_st_models as models
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
|
@ -11,9 +11,9 @@ from torch.utils.tensorboard import SummaryWriter
|
||||||
# self
|
# self
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
import deepke.re_re_models as models
|
import deepke.re_st_models as models
|
||||||
from deepke.re_re_tools import preprocess , CustomDataset, collate_fn ,train, validate
|
from deepke.re_st_tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||||
from deepke.re_re_utils import manual_seed, load_pkl
|
from deepke.re_st_utils import manual_seed, load_pkl
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
|
@ -1,142 +0,0 @@
|
||||||
import os
|
|
||||||
import hydra
|
|
||||||
import torch
|
|
||||||
import logging
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import optim
|
|
||||||
from hydra import utils
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
|
||||||
# self
|
|
||||||
import sys
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
|
||||||
import models as models
|
|
||||||
from tools import preprocess , CustomDataset, collate_fn ,train, validate
|
|
||||||
from utils import manual_seed, load_pkl
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
@hydra.main(config_path="../conf/config.yaml")
|
|
||||||
def main(cfg):
|
|
||||||
cwd = utils.get_original_cwd()
|
|
||||||
cwd = cwd[0:-5]
|
|
||||||
cfg.cwd = cwd
|
|
||||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
|
||||||
logger.info(f'\n{cfg.pretty()}')
|
|
||||||
|
|
||||||
__Model__ = {
|
|
||||||
'cnn': models.PCNN,
|
|
||||||
'rnn': models.BiLSTM,
|
|
||||||
'transformer': models.Transformer,
|
|
||||||
'gcn': models.GCN,
|
|
||||||
'capsule': models.Capsule,
|
|
||||||
'lm': models.LM,
|
|
||||||
}
|
|
||||||
|
|
||||||
# device
|
|
||||||
if cfg.use_gpu and torch.cuda.is_available():
|
|
||||||
device = torch.device('cuda', cfg.gpu_id)
|
|
||||||
else:
|
|
||||||
device = torch.device('cpu')
|
|
||||||
logger.info(f'device: {device}')
|
|
||||||
|
|
||||||
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
|
|
||||||
if cfg.preprocess:
|
|
||||||
preprocess(cfg)
|
|
||||||
|
|
||||||
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
|
|
||||||
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
|
|
||||||
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
|
|
||||||
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
|
|
||||||
|
|
||||||
if cfg.model_name == 'lm':
|
|
||||||
vocab_size = None
|
|
||||||
else:
|
|
||||||
vocab = load_pkl(vocab_path)
|
|
||||||
vocab_size = vocab.count
|
|
||||||
cfg.vocab_size = vocab_size
|
|
||||||
|
|
||||||
train_dataset = CustomDataset(train_data_path)
|
|
||||||
valid_dataset = CustomDataset(valid_data_path)
|
|
||||||
test_dataset = CustomDataset(test_data_path)
|
|
||||||
|
|
||||||
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
|
||||||
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
|
||||||
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
|
||||||
|
|
||||||
model = __Model__[cfg.model_name](cfg)
|
|
||||||
model.to(device)
|
|
||||||
logger.info(f'\n {model}')
|
|
||||||
|
|
||||||
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
|
||||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
best_f1, best_epoch = -1, 0
|
|
||||||
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
|
|
||||||
train_losses, valid_losses = [], []
|
|
||||||
|
|
||||||
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
|
|
||||||
writer = SummaryWriter('tensorboard')
|
|
||||||
else:
|
|
||||||
writer = None
|
|
||||||
|
|
||||||
logger.info('=' * 10 + ' Start training ' + '=' * 10)
|
|
||||||
|
|
||||||
for epoch in range(1, cfg.epoch + 1):
|
|
||||||
manual_seed(cfg.seed + epoch)
|
|
||||||
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
|
|
||||||
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
|
|
||||||
scheduler.step(valid_loss)
|
|
||||||
model_path = model.save(epoch, cfg)
|
|
||||||
# logger.info(model_path)
|
|
||||||
|
|
||||||
train_losses.append(train_loss)
|
|
||||||
valid_losses.append(valid_loss)
|
|
||||||
if best_f1 < valid_f1:
|
|
||||||
best_f1 = valid_f1
|
|
||||||
best_epoch = epoch
|
|
||||||
# 使用 valid loss 做 early stopping 的判断标准
|
|
||||||
if es_loss > valid_loss:
|
|
||||||
es_loss = valid_loss
|
|
||||||
es_f1 = valid_f1
|
|
||||||
es_epoch = epoch
|
|
||||||
es_patience = 0
|
|
||||||
es_path = model_path
|
|
||||||
else:
|
|
||||||
es_patience += 1
|
|
||||||
if es_patience >= cfg.early_stopping_patience:
|
|
||||||
best_es_epoch = es_epoch
|
|
||||||
best_es_f1 = es_f1
|
|
||||||
best_es_path = es_path
|
|
||||||
|
|
||||||
if cfg.show_plot:
|
|
||||||
if cfg.plot_utils == 'matplot':
|
|
||||||
plt.plot(train_losses, 'x-')
|
|
||||||
plt.plot(valid_losses, '+-')
|
|
||||||
plt.legend(['train', 'valid'])
|
|
||||||
plt.title('train/valid comparison loss')
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
if cfg.plot_utils == 'tensorboard':
|
|
||||||
for i in range(len(train_losses)):
|
|
||||||
writer.add_scalars('train/valid_comparison_loss', {
|
|
||||||
'train': train_losses[i],
|
|
||||||
'valid': valid_losses[i]
|
|
||||||
}, i)
|
|
||||||
writer.close()
|
|
||||||
|
|
||||||
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
|
|
||||||
f'this epoch macro f1: {best_es_f1:0.4f}')
|
|
||||||
logger.info(f'this model save path: {best_es_path}')
|
|
||||||
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
|
|
||||||
f'this epoch macro f1: {best_f1:.4f}')
|
|
||||||
|
|
||||||
logger.info('=====end of training====')
|
|
||||||
logger.info('')
|
|
||||||
logger.info('=====start test performance====')
|
|
||||||
validate(-1, model, test_dataloader, criterion, device, cfg)
|
|
||||||
logger.info('=====ending====')
|
|
||||||
|
|
|
@ -1,147 +0,0 @@
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import torch
|
|
||||||
import logging
|
|
||||||
import hydra
|
|
||||||
from hydra import utils
|
|
||||||
from serializer import Serializer
|
|
||||||
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
|
||||||
from utils import load_pkl, load_csv
|
|
||||||
import models as models
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _preprocess_data(data, cfg):
|
|
||||||
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
|
||||||
attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False)
|
|
||||||
atts = _handle_attribute_data(attribute_data)
|
|
||||||
cfg.vocab_size = vocab.count
|
|
||||||
serializer = Serializer(do_chinese_split=cfg.chinese_split)
|
|
||||||
serial = serializer.serialize
|
|
||||||
|
|
||||||
_serialize_sentence(data, serial, cfg)
|
|
||||||
_convert_tokens_into_index(data, vocab)
|
|
||||||
_add_pos_seq(data, cfg)
|
|
||||||
logger.info('start sentence preprocess...')
|
|
||||||
formats = '\nsentence: {}\nchinese_split: {}\n' \
|
|
||||||
'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}'
|
|
||||||
logger.info(
|
|
||||||
formats.format(data[0]['sentence'], cfg.chinese_split,
|
|
||||||
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
|
|
||||||
data[0]['entity_index'], data[0]['attribute_value_index']))
|
|
||||||
return data, atts
|
|
||||||
|
|
||||||
|
|
||||||
def _get_predict_instance(cfg):
|
|
||||||
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
|
|
||||||
flag = flag.strip().lower()
|
|
||||||
if flag == 'y' or flag == 'yes':
|
|
||||||
sentence = '张冬梅,女,汉族,1968年2月生,河南淇县人,1988年7月加入中国共产党,1989年9月参加工作,中央党校经济管理专业毕业,中央党校研究生学历'
|
|
||||||
entity = '张冬梅'
|
|
||||||
attribute_value = '汉族'
|
|
||||||
elif flag == 'n' or flag == 'no':
|
|
||||||
sentence = input('请输入句子:')
|
|
||||||
entity = input('请输入句中需要预测的实体:')
|
|
||||||
attribute_value = input('请输入句中需要预测的属性值:')
|
|
||||||
elif flag == 'exit':
|
|
||||||
sys.exit(0)
|
|
||||||
else:
|
|
||||||
print('please input yes or no, or exit!')
|
|
||||||
_get_predict_instance(cfg)
|
|
||||||
|
|
||||||
|
|
||||||
instance = dict()
|
|
||||||
instance['sentence'] = sentence.strip()
|
|
||||||
instance['entity'] = entity.strip()
|
|
||||||
instance['attribute_value'] = attribute_value.strip()
|
|
||||||
instance['entity_offset'] = sentence.find(entity)
|
|
||||||
instance['attribute_value_offset'] = sentence.find(attribute_value)
|
|
||||||
|
|
||||||
return instance
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path='../conf/config.yaml')
|
|
||||||
def main(cfg):
|
|
||||||
cwd = utils.get_original_cwd()
|
|
||||||
cwd = cwd[0:-5]
|
|
||||||
cfg.cwd = cwd
|
|
||||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
|
||||||
print(cfg.pretty())
|
|
||||||
|
|
||||||
# get predict instance
|
|
||||||
instance = _get_predict_instance(cfg)
|
|
||||||
data = [instance]
|
|
||||||
|
|
||||||
# preprocess data
|
|
||||||
data, rels = _preprocess_data(data, cfg)
|
|
||||||
|
|
||||||
# model
|
|
||||||
__Model__ = {
|
|
||||||
'cnn': models.PCNN,
|
|
||||||
'rnn': models.BiLSTM,
|
|
||||||
'transformer': models.Transformer,
|
|
||||||
'gcn': models.GCN,
|
|
||||||
'capsule': models.Capsule,
|
|
||||||
'lm': models.LM,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 最好在 cpu 上预测
|
|
||||||
cfg.use_gpu = False
|
|
||||||
if cfg.use_gpu and torch.cuda.is_available():
|
|
||||||
device = torch.device('cuda', cfg.gpu_id)
|
|
||||||
else:
|
|
||||||
device = torch.device('cpu')
|
|
||||||
logger.info(f'device: {device}')
|
|
||||||
|
|
||||||
model = __Model__[cfg.model_name](cfg)
|
|
||||||
logger.info(f'model name: {cfg.model_name}')
|
|
||||||
logger.info(f'\n {model}')
|
|
||||||
model.load(cfg.fp, device=device)
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
x = dict()
|
|
||||||
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
|
||||||
|
|
||||||
if cfg.model_name != 'lm':
|
|
||||||
x['entity_pos'], x['attribute_value_pos'] = torch.tensor([data[0]['entity_pos']]), torch.tensor([data[0]['attribute_value_pos']])
|
|
||||||
if cfg.model_name == 'cnn':
|
|
||||||
if cfg.use_pcnn:
|
|
||||||
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
|
||||||
if cfg.model_name == 'gcn':
|
|
||||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
|
||||||
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
|
|
||||||
x['adj'] = adj
|
|
||||||
|
|
||||||
|
|
||||||
for key in x.keys():
|
|
||||||
x[key] = x[key].to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
y_pred = model(x)
|
|
||||||
y_pred = torch.softmax(y_pred, dim=-1)[0]
|
|
||||||
prob = y_pred.max().item()
|
|
||||||
prob_att = list(rels.keys())[y_pred.argmax().item()]
|
|
||||||
logger.info(f"\"{data[0]['entity']}\" 和 \"{data[0]['attribute_value']}\" 在句中属性为:\"{prob_att}\",置信度为{prob:.2f}。")
|
|
||||||
|
|
||||||
if cfg.predict_plot:
|
|
||||||
plt.rcParams["font.family"] = 'Arial Unicode MS'
|
|
||||||
x = list(rels.keys())
|
|
||||||
height = list(y_pred.cpu().numpy())
|
|
||||||
plt.bar(x, height)
|
|
||||||
for x, y in zip(x, height):
|
|
||||||
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
|
|
||||||
plt.xlabel('关系')
|
|
||||||
plt.ylabel('置信度')
|
|
||||||
plt.xticks(rotation=315)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue