diff --git a/parakeet/models/__init__.py b/parakeet/models/__init__.py index 0a32a9d..3db9629 100644 --- a/parakeet/models/__init__.py +++ b/parakeet/models/__init__.py @@ -19,3 +19,4 @@ from parakeet.models.waveflow import * from parakeet.models.transformer_tts import * #from parakeet.models.deepvoice3 import * # from parakeet.models.fastspeech import * +from parakeet.models.tacotron2 import * diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py index 9949e1d..e00fe22 100644 --- a/parakeet/models/tacotron2.py +++ b/parakeet/models/tacotron2.py @@ -27,6 +27,25 @@ __all__ = ["Tacotron2", "Tacotron2Loss"] class DecoderPreNet(nn.Layer): + """ + Decoder prenet module for Tacotron2. + + Parameters + ---------- + d_input: int + input dimension + + d_hidden: int + hidden size + + d_output: int + output Dimension + + dropout_rate: float + droput probability + + """ + def __init__(self, d_input: int, d_hidden: int, @@ -39,23 +58,60 @@ class DecoderPreNet(nn.Layer): self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False) def forward(self, x): + """Calculate forward propagation. + + Parameters + ---------- + x: Tensor[shape=(B, T_mel, C)] + batch of the sequences of padded mel spectrogram + + Returns + ------- + output: Tensor[shape=(B, T_mel, C)] + batch of the sequences of padded hidden state + + """ + x = F.dropout(F.relu(self.linear1(x)), self.dropout_rate) output = F.dropout(F.relu(self.linear2(x)), self.dropout_rate) return output class DecoderPostNet(nn.Layer): + """ + Decoder postnet module for Tacotron2. + + Parameters + ---------- + d_mels: int + number of mel bands + + d_hidden: int + hidden size of postnet + + kernel_size: int + kernel size of the conv layer in postnet + + num_layers: int + number of conv layers in postnet + + dropout: float + droput probability + + """ + def __init__(self, d_mels: int=80, d_hidden: int=512, kernel_size: int=5, - padding: int=0, num_layers: int=5, dropout: float=0.1): super().__init__() self.dropout = dropout self.num_layers = num_layers + padding = int((kernel_size - 1) / 2), + self.conv_batchnorms = nn.LayerList() k = math.sqrt(1.0 / (d_mels * kernel_size)) self.conv_batchnorms.append( @@ -91,15 +147,47 @@ class DecoderPostNet(nn.Layer): data_format='NLC')) def forward(self, input): + """Calculate forward propagation. + + Parameters + ---------- + input: Tensor[shape=(B, T_mel, C)] + output sequence of features from decoder + + Returns + ------- + output: Tensor[shape=(B, T_mel, C)] + output sequence of features after postnet + + """ + for i in range(len(self.conv_batchnorms) - 1): input = F.dropout( F.tanh(self.conv_batchnorms[i](input), self.dropout)) - input = F.dropout(self.conv_batchnorms[self.num_layers - 1](input), - self.dropout) - return input + output = F.dropout(self.conv_batchnorms[self.num_layers - 1](input), + self.dropout) + return output class Tacotron2Encoder(nn.Layer): + """ + Tacotron2 encoder module for Tacotron2. + + Parameters + ---------- + d_hidden: int + hidden size in encoder module + + conv_layers: int + number of conv layers + + kernel_size: int + kernel size of conv layers + + p_dropout: float + droput probability + """ + def __init__(self, d_hidden: int, conv_layers: int, @@ -126,6 +214,22 @@ class Tacotron2Encoder(nn.Layer): d_hidden, self.hidden_size, direction="bidirectional") def forward(self, x, input_lens=None): + """Calculate forward propagation of tacotron2 encoder. + + Parameters + ---------- + x: Tensor[shape=(B, T)] + batch of the sequencees of padded character ids + + text_lens: Tensor[shape=(B,)] + batch of lengths of each text input batch. + + Returns + ------- + output : Tensor[shape=(B, T, C)] + batch of the sequences of padded hidden states + + """ for conv_batchnorm in self.conv_batchnorms: x = F.dropout(F.relu(conv_batchnorm(x)), self.p_dropout) #(B, T, C) @@ -135,6 +239,47 @@ class Tacotron2Encoder(nn.Layer): class Tacotron2Decoder(nn.Layer): + """ + Tacotron2 decoder module for Tacotron2. + + Parameters + ---------- + d_mels: int + number of mel bands + + reduction_factor: int + reduction factor of tacotron + + d_encoder: int + hidden size of encoder + + d_prenet: int + hidden size in decoder prenet + + d_attention_rnn: int + attention rnn layer hidden size + + d_decoder_rnn: int + decoder rnn layer hidden size + + d_attention: int + hidden size of the linear layer in location sensitive attention + + attention_filters: int + filter size of the conv layer in location sensitive attention + + attention_kernel_size: int + kernel size of the conv layer in location sensitive attention + + p_prenet_dropout: float + droput probability in decoder prenet + + p_attention_dropout: float + droput probability in location sensitive attention + + p_decoder_dropout: float + droput probability in decoder""" + def __init__(self, d_mels: int, reduction_factor: int, @@ -175,6 +320,8 @@ class Tacotron2Decoder(nn.Layer): self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1) def _initialize_decoder_states(self, key): + """init states be used in decoder + """ batch_size = key.shape[0] MAX_TIME = key.shape[1] @@ -199,6 +346,8 @@ class Tacotron2Decoder(nn.Layer): self.processed_key = self.attention_layer.key_layer(key) #[B, T, C] def _decode(self, query): + """decode one time step + """ cell_input = paddle.concat([query, self.attention_context], axis=-1) # The first lstm layer @@ -232,6 +381,31 @@ class Tacotron2Decoder(nn.Layer): return decoder_output, stop_logit, self.attention_weights def forward(self, keys, querys, mask): + """Calculate forward propagation of tacotron2 decoder. + + Parameters + ---------- + keys: Tensor[shape=(B, T_text, C)] + batch of the sequences of padded output from encoder + + querys: Tensor[shape(B, T_mel, C)] + batch of the sequences of padded mel spectrogram + + mask: Tensor[shape=(B, T_text, 1)] + mask generated with text length + + Returns + ------- + mel_output: Tensor[shape=(B, T_mel, C)] + output sequence of features + + stop_logits: Tensor[shape=(B, T_mel)] + output sequence of stop logits + + alignments: Tensor[shape=(B, T_mel, T_text)] + attention weights + + """ querys = paddle.reshape( querys, [querys.shape[0], querys.shape[1] // self.reduction_factor, -1]) @@ -263,6 +437,31 @@ class Tacotron2Decoder(nn.Layer): return mel_outputs, stop_logits, alignments def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000): + """Calculate forward propagation of tacotron2 decoder. + + Parameters + ---------- + keys: Tensor[shape=(B, T_text, C)] + batch of the sequences of padded output from encoder + + stop_threshold: float + stop synthesize when stop logit is greater than this stop threshold + + max_decoder_steps: int + number of max step when synthesize + + Returns + ------- + mel_output: Tensor[shape=(B, T_mel, C)] + output sequence of features + + stop_logits: Tensor[shape=(B, T_mel)] + output sequence of stop logits + + alignments: Tensor[shape=(B, T_mel, T_text)] + attention weights + + """ query = paddle.zeros( shape=[key.shape[0], self.d_mels * self.reduction_factor], dtype=key.dtype) #[B, C] @@ -296,16 +495,79 @@ class Tacotron2Decoder(nn.Layer): class Tacotron2(nn.Layer): """ - Tacotron2 module for end-to-end text-to-speech (E2E-TTS). + Tacotron2 model for end-to-end text-to-speech (E2E-TTS). - This is a module of Spectrogram prediction network in Tacotron2 described + This is a model of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis - by Conditioning WaveNet on Mel Spectrogram Predictions`_, + by Conditioning WaveNet on Mel Spectrogram Predictions`, which converts the sequence of characters into the sequence of mel spectrogram. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 + + Parameters + ---------- + frontend : parakeet.frontend.Phonetics + frontend used to preprocess text + + d_mels: int + number of mel bands + + d_encoder: int + hidden size in encoder module + + encoder_conv_layers: int + number of conv layers in encoder + + encoder_kernel_size: int + kernel size of conv layers in encoder + + d_prenet: int + hidden size in decoder prenet + + d_attention_rnn: int + attention rnn layer hidden size in decoder + + d_decoder_rnn: int + decoder rnn layer hidden size in decoder + + attention_filters: int + filter size of the conv layer in location sensitive attention + + attention_kernel_size: int + kernel size of the conv layer in location sensitive attention + + d_attention: int + hidden size of the linear layer in location sensitive attention + + d_postnet: int + hidden size of postnet + + postnet_kernel_size: int + kernel size of the conv layer in postnet + + postnet_conv_layers: int + number of conv layers in postnet + + reduction_factor: int + reduction factor of tacotron + + p_encoder_dropout: float + droput probability in encoder + + p_prenet_dropout: float + droput probability in decoder prenet + + p_attention_dropout: float + droput probability in location sensitive attention + + p_decoder_dropout: float + droput probability in decoder + + p_postnet_dropout: float + droput probability in postnet + """ def __init__(self, @@ -350,11 +612,38 @@ class Tacotron2(nn.Layer): d_mels=d_mels * reduction_factor, d_hidden=d_postnet, kernel_size=postnet_kernel_size, - padding=int((postnet_kernel_size - 1) / 2), num_layers=postnet_conv_layers, dropout=p_postnet_dropout) def forward(self, text_inputs, mels, text_lens, output_lens=None): + """Calculate forward propagation of tacotron2. + + Parameters + ---------- + text_inputs: Tensor[shape=(B, T_text)] + batch of the sequencees of padded character ids + + mels: Tensor[shape(B, T_mel, C)] + batch of the sequences of padded mel spectrogram + + text_lens: Tensor[shape=(B,)] + batch of lengths of each text input batch. + + output_lens: Tensor[shape=(B,)] + batch of lengths of each mels batch. + + Returns + ------- + outputs : Dict[str, Tensor] + + mel_output: output sequence of features (B, T_mel, C) + + mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C) + + stop_logits: output sequence of stop logits (B, T_mel) + + alignments: attention weights (B, T_mel, T_text) + """ embedded_inputs = self.embedding(text_inputs) encoder_outputs = self.encoder(embedded_inputs, text_lens) @@ -386,6 +675,31 @@ class Tacotron2(nn.Layer): @paddle.no_grad() def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000): + """Generate the mel sepctrogram of features given the sequences of character ids. + + Parameters + ---------- + text_inputs: Tensor[shape=(B, T_text)] + batch of the sequencees of padded character ids + + stop_threshold: float + stop synthesize when stop logit is greater than this stop threshold + + max_decoder_steps: int + number of max step when synthesize + + Returns + ------- + outputs : Dict[str, Tensor] + + mel_output: output sequence of sepctrogram (B, T_mel, C) + + mel_outputs_postnet: output sequence of sepctrogram after postnet (B, T_mel, C) + + stop_logits: output sequence of stop logits (B, T_mel) + + alignments: attention weights (B, T_mel, T_text) + """ embedded_inputs = self.embedding(text_inputs) encoder_outputs = self.encoder(embedded_inputs) mel_outputs, stop_logits, alignments = self.decoder.infer( @@ -407,7 +721,27 @@ class Tacotron2(nn.Layer): @paddle.no_grad() def predict(self, text, stop_threshold=0.5, max_decoder_steps=1000): - # TODO(lifuchen): implement predict function to product mel from texts + """Generate the mel sepctrogram of features given the sequenc of characters. + + Parameters + ---------- + text: str + sequence of characters + + stop_threshold: float + stop synthesize when stop logit is greater than this stop threshold + + max_decoder_steps: int + number of max step when synthesize + + Returns + ------- + outputs : Dict[str, Tensor] + + mel_outputs_postnet: output sequence of sepctrogram after postnet (T_mel, C) + + alignments: attention weights (T_mel, T_text) + """ ids = np.asarray(self.frontend(text)) ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0]) outputs = self.infer(ids, stop_threshold, max_decoder_steps) @@ -416,6 +750,27 @@ class Tacotron2(nn.Layer): @classmethod def from_pretrained(cls, frontend, config, checkpoint_path): + """Build a tacotron2 model from a pretrained model. + + Parameters + ---------- + frontend: parakeet.frontend.Phonetics + frontend used to preprocess text + + config: yacs.config.CfgNode + model configs + + checkpoint_path: Path + the path of pretrained model checkpoint + + Returns + ------- + mel_outputs_postnet: Tensor[shape=(T_mel, C)] + output sequence of sepctrogram after postnet + + alignments: Tensor[shape=(T_mel, T_text)] + attention weights + """ model = cls(frontend, d_mels=config.data.d_mels, d_encoder=config.model.d_encoder, @@ -442,11 +797,45 @@ class Tacotron2(nn.Layer): class Tacotron2Loss(nn.Layer): + """ Tacotron2 Loss module + """ + def __init__(self): super().__init__() def forward(self, mel_outputs, mel_outputs_postnet, stop_logits, mel_targets, stop_tokens): + """Calculate tacotron2 loss. + + Parameters + ---------- + mel_outputs: Tensor[shape=(B, T_mel, C)] + output mel spectrogram sequence + + mel_outputs_postnet: Tensor[shape(B, T_mel, C)] + output mel spectrogram sequence after postnet + + stop_logits: Tensor[shape=(B, T_mel)] + output sequence of stop logits befor sigmoid + + mel_targets: Tensor[shape=(B,)] + target mel spectrogram sequence + + stop_tokens: + target stop token + + Returns + ------- + losses : Dict[str, float] + + loss: the sum of the other three losses + + mel_loss: MSE loss compute by mel_targets and mel_outputs + + post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet + + stop_loss: stop loss computed by stop_logits and stop token + """ mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets) post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets) stop_loss = paddle.nn.BCEWithLogitsLoss()(stop_logits, stop_tokens)