From 6420da619704b1107b78f65a67e474c464618d93 Mon Sep 17 00:00:00 2001 From: lfchener Date: Thu, 17 Dec 2020 02:56:45 +0000 Subject: [PATCH] fix some bugs --- parakeet/models/tacotron2.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py index 599e6a2..9949e1d 100644 --- a/parakeet/models/tacotron2.py +++ b/parakeet/models/tacotron2.py @@ -238,10 +238,7 @@ class Tacotron2Decoder(nn.Layer): querys = paddle.concat( [ paddle.zeros( - shape=[ - querys.shape[0], 1, - querys.shape[-1] * self.reduction_factor - ], + shape=[querys.shape[0], 1, querys.shape[-1]], dtype=querys.dtype), querys ], axis=1) @@ -266,7 +263,7 @@ class Tacotron2Decoder(nn.Layer): return mel_outputs, stop_logits, alignments def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000): - decoder_input = paddle.zeros( + query = paddle.zeros( shape=[key.shape[0], self.d_mels * self.reduction_factor], dtype=key.dtype) #[B, C] @@ -275,8 +272,8 @@ class Tacotron2Decoder(nn.Layer): mel_outputs, stop_logits, alignments = [], [], [] while True: - decoder_input = self.prenet(decoder_input) - mel_output, stop_logit, alignment = self._decode(decoder_input) + query = self.prenet(query) + mel_output, stop_logit, alignment = self._decode(query) mel_outputs += [mel_output] stop_logits += [stop_logit] @@ -288,7 +285,7 @@ class Tacotron2Decoder(nn.Layer): print("Warning! Reached max decoder steps!!!") break - decoder_input = mel_output + query = mel_output alignments = paddle.stack(alignments, axis=1) stop_logits = paddle.concat(stop_logits, axis=1) @@ -350,7 +347,7 @@ class Tacotron2(nn.Layer): attention_kernel_size, p_prenet_dropout, p_attention_dropout, p_decoder_dropout) self.postnet = DecoderPostNet( - d_mels=d_mels, + d_mels=d_mels * reduction_factor, d_hidden=d_postnet, kernel_size=postnet_kernel_size, padding=int((postnet_kernel_size - 1) / 2),