test
This commit is contained in:
parent
db9d0e198c
commit
607ac898e3
16
README.md
16
README.md
|
@ -83,7 +83,7 @@ DeepKE 提供了多种知识抽取模型。
|
||||||
|
|
||||||
数据为txt文件,样式范例为:
|
数据为txt文件,样式范例为:
|
||||||
|
|
||||||
| sentence | Person | Location | Organization | Miscellaneous |
|
| Sentence | Person | Location | Organization | Miscellaneous |
|
||||||
| :----------------------------------------------------------: | :----------------------------------: | :---------------: | :-------------------------: | :-------------------: |
|
| :----------------------------------------------------------: | :----------------------------------: | :---------------: | :-------------------------: | :-------------------: |
|
||||||
| Australian Tom Moody took six for 82 but Chris Adams, 123, and Tim O'Gorman, 109, took Derbyshire to 471 and a first innings lead of 233. | Tom Moody, Chris Adams, Tim O'Gorman | / | Derbysire | Australian |
|
| Australian Tom Moody took six for 82 but Chris Adams, 123, and Tim O'Gorman, 109, took Derbyshire to 471 and a first innings lead of 233. | Tom Moody, Chris Adams, Tim O'Gorman | / | Derbysire | Australian |
|
||||||
| Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival. | Irene | Hangzhou, Warsaw | Zhejiang University | Chopin Music Festival |
|
| Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival. | Irene | Hangzhou, Warsaw | Zhejiang University | Chopin Music Festival |
|
||||||
|
@ -95,6 +95,20 @@ DeepKE 提供了多种知识抽取模型。
|
||||||
|
|
||||||
3. AE
|
3. AE
|
||||||
|
|
||||||
|
数据为csv文件,样式范例为:
|
||||||
|
|
||||||
|
| Sentence | Attribute | Entity | Entity_offset | Attribute_value | Attribute_value_offset |
|
||||||
|
| :----------------------------------------------------: | :------: | :--------: | :---------: | :--------: | :---------: |
|
||||||
|
| 张冬梅,女,汉族,1968年2月生,河南淇县人,1988年7月加入中国共产党,1989年9月参加工作,中央党校经济管理专业毕业,中央党校研究生学历 | 民族 | 张冬梅 | 0 | 汉族 | 6 |
|
||||||
|
| 杨缨,字绵公,号钓溪,松溪县人,祖籍将乐,是北宋理学家杨时的七世孙 | 朝代 | 杨缨 | 0 | 北宋 | 22 |
|
||||||
|
| 2014年10月1日许鞍华执导的电影《黄金时代》上映,冯绍峰饰演与之差别极大的民国东北爷们萧军,演技受到肯定 | 上映时间 | 黄金时代 | 19 | 2014年10月1日 | 0 |
|
||||||
|
|
||||||
|
具体流程请进入详细的README中,RE包括了以下三个子功能
|
||||||
|
|
||||||
|
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ae/regular/README.md)**
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## 模型架构
|
## 模型架构
|
||||||
Deepke的架构图如下所示
|
Deepke的架构图如下所示
|
||||||
|
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 149 KiB |
|
@ -0,0 +1,56 @@
|
||||||
|
## 快速上手
|
||||||
|
|
||||||
|
### 环境依赖
|
||||||
|
|
||||||
|
> python == 3.8
|
||||||
|
|
||||||
|
- torch == 1.5
|
||||||
|
- hydra-core == 1.0.6
|
||||||
|
- tensorboard == 2.4.1
|
||||||
|
- matplotlib == 3.4.1
|
||||||
|
- scikit-learn == 0.24.1
|
||||||
|
- transformers == 4.5.0
|
||||||
|
- jieba == 0.42.1
|
||||||
|
- deepke
|
||||||
|
|
||||||
|
### 克隆代码
|
||||||
|
```
|
||||||
|
git clone git@github.com:zjunlp/DeepKE.git
|
||||||
|
```
|
||||||
|
### 使用pip安装
|
||||||
|
|
||||||
|
首先创建python虚拟环境,再进入虚拟环境
|
||||||
|
|
||||||
|
- 安装依赖: ```pip install -r requirements.txt```
|
||||||
|
|
||||||
|
### 使用数据进行训练预测
|
||||||
|
|
||||||
|
- 存放数据:在 `data/origin` 文件夹下存放训练数据。训练文件主要有三个文件。
|
||||||
|
|
||||||
|
- `train.csv`:存放训练数据集
|
||||||
|
|
||||||
|
- `valid.csv`:存放验证数据集
|
||||||
|
|
||||||
|
- `test.csv`:存放测试数据集
|
||||||
|
|
||||||
|
- `attribute.csv`:存放属性种类
|
||||||
|
|
||||||
|
- 开始训练:```python run.py``` (训练所用到参数都在conf文件夹中,修改即可)
|
||||||
|
|
||||||
|
- 每次训练的日志保存在 `logs` 文件夹内,模型结果保存在 `checkpoints` 文件夹内。
|
||||||
|
|
||||||
|
- 进行预测 ```python predict.py```
|
||||||
|
|
||||||
|
|
||||||
|
## 模型内容
|
||||||
|
1、CNN
|
||||||
|
|
||||||
|
2、RNN
|
||||||
|
|
||||||
|
3、Capsule
|
||||||
|
|
||||||
|
4、GCN
|
||||||
|
|
||||||
|
5、Transformer
|
||||||
|
|
||||||
|
6、预训练模型
|
|
@ -0,0 +1,17 @@
|
||||||
|
# ??? is a mandatory value.
|
||||||
|
# you should be able to set it without open_dict
|
||||||
|
# but if you try to read it before it's set an error will get thrown.
|
||||||
|
|
||||||
|
# populated at runtime
|
||||||
|
cwd: ???
|
||||||
|
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- hydra/output: custom
|
||||||
|
- preprocess
|
||||||
|
- train
|
||||||
|
- embedding
|
||||||
|
- predict
|
||||||
|
- model: cnn # [cnn, rnn, transformer, capsule, gcn, lm]
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
# populated at runtime
|
||||||
|
vocab_size: ???
|
||||||
|
word_dim: 60
|
||||||
|
pos_size: ??? # 2 * pos_limit + 2
|
||||||
|
pos_dim: 10 # 当为 sum 时,此值无效,和 word_dim 强行相同
|
||||||
|
|
||||||
|
dim_strategy: sum # [cat, sum]
|
||||||
|
|
||||||
|
# 属性种类
|
||||||
|
num_attributes: 7
|
|
@ -0,0 +1,11 @@
|
||||||
|
hydra:
|
||||||
|
|
||||||
|
run:
|
||||||
|
# Output directory for normal runs
|
||||||
|
dir: logs/${now:%Y-%m-%d_%H-%M-%S}
|
||||||
|
|
||||||
|
sweep:
|
||||||
|
# Output directory for sweep runs
|
||||||
|
dir: logs/${now:%Y-%m-%d_%H-%M-%S}
|
||||||
|
# Output sub directory for sweep runs.
|
||||||
|
subdir: ${hydra.job.num}_${hydra.job.id}
|
|
@ -0,0 +1,20 @@
|
||||||
|
model_name: capsule
|
||||||
|
|
||||||
|
share_weights: True
|
||||||
|
num_iterations: 5 # 迭代次数
|
||||||
|
dropout: 0.3
|
||||||
|
|
||||||
|
input_dim_capsule: ??? # 由上层卷积结果得到,一般是卷积输出的 hidden_size
|
||||||
|
dim_capsule: 50 # 输出 capsule 的维度
|
||||||
|
num_capsule: ??? # 输出 capsule 的数目,和分类结果相同, == num_attributes
|
||||||
|
|
||||||
|
|
||||||
|
# primary capsule 组成
|
||||||
|
# 可以 embedding / cnn / rnn
|
||||||
|
# 暂时先用 cnn
|
||||||
|
in_channels: ??? # 使用 embedding 输出的结果,不需要指定
|
||||||
|
out_channels: 100 # == input_dim_capsule
|
||||||
|
kernel_sizes: [9] # 必须为奇数,而且要比较大
|
||||||
|
activation: 'lrelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||||
|
keep_length: False # 不需要padding,太多无用信息
|
||||||
|
pooling_strategy: cls # 无关紧要,根本用不到
|
|
@ -0,0 +1,13 @@
|
||||||
|
model_name: cnn
|
||||||
|
|
||||||
|
in_channels: ??? # 使用 embedding 输出的结果,不需要指定
|
||||||
|
out_channels: 100
|
||||||
|
kernel_sizes: [3, 5, 7] # 必须为奇数,为了保证cnn的输出不改变句子长度
|
||||||
|
activation: 'gelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||||
|
pooling_strategy: 'max' # [max, avg, cls]
|
||||||
|
keep_length: True
|
||||||
|
dropout: 0.3
|
||||||
|
|
||||||
|
# pcnn
|
||||||
|
use_pcnn: False
|
||||||
|
intermediate: 80
|
|
@ -0,0 +1,7 @@
|
||||||
|
model_name: gcn
|
||||||
|
|
||||||
|
num_layers: 3
|
||||||
|
|
||||||
|
input_size: ??? # 使用 embedding 输出的结果,不需要指定
|
||||||
|
hidden_size: 100
|
||||||
|
dropout: 0.3
|
|
@ -0,0 +1,20 @@
|
||||||
|
model_name: lm
|
||||||
|
|
||||||
|
# 当使用预训练语言模型时,该预训练的模型存放位置
|
||||||
|
# lm_name = 'bert-base-chinese' # download usage
|
||||||
|
#lm_file: 'pretrained'
|
||||||
|
lm_file: '/home/yhy/transformers/bert-base-chinese'
|
||||||
|
|
||||||
|
# transformer 层数,初始 base bert 为12层
|
||||||
|
# 但是数据量较小时调低些反而收敛更快效果更好
|
||||||
|
num_hidden_layers: 1
|
||||||
|
|
||||||
|
|
||||||
|
# 后面所接 bilstm 的参数
|
||||||
|
type_rnn: 'LSTM' # [RNN, GRU, LSTM]
|
||||||
|
input_size: 768 # 这个值由bert得到
|
||||||
|
hidden_size: 100 # 必须为偶数
|
||||||
|
num_layers: 1
|
||||||
|
dropout: 0.3
|
||||||
|
bidirectional: True
|
||||||
|
last_layer_hn: True
|
|
@ -0,0 +1,10 @@
|
||||||
|
model_name: rnn
|
||||||
|
|
||||||
|
type_rnn: 'LSTM' # [RNN, GRU, LSTM]
|
||||||
|
|
||||||
|
input_size: ??? # 使用 embedding 输出的结果,不需要指定
|
||||||
|
hidden_size: 150 # 必须为偶数
|
||||||
|
num_layers: 2
|
||||||
|
dropout: 0.3
|
||||||
|
bidirectional: True
|
||||||
|
last_layer_hn: True
|
|
@ -0,0 +1,12 @@
|
||||||
|
model_name: transformer
|
||||||
|
|
||||||
|
hidden_size: ??? # 使用 embedding 输出的结果,不需要指定
|
||||||
|
num_heads: 4 # 必须能被 hidden_size 整除
|
||||||
|
num_hidden_layers: 3
|
||||||
|
intermediate_size: 256
|
||||||
|
dropout: 0.1
|
||||||
|
layer_norm_eps: 1e-12
|
||||||
|
hidden_act: gelu_new # [relu, gelu, swish, gelu_new]
|
||||||
|
|
||||||
|
output_attentions: True
|
||||||
|
output_hidden_states: True
|
|
@ -0,0 +1,2 @@
|
||||||
|
# 自定义模型存储的路径
|
||||||
|
fp = 'xxx/checkpoints/2019-12-03_17-35-30/cnn_epoch21.pth'
|
|
@ -0,0 +1,20 @@
|
||||||
|
# 是否需要预处理数据
|
||||||
|
# 当数据处理参数没有变换时,不需要重新预处理
|
||||||
|
preprocess: True
|
||||||
|
|
||||||
|
# 原始数据存放位置
|
||||||
|
data_path: 'data/origin'
|
||||||
|
|
||||||
|
# 预处理后存放文件位置
|
||||||
|
out_path: 'data/out'
|
||||||
|
|
||||||
|
# 是否需要分词
|
||||||
|
chinese_split: True
|
||||||
|
|
||||||
|
# vocab 构建时的最低词频控制
|
||||||
|
min_freq: 3
|
||||||
|
|
||||||
|
# 句长限制: 指句子中词语相对entity的position限制
|
||||||
|
# 如:[-30, 30],embed 时整体+31,变成[1, 61]
|
||||||
|
# 则一共62个pos token,0 留给 pad
|
||||||
|
pos_limit: 30
|
|
@ -0,0 +1,21 @@
|
||||||
|
seed: 1
|
||||||
|
|
||||||
|
use_gpu: True
|
||||||
|
gpu_id: 0
|
||||||
|
|
||||||
|
epoch: 50
|
||||||
|
batch_size: 32
|
||||||
|
learning_rate: 3e-4
|
||||||
|
lr_factor: 0.7 # 学习率的衰减率
|
||||||
|
lr_patience: 3 # 学习率衰减的等待epoch
|
||||||
|
weight_decay: 1e-3 # L2正则
|
||||||
|
|
||||||
|
early_stopping_patience: 6
|
||||||
|
|
||||||
|
train_log: True
|
||||||
|
log_interval: 10
|
||||||
|
show_plot: True
|
||||||
|
only_comparison_plot: False
|
||||||
|
plot_utils: matplot # [matplot, tensorboard]
|
||||||
|
|
||||||
|
predict_plot: True
|
|
@ -0,0 +1,8 @@
|
||||||
|
attribute,index
|
||||||
|
None,0
|
||||||
|
民族,1
|
||||||
|
字,2
|
||||||
|
朝代,3
|
||||||
|
身高,4
|
||||||
|
创始人,5
|
||||||
|
上映时间,6
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,147 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import hydra
|
||||||
|
from hydra import utils
|
||||||
|
from serializer import Serializer
|
||||||
|
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
|
from utils import load_pkl, load_csv
|
||||||
|
import models as models
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_data(data, cfg):
|
||||||
|
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
||||||
|
attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False)
|
||||||
|
atts = _handle_attribute_data(attribute_data)
|
||||||
|
cfg.vocab_size = vocab.count
|
||||||
|
serializer = Serializer(do_chinese_split=cfg.chinese_split)
|
||||||
|
serial = serializer.serialize
|
||||||
|
|
||||||
|
_serialize_sentence(data, serial, cfg)
|
||||||
|
_convert_tokens_into_index(data, vocab)
|
||||||
|
_add_pos_seq(data, cfg)
|
||||||
|
logger.info('start sentence preprocess...')
|
||||||
|
formats = '\nsentence: {}\nchinese_split: {}\n' \
|
||||||
|
'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}'
|
||||||
|
logger.info(
|
||||||
|
formats.format(data[0]['sentence'], cfg.chinese_split,
|
||||||
|
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
|
||||||
|
data[0]['entity_index'], data[0]['attribute_value_index']))
|
||||||
|
return data, atts
|
||||||
|
|
||||||
|
|
||||||
|
def _get_predict_instance(cfg):
|
||||||
|
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
|
||||||
|
flag = flag.strip().lower()
|
||||||
|
if flag == 'y' or flag == 'yes':
|
||||||
|
sentence = '张冬梅,女,汉族,1968年2月生,河南淇县人,1988年7月加入中国共产党,1989年9月参加工作,中央党校经济管理专业毕业,中央党校研究生学历'
|
||||||
|
entity = '张冬梅'
|
||||||
|
attribute_value = '汉族'
|
||||||
|
elif flag == 'n' or flag == 'no':
|
||||||
|
sentence = input('请输入句子:')
|
||||||
|
entity = input('请输入句中需要预测的实体:')
|
||||||
|
attribute_value = input('请输入句中需要预测的属性值:')
|
||||||
|
elif flag == 'exit':
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print('please input yes or no, or exit!')
|
||||||
|
_get_predict_instance(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
instance = dict()
|
||||||
|
instance['sentence'] = sentence.strip()
|
||||||
|
instance['entity'] = entity.strip()
|
||||||
|
instance['attribute_value'] = attribute_value.strip()
|
||||||
|
instance['entity_offset'] = sentence.find(entity)
|
||||||
|
instance['attribute_value_offset'] = sentence.find(attribute_value)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_path='../conf/config.yaml')
|
||||||
|
def main(cfg):
|
||||||
|
cwd = utils.get_original_cwd()
|
||||||
|
cwd = cwd[0:-5]
|
||||||
|
cfg.cwd = cwd
|
||||||
|
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||||
|
print(cfg.pretty())
|
||||||
|
|
||||||
|
# get predict instance
|
||||||
|
instance = _get_predict_instance(cfg)
|
||||||
|
data = [instance]
|
||||||
|
|
||||||
|
# preprocess data
|
||||||
|
data, rels = _preprocess_data(data, cfg)
|
||||||
|
|
||||||
|
# model
|
||||||
|
__Model__ = {
|
||||||
|
'cnn': models.PCNN,
|
||||||
|
'rnn': models.BiLSTM,
|
||||||
|
'transformer': models.Transformer,
|
||||||
|
'gcn': models.GCN,
|
||||||
|
'capsule': models.Capsule,
|
||||||
|
'lm': models.LM,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 最好在 cpu 上预测
|
||||||
|
cfg.use_gpu = False
|
||||||
|
if cfg.use_gpu and torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda', cfg.gpu_id)
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
logger.info(f'device: {device}')
|
||||||
|
|
||||||
|
model = __Model__[cfg.model_name](cfg)
|
||||||
|
logger.info(f'model name: {cfg.model_name}')
|
||||||
|
logger.info(f'\n {model}')
|
||||||
|
model.load(cfg.fp, device=device)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
x = dict()
|
||||||
|
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
||||||
|
|
||||||
|
if cfg.model_name != 'lm':
|
||||||
|
x['entity_pos'], x['attribute_value_pos'] = torch.tensor([data[0]['entity_pos']]), torch.tensor([data[0]['attribute_value_pos']])
|
||||||
|
if cfg.model_name == 'cnn':
|
||||||
|
if cfg.use_pcnn:
|
||||||
|
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
||||||
|
if cfg.model_name == 'gcn':
|
||||||
|
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||||
|
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
|
||||||
|
x['adj'] = adj
|
||||||
|
|
||||||
|
|
||||||
|
for key in x.keys():
|
||||||
|
x[key] = x[key].to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
y_pred = model(x)
|
||||||
|
y_pred = torch.softmax(y_pred, dim=-1)[0]
|
||||||
|
prob = y_pred.max().item()
|
||||||
|
prob_att = list(rels.keys())[y_pred.argmax().item()]
|
||||||
|
logger.info(f"\"{data[0]['entity']}\" 和 \"{data[0]['attribute_value']}\" 在句中属性为:\"{prob_att}\",置信度为{prob:.2f}。")
|
||||||
|
|
||||||
|
if cfg.predict_plot:
|
||||||
|
plt.rcParams["font.family"] = 'Arial Unicode MS'
|
||||||
|
x = list(rels.keys())
|
||||||
|
height = list(y_pred.cpu().numpy())
|
||||||
|
plt.bar(x, height)
|
||||||
|
for x, y in zip(x, height):
|
||||||
|
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
|
||||||
|
plt.xlabel('关系')
|
||||||
|
plt.ylabel('置信度')
|
||||||
|
plt.xticks(rotation=315)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,8 @@
|
||||||
|
torch == 1.5
|
||||||
|
hydra-core == 1.0.6
|
||||||
|
tensorboard == 2.4.1
|
||||||
|
matplotlib == 3.4.1
|
||||||
|
scikit-learn == 0.24.1
|
||||||
|
transformers == 4.5.0
|
||||||
|
jieba == 0.42.1
|
||||||
|
deepke
|
|
@ -0,0 +1,148 @@
|
||||||
|
import os
|
||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import optim
|
||||||
|
from hydra import utils
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
# self
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
|
import models as models
|
||||||
|
from tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||||
|
from utils import manual_seed, load_pkl
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@hydra.main(config_path="../conf/config.yaml")
|
||||||
|
def main(cfg):
|
||||||
|
cwd = utils.get_original_cwd()
|
||||||
|
cwd = cwd[0:-5]
|
||||||
|
cfg.cwd = cwd
|
||||||
|
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||||
|
logger.info(f'\n{cfg.pretty()}')
|
||||||
|
|
||||||
|
__Model__ = {
|
||||||
|
'cnn': models.PCNN,
|
||||||
|
'rnn': models.BiLSTM,
|
||||||
|
'transformer': models.Transformer,
|
||||||
|
'gcn': models.GCN,
|
||||||
|
'capsule': models.Capsule,
|
||||||
|
'lm': models.LM,
|
||||||
|
}
|
||||||
|
|
||||||
|
# device
|
||||||
|
if cfg.use_gpu and torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda', cfg.gpu_id)
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
logger.info(f'device: {device}')
|
||||||
|
|
||||||
|
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
|
||||||
|
if cfg.preprocess:
|
||||||
|
preprocess(cfg)
|
||||||
|
|
||||||
|
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
|
||||||
|
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
|
||||||
|
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
|
||||||
|
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
|
||||||
|
|
||||||
|
if cfg.model_name == 'lm':
|
||||||
|
vocab_size = None
|
||||||
|
else:
|
||||||
|
vocab = load_pkl(vocab_path)
|
||||||
|
vocab_size = vocab.count
|
||||||
|
cfg.vocab_size = vocab_size
|
||||||
|
|
||||||
|
train_dataset = CustomDataset(train_data_path)
|
||||||
|
valid_dataset = CustomDataset(valid_data_path)
|
||||||
|
test_dataset = CustomDataset(test_data_path)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||||
|
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||||
|
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||||
|
|
||||||
|
model = __Model__[cfg.model_name](cfg)
|
||||||
|
model.to(device)
|
||||||
|
logger.info(f'\n {model}')
|
||||||
|
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
||||||
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
best_f1, best_epoch = -1, 0
|
||||||
|
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
|
||||||
|
train_losses, valid_losses = [], []
|
||||||
|
|
||||||
|
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
|
||||||
|
writer = SummaryWriter('tensorboard')
|
||||||
|
else:
|
||||||
|
writer = None
|
||||||
|
|
||||||
|
logger.info('=' * 10 + ' Start training ' + '=' * 10)
|
||||||
|
|
||||||
|
for epoch in range(1, cfg.epoch + 1):
|
||||||
|
manual_seed(cfg.seed + epoch)
|
||||||
|
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
|
||||||
|
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
|
||||||
|
scheduler.step(valid_loss)
|
||||||
|
model_path = model.save(epoch, cfg)
|
||||||
|
# logger.info(model_path)
|
||||||
|
|
||||||
|
train_losses.append(train_loss)
|
||||||
|
valid_losses.append(valid_loss)
|
||||||
|
if best_f1 < valid_f1:
|
||||||
|
best_f1 = valid_f1
|
||||||
|
best_epoch = epoch
|
||||||
|
# 使用 valid loss 做 early stopping 的判断标准
|
||||||
|
if es_loss > valid_loss:
|
||||||
|
es_loss = valid_loss
|
||||||
|
es_f1 = valid_f1
|
||||||
|
es_epoch = epoch
|
||||||
|
es_patience = 0
|
||||||
|
es_path = model_path
|
||||||
|
else:
|
||||||
|
es_patience += 1
|
||||||
|
if es_patience >= cfg.early_stopping_patience:
|
||||||
|
best_es_epoch = es_epoch
|
||||||
|
best_es_f1 = es_f1
|
||||||
|
best_es_path = es_path
|
||||||
|
|
||||||
|
if cfg.show_plot:
|
||||||
|
if cfg.plot_utils == 'matplot':
|
||||||
|
plt.plot(train_losses, 'x-')
|
||||||
|
plt.plot(valid_losses, '+-')
|
||||||
|
plt.legend(['train', 'valid'])
|
||||||
|
plt.title('train/valid comparison loss')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if cfg.plot_utils == 'tensorboard':
|
||||||
|
for i in range(len(train_losses)):
|
||||||
|
writer.add_scalars('train/valid_comparison_loss', {
|
||||||
|
'train': train_losses[i],
|
||||||
|
'valid': valid_losses[i]
|
||||||
|
}, i)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
|
||||||
|
f'this epoch macro f1: {best_es_f1:0.4f}')
|
||||||
|
logger.info(f'this model save path: {best_es_path}')
|
||||||
|
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
|
||||||
|
f'this epoch macro f1: {best_f1:.4f}')
|
||||||
|
|
||||||
|
logger.info('=====end of training====')
|
||||||
|
logger.info('')
|
||||||
|
logger.info('=====start test performance====')
|
||||||
|
validate(-1, model, test_dataloader, criterion, device, cfg)
|
||||||
|
logger.info('=====ending====')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
# python predict.py --help # 查看参数帮助
|
||||||
|
# python predict.py -c
|
||||||
|
# python predict.py chinese_split=0,1 replace_entity_with_type=0,1 -m
|
Binary file not shown.
Before Width: | Height: | Size: 149 KiB |
|
@ -0,0 +1,35 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class BasicModule(nn.Module):
|
||||||
|
'''
|
||||||
|
封装nn.Module, 提供 save 和 load 方法
|
||||||
|
'''
|
||||||
|
def __init__(self):
|
||||||
|
super(BasicModule, self).__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, path, device):
|
||||||
|
'''
|
||||||
|
加载指定路径的模型
|
||||||
|
'''
|
||||||
|
self.load_state_dict(torch.load(path, map_location=device))
|
||||||
|
|
||||||
|
|
||||||
|
def save(self, epoch=0, cfg=None):
|
||||||
|
'''
|
||||||
|
保存模型,默认使用“模型名字+时间”作为文件名
|
||||||
|
'''
|
||||||
|
time_prefix = time.strftime('%Y-%m-%d_%H-%M-%S')
|
||||||
|
prefix = os.path.join(cfg.cwd, 'checkpoints',time_prefix)
|
||||||
|
os.makedirs(prefix, exist_ok=True)
|
||||||
|
name = os.path.join(prefix, cfg.model_name + '_' + f'epoch{epoch}' + '.pth')
|
||||||
|
|
||||||
|
torch.save(self.state_dict(), name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
from . import BasicModule
|
||||||
|
from module import Embedding, RNN
|
||||||
|
|
||||||
|
|
||||||
|
class BiLSTM(BasicModule):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(BiLSTM, self).__init__()
|
||||||
|
|
||||||
|
if cfg.dim_strategy == 'cat':
|
||||||
|
cfg.input_size = cfg.word_dim + 2 * cfg.pos_dim
|
||||||
|
else:
|
||||||
|
cfg.input_size = cfg.word_dim
|
||||||
|
|
||||||
|
self.embedding = Embedding(cfg)
|
||||||
|
self.bilstm = RNN(cfg)
|
||||||
|
self.fc = nn.Linear(cfg.hidden_size, cfg.num_attributes)
|
||||||
|
self.dropout = nn.Dropout(cfg.dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
word, lens, entity_pos, attribute_value_pos = x['word'], x['lens'], x['entity_pos'], x['attribute_value_pos']
|
||||||
|
inputs = self.embedding(word, entity_pos, attribute_value_pos)
|
||||||
|
out, out_pool = self.bilstm(inputs, lens)
|
||||||
|
output = self.fc(out_pool)
|
||||||
|
|
||||||
|
return output
|
|
@ -0,0 +1,51 @@
|
||||||
|
import torch
|
||||||
|
from . import BasicModule
|
||||||
|
from module import Embedding, CNN
|
||||||
|
from module import Capsule as CapsuleLayer
|
||||||
|
|
||||||
|
from utils import seq_len_to_mask, to_one_hot
|
||||||
|
|
||||||
|
|
||||||
|
class Capsule(BasicModule):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(Capsule, self).__init__()
|
||||||
|
|
||||||
|
if cfg.dim_strategy == 'cat':
|
||||||
|
cfg.in_channels = cfg.word_dim + 2 * cfg.pos_dim
|
||||||
|
else:
|
||||||
|
cfg.in_channels = cfg.word_dim
|
||||||
|
|
||||||
|
# capsule config
|
||||||
|
cfg.input_dim_capsule = cfg.out_channels
|
||||||
|
cfg.num_capsule = cfg.num_attributes
|
||||||
|
|
||||||
|
self.num_attributes = cfg.num_attributes
|
||||||
|
self.embedding = Embedding(cfg)
|
||||||
|
self.cnn = CNN(cfg)
|
||||||
|
self.capsule = CapsuleLayer(cfg)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
word, lens, entity_pos, attribute_value_pos = x['word'], x['lens'], x['entity_pos'], x['attribute_value_pos']
|
||||||
|
mask = seq_len_to_mask(lens)
|
||||||
|
inputs = self.embedding(word, entity_pos, attribute_value_pos)
|
||||||
|
|
||||||
|
primary, _ = self.cnn(inputs) # 由于长度改变,无法定向mask,不mask可可以,毕竟primary capsule 就是粗粒度的信息
|
||||||
|
output = self.capsule(primary)
|
||||||
|
output = output.norm(p=2, dim=-1) # 求得模长再返回值
|
||||||
|
|
||||||
|
return output # [B, N]
|
||||||
|
|
||||||
|
def loss(self, predict, target, reduction='mean'):
|
||||||
|
m_plus, m_minus, loss_lambda = 0.9, 0.1, 0.5
|
||||||
|
|
||||||
|
target = to_one_hot(target, self.num_attributes)
|
||||||
|
max_l = (torch.relu(m_plus - predict))**2
|
||||||
|
max_r = (torch.relu(predict - m_minus))**2
|
||||||
|
loss = target * max_l + loss_lambda * (1 - target) * max_r
|
||||||
|
loss = torch.sum(loss, dim=-1)
|
||||||
|
|
||||||
|
if reduction == 'sum':
|
||||||
|
return loss.sum()
|
||||||
|
else:
|
||||||
|
# 默认情况为求平均
|
||||||
|
return loss.mean()
|
|
@ -0,0 +1,32 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from . import BasicModule
|
||||||
|
from module import Embedding
|
||||||
|
from module import GCN as GCNBlock
|
||||||
|
|
||||||
|
from utils import seq_len_to_mask
|
||||||
|
|
||||||
|
|
||||||
|
class GCN(BasicModule):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(GCN, self).__init__()
|
||||||
|
|
||||||
|
if cfg.dim_strategy == 'cat':
|
||||||
|
cfg.input_size = cfg.word_dim + 2 * cfg.pos_dim
|
||||||
|
else:
|
||||||
|
cfg.input_size = cfg.word_dim
|
||||||
|
|
||||||
|
self.embedding = Embedding(cfg)
|
||||||
|
self.gcn = GCNBlock(cfg)
|
||||||
|
self.fc = nn.Linear(cfg.hidden_size, cfg.num_attributes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
word, lens, entity_pos, attribute_value_pos, adj = x['word'], x['lens'], x['entity_pos'], x['attribute_value_pos'], x['adj']
|
||||||
|
|
||||||
|
inputs = self.embedding(word, entity_pos, attribute_value_pos)
|
||||||
|
output = self.gcn(inputs, adj)
|
||||||
|
output = output.max(dim=1)[0]
|
||||||
|
output = self.fc(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
from torch import nn
|
||||||
|
from . import BasicModule
|
||||||
|
from module import RNN
|
||||||
|
from transformers import BertModel
|
||||||
|
|
||||||
|
from utils import seq_len_to_mask
|
||||||
|
|
||||||
|
|
||||||
|
class LM(BasicModule):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(LM, self).__init__()
|
||||||
|
self.bert = BertModel.from_pretrained(cfg.lm_file, num_hidden_layers=cfg.num_hidden_layers)
|
||||||
|
self.bilstm = RNN(cfg)
|
||||||
|
self.fc = nn.Linear(cfg.hidden_size, cfg.num_attributes)
|
||||||
|
self.dropout = nn.Dropout(cfg.dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
word, lens = x['word'], x['lens']
|
||||||
|
mask = seq_len_to_mask(lens, mask_pos_to_true=False)
|
||||||
|
a = self.bert(word, attention_mask=mask)
|
||||||
|
last_hidden_state = a[0]
|
||||||
|
# pooler_output = a[1]
|
||||||
|
_, out_pool = self.bilstm(last_hidden_state, lens)
|
||||||
|
out_pool = self.dropout(out_pool)
|
||||||
|
output = self.fc(out_pool)
|
||||||
|
|
||||||
|
return output
|
|
@ -0,0 +1,57 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from . import BasicModule
|
||||||
|
from module import Embedding, CNN
|
||||||
|
|
||||||
|
from utils import seq_len_to_mask
|
||||||
|
|
||||||
|
|
||||||
|
class PCNN(BasicModule):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(PCNN, self).__init__()
|
||||||
|
|
||||||
|
self.use_pcnn = cfg.use_pcnn
|
||||||
|
if cfg.dim_strategy == 'cat':
|
||||||
|
cfg.in_channels = cfg.word_dim + 2 * cfg.pos_dim
|
||||||
|
else:
|
||||||
|
cfg.in_channels = cfg.word_dim
|
||||||
|
|
||||||
|
self.embedding = Embedding(cfg)
|
||||||
|
self.cnn = CNN(cfg)
|
||||||
|
self.fc1 = nn.Linear(len(cfg.kernel_sizes) * cfg.out_channels, cfg.intermediate)
|
||||||
|
self.fc2 = nn.Linear(cfg.intermediate, cfg.num_attributes)
|
||||||
|
self.dropout = nn.Dropout(cfg.dropout)
|
||||||
|
|
||||||
|
if self.use_pcnn:
|
||||||
|
self.fc_pcnn = nn.Linear(3 * len(cfg.kernel_sizes) * cfg.out_channels,
|
||||||
|
len(cfg.kernel_sizes) * cfg.out_channels)
|
||||||
|
self.pcnn_mask_embedding = nn.Embedding(4, 3)
|
||||||
|
masks = torch.tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0, 100]])
|
||||||
|
self.pcnn_mask_embedding.weight.data.copy_(masks)
|
||||||
|
self.pcnn_mask_embedding.weight.requires_grad = False
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
word, lens, entity_pos, attribute_value_pos = x['word'], x['lens'], x['entity_pos'], x['attribute_value_pos']
|
||||||
|
mask = seq_len_to_mask(lens)
|
||||||
|
|
||||||
|
inputs = self.embedding(word, entity_pos, attribute_value_pos)
|
||||||
|
out, out_pool = self.cnn(inputs, mask=mask)
|
||||||
|
|
||||||
|
if self.use_pcnn:
|
||||||
|
out = out.unsqueeze(-1) # [B, L, Hs, 1]
|
||||||
|
pcnn_mask = x['pcnn_mask']
|
||||||
|
pcnn_mask = self.pcnn_mask_embedding(pcnn_mask).unsqueeze(-2) # [B, L, 1, 3]
|
||||||
|
out = out + pcnn_mask # [B, L, Hs, 3]
|
||||||
|
out = out.max(dim=1)[0] - 100 # [B, Hs, 3]
|
||||||
|
out_pool = out.view(out.size(0), -1) # [B, 3 * Hs]
|
||||||
|
out_pool = F.leaky_relu(self.fc_pcnn(out_pool)) # [B, Hs]
|
||||||
|
out_pool = self.dropout(out_pool)
|
||||||
|
|
||||||
|
output = self.fc1(out_pool)
|
||||||
|
output = F.leaky_relu(output)
|
||||||
|
output = self.dropout(output)
|
||||||
|
output = self.fc2(output)
|
||||||
|
|
||||||
|
return output
|
|
@ -0,0 +1,30 @@
|
||||||
|
import torch.nn as nn
|
||||||
|
from . import BasicModule
|
||||||
|
from module import Embedding
|
||||||
|
from module import Transformer as TransformerBlock
|
||||||
|
|
||||||
|
from utils import seq_len_to_mask
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(BasicModule):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(Transformer, self).__init__()
|
||||||
|
|
||||||
|
if cfg.dim_strategy == 'cat':
|
||||||
|
cfg.hidden_size = cfg.word_dim + 2 * cfg.pos_dim
|
||||||
|
else:
|
||||||
|
cfg.hidden_size = cfg.word_dim
|
||||||
|
|
||||||
|
self.embedding = Embedding(cfg)
|
||||||
|
self.transformer = TransformerBlock(cfg)
|
||||||
|
self.fc = nn.Linear(cfg.hidden_size, cfg.num_attributes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
word, lens, entity_pos, attribute_value_pos = x['word'], x['lens'], x['entity_pos'], x['attribute_value_pos']
|
||||||
|
mask = seq_len_to_mask(lens)
|
||||||
|
inputs = self.embedding(word, entity_pos, attribute_value_pos)
|
||||||
|
last_layer_hidden_state, all_hidden_states, all_attentions = self.transformer(inputs, key_padding_mask=mask)
|
||||||
|
out_pool = last_layer_hidden_state.max(dim=1)[0]
|
||||||
|
output = self.fc(out_pool)
|
||||||
|
|
||||||
|
return output
|
|
@ -0,0 +1,7 @@
|
||||||
|
from .BasicModule import BasicModule
|
||||||
|
from .PCNN import PCNN
|
||||||
|
from .BiLSTM import BiLSTM
|
||||||
|
from .Transformer import Transformer
|
||||||
|
from .Capsule import Capsule
|
||||||
|
from .GCN import GCN
|
||||||
|
from .LM import LM
|
|
@ -0,0 +1,141 @@
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DotAttention(nn.Module):
|
||||||
|
def __init__(self, dropout=0.0):
|
||||||
|
super(DotAttention, self).__init__()
|
||||||
|
self.dropout = dropout
|
||||||
|
|
||||||
|
def forward(self, Q, K, V, mask_out=None, head_mask=None):
|
||||||
|
"""
|
||||||
|
一般输入信息 X 时,假设 K = V = X
|
||||||
|
|
||||||
|
att_weight = softmax( score_func(q, k) )
|
||||||
|
att = sum( att_weight * v )
|
||||||
|
|
||||||
|
:param Q: [..., L, H]
|
||||||
|
:param K: [..., S, H]
|
||||||
|
:param V: [..., S, H]
|
||||||
|
:param mask_out: [..., 1, S]
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
H = Q.size(-1)
|
||||||
|
|
||||||
|
scale = float(H)**0.5
|
||||||
|
attention_weight = torch.matmul(Q, K.transpose(-1, -2)) / scale
|
||||||
|
|
||||||
|
if mask_out is not None:
|
||||||
|
# 当 DotAttention 单独使用时(几乎不会),保证维度一样
|
||||||
|
while mask_out.dim() != Q.dim():
|
||||||
|
mask_out = mask_out.unsqueeze(1)
|
||||||
|
attention_weight.masked_fill_(mask_out, -1e8)
|
||||||
|
|
||||||
|
attention_weight = F.softmax(attention_weight, dim=-1)
|
||||||
|
|
||||||
|
attention_weight = F.dropout(attention_weight, self.dropout)
|
||||||
|
|
||||||
|
# mask heads if we want to:
|
||||||
|
# multi head 才会使用
|
||||||
|
if head_mask is not None:
|
||||||
|
attention_weight = attention_weight * head_mask
|
||||||
|
|
||||||
|
attention_out = torch.matmul(attention_weight, V)
|
||||||
|
|
||||||
|
return attention_out, attention_weight
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Module):
|
||||||
|
def __init__(self, embed_dim, num_heads, dropout=0.0, output_attentions=True):
|
||||||
|
"""
|
||||||
|
:param embed_dim: 输入的维度,必须能被 num_heads 整除
|
||||||
|
:param num_heads: attention 的个数
|
||||||
|
:param dropout: float。
|
||||||
|
"""
|
||||||
|
super(MultiHeadAttention, self).__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.output_attentions = output_attentions
|
||||||
|
self.head_dim = int(embed_dim / num_heads)
|
||||||
|
self.all_head_dim = self.head_dim * num_heads
|
||||||
|
assert self.all_head_dim == embed_dim, logger.error(
|
||||||
|
f"embed_dim{embed_dim} must be divisible by num_heads{num_heads}")
|
||||||
|
|
||||||
|
self.q_in = nn.Linear(embed_dim, self.all_head_dim)
|
||||||
|
self.k_in = nn.Linear(embed_dim, self.all_head_dim)
|
||||||
|
self.v_in = nn.Linear(embed_dim, self.all_head_dim)
|
||||||
|
self.attention = DotAttention(dropout=dropout)
|
||||||
|
self.out = nn.Linear(self.all_head_dim, embed_dim)
|
||||||
|
|
||||||
|
def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||||
|
"""
|
||||||
|
:param Q: [B, L, Hs]
|
||||||
|
:param K: [B, S, Hs]
|
||||||
|
:param V: [B, S, Hs]
|
||||||
|
:param key_padding_mask: [B, S] 为 1/True 的地方需要 mask
|
||||||
|
:param attention_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
|
||||||
|
:param head_mask: [N] 指定 head mask 掉, 为 1/True 的地方需要 mask
|
||||||
|
"""
|
||||||
|
B, L, Hs = Q.shape
|
||||||
|
S = V.size(1)
|
||||||
|
N, H = self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
q = self.q_in(Q).view(B, L, N, H).transpose(1, 2) # [B, N, L, H]
|
||||||
|
k = self.k_in(K).view(B, S, N, H).transpose(1, 2) # [B, N, S, H]
|
||||||
|
v = self.v_in(V).view(B, S, N, H).transpose(1, 2) # [B, N, S, H]
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = key_padding_mask.ne(0)
|
||||||
|
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.ne(0)
|
||||||
|
if attention_mask.dim() == 1:
|
||||||
|
attention_mask = attention_mask.unsqueeze(0)
|
||||||
|
elif attention_mask.dim() == 2:
|
||||||
|
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'attention_mask dim must be 1 or 2, can not be {attention_mask.dim()}')
|
||||||
|
|
||||||
|
if key_padding_mask is None:
|
||||||
|
mask_out = attention_mask if attention_mask is not None else None
|
||||||
|
else:
|
||||||
|
mask_out = (key_padding_mask + attention_mask).ne(0) if attention_mask is not None else key_padding_mask
|
||||||
|
|
||||||
|
if head_mask is not None:
|
||||||
|
head_mask = head_mask.eq(0)
|
||||||
|
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
|
||||||
|
attention_out, attention_weight = self.attention(q, k, v, mask_out=mask_out, head_mask=head_mask)
|
||||||
|
|
||||||
|
attention_out = attention_out.transpose(1, 2).reshape(B, L, N * H) # [B, N, L, H] -> [B, L, N * H]
|
||||||
|
|
||||||
|
# concat all heads, and do output linear
|
||||||
|
attention_out = self.out(attention_out) # [B, L, N * H] -> [B, L, H]
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
return attention_out, attention_weight
|
||||||
|
else:
|
||||||
|
return attention_out,
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
|
from utils import seq_len_to_mask
|
||||||
|
|
||||||
|
q = torch.randn(4, 6, 20) # [B, L, H]
|
||||||
|
k = v = torch.randn(4, 5, 20) # [B, S, H]
|
||||||
|
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
|
||||||
|
attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉
|
||||||
|
head_mask = torch.tensor([0, 1]) # 为1 的地方 mask 掉
|
||||||
|
|
||||||
|
m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0, output_attentions=True)
|
||||||
|
ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask)
|
||||||
|
print(ao.shape, aw.shape) # [B, L, H] [B, N, L, S]
|
||||||
|
print(ao)
|
||||||
|
print(aw.unbind(1))
|
|
@ -0,0 +1,114 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class GELU(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(GELU, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
|
class CNN(nn.Module):
|
||||||
|
"""
|
||||||
|
nlp 里为了保证输出的句长 = 输入的句长,一般使用奇数 kernel_size,如 [3, 5, 7, 9]
|
||||||
|
当然也可以不等长输出,keep_length 设为 False
|
||||||
|
此时,padding = k // 2
|
||||||
|
stride 一般为 1
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
in_channels : 一般就是 word embedding 的维度,或者 hidden size 的维度
|
||||||
|
out_channels : int
|
||||||
|
kernel_sizes : list 为了保证输出长度=输入长度,必须为奇数: 3, 5, 7...
|
||||||
|
activation : [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||||
|
pooling_strategy : [max, avg, cls]
|
||||||
|
dropout: : float
|
||||||
|
"""
|
||||||
|
super(CNN, self).__init__()
|
||||||
|
|
||||||
|
# self.xxx = config.xxx
|
||||||
|
self.in_channels = config.in_channels
|
||||||
|
self.out_channels = config.out_channels
|
||||||
|
self.kernel_sizes = config.kernel_sizes
|
||||||
|
self.activation = config.activation
|
||||||
|
self.pooling_strategy = config.pooling_strategy
|
||||||
|
self.dropout = config.dropout
|
||||||
|
self.keep_length = config.keep_length
|
||||||
|
for kernel_size in self.kernel_sizes:
|
||||||
|
assert kernel_size % 2 == 1, "kernel size has to be odd numbers."
|
||||||
|
|
||||||
|
# convolution
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
nn.Conv1d(in_channels=self.in_channels,
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=k,
|
||||||
|
stride=1,
|
||||||
|
padding=k // 2 if self.keep_length else 0,
|
||||||
|
dilation=1,
|
||||||
|
groups=1,
|
||||||
|
bias=False) for k in self.kernel_sizes
|
||||||
|
])
|
||||||
|
|
||||||
|
# activation function
|
||||||
|
assert self.activation in ['relu', 'lrelu', 'prelu', 'selu', 'celu', 'gelu', 'sigmoid', 'tanh'], \
|
||||||
|
'activation function must choose from [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]'
|
||||||
|
self.activations = nn.ModuleDict([
|
||||||
|
['relu', nn.ReLU()],
|
||||||
|
['lrelu', nn.LeakyReLU()],
|
||||||
|
['prelu', nn.PReLU()],
|
||||||
|
['selu', nn.SELU()],
|
||||||
|
['celu', nn.CELU()],
|
||||||
|
['gelu', GELU()],
|
||||||
|
['sigmoid', nn.Sigmoid()],
|
||||||
|
['tanh', nn.Tanh()],
|
||||||
|
])
|
||||||
|
|
||||||
|
# pooling
|
||||||
|
assert self.pooling_strategy in ['max', 'avg', 'cls'], 'pooling strategy must choose from [max, avg, cls]'
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(self.dropout)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
"""
|
||||||
|
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值
|
||||||
|
:param mask: [batch_size, max_len], 句长部分为0,padding部分为1。不影响卷积运算,max-pool一定不会pool到pad为0的位置
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# [B, L, H] -> [B, H, L] (注释:将 H 维度当作输入 channel 维度)
|
||||||
|
x = torch.transpose(x, 1, 2)
|
||||||
|
|
||||||
|
# convolution + activation [[B, H, L], ... ]
|
||||||
|
act_fn = self.activations[self.activation]
|
||||||
|
|
||||||
|
x = [act_fn(conv(x)) for conv in self.convs]
|
||||||
|
x = torch.cat(x, dim=1)
|
||||||
|
|
||||||
|
# mask
|
||||||
|
if mask is not None:
|
||||||
|
# [B, L] -> [B, 1, L]
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
x = x.masked_fill_(mask, 1e-12)
|
||||||
|
|
||||||
|
# pooling
|
||||||
|
# [[B, H, L], ... ] -> [[B, H], ... ]
|
||||||
|
if self.pooling_strategy == 'max':
|
||||||
|
xp = F.max_pool1d(x, kernel_size=x.size(2)).squeeze(2)
|
||||||
|
# 等价于 xp = torch.max(x, dim=2)[0]
|
||||||
|
|
||||||
|
elif self.pooling_strategy == 'avg':
|
||||||
|
x_len = mask.squeeze().eq(0).sum(-1).unsqueeze(-1).to(torch.float).to(device=mask.device)
|
||||||
|
xp = torch.sum(x, dim=-1) / x_len
|
||||||
|
|
||||||
|
else:
|
||||||
|
# self.pooling_strategy == 'cls'
|
||||||
|
xp = x[:, :, 0]
|
||||||
|
|
||||||
|
x = x.transpose(1, 2)
|
||||||
|
x = self.dropout(x)
|
||||||
|
xp = self.dropout(xp)
|
||||||
|
|
||||||
|
return x, xp # [B, L, Hs], [B, Hs]
|
|
@ -0,0 +1,54 @@
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Capsule(nn.Module):
|
||||||
|
def __init__(self, cfg):
|
||||||
|
super(Capsule, self).__init__()
|
||||||
|
|
||||||
|
# self.xxx = cfg.xxx
|
||||||
|
self.input_dim_capsule = cfg.input_dim_capsule
|
||||||
|
self.dim_capsule = cfg.dim_capsule
|
||||||
|
self.num_capsule = cfg.num_capsule
|
||||||
|
self.batch_size = cfg.batch_size
|
||||||
|
self.share_weights = cfg.share_weights
|
||||||
|
self.num_iterations = cfg.num_iterations
|
||||||
|
|
||||||
|
if self.share_weights:
|
||||||
|
W = torch.zeros(1, self.input_dim_capsule, self.num_capsule * self.dim_capsule)
|
||||||
|
else:
|
||||||
|
W = torch.zeros(self.batch_size, self.input_dim_capsule, self.num_capsule * self.dim_capsule)
|
||||||
|
|
||||||
|
W = nn.init.xavier_normal_(W)
|
||||||
|
self.W = nn.Parameter(W)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: [B, L, H] # 从 CNN / RNN 得到的结果
|
||||||
|
L 作为 input_num_capsules, H 作为 input_dim_capsule
|
||||||
|
"""
|
||||||
|
B, I, _ = x.size() # I 是 input_num_capsules
|
||||||
|
O, F = self.num_capsule, self.dim_capsule
|
||||||
|
|
||||||
|
u = torch.matmul(x, self.W)
|
||||||
|
u = u.view(B, I, O, F).transpose(1, 2) # [B, O, I, F]
|
||||||
|
|
||||||
|
b = torch.zeros_like(u[:, :, :, 0]).to(device=u.device) # [B, O, I]
|
||||||
|
for i in range(self.num_iterations):
|
||||||
|
c = torch.softmax(b, dim=1) # [B, O_s, I]
|
||||||
|
v = torch.einsum('boi,boif->bof', [c, u]) # [B, O, F]
|
||||||
|
v = self.squash(v)
|
||||||
|
b = torch.einsum('bof,boif->boi', [v, u]) # [B, O, I]
|
||||||
|
|
||||||
|
return v # [B, O, F] [B, num_capsule, dim_capsule]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def squash(x: torch.Tensor):
|
||||||
|
x_norm = x.norm(p=2, dim=-1, keepdim=True)
|
||||||
|
mag = x_norm**2
|
||||||
|
out = x / x_norm * mag / (1 + mag)
|
||||||
|
|
||||||
|
return out
|
|
@ -0,0 +1,39 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
word embedding: 一般 0 为 padding
|
||||||
|
pos embedding: 一般 0 为 padding
|
||||||
|
dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加
|
||||||
|
"""
|
||||||
|
super(Embedding, self).__init__()
|
||||||
|
|
||||||
|
# self.xxx = config.xxx
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.word_dim = config.word_dim
|
||||||
|
self.pos_size = config.pos_size
|
||||||
|
self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim
|
||||||
|
self.dim_strategy = config.dim_strategy
|
||||||
|
|
||||||
|
self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0)
|
||||||
|
self.headPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
|
||||||
|
self.tailPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
|
||||||
|
|
||||||
|
self.layer_norm = nn.LayerNorm(self.word_dim)
|
||||||
|
|
||||||
|
def forward(self, *x):
|
||||||
|
word, head, tail = x
|
||||||
|
word_embedding = self.wordEmbed(word)
|
||||||
|
head_embedding = self.headPosEmbed(head)
|
||||||
|
tail_embedding = self.tailPosEmbed(tail)
|
||||||
|
|
||||||
|
if self.dim_strategy == 'cat':
|
||||||
|
return torch.cat((word_embedding, head_embedding, tail_embedding), -1)
|
||||||
|
elif self.dim_strategy == 'sum':
|
||||||
|
# 此时 pos_dim == word_dim
|
||||||
|
return self.layer_norm(word_embedding + head_embedding + tail_embedding)
|
||||||
|
else:
|
||||||
|
raise Exception('dim_strategy must choose from [sum, cat]')
|
|
@ -0,0 +1,143 @@
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class GCN(nn.Module):
|
||||||
|
def __init__(self,cfg):
|
||||||
|
super(GCN , self).__init__()
|
||||||
|
|
||||||
|
self.num_layers = cfg.num_layers
|
||||||
|
self.input_size = cfg.input_size
|
||||||
|
self.hidden_size = cfg.hidden_size
|
||||||
|
self.dropout = cfg.dropout
|
||||||
|
|
||||||
|
self.fc1 = nn.Linear(self.input_size , self.hidden_size)
|
||||||
|
self.fc = nn.Linear(self.hidden_size , self.hidden_size)
|
||||||
|
self.weight_list = nn.ModuleList()
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
self.weight_list.append(nn.Linear(self.hidden_size * (i + 1),self.hidden_size))
|
||||||
|
self.dropout = nn.Dropout(self.dropout)
|
||||||
|
|
||||||
|
def forward(self , x, adj):
|
||||||
|
L = adj.sum(2).unsqueeze(2) + 1
|
||||||
|
outputs = self.fc1(x)
|
||||||
|
cache_list = [outputs]
|
||||||
|
output_list = []
|
||||||
|
for l in range(self.num_layers):
|
||||||
|
Ax = adj.bmm(outputs)
|
||||||
|
AxW = self.weight_list[l](Ax)
|
||||||
|
AxW = AxW + self.weight_list[l](outputs)
|
||||||
|
AxW = AxW / L
|
||||||
|
gAxW = F.relu(AxW)
|
||||||
|
cache_list.append(gAxW)
|
||||||
|
outputs = torch.cat(cache_list , dim=2)
|
||||||
|
output_list.append(self.dropout(gAxW))
|
||||||
|
# gcn_outputs = torch.cat(output_list, dim=2)
|
||||||
|
gcn_outputs = output_list[self.num_layers - 1]
|
||||||
|
gcn_outputs = gcn_outputs + self.fc1(x)
|
||||||
|
|
||||||
|
out = self.fc(gcn_outputs)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Tree(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.parent = None
|
||||||
|
self.num_children = 0
|
||||||
|
self.children = list()
|
||||||
|
|
||||||
|
def add_child(self, child):
|
||||||
|
child.parent = self
|
||||||
|
self.num_children += 1
|
||||||
|
self.children.append(child)
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
s = getattr(self, '_size', -1)
|
||||||
|
if s != -1:
|
||||||
|
return self._size
|
||||||
|
else:
|
||||||
|
count = 1
|
||||||
|
for i in range(self.num_children):
|
||||||
|
count += self.children[i].size()
|
||||||
|
self._size = count
|
||||||
|
return self._size
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
yield self
|
||||||
|
for c in self.children:
|
||||||
|
for x in c:
|
||||||
|
yield x
|
||||||
|
|
||||||
|
def depth(self):
|
||||||
|
d = getattr(self, '_depth', -1)
|
||||||
|
if d != -1:
|
||||||
|
return self._depth
|
||||||
|
else:
|
||||||
|
count = 0
|
||||||
|
if self.num_children > 0:
|
||||||
|
for i in range(self.num_children):
|
||||||
|
child_depth = self.children[i].depth()
|
||||||
|
if child_depth > count:
|
||||||
|
count = child_depth
|
||||||
|
count += 1
|
||||||
|
self._depth = count
|
||||||
|
return self._depth
|
||||||
|
|
||||||
|
|
||||||
|
def head_to_adj(head, directed=True, self_loop=False):
|
||||||
|
"""
|
||||||
|
Convert a sequence of head indexes to an (numpy) adjacency matrix.
|
||||||
|
"""
|
||||||
|
seq_len = len(head)
|
||||||
|
head = head[:seq_len]
|
||||||
|
root = None
|
||||||
|
nodes = [Tree() for _ in head]
|
||||||
|
|
||||||
|
for i in range(seq_len):
|
||||||
|
h = head[i]
|
||||||
|
setattr(nodes[i], 'idx', i)
|
||||||
|
if h == 0:
|
||||||
|
root = nodes[i]
|
||||||
|
else:
|
||||||
|
nodes[h - 1].add_child(nodes[i])
|
||||||
|
|
||||||
|
assert root is not None
|
||||||
|
|
||||||
|
ret = np.zeros((seq_len, seq_len), dtype=np.float32)
|
||||||
|
queue = [root]
|
||||||
|
idx = []
|
||||||
|
while len(queue) > 0:
|
||||||
|
t, queue = queue[0], queue[1:]
|
||||||
|
idx += [t.idx]
|
||||||
|
for c in t.children:
|
||||||
|
ret[t.idx, c.idx] = 1
|
||||||
|
queue += t.children
|
||||||
|
|
||||||
|
if not directed:
|
||||||
|
ret = ret + ret.T
|
||||||
|
|
||||||
|
if self_loop:
|
||||||
|
for i in idx:
|
||||||
|
ret[i, i] = 1
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def pad_adj(adj, max_len):
|
||||||
|
pad_len = max_len - adj.shape[0]
|
||||||
|
for i in range(pad_len):
|
||||||
|
adj = np.insert(adj, adj.shape[-1], 0, axis=1)
|
||||||
|
for i in range(len):
|
||||||
|
adj = np.insert(adj, adj.shape[0], 0, axis=0)
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||||
|
|
||||||
|
|
||||||
|
class RNN(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
type_rnn: RNN, GRU, LSTM 可选
|
||||||
|
"""
|
||||||
|
super(RNN, self).__init__()
|
||||||
|
|
||||||
|
# self.xxx = config.xxx
|
||||||
|
self.input_size = config.input_size
|
||||||
|
self.hidden_size = config.hidden_size // 2 if config.bidirectional else config.hidden_size
|
||||||
|
self.num_layers = config.num_layers
|
||||||
|
self.dropout = config.dropout
|
||||||
|
self.bidirectional = config.bidirectional
|
||||||
|
self.last_layer_hn = config.last_layer_hn
|
||||||
|
self.type_rnn = config.type_rnn
|
||||||
|
|
||||||
|
rnn = eval(f'nn.{self.type_rnn}')
|
||||||
|
self.rnn = rnn(input_size=self.input_size,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_layers=self.num_layers,
|
||||||
|
dropout=self.dropout,
|
||||||
|
bidirectional=self.bidirectional,
|
||||||
|
bias=True,
|
||||||
|
batch_first=True)
|
||||||
|
|
||||||
|
# 有bug
|
||||||
|
# self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
"""orthogonal init yields generally good results than uniform init"""
|
||||||
|
gain = 1 # use default value
|
||||||
|
for nth in range(self.num_layers * self.bidirectional):
|
||||||
|
# w_ih, (4 * hidden_size x input_size)
|
||||||
|
nn.init.orthogonal_(self.rnn.all_weights[nth][0], gain=gain)
|
||||||
|
# w_hh, (4 * hidden_size x hidden_size)
|
||||||
|
nn.init.orthogonal_(self.rnn.all_weights[nth][1], gain=gain)
|
||||||
|
# b_ih, (4 * hidden_size)
|
||||||
|
nn.init.zeros_(self.rnn.all_weights[nth][2])
|
||||||
|
# b_hh, (4 * hidden_size)
|
||||||
|
nn.init.zeros_(self.rnn.all_weights[nth][3])
|
||||||
|
|
||||||
|
def forward(self, x, x_len):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值
|
||||||
|
x_len: torch.Tensor [L] 已经排好序的句长值
|
||||||
|
Returns:
|
||||||
|
output: torch.Tensor [B, L, H_out] 序列标注的使用结果
|
||||||
|
hn: torch.Tensor [B, N, H_out] / [B, H_out] 分类的结果,当 last_layer_hn 时只有最后一层结果
|
||||||
|
"""
|
||||||
|
B, L, _ = x.size()
|
||||||
|
H, N = self.hidden_size, self.num_layers
|
||||||
|
|
||||||
|
x_len = x_len.cpu()
|
||||||
|
x = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=True)
|
||||||
|
output, hn = self.rnn(x)
|
||||||
|
output, _ = pad_packed_sequence(output, batch_first=True, total_length=L)
|
||||||
|
|
||||||
|
if self.type_rnn == 'LSTM':
|
||||||
|
hn = hn[0]
|
||||||
|
if self.bidirectional:
|
||||||
|
hn = hn.view(N, 2, B, H).transpose(1, 2).contiguous().view(N, B, 2 * H).transpose(0, 1)
|
||||||
|
else:
|
||||||
|
hn = hn.transpose(0, 1)
|
||||||
|
if self.last_layer_hn:
|
||||||
|
hn = hn[:, -1, :]
|
||||||
|
|
||||||
|
return output, hn
|
|
@ -0,0 +1,149 @@
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .Attention import MultiHeadAttention
|
||||||
|
|
||||||
|
|
||||||
|
def gelu(x):
|
||||||
|
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
||||||
|
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||||
|
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
Also see https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||||
|
|
||||||
|
|
||||||
|
def gelu_new(x):
|
||||||
|
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||||
|
Also see https://arxiv.org/abs/1606.08415
|
||||||
|
"""
|
||||||
|
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||||
|
|
||||||
|
|
||||||
|
def swish(x):
|
||||||
|
return x * torch.sigmoid(x)
|
||||||
|
|
||||||
|
|
||||||
|
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new}
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerAttention(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(TransformerAttention, self).__init__()
|
||||||
|
|
||||||
|
# self.xxx = config.xxx
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_heads
|
||||||
|
self.dropout = config.dropout
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
self.layer_norm_eps = config.layer_norm_eps
|
||||||
|
|
||||||
|
self.multihead_attention = MultiHeadAttention(self.hidden_size, self.num_heads, self.dropout,
|
||||||
|
self.output_attentions)
|
||||||
|
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
||||||
|
self.dropout = nn.Dropout(self.dropout)
|
||||||
|
self.layerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(self, x, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||||
|
"""
|
||||||
|
:param x: [B, L, Hs]
|
||||||
|
:param attention_mask: [B, L] padding后的句子后面补0了,补0的位置为True,前面部分为False
|
||||||
|
:param head_mask: [L] [N,L]
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
attention_outputs = self.multihead_attention(x, x, x, key_padding_mask, attention_mask, head_mask)
|
||||||
|
attention_output = attention_outputs[0]
|
||||||
|
attention_output = self.dense(attention_output)
|
||||||
|
attention_output = self.dropout(attention_output)
|
||||||
|
attention_output = self.layerNorm(attention_output + x)
|
||||||
|
outputs = (attention_output, ) + attention_outputs[1:] # 后面是 attention weight
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerOutput(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(TransformerOutput, self).__init__()
|
||||||
|
|
||||||
|
# self.xxx = config.xxx
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.dropout = config.dropout
|
||||||
|
self.layer_norm_eps = config.layer_norm_eps
|
||||||
|
|
||||||
|
self.zoom_in = nn.Linear(self.hidden_size, self.intermediate_size)
|
||||||
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.zoom_out = nn.Linear(self.intermediate_size, self.hidden_size)
|
||||||
|
self.dropout = nn.Dropout(self.dropout)
|
||||||
|
self.layerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||||
|
|
||||||
|
def forward(self, input_tensor):
|
||||||
|
hidden_states = self.zoom_in(input_tensor)
|
||||||
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
|
hidden_states = self.zoom_out(hidden_states)
|
||||||
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
hidden_states = self.layerNorm(hidden_states + input_tensor)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerLayer(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(TransformerLayer, self).__init__()
|
||||||
|
|
||||||
|
self.attention = TransformerAttention(config)
|
||||||
|
self.output = TransformerOutput(config)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||||
|
attention_outputs = self.attention(hidden_states, key_padding_mask, attention_mask, head_mask)
|
||||||
|
attention_output = attention_outputs[0]
|
||||||
|
layer_output = self.output(attention_output)
|
||||||
|
outputs = (layer_output, ) + attention_outputs[1:]
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super(Transformer, self).__init__()
|
||||||
|
|
||||||
|
# self.xxx = config.xxx
|
||||||
|
self.num_hidden_layers = config.num_hidden_layers
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|
||||||
|
self.layer = nn.ModuleList([TransformerLayer(config) for _ in range(self.num_hidden_layers)])
|
||||||
|
|
||||||
|
def forward(self, hidden_states, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||||
|
"""
|
||||||
|
:param hidden_states: [B, L, Hs]
|
||||||
|
:param key_padding_mask: [B, S] 为 1/True 的地方需要 mask
|
||||||
|
:param attn_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
|
||||||
|
:param head_mask: [N] / [L, N] 指定 head mask 掉, 为 1/True 的地方需要 mask
|
||||||
|
"""
|
||||||
|
if head_mask is not None:
|
||||||
|
if head_mask.dim() == 1:
|
||||||
|
head_mask = head_mask.expand((self.num_hidden_layers, ) + head_mask.shape)
|
||||||
|
else:
|
||||||
|
head_mask = [None] * self.num_hidden_layers
|
||||||
|
|
||||||
|
all_hidden_states = ()
|
||||||
|
all_attentions = ()
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
if self.output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||||
|
|
||||||
|
layer_outputs = layer_module(hidden_states, key_padding_mask, attention_mask, head_mask[i])
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if self.output_attentions:
|
||||||
|
all_attentions = all_attentions + (layer_outputs[1], )
|
||||||
|
|
||||||
|
# Add last layer
|
||||||
|
if self.output_hidden_states:
|
||||||
|
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||||
|
|
||||||
|
outputs = (hidden_states, )
|
||||||
|
if self.output_hidden_states:
|
||||||
|
outputs = outputs + (all_hidden_states, )
|
||||||
|
if self.output_attentions:
|
||||||
|
outputs = outputs + (all_attentions, )
|
||||||
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
|
@ -0,0 +1,7 @@
|
||||||
|
from .Embedding import Embedding
|
||||||
|
from .CNN import CNN
|
||||||
|
from .RNN import RNN
|
||||||
|
from .Attention import DotAttention, MultiHeadAttention
|
||||||
|
from .Transformer import Transformer
|
||||||
|
from .Capsule import Capsule
|
||||||
|
from .GCN import GCN
|
|
@ -0,0 +1,6 @@
|
||||||
|
from .dataset import *
|
||||||
|
from .metrics import *
|
||||||
|
from .preprocess import *
|
||||||
|
from .serializer import *
|
||||||
|
from .trainer import *
|
||||||
|
from .vocab import *
|
|
@ -0,0 +1,74 @@
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
|
from deepke.utils import load_pkl
|
||||||
|
|
||||||
|
def collate_fn(cfg):
|
||||||
|
|
||||||
|
def collate_fn_intra(batch):
|
||||||
|
"""
|
||||||
|
Arg :
|
||||||
|
batch () : 数据集
|
||||||
|
Returna :
|
||||||
|
x (dict) : key为词,value为长度
|
||||||
|
y (List) : 关系对应值的集合
|
||||||
|
"""
|
||||||
|
batch.sort(key=lambda data: data['seq_len'], reverse=True)
|
||||||
|
|
||||||
|
max_len = batch[0]['seq_len']
|
||||||
|
|
||||||
|
def _padding(x, max_len):
|
||||||
|
return x + [0] * (max_len - len(x))
|
||||||
|
|
||||||
|
x, y = dict(), []
|
||||||
|
word, word_len = [], []
|
||||||
|
head_pos, tail_pos = [], []
|
||||||
|
pcnn_mask = []
|
||||||
|
for data in batch:
|
||||||
|
word.append(_padding(data['token2idx'], max_len))
|
||||||
|
word_len.append(data['seq_len'])
|
||||||
|
y.append(int(data['att2idx']))
|
||||||
|
|
||||||
|
if cfg.model_name != 'lm':
|
||||||
|
head_pos.append(_padding(data['entity_pos'], max_len))
|
||||||
|
tail_pos.append(_padding(data['attribute_value_pos'], max_len))
|
||||||
|
if cfg.model_name == 'cnn':
|
||||||
|
if cfg.use_pcnn:
|
||||||
|
pcnn_mask.append(_padding(data['entities_pos'], max_len))
|
||||||
|
|
||||||
|
x['word'] = torch.tensor(word)
|
||||||
|
x['lens'] = torch.tensor(word_len)
|
||||||
|
y = torch.tensor(y)
|
||||||
|
|
||||||
|
if cfg.model_name != 'lm':
|
||||||
|
x['entity_pos'] = torch.tensor(head_pos)
|
||||||
|
x['attribute_value_pos'] = torch.tensor(tail_pos)
|
||||||
|
if cfg.model_name == 'cnn' and cfg.use_pcnn:
|
||||||
|
x['pcnn_mask'] = torch.tensor(pcnn_mask)
|
||||||
|
if cfg.model_name == 'gcn':
|
||||||
|
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||||
|
B, L = len(batch), max_len
|
||||||
|
adj = torch.empty(B, L, L).random_(2)
|
||||||
|
x['adj'] = adj
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
return collate_fn_intra
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDataset(Dataset):
|
||||||
|
"""
|
||||||
|
默认使用 List 存储数据
|
||||||
|
"""
|
||||||
|
def __init__(self, fp):
|
||||||
|
self.file = load_pkl(fp)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
sample = self.file[item]
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.file)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,142 @@
|
||||||
|
import os
|
||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import optim
|
||||||
|
from hydra import utils
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
# self
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
|
import models as models
|
||||||
|
from tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||||
|
from utils import manual_seed, load_pkl
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@hydra.main(config_path="../conf/config.yaml")
|
||||||
|
def main(cfg):
|
||||||
|
cwd = utils.get_original_cwd()
|
||||||
|
cwd = cwd[0:-5]
|
||||||
|
cfg.cwd = cwd
|
||||||
|
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||||
|
logger.info(f'\n{cfg.pretty()}')
|
||||||
|
|
||||||
|
__Model__ = {
|
||||||
|
'cnn': models.PCNN,
|
||||||
|
'rnn': models.BiLSTM,
|
||||||
|
'transformer': models.Transformer,
|
||||||
|
'gcn': models.GCN,
|
||||||
|
'capsule': models.Capsule,
|
||||||
|
'lm': models.LM,
|
||||||
|
}
|
||||||
|
|
||||||
|
# device
|
||||||
|
if cfg.use_gpu and torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda', cfg.gpu_id)
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
logger.info(f'device: {device}')
|
||||||
|
|
||||||
|
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
|
||||||
|
if cfg.preprocess:
|
||||||
|
preprocess(cfg)
|
||||||
|
|
||||||
|
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
|
||||||
|
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
|
||||||
|
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
|
||||||
|
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
|
||||||
|
|
||||||
|
if cfg.model_name == 'lm':
|
||||||
|
vocab_size = None
|
||||||
|
else:
|
||||||
|
vocab = load_pkl(vocab_path)
|
||||||
|
vocab_size = vocab.count
|
||||||
|
cfg.vocab_size = vocab_size
|
||||||
|
|
||||||
|
train_dataset = CustomDataset(train_data_path)
|
||||||
|
valid_dataset = CustomDataset(valid_data_path)
|
||||||
|
test_dataset = CustomDataset(test_data_path)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||||
|
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||||
|
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||||
|
|
||||||
|
model = __Model__[cfg.model_name](cfg)
|
||||||
|
model.to(device)
|
||||||
|
logger.info(f'\n {model}')
|
||||||
|
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
||||||
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
best_f1, best_epoch = -1, 0
|
||||||
|
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
|
||||||
|
train_losses, valid_losses = [], []
|
||||||
|
|
||||||
|
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
|
||||||
|
writer = SummaryWriter('tensorboard')
|
||||||
|
else:
|
||||||
|
writer = None
|
||||||
|
|
||||||
|
logger.info('=' * 10 + ' Start training ' + '=' * 10)
|
||||||
|
|
||||||
|
for epoch in range(1, cfg.epoch + 1):
|
||||||
|
manual_seed(cfg.seed + epoch)
|
||||||
|
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
|
||||||
|
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
|
||||||
|
scheduler.step(valid_loss)
|
||||||
|
model_path = model.save(epoch, cfg)
|
||||||
|
# logger.info(model_path)
|
||||||
|
|
||||||
|
train_losses.append(train_loss)
|
||||||
|
valid_losses.append(valid_loss)
|
||||||
|
if best_f1 < valid_f1:
|
||||||
|
best_f1 = valid_f1
|
||||||
|
best_epoch = epoch
|
||||||
|
# 使用 valid loss 做 early stopping 的判断标准
|
||||||
|
if es_loss > valid_loss:
|
||||||
|
es_loss = valid_loss
|
||||||
|
es_f1 = valid_f1
|
||||||
|
es_epoch = epoch
|
||||||
|
es_patience = 0
|
||||||
|
es_path = model_path
|
||||||
|
else:
|
||||||
|
es_patience += 1
|
||||||
|
if es_patience >= cfg.early_stopping_patience:
|
||||||
|
best_es_epoch = es_epoch
|
||||||
|
best_es_f1 = es_f1
|
||||||
|
best_es_path = es_path
|
||||||
|
|
||||||
|
if cfg.show_plot:
|
||||||
|
if cfg.plot_utils == 'matplot':
|
||||||
|
plt.plot(train_losses, 'x-')
|
||||||
|
plt.plot(valid_losses, '+-')
|
||||||
|
plt.legend(['train', 'valid'])
|
||||||
|
plt.title('train/valid comparison loss')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if cfg.plot_utils == 'tensorboard':
|
||||||
|
for i in range(len(train_losses)):
|
||||||
|
writer.add_scalars('train/valid_comparison_loss', {
|
||||||
|
'train': train_losses[i],
|
||||||
|
'valid': valid_losses[i]
|
||||||
|
}, i)
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
|
||||||
|
f'this epoch macro f1: {best_es_f1:0.4f}')
|
||||||
|
logger.info(f'this model save path: {best_es_path}')
|
||||||
|
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
|
||||||
|
f'this epoch macro f1: {best_f1:.4f}')
|
||||||
|
|
||||||
|
logger.info('=====end of training====')
|
||||||
|
logger.info('')
|
||||||
|
logger.info('=====start test performance====')
|
||||||
|
validate(-1, model, test_dataloader, criterion, device, cfg)
|
||||||
|
logger.info('=====ending====')
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
from sklearn.metrics import precision_recall_fscore_support
|
||||||
|
|
||||||
|
|
||||||
|
class Metric(metaclass=ABCMeta):
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Resets the metric to to it's initial state.
|
||||||
|
This is called at the start of each epoch.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, *args):
|
||||||
|
"""
|
||||||
|
Updates the metric's state using the passed batch output.
|
||||||
|
This is called once for each batch.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def compute(self):
|
||||||
|
"""
|
||||||
|
Computes the metric based on it's accumulated state.
|
||||||
|
This is called at the end of each epoch.
|
||||||
|
:return: the actual quantity of interest
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PRMetric():
|
||||||
|
def __init__(self):
|
||||||
|
"""
|
||||||
|
暂时调用 sklearn 的方法
|
||||||
|
"""
|
||||||
|
self.y_true = np.empty(0)
|
||||||
|
self.y_pred = np.empty(0)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""
|
||||||
|
重置为0
|
||||||
|
"""
|
||||||
|
self.y_true = np.empty(0)
|
||||||
|
self.y_pred = np.empty(0)
|
||||||
|
|
||||||
|
def update(self, y_true: torch.Tensor, y_pred: torch.Tensor):
|
||||||
|
"""
|
||||||
|
更新tensor,保留值,取消原有梯度
|
||||||
|
"""
|
||||||
|
y_true = y_true.cpu().detach().numpy()
|
||||||
|
y_pred = y_pred.cpu().detach().numpy()
|
||||||
|
y_pred = np.argmax(y_pred, axis=-1)
|
||||||
|
|
||||||
|
self.y_true = np.append(self.y_true, y_true)
|
||||||
|
self.y_pred = np.append(self.y_pred, y_pred)
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
"""
|
||||||
|
计算acc,p,r,f1并返回
|
||||||
|
"""
|
||||||
|
p, r, f1, _ = precision_recall_fscore_support(self.y_true, self.y_pred, average='macro', warn_for=tuple())
|
||||||
|
_, _, acc, _ = precision_recall_fscore_support(self.y_true, self.y_pred, average='micro', warn_for=tuple())
|
||||||
|
|
||||||
|
return acc, p, r, f1
|
|
@ -0,0 +1,147 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import hydra
|
||||||
|
from hydra import utils
|
||||||
|
from serializer import Serializer
|
||||||
|
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
|
from utils import load_pkl, load_csv
|
||||||
|
import models as models
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_data(data, cfg):
|
||||||
|
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
||||||
|
attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False)
|
||||||
|
atts = _handle_attribute_data(attribute_data)
|
||||||
|
cfg.vocab_size = vocab.count
|
||||||
|
serializer = Serializer(do_chinese_split=cfg.chinese_split)
|
||||||
|
serial = serializer.serialize
|
||||||
|
|
||||||
|
_serialize_sentence(data, serial, cfg)
|
||||||
|
_convert_tokens_into_index(data, vocab)
|
||||||
|
_add_pos_seq(data, cfg)
|
||||||
|
logger.info('start sentence preprocess...')
|
||||||
|
formats = '\nsentence: {}\nchinese_split: {}\n' \
|
||||||
|
'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}'
|
||||||
|
logger.info(
|
||||||
|
formats.format(data[0]['sentence'], cfg.chinese_split,
|
||||||
|
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
|
||||||
|
data[0]['entity_index'], data[0]['attribute_value_index']))
|
||||||
|
return data, atts
|
||||||
|
|
||||||
|
|
||||||
|
def _get_predict_instance(cfg):
|
||||||
|
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
|
||||||
|
flag = flag.strip().lower()
|
||||||
|
if flag == 'y' or flag == 'yes':
|
||||||
|
sentence = '张冬梅,女,汉族,1968年2月生,河南淇县人,1988年7月加入中国共产党,1989年9月参加工作,中央党校经济管理专业毕业,中央党校研究生学历'
|
||||||
|
entity = '张冬梅'
|
||||||
|
attribute_value = '汉族'
|
||||||
|
elif flag == 'n' or flag == 'no':
|
||||||
|
sentence = input('请输入句子:')
|
||||||
|
entity = input('请输入句中需要预测的实体:')
|
||||||
|
attribute_value = input('请输入句中需要预测的属性值:')
|
||||||
|
elif flag == 'exit':
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print('please input yes or no, or exit!')
|
||||||
|
_get_predict_instance(cfg)
|
||||||
|
|
||||||
|
|
||||||
|
instance = dict()
|
||||||
|
instance['sentence'] = sentence.strip()
|
||||||
|
instance['entity'] = entity.strip()
|
||||||
|
instance['attribute_value'] = attribute_value.strip()
|
||||||
|
instance['entity_offset'] = sentence.find(entity)
|
||||||
|
instance['attribute_value_offset'] = sentence.find(attribute_value)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(config_path='../conf/config.yaml')
|
||||||
|
def main(cfg):
|
||||||
|
cwd = utils.get_original_cwd()
|
||||||
|
cwd = cwd[0:-5]
|
||||||
|
cfg.cwd = cwd
|
||||||
|
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||||
|
print(cfg.pretty())
|
||||||
|
|
||||||
|
# get predict instance
|
||||||
|
instance = _get_predict_instance(cfg)
|
||||||
|
data = [instance]
|
||||||
|
|
||||||
|
# preprocess data
|
||||||
|
data, rels = _preprocess_data(data, cfg)
|
||||||
|
|
||||||
|
# model
|
||||||
|
__Model__ = {
|
||||||
|
'cnn': models.PCNN,
|
||||||
|
'rnn': models.BiLSTM,
|
||||||
|
'transformer': models.Transformer,
|
||||||
|
'gcn': models.GCN,
|
||||||
|
'capsule': models.Capsule,
|
||||||
|
'lm': models.LM,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 最好在 cpu 上预测
|
||||||
|
cfg.use_gpu = False
|
||||||
|
if cfg.use_gpu and torch.cuda.is_available():
|
||||||
|
device = torch.device('cuda', cfg.gpu_id)
|
||||||
|
else:
|
||||||
|
device = torch.device('cpu')
|
||||||
|
logger.info(f'device: {device}')
|
||||||
|
|
||||||
|
model = __Model__[cfg.model_name](cfg)
|
||||||
|
logger.info(f'model name: {cfg.model_name}')
|
||||||
|
logger.info(f'\n {model}')
|
||||||
|
model.load(cfg.fp, device=device)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
x = dict()
|
||||||
|
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
||||||
|
|
||||||
|
if cfg.model_name != 'lm':
|
||||||
|
x['entity_pos'], x['attribute_value_pos'] = torch.tensor([data[0]['entity_pos']]), torch.tensor([data[0]['attribute_value_pos']])
|
||||||
|
if cfg.model_name == 'cnn':
|
||||||
|
if cfg.use_pcnn:
|
||||||
|
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
||||||
|
if cfg.model_name == 'gcn':
|
||||||
|
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||||
|
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
|
||||||
|
x['adj'] = adj
|
||||||
|
|
||||||
|
|
||||||
|
for key in x.keys():
|
||||||
|
x[key] = x[key].to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
y_pred = model(x)
|
||||||
|
y_pred = torch.softmax(y_pred, dim=-1)[0]
|
||||||
|
prob = y_pred.max().item()
|
||||||
|
prob_att = list(rels.keys())[y_pred.argmax().item()]
|
||||||
|
logger.info(f"\"{data[0]['entity']}\" 和 \"{data[0]['attribute_value']}\" 在句中属性为:\"{prob_att}\",置信度为{prob:.2f}。")
|
||||||
|
|
||||||
|
if cfg.predict_plot:
|
||||||
|
plt.rcParams["font.family"] = 'Arial Unicode MS'
|
||||||
|
x = list(rels.keys())
|
||||||
|
height = list(y_pred.cpu().numpy())
|
||||||
|
plt.bar(x, height)
|
||||||
|
for x, y in zip(x, height):
|
||||||
|
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
|
||||||
|
plt.xlabel('关系')
|
||||||
|
plt.ylabel('置信度')
|
||||||
|
plt.xticks(rotation=315)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,158 @@
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import List, Dict
|
||||||
|
from transformers import BertTokenizer
|
||||||
|
from serializer import Serializer
|
||||||
|
from vocab import Vocab
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
|
from utils import save_pkl, load_csv
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"_handle_pos_limit",
|
||||||
|
"_add_pos_seq",
|
||||||
|
"_convert_tokens_into_index",
|
||||||
|
"_serialize_sentence",
|
||||||
|
"_lm_serialize",
|
||||||
|
"_add_attribute_data",
|
||||||
|
"_handle_attribute_data",
|
||||||
|
"preprocess"
|
||||||
|
]
|
||||||
|
def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:
|
||||||
|
for i,p in enumerate(pos):
|
||||||
|
if p > limit:
|
||||||
|
pos[i] = limit
|
||||||
|
if p < -limit:
|
||||||
|
pos[i] = -limit
|
||||||
|
return [p + limit + 1 for p in pos]
|
||||||
|
|
||||||
|
def _add_pos_seq(train_data: List[Dict], cfg):
|
||||||
|
for d in train_data:
|
||||||
|
entities_idx = [d['entity_index'],d['attribute_value_index']
|
||||||
|
] if d['entity_index'] < d['attribute_value_index'] else [d['entity_index'], d['attribute_value_index']]
|
||||||
|
|
||||||
|
d['entity_pos'] = list(map(lambda i: i - d['entity_index'], list(range(d['seq_len']))))
|
||||||
|
d['entity_pos'] = _handle_pos_limit(d['entity_pos'],int(cfg.pos_limit))
|
||||||
|
|
||||||
|
d['attribute_value_pos'] = list(map(lambda i: i - d['attribute_value_index'], list(range(d['seq_len']))))
|
||||||
|
d['attribute_value_pos'] = _handle_pos_limit(d['attribute_value_pos'],int(cfg.pos_limit))
|
||||||
|
|
||||||
|
if cfg.model_name == 'cnn':
|
||||||
|
if cfg.use_pcnn:
|
||||||
|
d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\
|
||||||
|
[3] * (d['seq_len'] - entities_idx[1])
|
||||||
|
|
||||||
|
def _convert_tokens_into_index(data: List[Dict], vocab):
|
||||||
|
unk_str = '[UNK]'
|
||||||
|
unk_idx = vocab.word2idx[unk_str]
|
||||||
|
|
||||||
|
for d in data:
|
||||||
|
d['token2idx'] = [vocab.word2idx.get(i, unk_idx) for i in d['tokens']]
|
||||||
|
d['seq_len'] = len(d['token2idx'])
|
||||||
|
|
||||||
|
def _serialize_sentence(data: List[Dict], serial, cfg):
|
||||||
|
for d in data:
|
||||||
|
sent = d['sentence'].strip()
|
||||||
|
snet = sent.replace(d['entity'] , ' entity ' , 1).replace(d['attribute_value'] , ' attribute_value ' , 1)
|
||||||
|
d['tokens'] = serial(sent, never_split=['entity','attribute_value'])
|
||||||
|
entity_index, attribute_value_index = d['entity_offset'] , d['attribute_value_offset']
|
||||||
|
d['entity_index'],d['attribute_value_index'] = int(entity_index) , int(attribute_value_index)
|
||||||
|
|
||||||
|
def _lm_serialize(data: List[Dict], cfg):
|
||||||
|
logger.info('use bert tokenizer...')
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(cfg.lm_file)
|
||||||
|
for d in data:
|
||||||
|
sent = d['sentence'].strip()
|
||||||
|
sent += '[SEP]' + d['entity'] + '[SEP]' + d['attribute_value']
|
||||||
|
d['token2idx'] = tokenizer.encode(sent, add_special_tokens=True)
|
||||||
|
d['seq_len'] = len(d['token2idx'])
|
||||||
|
|
||||||
|
def _add_attribute_data(atts: Dict, data: List) -> None:
|
||||||
|
for d in data:
|
||||||
|
d['att2idx'] = atts[d['attribute']]['index']
|
||||||
|
|
||||||
|
def _handle_attribute_data(attribute_data: List[Dict]) -> Dict:
|
||||||
|
atts = OrderedDict()
|
||||||
|
attribute_data = sorted(attribute_data, key=lambda i: int(i['index']))
|
||||||
|
for d in attribute_data:
|
||||||
|
atts[d['attribute']] = {
|
||||||
|
'index': int(d['index'])
|
||||||
|
}
|
||||||
|
return atts
|
||||||
|
|
||||||
|
def preprocess(cfg):
|
||||||
|
logger.info('===== start preprocess data =====')
|
||||||
|
train_fp = os.path.join(cfg.cwd, cfg.data_path, 'train.csv')
|
||||||
|
valid_fp = os.path.join(cfg.cwd, cfg.data_path, 'valid.csv')
|
||||||
|
test_fp = os.path.join(cfg.cwd, cfg.data_path, 'test.csv')
|
||||||
|
attribute_fp = os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv')
|
||||||
|
|
||||||
|
logger.info('load raw files...')
|
||||||
|
train_data = load_csv(train_fp)
|
||||||
|
valid_data = load_csv(valid_fp)
|
||||||
|
test_data = load_csv(test_fp)
|
||||||
|
attribute_data = load_csv(attribute_fp)
|
||||||
|
|
||||||
|
logger.info('convert relation into index...')
|
||||||
|
atts = _handle_attribute_data(attribute_data)
|
||||||
|
_add_attribute_data(atts,train_data)
|
||||||
|
_add_attribute_data(atts,test_data)
|
||||||
|
_add_attribute_data(atts,valid_data)
|
||||||
|
|
||||||
|
logger.info('verify whether use pretrained language models...')
|
||||||
|
if cfg.model_name == 'lm':
|
||||||
|
logger.info('use pretrained language models serialize sentence...')
|
||||||
|
_lm_serialize(train_data, cfg)
|
||||||
|
_lm_serialize(valid_data, cfg)
|
||||||
|
_lm_serialize(test_data, cfg)
|
||||||
|
else:
|
||||||
|
logger.info('serialize sentence into tokens...')
|
||||||
|
serializer = Serializer(do_chinese_split=cfg.chinese_split, do_lower_case=True)
|
||||||
|
serial = serializer.serialize
|
||||||
|
_serialize_sentence(train_data, serial, cfg)
|
||||||
|
_serialize_sentence(valid_data, serial, cfg)
|
||||||
|
_serialize_sentence(test_data, serial, cfg)
|
||||||
|
|
||||||
|
logger.info('build vocabulary...')
|
||||||
|
vocab = Vocab('word')
|
||||||
|
train_tokens = [d['tokens'] for d in train_data]
|
||||||
|
valid_tokens = [d['tokens'] for d in valid_data]
|
||||||
|
test_tokens = [d['tokens'] for d in test_data]
|
||||||
|
sent_tokens = [*train_tokens, *valid_tokens, *test_tokens]
|
||||||
|
|
||||||
|
for sent in sent_tokens:
|
||||||
|
vocab.add_words(sent)
|
||||||
|
vocab.trim(min_freq=cfg.min_freq)
|
||||||
|
|
||||||
|
_convert_tokens_into_index(train_data, vocab)
|
||||||
|
_convert_tokens_into_index(valid_data, vocab)
|
||||||
|
_convert_tokens_into_index(test_data, vocab)
|
||||||
|
|
||||||
|
logger.info('build position sequence...')
|
||||||
|
_add_pos_seq(train_data, cfg)
|
||||||
|
_add_pos_seq(valid_data, cfg)
|
||||||
|
_add_pos_seq(test_data, cfg)
|
||||||
|
|
||||||
|
logger.info('save data for backup...')
|
||||||
|
os.makedirs(os.path.join(cfg.cwd, cfg.out_path), exist_ok=True)
|
||||||
|
train_save_fp = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
|
||||||
|
valid_save_fp = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
|
||||||
|
test_save_fp = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
|
||||||
|
save_pkl(train_data, train_save_fp)
|
||||||
|
save_pkl(valid_data, valid_save_fp)
|
||||||
|
save_pkl(test_data, test_save_fp)
|
||||||
|
|
||||||
|
if cfg.model_name != 'lm':
|
||||||
|
vocab_save_fp = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
|
||||||
|
vocab_txt = os.path.join(cfg.cwd, cfg.out_path, 'vocab.txt')
|
||||||
|
save_pkl(vocab, vocab_save_fp)
|
||||||
|
logger.info('save vocab in txt file, for watching...')
|
||||||
|
with open(vocab_txt, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(os.linesep.join(vocab.word2idx.keys()))
|
||||||
|
|
||||||
|
logger.info('===== end preprocess data =====')
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,270 @@
|
||||||
|
import re
|
||||||
|
import unicodedata
|
||||||
|
import jieba
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
jieba.setLogLevel(logging.INFO)
|
||||||
|
|
||||||
|
|
||||||
|
class Serializer():
|
||||||
|
def __init__(self, never_split: List = None, do_lower_case=True, do_chinese_split=False):
|
||||||
|
self.never_split = never_split if never_split is not None else []
|
||||||
|
self.do_lower_case = do_lower_case
|
||||||
|
self.do_chinese_split = do_chinese_split
|
||||||
|
|
||||||
|
def serialize(self, text, never_split: List = None):
|
||||||
|
"""
|
||||||
|
将一段文本按照制定拆分规则,拆分成一个词汇List
|
||||||
|
Args :
|
||||||
|
text (String) : 所需拆分文本
|
||||||
|
never_split (List) : 不拆分的词,默认为空
|
||||||
|
Rerurn :
|
||||||
|
output_tokens (List): 拆分后的结果
|
||||||
|
"""
|
||||||
|
never_split = self.never_split + (never_split if never_split is not None else [])
|
||||||
|
text = self._clean_text(text)
|
||||||
|
|
||||||
|
if self.do_chinese_split:
|
||||||
|
output_tokens = self._use_jieba_cut(text, never_split)
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
text = self._tokenize_chinese_chars(text)
|
||||||
|
orig_tokens = self._orig_tokenize(text)
|
||||||
|
split_tokens = []
|
||||||
|
for token in orig_tokens:
|
||||||
|
if self.do_lower_case and token not in never_split:
|
||||||
|
token = token.lower()
|
||||||
|
token = self._run_strip_accents(token)
|
||||||
|
split_tokens.extend(self._run_split_on_punc(token, never_split=never_split))
|
||||||
|
|
||||||
|
output_tokens = self._whitespace_tokenize(" ".join(split_tokens))
|
||||||
|
|
||||||
|
return output_tokens
|
||||||
|
|
||||||
|
def _clean_text(self, text):
|
||||||
|
"""
|
||||||
|
删除文本中无效字符以及空白字符
|
||||||
|
Arg :
|
||||||
|
text (String) : 所需删除的文本
|
||||||
|
Return :
|
||||||
|
"".join(output) (String) : 删除后的文本
|
||||||
|
"""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if cp == 0 or cp == 0xfffd or self.is_control(char):
|
||||||
|
continue
|
||||||
|
if self.is_whitespace(char):
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _use_jieba_cut(self, text, never_split):
|
||||||
|
"""
|
||||||
|
使用jieba分词
|
||||||
|
Args :
|
||||||
|
text (String) : 所需拆分文本
|
||||||
|
never_split (List) : 不拆分的词
|
||||||
|
Return :
|
||||||
|
tokens (List) : 拆分完的结果
|
||||||
|
"""
|
||||||
|
for word in never_split:
|
||||||
|
jieba.suggest_freq(word, True)
|
||||||
|
tokens = jieba.lcut(text)
|
||||||
|
if self.do_lower_case:
|
||||||
|
tokens = [i.lower() for i in tokens]
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
tokens.remove(' ')
|
||||||
|
except:
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _tokenize_chinese_chars(self, text):
|
||||||
|
"""
|
||||||
|
在CJK字符周围添加空格
|
||||||
|
Arg :
|
||||||
|
text (String) : 所需拆分文本
|
||||||
|
Return :
|
||||||
|
"".join(output) (String) : 添加完后的文本
|
||||||
|
"""
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cp = ord(char)
|
||||||
|
if self.is_chinese_char(cp):
|
||||||
|
output.append(" ")
|
||||||
|
output.append(char)
|
||||||
|
output.append(" ")
|
||||||
|
else:
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _orig_tokenize(self, text):
|
||||||
|
"""
|
||||||
|
在空白和一些标点符号(如逗号或句点)上拆分文本
|
||||||
|
Arg :
|
||||||
|
text (String) : 所需拆分文本
|
||||||
|
Return :
|
||||||
|
tokens (List) : 分词完的结果
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
# 常见的断句标点
|
||||||
|
punc = """,.?!;: 、|,。?!;:《》「」【】/<>|\“ ”‘ ’"""
|
||||||
|
punc_re = '|'.join(re.escape(x) for x in punc)
|
||||||
|
tokens = re.sub(punc_re, lambda x: ' ' + x.group() + ' ', text)
|
||||||
|
tokens = tokens.split()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _whitespace_tokenize(self, text):
|
||||||
|
"""
|
||||||
|
进行基本的空白字符清理和分割
|
||||||
|
Arg :
|
||||||
|
text (String) : 所需拆分文本
|
||||||
|
Return :
|
||||||
|
tokens (List) : 分词完的结果
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return []
|
||||||
|
tokens = text.split()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _run_strip_accents(self, text):
|
||||||
|
"""
|
||||||
|
从文本中去除重音符号
|
||||||
|
Arg :
|
||||||
|
text (String) : 所需拆分文本
|
||||||
|
Return :
|
||||||
|
"".join(output) (String) : 去除后的文本
|
||||||
|
|
||||||
|
"""
|
||||||
|
text = unicodedata.normalize("NFD", text)
|
||||||
|
output = []
|
||||||
|
for char in text:
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Mn":
|
||||||
|
continue
|
||||||
|
output.append(char)
|
||||||
|
return "".join(output)
|
||||||
|
|
||||||
|
def _run_split_on_punc(self, text, never_split=None):
|
||||||
|
"""
|
||||||
|
通过标点符号拆分文本
|
||||||
|
Args :
|
||||||
|
text (String) : 所需拆分文本
|
||||||
|
never_split (List) : 不拆分的词,默认为空
|
||||||
|
Return :
|
||||||
|
["".join(x) for x in output] (List) : 拆分完的结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
if never_split is not None and text in never_split:
|
||||||
|
return [text]
|
||||||
|
chars = list(text)
|
||||||
|
i = 0
|
||||||
|
start_new_word = True
|
||||||
|
output = []
|
||||||
|
while i < len(chars):
|
||||||
|
char = chars[i]
|
||||||
|
if self.is_punctuation(char):
|
||||||
|
output.append([char])
|
||||||
|
start_new_word = True
|
||||||
|
else:
|
||||||
|
if start_new_word:
|
||||||
|
output.append([])
|
||||||
|
start_new_word = False
|
||||||
|
output[-1].append(char)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return ["".join(x) for x in output]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_control(char):
|
||||||
|
"""
|
||||||
|
判断字符是否为控制字符
|
||||||
|
Arg :
|
||||||
|
char : 字符
|
||||||
|
Return :
|
||||||
|
bool : 判断结果
|
||||||
|
"""
|
||||||
|
# These are technically control characters but we count them as whitespace
|
||||||
|
# characters.
|
||||||
|
if char == "\t" or char == "\n" or char == "\r":
|
||||||
|
return False
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat.startswith("C"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_whitespace(char):
|
||||||
|
"""
|
||||||
|
判断字符是否为空白字符
|
||||||
|
Arg :
|
||||||
|
char : 字符
|
||||||
|
Return :
|
||||||
|
bool : 判断结果
|
||||||
|
"""
|
||||||
|
# \t, \n, and \r are technically contorl characters but we treat them
|
||||||
|
# as whitespace since they are generally considered as such.
|
||||||
|
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||||
|
return True
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat == "Zs":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_chinese_char(cp):
|
||||||
|
"""
|
||||||
|
|
||||||
|
判断字符是否为中文字符
|
||||||
|
Arg :
|
||||||
|
cp (char): 字符
|
||||||
|
Return :
|
||||||
|
bool : 判断结果
|
||||||
|
|
||||||
|
"""
|
||||||
|
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||||
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||||
|
#
|
||||||
|
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||||
|
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||||
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||||
|
# space-separated words, so they are not treated specially and handled
|
||||||
|
# like the all of the other languages.
|
||||||
|
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||||
|
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||||
|
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||||
|
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||||
|
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||||
|
(cp >= 0x2B820 and cp <= 0x2CEAF) or (cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||||
|
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_punctuation(char):
|
||||||
|
"""
|
||||||
|
判断字符是否为标点字符
|
||||||
|
Arg :
|
||||||
|
char : 字符
|
||||||
|
Return :
|
||||||
|
bool : 判断结果
|
||||||
|
"""
|
||||||
|
cp = ord(char)
|
||||||
|
# We treat all non-letter/number ASCII as punctuation.
|
||||||
|
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||||
|
# Punctuation class but we treat them as punctuation anyways, for
|
||||||
|
# consistency.
|
||||||
|
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96)
|
||||||
|
or (cp >= 123 and cp <= 126)):
|
||||||
|
return True
|
||||||
|
cat = unicodedata.category(char)
|
||||||
|
if cat.startswith("P"):
|
||||||
|
return True
|
||||||
|
return False
|
|
@ -0,0 +1,117 @@
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from metrics import PRMetric
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
|
||||||
|
"""
|
||||||
|
training the model.
|
||||||
|
Args:
|
||||||
|
epoch (int): number of training steps.
|
||||||
|
model (class): model of training.
|
||||||
|
dataloader (dict): dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders.
|
||||||
|
optimizer (Callable): optimizer of training.
|
||||||
|
criterion (Callable): loss criterion of training.
|
||||||
|
device (torch.device): device of training.
|
||||||
|
writer (class): output to tensorboard.
|
||||||
|
cfg: configutation of training.
|
||||||
|
Return:
|
||||||
|
losses[-1] : the loss of training
|
||||||
|
"""
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
metric = PRMetric()
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
for batch_idx, (x, y) in enumerate(dataloader, 1):
|
||||||
|
for key, value in x.items():
|
||||||
|
x[key] = value.to(device)
|
||||||
|
|
||||||
|
y = y.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
y_pred = model(x)
|
||||||
|
|
||||||
|
if cfg.model_name == 'capsule':
|
||||||
|
loss = model.loss(y_pred, y)
|
||||||
|
else:
|
||||||
|
loss = criterion(y_pred, y)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
metric.update(y_true=y, y_pred=y_pred)
|
||||||
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
data_total = len(dataloader.dataset)
|
||||||
|
data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)
|
||||||
|
if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):
|
||||||
|
# p r f1 皆为 macro,因为micro时三者相同,定义为acc
|
||||||
|
acc, p, r, f1 = metric.compute()
|
||||||
|
logger.info(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\t'
|
||||||
|
f'Loss: {loss.item():.6f}')
|
||||||
|
logger.info(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\t'
|
||||||
|
f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||||
|
|
||||||
|
if cfg.show_plot and not cfg.only_comparison_plot:
|
||||||
|
if cfg.plot_utils == 'matplot':
|
||||||
|
plt.plot(losses)
|
||||||
|
plt.title(f'epoch {epoch} train loss')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
if cfg.plot_utils == 'tensorboard':
|
||||||
|
for i in range(len(losses)):
|
||||||
|
writer.add_scalar(f'epoch_{epoch}_training_loss', losses[i], i)
|
||||||
|
|
||||||
|
return losses[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def validate(epoch, model, dataloader, criterion, device, cfg):
|
||||||
|
"""
|
||||||
|
validating the model.
|
||||||
|
Args:
|
||||||
|
epoch (int): number of validating steps.
|
||||||
|
model (class): model of validating.
|
||||||
|
dataloader (dict): dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders.
|
||||||
|
criterion (Callable): loss criterion of validating.
|
||||||
|
device (torch.device): device of validating.
|
||||||
|
cfg: configutation of validating.
|
||||||
|
Return:
|
||||||
|
f1 : f1 score
|
||||||
|
loss : the loss of validating
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
metric = PRMetric()
|
||||||
|
losses = []
|
||||||
|
|
||||||
|
for batch_idx, (x, y) in enumerate(dataloader, 1):
|
||||||
|
for key, value in x.items():
|
||||||
|
x[key] = value.to(device)
|
||||||
|
y = y.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
y_pred = model(x)
|
||||||
|
|
||||||
|
if cfg.model_name == 'capsule':
|
||||||
|
loss = model.loss(y_pred, y)
|
||||||
|
else:
|
||||||
|
loss = criterion(y_pred, y)
|
||||||
|
|
||||||
|
metric.update(y_true=y, y_pred=y_pred)
|
||||||
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
loss = sum(losses) / len(losses)
|
||||||
|
acc, p, r, f1 = metric.compute()
|
||||||
|
data_total = len(dataloader.dataset)
|
||||||
|
|
||||||
|
if epoch >= 0:
|
||||||
|
logger.info(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}')
|
||||||
|
logger.info(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||||
|
else:
|
||||||
|
logger.info(f'Test Data: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}')
|
||||||
|
logger.info(f'Test Data: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||||
|
|
||||||
|
return f1, loss
|
|
@ -0,0 +1,113 @@
|
||||||
|
import logging
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Sequence, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SPECIAL_TOKENS_KEYS = [
|
||||||
|
"pad_token",
|
||||||
|
"unk_token",
|
||||||
|
"mask_token",
|
||||||
|
"cls_token",
|
||||||
|
"sep_token",
|
||||||
|
"bos_token",
|
||||||
|
"eos_token",
|
||||||
|
"head_token",
|
||||||
|
"tail_token",
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
SPECIAL_TOKENS_VALUES = [
|
||||||
|
"[PAD]",
|
||||||
|
"[UNK]",
|
||||||
|
"[MASK]",
|
||||||
|
"[CLS]",
|
||||||
|
"[SEP]",
|
||||||
|
"[BOS]",
|
||||||
|
"[EOS]",
|
||||||
|
"HEAD",
|
||||||
|
"TAIL",
|
||||||
|
]
|
||||||
|
|
||||||
|
SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES))
|
||||||
|
|
||||||
|
|
||||||
|
class Vocab(object):
|
||||||
|
"""
|
||||||
|
构建词汇表,增加词汇,删除低频词汇
|
||||||
|
"""
|
||||||
|
def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS):
|
||||||
|
self.name = name
|
||||||
|
self.init_tokens = init_tokens
|
||||||
|
self.trimed = False
|
||||||
|
self.word2idx = {}
|
||||||
|
self.word2count = {}
|
||||||
|
self.idx2word = {}
|
||||||
|
self.count = 0
|
||||||
|
self._add_init_tokens()
|
||||||
|
|
||||||
|
def _add_init_tokens(self):
|
||||||
|
"""
|
||||||
|
添加初始tokens
|
||||||
|
"""
|
||||||
|
for token in self.init_tokens.values():
|
||||||
|
self._add_word(token)
|
||||||
|
|
||||||
|
def _add_word(self, word: str):
|
||||||
|
"""
|
||||||
|
增加单个词汇
|
||||||
|
Arg :
|
||||||
|
word (String) : 增加的词汇
|
||||||
|
"""
|
||||||
|
if word not in self.word2idx:
|
||||||
|
self.word2idx[word] = self.count
|
||||||
|
self.word2count[word] = 1
|
||||||
|
self.idx2word[self.count] = word
|
||||||
|
self.count += 1
|
||||||
|
else:
|
||||||
|
self.word2count[word] += 1
|
||||||
|
|
||||||
|
def add_words(self, words: Sequence):
|
||||||
|
"""
|
||||||
|
通过数组增加词汇
|
||||||
|
Arg :
|
||||||
|
words (List) : 增加的词汇组
|
||||||
|
"""
|
||||||
|
for word in words:
|
||||||
|
self._add_word(word)
|
||||||
|
|
||||||
|
def trim(self, min_freq=2, verbose: Optional[bool] = True):
|
||||||
|
"""
|
||||||
|
当 word 词频低于 min_freq 时,从词库中删除
|
||||||
|
Args:
|
||||||
|
min_freq (int): 最低词频
|
||||||
|
verbose (bool) : 是否打印日志
|
||||||
|
"""
|
||||||
|
assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}'
|
||||||
|
min_freq = int(min_freq)
|
||||||
|
if min_freq < 2:
|
||||||
|
return
|
||||||
|
if self.trimed:
|
||||||
|
return
|
||||||
|
self.trimed = True
|
||||||
|
|
||||||
|
keep_words = []
|
||||||
|
new_words = []
|
||||||
|
|
||||||
|
for k, v in self.word2count.items():
|
||||||
|
if v >= min_freq:
|
||||||
|
keep_words.append(k)
|
||||||
|
new_words.extend([k] * v)
|
||||||
|
if verbose:
|
||||||
|
before_len = len(keep_words)
|
||||||
|
after_len = len(self.word2idx) - len(self.init_tokens)
|
||||||
|
logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(
|
||||||
|
before_len, after_len, before_len / after_len * 100))
|
||||||
|
|
||||||
|
# Reinitialize dictionaries
|
||||||
|
self.word2idx = {}
|
||||||
|
self.word2count = {}
|
||||||
|
self.idx2word = {}
|
||||||
|
self.count = 0
|
||||||
|
self._add_init_tokens()
|
||||||
|
self.add_words(new_words)
|
|
@ -0,0 +1,2 @@
|
||||||
|
from .ioUtils import *
|
||||||
|
from .nnUtils import *
|
|
@ -0,0 +1,181 @@
|
||||||
|
import os
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
import logging
|
||||||
|
from typing import NewType, List, Tuple, Dict, Any
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'load_pkl',
|
||||||
|
'save_pkl',
|
||||||
|
'load_csv',
|
||||||
|
'save_csv',
|
||||||
|
'load_jsonld',
|
||||||
|
'save_jsonld',
|
||||||
|
'jsonld2csv',
|
||||||
|
'csv2jsonld',
|
||||||
|
]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
Path = str
|
||||||
|
|
||||||
|
|
||||||
|
def load_pkl(fp: Path, verbose: bool = True) -> Any:
|
||||||
|
"""
|
||||||
|
读取文件
|
||||||
|
Args :
|
||||||
|
fp (String) : 读取数据地址
|
||||||
|
verbose (bool) : 是否打印日志
|
||||||
|
Return :
|
||||||
|
data (Any) : 读取的数据
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
logger.info(f'load data from {fp}')
|
||||||
|
|
||||||
|
with open(fp, 'rb') as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
保存文件
|
||||||
|
Args :
|
||||||
|
data (Any) : 数据
|
||||||
|
fp (String) :保存的地址
|
||||||
|
verbose (bool) : 是否打印日志
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
logger.info(f'save data in {fp}')
|
||||||
|
|
||||||
|
with open(fp, 'wb') as f:
|
||||||
|
pickle.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
|
def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:
|
||||||
|
"""
|
||||||
|
读取csv格式文件
|
||||||
|
Args :
|
||||||
|
fp (String) : 保存地址
|
||||||
|
is_tsv (bool) : 是否为excel-tab格式
|
||||||
|
verbose (bool) : 是否打印日志
|
||||||
|
Return :
|
||||||
|
list(reader) (List): 读取的List数据
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
logger.info(f'load csv from {fp}')
|
||||||
|
|
||||||
|
dialect = 'excel-tab' if is_tsv else 'excel'
|
||||||
|
with open(fp, encoding='utf-8') as f:
|
||||||
|
reader = csv.DictReader(f, dialect=dialect)
|
||||||
|
return list(reader)
|
||||||
|
|
||||||
|
|
||||||
|
def save_csv(data: List[Dict], fp: Path, save_in_tsv: False, write_head=True, verbose=True) -> None:
|
||||||
|
"""
|
||||||
|
保存csv格式文件
|
||||||
|
Args :
|
||||||
|
data (List) : 所需保存的List数据
|
||||||
|
fp (String) : 保存地址
|
||||||
|
save_in_tsv (bool) : 是否保存为excel-tab格式
|
||||||
|
write_head (bool) : 是否写表头
|
||||||
|
verbose (bool) : 是否打印日志
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
logger.info(f'save csv file in: {fp}')
|
||||||
|
|
||||||
|
with open(fp, 'w', encoding='utf-8') as f:
|
||||||
|
fieldnames = data[0].keys()
|
||||||
|
dialect = 'excel-tab' if save_in_tsv else 'excel'
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames, dialect=dialect)
|
||||||
|
if write_head:
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(data)
|
||||||
|
|
||||||
|
|
||||||
|
def load_jsonld(fp: Path, verbose: bool = True) -> List:
|
||||||
|
|
||||||
|
"""
|
||||||
|
读取jsonld文件
|
||||||
|
Args:
|
||||||
|
fp (String): jsonld 文件地址
|
||||||
|
verbose (bool): 是否打印日志
|
||||||
|
Return:
|
||||||
|
datas (List) : 读取后的List
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
logger.info(f'load jsonld from {fp}')
|
||||||
|
|
||||||
|
datas = []
|
||||||
|
with open(fp, encoding='utf-8') as f:
|
||||||
|
for l in f:
|
||||||
|
line = json.loads(l)
|
||||||
|
data = list(line.values())
|
||||||
|
datas.append(data)
|
||||||
|
|
||||||
|
return datas
|
||||||
|
|
||||||
|
|
||||||
|
def save_jsonld(fp):
|
||||||
|
"""
|
||||||
|
保存jsonld格式文件
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def jsonld2csv(fp: str, verbose: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
读入 jsonld 文件,存储在同位置同名的 csv 文件
|
||||||
|
Args:
|
||||||
|
fp (String): jsonld 文件地址
|
||||||
|
verbose (bool): 是否打印日志
|
||||||
|
Return:
|
||||||
|
fp_new (String):文件地址
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
root, ext = os.path.splitext(fp)
|
||||||
|
fp_new = root + '.csv'
|
||||||
|
if verbose:
|
||||||
|
print(f'read jsonld file in: {fp}')
|
||||||
|
with open(fp, encoding='utf-8') as f:
|
||||||
|
for l in f:
|
||||||
|
line = json.loads(l)
|
||||||
|
data.append(line)
|
||||||
|
if verbose:
|
||||||
|
print('saving...')
|
||||||
|
with open(fp_new, 'w', encoding='utf-8') as f:
|
||||||
|
fieldnames = data[0].keys()
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames, dialect='excel')
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(data)
|
||||||
|
if verbose:
|
||||||
|
print(f'saved csv file in: {fp_new}')
|
||||||
|
return fp_new
|
||||||
|
|
||||||
|
|
||||||
|
def csv2jsonld(fp: str, verbose: bool = True) -> str:
|
||||||
|
"""
|
||||||
|
读入 csv 文件,存储在同位置同名的 jsonld 文件
|
||||||
|
Args:
|
||||||
|
fp (String): csv 文件地址
|
||||||
|
verbose (bool): 是否打印日志
|
||||||
|
Return:
|
||||||
|
fp_new (String):文件地址
|
||||||
|
"""
|
||||||
|
data = []
|
||||||
|
root, ext = os.path.splitext(fp)
|
||||||
|
fp_new = root + '.jsonld'
|
||||||
|
if verbose:
|
||||||
|
print(f'read csv file in: {fp}')
|
||||||
|
with open(fp, encoding='utf-8') as f:
|
||||||
|
writer = csv.DictReader(f, fieldnames=None, dialect='excel')
|
||||||
|
for line in writer:
|
||||||
|
data.append(line)
|
||||||
|
if verbose:
|
||||||
|
print('saving...')
|
||||||
|
with open(fp_new, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(os.linesep.join([json.dumps(l, ensure_ascii=False) for l in data]))
|
||||||
|
if verbose:
|
||||||
|
print(f'saved jsonld file in: {fp_new}')
|
||||||
|
return fp_new
|
|
@ -0,0 +1,76 @@
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Tuple, Dict, Union
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'manual_seed',
|
||||||
|
'seq_len_to_mask',
|
||||||
|
'to_one_hot',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def manual_seed(seed: int = 1) -> None:
|
||||||
|
"""
|
||||||
|
设置seed。
|
||||||
|
"""
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
#if torch.cuda.CUDA_ENABLED and use_deterministic_cudnn:
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):
|
||||||
|
"""
|
||||||
|
将一个表示sequence length的一维数组转换为二维的mask,默认pad的位置为1。
|
||||||
|
转变 1-d seq_len到2-d mask。
|
||||||
|
|
||||||
|
Args :
|
||||||
|
seq_len (list, np.ndarray, torch.LongTensor) : shape将是(B,)
|
||||||
|
max_len (int): 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有区别,所以需要传入一个max_len使得mask的长度是pad到该长度。
|
||||||
|
Return:
|
||||||
|
mask (np.ndarray, torch.Tensor) : shape将是(B, max_length), 元素类似为bool或torch.uint8
|
||||||
|
"""
|
||||||
|
if isinstance(seq_len, list):
|
||||||
|
seq_len = np.array(seq_len)
|
||||||
|
|
||||||
|
if isinstance(seq_len, np.ndarray):
|
||||||
|
seq_len = torch.from_numpy(seq_len)
|
||||||
|
|
||||||
|
if isinstance(seq_len, torch.Tensor):
|
||||||
|
assert seq_len.dim() == 1, logger.error(f"seq_len can only have one dimension, got {seq_len.dim()} != 1.")
|
||||||
|
batch_size = seq_len.size(0)
|
||||||
|
max_len = int(max_len) if max_len else seq_len.max().long()
|
||||||
|
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len.device)
|
||||||
|
if mask_pos_to_true:
|
||||||
|
mask = broad_cast_seq_len.ge(seq_len.unsqueeze(1))
|
||||||
|
else:
|
||||||
|
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
|
||||||
|
else:
|
||||||
|
raise logger.error("Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.")
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def to_one_hot(x: torch.Tensor, length: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor):[B] , 一般是 target 的值
|
||||||
|
length (int) : L ,一般是关系种类树
|
||||||
|
Return:
|
||||||
|
x_one_hot.to(device=x.device) (torch.Tensor) : [B, L] 每一行,只有对应位置为1,其余为0
|
||||||
|
"""
|
||||||
|
B = x.size(0)
|
||||||
|
x_one_hot = torch.zeros(B, length)
|
||||||
|
for i in range(B):
|
||||||
|
x_one_hot[i, x[i]] = 1.0
|
||||||
|
|
||||||
|
return x_one_hot.to(device=x.device)
|
|
@ -11,9 +11,9 @@ from torch.utils.tensorboard import SummaryWriter
|
||||||
# self
|
# self
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
import deepke.models as models
|
import models as models
|
||||||
from deepke.tools import preprocess , CustomDataset, collate_fn ,train, validate
|
from tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||||
from deepke.utils import manual_seed, load_pkl
|
from utils import manual_seed, load_pkl
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,12 @@ import torch
|
||||||
import logging
|
import logging
|
||||||
import hydra
|
import hydra
|
||||||
from hydra import utils
|
from hydra import utils
|
||||||
from deepke.tools import Serializer
|
from serializer import Serializer
|
||||||
from deepke.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||||
from deepke.utils import load_pkl, load_csv
|
from utils import load_pkl, load_csv
|
||||||
import deepke.models as models
|
import models as models
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
Loading…
Reference in New Issue