add from_pretrained function for tacotron2 and support synthesize
This commit is contained in:
parent
b4533af207
commit
026ae1078b
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
@ -20,6 +21,7 @@ import parakeet
|
|||
from parakeet.modules.conv import Conv1dBatchNorm
|
||||
from parakeet.modules.attention import LocationSensitiveAttention
|
||||
from parakeet.modules import masking
|
||||
from parakeet.utils import checkpoint
|
||||
|
||||
__all__ = ["Tacotron2", "Tacotron2Loss"]
|
||||
|
||||
|
@ -268,13 +270,13 @@ class Tacotron2Decoder(nn.Layer):
|
|||
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
||||
dtype=key.dtype) #[B, C]
|
||||
|
||||
self.initialize_decoder_states(key)
|
||||
self._initialize_decoder_states(key)
|
||||
self.mask = None
|
||||
|
||||
mel_outputs, stop_logits, alignments = [], [], []
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
mel_output, stop_logit, alignment = self.decode(decoder_input)
|
||||
mel_output, stop_logit, alignment = self._decode(decoder_input)
|
||||
|
||||
mel_outputs += [mel_output]
|
||||
stop_logits += [stop_logit]
|
||||
|
@ -332,10 +334,11 @@ class Tacotron2(nn.Layer):
|
|||
p_postnet_dropout: float=0.5):
|
||||
super().__init__()
|
||||
|
||||
std = math.sqrt(2.0 / (frontend.vocab_size + d_encoder))
|
||||
self.frontend = frontend
|
||||
std = math.sqrt(2.0 / (self.frontend.vocab_size + d_encoder))
|
||||
val = math.sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding = nn.Embedding(
|
||||
frontend.vocab_size,
|
||||
self.frontend.vocab_size,
|
||||
d_encoder,
|
||||
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
||||
low=-val, high=val)))
|
||||
|
@ -384,10 +387,12 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
return outputs
|
||||
|
||||
@paddle.no_grad()
|
||||
def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000):
|
||||
self.eval()
|
||||
embedded_inputs = self.embedding(text_inputs)
|
||||
encoder_outputs = self.encoder(embedded_inputs)
|
||||
mel_outputs, stop_logits, alignments = self.decoder.inference(
|
||||
mel_outputs, stop_logits, alignments = self.decoder.infer(
|
||||
encoder_outputs,
|
||||
stop_threshold=stop_threshold,
|
||||
max_decoder_steps=max_decoder_steps)
|
||||
|
@ -404,9 +409,40 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
return outputs
|
||||
|
||||
def predict(self, text):
|
||||
@paddle.no_grad()
|
||||
def predict(self, text, stop_threshold=0.5, max_decoder_steps=1000):
|
||||
# TODO(lifuchen): implement predict function to product mel from texts
|
||||
pass
|
||||
ids = np.asarray(self.frontend(text))
|
||||
ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0])
|
||||
outputs = self.infer(ids, stop_threshold, max_decoder_steps)
|
||||
return outputs['mel_outputs_postnet'][0].numpy(), outputs[
|
||||
'alignments'][0].numpy()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, frontend, config, checkpoint_path):
|
||||
model = cls(frontend,
|
||||
d_mels=config.data.d_mels,
|
||||
d_encoder=config.model.d_encoder,
|
||||
encoder_conv_layers=config.model.encoder_conv_layers,
|
||||
encoder_kernel_size=config.model.encoder_kernel_size,
|
||||
d_prenet=config.model.d_prenet,
|
||||
d_attention_rnn=config.model.d_attention_rnn,
|
||||
d_decoder_rnn=config.model.d_decoder_rnn,
|
||||
attention_filters=config.model.attention_filters,
|
||||
attention_kernel_size=config.model.attention_kernel_size,
|
||||
d_attention=config.model.d_attention,
|
||||
d_postnet=config.model.d_postnet,
|
||||
postnet_kernel_size=config.model.postnet_kernel_size,
|
||||
postnet_conv_layers=config.model.postnet_conv_layers,
|
||||
reduction_factor=config.model.reduction_factor,
|
||||
p_encoder_dropout=config.model.p_encoder_dropout,
|
||||
p_prenet_dropout=config.model.p_prenet_dropout,
|
||||
p_attention_dropout=config.model.p_attention_dropout,
|
||||
p_decoder_dropout=config.model.p_decoder_dropout,
|
||||
p_postnet_dropout=config.model.p_postnet_dropout)
|
||||
|
||||
checkpoint.load_parameters(model, checkpoint_path=checkpoint_path)
|
||||
return model
|
||||
|
||||
|
||||
class Tacotron2Loss(nn.Layer):
|
||||
|
|
Loading…
Reference in New Issue