diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py index f7efd58..59699d9 100644 --- a/parakeet/models/tacotron2.py +++ b/parakeet/models/tacotron2.py @@ -607,7 +607,7 @@ class Tacotron2(nn.Layer): num_layers=postnet_conv_layers, dropout=p_postnet_dropout) - def forward(self, text_inputs, text_lens, mels, output_lens=None, speaker_ids=None, tones=None): + def forward(self, text_inputs, text_lens, mels, output_lens=None, speaker_ids=None, tones=None, global_condition=None): """Calculate forward propagation of tacotron2. Parameters @@ -645,6 +645,8 @@ class Tacotron2(nn.Layer): speaker_embedding = self.speaker_embedding(speaker_ids) speaker_feature = F.softplus(self.speaker_fc(speaker_embedding)) encoder_outputs += speaker_feature.unsqueeze(1) + if global_condition is not None: + encoder_outputs += global_condition.unsqueeze(1) # [B, T_enc, 1] @@ -671,7 +673,7 @@ class Tacotron2(nn.Layer): return outputs @paddle.no_grad() - def infer(self, text_inputs, max_decoder_steps=1000, speaker_ids=None, tones=None): + def infer(self, text_inputs, max_decoder_steps=1000, speaker_ids=None, tones=None, global_condition=None): """Generate the mel sepctrogram of features given the sequences of character ids. Parameters @@ -702,6 +704,8 @@ class Tacotron2(nn.Layer): speaker_embedding = self.speaker_embedding(speaker_ids) speaker_feature = F.softplus(self.speaker_fc(speaker_embedding)) encoder_outputs += speaker_feature.unsqueeze(1) + if global_condition is not None: + encoder_outputs += global_condition.unsqueeze(1) mel_outputs, alignments = self.decoder.infer( encoder_outputs, max_decoder_steps=max_decoder_steps)