test
232
README.md
|
@ -1,191 +1,75 @@
|
||||||
# DeepKE
|
<p align="center">
|
||||||
|
<br>
|
||||||
DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。
|
<img src="https://raw.githubusercontent.com/huggingface/transformers/master/docs/source/imgs/transformers_logo_name.png" width="400"/>
|
||||||
|
<br>
|
||||||
# Contributors
|
<p>
|
||||||
|
<p align="center">
|
||||||
> Organization: [浙江大学知识引擎实验室](http://openkg.cn/)
|
<a href="https://circleci.com/gh/huggingface/transformers">
|
||||||
|
<img alt="Build" src="https://img.shields.io/circleci/build/github/huggingface/transformers/master">
|
||||||
> Mentor: 陈华钧,张宁豫
|
</a>
|
||||||
|
<a href="https://github.com/huggingface/transformers/blob/master/LICENSE">
|
||||||
---
|
<img alt="GitHub" src="https://img.shields.io/github/license/huggingface/transformers.svg?color=blue">
|
||||||
|
</a>
|
||||||
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/huajunsir/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/huajunsir">
|
<a href="https://huggingface.co/transformers/index.html">
|
||||||
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/1858627?s=64&v=4" width="48" height="48" border-radius="24" alt="@huajunsir">
|
<img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/transformers/index.html.svg?down_color=red&down_message=offline&up_message=online">
|
||||||
</a>
|
</a>
|
||||||
|
<a href="https://github.com/huggingface/transformers/releases">
|
||||||
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/zxlzr/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/zxlzr">
|
<img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/transformers.svg">
|
||||||
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/1264492?s=64&v=4" width="48" height="48" border-radius="24" alt="@zxlzr">
|
</a>
|
||||||
</a>
|
<a href="https://github.com/huggingface/transformers/blob/master/CODE_OF_CONDUCT.md">
|
||||||
|
<img alt="Contributor Covenant" src="https://img.shields.io/badge/Contributor%20Covenant-v2.0%20adopted-ff69b4.svg">
|
||||||
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/231sm/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/231sm">
|
</a>
|
||||||
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/26428692?s=64&v=4" width="48" height="48" border-radius="24" alt="@231sm">
|
<a href="https://zenodo.org/badge/latestdoi/155220641"><img src="https://zenodo.org/badge/155220641.svg" alt="DOI"></a>
|
||||||
</a>
|
</p>
|
||||||
|
|
||||||
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/ruoxuwang/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/ruoxuwang">
|
|
||||||
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/19322627?s=64&v=4" width="48" height="48" border-radius="24" alt="@ruoxuwang">
|
|
||||||
</a>
|
|
||||||
|
|
||||||
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/yezqNLP/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/yezqNLP">
|
|
||||||
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/35182031?s=64&v=4" width="48" height="48" border-radius="24" alt="@yezqNLP">
|
|
||||||
</a>
|
|
||||||
|
|
||||||
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/yuwl798180/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/yuwl798180">
|
|
||||||
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/18118119?s=64&v=4" width="48" height="48" border-radius="24" alt="@yuwl798180">
|
|
||||||
</a>
|
|
||||||
|
|
||||||
<a class="mr-2" data-hovercard-type="user" data-hovercard-url="https://github.com/users/seventk/hovercard" data-octo-click="hovercard-link-click" data-octo-dimensions="link_type:self" href="https://github.com/seventk">
|
|
||||||
<img class="d-block avatar-user" src="https://avatars0.githubusercontent.com/u/37468830?s=64&v=4" width="48" height="48" border-radius="24" alt="@seventk">
|
|
||||||
</a>
|
|
||||||
|
|
||||||
|
|
||||||
## 环境依赖:
|
<h3 align="center">
|
||||||
|
<p>基于深度学习的开源中文知识图谱抽取框架</p>
|
||||||
|
</h3>
|
||||||
|
|
||||||
> python >= 3.6
|
<h3 align="center">
|
||||||
|
<a href="https://hf.co/course"><img src="https://raw.githubusercontent.com/huggingface/transformers/master/docs/source/imgs/course_banner.png"></a>
|
||||||
|
</h3>
|
||||||
|
|
||||||
- torch >= 1.2
|
DeepKE 提供了多种知识抽取模型。
|
||||||
- hydra-core >= 0.11
|
|
||||||
- tensorboard >= 2.0
|
## 在线演示
|
||||||
- matplotlib >= 3.1
|
演示的demo地址
|
||||||
- scikit-learn>=0.22
|
|
||||||
- transformers >= 2.0
|
1.NER
|
||||||
- jieba >= 0.39
|
|
||||||
- ~~pyhanlp >= 0.1.57~~(中文句法分析使用,但是在多句时效果也不好。。求推荐有比较好的中文句法分析)
|
2.RE
|
||||||
|
|
||||||
|
1.REGULAR
|
||||||
|
|
||||||
|
2.FEW-SHOT
|
||||||
|
|
||||||
|
3.DOCUMENT
|
||||||
|
|
||||||
|
3.AE
|
||||||
|
|
||||||
|
|
||||||
|
## 快速上手
|
||||||
|
|
||||||
## 主要目录
|
## 安装
|
||||||
|
pip安装
|
||||||
```
|
```
|
||||||
├── conf # 配置文件夹
|
pip install deepke
|
||||||
│ ├── config.yaml # 配置文件主入口
|
|
||||||
│ ├── preprocess.yaml # 数据预处理配置
|
|
||||||
│ ├── train.yaml # 训练过程参数配置
|
|
||||||
│ ├── hydra # log 日志输出目录配置
|
|
||||||
│ ├── embedding.yaml # embeding 层配置
|
|
||||||
│ ├── model # 模型配置文件夹
|
|
||||||
│ │ ├── cnn.yaml # cnn 模型参数配置
|
|
||||||
│ │ ├── rnn.yaml # rnn 模型参数配置
|
|
||||||
│ │ ├── capsule.yaml # capsule 模型参数配置
|
|
||||||
│ │ ├── transformer.yaml # transformer 模型参数配置
|
|
||||||
│ │ ├── gcn.yaml # gcn 模型参数配置
|
|
||||||
│ │ ├── lm.yaml # lm 模型参数配置
|
|
||||||
├── pretrained # 使用如 bert 等语言预训练模型时存放的参数
|
|
||||||
│ ├── vocab.txt # BERT 模型词表
|
|
||||||
│ ├── config.json # BERT 模型结构的配置文件
|
|
||||||
│ ├── pytorch_model.bin # 预训练模型参数
|
|
||||||
├── data # 数据目录
|
|
||||||
│ ├── origin # 训练使用的原始数据集
|
|
||||||
│ │ ├── train.csv # 训练数据集
|
|
||||||
│ │ ├── valid.csv # 验证数据集
|
|
||||||
│ │ ├── test.csv # 测试数据集
|
|
||||||
│ │ ├── relation.csv # 关系种类
|
|
||||||
│ ├── out # 预处理数据后的存放目录
|
|
||||||
├── module # 可复用模块
|
|
||||||
│ ├── Embedding.py # embedding 层
|
|
||||||
│ ├── CNN.py # cnn
|
|
||||||
│ ├── RNN.py # rnn
|
|
||||||
│ ├── Attention.py # attention
|
|
||||||
│ ├── Transformer.py # transformer
|
|
||||||
│ ├── Capsule.py # capsule
|
|
||||||
│ ├── GCN.py # gcn
|
|
||||||
├── models # 模型目录
|
|
||||||
│ ├── BasicModule.py # 模型基本配置
|
|
||||||
│ ├── PCNN.py # PCNN / CNN 模型
|
|
||||||
│ ├── BiLSTM.py # BiLSTM 模型
|
|
||||||
│ ├── Transformer.py # Transformer 模型
|
|
||||||
│ ├── LM.py # Language Model 模型
|
|
||||||
│ ├── Capsule.py # Capsule 模型
|
|
||||||
│ ├── GCN.py # GCN 模型
|
|
||||||
├── tools # 工具目录
|
|
||||||
│ ├── metrics.py # 评测指标文件
|
|
||||||
│ ├── serializer.py # 预处理数据过程序列化字符串文件
|
|
||||||
│ ├── preprocess.py # 训练前预处理数据文件
|
|
||||||
│ ├── vocab.py # token 词表构建函数文件
|
|
||||||
│ ├── dataset.py # 训练过程中批处理数据文件
|
|
||||||
│ ├── trainer.py # 训练验证迭代函数文件
|
|
||||||
│ ├── main.py # 主入口文件(训练)
|
|
||||||
│ ├── predict.py # 测试入口文件(测试)
|
|
||||||
├── test # pytest 测试目录
|
|
||||||
├── tutorial-notebooks # simple jupyter notebook tutorial
|
|
||||||
├── utils # 常用工具函数目录
|
|
||||||
│ ├── ioUtils.py # io工具
|
|
||||||
│ ├── nnUtils.py # 网络工具
|
|
||||||
├── README.md # read me 文件
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 快速开始
|
## 模型架构
|
||||||
|
Deepke包含了以下功能:(各子块导航到各模块的readme)
|
||||||
|
|
||||||
数据为 csv 文件,样式范例为:
|
1.NER
|
||||||
|
|
||||||
|
2.RE 其中RE包括了以下三个子功能
|
||||||
|
|
||||||
sentence|relation|head|head_offset|tail|tail_offset
|
**[REGULAR](https://github.com/tlk1997/deepke/blob/master/example/re/regular/re_regular.md)**
|
||||||
:---:|:---:|:---:|:---:|:---:|:---:
|
|
||||||
《岳父也是爹》是王军执导的电视剧,由马恩然、范明主演。|导演|岳父也是爹|1|王军|8
|
|
||||||
《九玄珠》是在纵横中文网连载的一部小说,作者是龙马。|连载网站|九玄珠|1|纵横中文网|7
|
|
||||||
提起杭州的美景,西湖总是第一个映入脑海的词语。|所在城市|西湖|8|杭州|2
|
|
||||||
|
|
||||||
- 安装依赖: `pip install -r requirements.txt`
|
FEW-SHOT
|
||||||
|
|
||||||
- 存放数据:在 `data/origin` 文件夹下存放训练数据。训练文件主要有三个文件。更多数据建议使用百度数据库中[Knowledge Extraction](http://ai.baidu.com/broad/download)。
|
DOCUMENT
|
||||||
|
|
||||||
- `train.csv`:存放训练数据集
|
3.AE
|
||||||
|
|
||||||
- `valid.csv`:存放验证数据集
|
## 引用
|
||||||
|
|
||||||
- `test.csv`:存放测试数据集
|
|
||||||
|
|
||||||
- `relation.csv`:存放关系种类
|
|
||||||
|
|
||||||
- 开始训练:python main.py
|
|
||||||
|
|
||||||
- 每次训练的日志保存在 `logs` 文件夹内,模型结果保存在 `checkpoints` 文件夹内。
|
|
||||||
|
|
||||||
## 具体介绍
|
|
||||||
|
|
||||||
见 [wiki](https://github.com/zjunlp/deepke/wiki)
|
|
||||||
|
|
||||||
|
|
||||||
## 备注(常见问题)
|
|
||||||
|
|
||||||
1. 使用 Anaconda 时,建议添加国内镜像,下载速度更快。如[清华镜像](https://mirrors.tuna.tsinghua.edu.cn/help/anaconda/)。
|
|
||||||
|
|
||||||
1. 使用 pip 时,建议使用国内镜像,下载速度更快,如阿里云镜像。
|
|
||||||
|
|
||||||
1. 安装后提示 `ModuleNotFoundError: No module named 'past'`,输入命令 `pip install future` 即可解决。
|
|
||||||
|
|
||||||
1. 使用 `python main.py --help` 可以查看所有可配置参数,并定制修改参数结果。参数为 bool 值的,可以用 `1,0` 代替 `True, False`。
|
|
||||||
|
|
||||||
- 如 `python main.py epoch=100 batch_size=128 use_gpu=False`
|
|
||||||
|
|
||||||
1. 使用 `python main.py xxx=xx,xx -m` 可以多任务处理程序。
|
|
||||||
|
|
||||||
- 如 `python main.py model=cnn,rnn,lm chinese_split=0,1 -m` 可以生成 3*2=6 个子任务。
|
|
||||||
|
|
||||||
1. 中文英文在数据预处理上有很多不同之处,`serializer.py` 用来专门序列化句子为 tokens。中文分词使用的是 jieba 分词。
|
|
||||||
|
|
||||||
- 英文序列化要求:大小写、特殊标点字符处理、特殊英文字符是否分词、是否做 word-piece 处理等。
|
|
||||||
|
|
||||||
- 中文序列化要求:是否分词、遇到英文字母是否大小写处理、是否将英文单词拆分按照单独字母处理等。
|
|
||||||
|
|
||||||
1. PCNN 预处理时,需要按照 head tail 的位置,将句子分为三段,做 piece wise max pooling。如果句子本身无法分为三段,就无法用统一的预处理方式处理句子。
|
|
||||||
|
|
||||||
- 比如句子为:`杭州西湖`,不管怎么分隔都不能分隔为三段。
|
|
||||||
|
|
||||||
- 原文分隔三段的方式为:`[...head, ..., tail....]`,当然也可以分隔为:`[..., head...tail, ....]`,或者 `[...head, ...tail, ....]` 或者 `[..., head..., tail...]` 等。具体效果没多少区别。
|
|
||||||
|
|
||||||
1. PCNN 为什么不比 CNN 好,甚至更差??
|
|
||||||
|
|
||||||
- 本人在跑百度的数据集,也发现 PCNN 效果并没有想象中的比 CNN 有提升,甚至大多时候都不如 CNN 那种直接 max pooling的结果。百度的 [ARNOR](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/Research/ACL2019-ARNOR) 结果也是 PCNN 并不一定比 CNN 好。
|
|
||||||
|
|
||||||
1. 使用语言预训练模型时,在线安装下载模型比较慢,更建议提前下载好,存放到 `pretrained` 文件夹内。具体存放文件要求见文件夹内的 `readme.md`。
|
|
||||||
|
|
||||||
1. 数据量较小时,直接使用如12层的 BERT,效果并不理想。此时可采取一些处理方式改善效果:
|
|
||||||
|
|
||||||
- 数据量较小时层数调低些,如设置为2、3层。
|
|
||||||
|
|
||||||
- 按照 BERT 训练方式,对新任务语料按照语言模型方式预训练。
|
|
||||||
|
|
||||||
1. 在单句上使用 GCN 时,需要先做句法分析,构建出词语之间的邻接矩阵(句法树相邻的边值设为1,不相邻为0)。
|
|
||||||
|
|
||||||
- ~~目前使用的是 `pyhanlp` 工具构建语法树。这个工具需要按照 java 包,具体使用见 [pyhanlp](https://github.com/hankcs/pyhanlp) 的介绍。~~ pyhanlp 在多句时效果也不理想,很多时候把整个单句当作一个节点。
|
|
After Width: | Height: | Size: 149 KiB |
After Width: | Height: | Size: 149 KiB |
After Width: | Height: | Size: 149 KiB |
After Width: | Height: | Size: 149 KiB |
After Width: | Height: | Size: 149 KiB |
|
@ -2,10 +2,10 @@ hydra:
|
||||||
|
|
||||||
run:
|
run:
|
||||||
# Output directory for normal runs
|
# Output directory for normal runs
|
||||||
dir: ../logs/${now:%Y-%m-%d_%H-%M-%S}
|
dir: logs/${now:%Y-%m-%d_%H-%M-%S}
|
||||||
|
|
||||||
sweep:
|
sweep:
|
||||||
# Output directory for sweep runs
|
# Output directory for sweep runs
|
||||||
dir: ../logs/${now:%Y-%m-%d_%H-%M-%S}
|
dir: logs/${now:%Y-%m-%d_%H-%M-%S}
|
||||||
# Output sub directory for sweep runs.
|
# Output sub directory for sweep runs.
|
||||||
subdir: ${hydra.job.num}_${hydra.job.id}
|
subdir: ${hydra.job.num}_${hydra.job.id}
|
Can't render this file because it is too large.
|
|
@ -0,0 +1,156 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import hydra
|
||||||
|
from hydra import utils
|
||||||
|
from deepke.tools import Serializer
|
||||||
|
from deepke.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
|
from deepke.utils import load_pkl, load_csv
|
||||||
|
import deepke.models as models
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_data(data, cfg):
|
||||||
|
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
||||||
|
relation_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'relation.csv'), verbose=False)
|
||||||
|
rels = _handle_relation_data(relation_data)
|
||||||
|
cfg.vocab_size = vocab.count
|
||||||
|
serializer = Serializer(do_chinese_split=cfg.chinese_split)
|
||||||
|
serial = serializer.serialize
|
||||||
|
|
||||||
|
_serialize_sentence(data, serial, cfg)
|
||||||
|
_convert_tokens_into_index(data, vocab)
|
||||||
|
_add_pos_seq(data, cfg)
|
||||||
|
logger.info('start sentence preprocess...')
|
||||||
|
formats = '\nsentence: {}\nchinese_split: {}\nreplace_entity_with_type: {}\nreplace_entity_with_scope: {}\n' \
|
||||||
|
'tokens: {}\ntoken2idx: {}\nlength: {}\nhead_idx: {}\ntail_idx: {}'
|
||||||
|
logger.info(
|
||||||
|
formats.format(data[0]['sentence'], cfg.chinese_split, cfg.replace_entity_with_type,
|
||||||
|
cfg.replace_entity_with_scope, data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
|
||||||
|
data[0]['head_idx'], data[0]['tail_idx']))
|
||||||
|
return data, rels
|
||||||
|
|
||||||
|
|
||||||
|
def _get_predict_instance(cfg):
|
||||||
|
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
|
||||||
|
flag = flag.strip().lower()
|
||||||
|
if flag == 'y' or flag == 'yes':
|
||||||
|
sentence = '《乡村爱情》是一部由知名导演赵本山在1985年所拍摄的农村青春偶像剧。'
|
||||||
|
head = '乡村爱情'
|
||||||
|
tail = '赵本山'
|
||||||
|
head_type = '电视剧'
|
||||||
|
tail_type = '人物'
|
||||||
|
elif flag == 'n' or flag == 'no':
|
||||||
|
sentence = input('请输入句子:')
|
||||||
|
head = input('请输入句中需要预测关系的头实体:')
|
||||||
|
head_type = input('请输入头实体类型(可以为空,按enter跳过):')
|
||||||
|
tail = input('请输入句中需要预测关系的尾实体:')
|
||||||
|
tail_type = input('请输入尾实体类型(可以为空,按enter跳过):')
|
||||||
|
elif flag == 'exit':
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print('please input yes or no, or exit!')
|
||||||
|
_get_predict_instance()
|
||||||
|
|
||||||
|
instance = dict()
|
||||||
|
instance['sentence'] = sentence.strip()
|
||||||
|
instance['head'] = head.strip()
|
||||||
|
instance['tail'] = tail.strip()
|
||||||
|
if head_type.strip() == '' or tail_type.strip() == '':
|
||||||
|
cfg.replace_entity_with_type = False
|
||||||
|
instance['head_type'] = 'None'
|
||||||
|
instance['tail_type'] = 'None'
|
||||||
|
else:
|
||||||
|
instance['head_type'] = head_type.strip()
|
||||||
|
instance['tail_type'] = tail_type.strip()
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_path='conf/config.yaml')
|
||||||
|
def main(cfg):
|
||||||
|
cwd = utils.get_original_cwd()
|
||||||
|
# cwd = cwd[0:-5]
|
||||||
|
cfg.cwd = cwd
|
||||||
|
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||||
|
print(cfg.pretty())
|
||||||
|
|
||||||
|
# get predict instance
|
||||||
|
instance = _get_predict_instance(cfg)
|
||||||
|
data = [instance]
|
||||||
|
|
||||||
|
# preprocess data
|
||||||
|
data, rels = _preprocess_data(data, cfg)
|
||||||
|
|
||||||
|
# model
|
||||||
|
__Model__ = {
|
||||||
|
'cnn': models.PCNN,
|
||||||
|
'rnn': models.BiLSTM,
|
||||||
|
'transformer': models.Transformer,
|
||||||
|
'gcn': models.GCN,
|
||||||
|
'capsule': models.Capsule,
|
||||||
|
'lm': models.LM,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 最好在 cpu 上预测
|
||||||
|
cfg.use_gpu = False
|
||||||
|
if cfg.use_gpu and torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda', cfg.gpu_id)
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
logger.info(f'device: {device}')
|
||||||
|
|
||||||
|
model = __Model__[cfg.model_name](cfg)
|
||||||
|
logger.info(f'model name: {cfg.model_name}')
|
||||||
|
logger.info(f'\n {model}')
|
||||||
|
model.load(cfg.fp, device=device)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
x = dict()
|
||||||
|
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
||||||
|
|
||||||
|
if cfg.model_name != 'lm':
|
||||||
|
x['head_pos'], x['tail_pos'] = torch.tensor([data[0]['head_pos']]), torch.tensor([data[0]['tail_pos']])
|
||||||
|
if cfg.model_name == 'cnn':
|
||||||
|
if cfg.use_pcnn:
|
||||||
|
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
||||||
|
if cfg.model_name == 'gcn':
|
||||||
|
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||||
|
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
|
||||||
|
x['adj'] = adj
|
||||||
|
|
||||||
|
|
||||||
|
for key in x.keys():
|
||||||
|
x[key] = x[key].to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
y_pred = model(x)
|
||||||
|
y_pred = torch.softmax(y_pred, dim=-1)[0]
|
||||||
|
prob = y_pred.max().item()
|
||||||
|
prob_rel = list(rels.keys())[y_pred.argmax().item()]
|
||||||
|
logger.info(f"\"{data[0]['head']}\" 和 \"{data[0]['tail']}\" 在句中关系为:\"{prob_rel}\",置信度为{prob:.2f}。")
|
||||||
|
|
||||||
|
if cfg.predict_plot:
|
||||||
|
# maplot 默认显示不支持中文
|
||||||
|
plt.rcParams["font.family"] = 'Arial Unicode MS'
|
||||||
|
x = list(rels.keys())
|
||||||
|
height = list(y_pred.cpu().numpy())
|
||||||
|
plt.bar(x, height)
|
||||||
|
for x, y in zip(x, height):
|
||||||
|
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
|
||||||
|
plt.xlabel('关系')
|
||||||
|
plt.ylabel('置信度')
|
||||||
|
plt.xticks(rotation=315)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,38 @@
|
||||||
|
## 快速上手
|
||||||
|
|
||||||
|
### 克隆代码
|
||||||
|
```
|
||||||
|
git clone git@github.com:zjunlp/DeepKE.git
|
||||||
|
```
|
||||||
|
### 使用pip安装
|
||||||
|
|
||||||
|
首先创建python虚拟环境,再进入虚拟环境
|
||||||
|
|
||||||
|
然后DeepKE可以依此安装:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install deepke
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用工具
|
||||||
|
先进行训练(训练所用到参数都在conf文件夹中,修改即可)
|
||||||
|
```
|
||||||
|
python run.py
|
||||||
|
```
|
||||||
|
再进行预测(需进入conf文件夹中predict.yaml修改使用模型路径)
|
||||||
|
```
|
||||||
|
python predict.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型架构
|
||||||
|
1、CNN
|
||||||
|
|
||||||
|
2、RNN
|
||||||
|
|
||||||
|
3、Capsule
|
||||||
|
|
||||||
|
4、GCN
|
||||||
|
|
||||||
|
5、Transformer
|
||||||
|
|
||||||
|
6、预训练模型
|
|
@ -0,0 +1,8 @@
|
||||||
|
python == 3.8
|
||||||
|
torch == 1.5
|
||||||
|
hydra-core == 1.0.6
|
||||||
|
tensorboard == 2.4.1
|
||||||
|
matplotlib == 3.4.1
|
||||||
|
scikit-learn == 0.24.1
|
||||||
|
transformers == 4.5.0
|
||||||
|
jieba == 0.42.1
|
|
@ -0,0 +1,146 @@
|
||||||
|
import os
|
||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import optim
|
||||||
|
from hydra import utils
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
# self
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
|
import deepke.models as models
|
||||||
|
from deepke.tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||||
|
from deepke.utils import manual_seed, load_pkl
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@hydra.main(config_path="conf/config.yaml")
|
||||||
|
def main(cfg):
|
||||||
|
cwd = utils.get_original_cwd()
|
||||||
|
# cwd = cwd[0:-5]
|
||||||
|
cfg.cwd = cwd
|
||||||
|
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||||
|
logger.info(f'\n{cfg.pretty()}')
|
||||||
|
|
||||||
|
__Model__ = {
|
||||||
|
'cnn': models.PCNN,
|
||||||
|
'rnn': models.BiLSTM,
|
||||||
|
'transformer': models.Transformer,
|
||||||
|
'gcn': models.GCN,
|
||||||
|
'capsule': models.Capsule,
|
||||||
|
'lm': models.LM,
|
||||||
|
}
|
||||||
|
|
||||||
|
# device
|
||||||
|
if cfg.use_gpu and torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda', cfg.gpu_id)
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
logger.info(f'device: {device}')
|
||||||
|
|
||||||
|
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
|
||||||
|
if cfg.preprocess:
|
||||||
|
preprocess(cfg)
|
||||||
|
|
||||||
|
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
|
||||||
|
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
|
||||||
|
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
|
||||||
|
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
|
||||||
|
|
||||||
|
if cfg.model_name == 'lm':
|
||||||
|
vocab_size = None
|
||||||
|
else:
|
||||||
|
vocab = load_pkl(vocab_path)
|
||||||
|
vocab_size = vocab.count
|
||||||
|
cfg.vocab_size = vocab_size
|
||||||
|
|
||||||
|
train_dataset = CustomDataset(train_data_path)
|
||||||
|
valid_dataset = CustomDataset(valid_data_path)
|
||||||
|
test_dataset = CustomDataset(test_data_path)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||||
|
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||||
|
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||||
|
|
||||||
|
model = __Model__[cfg.model_name](cfg)
|
||||||
|
model.to(device)
|
||||||
|
logger.info(f'\n {model}')
|
||||||
|
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
||||||
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
best_f1, best_epoch = -1, 0
|
||||||
|
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
|
||||||
|
train_losses, valid_losses = [], []
|
||||||
|
|
||||||
|
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
|
||||||
|
writer = SummaryWriter('tensorboard')
|
||||||
|
else:
|
||||||
|
writer = None
|
||||||
|
|
||||||
|
logger.info('=' * 10 + ' Start training ' + '=' * 10)
|
||||||
|
|
||||||
|
for epoch in range(1, cfg.epoch + 1):
|
||||||
|
manual_seed(cfg.seed + epoch)
|
||||||
|
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
|
||||||
|
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
|
||||||
|
scheduler.step(valid_loss)
|
||||||
|
model_path = model.save(epoch, cfg)
|
||||||
|
# logger.info(model_path)
|
||||||
|
|
||||||
|
train_losses.append(train_loss)
|
||||||
|
valid_losses.append(valid_loss)
|
||||||
|
if best_f1 < valid_f1:
|
||||||
|
best_f1 = valid_f1
|
||||||
|
best_epoch = epoch
|
||||||
|
# 使用 valid loss 做 early stopping 的判断标准
|
||||||
|
if es_loss > valid_loss:
|
||||||
|
es_loss = valid_loss
|
||||||
|
es_f1 = valid_f1
|
||||||
|
es_epoch = epoch
|
||||||
|
es_patience = 0
|
||||||
|
es_path = model_path
|
||||||
|
else:
|
||||||
|
es_patience += 1
|
||||||
|
if es_patience >= cfg.early_stopping_patience:
|
||||||
|
best_es_epoch = es_epoch
|
||||||
|
best_es_f1 = es_f1
|
||||||
|
best_es_path = es_path
|
||||||
|
|
||||||
|
if cfg.show_plot:
|
||||||
|
if cfg.plot_utils == 'matplot':
|
||||||
|
plt.plot(train_losses, 'x-')
|
||||||
|
plt.plot(valid_losses, '+-')
|
||||||
|
plt.legend(['train', 'valid'])
|
||||||
|
plt.title('train/valid comparison loss')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if cfg.plot_utils == 'tensorboard':
|
||||||
|
for i in range(len(train_losses)):
|
||||||
|
writer.add_scalars('train/valid_comparison_loss', {
|
||||||
|
'train': train_losses[i],
|
||||||
|
'valid': valid_losses[i]
|
||||||
|
}, i)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
|
||||||
|
f'this epoch macro f1: {best_es_f1:0.4f}')
|
||||||
|
logger.info(f'this model save path: {best_es_path}')
|
||||||
|
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
|
||||||
|
f'this epoch macro f1: {best_f1:.4f}')
|
||||||
|
|
||||||
|
logger.info('=====end of training====')
|
||||||
|
logger.info('')
|
||||||
|
logger.info('=====start test performance====')
|
||||||
|
validate(-1, model, test_dataloader, criterion, device, cfg)
|
||||||
|
logger.info('=====ending====')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
Before Width: | Height: | Size: 128 KiB After Width: | Height: | Size: 128 KiB |
Before Width: | Height: | Size: 68 KiB After Width: | Height: | Size: 68 KiB |
Before Width: | Height: | Size: 105 KiB After Width: | Height: | Size: 105 KiB |
Before Width: | Height: | Size: 136 KiB After Width: | Height: | Size: 136 KiB |
Before Width: | Height: | Size: 292 KiB After Width: | Height: | Size: 292 KiB |
Before Width: | Height: | Size: 118 KiB After Width: | Height: | Size: 118 KiB |
Before Width: | Height: | Size: 137 KiB After Width: | Height: | Size: 137 KiB |
Before Width: | Height: | Size: 94 KiB After Width: | Height: | Size: 94 KiB |
Before Width: | Height: | Size: 182 KiB After Width: | Height: | Size: 182 KiB |
|
@ -1 +0,0 @@
|
||||||
test
|
|
After Width: | Height: | Size: 149 KiB |
After Width: | Height: | Size: 149 KiB |
After Width: | Height: | Size: 149 KiB |
After Width: | Height: | Size: 149 KiB |
|
@ -1,38 +0,0 @@
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from utils import seq_len_to_mask
|
|
||||||
from module import DotAttention, MultiHeadAttention
|
|
||||||
|
|
||||||
torch.manual_seed(1)
|
|
||||||
q = torch.randn(4, 6, 20) # [B, L, H]
|
|
||||||
k = v = torch.randn(4, 5, 20) # [B, S, H]
|
|
||||||
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
|
|
||||||
attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉
|
|
||||||
head_mask = torch.tensor([0, 1, 0, 0]) # 为1 的地方 mask 掉
|
|
||||||
|
|
||||||
# m = DotAttention(dropout=0.0)
|
|
||||||
# ao,aw = m(q,k,v,key_padding_mask)
|
|
||||||
# print(ao.shape,aw.shape)
|
|
||||||
# print(aw)
|
|
||||||
|
|
||||||
|
|
||||||
def test_DotAttention():
|
|
||||||
m = DotAttention(dropout=0.0)
|
|
||||||
ao, aw = m(q, k, v, mask_out=key_padding_mask)
|
|
||||||
|
|
||||||
assert ao.shape == torch.Size([4, 6, 20])
|
|
||||||
assert aw.shape == torch.Size([4, 6, 5])
|
|
||||||
assert torch.all(aw[1, :, -1:].eq(0)) == torch.all(aw[2, :, -2:].eq(0)) == torch.all(aw[3, :, -3:].eq(0)) == True
|
|
||||||
|
|
||||||
|
|
||||||
def test_MultiHeadAttention():
|
|
||||||
m = MultiHeadAttention(embed_dim=20, num_heads=4, dropout=0.0)
|
|
||||||
ao, aw = m(q, k, v, key_padding_mask=key_padding_mask,attention_mask=attention_mask,head_mask=head_mask)
|
|
||||||
|
|
||||||
assert ao.shape == torch.Size([4, 6, 20])
|
|
||||||
assert aw.shape == torch.Size([4, 4, 6, 5])
|
|
||||||
assert aw.unbind(dim=1)[1].bool().any() == False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
pytest.main()
|
|
|
@ -1,32 +0,0 @@
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from module import CNN
|
|
||||||
from utils import seq_len_to_mask
|
|
||||||
|
|
||||||
|
|
||||||
class Config(object):
|
|
||||||
in_channels = 100
|
|
||||||
out_channels = 200
|
|
||||||
kernel_sizes = [3, 5, 7, 9, 11]
|
|
||||||
activation = 'gelu'
|
|
||||||
pooling_strategy = 'avg'
|
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
|
||||||
|
|
||||||
|
|
||||||
def test_CNN():
|
|
||||||
|
|
||||||
x = torch.randn(4, 5, 100)
|
|
||||||
seq = torch.arange(4, 0, -1)
|
|
||||||
mask = seq_len_to_mask(seq, max_len=5)
|
|
||||||
|
|
||||||
cnn = CNN(config)
|
|
||||||
out, out_pooling = cnn(x, mask=mask)
|
|
||||||
out_channels = config.out_channels * len(config.kernel_sizes)
|
|
||||||
assert out.shape == torch.Size([4, 5, out_channels])
|
|
||||||
assert out_pooling.shape == torch.Size([4, out_channels])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
pytest.main()
|
|
|
@ -1,38 +0,0 @@
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from module import Embedding
|
|
||||||
|
|
||||||
|
|
||||||
class Config(object):
|
|
||||||
vocab_size = 10
|
|
||||||
word_dim = 10
|
|
||||||
pos_size = 12 # 2 * pos_limit + 2
|
|
||||||
pos_dim = 5
|
|
||||||
dim_strategy = 'cat' # [cat, sum]
|
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
|
||||||
|
|
||||||
x = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 3, 5, 0], [8, 4, 3, 0, 0]])
|
|
||||||
x_pos = torch.tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0], [1, 2, 3, 0, 0]])
|
|
||||||
|
|
||||||
|
|
||||||
def test_Embedding_cat():
|
|
||||||
embed = Embedding(config)
|
|
||||||
feature = embed((x, x_pos))
|
|
||||||
dim = config.word_dim + config.pos_dim
|
|
||||||
|
|
||||||
assert feature.shape == torch.Size((3, 5, dim))
|
|
||||||
|
|
||||||
|
|
||||||
def test_Embedding_sum():
|
|
||||||
config.dim_strategy = 'sum'
|
|
||||||
embed = Embedding(config)
|
|
||||||
feature = embed((x, x_pos))
|
|
||||||
dim = config.word_dim
|
|
||||||
|
|
||||||
assert feature.shape == torch.Size((3, 5, dim))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
pytest.main()
|
|
|
@ -1,49 +0,0 @@
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from module import RNN
|
|
||||||
from utils import seq_len_to_mask
|
|
||||||
|
|
||||||
|
|
||||||
class Config(object):
|
|
||||||
type_rnn = 'LSTM'
|
|
||||||
input_size = 5
|
|
||||||
hidden_size = 4
|
|
||||||
num_layers = 3
|
|
||||||
dropout = 0.0
|
|
||||||
last_layer_hn = False
|
|
||||||
bidirectional = True
|
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
|
||||||
|
|
||||||
|
|
||||||
def test_CNN():
|
|
||||||
torch.manual_seed(1)
|
|
||||||
x = torch.tensor([[4, 3, 2, 1], [5, 6, 7, 0], [8, 10, 0, 0]])
|
|
||||||
x = torch.nn.Embedding(11, 5, padding_idx=0)(x) # B,L,H = 3,4,5
|
|
||||||
x_len = torch.tensor([4, 3, 2])
|
|
||||||
|
|
||||||
model = RNN(config)
|
|
||||||
output, hn = model(x, x_len)
|
|
||||||
|
|
||||||
B, L, _ = x.size()
|
|
||||||
H, N = config.hidden_size, config.num_layers
|
|
||||||
|
|
||||||
assert output.shape == torch.Size([B, L, H])
|
|
||||||
assert hn.shape == torch.Size([B, N, H])
|
|
||||||
|
|
||||||
config.bidirectional = False
|
|
||||||
model = RNN(config)
|
|
||||||
output, hn = model(x, x_len)
|
|
||||||
assert output.shape == torch.Size([B, L, H])
|
|
||||||
assert hn.shape == torch.Size([B, N, H])
|
|
||||||
|
|
||||||
config.last_layer_hn = True
|
|
||||||
model = RNN(config)
|
|
||||||
output, hn = model(x, x_len)
|
|
||||||
assert output.shape == torch.Size([B, L, H])
|
|
||||||
assert hn.shape == torch.Size([B, H])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
pytest.main()
|
|
|
@ -1,36 +0,0 @@
|
||||||
import pytest
|
|
||||||
from serializer import Serializer
|
|
||||||
|
|
||||||
|
|
||||||
def test_serializer_for_no_chinese_split():
|
|
||||||
text1 = "\nI\'m his pupp\'peer, and i have a ball\t"
|
|
||||||
text2 = '\t叫Stam一起到nba打篮球\n'
|
|
||||||
text3 = '\n\n现在时刻2014-04-08\t\t'
|
|
||||||
|
|
||||||
serializer = Serializer(do_chinese_split=False)
|
|
||||||
serial_text1 = serializer.serialize(text1)
|
|
||||||
serial_text2 = serializer.serialize(text2)
|
|
||||||
serial_text3 = serializer.serialize(text3)
|
|
||||||
|
|
||||||
assert serial_text1 == ['i', "'", 'm', 'his', 'pupp', "'", 'peer', ',', 'and', 'i', 'have', 'a', 'ball']
|
|
||||||
assert serial_text2 == ['叫', 'stam', '一', '起', '到', 'nba', '打', '篮', '球']
|
|
||||||
assert serial_text3 == ['现', '在', '时', '刻', '2014', '-', '04', '-', '08']
|
|
||||||
|
|
||||||
|
|
||||||
def test_serializer_for_chinese_split():
|
|
||||||
text1 = "\nI\'m his pupp\'peer, and i have a basketball\t"
|
|
||||||
text2 = '\t叫Stam一起到nba打篮球\n'
|
|
||||||
text3 = '\n\n现在时刻2014-04-08\t\t'
|
|
||||||
|
|
||||||
serializer = Serializer(do_chinese_split=True)
|
|
||||||
serial_text1 = serializer.serialize(text1)
|
|
||||||
serial_text2 = serializer.serialize(text2)
|
|
||||||
serial_text3 = serializer.serialize(text3)
|
|
||||||
|
|
||||||
assert serial_text1 == ['i', "'", 'm', 'his', 'pupp', "'", 'peer', ',', 'and', 'i', 'have', 'a', 'basketball']
|
|
||||||
assert serial_text2 == ['叫', 'stam', '一起', '到', 'nba', '打篮球']
|
|
||||||
assert serial_text3 == ['现在', '时刻', '2014', '-', '04', '-', '08']
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
pytest.main()
|
|
|
@ -1,40 +0,0 @@
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from module import Transformer
|
|
||||||
from utils import seq_len_to_mask
|
|
||||||
|
|
||||||
|
|
||||||
class Config():
|
|
||||||
hidden_size = 12
|
|
||||||
intermediate_size = 24
|
|
||||||
num_hidden_layers = 5
|
|
||||||
num_heads = 3
|
|
||||||
dropout = 0.0
|
|
||||||
layer_norm_eps = 1e-12
|
|
||||||
hidden_act = 'gelu_new'
|
|
||||||
output_attentions = True
|
|
||||||
output_hidden_states = True
|
|
||||||
|
|
||||||
|
|
||||||
config = Config()
|
|
||||||
|
|
||||||
|
|
||||||
def test_Transformer():
|
|
||||||
m = Transformer(config)
|
|
||||||
i = torch.randn(4, 5, 12) # [B, L, H]
|
|
||||||
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
|
|
||||||
attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉
|
|
||||||
head_mask = torch.tensor([0, 1, 0]) # 为1 的地方 mask 掉
|
|
||||||
|
|
||||||
out = m(i, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask)
|
|
||||||
hn, h_all, att_weights = out
|
|
||||||
assert hn.shape == torch.Size([4, 5, 12])
|
|
||||||
assert torch.equal(h_all[0], i) and torch.equal(h_all[-1], hn) == True
|
|
||||||
assert len(h_all) == config.num_hidden_layers + 1
|
|
||||||
assert len(att_weights) == config.num_hidden_layers
|
|
||||||
assert att_weights[0].shape == torch.Size([4, 3, 5, 5])
|
|
||||||
assert att_weights[0].unbind(dim=1)[1].bool().any() == False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
pytest.main()
|
|
|
@ -1,38 +0,0 @@
|
||||||
import pytest
|
|
||||||
from serializer import Serializer
|
|
||||||
from vocab import Vocab
|
|
||||||
|
|
||||||
|
|
||||||
def test_vocab():
|
|
||||||
vocab = Vocab('test')
|
|
||||||
sent = ' 我是中国人,我爱中国。 I\'m Chinese, I love China'
|
|
||||||
|
|
||||||
serializer = Serializer(do_lower_case=True)
|
|
||||||
tokens = serializer.serialize(sent)
|
|
||||||
assert tokens == [
|
|
||||||
'我', '是', '中', '国', '人', ',', '我', '爱', '中', '国', '。', 'i', "'", 'm', 'chinese', ',', 'i', 'love', 'china'
|
|
||||||
]
|
|
||||||
|
|
||||||
vocab.add_words(tokens)
|
|
||||||
unk_str = '[UNK]'
|
|
||||||
unk_idx = vocab.word2idx[unk_str]
|
|
||||||
|
|
||||||
assert vocab.count == 22
|
|
||||||
assert len(vocab.word2idx) == len(vocab.idx2word) == len(vocab.word2idx) == 22
|
|
||||||
|
|
||||||
vocab.trim(2, verbose=False)
|
|
||||||
|
|
||||||
assert vocab.count == 11
|
|
||||||
assert len(vocab.word2idx) == len(vocab.idx2word) == len(vocab.word2idx) == 11
|
|
||||||
|
|
||||||
token2idx = [vocab.word2idx.get(i, unk_idx) for i in tokens]
|
|
||||||
assert len(tokens) == len(token2idx)
|
|
||||||
assert token2idx == [7, 1, 8, 9, 1, 1, 7, 1, 8, 9, 1, 10, 1, 1, 1, 1, 10, 1, 1]
|
|
||||||
|
|
||||||
idx2tokens = [vocab.idx2word.get(i, unk_str) for i in token2idx]
|
|
||||||
assert len(idx2tokens) == len(token2idx)
|
|
||||||
assert ' '.join(idx2tokens) == '我 [UNK] 中 国 [UNK] [UNK] 我 [UNK] 中 国 [UNK] i [UNK] [UNK] [UNK] [UNK] i [UNK] [UNK]'
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
pytest.main()
|
|