update
This commit is contained in:
parent
c5c6ae5db0
commit
84e346efe0
67
README.md
67
README.md
|
@ -1,43 +1,50 @@
|
|||
# Deepke
|
||||
# DeepKE
|
||||
|
||||
deepke 是基于 Pytorch 的中文关系抽取处理套件。
|
||||
DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。
|
||||
|
||||
## 环境依赖:
|
||||
|
||||
- python >= 3.6
|
||||
- torch >=1.0
|
||||
- jieba >= 0.39
|
||||
- scikit_learn >= 0.21
|
||||
- pytorch_transformers>=1.0
|
||||
- torch >= 1.0
|
||||
- jieba >= 0.38
|
||||
- matplotlib >= 3.0
|
||||
- pytorch_transformers >= 1.2
|
||||
- scikit_learn >= 0.20
|
||||
|
||||
|
||||
|
||||
## 主要目录
|
||||
|
||||
```
|
||||
├── checkpoints # 保存训练后的模型参数
|
||||
├── data # 数据目录
|
||||
│ ├── origin # 训练使用的原始数据集
|
||||
│ ├── train.csv # 训练数据集
|
||||
│ ├── test.csv # 测试数据集
|
||||
│ ├── relation.txt # 关系种类
|
||||
├── model # 模型目录
|
||||
│ ├── __init__.py
|
||||
│ ├── BasicModule.py # 模型基本配置
|
||||
│ ├── Embedding.py # Embeddding 模块
|
||||
│ ├── CNN.py # CNN & PCNN 模型
|
||||
│ ├── BiLSTM.py # BiLSTM 模型
|
||||
│ ├── Transformer.py # Transformer 模型
|
||||
│ ├── Capsule.py # Capsule 模型
|
||||
│ ├── Bert.py # 语言预训练 模型
|
||||
├── src
|
||||
│ ├── config.py # 配置文件
|
||||
│ ├── vocab.py # 词汇表构建函数
|
||||
│ ├── process.py # 训练前预处理数据
|
||||
│ ├── dataset.py # 训练时批处理输入数据
|
||||
│ ├── trainer.py # 训练迭代函数
|
||||
│ ├── utils.py # 工具函数
|
||||
├── main.py # 主入口文件
|
||||
├── README.md # read me 文件
|
||||
├── bert_pretrained # 使用 bert 时存放的预训练模型参数
|
||||
│ ├── vocab.txt # BERT 模型词表
|
||||
│ ├── config.json # BERT 模型结构的配置文件
|
||||
│ ├── pytorch_model.bin # 预训练模型参数
|
||||
├── checkpoints # 保存训练后的模型参数
|
||||
├── data # 数据目录
|
||||
│ ├── origin # 训练使用的原始数据集
|
||||
│ ├── train.csv # 训练数据集
|
||||
│ ├── test.csv # 测试数据集
|
||||
│ ├── relation.txt # 关系种类
|
||||
├── deepke
|
||||
├── model # 模型目录
|
||||
│ ├── __init__.py
|
||||
│ ├── BasicModule.py # 模型基本配置
|
||||
│ ├── Embedding.py # Embeddding 模块
|
||||
│ ├── CNN.py # CNN & PCNN 模型
|
||||
│ ├── BiLSTM.py # BiLSTM 模型
|
||||
│ ├── Transformer.py # Transformer 模型
|
||||
│ ├── Capsule.py # Capsule 模型
|
||||
│ ├── Bert.py # 语言预训练 模型
|
||||
├── __init__.py
|
||||
├── config.py # 配置文件
|
||||
├── vocab.py # 词汇表构建函数
|
||||
├── preprocess.py # 训练前预处理数据
|
||||
├── dataset.py # 训练时批处理输入数据
|
||||
├── trainer.py # 训练迭代函数
|
||||
├── utils.py # 工具函数
|
||||
├── main.py # 主入口文件
|
||||
├── README.md # read me 文件
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
文件夹内必须有三个文件:
|
||||
|
||||
- bert_config.json `BERT 结构的配置文件`
|
||||
- pytorch_model.bin `预训练后保存的模型`
|
||||
- vocab.txt `BERT 词表`
|
||||
- config.json `BERT 模型结构的配置文件`
|
||||
- pytorch_model.bin `预训练模型参数`
|
||||
- vocab.txt `BERT 模型词表`
|
||||
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
sentence,head,head_type,head_offset,tail,tail_type,tail_offset
|
||||
“逆袭”系列微电影《宝贝》由优酷土豆股份有限公司于2012年出品,宝贝,影视作品,10,优酷土豆股份有限公司,企业,14
|
||||
位于伦敦东南方的格林威治,为地球经线的起始点,格林威治,景点,8,伦敦,城市,2
|
||||
崔恒源 男,1950年3月生,祖籍河南省孟县,现任孟县无缝钢管厂党委书记、厂长,崔恒源,人物,0,河南省孟县,地点,17
|
||||
帅长斌,男,1964年6月生,江西九江人,帅长斌,人物,0,江西九江,地点,15
|
||||
图为《西游记》拍摄幕后照片,猪八戒的大耳朵都掉了一只,可见当时拍摄条件实在有限,但是导演杨洁精益求精,使得这部电视剧成为经典,西游记,影视作品,3,杨洁,人物,44
|
|
143
deepke/config.py
143
deepke/config.py
|
@ -1,6 +1,63 @@
|
|||
# 原始文件位置
|
||||
class TrainingConfig(object):
|
||||
seed = 1
|
||||
use_gpu = True
|
||||
gpu_id = 0
|
||||
epoch = 30
|
||||
learning_rate = 1e-3
|
||||
decay_rate = 0.5
|
||||
decay_patience = 3
|
||||
batch_size = 64
|
||||
train_log = True
|
||||
log_interval = 10
|
||||
show_plot = True
|
||||
f1_norm = ['macro', 'micro']
|
||||
|
||||
|
||||
class ModelConfig(object):
|
||||
word_dim = 50
|
||||
pos_size = 102 # 2 * pos_limit + 2
|
||||
pos_dim = 5
|
||||
feature_dim = 60 # 50 + 5 * 2
|
||||
hidden_dim = 100
|
||||
dropout = 0.3
|
||||
|
||||
|
||||
class CNNConfig(object):
|
||||
use_pcnn = True
|
||||
out_channels = 100
|
||||
kernel_size = [3, 5]
|
||||
|
||||
|
||||
class RNNConfig(object):
|
||||
lstm_layers = 2
|
||||
last_hn = False
|
||||
|
||||
|
||||
class GCNConfig(object):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
class TransformerConfig(object):
|
||||
transformer_layers = 2
|
||||
|
||||
|
||||
class CapsuleConfig(object):
|
||||
num_primary_units = 8
|
||||
num_output_units = 10 # relation_type
|
||||
primary_channels = 1
|
||||
primary_unit_size = 768
|
||||
output_unit_size = 128
|
||||
num_iterations = 5
|
||||
|
||||
|
||||
class LMConfig(object):
|
||||
# lm_name = 'bert-base-chinese' # download usage
|
||||
lm_file = 'bert_pretrained' # cache file usage
|
||||
|
||||
|
||||
class Config(object):
|
||||
# 原始数据存放位置
|
||||
data_path = 'data/origin'
|
||||
# 预处理后存放文件的位置
|
||||
out_path = 'data/out'
|
||||
|
@ -16,81 +73,37 @@ class Config(object):
|
|||
# vocab 构建时最低词频控制
|
||||
min_freq = 2
|
||||
|
||||
# position embedding
|
||||
# position limit
|
||||
pos_limit = 50 # [-50, 50]
|
||||
pos_size = 102 # 2 * pos_limit + 2
|
||||
|
||||
# model name
|
||||
# (CNN, BiLSTM, Transformer, Capsule, Bert)
|
||||
# (CNN, RNN, GCN, Transformer, Capsule, LM)
|
||||
model_name = 'CNN'
|
||||
|
||||
# model
|
||||
word_dim = 50
|
||||
pos_dim = 5
|
||||
training = TrainingConfig()
|
||||
model = ModelConfig()
|
||||
cnn = CNNConfig()
|
||||
rnn = RNNConfig()
|
||||
gcn = GCNConfig()
|
||||
transformer = TransformerConfig()
|
||||
capsule = CapsuleConfig()
|
||||
lm = LMConfig()
|
||||
|
||||
# feature_dim = 50 + 5 * 2
|
||||
hidden_dim = 100
|
||||
dropout = 0.3
|
||||
|
||||
# PCNN config
|
||||
use_pcnn = True
|
||||
out_channels = 100
|
||||
kernel_size = [3, 5]
|
||||
|
||||
# BiLSTM
|
||||
lstm_layers = 2
|
||||
last_hn = False
|
||||
|
||||
# Transformer
|
||||
transformer_layers = 2
|
||||
|
||||
# Capsule
|
||||
num_primary_units=8
|
||||
num_output_units=10 # relation_type
|
||||
primary_channels=1
|
||||
primary_unit_size=768
|
||||
output_unit_size=128
|
||||
num_iterations=5
|
||||
|
||||
# Bert
|
||||
lm_name = 'bert-base-chinese'
|
||||
|
||||
# train
|
||||
seed = 1
|
||||
use_gpu = True
|
||||
gpu_id = 3
|
||||
epoch = 30
|
||||
learning_rate = 1e-3
|
||||
decay_rate = 0.5
|
||||
decay_patience = 3
|
||||
batch_size = 64
|
||||
train_log = True
|
||||
log_interval = 10
|
||||
show_plot = True
|
||||
f1_norm = ['macro', 'micro']
|
||||
|
||||
|
||||
|
||||
|
||||
def parse(self, kwargs):
|
||||
def parse(self, kwargs, verbose=False):
|
||||
'''
|
||||
user can update the default hyperparamter
|
||||
user can update the default hyper parameters
|
||||
'''
|
||||
for k, v in kwargs.items():
|
||||
if not hasattr(self, k):
|
||||
raise Exception('opt has No key: {}'.format(k))
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
print('*************************************************')
|
||||
print('user config:')
|
||||
for k, v in kwargs.items():
|
||||
if not k.startswith('__'):
|
||||
print("{} => {}".format(k, getattr(self, k)))
|
||||
|
||||
print('*************************************************')
|
||||
if verbose:
|
||||
print('*************************************************')
|
||||
print('user config:')
|
||||
for k, v in kwargs.items():
|
||||
if not k.startswith('__'):
|
||||
print("{} => {}".format(k, getattr(self, k)))
|
||||
print('*************************************************')
|
||||
|
||||
|
||||
Config.parse = parse
|
||||
|
||||
config =Config()
|
||||
config = Config()
|
||||
|
|
|
@ -1,35 +1,7 @@
|
|||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from deepke.utils import load_pkl
|
||||
|
||||
|
||||
class CustomLMDataset(Dataset):
|
||||
def __init__(self, fp):
|
||||
self.file = load_pkl(fp)
|
||||
|
||||
def __getitem__(self, item):
|
||||
sample = self.file[item]
|
||||
return sample
|
||||
|
||||
def __len__(self):
|
||||
return len(self.file)
|
||||
|
||||
|
||||
def collate_fn_lm(batch):
|
||||
batch.sort(key=lambda data: len(data[0]), reverse=True)
|
||||
lens = [len(data[0]) for data in batch]
|
||||
max_len = max(lens)
|
||||
|
||||
def _padding(x, max_len):
|
||||
return x + [0] * (max_len - len(x))
|
||||
|
||||
sent_arr = []
|
||||
y_arr = []
|
||||
for data in batch:
|
||||
sent, data_y = data
|
||||
sent_arr.append(_padding(sent, max_len))
|
||||
y_arr.append(data_y)
|
||||
return torch.tensor(sent_arr), torch.tensor(y_arr)
|
||||
from deepke.config import config
|
||||
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
|
@ -45,46 +17,57 @@ class CustomDataset(Dataset):
|
|||
|
||||
|
||||
def collate_fn(batch):
|
||||
batch.sort(key=lambda data: len(data[0]), reverse=True)
|
||||
lens = [len(data[0]) for data in batch]
|
||||
max_len = max(lens)
|
||||
batch.sort(key=lambda data: data['seq_len'], reverse=True)
|
||||
|
||||
max_len = 0
|
||||
for data in batch:
|
||||
if data['seq_len'] > max_len:
|
||||
max_len = data['seq_len']
|
||||
|
||||
def _padding(x, max_len):
|
||||
return x + [0] * (max_len - len(x))
|
||||
|
||||
sent_arr = []
|
||||
head_pos_arr = []
|
||||
tail_pos_arr = []
|
||||
mask_arr = []
|
||||
y_arr = []
|
||||
for data in batch:
|
||||
sent, head_pos, tail_pos, mask, data_y = data
|
||||
sent_arr.append(_padding(sent, max_len))
|
||||
head_pos_arr.append(_padding(head_pos, max_len))
|
||||
tail_pos_arr.append(_padding(tail_pos, max_len))
|
||||
mask_arr.append(_padding(mask, max_len))
|
||||
y_arr.append(data_y)
|
||||
return torch.tensor(sent_arr), torch.tensor(head_pos_arr), torch.tensor(
|
||||
tail_pos_arr), torch.tensor(mask_arr), torch.tensor(y_arr)
|
||||
if config.model_name == 'LM':
|
||||
x, y = [], []
|
||||
for data in batch:
|
||||
x.append(_padding(data['lm_idx'], max_len))
|
||||
y.append(data['target'])
|
||||
|
||||
return torch.tensor(x), torch.tensor(y)
|
||||
|
||||
else:
|
||||
sent, head_pos, tail_pos, mask_pos = [], [], [], []
|
||||
y = []
|
||||
for data in batch:
|
||||
sent.append(_padding(data['word2idx'], max_len))
|
||||
head_pos.append(_padding(data['head_pos'], max_len))
|
||||
tail_pos.append(_padding(data['tail_pos'], max_len))
|
||||
mask_pos.append(_padding(data['mask_pos'], max_len))
|
||||
y.append(data['target'])
|
||||
return torch.Tensor(sent), torch.Tensor(head_pos), torch.Tensor(
|
||||
tail_pos), torch.Tensor(mask_pos), torch.Tensor(y)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from torch.utils.data import DataLoader
|
||||
vocab_path = 'data/out/vocab.pkl'
|
||||
train_data_path = 'data/out/train.pkl'
|
||||
vocab_path = '../data/out/vocab.pkl'
|
||||
train_data_path = '../data/out/train.pkl'
|
||||
vocab = load_pkl(vocab_path)
|
||||
|
||||
train_dataset = CustomDataset(train_data_path)
|
||||
dataloader = DataLoader(train_dataset,
|
||||
batch_size=4,
|
||||
shuffle=False,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn)
|
||||
for idx, (*x, y) in enumerate(dataloader):
|
||||
sent, head_pos, tail_pos, mask = x
|
||||
|
||||
raw_sents = []
|
||||
for i in range(4):
|
||||
raw_sent = [vocab.idx2word[i] for i in sent[i].numpy()]
|
||||
raw_sents.append(''.join(raw_sent))
|
||||
print(raw_sents, head_pos, tail_pos, mask, y, sep='\n\n')
|
||||
for idx, (*x, y) in enumerate(dataloader):
|
||||
print(x)
|
||||
print(y)
|
||||
break
|
||||
# sent, head_pos, tail_pos, mask_pos = x
|
||||
# raw_sents = []
|
||||
# for i in range(4):
|
||||
# raw_sent = [vocab.idx2word[i] for i in sent[i].numpy()]
|
||||
# raw_sents.append(''.join(raw_sent))
|
||||
# print(raw_sents, head_pos, tail_pos, mask, y, sep='\n\n')
|
||||
# break
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch.nn as nn
|
|||
import time
|
||||
from deepke.utils import ensure_dir
|
||||
|
||||
|
||||
class BasicModule(nn.Module):
|
||||
'''
|
||||
封装nn.Module, 提供 save 和 load 方法
|
||||
|
@ -27,7 +28,7 @@ class BasicModule(nn.Module):
|
|||
name = prefix + self.model_name + '_' + f'epoch{epoch}_'
|
||||
name = time.strftime(name + '%m%d_%H:%M:%S.pth')
|
||||
else:
|
||||
name = prefix + name + '_'+ self.model_name + '_' + f'epoch{epoch}_'
|
||||
name = prefix + name + '_' + self.model_name + '_' + f'epoch{epoch}_'
|
||||
name = time.strftime(name + '%m%d_%H:%M:%S.pth')
|
||||
torch.save(self.state_dict(), name)
|
||||
return name
|
||||
return name
|
|
@ -8,26 +8,27 @@ class CNN(BasicModule):
|
|||
def __init__(self, vocab_size, config):
|
||||
super(CNN, self).__init__()
|
||||
self.model_name = 'CNN'
|
||||
self.out_channels = config.out_channels
|
||||
self.kernel_size = config.kernel_size
|
||||
self.word_dim = config.word_dim
|
||||
self.pos_size = config.pos_size
|
||||
self.pos_dim = config.pos_dim
|
||||
self.use_pcnn = config.use_pcnn
|
||||
self.hidden_dim = config.hidden_dim
|
||||
self.vocab_size = vocab_size
|
||||
self.word_dim = config.model.word_dim
|
||||
self.pos_size = config.model.pos_size
|
||||
self.pos_dim = config.model.pos_dim
|
||||
self.hidden_dim = config.model.hidden_dim
|
||||
self.dropout = config.model.dropout
|
||||
self.use_pcnn = config.cnn.use_pcnn
|
||||
self.out_channels = config.cnn.out_channels
|
||||
self.kernel_size = config.cnn.kernel_size
|
||||
self.out_dim = config.relation_type
|
||||
self.dropout = config.dropout
|
||||
|
||||
if isinstance(self.kernel_size, int):
|
||||
self.kernel_size = [self.kernel_size]
|
||||
for k in self.kernel_size:
|
||||
assert k % 2 == 1, "kernel size has to be odd numbers."
|
||||
|
||||
self.embedding = Embedding(vocab_size, self.word_dim, self.pos_size,
|
||||
self.embedding = Embedding(self.vocab_size, self.word_dim, self.pos_size,
|
||||
self.pos_dim)
|
||||
# PCNN embedding
|
||||
self.mask_embed = nn.Embedding(4, 3)
|
||||
masks = torch.tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0,
|
||||
masks = torch.Tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0,
|
||||
100]])
|
||||
self.mask_embed.weight.data.copy_(masks)
|
||||
self.mask_embed.weight.requires_grad = False
|
||||
|
|
|
@ -8,20 +8,20 @@ class Capsule(BasicModule):
|
|||
def __init__(self, vocab_size, config):
|
||||
super(Capsule, self).__init__()
|
||||
self.model_name = 'Capsule'
|
||||
self.word_dim = config.word_dim
|
||||
self.pos_size = config.pos_size
|
||||
self.pos_dim = config.pos_dim
|
||||
self.hidden_dim = config.hidden_dim
|
||||
self.dropout = config.dropout
|
||||
self.vocab_size = vocab_size
|
||||
self.word_dim = config.model.word_dim
|
||||
self.pos_size = config.model.pos_size
|
||||
self.pos_dim = config.model.pos_dim
|
||||
self.hidden_dim = config.model.hidden_dim
|
||||
|
||||
self.num_primary_units = config.num_primary_units
|
||||
self.num_output_units = config.num_output_units
|
||||
self.primary_channels = config.primary_channels
|
||||
self.primary_unit_size = config.primary_unit_size
|
||||
self.output_unit_size = config.output_unit_size
|
||||
self.num_iterations = config.num_iterations
|
||||
self.num_primary_units = config.capsule.num_primary_units
|
||||
self.num_output_units = config.capsule.num_output_units
|
||||
self.primary_channels = config.capsule.primary_channels
|
||||
self.primary_unit_size = config.capsule.primary_unit_size
|
||||
self.output_unit_size = config.capsule.output_unit_size
|
||||
self.num_iterations = config.capsule.num_iterations
|
||||
|
||||
self.embedding = Embedding(vocab_size, self.word_dim, self.pos_size,
|
||||
self.embedding = Embedding(self.vocab_size, self.word_dim, self.pos_size,
|
||||
self.pos_dim)
|
||||
self.input_dim = self.word_dim + self.pos_dim * 2
|
||||
self.lstm = VarLenLSTM(
|
||||
|
|
|
@ -13,8 +13,8 @@ class Embedding(nn.Module):
|
|||
def forward(self, x):
|
||||
words, head_pos, tail_pos = x
|
||||
word_embed = self.word_embed(words)
|
||||
head_embed = self.head_pos_embed(head_pos)
|
||||
tail_embed = self.tail_pos_embed(tail_pos)
|
||||
feature_embed = torch.cat([word_embed, head_embed, tail_embed], dim=-1)
|
||||
head_pos_embed = self.head_pos_embed(head_pos)
|
||||
tail_pos_embed = self.tail_pos_embed(tail_pos)
|
||||
feature_embed = torch.cat([word_embed, head_pos_embed, tail_pos_embed], dim=-1)
|
||||
|
||||
return feature_embed
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from deepke.model import BasicModule, Embedding
|
||||
|
||||
|
||||
class CNN(BasicModule):
|
||||
def __init__(self, vocab_size, config):
|
||||
super(CNN, self).__init__()
|
||||
self.model_name = 'CNN'
|
||||
# TODO
|
|
@ -3,17 +3,18 @@ from deepke.model import BasicModule
|
|||
from pytorch_transformers import BertModel
|
||||
|
||||
|
||||
class Bert(BasicModule):
|
||||
class LM(BasicModule):
|
||||
def __init__(self, vocab_size, config):
|
||||
super(Bert, self).__init__()
|
||||
self.model_name = 'Bert'
|
||||
self.lm_name = config.lm_name
|
||||
super(LM, self).__init__()
|
||||
self.model_name = 'LM'
|
||||
self.lm_name = config.lm.lm_file
|
||||
self.out_dim = config.relation_type
|
||||
|
||||
self.lm = BertModel.from_pretrained(self.lm_name)
|
||||
self.fc = nn.Linear(768, self.out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = x[0]
|
||||
out = self.lm(x)[-1]
|
||||
out = self.fc(out)
|
||||
return out
|
|
@ -55,16 +55,17 @@ class BiLSTM(BasicModule):
|
|||
def __init__(self, vocab_size, config):
|
||||
super(BiLSTM, self).__init__()
|
||||
self.model_name = 'BiLSTM'
|
||||
self.word_dim = config.word_dim
|
||||
self.pos_size = config.pos_size
|
||||
self.pos_dim = config.pos_dim
|
||||
self.hidden_dim = config.hidden_dim
|
||||
self.lstm_layers = config.lstm_layers
|
||||
self.last_hn = config.last_hn
|
||||
self.vocab_size = vocab_size
|
||||
self.word_dim = config.model.word_dim
|
||||
self.pos_size = config.model.pos_size
|
||||
self.pos_dim = config.model.pos_dim
|
||||
self.hidden_dim = config.model.hidden_dim
|
||||
self.dropout = config.model.dropout
|
||||
self.lstm_layers = config.rnn.lstm_layers
|
||||
self.last_hn = config.rnn.last_hn
|
||||
self.out_dim = config.relation_type
|
||||
self.dropout = config.dropout
|
||||
|
||||
self.embedding = Embedding(vocab_size, self.word_dim, self.pos_size,
|
||||
self.embedding = Embedding(self.vocab_size, self.word_dim, self.pos_size,
|
||||
self.pos_dim)
|
||||
self.input_dim = self.word_dim + self.pos_dim * 2
|
||||
self.lstm = VarLenLSTM(self.input_dim,
|
||||
|
@ -92,13 +93,13 @@ class BiLSTM(BasicModule):
|
|||
|
||||
if __name__ == '__main__':
|
||||
torch.manual_seed(1)
|
||||
x = torch.tensor([
|
||||
x = torch.Tensor([
|
||||
[1, 2, 3, 4, 3, 2],
|
||||
[1, 2, 3, 0, 0, 0],
|
||||
[2, 4, 3, 0, 0, 0],
|
||||
[2, 3, 0, 0, 0, 0],
|
||||
])
|
||||
x_len = torch.tensor([6, 3, 3, 2])
|
||||
x_len = torch.Tensor([6, 3, 3, 2])
|
||||
embedding = nn.Embedding(5, 10, padding_idx=0)
|
||||
model = VarLenLSTM(input_size=10,
|
||||
hidden_size=30,
|
|
@ -82,15 +82,16 @@ class Transformer(BasicModule):
|
|||
def __init__(self, vocab_size, config):
|
||||
super(Transformer, self).__init__()
|
||||
self.model_name = 'Transformer'
|
||||
self.word_dim = config.word_dim
|
||||
self.pos_size = config.pos_size
|
||||
self.pos_dim = config.pos_dim
|
||||
self.hidden_dim = config.hidden_dim
|
||||
self.dropout = config.dropout
|
||||
self.vocab_size = vocab_size
|
||||
self.word_dim = config.model.word_dim
|
||||
self.pos_size = config.model.pos_size
|
||||
self.pos_dim = config.model.pos_dim
|
||||
self.hidden_dim = config.model.hidden_dim
|
||||
self.dropout = config.model.dropout
|
||||
self.layers = config.transformer.transformer_layers
|
||||
self.out_dim = config.relation_type
|
||||
self.layers = config.transformer_layers
|
||||
|
||||
self.embedding = Embedding(vocab_size, self.word_dim, self.pos_size,
|
||||
self.embedding = Embedding(self.vocab_size, self.word_dim, self.pos_size,
|
||||
self.pos_dim)
|
||||
self.feature_dim = self.word_dim + self.pos_dim * 2
|
||||
self.att = MultiHeadAttention(self.feature_dim, num_head=4)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from .Embedding import Embedding
|
||||
from .BasicModule import BasicModule
|
||||
from .Transformer import Transformer
|
||||
from .BiLSTM import BiLSTM, VarLenLSTM
|
||||
from .Embedding import Embedding
|
||||
from .CNN import CNN
|
||||
from .RNN import VarLenLSTM, BiLSTM
|
||||
from .GCN import GCN
|
||||
from .Transformer import Transformer
|
||||
from .Capsule import Capsule
|
||||
from .Bert import Bert
|
||||
from .LM import LM
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
import os
|
||||
import jieba
|
||||
import logging
|
||||
from typing import List, Dict
|
||||
from pytorch_transformers import BertTokenizer
|
||||
# self file
|
||||
from deepke.vocab import Vocab
|
||||
from deepke.config import config
|
||||
from deepke.utils import ensure_dir, save_pkl, load_csv
|
||||
|
||||
jieba.setLogLevel(logging.INFO)
|
||||
|
||||
Path = str
|
||||
|
||||
|
||||
def _mask_feature(entities_idx: List, sen_len: int) -> List:
|
||||
left = [1] * (entities_idx[0] + 1)
|
||||
middle = [2] * (entities_idx[1] - entities_idx[0] - 1)
|
||||
right = [3] * (sen_len - entities_idx[1])
|
||||
|
||||
return left + middle + right
|
||||
|
||||
|
||||
def _pos_feature(sent_len: int, entity_idx: int, entity_len: int,
|
||||
pos_limit: int) -> List:
|
||||
|
||||
left = list(range(-entity_idx, 0))
|
||||
middle = [0] * entity_len
|
||||
right = list(range(1, sent_len - entity_idx - entity_len + 1))
|
||||
pos = left + middle + right
|
||||
|
||||
for i, p in enumerate(pos):
|
||||
if p > pos_limit:
|
||||
pos[i] = pos_limit
|
||||
if p < -pos_limit:
|
||||
pos[i] = -pos_limit
|
||||
pos = [p + pos_limit + 1 for p in pos]
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def _build_data(data: List[Dict], vocab: Vocab, relations: Dict) -> List[Dict]:
|
||||
|
||||
if vocab.name == 'LM':
|
||||
for d in data:
|
||||
d['seq_len'] = len(d['lm_idx'])
|
||||
d['target'] = relations[d['relation']]
|
||||
|
||||
return data
|
||||
|
||||
for d in data:
|
||||
if vocab.name == 'word':
|
||||
word2idx = [vocab.word2idx.get(w, 1) for w in d['words']]
|
||||
seq_len = len(word2idx)
|
||||
head_idx, tail_idx = d['head_idx'], d['tail_idx']
|
||||
head_len, tail_len = 1, 1
|
||||
|
||||
elif vocab.name == 'char':
|
||||
word2idx = [
|
||||
vocab.word2idx.get(w, 1) for w in d['sentence'].strip()
|
||||
]
|
||||
seq_len = len(word2idx)
|
||||
head_idx, tail_idx = int(d['head_offset']), int(d['tail_offset'])
|
||||
head_len, tail_len = len(d['head']), len(d['tail'])
|
||||
|
||||
entities_idx = [head_idx, tail_idx
|
||||
] if tail_idx > head_idx else [tail_idx, head_idx]
|
||||
head_pos = _pos_feature(seq_len, head_idx, head_len, config.pos_limit)
|
||||
tail_pos = _pos_feature(seq_len, tail_idx, tail_len, config.pos_limit)
|
||||
mask_pos = _mask_feature(entities_idx, seq_len)
|
||||
target = relations[d['relation']]
|
||||
|
||||
d['word2idx'] = word2idx
|
||||
d['seq_len'] = seq_len
|
||||
d['head_pos'] = head_pos
|
||||
d['tail_pos'] = tail_pos
|
||||
d['mask_pos'] = mask_pos
|
||||
d['target'] = target
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _build_vocab(data: List[Dict], out_path: Path) -> Vocab:
|
||||
if config.word_segment:
|
||||
vocab = Vocab('word')
|
||||
for d in data:
|
||||
vocab.add_sent(d['words'])
|
||||
else:
|
||||
vocab = Vocab('char')
|
||||
for d in data:
|
||||
vocab.add_sent(d['sentence'].strip())
|
||||
vocab.trim(config.min_freq)
|
||||
|
||||
ensure_dir(out_path)
|
||||
vocab_path = os.path.join(out_path, 'vocab.pkl')
|
||||
vocab_txt = os.path.join(out_path, 'vocab.txt')
|
||||
save_pkl(vocab_path, vocab, 'vocab')
|
||||
with open(vocab_txt, 'w', encoding='utf-8') as f:
|
||||
f.write(os.linesep.join([word for word in vocab.word2idx.keys()]))
|
||||
return vocab
|
||||
|
||||
|
||||
def _split_sent(data: List[Dict], verbose: bool = True) -> List[Dict]:
|
||||
if verbose:
|
||||
print('need word segment, use jieba to split sentence')
|
||||
|
||||
jieba.add_word('HEAD')
|
||||
jieba.add_word('TAIL')
|
||||
|
||||
for d in data:
|
||||
sent = d['sentence'].strip()
|
||||
sent = sent.replace(d['head'], 'HEAD', 1)
|
||||
sent = sent.replace(d['tail'], 'TAIL', 1)
|
||||
sent = jieba.lcut(sent)
|
||||
head_idx, tail_idx = sent.index('HEAD'), sent.index('TAIL')
|
||||
sent[head_idx], sent[tail_idx] = d['head'], d['tail']
|
||||
d['words'] = sent
|
||||
d['head_idx'] = head_idx
|
||||
d['tail_idx'] = tail_idx
|
||||
return data
|
||||
|
||||
|
||||
def _add_lm_data(data: List[Dict]) -> List[Dict]:
|
||||
'使用语言模型的词表,序列化输入的句子'
|
||||
tokenizer = BertTokenizer.from_pretrained('../bert_pretrained')
|
||||
|
||||
for d in data:
|
||||
sent = d['sentence'].strip()
|
||||
d['seq_len'] = len(sent)
|
||||
sent = sent.replace(d['head'], d['head_type'], 1)
|
||||
sent = sent.replace(d['tail'], d['tail_type'], 1)
|
||||
sent += '[SEP]' + d['head'] + '[SEP]' + d['tail']
|
||||
d['lm_idx'] = tokenizer.encode(sent, add_special_tokens=True)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _load_relations(fp: Path) -> Dict:
|
||||
'读取关系文件,并将关系保存为词典格式,用来序列化关系'
|
||||
|
||||
print(f'load {fp}')
|
||||
relations_arr = []
|
||||
relations_dict = {}
|
||||
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
for l in f:
|
||||
relations_arr.append(l.strip())
|
||||
|
||||
for k, v in enumerate(relations_arr):
|
||||
relations_dict[v] = k
|
||||
|
||||
return relations_dict
|
||||
|
||||
|
||||
def process(data_path: Path, out_path: Path) -> None:
|
||||
print('===== start preprocess data =====')
|
||||
train_fp = os.path.join(data_path, 'train.csv')
|
||||
test_fp = os.path.join(data_path, 'test.csv')
|
||||
relation_fp = os.path.join(data_path, 'relation.txt')
|
||||
|
||||
print('load raw files...')
|
||||
train_raw_data = load_csv(train_fp)
|
||||
test_raw_data = load_csv(test_fp)
|
||||
relations = _load_relations(relation_fp)
|
||||
|
||||
# 使用预训练语言模型时
|
||||
if config.model_name == 'LM':
|
||||
print('\nuse pretrained language model serialize sentence...')
|
||||
train_raw_data = _add_lm_data(train_raw_data)
|
||||
test_raw_data = _add_lm_data(test_raw_data)
|
||||
vocab = Vocab('LM')
|
||||
|
||||
else:
|
||||
# 当为中文时是否需要分词操作,如果句子已为分词的结果,则不需要分词
|
||||
print('\nverify whether need split words...')
|
||||
if config.is_chinese and config.word_segment:
|
||||
train_raw_data = _split_sent(train_raw_data)
|
||||
test_raw_data = _split_sent(test_raw_data, verbose=False)
|
||||
|
||||
print('build word vocabulary...')
|
||||
vocab = _build_vocab(train_raw_data, out_path)
|
||||
|
||||
print('\nbuild train data...')
|
||||
train_data = _build_data(train_raw_data, vocab, relations)
|
||||
print('build test data...\n')
|
||||
test_data = _build_data(test_raw_data, vocab, relations)
|
||||
|
||||
ensure_dir(out_path)
|
||||
train_data_path = os.path.join(out_path, 'train.pkl')
|
||||
test_data_path = os.path.join(out_path, 'test.pkl')
|
||||
|
||||
save_pkl(train_data_path, train_data, 'train data')
|
||||
save_pkl(test_data_path, test_data, 'test data')
|
||||
|
||||
print('===== end preprocess data =====')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path = '../data/origin'
|
||||
out_path = '../data/out'
|
||||
|
||||
process(data_path, out_path)
|
|
@ -1,246 +0,0 @@
|
|||
import os
|
||||
import csv
|
||||
import json
|
||||
import torch
|
||||
import jieba
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
# self file
|
||||
from deepke.config import config
|
||||
from deepke.vocab import Vocab
|
||||
from deepke.utils import ensure_dir, save_pkl, load_csv, load_jsonld
|
||||
from pytorch_transformers import BertTokenizer
|
||||
|
||||
jieba.setLogLevel(logging.INFO)
|
||||
|
||||
|
||||
def build_lm_data(raw_data: List) -> List:
|
||||
tokenizer = BertTokenizer.from_pretrained(config.lm_name)
|
||||
sents = []
|
||||
for data in raw_data:
|
||||
sent = data[0]
|
||||
sub = data[1]
|
||||
obj = data[4]
|
||||
sent = '[CLS]' + sent + '[SEP]' + sub + '[SEP]' + obj + '[SEP]'
|
||||
input_ids = torch.tensor([tokenizer.encode(sent)])
|
||||
sents.append(input_ids)
|
||||
return sents
|
||||
|
||||
|
||||
def mask_feature(entities_pos: List, sen_len: int) -> List:
|
||||
left = [1] * (entities_pos[0] + 1)
|
||||
middle = [2] * (entities_pos[1] - entities_pos[0] - 1)
|
||||
right = [3] * (sen_len - entities_pos[1])
|
||||
return left + middle + right
|
||||
|
||||
|
||||
def pos_feature(sent_len: int, entity_pos: int, entity_len: int,
|
||||
pos_limit: int) -> List:
|
||||
left = list(range(-entity_pos, 0))
|
||||
middle = [0] * entity_len
|
||||
right = list(range(1, sent_len - entity_pos - entity_len + 1))
|
||||
pos = left + middle + right
|
||||
for i, p in enumerate(pos):
|
||||
if p > pos_limit:
|
||||
pos[i] = pos_limit
|
||||
if p < -pos_limit:
|
||||
pos[i] = -pos_limit
|
||||
pos = [p + pos_limit + 1 for p in pos]
|
||||
return pos
|
||||
|
||||
|
||||
def build_data(raw_data: List[List], vocab) -> Tuple[List, List, List, List]:
|
||||
sents = []
|
||||
head_pos = []
|
||||
tail_pos = []
|
||||
mask_pos = []
|
||||
|
||||
if vocab.name == 'word':
|
||||
for data in raw_data:
|
||||
sent = [vocab.word2idx.get(w, 1) for w in data[-2]]
|
||||
pos = list(range(len(sent)))
|
||||
head, tail = int(data[-1][0]), int(data[-1][1])
|
||||
entities_pos = [head, tail] if tail > head else [tail, head]
|
||||
head_p = pos_feature(len(sent), head, 1, config.pos_limit)
|
||||
tail_p = pos_feature(len(sent), tail, 1, config.pos_limit)
|
||||
mask_p = mask_feature(entities_pos, len(sent))
|
||||
sents.append(sent)
|
||||
head_pos.append(head_p)
|
||||
tail_pos.append(tail_p)
|
||||
mask_pos.append(mask_p)
|
||||
|
||||
else:
|
||||
for data in raw_data:
|
||||
sent = [vocab.word2idx.get(w, 1) for w in data[0]]
|
||||
head, tail = int(data[3]), int(data[6])
|
||||
head_len, tail_len = len(data[1]), len(data[4])
|
||||
entities_pos = [head, tail] if tail > head else [tail, head]
|
||||
head_p = pos_feature(len(sent), head, head_len, config.pos_limit)
|
||||
tail_p = pos_feature(len(sent), tail, tail_len, config.pos_limit)
|
||||
mask_p = mask_feature(entities_pos, len(sent))
|
||||
head_pos.append(head_p)
|
||||
tail_pos.append(tail_p)
|
||||
mask_pos.append(mask_p)
|
||||
sents.append(sent)
|
||||
return sents, head_pos, tail_pos, mask_pos
|
||||
|
||||
|
||||
def relation_tokenize(relations: List[str], fp: str) -> List[int]:
|
||||
rels_arr = []
|
||||
rels = {}
|
||||
out = []
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
for l in f:
|
||||
rels_arr.append(l.strip())
|
||||
for i, rel in enumerate(rels_arr):
|
||||
rels[rel] = i
|
||||
for rel in relations:
|
||||
out.append(rels[rel])
|
||||
return out
|
||||
|
||||
|
||||
def build_vocab(raw_data: List[List], out_path: str) -> Tuple[Vocab, str]:
|
||||
if config.word_segment:
|
||||
vocab = Vocab('word')
|
||||
for data in raw_data:
|
||||
vocab.add_sent(data[-2])
|
||||
else:
|
||||
vocab = Vocab('char')
|
||||
for data in raw_data:
|
||||
vocab.add_sent(data[0])
|
||||
vocab.trim(config.min_freq)
|
||||
|
||||
ensure_dir(out_path)
|
||||
vocab_path = os.path.join(out_path, 'vocab.pkl')
|
||||
vocab_txt = os.path.join(out_path, 'vocab.txt')
|
||||
save_pkl(vocab_path, vocab, 'vocab')
|
||||
with open(vocab_txt, 'w', encoding='utf-8') as f:
|
||||
f.write(os.linesep.join([word for word in vocab.word2idx.keys()]))
|
||||
return vocab, vocab_path
|
||||
|
||||
|
||||
def split_sents(raw_data: List[List], verbose: bool = True) -> List[List]:
|
||||
if verbose:
|
||||
print('need word segment, use jieba to split sentence')
|
||||
new_data = []
|
||||
jieba.add_word('HEAD')
|
||||
jieba.add_word('TAIL')
|
||||
for data in raw_data:
|
||||
head, tail = data[2], data[5]
|
||||
sent = data[0].replace(data[1], 'HEAD', 1)
|
||||
sent = sent.replace(data[4], 'TAIL', 1)
|
||||
sent = jieba.lcut(sent)
|
||||
head_pos, tail_pos = sent.index('HEAD'), sent.index('TAIL')
|
||||
sent[head_pos] = head
|
||||
sent[tail_pos] = tail
|
||||
data.append(sent)
|
||||
data.append([head_pos, tail_pos])
|
||||
new_data.append(data)
|
||||
return new_data
|
||||
|
||||
|
||||
def exist_relation(fp: str, file_type: str) -> int:
|
||||
'''
|
||||
判断文件是否存在关系数据,即判断文件是用来训练还是用来预测
|
||||
当存在关系数据时,返回对应所在的列值(int number >= 0)
|
||||
当不存在时,返回 -1
|
||||
:param fp: 文件地址
|
||||
:return: 数值
|
||||
'''
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
if file_type == 'csv':
|
||||
f = csv.DictReader(f)
|
||||
for l in f:
|
||||
if file_type == 'jsonld':
|
||||
l = json.loads(l)
|
||||
keys = list(l.keys())
|
||||
try:
|
||||
num = keys.index('relation')
|
||||
except:
|
||||
num = -1
|
||||
return num
|
||||
|
||||
|
||||
def process(data_path: str, out_path: str, file_type: str) -> None:
|
||||
print('===== start preprocess data =====')
|
||||
|
||||
file_type = file_type.lower()
|
||||
assert file_type in ['csv', 'jsonld']
|
||||
|
||||
print('load raw files...')
|
||||
train_fp = os.path.join(data_path, 'train.' + file_type)
|
||||
test_fp = os.path.join(data_path, 'test.' + file_type)
|
||||
relation_fp = os.path.join(data_path, 'relation.txt')
|
||||
|
||||
relation_place = exist_relation(train_fp, file_type)
|
||||
if file_type == 'csv':
|
||||
train_raw_data = load_csv(train_fp)
|
||||
test_raw_data = load_csv(test_fp)
|
||||
else:
|
||||
train_raw_data = load_jsonld(train_fp)
|
||||
test_raw_data = load_jsonld(test_fp)
|
||||
train_relation = []
|
||||
test_relation = []
|
||||
if relation_place > -1:
|
||||
for data in train_raw_data:
|
||||
train_relation.append(data.pop(relation_place))
|
||||
for data in test_raw_data:
|
||||
test_relation.append(data.pop(relation_place))
|
||||
|
||||
# 使用语言模型预训练时
|
||||
if config.model_name == 'Bert':
|
||||
train_lm_sents = build_lm_data(train_raw_data)
|
||||
test_lm_sents = build_lm_data(test_raw_data)
|
||||
|
||||
# 当为中文时是否需要分词操作,如果sentence已经为分词的结果,则不需要分词
|
||||
print('\nverify whether need split words...')
|
||||
if config.is_chinese and config.word_segment:
|
||||
train_raw_data = split_sents(train_raw_data)
|
||||
test_raw_data = split_sents(test_raw_data, verbose=False)
|
||||
|
||||
print('build sentence vocab...')
|
||||
vocab, vocab_path = build_vocab(train_raw_data, out_path)
|
||||
|
||||
print('\nbuild train data...')
|
||||
train_sents, train_head_pos, train_tail_pos, train_mask_pos = build_data(
|
||||
train_raw_data, vocab)
|
||||
print('build test data...')
|
||||
test_sents, test_head_pos, test_tail_pos, test_mask_pos = build_data(
|
||||
test_raw_data, vocab)
|
||||
print('build relation data...\n')
|
||||
train_rel_tokens = relation_tokenize(train_relation, relation_fp)
|
||||
test_rel_tokens = relation_tokenize(test_relation, relation_fp)
|
||||
|
||||
train_data = list(
|
||||
zip(train_sents, train_head_pos, train_tail_pos, train_mask_pos,
|
||||
train_rel_tokens))
|
||||
test_data = list(
|
||||
zip(test_sents, test_head_pos, test_tail_pos, test_mask_pos,
|
||||
test_rel_tokens))
|
||||
|
||||
if config.model_name == 'Bert':
|
||||
train_data = list(zip(train_lm_sents, train_rel_tokens))
|
||||
test_data = list(zip(test_lm_sents, test_rel_tokens))
|
||||
|
||||
ensure_dir(out_path)
|
||||
train_data_path = os.path.join(out_path, 'train.pkl')
|
||||
test_data_path = os.path.join(out_path, 'test.pkl')
|
||||
|
||||
save_pkl(train_data_path, train_data, 'train data')
|
||||
save_pkl(test_data_path, test_data, 'test data')
|
||||
|
||||
if config.model_name == 'Bert':
|
||||
train_lm_data_path = os.path.join(out_path, 'train_lm.pkl')
|
||||
test_lm_data_path = os.path.join(out_path, 'test_lm.pkl')
|
||||
|
||||
save_pkl(train_lm_data_path, train_data, 'train data')
|
||||
save_pkl(test_lm_data_path, test_data, 'test data')
|
||||
|
||||
print('===== end preprocess data =====')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data_path = '../data/origin'
|
||||
out_path = '../data/out'
|
||||
|
||||
process(data_path, out_path, file_type='csv')
|
|
@ -9,13 +9,14 @@ def train(epoch, device, dataloader, model, optimizer, criterion, config):
|
|||
model.train()
|
||||
total_loss = []
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader, 1):
|
||||
*x, y = [data.to(device) for data in batch]
|
||||
for batch_idx, (*x, y) in enumerate(dataloader, 1):
|
||||
x = [i.to(device) for i in x]
|
||||
y = y.to(device)
|
||||
optimizer.zero_grad()
|
||||
y_pred = model(x)
|
||||
|
||||
if model.model_name == 'Capsule':
|
||||
y = to_one_hot(y,config.relation_type)
|
||||
y = to_one_hot(y, config.relation_type)
|
||||
loss = model.loss(y_pred, y)
|
||||
else:
|
||||
loss = criterion(y_pred, y)
|
||||
|
@ -46,8 +47,9 @@ def validate(dataloader, model, device, config):
|
|||
with torch.no_grad():
|
||||
total_y_true = np.empty(0)
|
||||
total_y_pred = np.empty(0)
|
||||
for batch_idx, batch in enumerate(dataloader, 1):
|
||||
*x, y = [data.to(device) for data in batch]
|
||||
for batch_idx, (*x, y) in enumerate(dataloader, 1):
|
||||
x = [i.to(device) for i in x]
|
||||
y = y.to(device)
|
||||
y_pred = model(x)
|
||||
|
||||
if model.model_name == 'Capsule':
|
||||
|
|
|
@ -23,6 +23,7 @@ __all__ = [
|
|||
'csv2jsonld',
|
||||
]
|
||||
|
||||
Path = str
|
||||
|
||||
def to_one_hot(x, length):
|
||||
batch_size = x.size(0)
|
||||
|
@ -137,7 +138,7 @@ def load_pkl(fp: str, obj_name: str = 'data', verbose: bool = True) -> Any:
|
|||
return data
|
||||
|
||||
|
||||
def save_pkl(fp: str, obj, obj_name: str = 'data',
|
||||
def save_pkl(fp: Path, obj, obj_name: str = 'data',
|
||||
verbose: bool = True) -> None:
|
||||
if verbose:
|
||||
print(f'save {obj_name} in {fp}')
|
||||
|
@ -160,14 +161,11 @@ def ensure_dir(d: str, verbose: bool = True) -> None:
|
|||
|
||||
def load_csv(fp: str) -> List:
|
||||
print(f'load {fp}')
|
||||
datas = []
|
||||
|
||||
with open(fp, encoding='utf-8') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for line in reader:
|
||||
data = list(line.values())
|
||||
datas.append(data)
|
||||
return datas
|
||||
return list(reader)
|
||||
|
||||
|
||||
|
||||
def load_jsonld(fp: str) -> List:
|
||||
|
|
|
@ -12,13 +12,13 @@ class Vocab(object):
|
|||
self.word2count = {}
|
||||
self.idx2word = {}
|
||||
self.count = 0
|
||||
self.add_init_tokens()
|
||||
self._add_init_tokens()
|
||||
|
||||
def add_init_tokens(self):
|
||||
def _add_init_tokens(self):
|
||||
for token in self.init_tokens:
|
||||
self.add_word(token)
|
||||
self._add_word(token)
|
||||
|
||||
def add_word(self, word):
|
||||
def _add_word(self, word: str):
|
||||
if word not in self.word2idx:
|
||||
self.word2idx[word] = self.count
|
||||
self.word2count[word] = 1
|
||||
|
@ -29,12 +29,13 @@ class Vocab(object):
|
|||
|
||||
def add_sent(self, sent: str):
|
||||
for word in sent:
|
||||
self.add_word(word)
|
||||
self._add_word(word)
|
||||
|
||||
def trim(self, min_freq=2, verbose: bool = True):
|
||||
'''
|
||||
当 word 词频低于 min_freq 时,从词库中删除
|
||||
:param min_freq: 最低词频
|
||||
'''当 word 词频低于 min_freq 时,从词库中删除
|
||||
|
||||
Args:
|
||||
param min_freq: 最低词频
|
||||
'''
|
||||
if self.trimed:
|
||||
return
|
||||
|
@ -42,32 +43,34 @@ class Vocab(object):
|
|||
|
||||
keep_words = []
|
||||
new_words = []
|
||||
keep_words.extend(self.init_tokens)
|
||||
new_words.extend(self.init_tokens)
|
||||
|
||||
for k, v in self.word2count.items():
|
||||
if v >= min_freq:
|
||||
keep_words.append(k)
|
||||
new_words.extend([k] * v)
|
||||
if verbose:
|
||||
print('after trim, keep words [{} / {}] = {:.2f}%'.format(
|
||||
len(keep_words + self.init_tokens), len(self.word2idx),
|
||||
len(keep_words + self.init_tokens) / len(self.word2idx) * 100))
|
||||
len(keep_words), len(self.word2idx),
|
||||
len(keep_words) / len(self.word2idx) * 100))
|
||||
|
||||
# Reinitialize dictionaries
|
||||
self.word2idx = {}
|
||||
self.word2count = {}
|
||||
self.idx2word = {}
|
||||
self.count = 0
|
||||
self.add_init_tokens()
|
||||
for word in new_words:
|
||||
self.add_word(word)
|
||||
self._add_word(word)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from nltk import word_tokenize
|
||||
vocab = Vocab('test')
|
||||
sent = ' 我是中国人,我爱中国。'
|
||||
# english
|
||||
# from nltk import word_tokenize
|
||||
# sent = "I'm chinese, I love China."
|
||||
# words = word_tokenize(sent)
|
||||
vocab = Vocab('test')
|
||||
sent = ' 我是中国人,我 爱中国。'
|
||||
print(sent, '\n')
|
||||
vocab.add_sent(sent)
|
||||
print(vocab.word2idx)
|
||||
|
|
81
main.py
81
main.py
|
@ -8,78 +8,63 @@ from deepke.config import config
|
|||
from deepke import model
|
||||
from deepke.utils import make_seed, load_pkl
|
||||
from deepke.trainer import train, validate
|
||||
from deepke.process import process
|
||||
from deepke.dataset import CustomDataset, CustomLMDataset, collate_fn, collate_fn_lm
|
||||
from deepke.preprocess import process
|
||||
from deepke.dataset import CustomDataset, collate_fn
|
||||
|
||||
__Models__ = {
|
||||
"CNN": model.CNN,
|
||||
"BiLSTM": model.BiLSTM,
|
||||
"RNN": model.BiLSTM,
|
||||
"GCN": model.GCN,
|
||||
"Transformer": model.Transformer,
|
||||
"Capsule": model.Capsule,
|
||||
"Bert": model.Bert,
|
||||
"LM": model.LM,
|
||||
}
|
||||
|
||||
parser = argparse.ArgumentParser(description='choose your model')
|
||||
parser.add_argument('--model_name', type=str, default='CNN', help='model name')
|
||||
parser.add_argument('--model_name', type=str, help='model name')
|
||||
args = parser.parse_args()
|
||||
model_name = args.model_name if args.model_name else config.model_name
|
||||
|
||||
make_seed(config.seed)
|
||||
make_seed(config.training.seed)
|
||||
|
||||
if config.use_gpu and torch.cuda.is_available():
|
||||
if config.training.use_gpu and torch.cuda.is_available():
|
||||
device = torch.device('cuda', config.gpu_id)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
if not os.path.exists(config.out_path):
|
||||
process(config.data_path, config.out_path, file_type='csv')
|
||||
# if not os.path.exists(config.out_path):
|
||||
process(config.data_path, config.out_path)
|
||||
|
||||
if config.model_name == 'Bert':
|
||||
vocab_path = os.path.join(config.out_path, 'bert_vocab.txt')
|
||||
train_data_path = os.path.join(config.out_path, 'train_lm.pkl')
|
||||
test_data_path = os.path.join(config.out_path, 'test_lm.pkl')
|
||||
train_data_path = os.path.join(config.out_path, 'train.pkl')
|
||||
test_data_path = os.path.join(config.out_path, 'test.pkl')
|
||||
|
||||
if model_name == 'LM':
|
||||
vocab_size = None
|
||||
else:
|
||||
vocab_path = os.path.join(config.out_path, 'vocab.pkl')
|
||||
train_data_path = os.path.join(config.out_path, 'train.pkl')
|
||||
test_data_path = os.path.join(config.out_path, 'test.pkl')
|
||||
vocab = load_pkl(vocab_path)
|
||||
vocab_size = len(vocab.word2idx)
|
||||
|
||||
vocab = load_pkl(vocab_path)
|
||||
vocab_size = len(vocab.word2idx)
|
||||
|
||||
if config.model_name == 'Bert':
|
||||
train_dataset = CustomLMDataset(train_data_path)
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn_lm)
|
||||
test_dataset = CustomLMDataset(test_data_path)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn_lm,
|
||||
)
|
||||
else:
|
||||
train_dataset = CustomDataset(train_data_path)
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn)
|
||||
test_dataset = CustomDataset(test_data_path)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
train_dataset = CustomDataset(train_data_path)
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn)
|
||||
test_dataset = CustomDataset(test_data_path)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=config.training.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
model = __Models__[model_name](vocab_size, config)
|
||||
model.to(device)
|
||||
print(model)
|
||||
# print(model)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.training.learning_rate)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, 'max', factor=config.decay_rate, patience=config.decay_patience)
|
||||
optimizer, 'max', factor=config.training.decay_rate, patience=config.training.decay_patience)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
best_macro_f1, best_macro_epoch = 0, 1
|
||||
|
@ -87,7 +72,7 @@ best_micro_f1, best_micro_epoch = 0, 1
|
|||
best_macro_model, best_micro_model = '', ''
|
||||
print('=' * 10, ' Start training ', '=' * 10)
|
||||
|
||||
for epoch in range(1, config.epoch + 1):
|
||||
for epoch in range(1, config.training.epoch + 1):
|
||||
train(epoch, device, train_dataloader, model, optimizer, criterion, config)
|
||||
macro_f1, micro_f1 = validate(test_dataloader, model, device, config)
|
||||
model_name = model.save(epoch=epoch)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
torch>=1.0
|
||||
jieba>=0.39
|
||||
scikit_learn>=0.21
|
||||
pytorch_transformers>=1.0
|
||||
matplotlib>=3.1
|
||||
jieba>=0.38
|
||||
pytorch_transformers>=1.2
|
||||
matplotlib>=3.0
|
||||
scikit_learn>=0.20
|
Loading…
Reference in New Issue