diff --git a/examples/fastspeech/synthesis.sh b/examples/fastspeech/synthesis.sh index 32562d7..4daef57 100644 --- a/examples/fastspeech/synthesis.sh +++ b/examples/fastspeech/synthesis.sh @@ -6,6 +6,7 @@ python -u synthesis.py \ --checkpoint_path='checkpoint/' \ --fastspeech_step=71000 \ --log_dir='./log' \ +--config_path='configs/synthesis.yaml' \ if [ $? -ne 0 ]; then echo "Failed in synthesis!" diff --git a/examples/fastspeech/train.sh b/examples/fastspeech/train.sh index 11e78c4..2301ab3 100644 --- a/examples/fastspeech/train.sh +++ b/examples/fastspeech/train.sh @@ -13,7 +13,7 @@ python -u train.py \ --transformer_step=160000 \ --save_path='./checkpoint' \ --log_dir='./log' \ ---config_path='config/fastspeech.yaml' \ +--config_path='configs/fastspeech.yaml' \ #--checkpoint_path='./checkpoint' \ #--fastspeech_step=97000 \ diff --git a/examples/transformer_tts/synthesis.py b/examples/transformer_tts/synthesis.py index 2896634..0f7b83d 100644 --- a/examples/transformer_tts/synthesis.py +++ b/examples/transformer_tts/synthesis.py @@ -84,7 +84,7 @@ def synthesis(text_input, args): dec_slf_mask = get_triu_tensor( mel_input.numpy(), mel_input.numpy()).astype(np.float32) dec_slf_mask = fluid.layers.cast( - dg.to_variable(dec_slf_mask == 0), np.float32) + dg.to_variable(dec_slf_mask != 0), np.float32) * (-2**32 + 1) pos_mel = np.arange(1, mel_input.shape[1] + 1) pos_mel = fluid.layers.unsqueeze(dg.to_variable(pos_mel), [0]) mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model( @@ -157,6 +157,5 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description="Synthesis model") add_config_options_to_parser(parser) args = parser.parse_args() - synthesis( - "They emphasized the necessity that the information now being furnished be handled with judgment and care.", - args) + synthesis("Parakeet stands for Paddle PARAllel text-to-speech toolkit.", + args) diff --git a/examples/transformer_tts/synthesis.sh b/examples/transformer_tts/synthesis.sh index 42b704d..65dff46 100644 --- a/examples/transformer_tts/synthesis.sh +++ b/examples/transformer_tts/synthesis.sh @@ -2,14 +2,14 @@ # train model CUDA_VISIBLE_DEVICES=0 \ python -u synthesis.py \ ---max_len=600 \ ---transformer_step=160000 \ ---vocoder_step=90000 \ +--max_len=300 \ +--transformer_step=120000 \ +--vocoder_step=100000 \ --use_gpu=1 \ --checkpoint_path='./checkpoint' \ --log_dir='./log' \ --sample_path='./sample' \ ---config_path='config/synthesis.yaml' \ +--config_path='configs/synthesis.yaml' \ if [ $? -ne 0 ]; then echo "Failed in training!" diff --git a/examples/transformer_tts/train_transformer.sh b/examples/transformer_tts/train_transformer.sh index 346d351..f5e47ee 100644 --- a/examples/transformer_tts/train_transformer.sh +++ b/examples/transformer_tts/train_transformer.sh @@ -14,7 +14,7 @@ python -u train_transformer.py \ --data_path='../../dataset/LJSpeech-1.1' \ --save_path='./checkpoint' \ --log_dir='./log' \ ---config_path='config/train_transformer.yaml' \ +--config_path='configs/train_transformer.yaml' \ #--checkpoint_path='./checkpoint' \ #--transformer_step=160000 \ diff --git a/examples/transformer_tts/train_vocoder.sh b/examples/transformer_tts/train_vocoder.sh index e453c83..ea57ebd 100644 --- a/examples/transformer_tts/train_vocoder.sh +++ b/examples/transformer_tts/train_vocoder.sh @@ -12,7 +12,7 @@ python -u train_vocoder.py \ --data_path='../../dataset/LJSpeech-1.1' \ --save_path='./checkpoint' \ --log_dir='./log' \ ---config_path='config/train_vocoder.yaml' \ +--config_path='configs/train_vocoder.yaml' \ #--checkpoint_path='./checkpoint' \ #--vocoder_step=27000 \ diff --git a/parakeet/models/fastspeech/decoder.py b/parakeet/models/fastspeech/decoder.py index 8432fc5..30432d0 100644 --- a/parakeet/models/fastspeech/decoder.py +++ b/parakeet/models/fastspeech/decoder.py @@ -23,12 +23,26 @@ class Decoder(dg.Layer): n_layers, n_head, d_k, - d_v, + d_q, d_model, d_inner, fft_conv1d_kernel, fft_conv1d_padding, dropout=0.1): + """Decoder layer of FastSpeech. + + Args: + len_max_seq (int): the max mel len of sequence. + n_layers (int): the layers number of FFTBlock. + n_head (int): the head number of multihead attention. + d_k (int): the dim of key in multihead attention. + d_q (int): the dim of query in multihead attention. + d_model (int): the dim of hidden layer in multihead attention. + d_inner (int): the dim of hidden layer in ffn. + fft_conv1d_kernel (int): the conv kernel size in FFTBlock. + fft_conv1d_padding (int): the conv padding size in FFTBlock. + dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1. + """ super(Decoder, self).__init__() n_position = len_max_seq + 1 @@ -48,7 +62,7 @@ class Decoder(dg.Layer): d_inner, n_head, d_k, - d_v, + d_q, fft_conv1d_kernel, fft_conv1d_padding, dropout=dropout) for _ in range(n_layers) @@ -58,16 +72,20 @@ class Decoder(dg.Layer): def forward(self, enc_seq, enc_pos, non_pad_mask, slf_attn_mask=None): """ - Decoder layer of FastSpeech. + Compute decoder outputs. Args: - enc_seq (Variable), Shape(B, text_T, C), dtype: float32. - The output of length regulator. - enc_pos (Variable, optional): Shape(B, T_mel), - dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum. + enc_seq (Variable): shape(B, T_text, C), dtype float32, + the output of length regulator, where T_text means the timesteps of input text, + enc_pos (Variable): shape(B, T_mel), dtype int64, + the spectrum position, where T_mel means the timesteps of input spectrum, + non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad. + slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64, + the mask of mel spectrum. Defaults to None. + Returns: - dec_output (Variable), Shape(B, mel_T, C), the decoder output. - dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list. + dec_output (Variable): shape(B, T_mel, C), the decoder output. + dec_slf_attn_list (list[Variable]): len(n_layers), the decoder self attention list. """ dec_slf_attn_list = [] slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) diff --git a/parakeet/models/fastspeech/encoder.py b/parakeet/models/fastspeech/encoder.py index 15d634e..d39fdc1 100644 --- a/parakeet/models/fastspeech/encoder.py +++ b/parakeet/models/fastspeech/encoder.py @@ -24,12 +24,27 @@ class Encoder(dg.Layer): n_layers, n_head, d_k, - d_v, + d_q, d_model, d_inner, fft_conv1d_kernel, fft_conv1d_padding, dropout=0.1): + """Encoder layer of FastSpeech. + + Args: + n_src_vocab (int): the number of source vocabulary. + len_max_seq (int): the max mel len of sequence. + n_layers (int): the layers number of FFTBlock. + n_head (int): the head number of multihead attention. + d_k (int): the dim of key in multihead attention. + d_q (int): the dim of query in multihead attention. + d_model (int): the dim of hidden layer in multihead attention. + d_inner (int): the dim of hidden layer in ffn. + fft_conv1d_kernel (int): the conv kernel size in FFTBlock. + fft_conv1d_padding (int): the conv padding size in FFTBlock. + dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1. + """ super(Encoder, self).__init__() n_position = len_max_seq + 1 self.n_head = n_head @@ -53,7 +68,7 @@ class Encoder(dg.Layer): d_inner, n_head, d_k, - d_v, + d_q, fft_conv1d_kernel, fft_conv1d_padding, dropout=dropout) for _ in range(n_layers) @@ -63,18 +78,20 @@ class Encoder(dg.Layer): def forward(self, character, text_pos, non_pad_mask, slf_attn_mask=None): """ - Encoder layer of FastSpeech. - - Args: - character (Variable): Shape(B, T_text), dtype: float32. The input text - characters. T_text means the timesteps of input characters. - text_pos (Variable): Shape(B, T_text), dtype: int64. The input text - position. T_text means the timesteps of input characters. + Encode text sequence. + Args: + character (Variable): shape(B, T_text), dtype float32, the input text characters, + where T_text means the timesteps of input characters, + text_pos (Variable): shape(B, T_text), dtype int64, the input text position. + non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad. + slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, + the mask of input characters. Defaults to None. + Returns: - enc_output (Variable), Shape(B, text_T, C), the encoder output. - non_pad_mask (Variable), Shape(B, T_text, 1), the mask with non pad. - enc_slf_attn_list (list), Len(n_layers), Shape(B * n_head, text_T, text_T), the encoder self attention list. + enc_output (Variable): shape(B, T_text, C), the encoder output. + non_pad_mask (Variable): shape(B, T_text, 1), the mask with non pad. + enc_slf_attn_list (list[Variable]): len(n_layers), the encoder self attention list. """ enc_slf_attn_list = [] slf_attn_mask = layers.expand(slf_attn_mask, [self.n_head, 1, 1]) diff --git a/parakeet/models/fastspeech/fastspeech.py b/parakeet/models/fastspeech/fastspeech.py index a37d5fa..5ee2de1 100644 --- a/parakeet/models/fastspeech/fastspeech.py +++ b/parakeet/models/fastspeech/fastspeech.py @@ -25,7 +25,11 @@ from parakeet.models.fastspeech.decoder import Decoder class FastSpeech(dg.Layer): def __init__(self, cfg): - " FastSpeech" + """FastSpeech model. + + Args: + cfg: the yaml configs used in FastSpeech model. + """ super(FastSpeech, self).__init__() self.encoder = Encoder( @@ -34,7 +38,7 @@ class FastSpeech(dg.Layer): n_layers=cfg['encoder_n_layer'], n_head=cfg['encoder_head'], d_k=cfg['fs_hidden_size'] // cfg['encoder_head'], - d_v=cfg['fs_hidden_size'] // cfg['encoder_head'], + d_q=cfg['fs_hidden_size'] // cfg['encoder_head'], d_model=cfg['fs_hidden_size'], d_inner=cfg['encoder_conv1d_filter_size'], fft_conv1d_kernel=cfg['fft_conv1d_filter'], @@ -50,7 +54,7 @@ class FastSpeech(dg.Layer): n_layers=cfg['decoder_n_layer'], n_head=cfg['decoder_head'], d_k=cfg['fs_hidden_size'] // cfg['decoder_head'], - d_v=cfg['fs_hidden_size'] // cfg['decoder_head'], + d_q=cfg['fs_hidden_size'] // cfg['decoder_head'], d_model=cfg['fs_hidden_size'], d_inner=cfg['decoder_conv1d_filter_size'], fft_conv1d_kernel=cfg['fft_conv1d_filter'], @@ -82,34 +86,37 @@ class FastSpeech(dg.Layer): text_pos, enc_non_pad_mask, dec_non_pad_mask, + mel_pos=None, enc_slf_attn_mask=None, dec_slf_attn_mask=None, - mel_pos=None, length_target=None, alpha=1.0): """ - FastSpeech model. + Compute mel output from text character. Args: - character (Variable): Shape(B, T_text), dtype: float32. The input text - characters. T_text means the timesteps of input characters. - text_pos (Variable): Shape(B, T_text), dtype: int64. The input text - position. T_text means the timesteps of input characters. - mel_pos (Variable, optional): Shape(B, T_mel), - dtype: int64. The spectrum position. T_mel means the timesteps of input spectrum. - length_target (Variable, optional): Shape(B, T_text), - dtype: int64. The duration of phoneme compute from pretrained transformerTTS. - alpha (Constant): - dtype: float32. The hyperparameter to determine the length of the expanded sequence - mel, thereby controlling the voice speed. + character (Variable): shape(B, T_text), dtype float32, the input text characters, + where T_text means the timesteps of input characters, + text_pos (Variable): shape(B, T_text), dtype int64, the input text position. + mel_pos (Variable, optional): shape(B, T_mel), dtype int64, the spectrum position, + where T_mel means the timesteps of input spectrum, + enc_non_pad_mask (Variable): shape(B, T_text, 1), dtype int64, the mask with non pad. + dec_non_pad_mask (Variable): shape(B, T_mel, 1), dtype int64, the mask with non pad. + enc_slf_attn_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, + the mask of input characters. Defaults to None. + slf_attn_mask (Variable, optional): shape(B, T_mel, T_mel), dtype int64, + the mask of mel spectrum. Defaults to None. + length_target (Variable, optional): shape(B, T_text), dtype int64, + the duration of phoneme compute from pretrained transformerTTS. Defaults to None. + alpha (float32, optional): The hyperparameter to determine the length of the expanded sequence + mel, thereby controlling the voice speed. Defaults to 1.0. Returns: - mel_output (Variable), Shape(B, mel_T, C), the mel output before postnet. - mel_output_postnet (Variable), Shape(B, mel_T, C), the mel output after postnet. - duration_predictor_output (Variable), Shape(B, text_T), the duration of phoneme compute - with duration predictor. - enc_slf_attn_list (Variable), Shape(B, text_T, text_T), the encoder self attention list. - dec_slf_attn_list (Variable), Shape(B, mel_T, mel_T), the decoder self attention list. + mel_output (Variable): shape(B, T_mel, C), the mel output before postnet. + mel_output_postnet (Variable): shape(B, T_mel, C), the mel output after postnet. + duration_predictor_output (Variable): shape(B, T_text), the duration of phoneme compute with duration predictor. + enc_slf_attn_list (List[Variable]): len(enc_n_layers), the encoder self attention list. + dec_slf_attn_list (List[Variable]): len(dec_n_layers), the decoder self attention list. """ encoder_output, enc_slf_attn_list = self.encoder( @@ -118,7 +125,6 @@ class FastSpeech(dg.Layer): enc_non_pad_mask, slf_attn_mask=enc_slf_attn_mask) if fluid.framework._dygraph_tracer()._train_mode: - length_regulator_output, duration_predictor_output = self.length_regulator( encoder_output, target=length_target, alpha=alpha) decoder_output, dec_slf_attn_list = self.decoder( diff --git a/parakeet/models/fastspeech/fft_block.py b/parakeet/models/fastspeech/fft_block.py index 0c0ed4f..b3c69ea 100644 --- a/parakeet/models/fastspeech/fft_block.py +++ b/parakeet/models/fastspeech/fft_block.py @@ -26,15 +26,27 @@ class FFTBlock(dg.Layer): d_inner, n_head, d_k, - d_v, + d_q, filter_size, padding, dropout=0.2): + """Feed forward structure based on self-attention. + + Args: + d_model (int): the dim of hidden layer in multihead attention. + d_inner (int): the dim of hidden layer in ffn. + n_head (int): the head number of multihead attention. + d_k (int): the dim of key in multihead attention. + d_q (int): the dim of query in multihead attention. + filter_size (int): the conv kernel size. + padding (int): the conv padding size. + dropout (float, optional): dropout probability. Defaults to 0.2. + """ super(FFTBlock, self).__init__() self.slf_attn = MultiheadAttention( d_model, d_k, - d_v, + d_q, num_head=n_head, is_bias=True, dropout=dropout, @@ -48,18 +60,18 @@ class FFTBlock(dg.Layer): def forward(self, enc_input, non_pad_mask, slf_attn_mask=None): """ - Feed Forward Transformer block in FastSpeech. + Feed forward block of FastSpeech Args: - enc_input (Variable): Shape(B, T, C), dtype: float32. The embedding characters input. - T means the timesteps of input. - non_pad_mask (Variable): Shape(B, T, 1), dtype: int64. The mask of sequence. - slf_attn_mask (Variable): Shape(B, len_q, len_k), dtype: int64. The mask of self attention. - len_q means the sequence length of query, len_k means the sequence length of key. - + enc_input (Variable): shape(B, T, C), dtype float32, the embedding characters input, + where T means the timesteps of input. + non_pad_mask (Variable): shape(B, T, 1), dtype int64, the mask of sequence. + slf_attn_mask (Variable, optional): shape(B, len_q, len_k), dtype int64, the mask of self attention, + where len_q means the sequence length of query and len_k means the sequence length of key. Defaults to None. + Returns: - output (Variable), Shape(B, T, C), the output after self-attention & ffn. - slf_attn (Variable), Shape(B * n_head, T, T), the self attention. + output (Variable): shape(B, T, C), the output after self-attention & ffn. + slf_attn (Variable): shape(B * n_head, T, T), the self attention. """ output, slf_attn = self.slf_attn( enc_input, enc_input, enc_input, mask=slf_attn_mask) diff --git a/parakeet/models/fastspeech/length_regulator.py b/parakeet/models/fastspeech/length_regulator.py index f6bc803..6fc6702 100644 --- a/parakeet/models/fastspeech/length_regulator.py +++ b/parakeet/models/fastspeech/length_regulator.py @@ -22,6 +22,14 @@ from parakeet.modules.customized import Conv1D class LengthRegulator(dg.Layer): def __init__(self, input_size, out_channels, filter_size, dropout=0.1): + """Length Regulator block in FastSpeech. + + Args: + input_size (int): the channel number of input. + out_channels (int): the output channel number. + filter_size (int): the filter size of duration predictor. + dropout (float, optional): dropout probability. Defaults to 0.1. + """ super(LengthRegulator, self).__init__() self.duration_predictor = DurationPredictor( input_size=input_size, @@ -66,18 +74,18 @@ class LengthRegulator(dg.Layer): def forward(self, x, alpha=1.0, target=None): """ - Length Regulator block in FastSpeech. + Compute length of mel from encoder output use TransformerTTS attention Args: - x (Variable): Shape(B, T, C), dtype: float32. The encoder output. - alpha (Constant): dtype: float32. The hyperparameter to determine the length of - the expanded sequence mel, thereby controlling the voice speed. - target (Variable): (Variable, optional): Shape(B, T_text), - dtype: int64. The duration of phoneme compute from pretrained transformerTTS. + x (Variable): shape(B, T, C), dtype float32, the encoder output. + alpha (float32, optional): the hyperparameter to determine the length of + the expanded sequence mel, thereby controlling the voice speed. Defaults to 1.0. + target (Variable, optional): shape(B, T_text), dtype int64, the duration of phoneme compute from pretrained transformerTTS. + Defaults to None. Returns: - output (Variable), Shape(B, T, C), the output after exppand. - duration_predictor_output (Variable), Shape(B, T, C), the output of duration predictor. + output (Variable): shape(B, T, C), the output after exppand. + duration_predictor_output (Variable): shape(B, T, C), the output of duration predictor. """ duration_predictor_output = self.duration_predictor(x) if fluid.framework._dygraph_tracer()._train_mode: @@ -93,6 +101,14 @@ class LengthRegulator(dg.Layer): class DurationPredictor(dg.Layer): def __init__(self, input_size, out_channels, filter_size, dropout=0.1): + """Duration Predictor block in FastSpeech. + + Args: + input_size (int): the channel number of input. + out_channels (int): the output channel number. + filter_size (int): the filter size. + dropout (float, optional): dropout probability. Defaults to 0.1. + """ super(DurationPredictor, self).__init__() self.input_size = input_size self.out_channels = out_channels @@ -135,12 +151,13 @@ class DurationPredictor(dg.Layer): def forward(self, encoder_output): """ - Duration Predictor block in FastSpeech. + Predict the duration of each character. Args: - encoder_output (Variable): Shape(B, T, C), dtype: float32. The encoder output. + encoder_output (Variable): shape(B, T, C), dtype float32, the encoder output. + Returns: - out (Variable), Shape(B, T, C), the output of duration predictor. + out (Variable): shape(B, T, C), the output of duration predictor. """ # encoder_output.shape(N, T, C) out = layers.transpose(encoder_output, [0, 2, 1]) diff --git a/parakeet/models/transformer_tts/cbhg.py b/parakeet/models/transformer_tts/cbhg.py index ca93536..5a28ebd 100644 --- a/parakeet/models/transformer_tts/cbhg.py +++ b/parakeet/models/transformer_tts/cbhg.py @@ -30,16 +30,19 @@ class CBHG(dg.Layer): num_gru_layers=2, max_pool_kernel_size=2, is_post=False): + """CBHG Module + + Args: + hidden_size (int): dimension of hidden unit. + batch_size (int): batch size of input. + K (int, optional): number of convolution banks. Defaults to 16. + projection_size (int, optional): dimension of projection unit. Defaults to 256. + num_gru_layers (int, optional): number of layers of GRUcell. Defaults to 2. + max_pool_kernel_size (int, optional): max pooling kernel size. Defaults to 2 + is_post (bool, optional): whether post processing or not. Defaults to False. + """ super(CBHG, self).__init__() - """ - :param hidden_size: dimension of hidden unit - :param batch_size: batch size - :param K: # of convolution banks - :param projection_size: dimension of projection unit - :param num_gru_layers: # of layers of GRUcell - :param max_pool_kernel_size: max pooling kernel size - :param is_post: whether post processing or not - """ + self.hidden_size = hidden_size self.projection_size = projection_size self.conv_list = [] @@ -176,7 +179,15 @@ class CBHG(dg.Layer): return x def forward(self, input_): - # input_.shape = [N, C, T] + """ + Convert linear spectrum to Mel spectrum. + + Args: + input_ (Variable): shape(B, C, T), dtype float32, the sequentially input. + + Returns: + out (Variable): shape(B, C, T), the CBHG output. + """ conv_list = [] conv_input = input_ @@ -217,6 +228,12 @@ class CBHG(dg.Layer): class Highwaynet(dg.Layer): def __init__(self, num_units, num_layers=4): + """Highway network + + Args: + num_units (int): dimension of hidden unit. + num_layers (int, optional): number of highway layers. Defaults to 4. + """ super(Highwaynet, self).__init__() self.num_units = num_units self.num_layers = num_layers @@ -249,6 +266,15 @@ class Highwaynet(dg.Layer): self.add_sublayer("gates_{}".format(i), gate) def forward(self, input_): + """ + Compute result of Highway network. + + Args: + input_(Variable): shape(B, T, C), dtype float32, the sequentially input. + + Returns: + out(Variable): the Highway output. + """ out = input_ for linear, gate in zip(self.linears, self.gates): diff --git a/parakeet/models/transformer_tts/decoder.py b/parakeet/models/transformer_tts/decoder.py index 5b17a7a..0b47e4f 100644 --- a/parakeet/models/transformer_tts/decoder.py +++ b/parakeet/models/transformer_tts/decoder.py @@ -22,7 +22,15 @@ from parakeet.models.transformer_tts.post_convnet import PostConvNet class Decoder(dg.Layer): - def __init__(self, num_hidden, config, num_head=4): + def __init__(self, num_hidden, config, num_head=4, n_layers=3): + """Decoder layer of TransformerTTS. + + Args: + num_hidden (int): the number of source vocabulary. + config: the yaml configs used in decoder. + n_layers (int, optional): the layers number of multihead attention. Defaults to 4. + num_head (int, optional): the head number of multihead attention. Defaults to 3. + """ super(Decoder, self).__init__() self.num_hidden = num_hidden self.num_head = num_head @@ -58,20 +66,20 @@ class Decoder(dg.Layer): self.selfattn_layers = [ MultiheadAttention(num_hidden, num_hidden // num_head, - num_hidden // num_head) for _ in range(3) + num_hidden // num_head) for _ in range(n_layers) ] for i, layer in enumerate(self.selfattn_layers): self.add_sublayer("self_attn_{}".format(i), layer) self.attn_layers = [ MultiheadAttention(num_hidden, num_hidden // num_head, - num_hidden // num_head) for _ in range(3) + num_hidden // num_head) for _ in range(n_layers) ] for i, layer in enumerate(self.attn_layers): self.add_sublayer("attn_{}".format(i), layer) self.ffns = [ PositionwiseFeedForward( num_hidden, num_hidden * num_head, filter_size=1) - for _ in range(3) + for _ in range(n_layers) ] for i, layer in enumerate(self.ffns): self.add_sublayer("ffns_{}".format(i), layer) @@ -108,6 +116,28 @@ class Decoder(dg.Layer): m_mask=None, m_self_mask=None, zero_mask=None): + """ + Compute decoder outputs. + + Args: + key (Variable): shape(B, T_text, C), dtype float32, the input key of decoder, + where T_text means the timesteps of input text, + value (Variable): shape(B, T_text, C), dtype float32, the input value of decoder. + query (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder, + where T_mel means the timesteps of input spectrum, + positional (Variable): shape(B, T_mel), dtype int64, the spectrum position. + mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention. + m_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None. + m_self_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None. + zero_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None. + + Returns: + mel_out (Variable): shape(B, T_mel, C), the decoder output after mel linear projection. + out (Variable): shape(B, T_mel, C), the decoder output after post mel network. + stop_tokens (Variable): shape(B, T_mel, 1), the stop tokens of output. + attn_list (list[Variable]): len(n_layers), the encoder-decoder attention list. + selfattn_list (list[Variable]): len(n_layers), the decoder self attention list. + """ # get decoder mask with triangular matrix @@ -121,7 +151,7 @@ class Decoder(dg.Layer): else: m_mask, m_self_mask, zero_mask = None, None, None -# Decoder pre-network + # Decoder pre-network query = self.decoder_prenet(query) # Centered position diff --git a/parakeet/models/transformer_tts/encoder.py b/parakeet/models/transformer_tts/encoder.py index ef3821f..48fc6c1 100644 --- a/parakeet/models/transformer_tts/encoder.py +++ b/parakeet/models/transformer_tts/encoder.py @@ -20,7 +20,15 @@ from parakeet.models.transformer_tts.encoderprenet import EncoderPrenet class Encoder(dg.Layer): - def __init__(self, embedding_size, num_hidden, num_head=4): + def __init__(self, embedding_size, num_hidden, num_head=4, n_layers=3): + """Encoder layer of TransformerTTS. + + Args: + embedding_size (int): the size of position embedding. + num_hidden (int): the size of hidden layer in network. + n_layers (int, optional): the layers number of multihead attention. Defaults to 4. + num_head (int, optional): the head number of multihead attention. Defaults to 3. + """ super(Encoder, self).__init__() self.num_hidden = num_hidden self.num_head = num_head @@ -42,7 +50,7 @@ class Encoder(dg.Layer): use_cudnn=True) self.layers = [ MultiheadAttention(num_hidden, num_hidden // num_head, - num_hidden // num_head) for _ in range(3) + num_hidden // num_head) for _ in range(n_layers) ] for i, layer in enumerate(self.layers): self.add_sublayer("self_attn_{}".format(i), layer) @@ -51,12 +59,26 @@ class Encoder(dg.Layer): num_hidden, num_hidden * num_head, filter_size=1, - use_cudnn=True) for _ in range(3) + use_cudnn=True) for _ in range(n_layers) ] for i, layer in enumerate(self.ffns): self.add_sublayer("ffns_{}".format(i), layer) def forward(self, x, positional, mask=None, query_mask=None): + """ + Encode text sequence. + + Args: + x (Variable): shape(B, T_text), dtype float32, the input character, + where T_text means the timesteps of input text, + positional (Variable): shape(B, T_text), dtype int64, the characters position. + mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None. + query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None. + + Returns: + x (Variable): shape(B, T_text, C), the encoder output. + attentions (list[Variable]): len(n_layers), the encoder self attention list. + """ if fluid.framework._dygraph_tracer()._train_mode: seq_len_key = x.shape[1] @@ -66,12 +88,12 @@ class Encoder(dg.Layer): else: query_mask, mask = None, None # Encoder pre_network - x = self.encoder_prenet(x) #(N,T,C) + x = self.encoder_prenet(x) # Get positional encoding positional = self.pos_emb(positional) - x = positional * self.alpha + x #(N, T, C) + x = positional * self.alpha + x # Positional dropout x = layers.dropout(x, 0.1, dropout_implementation='upscale_in_train') diff --git a/parakeet/models/transformer_tts/encoderprenet.py b/parakeet/models/transformer_tts/encoderprenet.py index e953dab..d692cea 100644 --- a/parakeet/models/transformer_tts/encoderprenet.py +++ b/parakeet/models/transformer_tts/encoderprenet.py @@ -22,6 +22,13 @@ import numpy as np class EncoderPrenet(dg.Layer): def __init__(self, embedding_size, num_hidden, use_cudnn=True): + """ Encoder prenet layer of TransformerTTS. + + Args: + embedding_size (int): the size of embedding. + num_hidden (int): the size of hidden layer in network. + use_cudnn (bool, optional): use cudnn or not. Defaults to True. + """ super(EncoderPrenet, self).__init__() self.embedding_size = embedding_size self.num_hidden = num_hidden @@ -81,8 +88,17 @@ class EncoderPrenet(dg.Layer): low=-k, high=k))) def forward(self, x): + """ + Prepare encoder input. + + Args: + x (Variable): shape(B, T_text), dtype float32, the input character, where T_text means the timesteps of input text. + + Returns: + (Variable): shape(B, T_text, C), the encoder prenet output. + """ - x = self.embedding(x) #(batch_size, seq_len, embending_size) + x = self.embedding(x) x = layers.transpose(x, [0, 2, 1]) for batch_norm, conv in zip(self.batch_norm_list, self.conv_list): x = layers.dropout( diff --git a/parakeet/models/transformer_tts/post_convnet.py b/parakeet/models/transformer_tts/post_convnet.py index 60e9382..da458ed 100644 --- a/parakeet/models/transformer_tts/post_convnet.py +++ b/parakeet/models/transformer_tts/post_convnet.py @@ -29,6 +29,19 @@ class PostConvNet(dg.Layer): use_cudnn=True, dropout=0.1, batchnorm_last=False): + """Decocder post conv net of TransformerTTS. + + Args: + n_mels (int, optional): the number of mel bands when calculating mel spectrograms. Defaults to 80. + num_hidden (int, optional): the size of hidden layer in network. Defaults to 512. + filter_size (int, optional): the filter size of Conv. Defaults to 5. + padding (int, optional): the padding size of Conv. Defaults to 0. + num_conv (int, optional): the num of Conv layers in network. Defaults to 5. + outputs_per_step (int, optional): the num of output frames per step . Defaults to 1. + use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True. + dropout (float, optional): dropout probability. Defaults to 0.1. + batchnorm_last (bool, optional): if batchnorm at last layer or not. Defaults to False. + """ super(PostConvNet, self).__init__() self.dropout = dropout @@ -93,12 +106,13 @@ class PostConvNet(dg.Layer): def forward(self, input): """ - Post Conv Net. + Compute the mel spectrum. Args: - input (Variable): Shape(B, T, C), dtype: float32. The input value. + input (Variable): shape(B, T, C), dtype float32, the result of mel linear projection. + Returns: - output (Variable), Shape(B, T, C), the result after postconvnet. + output (Variable): shape(B, T, C), the result after postconvnet. """ input = layers.transpose(input, [0, 2, 1]) diff --git a/parakeet/models/transformer_tts/prenet.py b/parakeet/models/transformer_tts/prenet.py index b47a9f8..b033860 100644 --- a/parakeet/models/transformer_tts/prenet.py +++ b/parakeet/models/transformer_tts/prenet.py @@ -19,10 +19,13 @@ import paddle.fluid.layers as layers class PreNet(dg.Layer): def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.2): - """ - :param input_size: dimension of input - :param hidden_size: dimension of hidden unit - :param output_size: dimension of output + """Prenet before passing through the network. + + Args: + input_size (int): the input channel size. + hidden_size (int): the size of hidden layer in network. + output_size (int): the output channel size. + dropout_rate (float, optional): dropout probability. Defaults to 0.2. """ super(PreNet, self).__init__() self.input_size = input_size @@ -49,19 +52,20 @@ class PreNet(dg.Layer): def forward(self, x): """ - Pre Net before passing through the network. + Prepare network input. Args: - x (Variable): Shape(B, T, C), dtype: float32. The input value. + x (Variable): shape(B, T, C), dtype float32, the input value. + Returns: - x (Variable), Shape(B, T, C), the result after pernet. + output (Variable): shape(B, T, C), the result after pernet. """ x = layers.dropout( layers.relu(self.linear1(x)), self.dropout_rate, dropout_implementation='upscale_in_train') - x = layers.dropout( + output = layers.dropout( layers.relu(self.linear2(x)), self.dropout_rate, dropout_implementation='upscale_in_train') - return x + return output diff --git a/parakeet/models/transformer_tts/transformer_tts.py b/parakeet/models/transformer_tts/transformer_tts.py index a7fffbd..75b283c 100644 --- a/parakeet/models/transformer_tts/transformer_tts.py +++ b/parakeet/models/transformer_tts/transformer_tts.py @@ -19,6 +19,11 @@ from parakeet.models.transformer_tts.decoder import Decoder class TransformerTTS(dg.Layer): def __init__(self, config): + """TransformerTTS model. + + Args: + config: the yaml configs used in TransformerTTS model. + """ super(TransformerTTS, self).__init__() self.encoder = Encoder(config['embedding_size'], config['hidden_size']) self.decoder = Decoder(config['hidden_size'], config) @@ -35,6 +40,31 @@ class TransformerTTS(dg.Layer): enc_dec_mask=None, dec_query_slf_mask=None, dec_query_mask=None): + """ + TransformerTTS network. + + Args: + characters (Variable): shape(B, T_text), dtype float32, the input character, + where T_text means the timesteps of input text, + mel_input (Variable): shape(B, T_mel, C), dtype float32, the input query of decoder, + where T_mel means the timesteps of input spectrum, + pos_text (Variable): shape(B, T_text), dtype int64, the characters position. + dec_slf_mask (Variable): shape(B, T_mel), dtype int64, the spectrum position. + mask (Variable): shape(B, T_mel, T_mel), dtype int64, the mask of decoder self attention. + enc_slf_mask (Variable, optional): shape(B, T_text, T_text), dtype int64, the mask of encoder self attention. Defaults to None. + enc_query_mask (Variable, optional): shape(B, T_text, 1), dtype int64, the query mask of encoder self attention. Defaults to None. + dec_query_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of encoder-decoder attention. Defaults to None. + dec_query_slf_mask (Variable, optional): shape(B, T_mel, 1), dtype int64, the query mask of decoder self attention. Defaults to None. + enc_dec_mask (Variable, optional): shape(B, T_mel, T_text), dtype int64, query mask of encoder-decoder attention. Defaults to None. + + Returns: + mel_output (Variable): shape(B, T_mel, C), the decoder output after mel linear projection. + postnet_output (Variable): shape(B, T_mel, C), the decoder output after post mel network. + stop_preds (Variable): shape(B, T_mel, 1), the stop tokens of output. + attn_probs (list[Variable]): len(n_layers), the encoder-decoder attention list. + attns_enc (list[Variable]): len(n_layers), the encoder self attention list. + attns_dec (list[Variable]): len(n_layers), the decoder self attention list. + """ key, attns_enc = self.encoder( characters, pos_text, mask=enc_slf_mask, query_mask=enc_query_mask) @@ -48,5 +78,3 @@ class TransformerTTS(dg.Layer): m_self_mask=dec_query_slf_mask, m_mask=dec_query_mask) return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec - - return mel_output, postnet_output, attn_probs, stop_preds, attns_enc, attns_dec diff --git a/parakeet/models/transformer_tts/vocoder.py b/parakeet/models/transformer_tts/vocoder.py index 33ffe1c..93843ca 100644 --- a/parakeet/models/transformer_tts/vocoder.py +++ b/parakeet/models/transformer_tts/vocoder.py @@ -19,11 +19,13 @@ from parakeet.models.transformer_tts.cbhg import CBHG class Vocoder(dg.Layer): - """ - CBHG Network (mel -> linear) - """ - def __init__(self, config, batch_size): + """CBHG Network (mel -> linear) + + Args: + config: the yaml configs used in Vocoder model. + batch_size (int): the batch size of input. + """ super(Vocoder, self).__init__() self.pre_proj = Conv1D( num_channels=config['audio']['num_mels'], @@ -36,6 +38,15 @@ class Vocoder(dg.Layer): filter_size=1) def forward(self, mel): + """ + Compute mel spectrum to linear spectrum. + + Args: + mel (Variable): shape(B, C, T), dtype float32, the input mel spectrum. + + Returns: + mag_pred (Variable): shape(B, T, C), the linear output. + """ mel = layers.transpose(mel, [0, 2, 1]) mel = self.pre_proj(mel) mel = self.cbhg(mel) diff --git a/parakeet/modules/dynamic_gru.py b/parakeet/modules/dynamic_gru.py index 9e55688..b944b92 100644 --- a/parakeet/modules/dynamic_gru.py +++ b/parakeet/modules/dynamic_gru.py @@ -43,9 +43,10 @@ class DynamicGRU(dg.Layer): Dynamic GRU block. Args: - input (Variable): Shape(B, T, C), dtype: float32. The input value. + input (Variable): shape(B, T, C), dtype float32, the input value. + Returns: - output (Variable), Shape(B, T, C), the result compute by GRU. + output (Variable): shape(B, T, C), the result compute by GRU. """ hidden = self.h_0 res = [] diff --git a/parakeet/modules/ffn.py b/parakeet/modules/ffn.py index fe39d3c..199f668 100644 --- a/parakeet/modules/ffn.py +++ b/parakeet/modules/ffn.py @@ -19,8 +19,6 @@ from parakeet.modules.customized import Conv1D class PositionwiseFeedForward(dg.Layer): - ''' A two-feed-forward-layer module ''' - def __init__(self, d_in, num_hidden, @@ -28,6 +26,16 @@ class PositionwiseFeedForward(dg.Layer): padding=0, use_cudnn=True, dropout=0.1): + """A two-feed-forward-layer module. + + Args: + d_in (int): the size of input channel. + num_hidden (int): the size of hidden layer in network. + filter_size (int): the filter size of Conv + padding (int, optional): the padding size of Conv. Defaults to 0. + use_cudnn (bool, optional): use cudnn in Conv or not. Defaults to True. + dropout (float, optional): dropout probability. Defaults to 0.1. + """ super(PositionwiseFeedForward, self).__init__() self.num_hidden = num_hidden self.use_cudnn = use_cudnn @@ -59,12 +67,13 @@ class PositionwiseFeedForward(dg.Layer): def forward(self, input): """ - Feed Forward Network. + Compute feed forward network result. Args: - input (Variable): Shape(B, T, C), dtype: float32. The input value. + input (Variable): shape(B, T, C), dtype float32, the input value. + Returns: - output (Variable), Shape(B, T, C), the result after FFN. + output (Variable): shape(B, T, C), the result after FFN. """ x = layers.transpose(input, [0, 2, 1]) #FFN Networt diff --git a/parakeet/modules/modules.py b/parakeet/modules/modules.py deleted file mode 100644 index 72a8d2d..0000000 --- a/parakeet/modules/modules.py +++ /dev/null @@ -1,610 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import paddle -from paddle import fluid -import paddle.fluid.dygraph as dg - -import numpy as np - -from . import conv -from . import weight_norm - - -def FC(name_scope, - in_features, - size, - num_flatten_dims=1, - relu=False, - dropout=0.0, - epsilon=1e-30, - act=None, - is_test=False, - dtype="float32"): - """ - A special Linear Layer, when it is used with dropout, the weight is - initialized as normal(0, std=np.sqrt((1-dropout) / in_features)) - """ - - # stds - if isinstance(in_features, int): - in_features = [in_features] - - stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features] - if relu: - stds = [std * np.sqrt(2.0) for std in stds] - - weight_inits = [ - fluid.initializer.NormalInitializer(scale=std) for std in stds - ] - bias_init = fluid.initializer.ConstantInitializer(0.0) - - # param attrs - weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits] - bias_attr = fluid.ParamAttr(initializer=bias_init) - - layer = weight_norm.FC(name_scope, - size, - num_flatten_dims=num_flatten_dims, - param_attr=weight_attrs, - bias_attr=bias_attr, - act=act, - dtype=dtype) - return layer - - -def Conv1D(name_scope, - in_channels, - num_filters, - filter_size=3, - dilation=1, - groups=None, - causal=False, - std_mul=1.0, - dropout=0.0, - use_cudnn=True, - act=None, - dtype="float32"): - """ - A special Conv1D Layer, when it is used with dropout, the weight is - initialized as - normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_features))) - """ - # std - std = np.sqrt((std_mul * (1 - dropout)) / (filter_size * in_channels)) - weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std) - bias_init = fluid.initializer.ConstantInitializer(0.0) - - # param attrs - weight_attr = fluid.ParamAttr(initializer=weight_init) - bias_attr = fluid.ParamAttr(initializer=bias_init) - - layer = conv.Conv1D( - name_scope, - in_channels, - num_filters, - filter_size, - dilation, - groups=groups, - causal=causal, - param_attr=weight_attr, - bias_attr=bias_attr, - use_cudnn=use_cudnn, - act=act, - dtype=dtype) - return layer - - -def Embedding(name_scope, - num_embeddings, - embed_dim, - is_sparse=False, - is_distributed=False, - padding_idx=None, - std=0.01, - dtype="float32"): - # param attrs - weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal( - scale=std)) - layer = dg.Embedding( - name_scope, (num_embeddings, embed_dim), - padding_idx=padding_idx, - param_attr=weight_attr, - dtype=dtype) - return layer - - -class Conv1DGLU(dg.Layer): - """ - A Convolution 1D block with GLU activation. It also applys dropout for the - input x. It fuses speaker embeddings through a FC activated by softsign. It - has residual connection from the input x, and scale the output by - np.sqrt(0.5). - """ - - def __init__(self, - name_scope, - n_speakers, - speaker_dim, - in_channels, - num_filters, - filter_size, - dilation, - std_mul=4.0, - dropout=0.0, - causal=False, - residual=True, - dtype="float32"): - super(Conv1DGLU, self).__init__(name_scope, dtype=dtype) - - # conv spec - self.in_channels = in_channels - self.n_speakers = n_speakers - self.speaker_dim = speaker_dim - self.num_filters = num_filters - self.filter_size = filter_size - self.dilation = dilation - self.causal = causal - self.residual = residual - - # weight init and dropout - self.std_mul = std_mul - self.dropout = dropout - - if residual: - assert ( - in_channels == num_filters - ), "this block uses residual connection"\ - "the input_channes should equals num_filters" - - self.conv = Conv1D( - self.full_name(), - in_channels, - 2 * num_filters, - filter_size, - dilation, - causal=causal, - std_mul=std_mul, - dropout=dropout, - dtype=dtype) - - if n_speakers > 1: - assert (speaker_dim is not None - ), "speaker embed should not be null in multi-speaker case" - self.fc = Conv1D( - self.full_name(), - speaker_dim, - num_filters, - filter_size=1, - dilation=1, - causal=False, - act="softsign", - dtype=dtype) - - def forward(self, x, speaker_embed_bc1t=None): - """ - Args: - x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU - layer, where B means batch_size, C_in means the input channels - T means input time steps. - speaker_embed_bct1 (Variable): Shape(B, C_sp, 1, T), expanded - speaker embed, where C_sp means speaker embedding size. Note - that when using residual connection, the Conv1DGLU does not - change the number of channels, so out channels equals input - channels. - - Returns: - x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where - C_out means the output channels of Conv1DGLU. - """ - - residual = x - x = fluid.layers.dropout(x, self.dropout) - x = self.conv(x) - - content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) - - if speaker_embed_bc1t is not None: - sp = self.fc(speaker_embed_bc1t) - content = content + sp - - # glu - x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content) - - if self.residual: - x = fluid.layers.scale(x + residual, np.sqrt(0.5)) - return x - - def add_input(self, x, speaker_embed_bc11=None): - """ - Inputs: - x: shape(B, num_filters, 1, time_steps) - speaker_embed_bc11: shape(B, speaker_dim, 1, time_steps) - - Outputs: - out: shape(B, num_filters, 1, time_steps), where time_steps = 1 - """ - - residual = x - - # add step input and produce step output - x = fluid.layers.dropout(x, self.dropout) - x = self.conv.add_input(x) - - content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) - - if speaker_embed_bc11 is not None: - sp = self.fc(speaker_embed_bc11) - content = content + sp - - x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), content) - - if self.residual: - x = fluid.layers.scale(x + residual, np.sqrt(0.5)) - return x - - -def Conv1DTranspose(name_scope, - in_channels, - num_filters, - filter_size, - padding=0, - stride=1, - dilation=1, - groups=None, - std_mul=1.0, - dropout=0.0, - use_cudnn=True, - act=None, - dtype="float32"): - std = np.sqrt(std_mul * (1 - dropout) / (in_channels * filter_size)) - weight_init = fluid.initializer.NormalInitializer(scale=std) - weight_attr = fluid.ParamAttr(initializer=weight_init) - bias_init = fluid.initializer.ConstantInitializer(0.0) - bias_attr = fluid.ParamAttr(initializer=bias_init) - layer = conv.Conv1DTranspose( - name_scope, - in_channels, - num_filters, - filter_size, - padding=padding, - stride=stride, - dilation=dilation, - groups=groups, - param_attr=weight_attr, - bias_attr=bias_attr, - use_cudnn=use_cudnn, - act=act, - dtype=dtype) - return layer - - -def compute_position_embedding(rad): - # rad is a transposed radius, shape(embed_dim, n_vocab) - embed_dim, n_vocab = rad.shape - - even_dims = dg.to_variable(np.arange(0, embed_dim, 2).astype("int32")) - odd_dims = dg.to_variable(np.arange(1, embed_dim, 2).astype("int32")) - - even_rads = fluid.layers.gather(rad, even_dims) - odd_rads = fluid.layers.gather(rad, odd_dims) - - sines = fluid.layers.sin(even_rads) - cosines = fluid.layers.cos(odd_rads) - - temp = fluid.layers.scatter(rad, even_dims, sines) - out = fluid.layers.scatter(temp, odd_dims, cosines) - out = fluid.layers.transpose(out, perm=[1, 0]) - return out - - -def position_encoding_init(n_position, - d_pos_vec, - position_rate=1.0, - sinusoidal=True): - """ Init the sinusoid position encoding table """ - - # keep idx 0 for padding token position encoding zero vector - position_enc = np.array([[ - position_rate * pos / np.power(10000, 2 * (i // 2) / d_pos_vec) - for i in range(d_pos_vec) - ] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)]) - - if sinusoidal: - position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i - position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1 - - return position_enc - - -class PositionEmbedding(dg.Layer): - def __init__(self, - name_scope, - n_position, - d_pos_vec, - position_rate=1.0, - is_sparse=False, - is_distributed=False, - param_attr=None, - max_norm=None, - padding_idx=None, - dtype="float32"): - super(PositionEmbedding, self).__init__(name_scope, dtype=dtype) - self.embed = dg.Embedding( - self.full_name(), - size=(n_position, d_pos_vec), - is_sparse=is_sparse, - is_distributed=is_distributed, - padding_idx=None, - param_attr=param_attr, - dtype=dtype) - self.set_weight( - position_encoding_init( - n_position, - d_pos_vec, - position_rate=position_rate, - sinusoidal=False).astype(dtype)) - - self._is_sparse = is_sparse - self._is_distributed = is_distributed - self._remote_prefetch = self._is_sparse and (not self._is_distributed) - if self._remote_prefetch: - assert self._is_sparse is True and self._is_distributed is False - - self._padding_idx = (-1 if padding_idx is None else padding_idx if - padding_idx >= 0 else (n_position + padding_idx)) - self._position_rate = position_rate - self._max_norm = max_norm - self._dtype = dtype - - def set_weight(self, array): - assert self.embed._w.shape == list(array.shape), "shape does not match" - self.embed._w._ivar.value().get_tensor().set( - array, fluid.framework._current_expected_place()) - - def forward(self, indices, speaker_position_rate=None): - """ - Args: - indices (Variable): Shape (B, T, 1), dtype: int64, position - indices, where B means the batch size, T means the time steps. - speaker_position_rate (Variable | float, optional), position - rate. It can be a float point number or a Variable with - shape (1,), then this speaker_position_rate is used for every - example. It can also be a Variable with shape (B, 1), which - contains a speaker position rate for each speaker. - Returns: - out (Variable): Shape(B, C_pos), position embedding, where C_pos - means position embedding size. - """ - rad = fluid.layers.transpose(self.embed._w, perm=[1, 0]) - batch_size = indices.shape[0] - - if speaker_position_rate is None: - weight = compute_position_embedding(rad) - out = self._helper.create_variable_for_type_inference(self._dtype) - self._helper.append_op( - type="lookup_table", - inputs={"Ids": indices, - "W": weight}, - outputs={"Out": out}, - attrs={ - "is_sparse": self._is_sparse, - "is_distributed": self._is_distributed, - "remote_prefetch": self._remote_prefetch, - "padding_idx": - self._padding_idx, # special value for lookup table op - }) - return out - - elif (np.isscalar(speaker_position_rate) or - isinstance(speaker_position_rate, fluid.framework.Variable) and - speaker_position_rate.shape == [1, 1]): - # # make a weight - # scale the weight (the operand for sin & cos) - if np.isscalar(speaker_position_rate): - scaled_rad = fluid.layers.scale(rad, speaker_position_rate) - else: - scaled_rad = fluid.layers.elementwise_mul( - rad, speaker_position_rate[0]) - weight = compute_position_embedding(scaled_rad) - out = self._helper.create_variable_for_type_inference(self._dtype) - self._helper.append_op( - type="lookup_table", - inputs={"Ids": indices, - "W": weight}, - outputs={"Out": out}, - attrs={ - "is_sparse": self._is_sparse, - "is_distributed": self._is_distributed, - "remote_prefetch": self._remote_prefetch, - "padding_idx": - self._padding_idx, # special value for lookup table op - }) - return out - - elif np.prod(speaker_position_rate.shape) > 1: - assert speaker_position_rate.shape == [batch_size, 1] - outputs = [] - for i in range(batch_size): - rate = speaker_position_rate[i] # rate has shape [1] - scaled_rad = fluid.layers.elementwise_mul(rad, rate) - weight = compute_position_embedding(scaled_rad) - out = self._helper.create_variable_for_type_inference( - self._dtype) - sequence = indices[i] - self._helper.append_op( - type="lookup_table", - inputs={"Ids": sequence, - "W": weight}, - outputs={"Out": out}, - attrs={ - "is_sparse": self._is_sparse, - "is_distributed": self._is_distributed, - "remote_prefetch": self._remote_prefetch, - "padding_idx": -1, - }) - outputs.append(out) - out = fluid.layers.stack(outputs) - return out - else: - raise Exception("Then you can just use position rate at init") - - -class Conv1D_GU(dg.Layer): - def __init__(self, - name_scope, - conditioner_dim, - in_channels, - num_filters, - filter_size, - dilation, - causal=False, - residual=True, - dtype="float32"): - super(Conv1D_GU, self).__init__(name_scope, dtype=dtype) - - self.conditioner_dim = conditioner_dim - self.in_channels = in_channels - self.num_filters = num_filters - self.filter_size = filter_size - self.dilation = dilation - self.causal = causal - self.residual = residual - - if residual: - assert ( - in_channels == num_filters - ), "this block uses residual connection"\ - "the input_channels should equals num_filters" - - self.conv = Conv1D( - self.full_name(), - in_channels, - 2 * num_filters, - filter_size, - dilation, - causal=causal, - dtype=dtype) - - self.fc = Conv1D( - self.full_name(), - conditioner_dim, - 2 * num_filters, - filter_size=1, - dilation=1, - causal=False, - dtype=dtype) - - def forward(self, x, skip=None, conditioner=None): - """ - Args: - x (Variable): Shape(B, C_in, 1, T), the input of Conv1D_GU - layer, where B means batch_size, C_in means the input channels - T means input time steps. - skip (Variable): Shape(B, C_in, 1, T), skip connection. - conditioner (Variable): Shape(B, C_con, 1, T), expanded mel - conditioner, where C_con is conditioner hidden dim which - equals the num of mel bands. Note that when using residual - connection, the Conv1D_GU does not change the number of - channels, so out channels equals input channels. - Returns: - x (Variable): Shape(B, C_out, 1, T), the output of Conv1D_GU, where - C_out means the output channels of Conv1D_GU. - skip (Variable): Shape(B, C_out, 1, T), skip connection. - """ - residual = x - x = self.conv(x) - - if conditioner is not None: - cond_bias = self.fc(conditioner) - x += cond_bias - - content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) - - # Gated Unit. - x = fluid.layers.elementwise_mul( - fluid.layers.sigmoid(gate), fluid.layers.tanh(content)) - - if skip is None: - skip = x - else: - skip = fluid.layers.scale(skip + x, np.sqrt(0.5)) - - if self.residual: - x = fluid.layers.scale(residual + x, np.sqrt(0.5)) - - return x, skip - - def add_input(self, x, skip=None, conditioner=None): - """ - Inputs: - x: shape(B, num_filters, 1, time_steps) - skip: shape(B, num_filters, 1, time_steps), skip connection - conditioner: shape(B, conditioner_dim, 1, time_steps) - Outputs: - x: shape(B, num_filters, 1, time_steps), where time_steps = 1 - skip: skip connection, same shape as x - """ - residual = x - - # add step input and produce step output - x = self.conv.add_input(x) - - if conditioner is not None: - cond_bias = self.fc(conditioner) - x += cond_bias - - content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) - - # Gated Unit. - x = fluid.layers.elementwise_mul( - fluid.layers.sigmoid(gate), fluid.layers.tanh(content)) - - if skip is None: - skip = x - else: - skip = fluid.layers.scale(skip + x, np.sqrt(0.5)) - - if self.residual: - x = fluid.layers.scale(residual + x, np.sqrt(0.5)) - - return x, skip - - -def Conv2DTranspose(name_scope, - num_filters, - filter_size, - padding=0, - stride=1, - dilation=1, - use_cudnn=True, - act=None, - dtype="float32"): - val = 1.0 / (filter_size[0] * filter_size[1]) - weight_init = fluid.initializer.ConstantInitializer(val) - weight_attr = fluid.ParamAttr(initializer=weight_init) - - layer = weight_norm.Conv2DTranspose( - name_scope, - num_filters, - filter_size=filter_size, - padding=padding, - stride=stride, - dilation=dilation, - param_attr=weight_attr, - use_cudnn=use_cudnn, - act=act, - dtype=dtype) - - return layer diff --git a/parakeet/modules/multihead_attention.py b/parakeet/modules/multihead_attention.py index 624d3ae..c6907e8 100644 --- a/parakeet/modules/multihead_attention.py +++ b/parakeet/modules/multihead_attention.py @@ -50,6 +50,11 @@ class Linear(dg.Layer): class ScaledDotProductAttention(dg.Layer): def __init__(self, d_key): + """Scaled dot product attention module. + + Args: + d_key (int): the dim of key in multihead attention. + """ super(ScaledDotProductAttention, self).__init__() self.d_key = d_key @@ -63,18 +68,18 @@ class ScaledDotProductAttention(dg.Layer): query_mask=None, dropout=0.1): """ - Scaled Dot Product Attention. + Compute scaled dot product attention. Args: - key (Variable): Shape(B, T, C), dtype: float32. The input key of attention. - value (Variable): Shape(B, T, C), dtype: float32. The input value of attention. - query (Variable): Shape(B, T, C), dtype: float32. The input query of attention. - mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key. - query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query. - dropout (Constant): dtype: float32. The probability of dropout. + key (Variable): shape(B, T, C), dtype float32, the input key of scaled dot product attention. + value (Variable): shape(B, T, C), dtype float32, the input value of scaled dot product attention. + query (Variable): shape(B, T, C), dtype float32, the input query of scaled dot product attention. + mask (Variable, optional): shape(B, T_q, T_k), dtype float32, the mask of key. Defaults to None. + query_mask (Variable, optional): shape(B, T_q, T_q), dtype float32, the mask of query. Defaults to None. + dropout (float32, optional): the probability of dropout. Defaults to 0.1. Returns: - result (Variable), Shape(B, T, C), the result of mutihead attention. - attention (Variable), Shape(n_head * B, T, C), the attention of key. + result (Variable): shape(B, T, C), the result of mutihead attention. + attention (Variable): shape(n_head * B, T, C), the attention of key. """ # Compute attention score attention = layers.matmul( @@ -105,6 +110,17 @@ class MultiheadAttention(dg.Layer): is_bias=False, dropout=0.1, is_concat=True): + """Multihead Attention. + + Args: + num_hidden (int): the number of hidden layer in network. + d_k (int): the dim of key in multihead attention. + d_q (int): the dim of query in multihead attention. + num_head (int, optional): the head number of multihead attention. Defaults to 4. + is_bias (bool, optional): whether have bias in linear layers. Default to False. + dropout (float, optional): dropout probability of FFTBlock. Defaults to 0.1. + is_concat (bool, optional): whether concat query and result. Default to True. + """ super(MultiheadAttention, self).__init__() self.num_hidden = num_hidden self.num_head = num_head @@ -128,17 +144,18 @@ class MultiheadAttention(dg.Layer): def forward(self, key, value, query_input, mask=None, query_mask=None): """ - Multihead Attention. + Compute attention. Args: - key (Variable): Shape(B, T, C), dtype: float32. The input key of attention. - value (Variable): Shape(B, T, C), dtype: float32. The input value of attention. - query_input (Variable): Shape(B, T, C), dtype: float32. The input query of attention. - mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key. - query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query. + key (Variable): shape(B, T, C), dtype float32, the input key of attention. + value (Variable): shape(B, T, C), dtype float32, the input value of attention. + query_input (Variable): shape(B, T, C), dtype float32, the input query of attention. + mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of key. Defaults to None. + query_mask (Variable, optional): shape(B, T_query, T_key), dtype float32, the mask of query. Defaults to None. + Returns: - result (Variable), Shape(B, T, C), the result of mutihead attention. - attention (Variable), Shape(n_head * B, T, C), the attention of key. + result (Variable): shape(B, T, C), the result of mutihead attention. + attention (Variable): shape(num_head * B, T, C), the attention of key and query. """ batch_size = key.shape[0] @@ -146,7 +163,6 @@ class MultiheadAttention(dg.Layer): seq_len_query = query_input.shape[1] # Make multihead attention - # key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn) key = layers.reshape( self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k]) value = layers.reshape( @@ -168,18 +184,6 @@ class MultiheadAttention(dg.Layer): result, attention = self.scal_attn( key, value, query, mask=mask, query_mask=query_mask) - key = layers.reshape( - layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k]) - value = layers.reshape( - layers.transpose(value, [2, 0, 1, 3]), - [-1, seq_len_key, self.d_k]) - query = layers.reshape( - layers.transpose(query, [2, 0, 1, 3]), - [-1, seq_len_query, self.d_q]) - - result, attention = self.scal_attn( - key, value, query, mask=mask, query_mask=query_mask) - # concat all multihead result result = layers.reshape( result, [self.num_head, batch_size, seq_len_query, self.d_q])