commit
This commit is contained in:
tlk-dsg 2021-06-10 14:30:49 +08:00
parent 33a777cddf
commit 0e61eed58b
13 changed files with 160 additions and 192 deletions

View File

@ -2,6 +2,7 @@ import torch
from . import BasicModule
from module import Embedding, CNN
from module import Capsule as CapsuleLayer
from utils import seq_len_to_mask, to_one_hot
@ -47,4 +48,4 @@ class Capsule(BasicModule):
return loss.sum()
else:
# 默认情况为求平均
return loss.mean()
return loss.mean()

View File

@ -13,17 +13,16 @@ class DotAttention(nn.Module):
def forward(self, Q, K, V, mask_out=None, head_mask=None):
"""
一般输入信息 X 假设 K = V = Xs
一般输入信息 X 假设 K = V = X
att_weight = softmax( score_func(q, k) )
att = sum( att_weight * v )
Args:
Q: [..., L, H]
K: [..., S, H]
V: [..., S, H]
mask_out: [..., 1, S]
Return:
attention_out
attention_weight
:param Q: [..., L, H]
:param K: [..., S, H]
:param V: [..., S, H]
:param mask_out: [..., 1, S]
:return:
"""
H = Q.size(-1)
@ -53,10 +52,9 @@ class DotAttention(nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.0, output_attentions=True):
"""
Args:
embed_dim: 输入的维度必须能被 num_heads 整除
num_heads: attention 的个数
dropout: float
:param embed_dim: 输入的维度必须能被 num_heads 整除
:param num_heads: attention 的个数
:param dropout: float
"""
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
@ -74,13 +72,12 @@ class MultiHeadAttention(nn.Module):
def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None):
"""
Args:
Q: [B, L, Hs]
K: [B, S, Hs]
V: [B, S, Hs]
key_padding_mask: [B, S] 1/True 的地方需要 mask
attention_mask: [S] / [L, S] 指定位置 mask 1/True 的地方需要 mask
head_mask: [N] 指定 head mask 1/True 的地方需要 mask
:param Q: [B, L, Hs]
:param K: [B, S, Hs]
:param V: [B, S, Hs]
:param key_padding_mask: [B, S] 1/True 的地方需要 mask
:param attention_mask: [S] / [L, S] 指定位置 mask 1/True 的地方需要 mask
:param head_mask: [N] 指定 head mask 1/True 的地方需要 mask
"""
B, L, Hs = Q.shape
S = V.size(1)
@ -123,5 +120,3 @@ class MultiHeadAttention(nn.Module):
return attention_out, attention_weight
else:
return attention_out,

View File

@ -21,13 +21,12 @@ class CNN(nn.Module):
"""
def __init__(self, config):
"""
Args:
in_channels: 一般就是 word embedding 的维度或者 hidden size 的维度
out_channels: int
kernel_sizes: list 为了保证输出长度=输入长度必须为奇数: 3, 5, 7...
activation: [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
pooling_strategy: [max, avg, cls]
dropout: float
in_channels : 一般就是 word embedding 的维度或者 hidden size 的维度
out_channels : int
kernel_sizes : list 为了保证输出长度=输入长度必须为奇数: 3, 5, 7...
activation : [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
pooling_strategy : [max, avg, cls]
dropout: : float
"""
super(CNN, self).__init__()
@ -75,12 +74,9 @@ class CNN(nn.Module):
def forward(self, x, mask=None):
"""
Args:
x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值
mask: [batch_size, max_len], 句长部分为0padding部分为1不影响卷积运算max-pool一定不会pool到pad为0的位置
Return:
x: torch.Tensor
xp: torch.Tensor
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值
:param mask: [batch_size, max_len], 句长部分为0padding部分为1不影响卷积运算max-pool一定不会pool到pad为0的位置
:return:
"""
# [B, L, H] -> [B, H, L] (注释:将 H 维度当作输入 channel 维度)
x = torch.transpose(x, 1, 2)

View File

@ -27,8 +27,7 @@ class Capsule(nn.Module):
def forward(self, x):
"""
Args:
x: [B, L, H] # 从 CNN / RNN 得到的结果
x: [B, L, H] # 从 CNN / RNN 得到的结果
L 作为 input_num_capsules, H 作为 input_dim_capsule
"""
B, I, _ = x.size() # I 是 input_num_capsules

View File

@ -5,10 +5,9 @@ import torch.nn as nn
class Embedding(nn.Module):
def __init__(self, config):
"""
Args:
word embedding: 一般 0 padding
pos embedding: 一般 0 padding
dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加
word embedding: 一般 0 padding
pos embedding: 一般 0 padding
dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加
"""
super(Embedding, self).__init__()

View File

@ -46,11 +46,12 @@ class RNN(nn.Module):
def forward(self, x, x_len):
"""
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值
:param x_len: torch.Tensor [L] 已经排好序的句长值
:return:
output: torch.Tensor [B, L, H_out] 序列标注的使用结果
hn: torch.Tensor [B, N, H_out] / [B, H_out] 分类的结果 last_layer_hn 时只有最后一层结果
Args:
torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值
x_len: torch.Tensor [L] 已经排好序的句长值
Returns:
output: torch.Tensor [B, L, H_out] 序列标注的使用结果
hn: torch.Tensor [B, N, H_out] / [B, H_out] 分类的结果 last_layer_hn 时只有最后一层结果
"""
B, L, _ = x.size()
H, N = self.hidden_size, self.num_layers
@ -71,4 +72,3 @@ class RNN(nn.Module):
return output, hn

View File

@ -1,6 +1,5 @@
from .dataset import *
from .metrics import *
from .predict import *
from .preprocess import *
from .serializer import *
from .trainer import *

View File

@ -5,7 +5,6 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from utils import load_pkl
def collate_fn(cfg):
def collate_fn_intra(batch):
@ -71,26 +70,3 @@ class CustomDataset(Dataset):
def __len__(self):
return len(self.file)
if __name__ == '__main__':
from torch.utils.data import DataLoader
train_data_path = 'data/out/train.pkl'
vocab_path = 'data/out/vocab.pkl'
unk_str = 'UNK'
vocab = load_pkl(vocab_path)
train_ds = CustomDataset(train_data_path)
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, drop_last=False)
for batch_idx, (x, y) in enumerate(train_dl):
word = x['word']
for idx in word:
idx2token = ''.join([vocab.idx2word.get(i, unk_str) for i in idx.numpy()])
print(idx2token)
print(y)
break
# x, y = x.to(device), y.to(device)
# optimizer.zero_grad()
# y_pred = models(y)
# loss = criterion(y_pred, y)
# loss.backward()
# optimizer.step()

View File

@ -9,6 +9,8 @@ import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# self
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
import models
from preprocess import preprocess
from dataset import CustomDataset, collate_fn
@ -17,10 +19,10 @@ from utils import manual_seed, load_pkl
logger = logging.getLogger(__name__)
@hydra.main(config_path='conf/config.yaml')
@hydra.main(config_path='../conf/config.yaml')
def main(cfg):
cwd = utils.get_original_cwd()
cwd = cwd[0:-5]
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
logger.info(f'\n{cfg.pretty()}')
@ -142,6 +144,4 @@ def main(cfg):
if __name__ == '__main__':
main()
# python predict.py --help # 查看参数帮助
# python predict.py -c
# python predict.py chinese_split=0,1 replace_entity_with_type=0,1 -m

View File

@ -7,10 +7,10 @@ from hydra import utils
from serializer import Serializer
from preprocess import _serialize_sentence, _convert_tokens_into_index, _add_pos_seq, _handle_relation_data
import matplotlib.pyplot as plt
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
import models
from utils import load_pkl, load_csv
import models
logger = logging.getLogger(__name__)
@ -73,9 +73,11 @@ def _get_predict_instance(cfg):
@hydra.main(config_path='conf/config.yaml')
@hydra.main(config_path='../conf/config.yaml')
def main(cfg):
cwd = utils.get_original_cwd()
cwd = cwd[0:-5]
cfg.cwd = cwd
cfg.pos_size = 2 * cfg.pos_limit + 2
print(cfg.pretty())
@ -98,7 +100,7 @@ def main(cfg):
}
# 最好在 cpu 上预测
# cfg.use_gpu = False
cfg.use_gpu = False
if cfg.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', cfg.gpu_id)
else:
@ -114,6 +116,7 @@ def main(cfg):
x = dict()
x['word'], x['lens'] = torch.tensor([data[0]['token2idx']]), torch.tensor([data[0]['seq_len']])
if cfg.model_name != 'lm':
x['head_pos'], x['tail_pos'] = torch.tensor([data[0]['head_pos']]), torch.tensor([data[0]['tail_pos']])
if cfg.model_name == 'cnn':
@ -124,6 +127,7 @@ def main(cfg):
adj = torch.empty(1,data[0]['seq_len'],data[0]['seq_len']).random_(2)
x['adj'] = adj
for key in x.keys():
x[key] = x[key].to(device)
@ -149,4 +153,4 @@ def main(cfg):
if __name__ == '__main__':
main()
main()

View File

@ -9,10 +9,18 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
from utils import save_pkl, load_csv
logger = logging.getLogger(__name__)
__all__ = [
"_handle_pos_limit",
"_add_pos_seq",
"_convert_tokens_into_index",
"_serialize_sentence",
"_lm_serialize",
"_add_relation_data",
"_handle_relation_data",
"preprocess"
]
def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:
"""
处理句子长度设定句长限制
@ -147,7 +155,6 @@ def _handle_relation_data(relation_data: List[Dict]) -> Dict:
return rels
def preprocess(cfg):
"""
数据预处理阶段

View File

@ -268,6 +268,3 @@ class Serializer():
if cat.startswith("P"):
return True
return False

View File

@ -1,118 +1,113 @@
import torch
import logging
import matplotlib.pyplot as plt
from metrics import PRMetric
from collections import OrderedDict
from typing import Sequence, Optional
logger = logging.getLogger(__name__)
SPECIAL_TOKENS_KEYS = [
"pad_token",
"unk_token",
"mask_token",
"cls_token",
"sep_token",
"bos_token",
"eos_token",
"head_token",
"tail_token",
def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
]
SPECIAL_TOKENS_VALUES = [
"[PAD]",
"[UNK]",
"[MASK]",
"[CLS]",
"[SEP]",
"[BOS]",
"[EOS]",
"HEAD",
"TAIL",
]
SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES))
class Vocab(object):
"""
训练模型
Args:
epoch (int): 训练步数
model (class): 训练的模型
dataloader (dict): 数据集
optimizer (Callable): 优化器
criterion (Callable): 损失函数
device (torch.device): 训练的设备
writer (class): 输出
cfg: 配置文件
Return:
losses[-1] : loss值
构建词汇表,增加词汇删除低频词汇
"""
model.train()
def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS):
self.name = name
self.init_tokens = init_tokens
self.trimed = False
self.word2idx = {}
self.word2count = {}
self.idx2word = {}
self.count = 0
self._add_init_tokens()
metric = PRMetric()
losses = []
def _add_init_tokens(self):
"""
添加初始tokens
"""
for token in self.init_tokens.values():
self._add_word(token)
for batch_idx, (x, y) in enumerate(dataloader, 1):
for key, value in x.items():
x[key] = value.to(device)
y = y.to(device)
optimizer.zero_grad()
y_pred = model(x)
if cfg.model_name == 'capsule':
loss = model.loss(y_pred, y)
def _add_word(self, word: str):
"""
增加单个词汇
Arg :
word (String) : 增加的词汇
"""
if word not in self.word2idx:
self.word2idx[word] = self.count
self.word2count[word] = 1
self.idx2word[self.count] = word
self.count += 1
else:
loss = criterion(y_pred, y)
self.word2count[word] += 1
loss.backward()
optimizer.step()
def add_words(self, words: Sequence):
"""
通过数组增加词汇
Arg :
words (List) : 增加的词汇组
"""
for word in words:
self._add_word(word)
metric.update(y_true=y, y_pred=y_pred)
losses.append(loss.item())
def trim(self, min_freq=2, verbose: Optional[bool] = True):
"""
word 词频低于 min_freq 从词库中删除
Args:
min_freq (int): 最低词频
verbose (bool) : 是否打印日志
"""
assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}'
min_freq = int(min_freq)
if min_freq < 2:
return
if self.trimed:
return
self.trimed = True
data_total = len(dataloader.dataset)
data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)
if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):
# p r f1 皆为 macro因为micro时三者相同定义为acc
acc, p, r, f1 = metric.compute()
logger.info(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\t'
f'Loss: {loss.item():.6f}')
logger.info(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\t'
f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
keep_words = []
new_words = []
if cfg.show_plot and not cfg.only_comparison_plot:
if cfg.plot_utils == 'matplot':
plt.plot(losses)
plt.title(f'epoch {epoch} train loss')
plt.show()
for k, v in self.word2count.items():
if v >= min_freq:
keep_words.append(k)
new_words.extend([k] * v)
if verbose:
before_len = len(keep_words)
after_len = len(self.word2idx) - len(self.init_tokens)
logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(
before_len, after_len, before_len / after_len * 100))
if cfg.plot_utils == 'tensorboard':
for i in range(len(losses)):
writer.add_scalar(f'epoch_{epoch}_training_loss', losses[i], i)
return losses[-1]
def validate(epoch, model, dataloader, criterion, device, cfg):
"""
验证模型
Args:
epoch (int): 训练步数
model (class): 训练的模型
dataloader (dict): 数据集
optimizer (Callable): 优化器
criterion (Callable): 损失函数
device (torch.device): 训练的设备
cfg: 配置文件
Return:
f1 : f1值
loss : loss值
"""
model.eval()
metric = PRMetric()
losses = []
for batch_idx, (x, y) in enumerate(dataloader, 1):
for key, value in x.items():
x[key] = value.to(device)
y = y.to(device)
with torch.no_grad():
y_pred = model(x)
if cfg.model_name == 'capsule':
loss = model.loss(y_pred, y)
else:
loss = criterion(y_pred, y)
metric.update(y_true=y, y_pred=y_pred)
losses.append(loss.item())
loss = sum(losses) / len(losses)
acc, p, r, f1 = metric.compute()
data_total = len(dataloader.dataset)
if epoch >= 0:
logger.info(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}')
logger.info(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
else:
logger.info(f'Test Data: [{data_total}/{data_total}](100%)\t Loss: {loss:.6f}')
logger.info(f'Test Data: Acc: {100. * acc:.2f}%\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')
return f1, loss
# Reinitialize dictionaries
self.word2idx = {}
self.word2count = {}
self.idx2word = {}
self.count = 0
self._add_init_tokens()
self.add_words(new_words)