parent
33a777cddf
commit
0e61eed58b
|
@ -2,6 +2,7 @@ import torch
|
|||
from . import BasicModule
|
||||
from module import Embedding, CNN
|
||||
from module import Capsule as CapsuleLayer
|
||||
|
||||
from utils import seq_len_to_mask, to_one_hot
|
||||
|
||||
|
||||
|
@ -47,4 +48,4 @@ class Capsule(BasicModule):
|
|||
return loss.sum()
|
||||
else:
|
||||
# 默认情况为求平均
|
||||
return loss.mean()
|
||||
return loss.mean()
|
|
@ -13,17 +13,16 @@ class DotAttention(nn.Module):
|
|||
|
||||
def forward(self, Q, K, V, mask_out=None, head_mask=None):
|
||||
"""
|
||||
一般输入信息 X 时,假设 K = V = Xs
|
||||
一般输入信息 X 时,假设 K = V = X
|
||||
|
||||
att_weight = softmax( score_func(q, k) )
|
||||
att = sum( att_weight * v )
|
||||
Args:
|
||||
Q: [..., L, H]
|
||||
K: [..., S, H]
|
||||
V: [..., S, H]
|
||||
mask_out: [..., 1, S]
|
||||
Return:
|
||||
attention_out
|
||||
attention_weight
|
||||
|
||||
:param Q: [..., L, H]
|
||||
:param K: [..., S, H]
|
||||
:param V: [..., S, H]
|
||||
:param mask_out: [..., 1, S]
|
||||
:return:
|
||||
"""
|
||||
H = Q.size(-1)
|
||||
|
||||
|
@ -53,10 +52,9 @@ class DotAttention(nn.Module):
|
|||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, dropout=0.0, output_attentions=True):
|
||||
"""
|
||||
Args:
|
||||
embed_dim: 输入的维度,必须能被 num_heads 整除
|
||||
num_heads: attention 的个数
|
||||
dropout: float。
|
||||
:param embed_dim: 输入的维度,必须能被 num_heads 整除
|
||||
:param num_heads: attention 的个数
|
||||
:param dropout: float。
|
||||
"""
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
|
@ -74,13 +72,12 @@ class MultiHeadAttention(nn.Module):
|
|||
|
||||
def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||
"""
|
||||
Args:
|
||||
Q: [B, L, Hs]
|
||||
K: [B, S, Hs]
|
||||
V: [B, S, Hs]
|
||||
key_padding_mask: [B, S] 为 1/True 的地方需要 mask
|
||||
attention_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
|
||||
head_mask: [N] 指定 head mask 掉, 为 1/True 的地方需要 mask
|
||||
:param Q: [B, L, Hs]
|
||||
:param K: [B, S, Hs]
|
||||
:param V: [B, S, Hs]
|
||||
:param key_padding_mask: [B, S] 为 1/True 的地方需要 mask
|
||||
:param attention_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
|
||||
:param head_mask: [N] 指定 head mask 掉, 为 1/True 的地方需要 mask
|
||||
"""
|
||||
B, L, Hs = Q.shape
|
||||
S = V.size(1)
|
||||
|
@ -123,5 +120,3 @@ class MultiHeadAttention(nn.Module):
|
|||
return attention_out, attention_weight
|
||||
else:
|
||||
return attention_out,
|
||||
|
||||
|
||||
|
|
|
@ -21,13 +21,12 @@ class CNN(nn.Module):
|
|||
"""
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Args:
|
||||
in_channels: 一般就是 word embedding 的维度,或者 hidden size 的维度
|
||||
out_channels: int
|
||||
kernel_sizes: list 为了保证输出长度=输入长度,必须为奇数: 3, 5, 7...
|
||||
activation: [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
pooling_strategy: [max, avg, cls]
|
||||
dropout: float
|
||||
in_channels : 一般就是 word embedding 的维度,或者 hidden size 的维度
|
||||
out_channels : int
|
||||
kernel_sizes : list 为了保证输出长度=输入长度,必须为奇数: 3, 5, 7...
|
||||
activation : [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
pooling_strategy : [max, avg, cls]
|
||||
dropout: : float
|
||||
"""
|
||||
super(CNN, self).__init__()
|
||||
|
||||
|
@ -75,12 +74,9 @@ class CNN(nn.Module):
|
|||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值
|
||||
mask: [batch_size, max_len], 句长部分为0,padding部分为1。不影响卷积运算,max-pool一定不会pool到pad为0的位置
|
||||
Return:
|
||||
x: torch.Tensor
|
||||
xp: torch.Tensor
|
||||
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值
|
||||
:param mask: [batch_size, max_len], 句长部分为0,padding部分为1。不影响卷积运算,max-pool一定不会pool到pad为0的位置
|
||||
:return:
|
||||
"""
|
||||
# [B, L, H] -> [B, H, L] (注释:将 H 维度当作输入 channel 维度)
|
||||
x = torch.transpose(x, 1, 2)
|
||||
|
|
|
@ -27,8 +27,7 @@ class Capsule(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [B, L, H] # 从 CNN / RNN 得到的结果
|
||||
x: [B, L, H] # 从 CNN / RNN 得到的结果
|
||||
L 作为 input_num_capsules, H 作为 input_dim_capsule
|
||||
"""
|
||||
B, I, _ = x.size() # I 是 input_num_capsules
|
||||
|
|
|
@ -5,10 +5,9 @@ import torch.nn as nn
|
|||
class Embedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Args:
|
||||
word embedding: 一般 0 为 padding
|
||||
pos embedding: 一般 0 为 padding
|
||||
dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加
|
||||
word embedding: 一般 0 为 padding
|
||||
pos embedding: 一般 0 为 padding
|
||||
dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加
|
||||
"""
|
||||
super(Embedding, self).__init__()
|
||||
|
||||
|
|
|
@ -46,11 +46,12 @@ class RNN(nn.Module):
|
|||
|
||||
def forward(self, x, x_len):
|
||||
"""
|
||||
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值
|
||||
:param x_len: torch.Tensor [L] 已经排好序的句长值
|
||||
:return:
|
||||
output: torch.Tensor [B, L, H_out] 序列标注的使用结果
|
||||
hn: torch.Tensor [B, N, H_out] / [B, H_out] 分类的结果,当 last_layer_hn 时只有最后一层结果
|
||||
Args:
|
||||
torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值
|
||||
x_len: torch.Tensor [L] 已经排好序的句长值
|
||||
Returns:
|
||||
output: torch.Tensor [B, L, H_out] 序列标注的使用结果
|
||||
hn: torch.Tensor [B, N, H_out] / [B, H_out] 分类的结果,当 last_layer_hn 时只有最后一层结果
|
||||
"""
|
||||
B, L, _ = x.size()
|
||||
H, N = self.hidden_size, self.num_layers
|
||||
|
@ -71,4 +72,3 @@ class RNN(nn.Module):
|
|||
|
||||
return output, hn
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from .dataset import *
|
||||
from .metrics import *
|
||||
from .predict import *
|
||||
from .preprocess import *
|
||||
from .serializer import *
|
||||
from .trainer import *
|
||||
|
|
|
@ -5,7 +5,6 @@ import sys
|
|||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from utils import load_pkl
|
||||
|
||||
|
||||
def collate_fn(cfg):
|
||||
|
||||
def collate_fn_intra(batch):
|
||||
|
@ -71,26 +70,3 @@ class CustomDataset(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return len(self.file)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from torch.utils.data import DataLoader
|
||||
train_data_path = 'data/out/train.pkl'
|
||||
vocab_path = 'data/out/vocab.pkl'
|
||||
unk_str = 'UNK'
|
||||
vocab = load_pkl(vocab_path)
|
||||
train_ds = CustomDataset(train_data_path)
|
||||
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, drop_last=False)
|
||||
for batch_idx, (x, y) in enumerate(train_dl):
|
||||
word = x['word']
|
||||
for idx in word:
|
||||
idx2token = ''.join([vocab.idx2word.get(i, unk_str) for i in idx.numpy()])
|
||||
print(idx2token)
|
||||
print(y)
|
||||
break
|
||||
# x, y = x.to(device), y.to(device)
|
||||
# optimizer.zero_grad()
|
||||
# y_pred = models(y)
|
||||
# loss = criterion(y_pred, y)
|
||||
# loss.backward()
|
||||
# optimizer.step()
|
||||
|
|
|
@ -9,6 +9,8 @@ import matplotlib.pyplot as plt
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
# self
|
||||
import sys
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
import models
|
||||
from preprocess import preprocess
|
||||
from dataset import CustomDataset, collate_fn
|
||||
|
@ -17,10 +19,10 @@ from utils import manual_seed, load_pkl
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@hydra.main(config_path='conf/config.yaml')
|
||||
@hydra.main(config_path='../conf/config.yaml')
|
||||
def main(cfg):
|
||||
cwd = utils.get_original_cwd()
|
||||
cwd = cwd[0:-5]
|
||||
cfg.cwd = cwd
|
||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||
logger.info(f'\n{cfg.pretty()}')
|
||||
|
@ -142,6 +144,4 @@ def main(cfg):
|
|||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
# python predict.py --help # 查看参数帮助
|
||||
# python predict.py -c
|
||||
# python predict.py chinese_split=0,1 replace_entity_with_type=0,1 -m
|
||||
|
|
@ -7,10 +7,10 @@ from hydra import utils
|
|||
from serializer import Serializer
|
||||
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
import models
|
||||
from utils import load_pkl, load_csv
|
||||
import models
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -73,9 +73,11 @@ def _get_predict_instance(cfg):
|
|||
|
||||
|
||||
|
||||
@hydra.main(config_path='conf/config.yaml')
|
||||
|
||||
@hydra.main(config_path='../conf/config.yaml')
|
||||
def main(cfg):
|
||||
cwd = utils.get_original_cwd()
|
||||
cwd = cwd[0:-5]
|
||||
cfg.cwd = cwd
|
||||
cfg.pos_size = 2 * cfg.pos_limit + 2
|
||||
print(cfg.pretty())
|
||||
|
@ -98,7 +100,7 @@ def main(cfg):
|
|||
}
|
||||
|
||||
# 最好在 cpu 上预测
|
||||
# cfg.use_gpu = False
|
||||
cfg.use_gpu = False
|
||||
if cfg.use_gpu and torch.cuda.is_available():
|
||||
device = torch.device('cuda', cfg.gpu_id)
|
||||
else:
|
||||
|
@ -114,6 +116,7 @@ def main(cfg):
|
|||
|
||||
x = dict()
|
||||
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
|
||||
|
||||
if cfg.model_name != 'lm':
|
||||
x['head_pos'], x['tail_pos'] = torch.tensor([data[0]['head_pos']]), torch.tensor([data[0]['tail_pos']])
|
||||
if cfg.model_name == 'cnn':
|
||||
|
@ -124,6 +127,7 @@ def main(cfg):
|
|||
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
|
||||
x['adj'] = adj
|
||||
|
||||
|
||||
for key in x.keys():
|
||||
x[key] = x[key].to(device)
|
||||
|
||||
|
@ -149,4 +153,4 @@ def main(cfg):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
|
@ -9,10 +9,18 @@ import sys
|
|||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from utils import save_pkl, load_csv
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"_handle_pos_limit",
|
||||
"_add_pos_seq",
|
||||
"_convert_tokens_into_index",
|
||||
"_serialize_sentence",
|
||||
"_lm_serialize",
|
||||
"_add_relation_data",
|
||||
"_handle_relation_data",
|
||||
"preprocess"
|
||||
]
|
||||
def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:
|
||||
"""
|
||||
处理句子长度,设定句长限制
|
||||
|
@ -147,7 +155,6 @@ def _handle_relation_data(relation_data: List[Dict]) -> Dict:
|
|||
|
||||
return rels
|
||||
|
||||
|
||||
def preprocess(cfg):
|
||||
"""
|
||||
数据预处理阶段
|
||||
|
|
|
@ -268,6 +268,3 @@ class Serializer():
|
|||
if cat.startswith("P"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
|
201
tools/trainer.py
201
tools/trainer.py
|
@ -1,118 +1,113 @@
|
|||
import torch
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
from metrics import PRMetric
|
||||
from collections import OrderedDict
|
||||
from typing import Sequence, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SPECIAL_TOKENS_KEYS = [
|
||||
"pad_token",
|
||||
"unk_token",
|
||||
"mask_token",
|
||||
"cls_token",
|
||||
"sep_token",
|
||||
"bos_token",
|
||||
"eos_token",
|
||||
"head_token",
|
||||
"tail_token",
|
||||
|
||||
def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
|
||||
]
|
||||
|
||||
SPECIAL_TOKENS_VALUES = [
|
||||
"[PAD]",
|
||||
"[UNK]",
|
||||
"[MASK]",
|
||||
"[CLS]",
|
||||
"[SEP]",
|
||||
"[BOS]",
|
||||
"[EOS]",
|
||||
"HEAD",
|
||||
"TAIL",
|
||||
]
|
||||
|
||||
SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES))
|
||||
|
||||
|
||||
class Vocab(object):
|
||||
"""
|
||||
训练模型
|
||||
Args:
|
||||
epoch (int): 训练步数
|
||||
model (class): 训练的模型
|
||||
dataloader (dict): 数据集
|
||||
optimizer (Callable): 优化器
|
||||
criterion (Callable): 损失函数
|
||||
device (torch.device): 训练的设备
|
||||
writer (class): 输出
|
||||
cfg: 配置文件
|
||||
Return:
|
||||
losses[-1] : loss值
|
||||
构建词汇表,增加词汇,删除低频词汇
|
||||
"""
|
||||
model.train()
|
||||
def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS):
|
||||
self.name = name
|
||||
self.init_tokens = init_tokens
|
||||
self.trimed = False
|
||||
self.word2idx = {}
|
||||
self.word2count = {}
|
||||
self.idx2word = {}
|
||||
self.count = 0
|
||||
self._add_init_tokens()
|
||||
|
||||
metric = PRMetric()
|
||||
losses = []
|
||||
def _add_init_tokens(self):
|
||||
"""
|
||||
添加初始tokens
|
||||
"""
|
||||
for token in self.init_tokens.values():
|
||||
self._add_word(token)
|
||||
|
||||
for batch_idx, (x, y) in enumerate(dataloader, 1):
|
||||
for key, value in x.items():
|
||||
x[key] = value.to(device)
|
||||
y = y.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
y_pred = model(x)
|
||||
|
||||
if cfg.model_name == 'capsule':
|
||||
loss = model.loss(y_pred, y)
|
||||
def _add_word(self, word: str):
|
||||
"""
|
||||
增加单个词汇
|
||||
Arg :
|
||||
word (String) : 增加的词汇
|
||||
"""
|
||||
if word not in self.word2idx:
|
||||
self.word2idx[word] = self.count
|
||||
self.word2count[word] = 1
|
||||
self.idx2word[self.count] = word
|
||||
self.count += 1
|
||||
else:
|
||||
loss = criterion(y_pred, y)
|
||||
self.word2count[word] += 1
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
def add_words(self, words: Sequence):
|
||||
"""
|
||||
通过数组增加词汇
|
||||
Arg :
|
||||
words (List) : 增加的词汇组
|
||||
"""
|
||||
for word in words:
|
||||
self._add_word(word)
|
||||
|
||||
metric.update(y_true=y, y_pred=y_pred)
|
||||
losses.append(loss.item())
|
||||
def trim(self, min_freq=2, verbose: Optional[bool] = True):
|
||||
"""
|
||||
当 word 词频低于 min_freq 时,从词库中删除
|
||||
Args:
|
||||
min_freq (int): 最低词频
|
||||
verbose (bool) : 是否打印日志
|
||||
"""
|
||||
assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}'
|
||||
min_freq = int(min_freq)
|
||||
if min_freq < 2:
|
||||
return
|
||||
if self.trimed:
|
||||
return
|
||||
self.trimed = True
|
||||
|
||||
data_total = len(dataloader.dataset)
|
||||
data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)
|
||||
if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):
|
||||
# p r f1 皆为 macro,因为micro时三者相同,定义为acc
|
||||
acc, p, r, f1 = metric.compute()
|
||||
logger.info(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\t'
|
||||
f'Loss: {loss.item():.6f}')
|
||||
logger.info(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\t'
|
||||
f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||
keep_words = []
|
||||
new_words = []
|
||||
|
||||
if cfg.show_plot and not cfg.only_comparison_plot:
|
||||
if cfg.plot_utils == 'matplot':
|
||||
plt.plot(losses)
|
||||
plt.title(f'epoch {epoch} train loss')
|
||||
plt.show()
|
||||
for k, v in self.word2count.items():
|
||||
if v >= min_freq:
|
||||
keep_words.append(k)
|
||||
new_words.extend([k] * v)
|
||||
if verbose:
|
||||
before_len = len(keep_words)
|
||||
after_len = len(self.word2idx) - len(self.init_tokens)
|
||||
logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(
|
||||
before_len, after_len, before_len / after_len * 100))
|
||||
|
||||
if cfg.plot_utils == 'tensorboard':
|
||||
for i in range(len(losses)):
|
||||
writer.add_scalar(f'epoch_{epoch}_training_loss', losses[i], i)
|
||||
|
||||
return losses[-1]
|
||||
|
||||
|
||||
def validate(epoch, model, dataloader, criterion, device, cfg):
|
||||
"""
|
||||
验证模型
|
||||
Args:
|
||||
epoch (int): 训练步数
|
||||
model (class): 训练的模型
|
||||
dataloader (dict): 数据集
|
||||
optimizer (Callable): 优化器
|
||||
criterion (Callable): 损失函数
|
||||
device (torch.device): 训练的设备
|
||||
cfg: 配置文件
|
||||
Return:
|
||||
f1 : f1值
|
||||
loss : loss值
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
metric = PRMetric()
|
||||
losses = []
|
||||
|
||||
for batch_idx, (x, y) in enumerate(dataloader, 1):
|
||||
for key, value in x.items():
|
||||
x[key] = value.to(device)
|
||||
y = y.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
y_pred = model(x)
|
||||
|
||||
if cfg.model_name == 'capsule':
|
||||
loss = model.loss(y_pred, y)
|
||||
else:
|
||||
loss = criterion(y_pred, y)
|
||||
|
||||
metric.update(y_true=y, y_pred=y_pred)
|
||||
losses.append(loss.item())
|
||||
|
||||
loss = sum(losses) / len(losses)
|
||||
acc, p, r, f1 = metric.compute()
|
||||
data_total = len(dataloader.dataset)
|
||||
|
||||
if epoch >= 0:
|
||||
logger.info(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}')
|
||||
logger.info(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||
else:
|
||||
logger.info(f'Test Data: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}')
|
||||
logger.info(f'Test Data: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
|
||||
|
||||
return f1, loss
|
||||
# Reinitialize dictionaries
|
||||
self.word2idx = {}
|
||||
self.word2count = {}
|
||||
self.idx2word = {}
|
||||
self.count = 0
|
||||
self._add_init_tokens()
|
||||
self.add_words(new_words)
|
||||
|
|
Loading…
Reference in New Issue