update transformer
This commit is contained in:
parent
de97e387d9
commit
b100404be0
|
@ -1,9 +1,12 @@
|
|||
hidden_size: 128
|
||||
intermediate_size: 256
|
||||
model_name: transformer
|
||||
|
||||
hidden_size: ??? # 使用 embedding 输出的结果,不需要指定
|
||||
num_heads: 4 # 必须能被 hidden_size 整除
|
||||
num_hidden_layers: 3
|
||||
num_heads: 4
|
||||
intermediate_size: 256
|
||||
dropout: 0.1
|
||||
layer_norm_eps: 1e-12
|
||||
hidden_act: gelu_new
|
||||
hidden_act: gelu_new # [relu, gelu, swish, gelu_new]
|
||||
|
||||
output_attentions: True
|
||||
output_hidden_states: True
|
|
@ -0,0 +1,29 @@
|
|||
import torch.nn as nn
|
||||
from . import BasicModule
|
||||
from module import Embedding
|
||||
from module import Transformer as TransformerBlock
|
||||
from utils import seq_len_to_mask
|
||||
|
||||
|
||||
class Transformer(BasicModule):
|
||||
def __init__(self, cfg):
|
||||
super(Transformer, self).__init__()
|
||||
|
||||
if cfg.dim_strategy == 'cat':
|
||||
cfg.hidden_size = cfg.word_dim + 2 * cfg.pos_dim
|
||||
else:
|
||||
cfg.hidden_size = cfg.word_dim
|
||||
|
||||
self.embedding = Embedding(cfg)
|
||||
self.transformer = TransformerBlock(cfg)
|
||||
self.fc = nn.Linear(cfg.hidden_size, cfg.num_relations)
|
||||
|
||||
def forward(self, x):
|
||||
word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']
|
||||
mask = seq_len_to_mask(lens)
|
||||
inputs = self.embedding(word, head_pos, tail_pos)
|
||||
last_layer_hidden_state, all_hidden_states, all_attentions = self.transformer(inputs, key_padding_mask=mask)
|
||||
out_pool = last_layer_hidden_state.max(dim=1)[0]
|
||||
output = self.fc(out_pool)
|
||||
|
||||
return output
|
|
@ -115,8 +115,8 @@ class Transformer(nn.Module):
|
|||
def forward(self, hidden_states, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||
"""
|
||||
:param hidden_states: [B, L, Hs]
|
||||
:param key_padding_mask: [B, S] 为 1/True 的地方需要 mask
|
||||
:param attn_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
|
||||
:param key_padding_mask: [B, S] 为 1/True 的地方需要 mask
|
||||
:param attn_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
|
||||
:param head_mask: [N] / [L, N] 指定 head mask 掉, 为 1/True 的地方需要 mask
|
||||
"""
|
||||
if head_mask is not None:
|
||||
|
|
Loading…
Reference in New Issue