From 429695d6a06aa7e75dfdd69ccdca6fbf6475901b Mon Sep 17 00:00:00 2001 From: lifuchen Date: Mon, 9 Mar 2020 11:57:49 +0000 Subject: [PATCH] add docstring to transformer_tts and fastspeech --- parakeet/models/fastspeech/decoder.py | 46 ++++++++------ parakeet/models/fastspeech/encoder.py | 48 +++++++++------ parakeet/models/fastspeech/fastspeech.py | 56 ++++++++--------- parakeet/models/fastspeech/fft_block.py | 34 +++++++---- .../models/fastspeech/length_regulator.py | 39 ++++++++---- parakeet/models/transformer_tts/cbhg.py | 35 ++++++++--- parakeet/models/transformer_tts/decoder.py | 54 ++++++++-------- parakeet/models/transformer_tts/encoder.py | 31 +++++----- .../models/transformer_tts/encoderprenet.py | 16 +++-- .../models/transformer_tts/post_convnet.py | 21 +++++-- parakeet/models/transformer_tts/prenet.py | 20 ++++-- .../models/transformer_tts/transformer_tts.py | 58 ++++++++---------- parakeet/models/transformer_tts/vocoder.py | 15 +++-- parakeet/modules/dynamic_gru.py | 6 +- parakeet/modules/ffn.py | 20 ++++-- parakeet/modules/multihead_attention.py | 61 +++++++++++-------- 16 files changed, 326 insertions(+), 234 deletions(-) diff --git a/parakeet/models/fastspeech/decoder.py b/parakeet/models/fastspeech/decoder.py index 96fefc1..30432d0 100644 --- a/parakeet/models/fastspeech/decoder.py +++ b/parakeet/models/fastspeech/decoder.py @@ -23,12 +23,26 @@ class Decoder(dg.Layer): n_layers, n_head, d_k, - d_v, + d_q, d_model, d_inner, fft_conv1d_kernel, fft_conv1d_padding, dropout=0.1): + """Decoder layer of FastSpeech. + + Args: + len_max_seq (int): the max mel len of sequence. + n_layers (int): the layers number of FFTBlock. + n_head (int): the head number of multihead attention. + d_k (int): the dim of key in multihead attention. + d_q (int): the dim of query in multihead attention. + d_model (int): the dim of hidden layer in multihead attention. + d_inner (int): the dim of hidden layer in ffn. + fft_conv1d_kernel (int): the conv kernel size in FFTBlock. + fft_conv1d_padding (int): the conv padding size in FFTBlock. + dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1. + """ super(Decoder, self).__init__() n_position = len_max_seq + 1 @@ -48,7 +62,7 @@ class Decoder(dg.Layer): d_inner, n_head, d_k, - d_v, + d_q, fft_conv1d_kernel, fft_conv1d_padding, dropout=dropout) for _ in range(n_layers) @@ -58,26 +72,20 @@ class Decoder(dg.Layer): def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None): """ - Decoder layer of FastSpeech. + Compute decoder outputs. + Args: - enc_seq (Variable): The output of length regulator. - Shape: (B, T_text, C), T_text means the timesteps of input text, - dtype: float32. - enc_pos (Variable): The spectrum position. - Shape: (B, T_mel), T_mel means the timesteps of input spectrum, - dtype: int64. - non_pad_mask (Variable): the mask with non pad. - Shape: (B, T_mel, 1), - dtype: int64. - slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None. - Shape: (B, T_mel, T_mel), - dtype: int64. + enc_seq (Variable): shape(B, T_text, C), dtype float32, + the output of length regulator, where T_text means the timesteps of input text, + enc_pos (Variable): shape(B, T_mel), dtype int64, + the spectrum position, where T_mel means the timesteps of input spectrum, + non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad. + slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64, + the mask of mel spectrum. Defaults to None. Returns: - dec_output (Variable): the decoder output. - Shape: (B, T_mel, C). - dec_slf_attn_list (list[Variable]): the decoder self attention list. - Len: n_layers. + dec_output (Variable): shape(B, T_mel, C), the decoder output. + dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list. """ dec_slf_attn_list = [] slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) diff --git a/parakeet/models/fastspeech/encoder.py b/parakeet/models/fastspeech/encoder.py index 0a39ded..d39fdc1 100644 --- a/parakeet/models/fastspeech/encoder.py +++ b/parakeet/models/fastspeech/encoder.py @@ -24,12 +24,27 @@ class Encoder(dg.Layer): n_layers, n_head, d_k, - d_v, + d_q, d_model, d_inner, fft_conv1d_kernel, fft_conv1d_padding, dropout=0.1): + """Encoder layer of FastSpeech. + + Args: + n_src_vocab (int): the number of source vocabulary. + len_max_seq (int): the max mel len of sequence. + n_layers (int): the layers number of FFTBlock. + n_head (int): the head number of multihead attention. + d_k (int): the dim of key in multihead attention. + d_q (int): the dim of query in multihead attention. + d_model (int): the dim of hidden layer in multihead attention. + d_inner (int): the dim of hidden layer in ffn. + fft_conv1d_kernel (int): the conv kernel size in FFTBlock. + fft_conv1d_padding (int): the conv padding size in FFTBlock. + dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1. + """ super(Encoder, self).__init__() n_position = len_max_seq + 1 self.n_head = n_head @@ -53,7 +68,7 @@ class Encoder(dg.Layer): d_inner, n_head, d_k, - d_v, + d_q, fft_conv1d_kernel, fft_conv1d_padding, dropout=dropout) for _ in range(n_layers) @@ -63,25 +78,20 @@ class Encoder(dg.Layer): def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None): """ - Encoder layer of FastSpeech. + Encode text sequence. + Args: - character (Variable): The input text characters. - Shape: (B, T_text), T_text means the timesteps of input characters, - dtype: float32. - text_pos (Variable): The input text position. - Shape: (B, T_text), dtype: int64. - non_pad_mask (Variable): the mask with non pad. - Shape: (B, T_text, 1), - dtype: int64. - slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None. - Shape: (B, T_text, T_text), - dtype: int64. - + character (Variable): shape(B, T_text), dtype float32, the input text characters, + where T_text means the timesteps of input characters, + text_pos (Variable): shape(B, T_text), dtype int64, the input text position. + non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad. + slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, + the mask of input characters. Defaults to None. + Returns: - enc_output (Variable), the encoder output. Shape(B, T_text, C) - non_pad_mask (Variable), the mask with non pad. Shape(B, T_text, 1) - enc_slf_attn_list (list[Variable]), the encoder self attention list. - Len: n_layers. + enc_output (Variable): shape(B, T_text, C), the encoder output. + non_pad_mask (Variable): shape(B, T_text, 1), the mask with non pad. + enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list. """ enc_slf_attn_list = [] slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) diff --git a/parakeet/models/fastspeech/fastspeech.py b/parakeet/models/fastspeech/fastspeech.py index 75f5db9..5ee2de1 100644 --- a/parakeet/models/fastspeech/fastspeech.py +++ b/parakeet/models/fastspeech/fastspeech.py @@ -25,7 +25,11 @@ from parakeet.models.fastspeech.decoder import Decoder class FastSpeech(dg.Layer): def __init__(self, cfg): - " FastSpeech" + """FastSpeech model. + + Args: + cfg: the yaml configs used in FastSpeech model. + """ super(FastSpeech, self).__init__() self.encoder = Encoder( @@ -34,7 +38,7 @@ class FastSpeech(dg.Layer): n_layers=cfg['encoder_n_layer'], n_head=cfg['encoder_head'], d_k=cfg['fs_hidden_size'] // cfg['encoder_head'], - d_v=cfg['fs_hidden_size'] // cfg['encoder_head'], + d_q=cfg['fs_hidden_size'] // cfg['encoder_head'], d_model=cfg['fs_hidden_size'], d_inner=cfg['encoder_conv1d_filter_size'], fft_conv1d_kernel=cfg['fft_conv1d_filter'], @@ -50,7 +54,7 @@ class FastSpeech(dg.Layer): n_layers=cfg['decoder_n_layer'], n_head=cfg['decoder_head'], d_k=cfg['fs_hidden_size'] // cfg['decoder_head'], - d_v=cfg['fs_hidden_size'] // cfg['decoder_head'], + d_q=cfg['fs_hidden_size'] // cfg['decoder_head'], d_model=cfg['fs_hidden_size'], d_inner=cfg['decoder_conv1d_filter_size'], fft_conv1d_kernel=cfg['fft_conv1d_filter'], @@ -88,39 +92,31 @@ class FastSpeech(dg.Layer): length_target=None, alpha=1.0): """ - FastSpeech model. + Compute mel output from text character. Args: - character (Variable): The input text characters. - Shape: (B, T_text), T_text means the timesteps of input characters, dtype: float32. - text_pos (Variable): The input text position. - Shape: (B, T_text), dtype: int64. - mel_pos (Variable, optional): The spectrum position. - Shape: (B, T_mel), T_mel means the timesteps of input spectrum, dtype: int64. - enc_non_pad_mask (Variable): the mask with non pad. - Shape: (B, T_text, 1), - dtype: int64. - dec_non_pad_mask (Variable): the mask with non pad. - Shape: (B, T_mel, 1), - dtype: int64. - enc_slf_attn_mask (Variable, optional): the mask of input characters. Defaults to None. - Shape: (B, T_text, T_text), - dtype: int64. - slf_attn_mask (Variable, optional): the mask of mel spectrum. Defaults to None. - Shape: (B, T_mel, T_mel), - dtype: int64. - length_target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS. - Defaults to None. Shape: (B, T_text), dtype: int64. + character (Variable): shape(B, T_text), dtype float32, the input text characters, + where T_text means the timesteps of input characters, + text_pos (Variable): shape(B, T_text), dtype int64, the input text position. + mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position, + where T_mel means the timesteps of input spectrum, + enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad. + dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad. + enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, + the mask of input characters. Defaults to None. + slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64, + the mask of mel spectrum. Defaults to None. + length_target (Variable, optional): shape(B, T_text), dtype int64, + the duration of phoneme compute from pretrained transformerTTS. Defaults to None. alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0. Returns: - mel_output (Variable), the mel output before postnet. Shape: (B, T_mel, C), - mel_output_postnet (Variable), the mel output after postnet. Shape: (B, T_mel, C). - duration_predictor_output (Variable), the duration of phoneme compute with duration predictor. - Shape: (B, T_text). - enc_slf_attn_list (List[Variable]), the encoder self attention list. Len: enc_n_layers. - dec_slf_attn_list (List[Variable]), the decoder self attention list. Len: dec_n_layers. + mel_output (Variable): shape(B, T_mel, C), the mel output before postnet. + mel_output_postnet (Variable): shape(B, T_mel, C), the mel output after postnet. + duration_predictor_output (Variable): shape(B, T_text), the duration of phoneme compute with duration predictor. + enc_slf_attn_list (List[Variable]): len(enc_n_layers), the encoder self attention list. + dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list. """ encoder_output, enc_slf_attn_list = self.encoder( diff --git a/parakeet/models/fastspeech/fft_block.py b/parakeet/models/fastspeech/fft_block.py index d58f5b2..b3c69ea 100644 --- a/parakeet/models/fastspeech/fft_block.py +++ b/parakeet/models/fastspeech/fft_block.py @@ -26,15 +26,27 @@ class FFTBlock(dg.Layer): d_inner, n_head, d_k, - d_v, + d_q, filter_size, padding, dropout=0.2): + """Feed forward structure based on self-attention. + + Args: + d_model (int): the dim of hidden layer in multihead attention. + d_inner (int): the dim of hidden layer in ffn. + n_head (int): the head number of multihead attention. + d_k (int): the dim of key in multihead attention. + d_q (int): the dim of query in multihead attention. + filter_size (int): the conv kernel size. + padding (int): the conv padding size. + dropout (float, optional): dropout probability. Defaults to 0.2. + """ super(FFTBlock, self).__init__() self.slf_attn = MultiheadAttention( d_model, d_k, - d_v, + d_q, num_head=n_head, is_bias=True, dropout=dropout, @@ -48,20 +60,18 @@ class FFTBlock(dg.Layer): def forward(self, enc_input, non_pad_mask, slf_attn_mask=None): """ - Feed Forward Transformer block in FastSpeech. + Feed forward block of FastSpeech Args: - enc_input (Variable): The embedding characters input. - Shape: (B, T, C), T means the timesteps of input, dtype: float32. - non_pad_mask (Variable): The mask of sequence. - Shape: (B, T, 1), dtype: int64. - slf_attn_mask (Variable, optional): The mask of self attention. Defaults to None. - Shape(B, len_q, len_k), len_q means the sequence length of query, - len_k means the sequence length of key, dtype: int64. + enc_input (Variable): shape(B, T, C), dtype float32, the embedding characters input, + where T means the timesteps of input. + non_pad_mask (Variable): shape(B, T, 1), dtype int64, the mask of sequence. + slf_attn_mask (Variable, optional): shape(B, len_q, len_k), dtype int64, the mask of self attention, + where len_q means the sequence length of query and len_k means the sequence length of key. Defaults to None. Returns: - output (Variable), the output after self-attention & ffn. Shape: (B, T, C). - slf_attn (Variable), the self attention. Shape: (B * n_head, T, T), + output (Variable): shape(B, T, C), the output after self-attention & ffn. + slf_attn (Variable): shape(B * n_head, T, T), the self attention. """ output, slf_attn = self.slf_attn( enc_input, enc_input, enc_input, mask=slf_attn_mask) diff --git a/parakeet/models/fastspeech/length_regulator.py b/parakeet/models/fastspeech/length_regulator.py index 6ba53d7..6fc6702 100644 --- a/parakeet/models/fastspeech/length_regulator.py +++ b/parakeet/models/fastspeech/length_regulator.py @@ -22,6 +22,14 @@ from parakeet.modules.customized import Conv1D class LengthRegulator(dg.Layer): def __init__(self, input_size, out_channels, filter_size, dropout=0.1): + """Length Regulator block in FastSpeech. + + Args: + input_size (int): the channel number of input. + out_channels (int): the output channel number. + filter_size (int): the filter size of duration predictor. + dropout (float, optional): dropout probability. Defaults to 0.1. + """ super(LengthRegulator, self).__init__() self.duration_predictor = DurationPredictor( input_size=input_size, @@ -66,20 +74,18 @@ class LengthRegulator(dg.Layer): def forward(self, x, alpha=1.0, target=None): """ - Length Regulator block in FastSpeech. + Compute length of mel from encoder output use TransformerTTS attention Args: - x (Variable): The encoder output. - Shape: (B, T, C), dtype: float32. - alpha (float32, optional): The hyperparameter to determine the length of + x (Variable): shape(B, T, C), dtype float32, the encoder output. + alpha (float32, optional): the hyperparameter to determine the length of the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0. - target (Variable, optional): The duration of phoneme compute from pretrained transformerTTS. - Defaults to None. Shape: (B, T_text), dtype: int64. + target (Variable, optional): shape(B, T_text), dtype int64, the duration of phoneme compute from pretrained transformerTTS. + Defaults to None. Returns: - output (Variable), the output after exppand. Shape: (B, T, C), - duration_predictor_output (Variable), the output of duration predictor. - Shape: (B, T, C). + output (Variable): shape(B, T, C), the output after exppand. + duration_predictor_output (Variable): shape(B, T, C), the output of duration predictor. """ duration_predictor_output = self.duration_predictor(x) if fluid.framework._dygraph_tracer()._train_mode: @@ -95,6 +101,14 @@ class LengthRegulator(dg.Layer): class DurationPredictor(dg.Layer): def __init__(self, input_size, out_channels, filter_size, dropout=0.1): + """Duration Predictor block in FastSpeech. + + Args: + input_size (int): the channel number of input. + out_channels (int): the output channel number. + filter_size (int): the filter size. + dropout (float, optional): dropout probability. Defaults to 0.1. + """ super(DurationPredictor, self).__init__() self.input_size = input_size self.out_channels = out_channels @@ -137,12 +151,13 @@ class DurationPredictor(dg.Layer): def forward(self, encoder_output): """ - Duration Predictor block in FastSpeech. + Predict the duration of each character. Args: - encoder_output (Variable): Shape(B, T, C), dtype: float32. The encoder output. + encoder_output (Variable): shape(B, T, C), dtype float32, the encoder output. + Returns: - out (Variable), Shape(B, T, C), the output of duration predictor. + out (Variable): shape(B, T, C), the output of duration predictor. """ # encoder_output.shape(N, T, C) out = layers.transpose(encoder_output, [0, 2, 1]) diff --git a/parakeet/models/transformer_tts/cbhg.py b/parakeet/models/transformer_tts/cbhg.py index b21cd87..5a28ebd 100644 --- a/parakeet/models/transformer_tts/cbhg.py +++ b/parakeet/models/transformer_tts/cbhg.py @@ -30,6 +30,17 @@ class CBHG(dg.Layer): num_gru_layers=2, max_pool_kernel_size=2, is_post=False): + """CBHG Module + + Args: + hidden_size (int): dimension of hidden unit. + batch_size (int): batch size of input. + K (int, optional): number of convolution banks. Defaults to 16. + projection_size (int, optional): dimension of projection unit. Defaults to 256. + num_gru_layers (int, optional): number of layers of GRUcell. Defaults to 2. + max_pool_kernel_size (int, optional): max pooling kernel size. Defaults to 2 + is_post (bool, optional): whether post processing or not. Defaults to False. + """ super(CBHG, self).__init__() self.hidden_size = hidden_size @@ -169,13 +180,13 @@ class CBHG(dg.Layer): def forward(self, input_): """ - CBHG Module + Convert linear spectrum to Mel spectrum. + Args: - input_(Variable): The sequentially input. - Shape: (B, C, T), dtype: float32. + input_ (Variable): shape(B, C, T), dtype float32, the sequentially input. Returns: - (Variable): the CBHG output. + out (Variable): shape(B, C, T), the CBHG output. """ conv_list = [] @@ -217,6 +228,12 @@ class CBHG(dg.Layer): class Highwaynet(dg.Layer): def __init__(self, num_units, num_layers=4): + """Highway network + + Args: + num_units (int): dimension of hidden unit. + num_layers (int, optional): number of highway layers. Defaults to 4. + """ super(Highwaynet, self).__init__() self.num_units = num_units self.num_layers = num_layers @@ -250,13 +267,13 @@ class Highwaynet(dg.Layer): def forward(self, input_): """ - Highway network - Args: - input_(Variable): The sequentially input. - Shape: (B, T, C), dtype: float32. + Compute result of Highway network. + Args: + input_(Variable): shape(B, T, C), dtype float32, the sequentially input. + Returns: - (Variable): the Highway output. + out(Variable): the Highway output. """ out = input_ diff --git a/parakeet/models/transformer_tts/decoder.py b/parakeet/models/transformer_tts/decoder.py index 54dc679..0b47e4f 100644 --- a/parakeet/models/transformer_tts/decoder.py +++ b/parakeet/models/transformer_tts/decoder.py @@ -23,6 +23,14 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet class Decoder(dg.Layer): def __init__(self, num_hidden, config, num_head=4, n_layers=3): + """Decoder layer of TransformerTTS. + + Args: + num_hidden (int): the number of source vocabulary. + config: the yaml configs used in decoder. + n_layers (int, optional): the layers number of multihead attention. Defaults to 4. + num_head (int, optional): the head number of multihead attention. Defaults to 3. + """ super(Decoder, self).__init__() self.num_hidden = num_hidden self.num_head = num_head @@ -109,38 +117,26 @@ class Decoder(dg.Layer): m_self_mask=None, zero_mask=None): """ - Decoder layer of TransformerTTS. + Compute decoder outputs. + Args: - key (Variable): The input key of decoder. - Shape: (B, T_text, C), T_text means the timesteps of input text, - dtype: float32. - value (Variable): The . input value of decoder. - Shape: (B, T_text, C), dtype: float32. - query (Variable): The input query of decoder. - Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum, - dtype: float32. - positional (Variable): The spectrum position. - Shape: (B, T_mel), dtype: int64. - mask (Variable): the mask of decoder self attention. - Shape: (B, T_mel, T_mel), dtype: int64. - m_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None. - Shape: (B, T_mel, 1), dtype: int64. - m_self_mask (Variable, optional): the query mask of decoder self attention. Defaults to None. - Shape: (B, T_mel, 1), dtype: int64. - zero_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None. - Shape: (B, T_mel, T_text), dtype: int64. + key (Variable): shape(B, T_text, C), dtype float32, the input key of decoder, + where T_text means the timesteps of input text, + value (Variable): shape(B, T_text, C), dtype float32, the input value of decoder. + query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder, + where T_mel means the timesteps of input spectrum, + positional (Variable): shape(B, T_mel), dtype int64, the spectrum position. + mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention. + m_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None. + m_self_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None. + zero_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None. Returns: - mel_out (Variable): the decoder output after mel linear projection. - Shape: (B, T_mel, C). - out (Variable): the decoder output after post mel network. - Shape: (B, T_mel, C). - stop_tokens (Variable): the stop tokens of output. - Shape: (B, T_mel, 1) - attn_list (list[Variable]): the encoder-decoder attention list. - Len: n_layers. - selfattn_list (list[Variable]): the decoder self attention list. - Len: n_layers. + mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection. + out (Variable): shape(B, T_mel, C), the decoder output after post mel network. + stop_tokens (Variable): shape(B, T_mel, 1), the stop tokens of output. + attn_list (list[Variable]): len(n_layers), the encoder-decoder attention list. + selfattn_list (list[Variable]): len(n_layers), the decoder self attention list. """ # get decoder mask with triangular matrix diff --git a/parakeet/models/transformer_tts/encoder.py b/parakeet/models/transformer_tts/encoder.py index 748b423..48fc6c1 100644 --- a/parakeet/models/transformer_tts/encoder.py +++ b/parakeet/models/transformer_tts/encoder.py @@ -21,6 +21,14 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet class Encoder(dg.Layer): def __init__(self, embedding_size, num_hidden, num_head=4, n_layers=3): + """Encoder layer of TransformerTTS. + + Args: + embedding_size (int): the size of position embedding. + num_hidden (int): the size of hidden layer in network. + n_layers (int, optional): the layers number of multihead attention. Defaults to 4. + num_head (int, optional): the head number of multihead attention. Defaults to 3. + """ super(Encoder, self).__init__() self.num_hidden = num_hidden self.num_head = num_head @@ -58,23 +66,18 @@ class Encoder(dg.Layer): def forward(self, x, positional, mask=None, query_mask=None): """ - Encoder layer of TransformerTTS. + Encode text sequence. + Args: - x (Variable): The input character. - Shape: (B, T_text), T_text means the timesteps of input text, - dtype: float32. - positional (Variable): The characters position. - Shape: (B, T_text), dtype: int64. - mask (Variable, optional): the mask of encoder self attention. Defaults to None. - Shape: (B, T_text, T_text), dtype: int64. - query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None. - Shape: (B, T_text, 1), dtype: int64. + x (Variable): shape(B, T_text), dtype float32, the input character, + where T_text means the timesteps of input text, + positional (Variable): shape(B, T_text), dtype int64, the characters position. + mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None. + query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None. Returns: - x (Variable): the encoder output. - Shape: (B, T_text, C). - attentions (list[Variable]): the encoder self attention list. - Len: n_layers. + x (Variable): shape(B, T_text, C), the encoder output. + attentions (list[Variable]): len(n_layers), the encoder self attention list. """ if fluid.framework._dygraph_tracer()._train_mode: diff --git a/parakeet/models/transformer_tts/encoderprenet.py b/parakeet/models/transformer_tts/encoderprenet.py index 0e4cd45..d692cea 100644 --- a/parakeet/models/transformer_tts/encoderprenet.py +++ b/parakeet/models/transformer_tts/encoderprenet.py @@ -22,6 +22,13 @@ import numpy as np class EncoderPrenet(dg.Layer): def __init__(self, embedding_size, num_hidden, use_cudnn=True): + """ Encoder prenet layer of TransformerTTS. + + Args: + embedding_size (int): the size of embedding. + num_hidden (int): the size of hidden layer in network. + use_cudnn (bool, optional): use cudnn or not. Defaults to True. + """ super(EncoderPrenet, self).__init__() self.embedding_size = embedding_size self.num_hidden = num_hidden @@ -82,14 +89,13 @@ class EncoderPrenet(dg.Layer): def forward(self, x): """ - Encoder prenet layer of TransformerTTS. + Prepare encoder input. + Args: - x (Variable): The input character. - Shape: (B, T_text), T_text means the timesteps of input text, - dtype: float32. + x (Variable): shape(B, T_text), dtype float32, the input character, where T_text means the timesteps of input text. Returns: - (Variable): the encoder prenet output. Shape: (B, T_text, C). + (Variable): shape(B, T_text, C), the encoder prenet output. """ x = self.embedding(x) diff --git a/parakeet/models/transformer_tts/post_convnet.py b/parakeet/models/transformer_tts/post_convnet.py index c754ca2..da458ed 100644 --- a/parakeet/models/transformer_tts/post_convnet.py +++ b/parakeet/models/transformer_tts/post_convnet.py @@ -29,6 +29,19 @@ class PostConvNet(dg.Layer): use_cudnn=True, dropout=0.1, batchnorm_last=False): + """Decocder post conv net of TransformerTTS. + + Args: + n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80. + num_hidden (int, optional): the size of hidden layer in network. Defaults to 512. + filter_size (int, optional): the filter size of Conv. Defaults to 5. + padding (int, optional): the padding size of Conv. Defaults to 0. + num_conv (int, optional): the num of Conv layers in network. Defaults to 5. + outputs_per_step (int, optional): the num of output frames per step . Defaults to 1. + use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True. + dropout (float, optional): dropout probability. Defaults to 0.1. + batchnorm_last (bool, optional): if batchnorm at last layer or not. Defaults to False. + """ super(PostConvNet, self).__init__() self.dropout = dropout @@ -93,13 +106,13 @@ class PostConvNet(dg.Layer): def forward(self, input): """ - Decocder Post Conv Net of TransformerTTS. + Compute the mel spectrum. Args: - input (Variable): The result of mel linear projection. - Shape: (B, T, C), dtype: float32. + input (Variable): shape(B, T, C), dtype float32, the result of mel linear projection. + Returns: - (Variable): the result after postconvnet. Shape: (B, T, C), + output (Variable): shape(B, T, C), the result after postconvnet. """ input = layers.transpose(input, [0, 2, 1]) diff --git a/parakeet/models/transformer_tts/prenet.py b/parakeet/models/transformer_tts/prenet.py index c6d2428..b033860 100644 --- a/parakeet/models/transformer_tts/prenet.py +++ b/parakeet/models/transformer_tts/prenet.py @@ -19,6 +19,14 @@ import paddle.fluid.layers as layers class PreNet(dg.Layer): def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2): + """Prenet before passing through the network. + + Args: + input_size (int): the input channel size. + hidden_size (int): the size of hidden layer in network. + output_size (int): the output channel size. + dropout_rate (float, optional): dropout probability. Defaults to 0.2. + """ super(PreNet, self).__init__() self.input_size = input_size self.hidden_size = hidden_size @@ -44,20 +52,20 @@ class PreNet(dg.Layer): def forward(self, x): """ - Pre Net before passing through the network. + Prepare network input. Args: - x (Variable): The input value. - Shape: (B, T, C), dtype: float32. + x (Variable): shape(B, T, C), dtype float32, the input value. + Returns: - (Variable), the result after pernet. Shape: (B, T, C), + output (Variable): shape(B, T, C), the result after pernet. """ x = layers.dropout( layers.relu(self.linear1(x)), self.dropout_rate, dropout_implementation='upscale_in_train') - x = layers.dropout( + output = layers.dropout( layers.relu(self.linear2(x)), self.dropout_rate, dropout_implementation='upscale_in_train') - return x + return output diff --git a/parakeet/models/transformer_tts/transformer_tts.py b/parakeet/models/transformer_tts/transformer_tts.py index ca605bf..75b283c 100644 --- a/parakeet/models/transformer_tts/transformer_tts.py +++ b/parakeet/models/transformer_tts/transformer_tts.py @@ -19,6 +19,11 @@ from parakeet.models.transformer_tts.decoder import Decoder class TransformerTTS(dg.Layer): def __init__(self, config): + """TransformerTTS model. + + Args: + config: the yaml configs used in TransformerTTS model. + """ super(TransformerTTS, self).__init__() self.encoder = Encoder(config['embedding_size'], config['hidden_size']) self.decoder = Decoder(config['hidden_size'], config) @@ -37,43 +42,28 @@ class TransformerTTS(dg.Layer): dec_query_mask=None): """ TransformerTTS network. + Args: - characters (Variable): The input character. - Shape: (B, T_text), T_text means the timesteps of input text, - dtype: float32. - mel_input (Variable): The input query of decoder. - Shape: (B, T_mel, C), T_mel means the timesteps of input spectrum, - dtype: float32. - pos_text (Variable): The characters position. - Shape: (B, T_text), dtype: int64. - dec_slf_mask (Variable): The spectrum position. - Shape: (B, T_mel), dtype: int64. - mask (Variable): the mask of decoder self attention. - Shape: (B, T_mel, T_mel), dtype: int64. - enc_slf_mask (Variable, optional): the mask of encoder self attention. Defaults to None. - Shape: (B, T_text, T_text), dtype: int64. - enc_query_mask (Variable, optional): the query mask of encoder self attention. Defaults to None. - Shape: (B, T_text, 1), dtype: int64. - dec_query_mask (Variable, optional): the query mask of encoder-decoder attention. Defaults to None. - Shape: (B, T_mel, 1), dtype: int64. - dec_query_slf_mask (Variable, optional): the query mask of decoder self attention. Defaults to None. - Shape: (B, T_mel, 1), dtype: int64. - enc_dec_mask (Variable, optional): query mask of encoder-decoder attention. Defaults to None. - Shape: (B, T_mel, T_text), dtype: int64. + characters (Variable): shape(B, T_text), dtype float32, the input character, + where T_text means the timesteps of input text, + mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder, + where T_mel means the timesteps of input spectrum, + pos_text (Variable): shape(B, T_text), dtype int64, the characters position. + dec_slf_mask (Variable): shape(B, T_mel), dtype int64, the spectrum position. + mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention. + enc_slf_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None. + enc_query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None. + dec_query_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None. + dec_query_slf_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None. + enc_dec_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None. Returns: - mel_output (Variable): the decoder output after mel linear projection. - Shape: (B, T_mel, C). - postnet_output (Variable): the decoder output after post mel network. - Shape: (B, T_mel, C). - stop_preds (Variable): the stop tokens of output. - Shape: (B, T_mel, 1) - attn_probs (list[Variable]): the encoder-decoder attention list. - Len: n_layers. - attns_enc (list[Variable]): the encoder self attention list. - Len: n_layers. - attns_dec (list[Variable]): the decoder self attention list. - Len: n_layers. + mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection. + postnet_output (Variable): shape(B, T_mel, C), the decoder output after post mel network. + stop_preds (Variable): shape(B, T_mel, 1), the stop tokens of output. + attn_probs (list[Variable]): len(n_layers), the encoder-decoder attention list. + attns_enc (list[Variable]): len(n_layers), the encoder self attention list. + attns_dec (list[Variable]): len(n_layers), the decoder self attention list. """ key, attns_enc = self.encoder( characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask) diff --git a/parakeet/models/transformer_tts/vocoder.py b/parakeet/models/transformer_tts/vocoder.py index 654467b..93843ca 100644 --- a/parakeet/models/transformer_tts/vocoder.py +++ b/parakeet/models/transformer_tts/vocoder.py @@ -20,6 +20,12 @@ from parakeet.models.transformer_tts.cbhg import CBHG class Vocoder(dg.Layer): def __init__(self, config, batch_size): + """CBHG Network (mel -> linear) + + Args: + config: the yaml configs used in Vocoder model. + batch_size (int): the batch size of input. + """ super(Vocoder, self).__init__() self.pre_proj = Conv1D( num_channels=config['audio']['num_mels'], @@ -33,14 +39,13 @@ class Vocoder(dg.Layer): def forward(self, mel): """ - CBHG Network (mel -> linear) + Compute mel spectrum to linear spectrum. + Args: - mel (Variable): The input mel spectrum. - Shape: (B, C, T), dtype: float32. + mel (Variable): shape(B, C, T), dtype float32, the input mel spectrum. Returns: - (Variable): the linear output. - Shape: (B, T, C). + mag_pred (Variable): shape(B, T, C), the linear output. """ mel = layers.transpose(mel, [0, 2, 1]) mel = self.pre_proj(mel) diff --git a/parakeet/modules/dynamic_gru.py b/parakeet/modules/dynamic_gru.py index 19fa060..b944b92 100644 --- a/parakeet/modules/dynamic_gru.py +++ b/parakeet/modules/dynamic_gru.py @@ -43,10 +43,10 @@ class DynamicGRU(dg.Layer): Dynamic GRU block. Args: - input (Variable): The input value. - Shape: (B, T, C), dtype: float32. + input (Variable): shape(B, T, C), dtype float32, the input value. + Returns: - output (Variable), the result compute by GRU. Shape: (B, T, C). + output (Variable): shape(B, T, C), the result compute by GRU. """ hidden = self.h_0 res = [] diff --git a/parakeet/modules/ffn.py b/parakeet/modules/ffn.py index 7aa7f4a..199f668 100644 --- a/parakeet/modules/ffn.py +++ b/parakeet/modules/ffn.py @@ -19,8 +19,6 @@ from parakeet.modules.customized import Conv1D class PositionwiseFeedForward(dg.Layer): - ''' A two-feed-forward-layer module ''' - def __init__(self, d_in, num_hidden, @@ -28,6 +26,16 @@ class PositionwiseFeedForward(dg.Layer): padding=0, use_cudnn=True, dropout=0.1): + """A two-feed-forward-layer module. + + Args: + d_in (int): the size of input channel. + num_hidden (int): the size of hidden layer in network. + filter_size (int): the filter size of Conv + padding (int, optional): the padding size of Conv. Defaults to 0. + use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True. + dropout (float, optional): dropout probability. Defaults to 0.1. + """ super(PositionwiseFeedForward, self).__init__() self.num_hidden = num_hidden self.use_cudnn = use_cudnn @@ -59,13 +67,13 @@ class PositionwiseFeedForward(dg.Layer): def forward(self, input): """ - Feed Forward Network. + Compute feed forward network result. Args: - input (Variable): The input value. - Shape: (B, T, C), dtype: float32. + input (Variable): shape(B, T, C), dtype float32, the input value. + Returns: - output (Variable), the result after FFN. Shape: (B, T, C). + output (Variable): shape(B, T, C), the result after FFN. """ x = layers.transpose(input, [0, 2, 1]) #FFN Networt diff --git a/parakeet/modules/multihead_attention.py b/parakeet/modules/multihead_attention.py index 04cf947..c6907e8 100644 --- a/parakeet/modules/multihead_attention.py +++ b/parakeet/modules/multihead_attention.py @@ -50,6 +50,11 @@ class Linear(dg.Layer): class ScaledDotProductAttention(dg.Layer): def __init__(self, d_key): + """Scaled dot product attention module. + + Args: + d_key (int): the dim of key in multihead attention. + """ super(ScaledDotProductAttention, self).__init__() self.d_key = d_key @@ -63,23 +68,18 @@ class ScaledDotProductAttention(dg.Layer): query_mask=None, dropout=0.1): """ - Scaled Dot Product Attention. + Compute scaled dot product attention. Args: - key (Variable): The input key of scaled dot product attention. - Shape: (B, T, C), dtype: float32. - value (Variable): The input value of scaled dot product attention. - Shape: (B, T, C), dtype: float32. - query (Variable): The input query of scaled dot product attention. - Shape: (B, T, C), dtype: float32. - mask (Variable, optional): The mask of key. Defaults to None. - Shape(B, T_q, T_k), dtype: float32. - query_mask (Variable, optional): The mask of query. Defaults to None. - Shape(B, T_q, T_q), dtype: float32. - dropout (float32, optional): The probability of dropout. Defaults to 0.1. + key (Variable): shape(B, T, C), dtype float32, the input key of scaled dot product attention. + value (Variable): shape(B, T, C), dtype float32, the input value of scaled dot product attention. + query (Variable): shape(B, T, C), dtype float32, the input query of scaled dot product attention. + mask (Variable, optional): shape(B, T_q, T_k), dtype float32, the mask of key. Defaults to None. + query_mask (Variable, optional): shape(B, T_q, T_q), dtype float32, the mask of query. Defaults to None. + dropout (float32, optional): the probability of dropout. Defaults to 0.1. Returns: - result (Variable), Shape(B, T, C), the result of mutihead attention. - attention (Variable), Shape(n_head * B, T, C), the attention of key. + result (Variable): shape(B, T, C), the result of mutihead attention. + attention (Variable): shape(n_head * B, T, C), the attention of key. """ # Compute attention score attention = layers.matmul( @@ -110,6 +110,17 @@ class MultiheadAttention(dg.Layer): is_bias=False, dropout=0.1, is_concat=True): + """Multihead Attention. + + Args: + num_hidden (int): the number of hidden layer in network. + d_k (int): the dim of key in multihead attention. + d_q (int): the dim of query in multihead attention. + num_head (int, optional): the head number of multihead attention. Defaults to 4. + is_bias (bool, optional): whether have bias in linear layers. Default to False. + dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1. + is_concat (bool, optional): whether concat query and result. Default to True. + """ super(MultiheadAttention, self).__init__() self.num_hidden = num_hidden self.num_head = num_head @@ -133,22 +144,18 @@ class MultiheadAttention(dg.Layer): def forward(self, key, value, query_input, mask=None, query_mask=None): """ - Multihead Attention. + Compute attention. Args: - key (Variable): The input key of attention. - Shape: (B, T, C), dtype: float32. - value (Variable): The input value of attention. - Shape: (B, T, C), dtype: float32. - query_input (Variable): The input query of attention. - Shape: (B, T, C), dtype: float32. - mask (Variable, optional): The mask of key. Defaults to None. - Shape: (B, T_query, T_key), dtype: float32. - query_mask (Variable, optional): The mask of query. Defaults to None. - Shape: (B, T_query, T_key), dtype: float32. + key (Variable): shape(B, T, C), dtype float32, the input key of attention. + value (Variable): shape(B, T, C), dtype float32, the input value of attention. + query_input (Variable): shape(B, T, C), dtype float32, the input query of attention. + mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of key. Defaults to None. + query_mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of query. Defaults to None. + Returns: - result (Variable), the result of mutihead attention. Shape: (B, T, C). - attention (Variable), the attention of key and query. Shape: (num_head * B, T, C) + result (Variable): shape(B, T, C), the result of mutihead attention. + attention (Variable): shape(num_head * B, T, C), the attention of key and query. """ batch_size = key.shape[0]