Delete Transformer_offical.py
This commit is contained in:
parent
621a7f46b1
commit
26871a90b1
|
@ -1,429 +0,0 @@
|
||||||
import copy
|
|
||||||
import torch
|
|
||||||
from torch.nn.init import xavier_uniform_
|
|
||||||
from torch.nn import Module,ModuleList,LayerNorm,Linear,Dropout,MultiheadAttention
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
# 代码来自 torch 1.3.0 这是官网些的 transformer
|
|
||||||
# 但是这个transformer 接口写的太死,自己重新实现了一版
|
|
||||||
class Transformer(Module):
|
|
||||||
r"""A transformer model. User is able to modify the attributes as needed. The architecture
|
|
||||||
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
|
||||||
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
|
||||||
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
|
||||||
Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
|
|
||||||
model with corresponding parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
|
||||||
nhead: the number of heads in the multiheadattention models (default=8).
|
|
||||||
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
|
||||||
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
|
||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
||||||
dropout: the dropout value (default=0.1).
|
|
||||||
activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
|
|
||||||
custom_encoder: custom encoder (default=None).
|
|
||||||
custom_decoder: custom decoder (default=None).
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
|
|
||||||
>>> src = torch.rand((10, 32, 512))
|
|
||||||
>>> tgt = torch.rand((20, 32, 512))
|
|
||||||
>>> out = transformer_model(src, tgt)
|
|
||||||
|
|
||||||
Note: A full example to apply nn.Transformer module for the word language model is available in
|
|
||||||
https://github.com/pytorch/examples/tree/master/word_language_model
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
|
||||||
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
|
||||||
activation="relu", custom_encoder=None, custom_decoder=None):
|
|
||||||
super(Transformer, self).__init__()
|
|
||||||
|
|
||||||
if custom_encoder is not None:
|
|
||||||
self.encoder = custom_encoder
|
|
||||||
else:
|
|
||||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
|
|
||||||
encoder_norm = LayerNorm(d_model)
|
|
||||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
|
||||||
|
|
||||||
if custom_decoder is not None:
|
|
||||||
self.decoder = custom_decoder
|
|
||||||
else:
|
|
||||||
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
|
|
||||||
decoder_norm = LayerNorm(d_model)
|
|
||||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
|
|
||||||
|
|
||||||
self._reset_parameters()
|
|
||||||
|
|
||||||
self.d_model = d_model
|
|
||||||
self.nhead = nhead
|
|
||||||
|
|
||||||
def forward(self, src, tgt, src_mask=None, tgt_mask=None,
|
|
||||||
memory_mask=None, src_key_padding_mask=None,
|
|
||||||
tgt_key_padding_mask=None, memory_key_padding_mask=None):
|
|
||||||
r"""Take in and process masked source/target sequences.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
src: the sequence to the encoder (required).
|
|
||||||
tgt: the sequence to the decoder (required).
|
|
||||||
src_mask: the additive mask for the src sequence (optional).
|
|
||||||
tgt_mask: the additive mask for the tgt sequence (optional).
|
|
||||||
memory_mask: the additive mask for the encoder output (optional).
|
|
||||||
src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
|
|
||||||
tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
|
|
||||||
memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
- src: :math:`(S, N, E)`.
|
|
||||||
- tgt: :math:`(T, N, E)`.
|
|
||||||
- src_mask: :math:`(S, S)`.
|
|
||||||
- tgt_mask: :math:`(T, T)`.
|
|
||||||
- memory_mask: :math:`(T, S)`.
|
|
||||||
- src_key_padding_mask: :math:`(N, S)`.
|
|
||||||
- tgt_key_padding_mask: :math:`(N, T)`.
|
|
||||||
- memory_key_padding_mask: :math:`(N, S)`.
|
|
||||||
|
|
||||||
Note: [src/tgt/memory]_mask should be filled with
|
|
||||||
float('-inf') for the masked positions and float(0.0) else. These masks
|
|
||||||
ensure that predictions for position i depend only on the unmasked positions
|
|
||||||
j and are applied identically for each sequence in a batch.
|
|
||||||
[src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions
|
|
||||||
that should be masked with float('-inf') and False values will be unchanged.
|
|
||||||
This mask ensures that no information will be taken from position i if
|
|
||||||
it is masked, and has a separate mask for each sequence in a batch.
|
|
||||||
|
|
||||||
- output: :math:`(T, N, E)`.
|
|
||||||
|
|
||||||
Note: Due to the multi-head attention architecture in the transformer model,
|
|
||||||
the output sequence length of a transformer is same as the input sequence
|
|
||||||
(i.e. target) length of the decode.
|
|
||||||
|
|
||||||
where S is the source sequence length, T is the target sequence length, N is the
|
|
||||||
batch size, E is the feature number
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if src.size(1) != tgt.size(1):
|
|
||||||
raise RuntimeError("the batch number of src and tgt must be equal")
|
|
||||||
|
|
||||||
if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
|
|
||||||
raise RuntimeError("the feature number of src and tgt must be equal to d_model")
|
|
||||||
|
|
||||||
memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
|
|
||||||
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
|
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def generate_square_subsequent_mask(self, sz):
|
|
||||||
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
|
||||||
Unmasked positions are filled with float(0.0).
|
|
||||||
"""
|
|
||||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
|
||||||
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
|
|
||||||
return mask
|
|
||||||
|
|
||||||
def _reset_parameters(self):
|
|
||||||
r"""Initiate parameters in the transformer model."""
|
|
||||||
|
|
||||||
for p in self.parameters():
|
|
||||||
if p.dim() > 1:
|
|
||||||
xavier_uniform_(p)
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(Module):
|
|
||||||
r"""TransformerEncoder is a stack of N encoder layers
|
|
||||||
|
|
||||||
Args:
|
|
||||||
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
|
||||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
|
||||||
norm: the layer normalization component (optional).
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
|
||||||
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
|
|
||||||
>>> src = torch.rand(10, 32, 512)
|
|
||||||
>>> out = transformer_encoder(src)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
|
||||||
super(TransformerEncoder, self).__init__()
|
|
||||||
self.layers = _get_clones(encoder_layer, num_layers)
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.norm = norm
|
|
||||||
|
|
||||||
def forward(self, src, mask=None, src_key_padding_mask=None):
|
|
||||||
r"""Pass the input through the endocder layers in turn.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
src: the sequnce to the encoder (required).
|
|
||||||
mask: the mask for the src sequence (optional).
|
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
see the docs in Transformer class.
|
|
||||||
"""
|
|
||||||
output = src
|
|
||||||
|
|
||||||
for i in range(self.num_layers):
|
|
||||||
output = self.layers[i](output, src_mask=mask,
|
|
||||||
src_key_padding_mask=src_key_padding_mask)
|
|
||||||
|
|
||||||
if self.norm:
|
|
||||||
output = self.norm(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(Module):
|
|
||||||
r"""TransformerDecoder is a stack of N decoder layers
|
|
||||||
|
|
||||||
Args:
|
|
||||||
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
|
||||||
num_layers: the number of sub-decoder-layers in the decoder (required).
|
|
||||||
norm: the layer normalization component (optional).
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
|
||||||
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
|
|
||||||
>>> memory = torch.rand(10, 32, 512)
|
|
||||||
>>> tgt = torch.rand(20, 32, 512)
|
|
||||||
>>> out = transformer_decoder(tgt, memory)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, decoder_layer, num_layers, norm=None):
|
|
||||||
super(TransformerDecoder, self).__init__()
|
|
||||||
self.layers = _get_clones(decoder_layer, num_layers)
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.norm = norm
|
|
||||||
|
|
||||||
def forward(self, tgt, memory, tgt_mask=None,
|
|
||||||
memory_mask=None, tgt_key_padding_mask=None,
|
|
||||||
memory_key_padding_mask=None):
|
|
||||||
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tgt: the sequence to the decoder (required).
|
|
||||||
memory: the sequnce from the last layer of the encoder (required).
|
|
||||||
tgt_mask: the mask for the tgt sequence (optional).
|
|
||||||
memory_mask: the mask for the memory sequence (optional).
|
|
||||||
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
|
||||||
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
see the docs in Transformer class.
|
|
||||||
"""
|
|
||||||
output = tgt
|
|
||||||
|
|
||||||
for i in range(self.num_layers):
|
|
||||||
output = self.layers[i](output, memory, tgt_mask=tgt_mask,
|
|
||||||
memory_mask=memory_mask,
|
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask)
|
|
||||||
|
|
||||||
if self.norm:
|
|
||||||
output = self.norm(output)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(Module):
|
|
||||||
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
|
||||||
This standard encoder layer is based on the paper "Attention Is All You Need".
|
|
||||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
||||||
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
|
||||||
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
|
||||||
in a different way during application.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
d_model: the number of expected features in the input (required).
|
|
||||||
nhead: the number of heads in the multiheadattention models (required).
|
|
||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
||||||
dropout: the dropout value (default=0.1).
|
|
||||||
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
|
|
||||||
>>> src = torch.rand(10, 32, 512)
|
|
||||||
>>> out = encoder_layer(src)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
|
|
||||||
super(TransformerEncoderLayer, self).__init__()
|
|
||||||
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model
|
|
||||||
self.linear1 = Linear(d_model, dim_feedforward)
|
|
||||||
self.dropout = Dropout(dropout)
|
|
||||||
self.linear2 = Linear(dim_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = LayerNorm(d_model)
|
|
||||||
self.norm2 = LayerNorm(d_model)
|
|
||||||
self.dropout1 = Dropout(dropout)
|
|
||||||
self.dropout2 = Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
|
|
||||||
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
|
||||||
r"""Pass the input through the endocder layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
src: the sequnce to the encoder layer (required).
|
|
||||||
src_mask: the mask for the src sequence (optional).
|
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
see the docs in Transformer class.
|
|
||||||
"""
|
|
||||||
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
|
|
||||||
key_padding_mask=src_key_padding_mask)[0]
|
|
||||||
src = src + self.dropout1(src2)
|
|
||||||
src = self.norm1(src)
|
|
||||||
if hasattr(self, "activation"):
|
|
||||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
||||||
else: # for backward compatibility
|
|
||||||
src2 = self.linear2(self.dropout(F.relu(self.linear1(src))))
|
|
||||||
src = src + self.dropout2(src2)
|
|
||||||
src = self.norm2(src)
|
|
||||||
return src
|
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(Module):
|
|
||||||
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
|
||||||
This standard decoder layer is based on the paper "Attention Is All You Need".
|
|
||||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
||||||
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
|
||||||
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
|
||||||
in a different way during application.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
d_model: the number of expected features in the input (required).
|
|
||||||
nhead: the number of heads in the multiheadattention models (required).
|
|
||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
||||||
dropout: the dropout value (default=0.1).
|
|
||||||
activation: the activation function of intermediate layer, relu or gelu (default=relu).
|
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
|
||||||
>>> memory = torch.rand(10, 32, 512)
|
|
||||||
>>> tgt = torch.rand(20, 32, 512)
|
|
||||||
>>> out = decoder_layer(tgt, memory)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
|
|
||||||
super(TransformerDecoderLayer, self).__init__()
|
|
||||||
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
||||||
# Implementation of Feedforward model
|
|
||||||
self.linear1 = Linear(d_model, dim_feedforward)
|
|
||||||
self.dropout = Dropout(dropout)
|
|
||||||
self.linear2 = Linear(dim_feedforward, d_model)
|
|
||||||
|
|
||||||
self.norm1 = LayerNorm(d_model)
|
|
||||||
self.norm2 = LayerNorm(d_model)
|
|
||||||
self.norm3 = LayerNorm(d_model)
|
|
||||||
self.dropout1 = Dropout(dropout)
|
|
||||||
self.dropout2 = Dropout(dropout)
|
|
||||||
self.dropout3 = Dropout(dropout)
|
|
||||||
|
|
||||||
self.activation = _get_activation_fn(activation)
|
|
||||||
|
|
||||||
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
|
|
||||||
tgt_key_padding_mask=None, memory_key_padding_mask=None):
|
|
||||||
r"""Pass the inputs (and mask) through the decoder layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tgt: the sequence to the decoder layer (required).
|
|
||||||
memory: the sequnce from the last layer of the encoder (required).
|
|
||||||
tgt_mask: the mask for the tgt sequence (optional).
|
|
||||||
memory_mask: the mask for the memory sequence (optional).
|
|
||||||
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
|
||||||
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
|
||||||
|
|
||||||
Shape:
|
|
||||||
see the docs in Transformer class.
|
|
||||||
"""
|
|
||||||
tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
|
|
||||||
key_padding_mask=tgt_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
|
||||||
tgt = self.norm1(tgt)
|
|
||||||
tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
|
|
||||||
key_padding_mask=memory_key_padding_mask)[0]
|
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
|
||||||
tgt = self.norm2(tgt)
|
|
||||||
if hasattr(self, "activation"):
|
|
||||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
|
||||||
else: # for backward compatibility
|
|
||||||
tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
|
|
||||||
tgt = tgt + self.dropout3(tgt2)
|
|
||||||
tgt = self.norm3(tgt)
|
|
||||||
return tgt
|
|
||||||
|
|
||||||
|
|
||||||
def _get_clones(module, N):
|
|
||||||
return ModuleList([copy.deepcopy(module) for i in range(N)])
|
|
||||||
|
|
||||||
|
|
||||||
def _get_activation_fn(activation):
|
|
||||||
if activation == "relu":
|
|
||||||
return F.relu
|
|
||||||
elif activation == "gelu":
|
|
||||||
return F.gelu
|
|
||||||
else:
|
|
||||||
raise RuntimeError("activation should be relu/gelu, not %s." % activation)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import torch.nn as nn
|
|
||||||
torch.manual_seed(1)
|
|
||||||
class Config():
|
|
||||||
d_model = 8
|
|
||||||
nhead = 4
|
|
||||||
num_encoder_layers = 3
|
|
||||||
num_decoder_layers = 3
|
|
||||||
dim_feedforward = 64
|
|
||||||
dropout = 0.1
|
|
||||||
activation = 'gelu'
|
|
||||||
|
|
||||||
cfg = Config()
|
|
||||||
|
|
||||||
encoder_layer = nn.TransformerEncoderLayer(cfg.d_model, cfg.nhead, cfg.dim_feedforward, cfg.dropout,
|
|
||||||
cfg.activation)
|
|
||||||
encoder_norm = nn.LayerNorm(cfg.d_model)
|
|
||||||
encoder = nn.TransformerEncoder(encoder_layer, cfg.num_encoder_layers, encoder_norm)
|
|
||||||
|
|
||||||
decoder_layer = nn.TransformerDecoderLayer(cfg.d_model, cfg.nhead, cfg.dim_feedforward, cfg.dropout,
|
|
||||||
cfg.activation)
|
|
||||||
decoder_norm = nn.LayerNorm(cfg.d_model)
|
|
||||||
decoder = nn.TransformerDecoder(decoder_layer, cfg.num_decoder_layers, decoder_norm)
|
|
||||||
|
|
||||||
src = torch.randn((2, 7, 8)) # B,L,H
|
|
||||||
tgt = torch.randn((2, 5, 8))
|
|
||||||
src.transpose_(0,1)
|
|
||||||
tgt.transpose_(0,1)
|
|
||||||
src_mask = None
|
|
||||||
tgt_mask = None
|
|
||||||
memory_mask = None
|
|
||||||
src_key_padding_mask = None
|
|
||||||
tgt_key_padding_mask = None
|
|
||||||
memory_key_padding_mask = None
|
|
||||||
|
|
||||||
memory = encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
|
|
||||||
output = decoder(tgt,
|
|
||||||
memory,
|
|
||||||
tgt_mask=tgt_mask,
|
|
||||||
memory_mask=memory_mask,
|
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask)
|
|
||||||
memory.transpose_(0,1)
|
|
||||||
output.transpose_(0,1)
|
|
||||||
print(memory.shape, output.shape) # torch.Size([2, 80, 8]) torch.Size([2, 160, 8])
|
|
||||||
|
|
||||||
# 直接调用 transformer
|
|
||||||
transformer = nn.Transformer(cfg.d_model,cfg.nhead,cfg.num_encoder_layers,cfg.num_decoder_layers,cfg.dim_feedforward,cfg.dropout,cfg.activation)
|
|
||||||
out = transformer(src,tgt,src_mask=src_mask,tgt_mask=tgt_mask,memory_mask=memory_mask)
|
|
||||||
out.transpose_(0,1)
|
|
||||||
print(out.shape)
|
|
||||||
|
|
Loading…
Reference in New Issue