fix some bugs

This commit is contained in:
lfchener 2020-12-17 02:56:45 +00:00
parent a5c81c75d5
commit 6420da6197
1 changed files with 6 additions and 9 deletions

View File

@ -238,10 +238,7 @@ class Tacotron2Decoder(nn.Layer):
querys = paddle.concat( querys = paddle.concat(
[ [
paddle.zeros( paddle.zeros(
shape=[ shape=[querys.shape[0], 1, querys.shape[-1]],
querys.shape[0], 1,
querys.shape[-1] * self.reduction_factor
],
dtype=querys.dtype), querys dtype=querys.dtype), querys
], ],
axis=1) axis=1)
@ -266,7 +263,7 @@ class Tacotron2Decoder(nn.Layer):
return mel_outputs, stop_logits, alignments return mel_outputs, stop_logits, alignments
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000): def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
decoder_input = paddle.zeros( query = paddle.zeros(
shape=[key.shape[0], self.d_mels * self.reduction_factor], shape=[key.shape[0], self.d_mels * self.reduction_factor],
dtype=key.dtype) #[B, C] dtype=key.dtype) #[B, C]
@ -275,8 +272,8 @@ class Tacotron2Decoder(nn.Layer):
mel_outputs, stop_logits, alignments = [], [], [] mel_outputs, stop_logits, alignments = [], [], []
while True: while True:
decoder_input = self.prenet(decoder_input) query = self.prenet(query)
mel_output, stop_logit, alignment = self._decode(decoder_input) mel_output, stop_logit, alignment = self._decode(query)
mel_outputs += [mel_output] mel_outputs += [mel_output]
stop_logits += [stop_logit] stop_logits += [stop_logit]
@ -288,7 +285,7 @@ class Tacotron2Decoder(nn.Layer):
print("Warning! Reached max decoder steps!!!") print("Warning! Reached max decoder steps!!!")
break break
decoder_input = mel_output query = mel_output
alignments = paddle.stack(alignments, axis=1) alignments = paddle.stack(alignments, axis=1)
stop_logits = paddle.concat(stop_logits, axis=1) stop_logits = paddle.concat(stop_logits, axis=1)
@ -350,7 +347,7 @@ class Tacotron2(nn.Layer):
attention_kernel_size, p_prenet_dropout, p_attention_dropout, attention_kernel_size, p_prenet_dropout, p_attention_dropout,
p_decoder_dropout) p_decoder_dropout)
self.postnet = DecoderPostNet( self.postnet = DecoderPostNet(
d_mels=d_mels, d_mels=d_mels * reduction_factor,
d_hidden=d_postnet, d_hidden=d_postnet,
kernel_size=postnet_kernel_size, kernel_size=postnet_kernel_size,
padding=int((postnet_kernel_size - 1) / 2), padding=int((postnet_kernel_size - 1) / 2),