This commit is contained in:
tlk-dsg 2021-08-19 15:33:03 +08:00
parent df0009d59b
commit 10916e91a8
75 changed files with 408 additions and 448 deletions

232
README.md
View File

@ -1,191 +1,75 @@
# DeepKE <p align="center">
<br>
DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。 <img src="https://raw.githubusercontent.com/huggingface/transformers/master/docs/source/imgs/transformers_logo_name.png" width="400"/>
<br>
# Contributors <p>
<p align="center">
> Organization: [浙江大学知识引擎实验室](http://openkg.cn/) <a href="https://circleci.com/gh/huggingface/transformers">
<img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/master">
> Mentor: 陈华钧,张宁豫 </a>
<a href="https://github.com/huggingface/transformers/blob/master/LICENSE">
--- <img alt="GitHub" src="https://img.shields.io/github/license/huggingface/transformers.svg?color=blue">
</a>
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/huajunsir/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/huajunsir"> <a href="https://huggingface.co/transformers/index.html">
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/1858627?s=64&amp;v=4" width="48" height="48" border-radius="24" alt="@huajunsir"> <img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/transformers/index.html.svg?down_color=red&down_message=offline&up_message=online">
</a> </a>
<a href="https://github.com/huggingface/transformers/releases">
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/zxlzr/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/zxlzr"> <img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/transformers.svg">
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/1264492?s=64&amp;v=4" width="48" height="48" border-radius="24" alt="@zxlzr"> </a>
</a> <a href="https://github.com/huggingface/transformers/blob/master/CODE_OF_CONDUCT.md">
<img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-v2.0%20adopted-ff69b4.svg">
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/231sm/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/231sm"> </a>
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/26428692?s=64&amp;v=4" width="48" height="48" border-radius="24" alt="@231sm"> <a href="https://zenodo.org/badge/latestdoi/155220641"><img src="https://zenodo.org/badge/155220641.svg" alt="DOI"></a>
</a> </p>
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/ruoxuwang/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/ruoxuwang">
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/19322627?s=64&amp;v=4" width="48" height="48" border-radius="24" alt="@ruoxuwang">
</a>
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/yezqNLP/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/yezqNLP">
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/35182031?s=64&amp;v=4" width="48" height="48" border-radius="24" alt="@yezqNLP">
</a>
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/yuwl798180/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/yuwl798180">
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/18118119?s=64&amp;v=4" width="48" height="48" border-radius="24" alt="@yuwl798180">
</a>
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/seventk/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/seventk">
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/37468830?s=64&amp;v=4" width="48" height="48" border-radius="24" alt="@seventk">
</a>
## 环境依赖: <h3 align="center">
<p>基于深度学习的开源中文知识图谱抽取框架</p>
</h3>
> python >= 3.6 <h3 align="center">
<a href="https://hf.co/course"><img src="https://raw.githubusercontent.com/huggingface/transformers/master/docs/source/imgs/course_banner.png"></a>
</h3>
- torch >= 1.2 DeepKE 提供了多种知识抽取模型。
- hydra-core >= 0.11
- tensorboard >= 2.0 ## 在线演示
- matplotlib >= 3.1 演示的demo地址
- scikit-learn>=0.22
- transformers >= 2.0 1.NER
- jieba >= 0.39
- ~~pyhanlp >= 0.1.57~~(中文句法分析使用,但是在多句时效果也不好。。求推荐有比较好的中文句法分析) 2.RE
1.REGULAR
2.FEW-SHOT
3.DOCUMENT
3.AE
## 快速上手
## 主要目录 ## 安装
pip安装
``` ```
├── conf # 配置文件夹 pip install deepke
│ ├── config.yaml # 配置文件主入口
│ ├── preprocess.yaml # 数据预处理配置
│ ├── train.yaml # 训练过程参数配置
│ ├── hydra # log 日志输出目录配置
│ ├── embedding.yaml # embeding 层配置
│ ├── model # 模型配置文件夹
│ │ ├── cnn.yaml # cnn 模型参数配置
│ │ ├── rnn.yaml # rnn 模型参数配置
│ │ ├── capsule.yaml # capsule 模型参数配置
│ │ ├── transformer.yaml # transformer 模型参数配置
│ │ ├── gcn.yaml # gcn 模型参数配置
│ │ ├── lm.yaml # lm 模型参数配置
├── pretrained # 使用如 bert 等语言预训练模型时存放的参数
│ ├── vocab.txt # BERT 模型词表
│ ├── config.json # BERT 模型结构的配置文件
│ ├── pytorch_model.bin # 预训练模型参数
├── data # 数据目录
│ ├── origin # 训练使用的原始数据集
│ │ ├── train.csv # 训练数据集
│ │ ├── valid.csv # 验证数据集
│ │ ├── test.csv # 测试数据集
│ │ ├── relation.csv # 关系种类
│ ├── out # 预处理数据后的存放目录
├── module # 可复用模块
│ ├── Embedding.py # embedding 层
│ ├── CNN.py # cnn
│ ├── RNN.py # rnn
│ ├── Attention.py # attention
│ ├── Transformer.py # transformer
│ ├── Capsule.py # capsule
│ ├── GCN.py # gcn
├── models # 模型目录
│ ├── BasicModule.py # 模型基本配置
│ ├── PCNN.py # PCNN / CNN 模型
│ ├── BiLSTM.py # BiLSTM 模型
│ ├── Transformer.py # Transformer 模型
│ ├── LM.py # Language Model 模型
│ ├── Capsule.py # Capsule 模型
│ ├── GCN.py # GCN 模型
├── tools # 工具目录
│ ├── metrics.py # 评测指标文件
│ ├── serializer.py # 预处理数据过程序列化字符串文件
│ ├── preprocess.py # 训练前预处理数据文件
│ ├── vocab.py # token 词表构建函数文件
│ ├── dataset.py # 训练过程中批处理数据文件
│ ├── trainer.py # 训练验证迭代函数文件
│ ├── main.py # 主入口文件(训练)
│ ├── predict.py # 测试入口文件(测试)
├── test # pytest 测试目录
├── tutorial-notebooks # simple jupyter notebook tutorial
├── utils # 常用工具函数目录
│ ├── ioUtils.py # io工具
│ ├── nnUtils.py # 网络工具
├── README.md # read me 文件
``` ```
## 快速开始 ## 模型架构
Deepke包含了以下功能各子块导航到各模块的readme
数据为 csv 文件,样式范例为: 1.NER
2.RE 其中RE包括了以下三个子功能
sentence|relation|head|head_offset|tail|tail_offset **[REGULAR](https://github.com/tlk1997/deepke/blob/master/example/re/regular/re_regular.md)**
:---:|:---:|:---:|:---:|:---:|:---:
《岳父也是爹》是王军执导的电视剧,由马恩然、范明主演。|导演|岳父也是爹|1|王军|8
《九玄珠》是在纵横中文网连载的一部小说,作者是龙马。|连载网站|九玄珠|1|纵横中文网|7
提起杭州的美景,西湖总是第一个映入脑海的词语。|所在城市|西湖|8|杭州|2
- 安装依赖: `pip install -r requirements.txt` FEW-SHOT
- 存放数据:在 `data/origin` 文件夹下存放训练数据。训练文件主要有三个文件。更多数据建议使用百度数据库中[Knowledge Extraction](http://ai.baidu.com/broad/download)。 DOCUMENT
- `train.csv`:存放训练数据集 3.AE
- `valid.csv`:存放验证数据集 ## 引用
- `test.csv`:存放测试数据集
- `relation.csv`:存放关系种类
- 开始训练python main.py
- 每次训练的日志保存在 `logs` 文件夹内,模型结果保存在 `checkpoints` 文件夹内。
## 具体介绍
见 [wiki](https://github.com/zjunlp/deepke/wiki)
## 备注(常见问题)
1. 使用 Anaconda 时,建议添加国内镜像,下载速度更快。如[清华镜像](https://mirrors.tuna.tsinghua.edu.cn/help/anaconda/)。
1. 使用 pip 时,建议使用国内镜像,下载速度更快,如阿里云镜像。
1. 安装后提示 `ModuleNotFoundError: No module named 'past'`,输入命令 `pip install future` 即可解决。
1. 使用 `python main.py --help` 可以查看所有可配置参数,并定制修改参数结果。参数为 bool 值的,可以用 `10` 代替 `True, False`
- 如 `python main.py epoch=100 batch_size=128 use_gpu=False`
1. 使用 `python main.py xxx=xx,xx -m` 可以多任务处理程序。
- 如 `python main.py model=cnn,rnn,lm chinese_split=0,1 -m` 可以生成 3*2=6 个子任务。
1. 中文英文在数据预处理上有很多不同之处,`serializer.py` 用来专门序列化句子为 tokens。中文分词使用的是 jieba 分词。
- 英文序列化要求:大小写、特殊标点字符处理、特殊英文字符是否分词、是否做 word-piece 处理等。
- 中文序列化要求:是否分词、遇到英文字母是否大小写处理、是否将英文单词拆分按照单独字母处理等。
1. PCNN 预处理时,需要按照 head tail 的位置,将句子分为三段,做 piece wise max pooling。如果句子本身无法分为三段就无法用统一的预处理方式处理句子。
- 比如句子为:`杭州西湖`,不管怎么分隔都不能分隔为三段。
- 原文分隔三段的方式为:`[...head, ..., tail....]`,当然也可以分隔为:`[..., head...tail, ....]`,或者 `[...head, ...tail, ....]` 或者 `[..., head..., tail...]` 等。具体效果没多少区别。
1. PCNN 为什么不比 CNN 好,甚至更差??
- 本人在跑百度的数据集,也发现 PCNN 效果并没有想象中的比 CNN 有提升,甚至大多时候都不如 CNN 那种直接 max pooling的结果。百度的 [ARNOR](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/Research/ACL2019-ARNOR) 结果也是 PCNN 并不一定比 CNN 好。
1. 使用语言预训练模型时,在线安装下载模型比较慢,更建议提前下载好,存放到 `pretrained` 文件夹内。具体存放文件要求见文件夹内的 `readme.md`
1. 数据量较小时直接使用如12层的 BERT效果并不理想。此时可采取一些处理方式改善效果
- 数据量较小时层数调低些如设置为2、3层。
- 按照 BERT 训练方式,对新任务语料按照语言模型方式预训练。
1. 在单句上使用 GCN 时需要先做句法分析构建出词语之间的邻接矩阵句法树相邻的边值设为1不相邻为0
- ~~目前使用的是 `pyhanlp` 工具构建语法树。这个工具需要按照 java 包,具体使用见 [pyhanlp](https://github.com/hankcs/pyhanlp) 的介绍。~~ pyhanlp 在多句时效果也不理想,很多时候把整个单句当作一个节点。

BIN
benchmark/deepke目录.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

BIN
example/ae/deepke目录.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

View File

@ -2,10 +2,10 @@ hydra:
run: run:
# Output directory for normal runs # Output directory for normal runs
dir: ../logs/${now:%Y-%m-%d_%H-%M-%S} dir: logs/${now:%Y-%m-%d_%H-%M-%S}
sweep: sweep:
# Output directory for sweep runs # Output directory for sweep runs
dir: ../logs/${now:%Y-%m-%d_%H-%M-%S} dir: logs/${now:%Y-%m-%d_%H-%M-%S}
# Output sub directory for sweep runs. # Output sub directory for sweep runs.
subdir: ${hydra.job.num}_${hydra.job.id} subdir: ${hydra.job.num}_${hydra.job.id}

View File

Can't render this file because it is too large.

View File

@ -0,0 +1,156 @@
import os
import sys
import torch
import logging
import hydra
from hydra import utils
from deepke.tools import Serializer
from deepke.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from deepke.utils import load_pkl, load_csv
import deepke.models as models
logger = logging.getLogger(__name__)
def _preprocess_data(data, cfg):
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
relation_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'relation.csv'), verbose=False)
rels = _handle_relation_data(relation_data)
cfg.vocab_size = vocab.count
serializer = Serializer(do_chinese_split=cfg.chinese_split)
serial = serializer.serialize
_serialize_sentence(data, serial, cfg)
_convert_tokens_into_index(data, vocab)
_add_pos_seq(data, cfg)
logger.info('start sentence preprocess...')
formats = '\nsentence: {}\nchinese_split: {}\nreplace_entity_with_type: {}\nreplace_entity_with_scope: {}\n' \
'tokens: {}\ntoken2idx: {}\nlength: {}\nhead_idx: {}\ntail_idx: {}'
logger.info(
formats.format(data[0]['sentence'], cfg.chinese_split, cfg.replace_entity_with_type,
cfg.replace_entity_with_scope, data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
data[0]['head_idx'], data[0]['tail_idx']))
return data, rels
def _get_predict_instance(cfg):
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
flag = flag.strip().lower()
if flag == 'y' or flag == 'yes':
sentence = '《乡村爱情》是一部由知名导演赵本山在1985年所拍摄的农村青春偶像剧。'
head = '乡村爱情'
tail = '赵本山'
head_type = '电视剧'
tail_type = '人物'
elif flag == 'n' or flag == 'no':
sentence = input('请输入句子:')
head = input('请输入句中需要预测关系的头实体:')
head_type = input('请输入头实体类型可以为空按enter跳过')
tail = input('请输入句中需要预测关系的尾实体:')
tail_type = input('请输入尾实体类型可以为空按enter跳过')
elif flag == 'exit':
sys.exit(0)
else:
print('please input yes or no, or exit!')
_get_predict_instance()
instance = dict()
instance['sentence'] = sentence.strip()
instance['head'] = head.strip()
instance['tail'] = tail.strip()
if head_type.strip() == '' or tail_type.strip() == '':
cfg.replace_entity_with_type = False
instance['head_type'] = 'None'
instance['tail_type'] = 'None'
else:
instance['head_type'] = head_type.strip()
instance['tail_type'] = tail_type.strip()
return instance
@hydra.main(config_path='conf/config.yaml')
def main(cfg):
cwd = utils.get_original_cwd()
# cwd = cwd[0:-5]
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
print(cfg.pretty())
# get predict instance
instance = _get_predict_instance(cfg)
data = [instance]
# preprocess data
data, rels = _preprocess_data(data, cfg)
# model
__Model__ = {
'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
}
# 最好在 cpu 上预测
cfg.use_gpu = False
if cfg.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', cfg.gpu_id)
else:
device = torch.device('cpu')
logger.info(f'device: {device}')
model = __Model__[cfg.model_name](cfg)
logger.info(f'model name: {cfg.model_name}')
logger.info(f'\n {model}')
model.load(cfg.fp, device=device)
model.to(device)
model.eval()
x = dict()
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
if cfg.model_name != 'lm':
x['head_pos'], x['tail_pos'] = torch.tensor([data[0]['head_pos']]), torch.tensor([data[0]['tail_pos']])
if cfg.model_name == 'cnn':
if cfg.use_pcnn:
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
if cfg.model_name == 'gcn':
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
x['adj'] = adj
for key in x.keys():
x[key] = x[key].to(device)
with torch.no_grad():
y_pred = model(x)
y_pred = torch.softmax(y_pred, dim=-1)[0]
prob = y_pred.max().item()
prob_rel = list(rels.keys())[y_pred.argmax().item()]
logger.info(f"\"{data[0]['head']}\"\"{data[0]['tail']}\" 在句中关系为:\"{prob_rel}\",置信度为{prob:.2f}")
if cfg.predict_plot:
# maplot 默认显示不支持中文
plt.rcParams["font.family"] = 'Arial Unicode MS'
x = list(rels.keys())
height = list(y_pred.cpu().numpy())
plt.bar(x, height)
for x, y in zip(x, height):
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
plt.xlabel('关系')
plt.ylabel('置信度')
plt.xticks(rotation=315)
plt.show()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,38 @@
## 快速上手
### 克隆代码
```
git clone git@github.com:zjunlp/DeepKE.git
```
### 使用pip安装
首先创建python虚拟环境再进入虚拟环境
然后DeepKE可以依此安装
```
pip install deepke
```
### 使用工具
先进行训练(训练所用到参数都在conf文件夹中修改即可)
```
python run.py
```
再进行预测(需进入conf文件夹中predict.yaml修改使用模型路径)
```
python predict.py
```
## 模型架构
1、CNN
2、RNN
3、Capsule
4、GCN
5、Transformer
6、预训练模型

View File

@ -0,0 +1,8 @@
python == 3.8
torch == 1.5
hydra-core == 1.0.6
tensorboard == 2.4.1
matplotlib == 3.4.1
scikit-learn == 0.24.1
transformers == 4.5.0
jieba == 0.42.1

146
example/re/regular/run.py Normal file
View File

@ -0,0 +1,146 @@
import os
import hydra
import torch
import logging
import torch.nn as nn
from torch import optim
from hydra import utils
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# self
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
import deepke.models as models
from deepke.tools import preprocess , CustomDataset, collate_fn ,train, validate
from deepke.utils import manual_seed, load_pkl
logger = logging.getLogger(__name__)
@hydra.main(config_path="conf/config.yaml")
def main(cfg):
cwd = utils.get_original_cwd()
# cwd = cwd[0:-5]
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
logger.info(f'\n{cfg.pretty()}')
__Model__ = {
'cnn': models.PCNN,
'rnn': models.BiLSTM,
'transformer': models.Transformer,
'gcn': models.GCN,
'capsule': models.Capsule,
'lm': models.LM,
}
# device
if cfg.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', cfg.gpu_id)
else:
device = torch.device('cpu')
logger.info(f'device: {device}')
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
if cfg.preprocess:
preprocess(cfg)
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
if cfg.model_name == 'lm':
vocab_size = None
else:
vocab = load_pkl(vocab_path)
vocab_size = vocab.count
cfg.vocab_size = vocab_size
train_dataset = CustomDataset(train_data_path)
valid_dataset = CustomDataset(valid_data_path)
test_dataset = CustomDataset(test_data_path)
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
model = __Model__[cfg.model_name](cfg)
model.to(device)
logger.info(f'\n {model}')
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
criterion = nn.CrossEntropyLoss()
best_f1, best_epoch = -1, 0
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
train_losses, valid_losses = [], []
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
writer = SummaryWriter('tensorboard')
else:
writer = None
logger.info('=' * 10 + ' Start training ' + '=' * 10)
for epoch in range(1, cfg.epoch + 1):
manual_seed(cfg.seed + epoch)
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
scheduler.step(valid_loss)
model_path = model.save(epoch, cfg)
# logger.info(model_path)
train_losses.append(train_loss)
valid_losses.append(valid_loss)
if best_f1 < valid_f1:
best_f1 = valid_f1
best_epoch = epoch
# 使用 valid loss 做 early stopping 的判断标准
if es_loss > valid_loss:
es_loss = valid_loss
es_f1 = valid_f1
es_epoch = epoch
es_patience = 0
es_path = model_path
else:
es_patience += 1
if es_patience >= cfg.early_stopping_patience:
best_es_epoch = es_epoch
best_es_f1 = es_f1
best_es_path = es_path
if cfg.show_plot:
if cfg.plot_utils == 'matplot':
plt.plot(train_losses, 'x-')
plt.plot(valid_losses, '+-')
plt.legend(['train', 'valid'])
plt.title('train/valid comparison loss')
plt.show()
if cfg.plot_utils == 'tensorboard':
for i in range(len(train_losses)):
writer.add_scalars('train/valid_comparison_loss', {
'train': train_losses[i],
'valid': valid_losses[i]
}, i)
writer.close()
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
f'this epoch macro f1: {best_es_f1:0.4f}')
logger.info(f'this model save path: {best_es_path}')
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
f'this epoch macro f1: {best_f1:.4f}')
logger.info('=====end of training====')
logger.info('')
logger.info('=====start test performance====')
validate(-1, model, test_dataloader, criterion, device, cfg)
logger.info('=====ending====')
if __name__ == '__main__':
main()

View File

Before

Width:  |  Height:  |  Size: 128 KiB

After

Width:  |  Height:  |  Size: 128 KiB

View File

Before

Width:  |  Height:  |  Size: 68 KiB

After

Width:  |  Height:  |  Size: 68 KiB

View File

Before

Width:  |  Height:  |  Size: 105 KiB

After

Width:  |  Height:  |  Size: 105 KiB

View File

Before

Width:  |  Height:  |  Size: 136 KiB

After

Width:  |  Height:  |  Size: 136 KiB

View File

Before

Width:  |  Height:  |  Size: 292 KiB

After

Width:  |  Height:  |  Size: 292 KiB

View File

Before

Width:  |  Height:  |  Size: 118 KiB

After

Width:  |  Height:  |  Size: 118 KiB

View File

Before

Width:  |  Height:  |  Size: 137 KiB

After

Width:  |  Height:  |  Size: 137 KiB

View File

Before

Width:  |  Height:  |  Size: 94 KiB

After

Width:  |  Height:  |  Size: 94 KiB

View File

Before

Width:  |  Height:  |  Size: 182 KiB

After

Width:  |  Height:  |  Size: 182 KiB

View File

@ -1 +0,0 @@
test

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 149 KiB

View File

@ -1,38 +0,0 @@
import pytest
import torch
from utils import seq_len_to_mask
from module import DotAttention, MultiHeadAttention
torch.manual_seed(1)
q = torch.randn(4, 6, 20) # [B, L, H]
k = v = torch.randn(4, 5, 20) # [B, S, H]
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉
head_mask = torch.tensor([0, 1, 0, 0]) # 为1 的地方 mask 掉
# m = DotAttention(dropout=0.0)
# ao,aw = m(q,k,v,key_padding_mask)
# print(ao.shape,aw.shape)
# print(aw)
def test_DotAttention():
m = DotAttention(dropout=0.0)
ao, aw = m(q, k, v, mask_out=key_padding_mask)
assert ao.shape == torch.Size([4, 6, 20])
assert aw.shape == torch.Size([4, 6, 5])
assert torch.all(aw[1, :, -1:].eq(0)) == torch.all(aw[2, :, -2:].eq(0)) == torch.all(aw[3, :, -3:].eq(0)) == True
def test_MultiHeadAttention():
m = MultiHeadAttention(embed_dim=20, num_heads=4, dropout=0.0)
ao, aw = m(q, k, v, key_padding_mask=key_padding_mask,attention_mask=attention_mask,head_mask=head_mask)
assert ao.shape == torch.Size([4, 6, 20])
assert aw.shape == torch.Size([4, 4, 6, 5])
assert aw.unbind(dim=1)[1].bool().any() == False
if __name__ == '__main__':
pytest.main()

View File

@ -1,32 +0,0 @@
import pytest
import torch
from module import CNN
from utils import seq_len_to_mask
class Config(object):
in_channels = 100
out_channels = 200
kernel_sizes = [3, 5, 7, 9, 11]
activation = 'gelu'
pooling_strategy = 'avg'
config = Config()
def test_CNN():
x = torch.randn(4, 5, 100)
seq = torch.arange(4, 0, -1)
mask = seq_len_to_mask(seq, max_len=5)
cnn = CNN(config)
out, out_pooling = cnn(x, mask=mask)
out_channels = config.out_channels * len(config.kernel_sizes)
assert out.shape == torch.Size([4, 5, out_channels])
assert out_pooling.shape == torch.Size([4, out_channels])
if __name__ == '__main__':
pytest.main()

View File

@ -1,38 +0,0 @@
import pytest
import torch
from module import Embedding
class Config(object):
vocab_size = 10
word_dim = 10
pos_size = 12 # 2 * pos_limit + 2
pos_dim = 5
dim_strategy = 'cat' # [cat, sum]
config = Config()
x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 3, 5, 0], [8, 4, 3, 0, 0]])
x_pos = torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0], [1, 2, 3, 0, 0]])
def test_Embedding_cat():
embed = Embedding(config)
feature = embed((x, x_pos))
dim = config.word_dim + config.pos_dim
assert feature.shape == torch.Size((3, 5, dim))
def test_Embedding_sum():
config.dim_strategy = 'sum'
embed = Embedding(config)
feature = embed((x, x_pos))
dim = config.word_dim
assert feature.shape == torch.Size((3, 5, dim))
if __name__ == '__main__':
pytest.main()

View File

@ -1,49 +0,0 @@
import pytest
import torch
from module import RNN
from utils import seq_len_to_mask
class Config(object):
type_rnn = 'LSTM'
input_size = 5
hidden_size = 4
num_layers = 3
dropout = 0.0
last_layer_hn = False
bidirectional = True
config = Config()
def test_CNN():
torch.manual_seed(1)
x = torch.tensor([[4, 3, 2, 1], [5, 6, 7, 0], [8, 10, 0, 0]])
x = torch.nn.Embedding(11, 5, padding_idx=0)(x) # B,L,H = 3,4,5
x_len = torch.tensor([4, 3, 2])
model = RNN(config)
output, hn = model(x, x_len)
B, L, _ = x.size()
H, N = config.hidden_size, config.num_layers
assert output.shape == torch.Size([B, L, H])
assert hn.shape == torch.Size([B, N, H])
config.bidirectional = False
model = RNN(config)
output, hn = model(x, x_len)
assert output.shape == torch.Size([B, L, H])
assert hn.shape == torch.Size([B, N, H])
config.last_layer_hn = True
model = RNN(config)
output, hn = model(x, x_len)
assert output.shape == torch.Size([B, L, H])
assert hn.shape == torch.Size([B, H])
if __name__ == '__main__':
pytest.main()

View File

@ -1,36 +0,0 @@
import pytest
from serializer import Serializer
def test_serializer_for_no_chinese_split():
text1 = "\nI\'m his pupp\'peer, and i have a ball\t"
text2 = '\t叫Stam一起到nba打篮球\n'
text3 = '\n\n现在时刻2014-04-08\t\t'
serializer = Serializer(do_chinese_split=False)
serial_text1 = serializer.serialize(text1)
serial_text2 = serializer.serialize(text2)
serial_text3 = serializer.serialize(text3)
assert serial_text1 == ['i', "'", 'm', 'his', 'pupp', "'", 'peer', ',', 'and', 'i', 'have', 'a', 'ball']
assert serial_text2 == ['', 'stam', '', '', '', 'nba', '', '', '']
assert serial_text3 == ['', '', '', '', '2014', '-', '04', '-', '08']
def test_serializer_for_chinese_split():
text1 = "\nI\'m his pupp\'peer, and i have a basketball\t"
text2 = '\t叫Stam一起到nba打篮球\n'
text3 = '\n\n现在时刻2014-04-08\t\t'
serializer = Serializer(do_chinese_split=True)
serial_text1 = serializer.serialize(text1)
serial_text2 = serializer.serialize(text2)
serial_text3 = serializer.serialize(text3)
assert serial_text1 == ['i', "'", 'm', 'his', 'pupp', "'", 'peer', ',', 'and', 'i', 'have', 'a', 'basketball']
assert serial_text2 == ['', 'stam', '一起', '', 'nba', '打篮球']
assert serial_text3 == ['现在', '时刻', '2014', '-', '04', '-', '08']
if __name__ == '__main__':
pytest.main()

View File

@ -1,40 +0,0 @@
import pytest
import torch
from module import Transformer
from utils import seq_len_to_mask
class Config():
hidden_size = 12
intermediate_size = 24
num_hidden_layers = 5
num_heads = 3
dropout = 0.0
layer_norm_eps = 1e-12
hidden_act = 'gelu_new'
output_attentions = True
output_hidden_states = True
config = Config()
def test_Transformer():
m = Transformer(config)
i = torch.randn(4, 5, 12) # [B, L, H]
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉
head_mask = torch.tensor([0, 1, 0]) # 为1 的地方 mask 掉
out = m(i, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask)
hn, h_all, att_weights = out
assert hn.shape == torch.Size([4, 5, 12])
assert torch.equal(h_all[0], i) and torch.equal(h_all[-1], hn) == True
assert len(h_all) == config.num_hidden_layers + 1
assert len(att_weights) == config.num_hidden_layers
assert att_weights[0].shape == torch.Size([4, 3, 5, 5])
assert att_weights[0].unbind(dim=1)[1].bool().any() == False
if __name__ == '__main__':
pytest.main()

View File

@ -1,38 +0,0 @@
import pytest
from serializer import Serializer
from vocab import Vocab
def test_vocab():
vocab = Vocab('test')
sent = ' 我是中国人,我爱中国。 I\'m Chinese, I love China'
serializer = Serializer(do_lower_case=True)
tokens = serializer.serialize(sent)
assert tokens == [
'', '', '', '', '', '', '', '', '', '', '', 'i', "'", 'm', 'chinese', ',', 'i', 'love', 'china'
]
vocab.add_words(tokens)
unk_str = '[UNK]'
unk_idx = vocab.word2idx[unk_str]
assert vocab.count == 22
assert len(vocab.word2idx) == len(vocab.idx2word) == len(vocab.word2idx) == 22
vocab.trim(2, verbose=False)
assert vocab.count == 11
assert len(vocab.word2idx) == len(vocab.idx2word) == len(vocab.word2idx) == 11
token2idx = [vocab.word2idx.get(i, unk_idx) for i in tokens]
assert len(tokens) == len(token2idx)
assert token2idx == [7, 1, 8, 9, 1, 1, 7, 1, 8, 9, 1, 10, 1, 1, 1, 1, 10, 1, 1]
idx2tokens = [vocab.idx2word.get(i, unk_str) for i in token2idx]
assert len(idx2tokens) == len(token2idx)
assert ' '.join(idx2tokens) == '我 [UNK] 中 国 [UNK] [UNK] 我 [UNK] 中 国 [UNK] i [UNK] [UNK] [UNK] [UNK] i [UNK] [UNK]'
if __name__ == '__main__':
pytest.main()