add global condition support for tacotron2

This commit is contained in:
iclementine 2021-04-08 04:58:44 +08:00
parent 5011f16c10
commit dc3b798f82
1 changed files with 6 additions and 2 deletions

View File

@ -607,7 +607,7 @@ class Tacotron2(nn.Layer):
num_layers=postnet_conv_layers,
dropout=p_postnet_dropout)
def forward(self, text_inputs, text_lens, mels, output_lens=None, speaker_ids=None, tones=None):
def forward(self, text_inputs, text_lens, mels, output_lens=None, speaker_ids=None, tones=None, global_condition=None):
"""Calculate forward propagation of tacotron2.
Parameters
@ -645,6 +645,8 @@ class Tacotron2(nn.Layer):
speaker_embedding = self.speaker_embedding(speaker_ids)
speaker_feature = F.softplus(self.speaker_fc(speaker_embedding))
encoder_outputs += speaker_feature.unsqueeze(1)
if global_condition is not None:
encoder_outputs += global_condition.unsqueeze(1)
# [B, T_enc, 1]
@ -671,7 +673,7 @@ class Tacotron2(nn.Layer):
return outputs
@paddle.no_grad()
def infer(self, text_inputs, max_decoder_steps=1000, speaker_ids=None, tones=None):
def infer(self, text_inputs, max_decoder_steps=1000, speaker_ids=None, tones=None, global_condition=None):
"""Generate the mel sepctrogram of features given the sequences of character ids.
Parameters
@ -702,6 +704,8 @@ class Tacotron2(nn.Layer):
speaker_embedding = self.speaker_embedding(speaker_ids)
speaker_feature = F.softplus(self.speaker_fc(speaker_embedding))
encoder_outputs += speaker_feature.unsqueeze(1)
if global_condition is not None:
encoder_outputs += global_condition.unsqueeze(1)
mel_outputs, alignments = self.decoder.infer(
encoder_outputs, max_decoder_steps=max_decoder_steps)