Merge pull request #50 from lfchener/reborn

add from_pretrained function for tacotron2 and support synthesize
This commit is contained in:
Feiyu Chan 2020-12-12 18:16:19 +08:00 committed by GitHub
commit b2bd479f46
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 7 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
import math
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
@ -20,6 +21,7 @@ import parakeet
from parakeet.modules.conv import Conv1dBatchNorm
from parakeet.modules.attention import LocationSensitiveAttention
from parakeet.modules import masking
from parakeet.utils import checkpoint
__all__ = ["Tacotron2", "Tacotron2Loss"]
@ -268,13 +270,13 @@ class Tacotron2Decoder(nn.Layer):
shape=[key.shape[0], self.d_mels * self.reduction_factor],
dtype=key.dtype) #[B, C]
self.initialize_decoder_states(key)
self._initialize_decoder_states(key)
self.mask = None
mel_outputs, stop_logits, alignments = [], [], []
while True:
decoder_input = self.prenet(decoder_input)
mel_output, stop_logit, alignment = self.decode(decoder_input)
mel_output, stop_logit, alignment = self._decode(decoder_input)
mel_outputs += [mel_output]
stop_logits += [stop_logit]
@ -332,10 +334,11 @@ class Tacotron2(nn.Layer):
p_postnet_dropout: float=0.5):
super().__init__()
std = math.sqrt(2.0 / (frontend.vocab_size + d_encoder))
self.frontend = frontend
std = math.sqrt(2.0 / (self.frontend.vocab_size + d_encoder))
val = math.sqrt(3.0) * std # uniform bounds for std
self.embedding = nn.Embedding(
frontend.vocab_size,
self.frontend.vocab_size,
d_encoder,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
low=-val, high=val)))
@ -384,10 +387,12 @@ class Tacotron2(nn.Layer):
return outputs
@paddle.no_grad()
def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000):
self.eval()
embedded_inputs = self.embedding(text_inputs)
encoder_outputs = self.encoder(embedded_inputs)
mel_outputs, stop_logits, alignments = self.decoder.inference(
mel_outputs, stop_logits, alignments = self.decoder.infer(
encoder_outputs,
stop_threshold=stop_threshold,
max_decoder_steps=max_decoder_steps)
@ -404,9 +409,40 @@ class Tacotron2(nn.Layer):
return outputs
def predict(self, text):
@paddle.no_grad()
def predict(self, text, stop_threshold=0.5, max_decoder_steps=1000):
# TODO(lifuchen): implement predict function to product mel from texts
pass
ids = np.asarray(self.frontend(text))
ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0])
outputs = self.infer(ids, stop_threshold, max_decoder_steps)
return outputs['mel_outputs_postnet'][0].numpy(), outputs[
'alignments'][0].numpy()
@classmethod
def from_pretrained(cls, frontend, config, checkpoint_path):
model = cls(frontend,
d_mels=config.data.d_mels,
d_encoder=config.model.d_encoder,
encoder_conv_layers=config.model.encoder_conv_layers,
encoder_kernel_size=config.model.encoder_kernel_size,
d_prenet=config.model.d_prenet,
d_attention_rnn=config.model.d_attention_rnn,
d_decoder_rnn=config.model.d_decoder_rnn,
attention_filters=config.model.attention_filters,
attention_kernel_size=config.model.attention_kernel_size,
d_attention=config.model.d_attention,
d_postnet=config.model.d_postnet,
postnet_kernel_size=config.model.postnet_kernel_size,
postnet_conv_layers=config.model.postnet_conv_layers,
reduction_factor=config.model.reduction_factor,
p_encoder_dropout=config.model.p_encoder_dropout,
p_prenet_dropout=config.model.p_prenet_dropout,
p_attention_dropout=config.model.p_attention_dropout,
p_decoder_dropout=config.model.p_decoder_dropout,
p_postnet_dropout=config.model.p_postnet_dropout)
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
return model
class Tacotron2Loss(nn.Layer):