add global condition support for tacotron2
This commit is contained in:
parent
5011f16c10
commit
dc3b798f82
|
@ -607,7 +607,7 @@ class Tacotron2(nn.Layer):
|
|||
num_layers=postnet_conv_layers,
|
||||
dropout=p_postnet_dropout)
|
||||
|
||||
def forward(self, text_inputs, text_lens, mels, output_lens=None, speaker_ids=None, tones=None):
|
||||
def forward(self, text_inputs, text_lens, mels, output_lens=None, speaker_ids=None, tones=None, global_condition=None):
|
||||
"""Calculate forward propagation of tacotron2.
|
||||
|
||||
Parameters
|
||||
|
@ -645,6 +645,8 @@ class Tacotron2(nn.Layer):
|
|||
speaker_embedding = self.speaker_embedding(speaker_ids)
|
||||
speaker_feature = F.softplus(self.speaker_fc(speaker_embedding))
|
||||
encoder_outputs += speaker_feature.unsqueeze(1)
|
||||
if global_condition is not None:
|
||||
encoder_outputs += global_condition.unsqueeze(1)
|
||||
|
||||
|
||||
# [B, T_enc, 1]
|
||||
|
@ -671,7 +673,7 @@ class Tacotron2(nn.Layer):
|
|||
return outputs
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self, text_inputs, max_decoder_steps=1000, speaker_ids=None, tones=None):
|
||||
def infer(self, text_inputs, max_decoder_steps=1000, speaker_ids=None, tones=None, global_condition=None):
|
||||
"""Generate the mel sepctrogram of features given the sequences of character ids.
|
||||
|
||||
Parameters
|
||||
|
@ -702,6 +704,8 @@ class Tacotron2(nn.Layer):
|
|||
speaker_embedding = self.speaker_embedding(speaker_ids)
|
||||
speaker_feature = F.softplus(self.speaker_fc(speaker_embedding))
|
||||
encoder_outputs += speaker_feature.unsqueeze(1)
|
||||
if global_condition is not None:
|
||||
encoder_outputs += global_condition.unsqueeze(1)
|
||||
|
||||
mel_outputs, alignments = self.decoder.infer(
|
||||
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
||||
|
|
Loading…
Reference in New Issue