test
This commit is contained in:
parent
db9d0e198c
commit
607ac898e3
16
README.md
16
README.md
|
@ -83,7 +83,7 @@ DeepKE 提供了多种知识抽取模型。
|
|||
|
||||
数据为txt文件,样式范例为:
|
||||
|
||||
| sentence | Person | Location | Organization | Miscellaneous |
|
||||
| Sentence | Person | Location | Organization | Miscellaneous |
|
||||
| :----------------------------------------------------------: | :----------------------------------: | :---------------: | :-------------------------: | :-------------------: |
|
||||
| Australian Tom Moody took six for 82 but Chris Adams, 123, and Tim O'Gorman, 109, took Derbyshire to 471 and a first innings lead of 233. | Tom Moody, Chris Adams, Tim O'Gorman | / | Derbysire | Australian |
|
||||
| Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival. | Irene | Hangzhou, Warsaw | Zhejiang University | Chopin Music Festival |
|
||||
|
@ -95,6 +95,20 @@ DeepKE 提供了多种知识抽取模型。
|
|||
|
||||
3. AE
|
||||
|
||||
数据为csv文件,样式范例为:
|
||||
|
||||
| Sentence | Attribute | Entity | Entity_offset | Attribute_value | Attribute_value_offset |
|
||||
| :----------------------------------------------------: | :------: | :--------: | :---------: | :--------: | :---------: |
|
||||
| 张冬梅,女,汉族,1968年2月生,河南淇县人,1988年7月加入中国共产党,1989年9月参加工作,中央党校经济管理专业毕业,中央党校研究生学历 | 民族 | 张冬梅 | 0 | 汉族 | 6 |
|
||||
| 杨缨,字绵公,号钓溪,松溪县人,祖籍将乐,是北宋理学家杨时的七世孙 | 朝代 | 杨缨 | 0 | 北宋 | 22 |
|
||||
| 2014年10月1日许鞍华执导的电影《黄金时代》上映,冯绍峰饰演与之差别极大的民国东北爷们萧军,演技受到肯定 | 上映时间 | 黄金时代 | 19 | 2014年10月1日 | 0 |
|
||||
|
||||
具体流程请进入详细的README中,RE包括了以下三个子功能
|
||||
|
||||
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ae/regular/README.md)**
|
||||
|
||||
|
||||
|
||||
## 模型架构
|
||||
Deepke的架构图如下所示
|
||||
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 149 KiB |
|
@ -0,0 +1,56 @@
|
|||
## 快速上手
|
||||
|
||||
### 环境依赖
|
||||
|
||||
> python == 3.8
|
||||
|
||||
- torch == 1.5
|
||||
- hydra-core == 1.0.6
|
||||
- tensorboard == 2.4.1
|
||||
- matplotlib == 3.4.1
|
||||
- scikit-learn == 0.24.1
|
||||
- transformers == 4.5.0
|
||||
- jieba == 0.42.1
|
||||
- deepke
|
||||
|
||||
### 克隆代码
|
||||
```
|
||||
git clone git@github.com:zjunlp/DeepKE.git
|
||||
```
|
||||
### 使用pip安装
|
||||
|
||||
首先创建python虚拟环境,再进入虚拟环境
|
||||
|
||||
- 安装依赖: ```pip install -r requirements.txt```
|
||||
|
||||
### 使用数据进行训练预测
|
||||
|
||||
- 存放数据:在 `data/origin` 文件夹下存放训练数据。训练文件主要有三个文件。
|
||||
|
||||
- `train.csv`:存放训练数据集
|
||||
|
||||
- `valid.csv`:存放验证数据集
|
||||
|
||||
- `test.csv`:存放测试数据集
|
||||
|
||||
- `attribute.csv`:存放属性种类
|
||||
|
||||
- 开始训练:```python run.py``` (训练所用到参数都在conf文件夹中,修改即可)
|
||||
|
||||
- 每次训练的日志保存在 `logs` 文件夹内,模型结果保存在 `checkpoints` 文件夹内。
|
||||
|
||||
- 进行预测 ```python predict.py```
|
||||
|
||||
|
||||
## 模型内容
|
||||
1、CNN
|
||||
|
||||
2、RNN
|
||||
|
||||
3、Capsule
|
||||
|
||||
4、GCN
|
||||
|
||||
5、Transformer
|
||||
|
||||
6、预训练模型
|
|
@ -0,0 +1,17 @@
|
|||
# ??? is a mandatory value.
|
||||
# you should be able to set it without open_dict
|
||||
# but if you try to read it before it's set an error will get thrown.
|
||||
|
||||
# populated at runtime
|
||||
cwd: ???
|
||||
|
||||
|
||||
defaults:
|
||||
- hydra/output: custom
|
||||
- preprocess
|
||||
- train
|
||||
- embedding
|
||||
- predict
|
||||
- model: cnn # [cnn, rnn, transformer, capsule, gcn, lm]
|
||||
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
# populated at runtime
|
||||
vocab_size: ???
|
||||
word_dim: 60
|
||||
pos_size: ??? # 2 * pos_limit + 2
|
||||
pos_dim: 10 # 当为 sum 时,此值无效,和 word_dim 强行相同
|
||||
|
||||
dim_strategy: sum # [cat, sum]
|
||||
|
||||
# 属性种类
|
||||
num_attributes: 7
|
|
@ -0,0 +1,11 @@
|
|||
hydra:
|
||||
|
||||
run:
|
||||
# Output directory for normal runs
|
||||
dir: logs/${now:%Y-%m-%d_%H-%M-%S}
|
||||
|
||||
sweep:
|
||||
# Output directory for sweep runs
|
||||
dir: logs/${now:%Y-%m-%d_%H-%M-%S}
|
||||
# Output sub directory for sweep runs.
|
||||
subdir: ${hydra.job.num}_${hydra.job.id}
|
|
@ -0,0 +1,20 @@
|
|||
model_name: capsule
|
||||
|
||||
share_weights: True
|
||||
num_iterations: 5 # 迭代次数
|
||||
dropout: 0.3
|
||||
|
||||
input_dim_capsule: ??? # 由上层卷积结果得到,一般是卷积输出的 hidden_size
|
||||
dim_capsule: 50 # 输出 capsule 的维度
|
||||
num_capsule: ??? # 输出 capsule 的数目,和分类结果相同, == num_attributes
|
||||
|
||||
|
||||
# primary capsule 组成
|
||||
# 可以 embedding / cnn / rnn
|
||||
# 暂时先用 cnn
|
||||
in_channels: ??? # 使用 embedding 输出的结果,不需要指定
|
||||
out_channels: 100 # == input_dim_capsule
|
||||
kernel_sizes: [9] # 必须为奇数,而且要比较大
|
||||
activation: 'lrelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
keep_length: False # 不需要padding,太多无用信息
|
||||
pooling_strategy: cls # 无关紧要,根本用不到
|
|
@ -0,0 +1,13 @@
|
|||
model_name: cnn
|
||||
|
||||
in_channels: ??? # 使用 embedding 输出的结果,不需要指定
|
||||
out_channels: 100
|
||||
kernel_sizes: [3, 5, 7] # 必须为奇数,为了保证cnn的输出不改变句子长度
|
||||
activation: 'gelu' # [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
pooling_strategy: 'max' # [max, avg, cls]
|
||||
keep_length: True
|
||||
dropout: 0.3
|
||||
|
||||
# pcnn
|
||||
use_pcnn: False
|
||||
intermediate: 80
|
|
@ -0,0 +1,7 @@
|
|||
model_name: gcn
|
||||
|
||||
num_layers: 3
|
||||
|
||||
input_size: ??? # 使用 embedding 输出的结果,不需要指定
|
||||
hidden_size: 100
|
||||
dropout: 0.3
|
|
@ -0,0 +1,20 @@
|
|||
model_name: lm
|
||||
|
||||
# 当使用预训练语言模型时,该预训练的模型存放位置
|
||||
# lm_name = 'bert-base-chinese' # download usage
|
||||
#lm_file: 'pretrained'
|
||||
lm_file: '/home/yhy/transformers/bert-base-chinese'
|
||||
|
||||
# transformer 层数,初始 base bert 为12层
|
||||
# 但是数据量较小时调低些反而收敛更快效果更好
|
||||
num_hidden_layers: 1
|
||||
|
||||
|
||||
# 后面所接 bilstm 的参数
|
||||
type_rnn: 'LSTM' # [RNN, GRU, LSTM]
|
||||
input_size: 768 # 这个值由bert得到
|
||||
hidden_size: 100 # 必须为偶数
|
||||
num_layers: 1
|
||||
dropout: 0.3
|
||||
bidirectional: True
|
||||
last_layer_hn: True
|
|
@ -0,0 +1,10 @@
|
|||
model_name: rnn
|
||||
|
||||
type_rnn: 'LSTM' # [RNN, GRU, LSTM]
|
||||
|
||||
input_size: ??? # 使用 embedding 输出的结果,不需要指定
|
||||
hidden_size: 150 # 必须为偶数
|
||||
num_layers: 2
|
||||
dropout: 0.3
|
||||
bidirectional: True
|
||||
last_layer_hn: True
|
|
@ -0,0 +1,12 @@
|
|||
model_name: transformer
|
||||
|
||||
hidden_size: ??? # 使用 embedding 输出的结果,不需要指定
|
||||
num_heads: 4 # 必须能被 hidden_size 整除
|
||||
num_hidden_layers: 3
|
||||
intermediate_size: 256
|
||||
dropout: 0.1
|
||||
layer_norm_eps: 1e-12
|
||||
hidden_act: gelu_new # [relu, gelu, swish, gelu_new]
|
||||
|
||||
output_attentions: True
|
||||
output_hidden_states: True
|
|
@ -0,0 +1,2 @@
|
|||
# 自定义模型存储的路径
|
||||
fp = 'xxx/checkpoints/2019-12-03_17-35-30/cnn_epoch21.pth'
|
|
@ -0,0 +1,20 @@
|
|||
# 是否需要预处理数据
|
||||
# 当数据处理参数没有变换时,不需要重新预处理
|
||||
preprocess: True
|
||||
|
||||
# 原始数据存放位置
|
||||
data_path: 'data/origin'
|
||||
|
||||
# 预处理后存放文件位置
|
||||
out_path: 'data/out'
|
||||
|
||||
# 是否需要分词
|
||||
chinese_split: True
|
||||
|
||||
# vocab 构建时的最低词频控制
|
||||
min_freq: 3
|
||||
|
||||
# 句长限制: 指句子中词语相对entity的position限制
|
||||
# 如:[-30, 30],embed 时整体+31,变成[1, 61]
|
||||
# 则一共62个pos token,0 留给 pad
|
||||
pos_limit: 30
|
|
@ -0,0 +1,21 @@
|
|||
seed: 1
|
||||
|
||||
use_gpu: True
|
||||
gpu_id: 0
|
||||
|
||||
epoch: 50
|
||||
batch_size: 32
|
||||
learning_rate: 3e-4
|
||||
lr_factor: 0.7 # 学习率的衰减率
|
||||
lr_patience: 3 # 学习率衰减的等待epoch
|
||||
weight_decay: 1e-3 # L2正则
|
||||
|
||||
early_stopping_patience: 6
|
||||
|
||||
train_log: True
|
||||
log_interval: 10
|
||||
show_plot: True
|
||||
only_comparison_plot: False
|
||||
plot_utils: matplot # [matplot, tensorboard]
|
||||
|
||||
predict_plot: True
|
|
@ -0,0 +1,8 @@
|
|||
attribute,index
|
||||
None,0
|
||||
民族,1
|
||||
字,2
|
||||
朝代,3
|
||||
身高,4
|
||||
创始人,5
|
||||
上映时间,6
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,147 @@
|
|||
import os
|
||||
import sys
|
||||
import torch
|
||||
import logging
|
||||
import hydra
|
||||
from hydra import utils
|
||||
from serializer import Serializer
|
||||
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
|
||||
import matplotlib.pyplot as plt
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from utils import load_pkl, load_csv
|
||||
import models as models
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _preprocess_data(data, cfg):
|
||||
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
||||
attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False)
|
||||
atts = _handle_attribute_data(attribute_data)
|
||||
cfg.vocab_size = vocab.count
|
||||
serializer = Serializer(do_chinese_split=cfg.chinese_split)
|
||||
serial = serializer.serialize
|
||||
|
||||
_serialize_sentence(data, serial, cfg)
|
||||
_convert_tokens_into_index(data, vocab)
|
||||
_add_pos_seq(data, cfg)
|
||||
logger.info('start sentence preprocess...')
|
||||
formats = '\nsentence: {}\nchinese_split: {}\n' \
|
||||
'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}'
|
||||
logger.info(
|
||||
formats.format(data[0]['sentence'], cfg.chinese_split,
|
||||
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
|
||||
data[0]['entity_index'], data[0]['attribute_value_index']))
|
||||
return data, atts
|
||||
|
||||
|
||||
def _get_predict_instance(cfg):
|
||||
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
|
||||
flag = flag.strip().lower()
|
||||
if flag == 'y' or flag == 'yes':
|
||||
sentence = '张冬梅,女,汉族,1968年2月生,河南淇县人,1988年7月加入中国共产党,1989年9月参加工作,中央党校经济管理专业毕业,中央党校研究生学历'
|
||||
entity = '张冬梅'
|
||||
attribute_value = '汉族'
|
||||
elif flag == 'n' or flag == 'no':
|
||||
sentence = input('请输入句子:')
|
||||
entity = input('请输入句中需要预测的实体:')
|
||||
attribute_value = input('请输入句中需要预测的属性值:')
|
||||
elif flag == 'exit':
|
||||
sys.exit(0)
|
||||
else:
|
||||
print('please input yes or no, or exit!')
|
||||
_get_predict_instance(cfg)
|
||||
|
||||
|
||||
instance = dict()
|
||||
instance['sentence'] = sentence.strip()
|
||||
instance['entity'] = entity.strip()
|
||||
instance['attribute_value'] = attribute_value.strip()
|
||||
instance['entity_offset'] = sentence.find(entity)
|
||||
instance['attribute_value_offset'] = sentence.find(attribute_value)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_path='../conf/config.yaml')
|
||||
def main(cfg):
|
||||
cwd = utils.get_original_cwd()
|
||||
cwd = cwd[0:-5]
|
||||
cfg.cwd = cwd
|
||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||
print(cfg.pretty())
|
||||
|
||||
# get predict instance
|
||||
instance = _get_predict_instance(cfg)
|
||||
data = [instance]
|
||||
|
||||
# preprocess data
|
||||
data, rels = _preprocess_data(data, cfg)
|
||||
|
||||
# model
|
||||
__Model__ = {
|
||||
'cnn': models.PCNN,
|
||||
'rnn': models.BiLSTM,
|
||||
'transformer': models.Transformer,
|
||||
'gcn': models.GCN,
|
||||
'capsule': models.Capsule,
|
||||
'lm': models.LM,
|
||||
}
|
||||
|
||||
# 最好在 cpu 上预测
|
||||
cfg.use_gpu = False
|
||||
if cfg.use_gpu and torch.cuda.is_available():
|
||||
device = torch.device('cuda', cfg.gpu_id)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
logger.info(f'device: {device}')
|
||||
|
||||
model = __Model__[cfg.model_name](cfg)
|
||||
logger.info(f'model name: {cfg.model_name}')
|
||||
logger.info(f'\n {model}')
|
||||
model.load(cfg.fp, device=device)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
x = dict()
|
||||
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
||||
|
||||
if cfg.model_name != 'lm':
|
||||
x['entity_pos'], x['attribute_value_pos'] = torch.tensor([data[0]['entity_pos']]), torch.tensor([data[0]['attribute_value_pos']])
|
||||
if cfg.model_name == 'cnn':
|
||||
if cfg.use_pcnn:
|
||||
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
||||
if cfg.model_name == 'gcn':
|
||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
|
||||
x['adj'] = adj
|
||||
|
||||
|
||||
for key in x.keys():
|
||||
x[key] = x[key].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
y_pred = model(x)
|
||||
y_pred = torch.softmax(y_pred, dim=-1)[0]
|
||||
prob = y_pred.max().item()
|
||||
prob_att = list(rels.keys())[y_pred.argmax().item()]
|
||||
logger.info(f"\"{data[0]['entity']}\" 和 \"{data[0]['attribute_value']}\" 在句中属性为:\"{prob_att}\",置信度为{prob:.2f}。")
|
||||
|
||||
if cfg.predict_plot:
|
||||
plt.rcParams["font.family"] = 'Arial Unicode MS'
|
||||
x = list(rels.keys())
|
||||
height = list(y_pred.cpu().numpy())
|
||||
plt.bar(x, height)
|
||||
for x, y in zip(x, height):
|
||||
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
|
||||
plt.xlabel('关系')
|
||||
plt.ylabel('置信度')
|
||||
plt.xticks(rotation=315)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,8 @@
|
|||
torch == 1.5
|
||||
hydra-core == 1.0.6
|
||||
tensorboard == 2.4.1
|
||||
matplotlib == 3.4.1
|
||||
scikit-learn == 0.24.1
|
||||
transformers == 4.5.0
|
||||
jieba == 0.42.1
|
||||
deepke
|
|
@ -0,0 +1,148 @@
|
|||
import os
|
||||
import hydra
|
||||
import torch
|
||||
import logging
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
from hydra import utils
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
# self
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
import models as models
|
||||
from tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||
from utils import manual_seed, load_pkl
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@hydra.main(config_path="../conf/config.yaml")
|
||||
def main(cfg):
|
||||
cwd = utils.get_original_cwd()
|
||||
cwd = cwd[0:-5]
|
||||
cfg.cwd = cwd
|
||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||
logger.info(f'\n{cfg.pretty()}')
|
||||
|
||||
__Model__ = {
|
||||
'cnn': models.PCNN,
|
||||
'rnn': models.BiLSTM,
|
||||
'transformer': models.Transformer,
|
||||
'gcn': models.GCN,
|
||||
'capsule': models.Capsule,
|
||||
'lm': models.LM,
|
||||
}
|
||||
|
||||
# device
|
||||
if cfg.use_gpu and torch.cuda.is_available():
|
||||
device = torch.device('cuda', cfg.gpu_id)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
logger.info(f'device: {device}')
|
||||
|
||||
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
|
||||
if cfg.preprocess:
|
||||
preprocess(cfg)
|
||||
|
||||
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
|
||||
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
|
||||
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
|
||||
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
|
||||
|
||||
if cfg.model_name == 'lm':
|
||||
vocab_size = None
|
||||
else:
|
||||
vocab = load_pkl(vocab_path)
|
||||
vocab_size = vocab.count
|
||||
cfg.vocab_size = vocab_size
|
||||
|
||||
train_dataset = CustomDataset(train_data_path)
|
||||
valid_dataset = CustomDataset(valid_data_path)
|
||||
test_dataset = CustomDataset(test_data_path)
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||
|
||||
model = __Model__[cfg.model_name](cfg)
|
||||
model.to(device)
|
||||
logger.info(f'\n {model}')
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
best_f1, best_epoch = -1, 0
|
||||
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
|
||||
train_losses, valid_losses = [], []
|
||||
|
||||
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
|
||||
writer = SummaryWriter('tensorboard')
|
||||
else:
|
||||
writer = None
|
||||
|
||||
logger.info('=' * 10 + ' Start training ' + '=' * 10)
|
||||
|
||||
for epoch in range(1, cfg.epoch + 1):
|
||||
manual_seed(cfg.seed + epoch)
|
||||
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
|
||||
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
|
||||
scheduler.step(valid_loss)
|
||||
model_path = model.save(epoch, cfg)
|
||||
# logger.info(model_path)
|
||||
|
||||
train_losses.append(train_loss)
|
||||
valid_losses.append(valid_loss)
|
||||
if best_f1 < valid_f1:
|
||||
best_f1 = valid_f1
|
||||
best_epoch = epoch
|
||||
# 使用 valid loss 做 early stopping 的判断标准
|
||||
if es_loss > valid_loss:
|
||||
es_loss = valid_loss
|
||||
es_f1 = valid_f1
|
||||
es_epoch = epoch
|
||||
es_patience = 0
|
||||
es_path = model_path
|
||||
else:
|
||||
es_patience += 1
|
||||
if es_patience >= cfg.early_stopping_patience:
|
||||
best_es_epoch = es_epoch
|
||||
best_es_f1 = es_f1
|
||||
best_es_path = es_path
|
||||
|
||||
if cfg.show_plot:
|
||||
if cfg.plot_utils == 'matplot':
|
||||
plt.plot(train_losses, 'x-')
|
||||
plt.plot(valid_losses, '+-')
|
||||
plt.legend(['train', 'valid'])
|
||||
plt.title('train/valid comparison loss')
|
||||
plt.show()
|
||||
|
||||
if cfg.plot_utils == 'tensorboard':
|
||||
for i in range(len(train_losses)):
|
||||
writer.add_scalars('train/valid_comparison_loss', {
|
||||
'train': train_losses[i],
|
||||
'valid': valid_losses[i]
|
||||
}, i)
|
||||
writer.close()
|
||||
|
||||
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
|
||||
f'this epoch macro f1: {best_es_f1:0.4f}')
|
||||
logger.info(f'this model save path: {best_es_path}')
|
||||
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
|
||||
f'this epoch macro f1: {best_f1:.4f}')
|
||||
|
||||
logger.info('=====end of training====')
|
||||
logger.info('')
|
||||
logger.info('=====start test performance====')
|
||||
validate(-1, model, test_dataloader, criterion, device, cfg)
|
||||
logger.info('=====ending====')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
# python predict.py --help # 查看参数帮助
|
||||
# python predict.py -c
|
||||
# python predict.py chinese_split=0,1 replace_entity_with_type=0,1 -m
|
Binary file not shown.
Before Width: | Height: | Size: 149 KiB |
|
@ -0,0 +1,35 @@
|
|||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class BasicModule(nn.Module):
|
||||
'''
|
||||
封装nn.Module, 提供 save 和 load 方法
|
||||
'''
|
||||
def __init__(self):
|
||||
super(BasicModule, self).__init__()
|
||||
|
||||
|
||||
def load(self, path, device):
|
||||
'''
|
||||
加载指定路径的模型
|
||||
'''
|
||||
self.load_state_dict(torch.load(path, map_location=device))
|
||||
|
||||
|
||||
def save(self, epoch=0, cfg=None):
|
||||
'''
|
||||
保存模型,默认使用“模型名字+时间”作为文件名
|
||||
'''
|
||||
time_prefix = time.strftime('%Y-%m-%d_%H-%M-%S')
|
||||
prefix = os.path.join(cfg.cwd, 'checkpoints',time_prefix)
|
||||
os.makedirs(prefix, exist_ok=True)
|
||||
name = os.path.join(prefix, cfg.model_name + '_' + f'epoch{epoch}' + '.pth')
|
||||
|
||||
torch.save(self.state_dict(), name)
|
||||
return name
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
import torch.nn as nn
|
||||
from . import BasicModule
|
||||
from module import Embedding, RNN
|
||||
|
||||
|
||||
class BiLSTM(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(BiLSTM, self).__init__()
|
||||
|
||||
if cfg.dim_strategy == 'cat':
|
||||
cfg.input_size = cfg.word_dim + 2 * cfg.pos_dim
|
||||
else:
|
||||
cfg.input_size = cfg.word_dim
|
||||
|
||||
self.embedding = Embedding(cfg)
|
||||
self.bilstm = RNN(cfg)
|
||||
self.fc = nn.Linear(cfg.hidden_size, cfg.num_attributes)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens, entity_pos, attribute_value_pos = x['word'], x['lens'], x['entity_pos'], x['attribute_value_pos']
|
||||
inputs = self.embedding(word, entity_pos, attribute_value_pos)
|
||||
out, out_pool = self.bilstm(inputs, lens)
|
||||
output = self.fc(out_pool)
|
||||
|
||||
return output
|
|
@ -0,0 +1,51 @@
|
|||
import torch
|
||||
from . import BasicModule
|
||||
from module import Embedding, CNN
|
||||
from module import Capsule as CapsuleLayer
|
||||
|
||||
from utils import seq_len_to_mask, to_one_hot
|
||||
|
||||
|
||||
class Capsule(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(Capsule, self).__init__()
|
||||
|
||||
if cfg.dim_strategy == 'cat':
|
||||
cfg.in_channels = cfg.word_dim + 2 * cfg.pos_dim
|
||||
else:
|
||||
cfg.in_channels = cfg.word_dim
|
||||
|
||||
# capsule config
|
||||
cfg.input_dim_capsule = cfg.out_channels
|
||||
cfg.num_capsule = cfg.num_attributes
|
||||
|
||||
self.num_attributes = cfg.num_attributes
|
||||
self.embedding = Embedding(cfg)
|
||||
self.cnn = CNN(cfg)
|
||||
self.capsule = CapsuleLayer(cfg)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens, entity_pos, attribute_value_pos = x['word'], x['lens'], x['entity_pos'], x['attribute_value_pos']
|
||||
mask = seq_len_to_mask(lens)
|
||||
inputs = self.embedding(word, entity_pos, attribute_value_pos)
|
||||
|
||||
primary, _ = self.cnn(inputs) # 由于长度改变,无法定向mask,不mask可可以,毕竟primary capsule 就是粗粒度的信息
|
||||
output = self.capsule(primary)
|
||||
output = output.norm(p=2, dim=-1) # 求得模长再返回值
|
||||
|
||||
return output # [B, N]
|
||||
|
||||
def loss(self, predict, target, reduction='mean'):
|
||||
m_plus, m_minus, loss_lambda = 0.9, 0.1, 0.5
|
||||
|
||||
target = to_one_hot(target, self.num_attributes)
|
||||
max_l = (torch.relu(m_plus - predict))**2
|
||||
max_r = (torch.relu(predict - m_minus))**2
|
||||
loss = target * max_l + loss_lambda * (1 - target) * max_r
|
||||
loss = torch.sum(loss, dim=-1)
|
||||
|
||||
if reduction == 'sum':
|
||||
return loss.sum()
|
||||
else:
|
||||
# 默认情况为求平均
|
||||
return loss.mean()
|
|
@ -0,0 +1,32 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from . import BasicModule
|
||||
from module import Embedding
|
||||
from module import GCN as GCNBlock
|
||||
|
||||
from utils import seq_len_to_mask
|
||||
|
||||
|
||||
class GCN(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(GCN, self).__init__()
|
||||
|
||||
if cfg.dim_strategy == 'cat':
|
||||
cfg.input_size = cfg.word_dim + 2 * cfg.pos_dim
|
||||
else:
|
||||
cfg.input_size = cfg.word_dim
|
||||
|
||||
self.embedding = Embedding(cfg)
|
||||
self.gcn = GCNBlock(cfg)
|
||||
self.fc = nn.Linear(cfg.hidden_size, cfg.num_attributes)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens, entity_pos, attribute_value_pos, adj = x['word'], x['lens'], x['entity_pos'], x['attribute_value_pos'], x['adj']
|
||||
|
||||
inputs = self.embedding(word, entity_pos, attribute_value_pos)
|
||||
output = self.gcn(inputs, adj)
|
||||
output = output.max(dim=1)[0]
|
||||
output = self.fc(output)
|
||||
|
||||
return output
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
from torch import nn
|
||||
from . import BasicModule
|
||||
from module import RNN
|
||||
from transformers import BertModel
|
||||
|
||||
from utils import seq_len_to_mask
|
||||
|
||||
|
||||
class LM(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(LM, self).__init__()
|
||||
self.bert = BertModel.from_pretrained(cfg.lm_file, num_hidden_layers=cfg.num_hidden_layers)
|
||||
self.bilstm = RNN(cfg)
|
||||
self.fc = nn.Linear(cfg.hidden_size, cfg.num_attributes)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens = x['word'], x['lens']
|
||||
mask = seq_len_to_mask(lens, mask_pos_to_true=False)
|
||||
a = self.bert(word, attention_mask=mask)
|
||||
last_hidden_state = a[0]
|
||||
# pooler_output = a[1]
|
||||
_, out_pool = self.bilstm(last_hidden_state, lens)
|
||||
out_pool = self.dropout(out_pool)
|
||||
output = self.fc(out_pool)
|
||||
|
||||
return output
|
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from . import BasicModule
|
||||
from module import Embedding, CNN
|
||||
|
||||
from utils import seq_len_to_mask
|
||||
|
||||
|
||||
class PCNN(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(PCNN, self).__init__()
|
||||
|
||||
self.use_pcnn = cfg.use_pcnn
|
||||
if cfg.dim_strategy == 'cat':
|
||||
cfg.in_channels = cfg.word_dim + 2 * cfg.pos_dim
|
||||
else:
|
||||
cfg.in_channels = cfg.word_dim
|
||||
|
||||
self.embedding = Embedding(cfg)
|
||||
self.cnn = CNN(cfg)
|
||||
self.fc1 = nn.Linear(len(cfg.kernel_sizes) * cfg.out_channels, cfg.intermediate)
|
||||
self.fc2 = nn.Linear(cfg.intermediate, cfg.num_attributes)
|
||||
self.dropout = nn.Dropout(cfg.dropout)
|
||||
|
||||
if self.use_pcnn:
|
||||
self.fc_pcnn = nn.Linear(3 * len(cfg.kernel_sizes) * cfg.out_channels,
|
||||
len(cfg.kernel_sizes) * cfg.out_channels)
|
||||
self.pcnn_mask_embedding = nn.Embedding(4, 3)
|
||||
masks = torch.tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0, 100]])
|
||||
self.pcnn_mask_embedding.weight.data.copy_(masks)
|
||||
self.pcnn_mask_embedding.weight.requires_grad = False
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
word, lens, entity_pos, attribute_value_pos = x['word'], x['lens'], x['entity_pos'], x['attribute_value_pos']
|
||||
mask = seq_len_to_mask(lens)
|
||||
|
||||
inputs = self.embedding(word, entity_pos, attribute_value_pos)
|
||||
out, out_pool = self.cnn(inputs, mask=mask)
|
||||
|
||||
if self.use_pcnn:
|
||||
out = out.unsqueeze(-1) # [B, L, Hs, 1]
|
||||
pcnn_mask = x['pcnn_mask']
|
||||
pcnn_mask = self.pcnn_mask_embedding(pcnn_mask).unsqueeze(-2) # [B, L, 1, 3]
|
||||
out = out + pcnn_mask # [B, L, Hs, 3]
|
||||
out = out.max(dim=1)[0] - 100 # [B, Hs, 3]
|
||||
out_pool = out.view(out.size(0), -1) # [B, 3 * Hs]
|
||||
out_pool = F.leaky_relu(self.fc_pcnn(out_pool)) # [B, Hs]
|
||||
out_pool = self.dropout(out_pool)
|
||||
|
||||
output = self.fc1(out_pool)
|
||||
output = F.leaky_relu(output)
|
||||
output = self.dropout(output)
|
||||
output = self.fc2(output)
|
||||
|
||||
return output
|
|
@ -0,0 +1,30 @@
|
|||
import torch.nn as nn
|
||||
from . import BasicModule
|
||||
from module import Embedding
|
||||
from module import Transformer as TransformerBlock
|
||||
|
||||
from utils import seq_len_to_mask
|
||||
|
||||
|
||||
class Transformer(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(Transformer, self).__init__()
|
||||
|
||||
if cfg.dim_strategy == 'cat':
|
||||
cfg.hidden_size = cfg.word_dim + 2 * cfg.pos_dim
|
||||
else:
|
||||
cfg.hidden_size = cfg.word_dim
|
||||
|
||||
self.embedding = Embedding(cfg)
|
||||
self.transformer = TransformerBlock(cfg)
|
||||
self.fc = nn.Linear(cfg.hidden_size, cfg.num_attributes)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens, entity_pos, attribute_value_pos = x['word'], x['lens'], x['entity_pos'], x['attribute_value_pos']
|
||||
mask = seq_len_to_mask(lens)
|
||||
inputs = self.embedding(word, entity_pos, attribute_value_pos)
|
||||
last_layer_hidden_state, all_hidden_states, all_attentions = self.transformer(inputs, key_padding_mask=mask)
|
||||
out_pool = last_layer_hidden_state.max(dim=1)[0]
|
||||
output = self.fc(out_pool)
|
||||
|
||||
return output
|
|
@ -0,0 +1,7 @@
|
|||
from .BasicModule import BasicModule
|
||||
from .PCNN import PCNN
|
||||
from .BiLSTM import BiLSTM
|
||||
from .Transformer import Transformer
|
||||
from .Capsule import Capsule
|
||||
from .GCN import GCN
|
||||
from .LM import LM
|
|
@ -0,0 +1,141 @@
|
|||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DotAttention(nn.Module):
|
||||
def __init__(self, dropout=0.0):
|
||||
super(DotAttention, self).__init__()
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, Q, K, V, mask_out=None, head_mask=None):
|
||||
"""
|
||||
一般输入信息 X 时,假设 K = V = X
|
||||
|
||||
att_weight = softmax( score_func(q, k) )
|
||||
att = sum( att_weight * v )
|
||||
|
||||
:param Q: [..., L, H]
|
||||
:param K: [..., S, H]
|
||||
:param V: [..., S, H]
|
||||
:param mask_out: [..., 1, S]
|
||||
:return:
|
||||
"""
|
||||
H = Q.size(-1)
|
||||
|
||||
scale = float(H)**0.5
|
||||
attention_weight = torch.matmul(Q, K.transpose(-1, -2)) / scale
|
||||
|
||||
if mask_out is not None:
|
||||
# 当 DotAttention 单独使用时(几乎不会),保证维度一样
|
||||
while mask_out.dim() != Q.dim():
|
||||
mask_out = mask_out.unsqueeze(1)
|
||||
attention_weight.masked_fill_(mask_out, -1e8)
|
||||
|
||||
attention_weight = F.softmax(attention_weight, dim=-1)
|
||||
|
||||
attention_weight = F.dropout(attention_weight, self.dropout)
|
||||
|
||||
# mask heads if we want to:
|
||||
# multi head 才会使用
|
||||
if head_mask is not None:
|
||||
attention_weight = attention_weight * head_mask
|
||||
|
||||
attention_out = torch.matmul(attention_weight, V)
|
||||
|
||||
return attention_out, attention_weight
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, dropout=0.0, output_attentions=True):
|
||||
"""
|
||||
:param embed_dim: 输入的维度,必须能被 num_heads 整除
|
||||
:param num_heads: attention 的个数
|
||||
:param dropout: float。
|
||||
"""
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
self.output_attentions = output_attentions
|
||||
self.head_dim = int(embed_dim / num_heads)
|
||||
self.all_head_dim = self.head_dim * num_heads
|
||||
assert self.all_head_dim == embed_dim, logger.error(
|
||||
f"embed_dim{embed_dim} must be divisible by num_heads{num_heads}")
|
||||
|
||||
self.q_in = nn.Linear(embed_dim, self.all_head_dim)
|
||||
self.k_in = nn.Linear(embed_dim, self.all_head_dim)
|
||||
self.v_in = nn.Linear(embed_dim, self.all_head_dim)
|
||||
self.attention = DotAttention(dropout=dropout)
|
||||
self.out = nn.Linear(self.all_head_dim, embed_dim)
|
||||
|
||||
def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||
"""
|
||||
:param Q: [B, L, Hs]
|
||||
:param K: [B, S, Hs]
|
||||
:param V: [B, S, Hs]
|
||||
:param key_padding_mask: [B, S] 为 1/True 的地方需要 mask
|
||||
:param attention_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
|
||||
:param head_mask: [N] 指定 head mask 掉, 为 1/True 的地方需要 mask
|
||||
"""
|
||||
B, L, Hs = Q.shape
|
||||
S = V.size(1)
|
||||
N, H = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q_in(Q).view(B, L, N, H).transpose(1, 2) # [B, N, L, H]
|
||||
k = self.k_in(K).view(B, S, N, H).transpose(1, 2) # [B, N, S, H]
|
||||
v = self.v_in(V).view(B, S, N, H).transpose(1, 2) # [B, N, S, H]
|
||||
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = key_padding_mask.ne(0)
|
||||
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.ne(0)
|
||||
if attention_mask.dim() == 1:
|
||||
attention_mask = attention_mask.unsqueeze(0)
|
||||
elif attention_mask.dim() == 2:
|
||||
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1)
|
||||
else:
|
||||
raise ValueError(f'attention_mask dim must be 1 or 2, can not be {attention_mask.dim()}')
|
||||
|
||||
if key_padding_mask is None:
|
||||
mask_out = attention_mask if attention_mask is not None else None
|
||||
else:
|
||||
mask_out = (key_padding_mask + attention_mask).ne(0) if attention_mask is not None else key_padding_mask
|
||||
|
||||
if head_mask is not None:
|
||||
head_mask = head_mask.eq(0)
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
attention_out, attention_weight = self.attention(q, k, v, mask_out=mask_out, head_mask=head_mask)
|
||||
|
||||
attention_out = attention_out.transpose(1, 2).reshape(B, L, N * H) # [B, N, L, H] -> [B, L, N * H]
|
||||
|
||||
# concat all heads, and do output linear
|
||||
attention_out = self.out(attention_out) # [B, L, N * H] -> [B, L, H]
|
||||
|
||||
if self.output_attentions:
|
||||
return attention_out, attention_weight
|
||||
else:
|
||||
return attention_out,
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from utils import seq_len_to_mask
|
||||
|
||||
q = torch.randn(4, 6, 20) # [B, L, H]
|
||||
k = v = torch.randn(4, 5, 20) # [B, S, H]
|
||||
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
|
||||
attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉
|
||||
head_mask = torch.tensor([0, 1]) # 为1 的地方 mask 掉
|
||||
|
||||
m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0, output_attentions=True)
|
||||
ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask)
|
||||
print(ao.shape, aw.shape) # [B, L, H] [B, N, L, S]
|
||||
print(ao)
|
||||
print(aw.unbind(1))
|
|
@ -0,0 +1,114 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self):
|
||||
super(GELU, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
"""
|
||||
nlp 里为了保证输出的句长 = 输入的句长,一般使用奇数 kernel_size,如 [3, 5, 7, 9]
|
||||
当然也可以不等长输出,keep_length 设为 False
|
||||
此时,padding = k // 2
|
||||
stride 一般为 1
|
||||
"""
|
||||
def __init__(self, config):
|
||||
"""
|
||||
in_channels : 一般就是 word embedding 的维度,或者 hidden size 的维度
|
||||
out_channels : int
|
||||
kernel_sizes : list 为了保证输出长度=输入长度,必须为奇数: 3, 5, 7...
|
||||
activation : [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
pooling_strategy : [max, avg, cls]
|
||||
dropout: : float
|
||||
"""
|
||||
super(CNN, self).__init__()
|
||||
|
||||
# self.xxx = config.xxx
|
||||
self.in_channels = config.in_channels
|
||||
self.out_channels = config.out_channels
|
||||
self.kernel_sizes = config.kernel_sizes
|
||||
self.activation = config.activation
|
||||
self.pooling_strategy = config.pooling_strategy
|
||||
self.dropout = config.dropout
|
||||
self.keep_length = config.keep_length
|
||||
for kernel_size in self.kernel_sizes:
|
||||
assert kernel_size % 2 == 1, "kernel size has to be odd numbers."
|
||||
|
||||
# convolution
|
||||
self.convs = nn.ModuleList([
|
||||
nn.Conv1d(in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=k // 2 if self.keep_length else 0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=False) for k in self.kernel_sizes
|
||||
])
|
||||
|
||||
# activation function
|
||||
assert self.activation in ['relu', 'lrelu', 'prelu', 'selu', 'celu', 'gelu', 'sigmoid', 'tanh'], \
|
||||
'activation function must choose from [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]'
|
||||
self.activations = nn.ModuleDict([
|
||||
['relu', nn.ReLU()],
|
||||
['lrelu', nn.LeakyReLU()],
|
||||
['prelu', nn.PReLU()],
|
||||
['selu', nn.SELU()],
|
||||
['celu', nn.CELU()],
|
||||
['gelu', GELU()],
|
||||
['sigmoid', nn.Sigmoid()],
|
||||
['tanh', nn.Tanh()],
|
||||
])
|
||||
|
||||
# pooling
|
||||
assert self.pooling_strategy in ['max', 'avg', 'cls'], 'pooling strategy must choose from [max, avg, cls]'
|
||||
|
||||
self.dropout = nn.Dropout(self.dropout)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值
|
||||
:param mask: [batch_size, max_len], 句长部分为0,padding部分为1。不影响卷积运算,max-pool一定不会pool到pad为0的位置
|
||||
:return:
|
||||
"""
|
||||
# [B, L, H] -> [B, H, L] (注释:将 H 维度当作输入 channel 维度)
|
||||
x = torch.transpose(x, 1, 2)
|
||||
|
||||
# convolution + activation [[B, H, L], ... ]
|
||||
act_fn = self.activations[self.activation]
|
||||
|
||||
x = [act_fn(conv(x)) for conv in self.convs]
|
||||
x = torch.cat(x, dim=1)
|
||||
|
||||
# mask
|
||||
if mask is not None:
|
||||
# [B, L] -> [B, 1, L]
|
||||
mask = mask.unsqueeze(1)
|
||||
x = x.masked_fill_(mask, 1e-12)
|
||||
|
||||
# pooling
|
||||
# [[B, H, L], ... ] -> [[B, H], ... ]
|
||||
if self.pooling_strategy == 'max':
|
||||
xp = F.max_pool1d(x, kernel_size=x.size(2)).squeeze(2)
|
||||
# 等价于 xp = torch.max(x, dim=2)[0]
|
||||
|
||||
elif self.pooling_strategy == 'avg':
|
||||
x_len = mask.squeeze().eq(0).sum(-1).unsqueeze(-1).to(torch.float).to(device=mask.device)
|
||||
xp = torch.sum(x, dim=-1) / x_len
|
||||
|
||||
else:
|
||||
# self.pooling_strategy == 'cls'
|
||||
xp = x[:, :, 0]
|
||||
|
||||
x = x.transpose(1, 2)
|
||||
x = self.dropout(x)
|
||||
xp = self.dropout(xp)
|
||||
|
||||
return x, xp # [B, L, Hs], [B, Hs]
|
|
@ -0,0 +1,54 @@
|
|||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Capsule(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super(Capsule, self).__init__()
|
||||
|
||||
# self.xxx = cfg.xxx
|
||||
self.input_dim_capsule = cfg.input_dim_capsule
|
||||
self.dim_capsule = cfg.dim_capsule
|
||||
self.num_capsule = cfg.num_capsule
|
||||
self.batch_size = cfg.batch_size
|
||||
self.share_weights = cfg.share_weights
|
||||
self.num_iterations = cfg.num_iterations
|
||||
|
||||
if self.share_weights:
|
||||
W = torch.zeros(1, self.input_dim_capsule, self.num_capsule * self.dim_capsule)
|
||||
else:
|
||||
W = torch.zeros(self.batch_size, self.input_dim_capsule, self.num_capsule * self.dim_capsule)
|
||||
|
||||
W = nn.init.xavier_normal_(W)
|
||||
self.W = nn.Parameter(W)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, L, H] # 从 CNN / RNN 得到的结果
|
||||
L 作为 input_num_capsules, H 作为 input_dim_capsule
|
||||
"""
|
||||
B, I, _ = x.size() # I 是 input_num_capsules
|
||||
O, F = self.num_capsule, self.dim_capsule
|
||||
|
||||
u = torch.matmul(x, self.W)
|
||||
u = u.view(B, I, O, F).transpose(1, 2) # [B, O, I, F]
|
||||
|
||||
b = torch.zeros_like(u[:, :, :, 0]).to(device=u.device) # [B, O, I]
|
||||
for i in range(self.num_iterations):
|
||||
c = torch.softmax(b, dim=1) # [B, O_s, I]
|
||||
v = torch.einsum('boi,boif->bof', [c, u]) # [B, O, F]
|
||||
v = self.squash(v)
|
||||
b = torch.einsum('bof,boif->boi', [v, u]) # [B, O, I]
|
||||
|
||||
return v # [B, O, F] [B, num_capsule, dim_capsule]
|
||||
|
||||
@staticmethod
|
||||
def squash(x: torch.Tensor):
|
||||
x_norm = x.norm(p=2, dim=-1, keepdim=True)
|
||||
mag = x_norm**2
|
||||
out = x / x_norm * mag / (1 + mag)
|
||||
|
||||
return out
|
|
@ -0,0 +1,39 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
word embedding: 一般 0 为 padding
|
||||
pos embedding: 一般 0 为 padding
|
||||
dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加
|
||||
"""
|
||||
super(Embedding, self).__init__()
|
||||
|
||||
# self.xxx = config.xxx
|
||||
self.vocab_size = config.vocab_size
|
||||
self.word_dim = config.word_dim
|
||||
self.pos_size = config.pos_size
|
||||
self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim
|
||||
self.dim_strategy = config.dim_strategy
|
||||
|
||||
self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0)
|
||||
self.headPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
|
||||
self.tailPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(self.word_dim)
|
||||
|
||||
def forward(self, *x):
|
||||
word, head, tail = x
|
||||
word_embedding = self.wordEmbed(word)
|
||||
head_embedding = self.headPosEmbed(head)
|
||||
tail_embedding = self.tailPosEmbed(tail)
|
||||
|
||||
if self.dim_strategy == 'cat':
|
||||
return torch.cat((word_embedding, head_embedding, tail_embedding), -1)
|
||||
elif self.dim_strategy == 'sum':
|
||||
# 此时 pos_dim == word_dim
|
||||
return self.layer_norm(word_embedding + head_embedding + tail_embedding)
|
||||
else:
|
||||
raise Exception('dim_strategy must choose from [sum, cat]')
|
|
@ -0,0 +1,143 @@
|
|||
import logging
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class GCN(nn.Module):
|
||||
def __init__(self,cfg):
|
||||
super(GCN , self).__init__()
|
||||
|
||||
self.num_layers = cfg.num_layers
|
||||
self.input_size = cfg.input_size
|
||||
self.hidden_size = cfg.hidden_size
|
||||
self.dropout = cfg.dropout
|
||||
|
||||
self.fc1 = nn.Linear(self.input_size , self.hidden_size)
|
||||
self.fc = nn.Linear(self.hidden_size , self.hidden_size)
|
||||
self.weight_list = nn.ModuleList()
|
||||
for i in range(self.num_layers):
|
||||
self.weight_list.append(nn.Linear(self.hidden_size * (i + 1),self.hidden_size))
|
||||
self.dropout = nn.Dropout(self.dropout)
|
||||
|
||||
def forward(self , x, adj):
|
||||
L = adj.sum(2).unsqueeze(2) + 1
|
||||
outputs = self.fc1(x)
|
||||
cache_list = [outputs]
|
||||
output_list = []
|
||||
for l in range(self.num_layers):
|
||||
Ax = adj.bmm(outputs)
|
||||
AxW = self.weight_list[l](Ax)
|
||||
AxW = AxW + self.weight_list[l](outputs)
|
||||
AxW = AxW / L
|
||||
gAxW = F.relu(AxW)
|
||||
cache_list.append(gAxW)
|
||||
outputs = torch.cat(cache_list , dim=2)
|
||||
output_list.append(self.dropout(gAxW))
|
||||
# gcn_outputs = torch.cat(output_list, dim=2)
|
||||
gcn_outputs = output_list[self.num_layers - 1]
|
||||
gcn_outputs = gcn_outputs + self.fc1(x)
|
||||
|
||||
out = self.fc(gcn_outputs)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class Tree(object):
|
||||
def __init__(self):
|
||||
self.parent = None
|
||||
self.num_children = 0
|
||||
self.children = list()
|
||||
|
||||
def add_child(self, child):
|
||||
child.parent = self
|
||||
self.num_children += 1
|
||||
self.children.append(child)
|
||||
|
||||
def size(self):
|
||||
s = getattr(self, '_size', -1)
|
||||
if s != -1:
|
||||
return self._size
|
||||
else:
|
||||
count = 1
|
||||
for i in range(self.num_children):
|
||||
count += self.children[i].size()
|
||||
self._size = count
|
||||
return self._size
|
||||
|
||||
def __iter__(self):
|
||||
yield self
|
||||
for c in self.children:
|
||||
for x in c:
|
||||
yield x
|
||||
|
||||
def depth(self):
|
||||
d = getattr(self, '_depth', -1)
|
||||
if d != -1:
|
||||
return self._depth
|
||||
else:
|
||||
count = 0
|
||||
if self.num_children > 0:
|
||||
for i in range(self.num_children):
|
||||
child_depth = self.children[i].depth()
|
||||
if child_depth > count:
|
||||
count = child_depth
|
||||
count += 1
|
||||
self._depth = count
|
||||
return self._depth
|
||||
|
||||
|
||||
def head_to_adj(head, directed=True, self_loop=False):
|
||||
"""
|
||||
Convert a sequence of head indexes to an (numpy) adjacency matrix.
|
||||
"""
|
||||
seq_len = len(head)
|
||||
head = head[:seq_len]
|
||||
root = None
|
||||
nodes = [Tree() for _ in head]
|
||||
|
||||
for i in range(seq_len):
|
||||
h = head[i]
|
||||
setattr(nodes[i], 'idx', i)
|
||||
if h == 0:
|
||||
root = nodes[i]
|
||||
else:
|
||||
nodes[h - 1].add_child(nodes[i])
|
||||
|
||||
assert root is not None
|
||||
|
||||
ret = np.zeros((seq_len, seq_len), dtype=np.float32)
|
||||
queue = [root]
|
||||
idx = []
|
||||
while len(queue) > 0:
|
||||
t, queue = queue[0], queue[1:]
|
||||
idx += [t.idx]
|
||||
for c in t.children:
|
||||
ret[t.idx, c.idx] = 1
|
||||
queue += t.children
|
||||
|
||||
if not directed:
|
||||
ret = ret + ret.T
|
||||
|
||||
if self_loop:
|
||||
for i in idx:
|
||||
ret[i, i] = 1
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def pad_adj(adj, max_len):
|
||||
pad_len = max_len - adj.shape[0]
|
||||
for i in range(pad_len):
|
||||
adj = np.insert(adj, adj.shape[-1], 0, axis=1)
|
||||
for i in range(len):
|
||||
adj = np.insert(adj, adj.shape[0], 0, axis=0)
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
|
||||
|
||||
class RNN(nn.Module):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
type_rnn: RNN, GRU, LSTM 可选
|
||||
"""
|
||||
super(RNN, self).__init__()
|
||||
|
||||
# self.xxx = config.xxx
|
||||
self.input_size = config.input_size
|
||||
self.hidden_size = config.hidden_size // 2 if config.bidirectional else config.hidden_size
|
||||
self.num_layers = config.num_layers
|
||||
self.dropout = config.dropout
|
||||
self.bidirectional = config.bidirectional
|
||||
self.last_layer_hn = config.last_layer_hn
|
||||
self.type_rnn = config.type_rnn
|
||||
|
||||
rnn = eval(f'nn.{self.type_rnn}')
|
||||
self.rnn = rnn(input_size=self.input_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
bidirectional=self.bidirectional,
|
||||
bias=True,
|
||||
batch_first=True)
|
||||
|
||||
# 有bug
|
||||
# self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""orthogonal init yields generally good results than uniform init"""
|
||||
gain = 1 # use default value
|
||||
for nth in range(self.num_layers * self.bidirectional):
|
||||
# w_ih, (4 * hidden_size x input_size)
|
||||
nn.init.orthogonal_(self.rnn.all_weights[nth][0], gain=gain)
|
||||
# w_hh, (4 * hidden_size x hidden_size)
|
||||
nn.init.orthogonal_(self.rnn.all_weights[nth][1], gain=gain)
|
||||
# b_ih, (4 * hidden_size)
|
||||
nn.init.zeros_(self.rnn.all_weights[nth][2])
|
||||
# b_hh, (4 * hidden_size)
|
||||
nn.init.zeros_(self.rnn.all_weights[nth][3])
|
||||
|
||||
def forward(self, x, x_len):
|
||||
"""
|
||||
Args:
|
||||
torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值
|
||||
x_len: torch.Tensor [L] 已经排好序的句长值
|
||||
Returns:
|
||||
output: torch.Tensor [B, L, H_out] 序列标注的使用结果
|
||||
hn: torch.Tensor [B, N, H_out] / [B, H_out] 分类的结果,当 last_layer_hn 时只有最后一层结果
|
||||
"""
|
||||
B, L, _ = x.size()
|
||||
H, N = self.hidden_size, self.num_layers
|
||||
|
||||
x_len = x_len.cpu()
|
||||
x = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=True)
|
||||
output, hn = self.rnn(x)
|
||||
output, _ = pad_packed_sequence(output, batch_first=True, total_length=L)
|
||||
|
||||
if self.type_rnn == 'LSTM':
|
||||
hn = hn[0]
|
||||
if self.bidirectional:
|
||||
hn = hn.view(N, 2, B, H).transpose(1, 2).contiguous().view(N, B, 2 * H).transpose(0, 1)
|
||||
else:
|
||||
hn = hn.transpose(0, 1)
|
||||
if self.last_layer_hn:
|
||||
hn = hn[:, -1, :]
|
||||
|
||||
return output, hn
|
|
@ -0,0 +1,149 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .Attention import MultiHeadAttention
|
||||
|
||||
|
||||
def gelu(x):
|
||||
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def gelu_new(x):
|
||||
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new}
|
||||
|
||||
|
||||
class TransformerAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(TransformerAttention, self).__init__()
|
||||
|
||||
# self.xxx = config.xxx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_heads
|
||||
self.dropout = config.dropout
|
||||
self.output_attentions = config.output_attentions
|
||||
self.layer_norm_eps = config.layer_norm_eps
|
||||
|
||||
self.multihead_attention = MultiHeadAttention(self.hidden_size, self.num_heads, self.dropout,
|
||||
self.output_attentions)
|
||||
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
|
||||
self.dropout = nn.Dropout(self.dropout)
|
||||
self.layerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||
|
||||
def forward(self, x, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||
"""
|
||||
:param x: [B, L, Hs]
|
||||
:param attention_mask: [B, L] padding后的句子后面补0了,补0的位置为True,前面部分为False
|
||||
:param head_mask: [L] [N,L]
|
||||
:return:
|
||||
"""
|
||||
attention_outputs = self.multihead_attention(x, x, x, key_padding_mask, attention_mask, head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
attention_output = self.dense(attention_output)
|
||||
attention_output = self.dropout(attention_output)
|
||||
attention_output = self.layerNorm(attention_output + x)
|
||||
outputs = (attention_output, ) + attention_outputs[1:] # 后面是 attention weight
|
||||
return outputs
|
||||
|
||||
|
||||
class TransformerOutput(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(TransformerOutput, self).__init__()
|
||||
|
||||
# self.xxx = config.xxx
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.dropout = config.dropout
|
||||
self.layer_norm_eps = config.layer_norm_eps
|
||||
|
||||
self.zoom_in = nn.Linear(self.hidden_size, self.intermediate_size)
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
self.zoom_out = nn.Linear(self.intermediate_size, self.hidden_size)
|
||||
self.dropout = nn.Dropout(self.dropout)
|
||||
self.layerNorm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
hidden_states = self.zoom_in(input_tensor)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
hidden_states = self.zoom_out(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.layerNorm(hidden_states + input_tensor)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(TransformerLayer, self).__init__()
|
||||
|
||||
self.attention = TransformerAttention(config)
|
||||
self.output = TransformerOutput(config)
|
||||
|
||||
def forward(self, hidden_states, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||
attention_outputs = self.attention(hidden_states, key_padding_mask, attention_mask, head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
layer_output = self.output(attention_output)
|
||||
outputs = (layer_output, ) + attention_outputs[1:]
|
||||
return outputs
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(Transformer, self).__init__()
|
||||
|
||||
# self.xxx = config.xxx
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
|
||||
self.layer = nn.ModuleList([TransformerLayer(config) for _ in range(self.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||
"""
|
||||
:param hidden_states: [B, L, Hs]
|
||||
:param key_padding_mask: [B, S] 为 1/True 的地方需要 mask
|
||||
:param attn_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
|
||||
:param head_mask: [N] / [L, N] 指定 head mask 掉, 为 1/True 的地方需要 mask
|
||||
"""
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.expand((self.num_hidden_layers, ) + head_mask.shape)
|
||||
else:
|
||||
head_mask = [None] * self.num_hidden_layers
|
||||
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||
|
||||
layer_outputs = layer_module(hidden_states, key_padding_mask, attention_mask, head_mask[i])
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1], )
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||
|
||||
outputs = (hidden_states, )
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states, )
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions, )
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
|
@ -0,0 +1,7 @@
|
|||
from .Embedding import Embedding
|
||||
from .CNN import CNN
|
||||
from .RNN import RNN
|
||||
from .Attention import DotAttention, MultiHeadAttention
|
||||
from .Transformer import Transformer
|
||||
from .Capsule import Capsule
|
||||
from .GCN import GCN
|
|
@ -0,0 +1,6 @@
|
|||
from .dataset import *
|
||||
from .metrics import *
|
||||
from .preprocess import *
|
||||
from .serializer import *
|
||||
from .trainer import *
|
||||
from .vocab import *
|
|
@ -0,0 +1,74 @@
|
|||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import os
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from deepke.utils import load_pkl
|
||||
|
||||
def collate_fn(cfg):
|
||||
|
||||
def collate_fn_intra(batch):
|
||||
"""
|
||||
Arg :
|
||||
batch () : 数据集
|
||||
Returna :
|
||||
x (dict) : key为词,value为长度
|
||||
y (List) : 关系对应值的集合
|
||||
"""
|
||||
batch.sort(key=lambda data: data['seq_len'], reverse=True)
|
||||
|
||||
max_len = batch[0]['seq_len']
|
||||
|
||||
def _padding(x, max_len):
|
||||
return x + [0] * (max_len - len(x))
|
||||
|
||||
x, y = dict(), []
|
||||
word, word_len = [], []
|
||||
head_pos, tail_pos = [], []
|
||||
pcnn_mask = []
|
||||
for data in batch:
|
||||
word.append(_padding(data['token2idx'], max_len))
|
||||
word_len.append(data['seq_len'])
|
||||
y.append(int(data['att2idx']))
|
||||
|
||||
if cfg.model_name != 'lm':
|
||||
head_pos.append(_padding(data['entity_pos'], max_len))
|
||||
tail_pos.append(_padding(data['attribute_value_pos'], max_len))
|
||||
if cfg.model_name == 'cnn':
|
||||
if cfg.use_pcnn:
|
||||
pcnn_mask.append(_padding(data['entities_pos'], max_len))
|
||||
|
||||
x['word'] = torch.tensor(word)
|
||||
x['lens'] = torch.tensor(word_len)
|
||||
y = torch.tensor(y)
|
||||
|
||||
if cfg.model_name != 'lm':
|
||||
x['entity_pos'] = torch.tensor(head_pos)
|
||||
x['attribute_value_pos'] = torch.tensor(tail_pos)
|
||||
if cfg.model_name == 'cnn' and cfg.use_pcnn:
|
||||
x['pcnn_mask'] = torch.tensor(pcnn_mask)
|
||||
if cfg.model_name == 'gcn':
|
||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||
B, L = len(batch), max_len
|
||||
adj = torch.empty(B, L, L).random_(2)
|
||||
x['adj'] = adj
|
||||
return x, y
|
||||
|
||||
return collate_fn_intra
|
||||
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
"""
|
||||
默认使用 List 存储数据
|
||||
"""
|
||||
def __init__(self, fp):
|
||||
self.file = load_pkl(fp)
|
||||
|
||||
def __getitem__(self, item):
|
||||
sample = self.file[item]
|
||||
return sample
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file)
|
||||
|
||||
|
|
@ -0,0 +1,142 @@
|
|||
import os
|
||||
import hydra
|
||||
import torch
|
||||
import logging
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
from hydra import utils
|
||||
import matplotlib.pyplot as plt
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
# self
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
import models as models
|
||||
from tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||
from utils import manual_seed, load_pkl
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@hydra.main(config_path="../conf/config.yaml")
|
||||
def main(cfg):
|
||||
cwd = utils.get_original_cwd()
|
||||
cwd = cwd[0:-5]
|
||||
cfg.cwd = cwd
|
||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||
logger.info(f'\n{cfg.pretty()}')
|
||||
|
||||
__Model__ = {
|
||||
'cnn': models.PCNN,
|
||||
'rnn': models.BiLSTM,
|
||||
'transformer': models.Transformer,
|
||||
'gcn': models.GCN,
|
||||
'capsule': models.Capsule,
|
||||
'lm': models.LM,
|
||||
}
|
||||
|
||||
# device
|
||||
if cfg.use_gpu and torch.cuda.is_available():
|
||||
device = torch.device('cuda', cfg.gpu_id)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
logger.info(f'device: {device}')
|
||||
|
||||
# 如果不修改预处理的过程,这一步最好注释掉,不用每次运行都预处理数据一次
|
||||
if cfg.preprocess:
|
||||
preprocess(cfg)
|
||||
|
||||
train_data_path = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
|
||||
valid_data_path = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
|
||||
test_data_path = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
|
||||
vocab_path = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
|
||||
|
||||
if cfg.model_name == 'lm':
|
||||
vocab_size = None
|
||||
else:
|
||||
vocab = load_pkl(vocab_path)
|
||||
vocab_size = vocab.count
|
||||
cfg.vocab_size = vocab_size
|
||||
|
||||
train_dataset = CustomDataset(train_data_path)
|
||||
valid_dataset = CustomDataset(valid_data_path)
|
||||
test_dataset = CustomDataset(test_data_path)
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||
valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||
test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
|
||||
|
||||
model = __Model__[cfg.model_name](cfg)
|
||||
model.to(device)
|
||||
logger.info(f'\n {model}')
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
best_f1, best_epoch = -1, 0
|
||||
es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, es_path, best_es_path = 1e8, -1, 0, 0, 0, -1, '', ''
|
||||
train_losses, valid_losses = [], []
|
||||
|
||||
if cfg.show_plot and cfg.plot_utils == 'tensorboard':
|
||||
writer = SummaryWriter('tensorboard')
|
||||
else:
|
||||
writer = None
|
||||
|
||||
logger.info('=' * 10 + ' Start training ' + '=' * 10)
|
||||
|
||||
for epoch in range(1, cfg.epoch + 1):
|
||||
manual_seed(cfg.seed + epoch)
|
||||
train_loss = train(epoch, model, train_dataloader, optimizer, criterion, device, writer, cfg)
|
||||
valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion, device, cfg)
|
||||
scheduler.step(valid_loss)
|
||||
model_path = model.save(epoch, cfg)
|
||||
# logger.info(model_path)
|
||||
|
||||
train_losses.append(train_loss)
|
||||
valid_losses.append(valid_loss)
|
||||
if best_f1 < valid_f1:
|
||||
best_f1 = valid_f1
|
||||
best_epoch = epoch
|
||||
# 使用 valid loss 做 early stopping 的判断标准
|
||||
if es_loss > valid_loss:
|
||||
es_loss = valid_loss
|
||||
es_f1 = valid_f1
|
||||
es_epoch = epoch
|
||||
es_patience = 0
|
||||
es_path = model_path
|
||||
else:
|
||||
es_patience += 1
|
||||
if es_patience >= cfg.early_stopping_patience:
|
||||
best_es_epoch = es_epoch
|
||||
best_es_f1 = es_f1
|
||||
best_es_path = es_path
|
||||
|
||||
if cfg.show_plot:
|
||||
if cfg.plot_utils == 'matplot':
|
||||
plt.plot(train_losses, 'x-')
|
||||
plt.plot(valid_losses, '+-')
|
||||
plt.legend(['train', 'valid'])
|
||||
plt.title('train/valid comparison loss')
|
||||
plt.show()
|
||||
|
||||
if cfg.plot_utils == 'tensorboard':
|
||||
for i in range(len(train_losses)):
|
||||
writer.add_scalars('train/valid_comparison_loss', {
|
||||
'train': train_losses[i],
|
||||
'valid': valid_losses[i]
|
||||
}, i)
|
||||
writer.close()
|
||||
|
||||
logger.info(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '
|
||||
f'this epoch macro f1: {best_es_f1:0.4f}')
|
||||
logger.info(f'this model save path: {best_es_path}')
|
||||
logger.info(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '
|
||||
f'this epoch macro f1: {best_f1:.4f}')
|
||||
|
||||
logger.info('=====end of training====')
|
||||
logger.info('')
|
||||
logger.info('=====start test performance====')
|
||||
validate(-1, model, test_dataloader, criterion, device, cfg)
|
||||
logger.info('=====ending====')
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from sklearn.metrics import precision_recall_fscore_support
|
||||
|
||||
|
||||
class Metric(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the metric to to it's initial state.
|
||||
This is called at the start of each epoch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args):
|
||||
"""
|
||||
Updates the metric's state using the passed batch output.
|
||||
This is called once for each batch.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def compute(self):
|
||||
"""
|
||||
Computes the metric based on it's accumulated state.
|
||||
This is called at the end of each epoch.
|
||||
:return: the actual quantity of interest
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class PRMetric():
|
||||
def __init__(self):
|
||||
"""
|
||||
暂时调用 sklearn 的方法
|
||||
"""
|
||||
self.y_true = np.empty(0)
|
||||
self.y_pred = np.empty(0)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
重置为0
|
||||
"""
|
||||
self.y_true = np.empty(0)
|
||||
self.y_pred = np.empty(0)
|
||||
|
||||
def update(self, y_true: torch.Tensor, y_pred: torch.Tensor):
|
||||
"""
|
||||
更新tensor,保留值,取消原有梯度
|
||||
"""
|
||||
y_true = y_true.cpu().detach().numpy()
|
||||
y_pred = y_pred.cpu().detach().numpy()
|
||||
y_pred = np.argmax(y_pred, axis=-1)
|
||||
|
||||
self.y_true = np.append(self.y_true, y_true)
|
||||
self.y_pred = np.append(self.y_pred, y_pred)
|
||||
|
||||
def compute(self):
|
||||
"""
|
||||
计算acc,p,r,f1并返回
|
||||
"""
|
||||
p, r, f1, _ = precision_recall_fscore_support(self.y_true, self.y_pred, average='macro', warn_for=tuple())
|
||||
_, _, acc, _ = precision_recall_fscore_support(self.y_true, self.y_pred, average='micro', warn_for=tuple())
|
||||
|
||||
return acc, p, r, f1
|
|
@ -0,0 +1,147 @@
|
|||
import os
|
||||
import sys
|
||||
import torch
|
||||
import logging
|
||||
import hydra
|
||||
from hydra import utils
|
||||
from serializer import Serializer
|
||||
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_attribute_data
|
||||
import matplotlib.pyplot as plt
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from utils import load_pkl, load_csv
|
||||
import models as models
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _preprocess_data(data, cfg):
|
||||
vocab = load_pkl(os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl'), verbose=False)
|
||||
attribute_data = load_csv(os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv'), verbose=False)
|
||||
atts = _handle_attribute_data(attribute_data)
|
||||
cfg.vocab_size = vocab.count
|
||||
serializer = Serializer(do_chinese_split=cfg.chinese_split)
|
||||
serial = serializer.serialize
|
||||
|
||||
_serialize_sentence(data, serial, cfg)
|
||||
_convert_tokens_into_index(data, vocab)
|
||||
_add_pos_seq(data, cfg)
|
||||
logger.info('start sentence preprocess...')
|
||||
formats = '\nsentence: {}\nchinese_split: {}\n' \
|
||||
'tokens: {}\ntoken2idx: {}\nlength: {}\nentity_index: {}\nattribute_value_index: {}'
|
||||
logger.info(
|
||||
formats.format(data[0]['sentence'], cfg.chinese_split,
|
||||
data[0]['tokens'], data[0]['token2idx'], data[0]['seq_len'],
|
||||
data[0]['entity_index'], data[0]['attribute_value_index']))
|
||||
return data, atts
|
||||
|
||||
|
||||
def _get_predict_instance(cfg):
|
||||
flag = input('是否使用范例[y/n],退出请输入: exit .... ')
|
||||
flag = flag.strip().lower()
|
||||
if flag == 'y' or flag == 'yes':
|
||||
sentence = '张冬梅,女,汉族,1968年2月生,河南淇县人,1988年7月加入中国共产党,1989年9月参加工作,中央党校经济管理专业毕业,中央党校研究生学历'
|
||||
entity = '张冬梅'
|
||||
attribute_value = '汉族'
|
||||
elif flag == 'n' or flag == 'no':
|
||||
sentence = input('请输入句子:')
|
||||
entity = input('请输入句中需要预测的实体:')
|
||||
attribute_value = input('请输入句中需要预测的属性值:')
|
||||
elif flag == 'exit':
|
||||
sys.exit(0)
|
||||
else:
|
||||
print('please input yes or no, or exit!')
|
||||
_get_predict_instance(cfg)
|
||||
|
||||
|
||||
instance = dict()
|
||||
instance['sentence'] = sentence.strip()
|
||||
instance['entity'] = entity.strip()
|
||||
instance['attribute_value'] = attribute_value.strip()
|
||||
instance['entity_offset'] = sentence.find(entity)
|
||||
instance['attribute_value_offset'] = sentence.find(attribute_value)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_path='../conf/config.yaml')
|
||||
def main(cfg):
|
||||
cwd = utils.get_original_cwd()
|
||||
cwd = cwd[0:-5]
|
||||
cfg.cwd = cwd
|
||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||
print(cfg.pretty())
|
||||
|
||||
# get predict instance
|
||||
instance = _get_predict_instance(cfg)
|
||||
data = [instance]
|
||||
|
||||
# preprocess data
|
||||
data, rels = _preprocess_data(data, cfg)
|
||||
|
||||
# model
|
||||
__Model__ = {
|
||||
'cnn': models.PCNN,
|
||||
'rnn': models.BiLSTM,
|
||||
'transformer': models.Transformer,
|
||||
'gcn': models.GCN,
|
||||
'capsule': models.Capsule,
|
||||
'lm': models.LM,
|
||||
}
|
||||
|
||||
# 最好在 cpu 上预测
|
||||
cfg.use_gpu = False
|
||||
if cfg.use_gpu and torch.cuda.is_available():
|
||||
device = torch.device('cuda', cfg.gpu_id)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
logger.info(f'device: {device}')
|
||||
|
||||
model = __Model__[cfg.model_name](cfg)
|
||||
logger.info(f'model name: {cfg.model_name}')
|
||||
logger.info(f'\n {model}')
|
||||
model.load(cfg.fp, device=device)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
x = dict()
|
||||
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
||||
|
||||
if cfg.model_name != 'lm':
|
||||
x['entity_pos'], x['attribute_value_pos'] = torch.tensor([data[0]['entity_pos']]), torch.tensor([data[0]['attribute_value_pos']])
|
||||
if cfg.model_name == 'cnn':
|
||||
if cfg.use_pcnn:
|
||||
x['pcnn_mask'] = torch.tensor([data[0]['entities_pos']])
|
||||
if cfg.model_name == 'gcn':
|
||||
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
||||
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
|
||||
x['adj'] = adj
|
||||
|
||||
|
||||
for key in x.keys():
|
||||
x[key] = x[key].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
y_pred = model(x)
|
||||
y_pred = torch.softmax(y_pred, dim=-1)[0]
|
||||
prob = y_pred.max().item()
|
||||
prob_att = list(rels.keys())[y_pred.argmax().item()]
|
||||
logger.info(f"\"{data[0]['entity']}\" 和 \"{data[0]['attribute_value']}\" 在句中属性为:\"{prob_att}\",置信度为{prob:.2f}。")
|
||||
|
||||
if cfg.predict_plot:
|
||||
plt.rcParams["font.family"] = 'Arial Unicode MS'
|
||||
x = list(rels.keys())
|
||||
height = list(y_pred.cpu().numpy())
|
||||
plt.bar(x, height)
|
||||
for x, y in zip(x, height):
|
||||
plt.text(x, y, '%.2f' % y, ha="center", va="bottom")
|
||||
plt.xlabel('关系')
|
||||
plt.ylabel('置信度')
|
||||
plt.xticks(rotation=315)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,158 @@
|
|||
import os
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import List, Dict
|
||||
from transformers import BertTokenizer
|
||||
from serializer import Serializer
|
||||
from vocab import Vocab
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from utils import save_pkl, load_csv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"_handle_pos_limit",
|
||||
"_add_pos_seq",
|
||||
"_convert_tokens_into_index",
|
||||
"_serialize_sentence",
|
||||
"_lm_serialize",
|
||||
"_add_attribute_data",
|
||||
"_handle_attribute_data",
|
||||
"preprocess"
|
||||
]
|
||||
def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:
|
||||
for i,p in enumerate(pos):
|
||||
if p > limit:
|
||||
pos[i] = limit
|
||||
if p < -limit:
|
||||
pos[i] = -limit
|
||||
return [p + limit + 1 for p in pos]
|
||||
|
||||
def _add_pos_seq(train_data: List[Dict], cfg):
|
||||
for d in train_data:
|
||||
entities_idx = [d['entity_index'],d['attribute_value_index']
|
||||
] if d['entity_index'] < d['attribute_value_index'] else [d['entity_index'], d['attribute_value_index']]
|
||||
|
||||
d['entity_pos'] = list(map(lambda i: i - d['entity_index'], list(range(d['seq_len']))))
|
||||
d['entity_pos'] = _handle_pos_limit(d['entity_pos'],int(cfg.pos_limit))
|
||||
|
||||
d['attribute_value_pos'] = list(map(lambda i: i - d['attribute_value_index'], list(range(d['seq_len']))))
|
||||
d['attribute_value_pos'] = _handle_pos_limit(d['attribute_value_pos'],int(cfg.pos_limit))
|
||||
|
||||
if cfg.model_name == 'cnn':
|
||||
if cfg.use_pcnn:
|
||||
d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\
|
||||
[3] * (d['seq_len'] - entities_idx[1])
|
||||
|
||||
def _convert_tokens_into_index(data: List[Dict], vocab):
|
||||
unk_str = '[UNK]'
|
||||
unk_idx = vocab.word2idx[unk_str]
|
||||
|
||||
for d in data:
|
||||
d['token2idx'] = [vocab.word2idx.get(i, unk_idx) for i in d['tokens']]
|
||||
d['seq_len'] = len(d['token2idx'])
|
||||
|
||||
def _serialize_sentence(data: List[Dict], serial, cfg):
|
||||
for d in data:
|
||||
sent = d['sentence'].strip()
|
||||
snet = sent.replace(d['entity'] , ' entity ' , 1).replace(d['attribute_value'] , ' attribute_value ' , 1)
|
||||
d['tokens'] = serial(sent, never_split=['entity','attribute_value'])
|
||||
entity_index, attribute_value_index = d['entity_offset'] , d['attribute_value_offset']
|
||||
d['entity_index'],d['attribute_value_index'] = int(entity_index) , int(attribute_value_index)
|
||||
|
||||
def _lm_serialize(data: List[Dict], cfg):
|
||||
logger.info('use bert tokenizer...')
|
||||
tokenizer = BertTokenizer.from_pretrained(cfg.lm_file)
|
||||
for d in data:
|
||||
sent = d['sentence'].strip()
|
||||
sent += '[SEP]' + d['entity'] + '[SEP]' + d['attribute_value']
|
||||
d['token2idx'] = tokenizer.encode(sent, add_special_tokens=True)
|
||||
d['seq_len'] = len(d['token2idx'])
|
||||
|
||||
def _add_attribute_data(atts: Dict, data: List) -> None:
|
||||
for d in data:
|
||||
d['att2idx'] = atts[d['attribute']]['index']
|
||||
|
||||
def _handle_attribute_data(attribute_data: List[Dict]) -> Dict:
|
||||
atts = OrderedDict()
|
||||
attribute_data = sorted(attribute_data, key=lambda i: int(i['index']))
|
||||
for d in attribute_data:
|
||||
atts[d['attribute']] = {
|
||||
'index': int(d['index'])
|
||||
}
|
||||
return atts
|
||||
|
||||
def preprocess(cfg):
|
||||
logger.info('===== start preprocess data =====')
|
||||
train_fp = os.path.join(cfg.cwd, cfg.data_path, 'train.csv')
|
||||
valid_fp = os.path.join(cfg.cwd, cfg.data_path, 'valid.csv')
|
||||
test_fp = os.path.join(cfg.cwd, cfg.data_path, 'test.csv')
|
||||
attribute_fp = os.path.join(cfg.cwd, cfg.data_path, 'attribute.csv')
|
||||
|
||||
logger.info('load raw files...')
|
||||
train_data = load_csv(train_fp)
|
||||
valid_data = load_csv(valid_fp)
|
||||
test_data = load_csv(test_fp)
|
||||
attribute_data = load_csv(attribute_fp)
|
||||
|
||||
logger.info('convert relation into index...')
|
||||
atts = _handle_attribute_data(attribute_data)
|
||||
_add_attribute_data(atts,train_data)
|
||||
_add_attribute_data(atts,test_data)
|
||||
_add_attribute_data(atts,valid_data)
|
||||
|
||||
logger.info('verify whether use pretrained language models...')
|
||||
if cfg.model_name == 'lm':
|
||||
logger.info('use pretrained language models serialize sentence...')
|
||||
_lm_serialize(train_data, cfg)
|
||||
_lm_serialize(valid_data, cfg)
|
||||
_lm_serialize(test_data, cfg)
|
||||
else:
|
||||
logger.info('serialize sentence into tokens...')
|
||||
serializer = Serializer(do_chinese_split=cfg.chinese_split, do_lower_case=True)
|
||||
serial = serializer.serialize
|
||||
_serialize_sentence(train_data, serial, cfg)
|
||||
_serialize_sentence(valid_data, serial, cfg)
|
||||
_serialize_sentence(test_data, serial, cfg)
|
||||
|
||||
logger.info('build vocabulary...')
|
||||
vocab = Vocab('word')
|
||||
train_tokens = [d['tokens'] for d in train_data]
|
||||
valid_tokens = [d['tokens'] for d in valid_data]
|
||||
test_tokens = [d['tokens'] for d in test_data]
|
||||
sent_tokens = [*train_tokens, *valid_tokens, *test_tokens]
|
||||
|
||||
for sent in sent_tokens:
|
||||
vocab.add_words(sent)
|
||||
vocab.trim(min_freq=cfg.min_freq)
|
||||
|
||||
_convert_tokens_into_index(train_data, vocab)
|
||||
_convert_tokens_into_index(valid_data, vocab)
|
||||
_convert_tokens_into_index(test_data, vocab)
|
||||
|
||||
logger.info('build position sequence...')
|
||||
_add_pos_seq(train_data, cfg)
|
||||
_add_pos_seq(valid_data, cfg)
|
||||
_add_pos_seq(test_data, cfg)
|
||||
|
||||
logger.info('save data for backup...')
|
||||
os.makedirs(os.path.join(cfg.cwd, cfg.out_path), exist_ok=True)
|
||||
train_save_fp = os.path.join(cfg.cwd, cfg.out_path, 'train.pkl')
|
||||
valid_save_fp = os.path.join(cfg.cwd, cfg.out_path, 'valid.pkl')
|
||||
test_save_fp = os.path.join(cfg.cwd, cfg.out_path, 'test.pkl')
|
||||
save_pkl(train_data, train_save_fp)
|
||||
save_pkl(valid_data, valid_save_fp)
|
||||
save_pkl(test_data, test_save_fp)
|
||||
|
||||
if cfg.model_name != 'lm':
|
||||
vocab_save_fp = os.path.join(cfg.cwd, cfg.out_path, 'vocab.pkl')
|
||||
vocab_txt = os.path.join(cfg.cwd, cfg.out_path, 'vocab.txt')
|
||||
save_pkl(vocab, vocab_save_fp)
|
||||
logger.info('save vocab in txt file, for watching...')
|
||||
with open(vocab_txt, 'w', encoding='utf-8') as f:
|
||||
f.write(os.linesep.join(vocab.word2idx.keys()))
|
||||
|
||||
logger.info('===== end preprocess data =====')
|
||||
|
||||
|
|
@ -0,0 +1,270 @@
|
|||
import re
|
||||
import unicodedata
|
||||
import jieba
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
jieba.setLogLevel(logging.INFO)
|
||||
|
||||
|
||||
class Serializer():
|
||||
def __init__(self, never_split: List = None, do_lower_case=True, do_chinese_split=False):
|
||||
self.never_split = never_split if never_split is not None else []
|
||||
self.do_lower_case = do_lower_case
|
||||
self.do_chinese_split = do_chinese_split
|
||||
|
||||
def serialize(self, text, never_split: List = None):
|
||||
"""
|
||||
将一段文本按照制定拆分规则,拆分成一个词汇List
|
||||
Args :
|
||||
text (String) : 所需拆分文本
|
||||
never_split (List) : 不拆分的词,默认为空
|
||||
Rerurn :
|
||||
output_tokens (List): 拆分后的结果
|
||||
"""
|
||||
never_split = self.never_split + (never_split if never_split is not None else [])
|
||||
text = self._clean_text(text)
|
||||
|
||||
if self.do_chinese_split:
|
||||
output_tokens = self._use_jieba_cut(text, never_split)
|
||||
return output_tokens
|
||||
|
||||
text = self._tokenize_chinese_chars(text)
|
||||
orig_tokens = self._orig_tokenize(text)
|
||||
split_tokens = []
|
||||
for token in orig_tokens:
|
||||
if self.do_lower_case and token not in never_split:
|
||||
token = token.lower()
|
||||
token = self._run_strip_accents(token)
|
||||
split_tokens.extend(self._run_split_on_punc(token, never_split=never_split))
|
||||
|
||||
output_tokens = self._whitespace_tokenize(" ".join(split_tokens))
|
||||
|
||||
return output_tokens
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""
|
||||
删除文本中无效字符以及空白字符
|
||||
Arg :
|
||||
text (String) : 所需删除的文本
|
||||
Return :
|
||||
"".join(output) (String) : 删除后的文本
|
||||
"""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if cp == 0 or cp == 0xfffd or self.is_control(char):
|
||||
continue
|
||||
if self.is_whitespace(char):
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _use_jieba_cut(self, text, never_split):
|
||||
"""
|
||||
使用jieba分词
|
||||
Args :
|
||||
text (String) : 所需拆分文本
|
||||
never_split (List) : 不拆分的词
|
||||
Return :
|
||||
tokens (List) : 拆分完的结果
|
||||
"""
|
||||
for word in never_split:
|
||||
jieba.suggest_freq(word, True)
|
||||
tokens = jieba.lcut(text)
|
||||
if self.do_lower_case:
|
||||
tokens = [i.lower() for i in tokens]
|
||||
try:
|
||||
while True:
|
||||
tokens.remove(' ')
|
||||
except:
|
||||
return tokens
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""
|
||||
在CJK字符周围添加空格
|
||||
Arg :
|
||||
text (String) : 所需拆分文本
|
||||
Return :
|
||||
"".join(output) (String) : 添加完后的文本
|
||||
"""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
if self.is_chinese_char(cp):
|
||||
output.append(" ")
|
||||
output.append(char)
|
||||
output.append(" ")
|
||||
else:
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _orig_tokenize(self, text):
|
||||
"""
|
||||
在空白和一些标点符号(如逗号或句点)上拆分文本
|
||||
Arg :
|
||||
text (String) : 所需拆分文本
|
||||
Return :
|
||||
tokens (List) : 分词完的结果
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
# 常见的断句标点
|
||||
punc = """,.?!;: 、|,。?!;:《》「」【】/<>|\“ ”‘ ’"""
|
||||
punc_re = '|'.join(re.escape(x) for x in punc)
|
||||
tokens = re.sub(punc_re, lambda x: ' ' + x.group() + ' ', text)
|
||||
tokens = tokens.split()
|
||||
return tokens
|
||||
|
||||
def _whitespace_tokenize(self, text):
|
||||
"""
|
||||
进行基本的空白字符清理和分割
|
||||
Arg :
|
||||
text (String) : 所需拆分文本
|
||||
Return :
|
||||
tokens (List) : 分词完的结果
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
tokens = text.split()
|
||||
return tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""
|
||||
从文本中去除重音符号
|
||||
Arg :
|
||||
text (String) : 所需拆分文本
|
||||
Return :
|
||||
"".join(output) (String) : 去除后的文本
|
||||
|
||||
"""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Mn":
|
||||
continue
|
||||
output.append(char)
|
||||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text, never_split=None):
|
||||
"""
|
||||
通过标点符号拆分文本
|
||||
Args :
|
||||
text (String) : 所需拆分文本
|
||||
never_split (List) : 不拆分的词,默认为空
|
||||
Return :
|
||||
["".join(x) for x in output] (List) : 拆分完的结果
|
||||
"""
|
||||
|
||||
if never_split is not None and text in never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
i = 0
|
||||
start_new_word = True
|
||||
output = []
|
||||
while i < len(chars):
|
||||
char = chars[i]
|
||||
if self.is_punctuation(char):
|
||||
output.append([char])
|
||||
start_new_word = True
|
||||
else:
|
||||
if start_new_word:
|
||||
output.append([])
|
||||
start_new_word = False
|
||||
output[-1].append(char)
|
||||
i += 1
|
||||
|
||||
return ["".join(x) for x in output]
|
||||
|
||||
@staticmethod
|
||||
def is_control(char):
|
||||
"""
|
||||
判断字符是否为控制字符
|
||||
Arg :
|
||||
char : 字符
|
||||
Return :
|
||||
bool : 判断结果
|
||||
"""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
return False
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("C"):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_whitespace(char):
|
||||
"""
|
||||
判断字符是否为空白字符
|
||||
Arg :
|
||||
char : 字符
|
||||
Return :
|
||||
bool : 判断结果
|
||||
"""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat == "Zs":
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_chinese_char(cp):
|
||||
"""
|
||||
|
||||
判断字符是否为中文字符
|
||||
Arg :
|
||||
cp (char): 字符
|
||||
Return :
|
||||
bool : 判断结果
|
||||
|
||||
"""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
||||
# despite its name. The modern Korean Hangul alphabet is a different block,
|
||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
||||
# space-separated words, so they are not treated specially and handled
|
||||
# like the all of the other languages.
|
||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
|
||||
(cp >= 0x3400 and cp <= 0x4DBF) or #
|
||||
(cp >= 0x20000 and cp <= 0x2A6DF) or #
|
||||
(cp >= 0x2A700 and cp <= 0x2B73F) or #
|
||||
(cp >= 0x2B740 and cp <= 0x2B81F) or #
|
||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or (cp >= 0xF900 and cp <= 0xFAFF) or #
|
||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_punctuation(char):
|
||||
"""
|
||||
判断字符是否为标点字符
|
||||
Arg :
|
||||
char : 字符
|
||||
Return :
|
||||
bool : 判断结果
|
||||
"""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
# Punctuation class but we treat them as punctuation anyways, for
|
||||
# consistency.
|
||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96)
|
||||
or (cp >= 123 and cp <= 126)):
|
||||
return True
|
||||
cat = unicodedata.category(char)
|
||||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
|
@ -0,0 +1,117 @@
|
|||
import torch
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
from metrics import PRMetric
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
|
||||
"""
|
||||
training the model.
|
||||
Args:
|
||||
epoch (int): number of training steps.
|
||||
model (class): model of training.
|
||||
dataloader (dict): dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders.
|
||||
optimizer (Callable): optimizer of training.
|
||||
criterion (Callable): loss criterion of training.
|
||||
device (torch.device): device of training.
|
||||
writer (class): output to tensorboard.
|
||||
cfg: configutation of training.
|
||||
Return:
|
||||
losses[-1] : the loss of training
|
||||
"""
|
||||
model.train()
|
||||
|
||||
metric = PRMetric()
|
||||
losses = []
|
||||
|
||||
for batch_idx, (x, y) in enumerate(dataloader, 1):
|
||||
for key, value in x.items():
|
||||
x[key] = value.to(device)
|
||||
|
||||
y = y.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
y_pred = model(x)
|
||||
|
||||
if cfg.model_name == 'capsule':
|
||||
loss = model.loss(y_pred, y)
|
||||
else:
|
||||
loss = criterion(y_pred, y)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
metric.update(y_true=y, y_pred=y_pred)
|
||||
losses.append(loss.item())
|
||||
|
||||
data_total = len(dataloader.dataset)
|
||||
data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)
|
||||
if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):
|
||||
# p r f1 皆为 macro,因为micro时三者相同,定义为acc
|
||||
acc, p, r, f1 = metric.compute()
|
||||
logger.info(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\t'
|
||||
f'Loss: {loss.item():.6f}')
|
||||
logger.info(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\t'
|
||||
f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||
|
||||
if cfg.show_plot and not cfg.only_comparison_plot:
|
||||
if cfg.plot_utils == 'matplot':
|
||||
plt.plot(losses)
|
||||
plt.title(f'epoch {epoch} train loss')
|
||||
plt.show()
|
||||
|
||||
if cfg.plot_utils == 'tensorboard':
|
||||
for i in range(len(losses)):
|
||||
writer.add_scalar(f'epoch_{epoch}_training_loss', losses[i], i)
|
||||
|
||||
return losses[-1]
|
||||
|
||||
|
||||
def validate(epoch, model, dataloader, criterion, device, cfg):
|
||||
"""
|
||||
validating the model.
|
||||
Args:
|
||||
epoch (int): number of validating steps.
|
||||
model (class): model of validating.
|
||||
dataloader (dict): dict of dataset iterator. Keys are tasknames, values are corresponding dataloaders.
|
||||
criterion (Callable): loss criterion of validating.
|
||||
device (torch.device): device of validating.
|
||||
cfg: configutation of validating.
|
||||
Return:
|
||||
f1 : f1 score
|
||||
loss : the loss of validating
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
metric = PRMetric()
|
||||
losses = []
|
||||
|
||||
for batch_idx, (x, y) in enumerate(dataloader, 1):
|
||||
for key, value in x.items():
|
||||
x[key] = value.to(device)
|
||||
y = y.to(device)
|
||||
with torch.no_grad():
|
||||
y_pred = model(x)
|
||||
|
||||
if cfg.model_name == 'capsule':
|
||||
loss = model.loss(y_pred, y)
|
||||
else:
|
||||
loss = criterion(y_pred, y)
|
||||
|
||||
metric.update(y_true=y, y_pred=y_pred)
|
||||
losses.append(loss.item())
|
||||
|
||||
loss = sum(losses) / len(losses)
|
||||
acc, p, r, f1 = metric.compute()
|
||||
data_total = len(dataloader.dataset)
|
||||
|
||||
if epoch >= 0:
|
||||
logger.info(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}')
|
||||
logger.info(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||
else:
|
||||
logger.info(f'Test Data: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}')
|
||||
logger.info(f'Test Data: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||
|
||||
return f1, loss
|
|
@ -0,0 +1,113 @@
|
|||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Sequence, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SPECIAL_TOKENS_KEYS = [
|
||||
"pad_token",
|
||||
"unk_token",
|
||||
"mask_token",
|
||||
"cls_token",
|
||||
"sep_token",
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"head_token",
|
||||
"tail_token",
|
||||
|
||||
]
|
||||
|
||||
SPECIAL_TOKENS_VALUES = [
|
||||
"[PAD]",
|
||||
"[UNK]",
|
||||
"[MASK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[BOS]",
|
||||
"[EOS]",
|
||||
"HEAD",
|
||||
"TAIL",
|
||||
]
|
||||
|
||||
SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES))
|
||||
|
||||
|
||||
class Vocab(object):
|
||||
"""
|
||||
构建词汇表,增加词汇,删除低频词汇
|
||||
"""
|
||||
def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS):
|
||||
self.name = name
|
||||
self.init_tokens = init_tokens
|
||||
self.trimed = False
|
||||
self.word2idx = {}
|
||||
self.word2count = {}
|
||||
self.idx2word = {}
|
||||
self.count = 0
|
||||
self._add_init_tokens()
|
||||
|
||||
def _add_init_tokens(self):
|
||||
"""
|
||||
添加初始tokens
|
||||
"""
|
||||
for token in self.init_tokens.values():
|
||||
self._add_word(token)
|
||||
|
||||
def _add_word(self, word: str):
|
||||
"""
|
||||
增加单个词汇
|
||||
Arg :
|
||||
word (String) : 增加的词汇
|
||||
"""
|
||||
if word not in self.word2idx:
|
||||
self.word2idx[word] = self.count
|
||||
self.word2count[word] = 1
|
||||
self.idx2word[self.count] = word
|
||||
self.count += 1
|
||||
else:
|
||||
self.word2count[word] += 1
|
||||
|
||||
def add_words(self, words: Sequence):
|
||||
"""
|
||||
通过数组增加词汇
|
||||
Arg :
|
||||
words (List) : 增加的词汇组
|
||||
"""
|
||||
for word in words:
|
||||
self._add_word(word)
|
||||
|
||||
def trim(self, min_freq=2, verbose: Optional[bool] = True):
|
||||
"""
|
||||
当 word 词频低于 min_freq 时,从词库中删除
|
||||
Args:
|
||||
min_freq (int): 最低词频
|
||||
verbose (bool) : 是否打印日志
|
||||
"""
|
||||
assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}'
|
||||
min_freq = int(min_freq)
|
||||
if min_freq < 2:
|
||||
return
|
||||
if self.trimed:
|
||||
return
|
||||
self.trimed = True
|
||||
|
||||
keep_words = []
|
||||
new_words = []
|
||||
|
||||
for k, v in self.word2count.items():
|
||||
if v >= min_freq:
|
||||
keep_words.append(k)
|
||||
new_words.extend([k] * v)
|
||||
if verbose:
|
||||
before_len = len(keep_words)
|
||||
after_len = len(self.word2idx) - len(self.init_tokens)
|
||||
logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(
|
||||
before_len, after_len, before_len / after_len * 100))
|
||||
|
||||
# Reinitialize dictionaries
|
||||
self.word2idx = {}
|
||||
self.word2count = {}
|
||||
self.idx2word = {}
|
||||
self.count = 0
|
||||
self._add_init_tokens()
|
||||
self.add_words(new_words)
|
|
@ -0,0 +1,2 @@
|
|||
from .ioUtils import *
|
||||
from .nnUtils import *
|
|
@ -0,0 +1,181 @@
|
|||
import os
|
||||
import csv
|
||||
import json
|
||||
import pickle
|
||||
import logging
|
||||
from typing import NewType, List, Tuple, Dict, Any
|
||||
|
||||
__all__ = [
|
||||
'load_pkl',
|
||||
'save_pkl',
|
||||
'load_csv',
|
||||
'save_csv',
|
||||
'load_jsonld',
|
||||
'save_jsonld',
|
||||
'jsonld2csv',
|
||||
'csv2jsonld',
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Path = str
|
||||
|
||||
|
||||
def load_pkl(fp: Path, verbose: bool = True) -> Any:
|
||||
"""
|
||||
读取文件
|
||||
Args :
|
||||
fp (String) : 读取数据地址
|
||||
verbose (bool) : 是否打印日志
|
||||
Return :
|
||||
data (Any) : 读取的数据
|
||||
"""
|
||||
if verbose:
|
||||
logger.info(f'load data from {fp}')
|
||||
|
||||
with open(fp, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:
|
||||
"""
|
||||
保存文件
|
||||
Args :
|
||||
data (Any) : 数据
|
||||
fp (String) :保存的地址
|
||||
verbose (bool) : 是否打印日志
|
||||
"""
|
||||
if verbose:
|
||||
logger.info(f'save data in {fp}')
|
||||
|
||||
with open(fp, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
|
||||
def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:
|
||||
"""
|
||||
读取csv格式文件
|
||||
Args :
|
||||
fp (String) : 保存地址
|
||||
is_tsv (bool) : 是否为excel-tab格式
|
||||
verbose (bool) : 是否打印日志
|
||||
Return :
|
||||
list(reader) (List): 读取的List数据
|
||||
"""
|
||||
if verbose:
|
||||
logger.info(f'load csv from {fp}')
|
||||
|
||||
dialect = 'excel-tab' if is_tsv else 'excel'
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f, dialect=dialect)
|
||||
return list(reader)
|
||||
|
||||
|
||||
def save_csv(data: List[Dict], fp: Path, save_in_tsv: False, write_head=True, verbose=True) -> None:
|
||||
"""
|
||||
保存csv格式文件
|
||||
Args :
|
||||
data (List) : 所需保存的List数据
|
||||
fp (String) : 保存地址
|
||||
save_in_tsv (bool) : 是否保存为excel-tab格式
|
||||
write_head (bool) : 是否写表头
|
||||
verbose (bool) : 是否打印日志
|
||||
"""
|
||||
if verbose:
|
||||
logger.info(f'save csv file in: {fp}')
|
||||
|
||||
with open(fp, 'w', encoding='utf-8') as f:
|
||||
fieldnames = data[0].keys()
|
||||
dialect = 'excel-tab' if save_in_tsv else 'excel'
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames, dialect=dialect)
|
||||
if write_head:
|
||||
writer.writeheader()
|
||||
writer.writerows(data)
|
||||
|
||||
|
||||
def load_jsonld(fp: Path, verbose: bool = True) -> List:
|
||||
|
||||
"""
|
||||
读取jsonld文件
|
||||
Args:
|
||||
fp (String): jsonld 文件地址
|
||||
verbose (bool): 是否打印日志
|
||||
Return:
|
||||
datas (List) : 读取后的List
|
||||
"""
|
||||
if verbose:
|
||||
logger.info(f'load jsonld from {fp}')
|
||||
|
||||
datas = []
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
for l in f:
|
||||
line = json.loads(l)
|
||||
data = list(line.values())
|
||||
datas.append(data)
|
||||
|
||||
return datas
|
||||
|
||||
|
||||
def save_jsonld(fp):
|
||||
"""
|
||||
保存jsonld格式文件
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def jsonld2csv(fp: str, verbose: bool = True) -> str:
|
||||
"""
|
||||
读入 jsonld 文件,存储在同位置同名的 csv 文件
|
||||
Args:
|
||||
fp (String): jsonld 文件地址
|
||||
verbose (bool): 是否打印日志
|
||||
Return:
|
||||
fp_new (String):文件地址
|
||||
"""
|
||||
data = []
|
||||
root, ext = os.path.splitext(fp)
|
||||
fp_new = root + '.csv'
|
||||
if verbose:
|
||||
print(f'read jsonld file in: {fp}')
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
for l in f:
|
||||
line = json.loads(l)
|
||||
data.append(line)
|
||||
if verbose:
|
||||
print('saving...')
|
||||
with open(fp_new, 'w', encoding='utf-8') as f:
|
||||
fieldnames = data[0].keys()
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames, dialect='excel')
|
||||
writer.writeheader()
|
||||
writer.writerows(data)
|
||||
if verbose:
|
||||
print(f'saved csv file in: {fp_new}')
|
||||
return fp_new
|
||||
|
||||
|
||||
def csv2jsonld(fp: str, verbose: bool = True) -> str:
|
||||
"""
|
||||
读入 csv 文件,存储在同位置同名的 jsonld 文件
|
||||
Args:
|
||||
fp (String): csv 文件地址
|
||||
verbose (bool): 是否打印日志
|
||||
Return:
|
||||
fp_new (String):文件地址
|
||||
"""
|
||||
data = []
|
||||
root, ext = os.path.splitext(fp)
|
||||
fp_new = root + '.jsonld'
|
||||
if verbose:
|
||||
print(f'read csv file in: {fp}')
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
writer = csv.DictReader(f, fieldnames=None, dialect='excel')
|
||||
for line in writer:
|
||||
data.append(line)
|
||||
if verbose:
|
||||
print('saving...')
|
||||
with open(fp_new, 'w', encoding='utf-8') as f:
|
||||
f.write(os.linesep.join([json.dumps(l, ensure_ascii=False) for l in data]))
|
||||
if verbose:
|
||||
print(f'saved jsonld file in: {fp_new}')
|
||||
return fp_new
|
|
@ -0,0 +1,76 @@
|
|||
import torch
|
||||
import random
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Dict, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
'manual_seed',
|
||||
'seq_len_to_mask',
|
||||
'to_one_hot',
|
||||
]
|
||||
|
||||
|
||||
def manual_seed(seed: int = 1) -> None:
|
||||
"""
|
||||
设置seed。
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
#if torch.cuda.CUDA_ENABLED and use_deterministic_cudnn:
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
|
||||
def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):
|
||||
"""
|
||||
将一个表示sequence length的一维数组转换为二维的mask,默认pad的位置为1。
|
||||
转变 1-d seq_len到2-d mask。
|
||||
|
||||
Args :
|
||||
seq_len (list, np.ndarray, torch.LongTensor) : shape将是(B,)
|
||||
max_len (int): 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有区别,所以需要传入一个max_len使得mask的长度是pad到该长度。
|
||||
Return:
|
||||
mask (np.ndarray, torch.Tensor) : shape将是(B, max_length), 元素类似为bool或torch.uint8
|
||||
"""
|
||||
if isinstance(seq_len, list):
|
||||
seq_len = np.array(seq_len)
|
||||
|
||||
if isinstance(seq_len, np.ndarray):
|
||||
seq_len = torch.from_numpy(seq_len)
|
||||
|
||||
if isinstance(seq_len, torch.Tensor):
|
||||
assert seq_len.dim() == 1, logger.error(f"seq_len can only have one dimension, got {seq_len.dim()} != 1.")
|
||||
batch_size = seq_len.size(0)
|
||||
max_len = int(max_len) if max_len else seq_len.max().long()
|
||||
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len.device)
|
||||
if mask_pos_to_true:
|
||||
mask = broad_cast_seq_len.ge(seq_len.unsqueeze(1))
|
||||
else:
|
||||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
|
||||
else:
|
||||
raise logger.error("Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.")
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def to_one_hot(x: torch.Tensor, length: int) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor):[B] , 一般是 target 的值
|
||||
length (int) : L ,一般是关系种类树
|
||||
Return:
|
||||
x_one_hot.to(device=x.device) (torch.Tensor) : [B, L] 每一行,只有对应位置为1,其余为0
|
||||
"""
|
||||
B = x.size(0)
|
||||
x_one_hot = torch.zeros(B, length)
|
||||
for i in range(B):
|
||||
x_one_hot[i, x[i]] = 1.0
|
||||
|
||||
return x_one_hot.to(device=x.device)
|
|
@ -11,9 +11,9 @@ from torch.utils.tensorboard import SummaryWriter
|
|||
# self
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
import deepke.models as models
|
||||
from deepke.tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||
from deepke.utils import manual_seed, load_pkl
|
||||
import models as models
|
||||
from tools import preprocess , CustomDataset, collate_fn ,train, validate
|
||||
from utils import manual_seed, load_pkl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@ import torch
|
|||
import logging
|
||||
import hydra
|
||||
from hydra import utils
|
||||
from deepke.tools import Serializer
|
||||
from deepke.tools import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
||||
from serializer import Serializer
|
||||
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
||||
import matplotlib.pyplot as plt
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from deepke.utils import load_pkl, load_csv
|
||||
import deepke.models as models
|
||||
from utils import load_pkl, load_csv
|
||||
import models as models
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
Loading…
Reference in New Issue