update transformer

This commit is contained in:
leo 2019-12-03 22:43:03 +08:00
parent de97e387d9
commit b100404be0
3 changed files with 38 additions and 6 deletions

View File

@ -1,9 +1,12 @@
hidden_size: 128
intermediate_size: 256
model_name: transformer
hidden_size: ??? # 使用 embedding 输出的结果,不需要指定
num_heads: 4 # 必须能被 hidden_size 整除
num_hidden_layers: 3
num_heads: 4
intermediate_size: 256
dropout: 0.1
layer_norm_eps: 1e-12
hidden_act: gelu_new
hidden_act: gelu_new # [relu, gelu, swish, gelu_new]
output_attentions: True
output_hidden_states: True

29
models/Transformer.py Normal file
View File

@ -0,0 +1,29 @@
import torch.nn as nn
from . import BasicModule
from module import Embedding
from module import Transformer as TransformerBlock
from utils import seq_len_to_mask
class Transformer(BasicModule):
def __init__(self, cfg):
super(Transformer, self).__init__()
if cfg.dim_strategy == 'cat':
cfg.hidden_size = cfg.word_dim + 2 * cfg.pos_dim
else:
cfg.hidden_size = cfg.word_dim
self.embedding = Embedding(cfg)
self.transformer = TransformerBlock(cfg)
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
def forward(self, x):
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
mask = seq_len_to_mask(lens)
inputs = self.embedding(word, head_pos, tail_pos)
last_layer_hidden_state, all_hidden_states, all_attentions = self.transformer(inputs, key_padding_mask=mask)
out_pool = last_layer_hidden_state.max(dim=1)[0]
output = self.fc(out_pool)
return output

View File

@ -115,8 +115,8 @@ class Transformer(nn.Module):
def forward(self, hidden_states, key_padding_mask=None, attention_mask=None, head_mask=None):
"""
:param hidden_states: [B, L, Hs]
:param key_padding_mask: [B, S] 1/True 的地方需要 mask
:param attn_mask: [S] / [L, S] 指定位置 mask 1/True 的地方需要 mask
:param key_padding_mask: [B, S] 1/True 的地方需要 mask
:param attn_mask: [S] / [L, S] 指定位置 mask 1/True 的地方需要 mask
:param head_mask: [N] / [L, N] 指定 head mask 1/True 的地方需要 mask
"""
if head_mask is not None: