fix some bugs
This commit is contained in:
parent
a5c81c75d5
commit
6420da6197
|
@ -238,10 +238,7 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
querys = paddle.concat(
|
querys = paddle.concat(
|
||||||
[
|
[
|
||||||
paddle.zeros(
|
paddle.zeros(
|
||||||
shape=[
|
shape=[querys.shape[0], 1, querys.shape[-1]],
|
||||||
querys.shape[0], 1,
|
|
||||||
querys.shape[-1] * self.reduction_factor
|
|
||||||
],
|
|
||||||
dtype=querys.dtype), querys
|
dtype=querys.dtype), querys
|
||||||
],
|
],
|
||||||
axis=1)
|
axis=1)
|
||||||
|
@ -266,7 +263,7 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
return mel_outputs, stop_logits, alignments
|
return mel_outputs, stop_logits, alignments
|
||||||
|
|
||||||
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
|
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
|
||||||
decoder_input = paddle.zeros(
|
query = paddle.zeros(
|
||||||
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
||||||
dtype=key.dtype) #[B, C]
|
dtype=key.dtype) #[B, C]
|
||||||
|
|
||||||
|
@ -275,8 +272,8 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
|
|
||||||
mel_outputs, stop_logits, alignments = [], [], []
|
mel_outputs, stop_logits, alignments = [], [], []
|
||||||
while True:
|
while True:
|
||||||
decoder_input = self.prenet(decoder_input)
|
query = self.prenet(query)
|
||||||
mel_output, stop_logit, alignment = self._decode(decoder_input)
|
mel_output, stop_logit, alignment = self._decode(query)
|
||||||
|
|
||||||
mel_outputs += [mel_output]
|
mel_outputs += [mel_output]
|
||||||
stop_logits += [stop_logit]
|
stop_logits += [stop_logit]
|
||||||
|
@ -288,7 +285,7 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
print("Warning! Reached max decoder steps!!!")
|
print("Warning! Reached max decoder steps!!!")
|
||||||
break
|
break
|
||||||
|
|
||||||
decoder_input = mel_output
|
query = mel_output
|
||||||
|
|
||||||
alignments = paddle.stack(alignments, axis=1)
|
alignments = paddle.stack(alignments, axis=1)
|
||||||
stop_logits = paddle.concat(stop_logits, axis=1)
|
stop_logits = paddle.concat(stop_logits, axis=1)
|
||||||
|
@ -350,7 +347,7 @@ class Tacotron2(nn.Layer):
|
||||||
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
||||||
p_decoder_dropout)
|
p_decoder_dropout)
|
||||||
self.postnet = DecoderPostNet(
|
self.postnet = DecoderPostNet(
|
||||||
d_mels=d_mels,
|
d_mels=d_mels * reduction_factor,
|
||||||
d_hidden=d_postnet,
|
d_hidden=d_postnet,
|
||||||
kernel_size=postnet_kernel_size,
|
kernel_size=postnet_kernel_size,
|
||||||
padding=int((postnet_kernel_size - 1) / 2),
|
padding=int((postnet_kernel_size - 1) / 2),
|
||||||
|
|
Loading…
Reference in New Issue