add docstring to transformer_tts and fastspeech
This commit is contained in:
parent
3d1fda0ce9
commit
429695d6a0
|
@ -23,12 +23,26 @@ class Decoder(dg.Layer):
|
|||
n_layers,
|
||||
n_head,
|
||||
d_k,
|
||||
d_v,
|
||||
d_q,
|
||||
d_model,
|
||||
d_inner,
|
||||
fft_conv1d_kernel,
|
||||
fft_conv1d_padding,
|
||||
dropout=0.1):
|
||||
"""Decoder layer of FastSpeech.
|
||||
|
||||
Args:
|
||||
len_max_seq (int): the max mel len of sequence.
|
||||
n_layers (int): the layers number of FFTBlock.
|
||||
n_head (int): the head number of multihead attention.
|
||||
d_k (int): the dim of key in multihead attention.
|
||||
d_q (int): the dim of query in multihead attention.
|
||||
d_model (int): the dim of hidden layer in multihead attention.
|
||||
d_inner (int): the dim of hidden layer in ffn.
|
||||
fft_conv1d_kernel (int): the conv kernel size in FFTBlock.
|
||||
fft_conv1d_padding (int): the conv padding size in FFTBlock.
|
||||
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
|
||||
"""
|
||||
super(Decoder, self).__init__()
|
||||
|
||||
n_position = len_max_seq + 1
|
||||
|
@ -48,7 +62,7 @@ class Decoder(dg.Layer):
|
|||
d_inner,
|
||||
n_head,
|
||||
d_k,
|
||||
d_v,
|
||||
d_q,
|
||||
fft_conv1d_kernel,
|
||||
fft_conv1d_padding,
|
||||
dropout=dropout) for _ in range(n_layers)
|
||||
|
@ -58,26 +72,20 @@ class Decoder(dg.Layer):
|
|||
|
||||
def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None):
|
||||
"""
|
||||
Decoder layer of FastSpeech.
|
||||
Compute decoder outputs.
|
||||
|
||||
Args:
|
||||
enc_seq (Variable): The output of length regulator.
|
||||
Shape: (B, T_text, C), T_text means the timesteps of input text,
|
||||
dtype: float32.
|
||||
enc_pos (Variable): The spectrum position.
|
||||
Shape: (B, T_mel), T_mel means the timesteps of input spectrum,
|
||||
dtype: int64.
|
||||
non_pad_mask (Variable): the mask with non pad.
|
||||
Shape: (B, T_mel, 1),
|
||||
dtype: int64.
|
||||
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
|
||||
Shape: (B, T_mel, T_mel),
|
||||
dtype: int64.
|
||||
enc_seq (Variable): shape(B, T_text, C), dtype float32,
|
||||
the output of length regulator, where T_text means the timesteps of input text,
|
||||
enc_pos (Variable): shape(B, T_mel), dtype int64,
|
||||
the spectrum position, where T_mel means the timesteps of input spectrum,
|
||||
non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
|
||||
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
|
||||
the mask of mel spectrum. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dec_output (Variable): the decoder output.
|
||||
Shape: (B, T_mel, C).
|
||||
dec_slf_attn_list (list[Variable]): the decoder self attention list.
|
||||
Len: n_layers.
|
||||
dec_output (Variable): shape(B, T_mel, C), the decoder output.
|
||||
dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list.
|
||||
"""
|
||||
dec_slf_attn_list = []
|
||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||
|
|
|
@ -24,12 +24,27 @@ class Encoder(dg.Layer):
|
|||
n_layers,
|
||||
n_head,
|
||||
d_k,
|
||||
d_v,
|
||||
d_q,
|
||||
d_model,
|
||||
d_inner,
|
||||
fft_conv1d_kernel,
|
||||
fft_conv1d_padding,
|
||||
dropout=0.1):
|
||||
"""Encoder layer of FastSpeech.
|
||||
|
||||
Args:
|
||||
n_src_vocab (int): the number of source vocabulary.
|
||||
len_max_seq (int): the max mel len of sequence.
|
||||
n_layers (int): the layers number of FFTBlock.
|
||||
n_head (int): the head number of multihead attention.
|
||||
d_k (int): the dim of key in multihead attention.
|
||||
d_q (int): the dim of query in multihead attention.
|
||||
d_model (int): the dim of hidden layer in multihead attention.
|
||||
d_inner (int): the dim of hidden layer in ffn.
|
||||
fft_conv1d_kernel (int): the conv kernel size in FFTBlock.
|
||||
fft_conv1d_padding (int): the conv padding size in FFTBlock.
|
||||
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
|
||||
"""
|
||||
super(Encoder, self).__init__()
|
||||
n_position = len_max_seq + 1
|
||||
self.n_head = n_head
|
||||
|
@ -53,7 +68,7 @@ class Encoder(dg.Layer):
|
|||
d_inner,
|
||||
n_head,
|
||||
d_k,
|
||||
d_v,
|
||||
d_q,
|
||||
fft_conv1d_kernel,
|
||||
fft_conv1d_padding,
|
||||
dropout=dropout) for _ in range(n_layers)
|
||||
|
@ -63,25 +78,20 @@ class Encoder(dg.Layer):
|
|||
|
||||
def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None):
|
||||
"""
|
||||
Encoder layer of FastSpeech.
|
||||
Encode text sequence.
|
||||
|
||||
Args:
|
||||
character (Variable): The input text characters.
|
||||
Shape: (B, T_text), T_text means the timesteps of input characters,
|
||||
dtype: float32.
|
||||
text_pos (Variable): The input text position.
|
||||
Shape: (B, T_text), dtype: int64.
|
||||
non_pad_mask (Variable): the mask with non pad.
|
||||
Shape: (B, T_text, 1),
|
||||
dtype: int64.
|
||||
slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
|
||||
Shape: (B, T_text, T_text),
|
||||
dtype: int64.
|
||||
|
||||
character (Variable): shape(B, T_text), dtype float32, the input text characters,
|
||||
where T_text means the timesteps of input characters,
|
||||
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
|
||||
non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
|
||||
slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
|
||||
the mask of input characters. Defaults to None.
|
||||
|
||||
Returns:
|
||||
enc_output (Variable), the encoder output. Shape(B, T_text, C)
|
||||
non_pad_mask (Variable), the mask with non pad. Shape(B, T_text, 1)
|
||||
enc_slf_attn_list (list[Variable]), the encoder self attention list.
|
||||
Len: n_layers.
|
||||
enc_output (Variable): shape(B, T_text, C), the encoder output.
|
||||
non_pad_mask (Variable): shape(B, T_text, 1), the mask with non pad.
|
||||
enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list.
|
||||
"""
|
||||
enc_slf_attn_list = []
|
||||
slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1])
|
||||
|
|
|
@ -25,7 +25,11 @@ from parakeet.models.fastspeech.decoder import Decoder
|
|||
|
||||
class FastSpeech(dg.Layer):
|
||||
def __init__(self, cfg):
|
||||
" FastSpeech"
|
||||
"""FastSpeech model.
|
||||
|
||||
Args:
|
||||
cfg: the yaml configs used in FastSpeech model.
|
||||
"""
|
||||
super(FastSpeech, self).__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
|
@ -34,7 +38,7 @@ class FastSpeech(dg.Layer):
|
|||
n_layers=cfg['encoder_n_layer'],
|
||||
n_head=cfg['encoder_head'],
|
||||
d_k=cfg['fs_hidden_size'] // cfg['encoder_head'],
|
||||
d_v=cfg['fs_hidden_size'] // cfg['encoder_head'],
|
||||
d_q=cfg['fs_hidden_size'] // cfg['encoder_head'],
|
||||
d_model=cfg['fs_hidden_size'],
|
||||
d_inner=cfg['encoder_conv1d_filter_size'],
|
||||
fft_conv1d_kernel=cfg['fft_conv1d_filter'],
|
||||
|
@ -50,7 +54,7 @@ class FastSpeech(dg.Layer):
|
|||
n_layers=cfg['decoder_n_layer'],
|
||||
n_head=cfg['decoder_head'],
|
||||
d_k=cfg['fs_hidden_size'] // cfg['decoder_head'],
|
||||
d_v=cfg['fs_hidden_size'] // cfg['decoder_head'],
|
||||
d_q=cfg['fs_hidden_size'] // cfg['decoder_head'],
|
||||
d_model=cfg['fs_hidden_size'],
|
||||
d_inner=cfg['decoder_conv1d_filter_size'],
|
||||
fft_conv1d_kernel=cfg['fft_conv1d_filter'],
|
||||
|
@ -88,39 +92,31 @@ class FastSpeech(dg.Layer):
|
|||
length_target=None,
|
||||
alpha=1.0):
|
||||
"""
|
||||
FastSpeech model.
|
||||
Compute mel output from text character.
|
||||
|
||||
Args:
|
||||
character (Variable): The input text characters.
|
||||
Shape: (B, T_text), T_text means the timesteps of input characters, dtype: float32.
|
||||
text_pos (Variable): The input text position.
|
||||
Shape: (B, T_text), dtype: int64.
|
||||
mel_pos (Variable, optional): The spectrum position.
|
||||
Shape: (B, T_mel), T_mel means the timesteps of input spectrum, dtype: int64.
|
||||
enc_non_pad_mask (Variable): the mask with non pad.
|
||||
Shape: (B, T_text, 1),
|
||||
dtype: int64.
|
||||
dec_non_pad_mask (Variable): the mask with non pad.
|
||||
Shape: (B, T_mel, 1),
|
||||
dtype: int64.
|
||||
enc_slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None.
|
||||
Shape: (B, T_text, T_text),
|
||||
dtype: int64.
|
||||
slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None.
|
||||
Shape: (B, T_mel, T_mel),
|
||||
dtype: int64.
|
||||
length_target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS.
|
||||
Defaults to None. Shape: (B, T_text), dtype: int64.
|
||||
character (Variable): shape(B, T_text), dtype float32, the input text characters,
|
||||
where T_text means the timesteps of input characters,
|
||||
text_pos (Variable): shape(B, T_text), dtype int64, the input text position.
|
||||
mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position,
|
||||
where T_mel means the timesteps of input spectrum,
|
||||
enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad.
|
||||
dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad.
|
||||
enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64,
|
||||
the mask of input characters. Defaults to None.
|
||||
slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64,
|
||||
the mask of mel spectrum. Defaults to None.
|
||||
length_target (Variable, optional): shape(B, T_text), dtype int64,
|
||||
the duration of phoneme compute from pretrained transformerTTS. Defaults to None.
|
||||
alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence
|
||||
mel, thereby controlling the voice speed. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
mel_output (Variable), the mel output before postnet. Shape: (B, T_mel, C),
|
||||
mel_output_postnet (Variable), the mel output after postnet. Shape: (B, T_mel, C).
|
||||
duration_predictor_output (Variable), the duration of phoneme compute with duration predictor.
|
||||
Shape: (B, T_text).
|
||||
enc_slf_attn_list (List[Variable]), the encoder self attention list. Len: enc_n_layers.
|
||||
dec_slf_attn_list (List[Variable]), the decoder self attention list. Len: dec_n_layers.
|
||||
mel_output (Variable): shape(B, T_mel, C), the mel output before postnet.
|
||||
mel_output_postnet (Variable): shape(B, T_mel, C), the mel output after postnet.
|
||||
duration_predictor_output (Variable): shape(B, T_text), the duration of phoneme compute with duration predictor.
|
||||
enc_slf_attn_list (List[Variable]): len(enc_n_layers), the encoder self attention list.
|
||||
dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list.
|
||||
"""
|
||||
|
||||
encoder_output, enc_slf_attn_list = self.encoder(
|
||||
|
|
|
@ -26,15 +26,27 @@ class FFTBlock(dg.Layer):
|
|||
d_inner,
|
||||
n_head,
|
||||
d_k,
|
||||
d_v,
|
||||
d_q,
|
||||
filter_size,
|
||||
padding,
|
||||
dropout=0.2):
|
||||
"""Feed forward structure based on self-attention.
|
||||
|
||||
Args:
|
||||
d_model (int): the dim of hidden layer in multihead attention.
|
||||
d_inner (int): the dim of hidden layer in ffn.
|
||||
n_head (int): the head number of multihead attention.
|
||||
d_k (int): the dim of key in multihead attention.
|
||||
d_q (int): the dim of query in multihead attention.
|
||||
filter_size (int): the conv kernel size.
|
||||
padding (int): the conv padding size.
|
||||
dropout (float, optional): dropout probability. Defaults to 0.2.
|
||||
"""
|
||||
super(FFTBlock, self).__init__()
|
||||
self.slf_attn = MultiheadAttention(
|
||||
d_model,
|
||||
d_k,
|
||||
d_v,
|
||||
d_q,
|
||||
num_head=n_head,
|
||||
is_bias=True,
|
||||
dropout=dropout,
|
||||
|
@ -48,20 +60,18 @@ class FFTBlock(dg.Layer):
|
|||
|
||||
def forward(self, enc_input, non_pad_mask, slf_attn_mask=None):
|
||||
"""
|
||||
Feed Forward Transformer block in FastSpeech.
|
||||
Feed forward block of FastSpeech
|
||||
|
||||
Args:
|
||||
enc_input (Variable): The embedding characters input.
|
||||
Shape: (B, T, C), T means the timesteps of input, dtype: float32.
|
||||
non_pad_mask (Variable): The mask of sequence.
|
||||
Shape: (B, T, 1), dtype: int64.
|
||||
slf_attn_mask (Variable, optional): The mask of self attention. Defaults to None.
|
||||
Shape(B, len_q, len_k), len_q means the sequence length of query,
|
||||
len_k means the sequence length of key, dtype: int64.
|
||||
enc_input (Variable): shape(B, T, C), dtype float32, the embedding characters input,
|
||||
where T means the timesteps of input.
|
||||
non_pad_mask (Variable): shape(B, T, 1), dtype int64, the mask of sequence.
|
||||
slf_attn_mask (Variable, optional): shape(B, len_q, len_k), dtype int64, the mask of self attention,
|
||||
where len_q means the sequence length of query and len_k means the sequence length of key. Defaults to None.
|
||||
|
||||
Returns:
|
||||
output (Variable), the output after self-attention & ffn. Shape: (B, T, C).
|
||||
slf_attn (Variable), the self attention. Shape: (B * n_head, T, T),
|
||||
output (Variable): shape(B, T, C), the output after self-attention & ffn.
|
||||
slf_attn (Variable): shape(B * n_head, T, T), the self attention.
|
||||
"""
|
||||
output, slf_attn = self.slf_attn(
|
||||
enc_input, enc_input, enc_input, mask=slf_attn_mask)
|
||||
|
|
|
@ -22,6 +22,14 @@ from parakeet.modules.customized import Conv1D
|
|||
|
||||
class LengthRegulator(dg.Layer):
|
||||
def __init__(self, input_size, out_channels, filter_size, dropout=0.1):
|
||||
"""Length Regulator block in FastSpeech.
|
||||
|
||||
Args:
|
||||
input_size (int): the channel number of input.
|
||||
out_channels (int): the output channel number.
|
||||
filter_size (int): the filter size of duration predictor.
|
||||
dropout (float, optional): dropout probability. Defaults to 0.1.
|
||||
"""
|
||||
super(LengthRegulator, self).__init__()
|
||||
self.duration_predictor = DurationPredictor(
|
||||
input_size=input_size,
|
||||
|
@ -66,20 +74,18 @@ class LengthRegulator(dg.Layer):
|
|||
|
||||
def forward(self, x, alpha=1.0, target=None):
|
||||
"""
|
||||
Length Regulator block in FastSpeech.
|
||||
Compute length of mel from encoder output use TransformerTTS attention
|
||||
|
||||
Args:
|
||||
x (Variable): The encoder output.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
alpha (float32, optional): The hyperparameter to determine the length of
|
||||
x (Variable): shape(B, T, C), dtype float32, the encoder output.
|
||||
alpha (float32, optional): the hyperparameter to determine the length of
|
||||
the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0.
|
||||
target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS.
|
||||
Defaults to None. Shape: (B, T_text), dtype: int64.
|
||||
target (Variable, optional): shape(B, T_text), dtype int64, the duration of phoneme compute from pretrained transformerTTS.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
output (Variable), the output after exppand. Shape: (B, T, C),
|
||||
duration_predictor_output (Variable), the output of duration predictor.
|
||||
Shape: (B, T, C).
|
||||
output (Variable): shape(B, T, C), the output after exppand.
|
||||
duration_predictor_output (Variable): shape(B, T, C), the output of duration predictor.
|
||||
"""
|
||||
duration_predictor_output = self.duration_predictor(x)
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
|
@ -95,6 +101,14 @@ class LengthRegulator(dg.Layer):
|
|||
|
||||
class DurationPredictor(dg.Layer):
|
||||
def __init__(self, input_size, out_channels, filter_size, dropout=0.1):
|
||||
"""Duration Predictor block in FastSpeech.
|
||||
|
||||
Args:
|
||||
input_size (int): the channel number of input.
|
||||
out_channels (int): the output channel number.
|
||||
filter_size (int): the filter size.
|
||||
dropout (float, optional): dropout probability. Defaults to 0.1.
|
||||
"""
|
||||
super(DurationPredictor, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.out_channels = out_channels
|
||||
|
@ -137,12 +151,13 @@ class DurationPredictor(dg.Layer):
|
|||
|
||||
def forward(self, encoder_output):
|
||||
"""
|
||||
Duration Predictor block in FastSpeech.
|
||||
Predict the duration of each character.
|
||||
|
||||
Args:
|
||||
encoder_output (Variable): Shape(B, T, C), dtype: float32. The encoder output.
|
||||
encoder_output (Variable): shape(B, T, C), dtype float32, the encoder output.
|
||||
|
||||
Returns:
|
||||
out (Variable), Shape(B, T, C), the output of duration predictor.
|
||||
out (Variable): shape(B, T, C), the output of duration predictor.
|
||||
"""
|
||||
# encoder_output.shape(N, T, C)
|
||||
out = layers.transpose(encoder_output, [0, 2, 1])
|
||||
|
|
|
@ -30,6 +30,17 @@ class CBHG(dg.Layer):
|
|||
num_gru_layers=2,
|
||||
max_pool_kernel_size=2,
|
||||
is_post=False):
|
||||
"""CBHG Module
|
||||
|
||||
Args:
|
||||
hidden_size (int): dimension of hidden unit.
|
||||
batch_size (int): batch size of input.
|
||||
K (int, optional): number of convolution banks. Defaults to 16.
|
||||
projection_size (int, optional): dimension of projection unit. Defaults to 256.
|
||||
num_gru_layers (int, optional): number of layers of GRUcell. Defaults to 2.
|
||||
max_pool_kernel_size (int, optional): max pooling kernel size. Defaults to 2
|
||||
is_post (bool, optional): whether post processing or not. Defaults to False.
|
||||
"""
|
||||
super(CBHG, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -169,13 +180,13 @@ class CBHG(dg.Layer):
|
|||
|
||||
def forward(self, input_):
|
||||
"""
|
||||
CBHG Module
|
||||
Convert linear spectrum to Mel spectrum.
|
||||
|
||||
Args:
|
||||
input_(Variable): The sequentially input.
|
||||
Shape: (B, C, T), dtype: float32.
|
||||
input_ (Variable): shape(B, C, T), dtype float32, the sequentially input.
|
||||
|
||||
Returns:
|
||||
(Variable): the CBHG output.
|
||||
out (Variable): shape(B, C, T), the CBHG output.
|
||||
"""
|
||||
|
||||
conv_list = []
|
||||
|
@ -217,6 +228,12 @@ class CBHG(dg.Layer):
|
|||
|
||||
class Highwaynet(dg.Layer):
|
||||
def __init__(self, num_units, num_layers=4):
|
||||
"""Highway network
|
||||
|
||||
Args:
|
||||
num_units (int): dimension of hidden unit.
|
||||
num_layers (int, optional): number of highway layers. Defaults to 4.
|
||||
"""
|
||||
super(Highwaynet, self).__init__()
|
||||
self.num_units = num_units
|
||||
self.num_layers = num_layers
|
||||
|
@ -250,13 +267,13 @@ class Highwaynet(dg.Layer):
|
|||
|
||||
def forward(self, input_):
|
||||
"""
|
||||
Highway network
|
||||
Args:
|
||||
input_(Variable): The sequentially input.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
Compute result of Highway network.
|
||||
|
||||
Args:
|
||||
input_(Variable): shape(B, T, C), dtype float32, the sequentially input.
|
||||
|
||||
Returns:
|
||||
(Variable): the Highway output.
|
||||
out(Variable): the Highway output.
|
||||
"""
|
||||
out = input_
|
||||
|
||||
|
|
|
@ -23,6 +23,14 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet
|
|||
|
||||
class Decoder(dg.Layer):
|
||||
def __init__(self, num_hidden, config, num_head=4, n_layers=3):
|
||||
"""Decoder layer of TransformerTTS.
|
||||
|
||||
Args:
|
||||
num_hidden (int): the number of source vocabulary.
|
||||
config: the yaml configs used in decoder.
|
||||
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 3.
|
||||
"""
|
||||
super(Decoder, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.num_head = num_head
|
||||
|
@ -109,38 +117,26 @@ class Decoder(dg.Layer):
|
|||
m_self_mask=None,
|
||||
zero_mask=None):
|
||||
"""
|
||||
Decoder layer of TransformerTTS.
|
||||
Compute decoder outputs.
|
||||
|
||||
Args:
|
||||
key (Variable): The input key of decoder.
|
||||
Shape: (B, T_text, C), T_text means the timesteps of input text,
|
||||
dtype: float32.
|
||||
value (Variable): The . input value of decoder.
|
||||
Shape: (B, T_text, C), dtype: float32.
|
||||
query (Variable): The input query of decoder.
|
||||
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum,
|
||||
dtype: float32.
|
||||
positional (Variable): The spectrum position.
|
||||
Shape: (B, T_mel), dtype: int64.
|
||||
mask (Variable): the mask of decoder self attention.
|
||||
Shape: (B, T_mel, T_mel), dtype: int64.
|
||||
m_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
|
||||
Shape: (B, T_mel, 1), dtype: int64.
|
||||
m_self_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
|
||||
Shape: (B, T_mel, 1), dtype: int64.
|
||||
zero_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
|
||||
Shape: (B, T_mel, T_text), dtype: int64.
|
||||
key (Variable): shape(B, T_text, C), dtype float32, the input key of decoder,
|
||||
where T_text means the timesteps of input text,
|
||||
value (Variable): shape(B, T_text, C), dtype float32, the input value of decoder.
|
||||
query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
|
||||
where T_mel means the timesteps of input spectrum,
|
||||
positional (Variable): shape(B, T_mel), dtype int64, the spectrum position.
|
||||
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
|
||||
m_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
|
||||
m_self_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
|
||||
zero_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
|
||||
|
||||
Returns:
|
||||
mel_out (Variable): the decoder output after mel linear projection.
|
||||
Shape: (B, T_mel, C).
|
||||
out (Variable): the decoder output after post mel network.
|
||||
Shape: (B, T_mel, C).
|
||||
stop_tokens (Variable): the stop tokens of output.
|
||||
Shape: (B, T_mel, 1)
|
||||
attn_list (list[Variable]): the encoder-decoder attention list.
|
||||
Len: n_layers.
|
||||
selfattn_list (list[Variable]): the decoder self attention list.
|
||||
Len: n_layers.
|
||||
mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
|
||||
out (Variable): shape(B, T_mel, C), the decoder output after post mel network.
|
||||
stop_tokens (Variable): shape(B, T_mel, 1), the stop tokens of output.
|
||||
attn_list (list[Variable]): len(n_layers), the encoder-decoder attention list.
|
||||
selfattn_list (list[Variable]): len(n_layers), the decoder self attention list.
|
||||
"""
|
||||
|
||||
# get decoder mask with triangular matrix
|
||||
|
|
|
@ -21,6 +21,14 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet
|
|||
|
||||
class Encoder(dg.Layer):
|
||||
def __init__(self, embedding_size, num_hidden, num_head=4, n_layers=3):
|
||||
"""Encoder layer of TransformerTTS.
|
||||
|
||||
Args:
|
||||
embedding_size (int): the size of position embedding.
|
||||
num_hidden (int): the size of hidden layer in network.
|
||||
n_layers (int, optional): the layers number of multihead attention. Defaults to 4.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 3.
|
||||
"""
|
||||
super(Encoder, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.num_head = num_head
|
||||
|
@ -58,23 +66,18 @@ class Encoder(dg.Layer):
|
|||
|
||||
def forward(self, x, positional, mask=None, query_mask=None):
|
||||
"""
|
||||
Encoder layer of TransformerTTS.
|
||||
Encode text sequence.
|
||||
|
||||
Args:
|
||||
x (Variable): The input character.
|
||||
Shape: (B, T_text), T_text means the timesteps of input text,
|
||||
dtype: float32.
|
||||
positional (Variable): The characters position.
|
||||
Shape: (B, T_text), dtype: int64.
|
||||
mask (Variable, optional): the mask of encoder self attention. Defaults to None.
|
||||
Shape: (B, T_text, T_text), dtype: int64.
|
||||
query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
|
||||
Shape: (B, T_text, 1), dtype: int64.
|
||||
x (Variable): shape(B, T_text), dtype float32, the input character,
|
||||
where T_text means the timesteps of input text,
|
||||
positional (Variable): shape(B, T_text), dtype int64, the characters position.
|
||||
mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
|
||||
query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
|
||||
|
||||
Returns:
|
||||
x (Variable): the encoder output.
|
||||
Shape: (B, T_text, C).
|
||||
attentions (list[Variable]): the encoder self attention list.
|
||||
Len: n_layers.
|
||||
x (Variable): shape(B, T_text, C), the encoder output.
|
||||
attentions (list[Variable]): len(n_layers), the encoder self attention list.
|
||||
"""
|
||||
|
||||
if fluid.framework._dygraph_tracer()._train_mode:
|
||||
|
|
|
@ -22,6 +22,13 @@ import numpy as np
|
|||
|
||||
class EncoderPrenet(dg.Layer):
|
||||
def __init__(self, embedding_size, num_hidden, use_cudnn=True):
|
||||
""" Encoder prenet layer of TransformerTTS.
|
||||
|
||||
Args:
|
||||
embedding_size (int): the size of embedding.
|
||||
num_hidden (int): the size of hidden layer in network.
|
||||
use_cudnn (bool, optional): use cudnn or not. Defaults to True.
|
||||
"""
|
||||
super(EncoderPrenet, self).__init__()
|
||||
self.embedding_size = embedding_size
|
||||
self.num_hidden = num_hidden
|
||||
|
@ -82,14 +89,13 @@ class EncoderPrenet(dg.Layer):
|
|||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Encoder prenet layer of TransformerTTS.
|
||||
Prepare encoder input.
|
||||
|
||||
Args:
|
||||
x (Variable): The input character.
|
||||
Shape: (B, T_text), T_text means the timesteps of input text,
|
||||
dtype: float32.
|
||||
x (Variable): shape(B, T_text), dtype float32, the input character, where T_text means the timesteps of input text.
|
||||
|
||||
Returns:
|
||||
(Variable): the encoder prenet output. Shape: (B, T_text, C).
|
||||
(Variable): shape(B, T_text, C), the encoder prenet output.
|
||||
"""
|
||||
|
||||
x = self.embedding(x)
|
||||
|
|
|
@ -29,6 +29,19 @@ class PostConvNet(dg.Layer):
|
|||
use_cudnn=True,
|
||||
dropout=0.1,
|
||||
batchnorm_last=False):
|
||||
"""Decocder post conv net of TransformerTTS.
|
||||
|
||||
Args:
|
||||
n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80.
|
||||
num_hidden (int, optional): the size of hidden layer in network. Defaults to 512.
|
||||
filter_size (int, optional): the filter size of Conv. Defaults to 5.
|
||||
padding (int, optional): the padding size of Conv. Defaults to 0.
|
||||
num_conv (int, optional): the num of Conv layers in network. Defaults to 5.
|
||||
outputs_per_step (int, optional): the num of output frames per step . Defaults to 1.
|
||||
use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True.
|
||||
dropout (float, optional): dropout probability. Defaults to 0.1.
|
||||
batchnorm_last (bool, optional): if batchnorm at last layer or not. Defaults to False.
|
||||
"""
|
||||
super(PostConvNet, self).__init__()
|
||||
|
||||
self.dropout = dropout
|
||||
|
@ -93,13 +106,13 @@ class PostConvNet(dg.Layer):
|
|||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Decocder Post Conv Net of TransformerTTS.
|
||||
Compute the mel spectrum.
|
||||
|
||||
Args:
|
||||
input (Variable): The result of mel linear projection.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
input (Variable): shape(B, T, C), dtype float32, the result of mel linear projection.
|
||||
|
||||
Returns:
|
||||
(Variable): the result after postconvnet. Shape: (B, T, C),
|
||||
output (Variable): shape(B, T, C), the result after postconvnet.
|
||||
"""
|
||||
|
||||
input = layers.transpose(input, [0, 2, 1])
|
||||
|
|
|
@ -19,6 +19,14 @@ import paddle.fluid.layers as layers
|
|||
|
||||
class PreNet(dg.Layer):
|
||||
def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2):
|
||||
"""Prenet before passing through the network.
|
||||
|
||||
Args:
|
||||
input_size (int): the input channel size.
|
||||
hidden_size (int): the size of hidden layer in network.
|
||||
output_size (int): the output channel size.
|
||||
dropout_rate (float, optional): dropout probability. Defaults to 0.2.
|
||||
"""
|
||||
super(PreNet, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
|
@ -44,20 +52,20 @@ class PreNet(dg.Layer):
|
|||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Pre Net before passing through the network.
|
||||
Prepare network input.
|
||||
|
||||
Args:
|
||||
x (Variable): The input value.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
x (Variable): shape(B, T, C), dtype float32, the input value.
|
||||
|
||||
Returns:
|
||||
(Variable), the result after pernet. Shape: (B, T, C),
|
||||
output (Variable): shape(B, T, C), the result after pernet.
|
||||
"""
|
||||
x = layers.dropout(
|
||||
layers.relu(self.linear1(x)),
|
||||
self.dropout_rate,
|
||||
dropout_implementation='upscale_in_train')
|
||||
x = layers.dropout(
|
||||
output = layers.dropout(
|
||||
layers.relu(self.linear2(x)),
|
||||
self.dropout_rate,
|
||||
dropout_implementation='upscale_in_train')
|
||||
return x
|
||||
return output
|
||||
|
|
|
@ -19,6 +19,11 @@ from parakeet.models.transformer_tts.decoder import Decoder
|
|||
|
||||
class TransformerTTS(dg.Layer):
|
||||
def __init__(self, config):
|
||||
"""TransformerTTS model.
|
||||
|
||||
Args:
|
||||
config: the yaml configs used in TransformerTTS model.
|
||||
"""
|
||||
super(TransformerTTS, self).__init__()
|
||||
self.encoder = Encoder(config['embedding_size'], config['hidden_size'])
|
||||
self.decoder = Decoder(config['hidden_size'], config)
|
||||
|
@ -37,43 +42,28 @@ class TransformerTTS(dg.Layer):
|
|||
dec_query_mask=None):
|
||||
"""
|
||||
TransformerTTS network.
|
||||
|
||||
Args:
|
||||
characters (Variable): The input character.
|
||||
Shape: (B, T_text), T_text means the timesteps of input text,
|
||||
dtype: float32.
|
||||
mel_input (Variable): The input query of decoder.
|
||||
Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum,
|
||||
dtype: float32.
|
||||
pos_text (Variable): The characters position.
|
||||
Shape: (B, T_text), dtype: int64.
|
||||
dec_slf_mask (Variable): The spectrum position.
|
||||
Shape: (B, T_mel), dtype: int64.
|
||||
mask (Variable): the mask of decoder self attention.
|
||||
Shape: (B, T_mel, T_mel), dtype: int64.
|
||||
enc_slf_mask (Variable, optional): the mask of encoder self attention. Defaults to None.
|
||||
Shape: (B, T_text, T_text), dtype: int64.
|
||||
enc_query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None.
|
||||
Shape: (B, T_text, 1), dtype: int64.
|
||||
dec_query_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None.
|
||||
Shape: (B, T_mel, 1), dtype: int64.
|
||||
dec_query_slf_mask (Variable, optional): the query mask of decoder self attention. Defaults to None.
|
||||
Shape: (B, T_mel, 1), dtype: int64.
|
||||
enc_dec_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None.
|
||||
Shape: (B, T_mel, T_text), dtype: int64.
|
||||
characters (Variable): shape(B, T_text), dtype float32, the input character,
|
||||
where T_text means the timesteps of input text,
|
||||
mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder,
|
||||
where T_mel means the timesteps of input spectrum,
|
||||
pos_text (Variable): shape(B, T_text), dtype int64, the characters position.
|
||||
dec_slf_mask (Variable): shape(B, T_mel), dtype int64, the spectrum position.
|
||||
mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention.
|
||||
enc_slf_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None.
|
||||
enc_query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None.
|
||||
dec_query_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None.
|
||||
dec_query_slf_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None.
|
||||
enc_dec_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None.
|
||||
|
||||
Returns:
|
||||
mel_output (Variable): the decoder output after mel linear projection.
|
||||
Shape: (B, T_mel, C).
|
||||
postnet_output (Variable): the decoder output after post mel network.
|
||||
Shape: (B, T_mel, C).
|
||||
stop_preds (Variable): the stop tokens of output.
|
||||
Shape: (B, T_mel, 1)
|
||||
attn_probs (list[Variable]): the encoder-decoder attention list.
|
||||
Len: n_layers.
|
||||
attns_enc (list[Variable]): the encoder self attention list.
|
||||
Len: n_layers.
|
||||
attns_dec (list[Variable]): the decoder self attention list.
|
||||
Len: n_layers.
|
||||
mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection.
|
||||
postnet_output (Variable): shape(B, T_mel, C), the decoder output after post mel network.
|
||||
stop_preds (Variable): shape(B, T_mel, 1), the stop tokens of output.
|
||||
attn_probs (list[Variable]): len(n_layers), the encoder-decoder attention list.
|
||||
attns_enc (list[Variable]): len(n_layers), the encoder self attention list.
|
||||
attns_dec (list[Variable]): len(n_layers), the decoder self attention list.
|
||||
"""
|
||||
key, attns_enc = self.encoder(
|
||||
characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask)
|
||||
|
|
|
@ -20,6 +20,12 @@ from parakeet.models.transformer_tts.cbhg import CBHG
|
|||
|
||||
class Vocoder(dg.Layer):
|
||||
def __init__(self, config, batch_size):
|
||||
"""CBHG Network (mel -> linear)
|
||||
|
||||
Args:
|
||||
config: the yaml configs used in Vocoder model.
|
||||
batch_size (int): the batch size of input.
|
||||
"""
|
||||
super(Vocoder, self).__init__()
|
||||
self.pre_proj = Conv1D(
|
||||
num_channels=config['audio']['num_mels'],
|
||||
|
@ -33,14 +39,13 @@ class Vocoder(dg.Layer):
|
|||
|
||||
def forward(self, mel):
|
||||
"""
|
||||
CBHG Network (mel -> linear)
|
||||
Compute mel spectrum to linear spectrum.
|
||||
|
||||
Args:
|
||||
mel (Variable): The input mel spectrum.
|
||||
Shape: (B, C, T), dtype: float32.
|
||||
mel (Variable): shape(B, C, T), dtype float32, the input mel spectrum.
|
||||
|
||||
Returns:
|
||||
(Variable): the linear output.
|
||||
Shape: (B, T, C).
|
||||
mag_pred (Variable): shape(B, T, C), the linear output.
|
||||
"""
|
||||
mel = layers.transpose(mel, [0, 2, 1])
|
||||
mel = self.pre_proj(mel)
|
||||
|
|
|
@ -43,10 +43,10 @@ class DynamicGRU(dg.Layer):
|
|||
Dynamic GRU block.
|
||||
|
||||
Args:
|
||||
input (Variable): The input value.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
input (Variable): shape(B, T, C), dtype float32, the input value.
|
||||
|
||||
Returns:
|
||||
output (Variable), the result compute by GRU. Shape: (B, T, C).
|
||||
output (Variable): shape(B, T, C), the result compute by GRU.
|
||||
"""
|
||||
hidden = self.h_0
|
||||
res = []
|
||||
|
|
|
@ -19,8 +19,6 @@ from parakeet.modules.customized import Conv1D
|
|||
|
||||
|
||||
class PositionwiseFeedForward(dg.Layer):
|
||||
''' A two-feed-forward-layer module '''
|
||||
|
||||
def __init__(self,
|
||||
d_in,
|
||||
num_hidden,
|
||||
|
@ -28,6 +26,16 @@ class PositionwiseFeedForward(dg.Layer):
|
|||
padding=0,
|
||||
use_cudnn=True,
|
||||
dropout=0.1):
|
||||
"""A two-feed-forward-layer module.
|
||||
|
||||
Args:
|
||||
d_in (int): the size of input channel.
|
||||
num_hidden (int): the size of hidden layer in network.
|
||||
filter_size (int): the filter size of Conv
|
||||
padding (int, optional): the padding size of Conv. Defaults to 0.
|
||||
use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True.
|
||||
dropout (float, optional): dropout probability. Defaults to 0.1.
|
||||
"""
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.use_cudnn = use_cudnn
|
||||
|
@ -59,13 +67,13 @@ class PositionwiseFeedForward(dg.Layer):
|
|||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Feed Forward Network.
|
||||
Compute feed forward network result.
|
||||
|
||||
Args:
|
||||
input (Variable): The input value.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
input (Variable): shape(B, T, C), dtype float32, the input value.
|
||||
|
||||
Returns:
|
||||
output (Variable), the result after FFN. Shape: (B, T, C).
|
||||
output (Variable): shape(B, T, C), the result after FFN.
|
||||
"""
|
||||
x = layers.transpose(input, [0, 2, 1])
|
||||
#FFN Networt
|
||||
|
|
|
@ -50,6 +50,11 @@ class Linear(dg.Layer):
|
|||
|
||||
class ScaledDotProductAttention(dg.Layer):
|
||||
def __init__(self, d_key):
|
||||
"""Scaled dot product attention module.
|
||||
|
||||
Args:
|
||||
d_key (int): the dim of key in multihead attention.
|
||||
"""
|
||||
super(ScaledDotProductAttention, self).__init__()
|
||||
|
||||
self.d_key = d_key
|
||||
|
@ -63,23 +68,18 @@ class ScaledDotProductAttention(dg.Layer):
|
|||
query_mask=None,
|
||||
dropout=0.1):
|
||||
"""
|
||||
Scaled Dot Product Attention.
|
||||
Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
key (Variable): The input key of scaled dot product attention.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
value (Variable): The input value of scaled dot product attention.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
query (Variable): The input query of scaled dot product attention.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
mask (Variable, optional): The mask of key. Defaults to None.
|
||||
Shape(B, T_q, T_k), dtype: float32.
|
||||
query_mask (Variable, optional): The mask of query. Defaults to None.
|
||||
Shape(B, T_q, T_q), dtype: float32.
|
||||
dropout (float32, optional): The probability of dropout. Defaults to 0.1.
|
||||
key (Variable): shape(B, T, C), dtype float32, the input key of scaled dot product attention.
|
||||
value (Variable): shape(B, T, C), dtype float32, the input value of scaled dot product attention.
|
||||
query (Variable): shape(B, T, C), dtype float32, the input query of scaled dot product attention.
|
||||
mask (Variable, optional): shape(B, T_q, T_k), dtype float32, the mask of key. Defaults to None.
|
||||
query_mask (Variable, optional): shape(B, T_q, T_q), dtype float32, the mask of query. Defaults to None.
|
||||
dropout (float32, optional): the probability of dropout. Defaults to 0.1.
|
||||
Returns:
|
||||
result (Variable), Shape(B, T, C), the result of mutihead attention.
|
||||
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
||||
result (Variable): shape(B, T, C), the result of mutihead attention.
|
||||
attention (Variable): shape(n_head * B, T, C), the attention of key.
|
||||
"""
|
||||
# Compute attention score
|
||||
attention = layers.matmul(
|
||||
|
@ -110,6 +110,17 @@ class MultiheadAttention(dg.Layer):
|
|||
is_bias=False,
|
||||
dropout=0.1,
|
||||
is_concat=True):
|
||||
"""Multihead Attention.
|
||||
|
||||
Args:
|
||||
num_hidden (int): the number of hidden layer in network.
|
||||
d_k (int): the dim of key in multihead attention.
|
||||
d_q (int): the dim of query in multihead attention.
|
||||
num_head (int, optional): the head number of multihead attention. Defaults to 4.
|
||||
is_bias (bool, optional): whether have bias in linear layers. Default to False.
|
||||
dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1.
|
||||
is_concat (bool, optional): whether concat query and result. Default to True.
|
||||
"""
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.num_hidden = num_hidden
|
||||
self.num_head = num_head
|
||||
|
@ -133,22 +144,18 @@ class MultiheadAttention(dg.Layer):
|
|||
|
||||
def forward(self, key, value, query_input, mask=None, query_mask=None):
|
||||
"""
|
||||
Multihead Attention.
|
||||
Compute attention.
|
||||
|
||||
Args:
|
||||
key (Variable): The input key of attention.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
value (Variable): The input value of attention.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
query_input (Variable): The input query of attention.
|
||||
Shape: (B, T, C), dtype: float32.
|
||||
mask (Variable, optional): The mask of key. Defaults to None.
|
||||
Shape: (B, T_query, T_key), dtype: float32.
|
||||
query_mask (Variable, optional): The mask of query. Defaults to None.
|
||||
Shape: (B, T_query, T_key), dtype: float32.
|
||||
key (Variable): shape(B, T, C), dtype float32, the input key of attention.
|
||||
value (Variable): shape(B, T, C), dtype float32, the input value of attention.
|
||||
query_input (Variable): shape(B, T, C), dtype float32, the input query of attention.
|
||||
mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of key. Defaults to None.
|
||||
query_mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of query. Defaults to None.
|
||||
|
||||
Returns:
|
||||
result (Variable), the result of mutihead attention. Shape: (B, T, C).
|
||||
attention (Variable), the attention of key and query. Shape: (num_head * B, T, C)
|
||||
result (Variable): shape(B, T, C), the result of mutihead attention.
|
||||
attention (Variable): shape(num_head * B, T, C), the attention of key and query.
|
||||
"""
|
||||
|
||||
batch_size = key.shape[0]
|
||||
|
|
Loading…
Reference in New Issue