add global condition support for tacotron2
This commit is contained in:
parent
5011f16c10
commit
dc3b798f82
|
@ -607,7 +607,7 @@ class Tacotron2(nn.Layer):
|
||||||
num_layers=postnet_conv_layers,
|
num_layers=postnet_conv_layers,
|
||||||
dropout=p_postnet_dropout)
|
dropout=p_postnet_dropout)
|
||||||
|
|
||||||
def forward(self, text_inputs, text_lens, mels, output_lens=None, speaker_ids=None, tones=None):
|
def forward(self, text_inputs, text_lens, mels, output_lens=None, speaker_ids=None, tones=None, global_condition=None):
|
||||||
"""Calculate forward propagation of tacotron2.
|
"""Calculate forward propagation of tacotron2.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -645,6 +645,8 @@ class Tacotron2(nn.Layer):
|
||||||
speaker_embedding = self.speaker_embedding(speaker_ids)
|
speaker_embedding = self.speaker_embedding(speaker_ids)
|
||||||
speaker_feature = F.softplus(self.speaker_fc(speaker_embedding))
|
speaker_feature = F.softplus(self.speaker_fc(speaker_embedding))
|
||||||
encoder_outputs += speaker_feature.unsqueeze(1)
|
encoder_outputs += speaker_feature.unsqueeze(1)
|
||||||
|
if global_condition is not None:
|
||||||
|
encoder_outputs += global_condition.unsqueeze(1)
|
||||||
|
|
||||||
|
|
||||||
# [B, T_enc, 1]
|
# [B, T_enc, 1]
|
||||||
|
@ -671,7 +673,7 @@ class Tacotron2(nn.Layer):
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def infer(self, text_inputs, max_decoder_steps=1000, speaker_ids=None, tones=None):
|
def infer(self, text_inputs, max_decoder_steps=1000, speaker_ids=None, tones=None, global_condition=None):
|
||||||
"""Generate the mel sepctrogram of features given the sequences of character ids.
|
"""Generate the mel sepctrogram of features given the sequences of character ids.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -702,6 +704,8 @@ class Tacotron2(nn.Layer):
|
||||||
speaker_embedding = self.speaker_embedding(speaker_ids)
|
speaker_embedding = self.speaker_embedding(speaker_ids)
|
||||||
speaker_feature = F.softplus(self.speaker_fc(speaker_embedding))
|
speaker_feature = F.softplus(self.speaker_fc(speaker_embedding))
|
||||||
encoder_outputs += speaker_feature.unsqueeze(1)
|
encoder_outputs += speaker_feature.unsqueeze(1)
|
||||||
|
if global_condition is not None:
|
||||||
|
encoder_outputs += global_condition.unsqueeze(1)
|
||||||
|
|
||||||
mel_outputs, alignments = self.decoder.infer(
|
mel_outputs, alignments = self.decoder.infer(
|
||||||
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
||||||
|
|
Loading…
Reference in New Issue