add docstring for tacotron2
This commit is contained in:
parent
d81df88173
commit
ecdeb14a40
|
@ -19,3 +19,4 @@ from parakeet.models.waveflow import *
|
|||
from parakeet.models.transformer_tts import *
|
||||
#from parakeet.models.deepvoice3 import *
|
||||
# from parakeet.models.fastspeech import *
|
||||
from parakeet.models.tacotron2 import *
|
||||
|
|
|
@ -27,6 +27,25 @@ __all__ = ["Tacotron2", "Tacotron2Loss"]
|
|||
|
||||
|
||||
class DecoderPreNet(nn.Layer):
|
||||
"""
|
||||
Decoder prenet module for Tacotron2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_input: int
|
||||
input dimension
|
||||
|
||||
d_hidden: int
|
||||
hidden size
|
||||
|
||||
d_output: int
|
||||
output Dimension
|
||||
|
||||
dropout_rate: float
|
||||
droput probability
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_input: int,
|
||||
d_hidden: int,
|
||||
|
@ -39,23 +58,60 @@ class DecoderPreNet(nn.Layer):
|
|||
self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: Tensor[shape=(B, T_mel, C)]
|
||||
batch of the sequences of padded mel spectrogram
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: Tensor[shape=(B, T_mel, C)]
|
||||
batch of the sequences of padded hidden state
|
||||
|
||||
"""
|
||||
|
||||
x = F.dropout(F.relu(self.linear1(x)), self.dropout_rate)
|
||||
output = F.dropout(F.relu(self.linear2(x)), self.dropout_rate)
|
||||
return output
|
||||
|
||||
|
||||
class DecoderPostNet(nn.Layer):
|
||||
"""
|
||||
Decoder postnet module for Tacotron2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_mels: int
|
||||
number of mel bands
|
||||
|
||||
d_hidden: int
|
||||
hidden size of postnet
|
||||
|
||||
kernel_size: int
|
||||
kernel size of the conv layer in postnet
|
||||
|
||||
num_layers: int
|
||||
number of conv layers in postnet
|
||||
|
||||
dropout: float
|
||||
droput probability
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_mels: int=80,
|
||||
d_hidden: int=512,
|
||||
kernel_size: int=5,
|
||||
padding: int=0,
|
||||
num_layers: int=5,
|
||||
dropout: float=0.1):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
self.num_layers = num_layers
|
||||
|
||||
padding = int((kernel_size - 1) / 2),
|
||||
|
||||
self.conv_batchnorms = nn.LayerList()
|
||||
k = math.sqrt(1.0 / (d_mels * kernel_size))
|
||||
self.conv_batchnorms.append(
|
||||
|
@ -91,15 +147,47 @@ class DecoderPostNet(nn.Layer):
|
|||
data_format='NLC'))
|
||||
|
||||
def forward(self, input):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input: Tensor[shape=(B, T_mel, C)]
|
||||
output sequence of features from decoder
|
||||
|
||||
Returns
|
||||
-------
|
||||
output: Tensor[shape=(B, T_mel, C)]
|
||||
output sequence of features after postnet
|
||||
|
||||
"""
|
||||
|
||||
for i in range(len(self.conv_batchnorms) - 1):
|
||||
input = F.dropout(
|
||||
F.tanh(self.conv_batchnorms[i](input), self.dropout))
|
||||
input = F.dropout(self.conv_batchnorms[self.num_layers - 1](input),
|
||||
self.dropout)
|
||||
return input
|
||||
output = F.dropout(self.conv_batchnorms[self.num_layers - 1](input),
|
||||
self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
class Tacotron2Encoder(nn.Layer):
|
||||
"""
|
||||
Tacotron2 encoder module for Tacotron2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_hidden: int
|
||||
hidden size in encoder module
|
||||
|
||||
conv_layers: int
|
||||
number of conv layers
|
||||
|
||||
kernel_size: int
|
||||
kernel size of conv layers
|
||||
|
||||
p_dropout: float
|
||||
droput probability
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_hidden: int,
|
||||
conv_layers: int,
|
||||
|
@ -126,6 +214,22 @@ class Tacotron2Encoder(nn.Layer):
|
|||
d_hidden, self.hidden_size, direction="bidirectional")
|
||||
|
||||
def forward(self, x, input_lens=None):
|
||||
"""Calculate forward propagation of tacotron2 encoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x: Tensor[shape=(B, T)]
|
||||
batch of the sequencees of padded character ids
|
||||
|
||||
text_lens: Tensor[shape=(B,)]
|
||||
batch of lengths of each text input batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
output : Tensor[shape=(B, T, C)]
|
||||
batch of the sequences of padded hidden states
|
||||
|
||||
"""
|
||||
for conv_batchnorm in self.conv_batchnorms:
|
||||
x = F.dropout(F.relu(conv_batchnorm(x)),
|
||||
self.p_dropout) #(B, T, C)
|
||||
|
@ -135,6 +239,47 @@ class Tacotron2Encoder(nn.Layer):
|
|||
|
||||
|
||||
class Tacotron2Decoder(nn.Layer):
|
||||
"""
|
||||
Tacotron2 decoder module for Tacotron2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_mels: int
|
||||
number of mel bands
|
||||
|
||||
reduction_factor: int
|
||||
reduction factor of tacotron
|
||||
|
||||
d_encoder: int
|
||||
hidden size of encoder
|
||||
|
||||
d_prenet: int
|
||||
hidden size in decoder prenet
|
||||
|
||||
d_attention_rnn: int
|
||||
attention rnn layer hidden size
|
||||
|
||||
d_decoder_rnn: int
|
||||
decoder rnn layer hidden size
|
||||
|
||||
d_attention: int
|
||||
hidden size of the linear layer in location sensitive attention
|
||||
|
||||
attention_filters: int
|
||||
filter size of the conv layer in location sensitive attention
|
||||
|
||||
attention_kernel_size: int
|
||||
kernel size of the conv layer in location sensitive attention
|
||||
|
||||
p_prenet_dropout: float
|
||||
droput probability in decoder prenet
|
||||
|
||||
p_attention_dropout: float
|
||||
droput probability in location sensitive attention
|
||||
|
||||
p_decoder_dropout: float
|
||||
droput probability in decoder"""
|
||||
|
||||
def __init__(self,
|
||||
d_mels: int,
|
||||
reduction_factor: int,
|
||||
|
@ -175,6 +320,8 @@ class Tacotron2Decoder(nn.Layer):
|
|||
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
|
||||
|
||||
def _initialize_decoder_states(self, key):
|
||||
"""init states be used in decoder
|
||||
"""
|
||||
batch_size = key.shape[0]
|
||||
MAX_TIME = key.shape[1]
|
||||
|
||||
|
@ -199,6 +346,8 @@ class Tacotron2Decoder(nn.Layer):
|
|||
self.processed_key = self.attention_layer.key_layer(key) #[B, T, C]
|
||||
|
||||
def _decode(self, query):
|
||||
"""decode one time step
|
||||
"""
|
||||
cell_input = paddle.concat([query, self.attention_context], axis=-1)
|
||||
|
||||
# The first lstm layer
|
||||
|
@ -232,6 +381,31 @@ class Tacotron2Decoder(nn.Layer):
|
|||
return decoder_output, stop_logit, self.attention_weights
|
||||
|
||||
def forward(self, keys, querys, mask):
|
||||
"""Calculate forward propagation of tacotron2 decoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
keys: Tensor[shape=(B, T_text, C)]
|
||||
batch of the sequences of padded output from encoder
|
||||
|
||||
querys: Tensor[shape(B, T_mel, C)]
|
||||
batch of the sequences of padded mel spectrogram
|
||||
|
||||
mask: Tensor[shape=(B, T_text, 1)]
|
||||
mask generated with text length
|
||||
|
||||
Returns
|
||||
-------
|
||||
mel_output: Tensor[shape=(B, T_mel, C)]
|
||||
output sequence of features
|
||||
|
||||
stop_logits: Tensor[shape=(B, T_mel)]
|
||||
output sequence of stop logits
|
||||
|
||||
alignments: Tensor[shape=(B, T_mel, T_text)]
|
||||
attention weights
|
||||
|
||||
"""
|
||||
querys = paddle.reshape(
|
||||
querys,
|
||||
[querys.shape[0], querys.shape[1] // self.reduction_factor, -1])
|
||||
|
@ -263,6 +437,31 @@ class Tacotron2Decoder(nn.Layer):
|
|||
return mel_outputs, stop_logits, alignments
|
||||
|
||||
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
|
||||
"""Calculate forward propagation of tacotron2 decoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
keys: Tensor[shape=(B, T_text, C)]
|
||||
batch of the sequences of padded output from encoder
|
||||
|
||||
stop_threshold: float
|
||||
stop synthesize when stop logit is greater than this stop threshold
|
||||
|
||||
max_decoder_steps: int
|
||||
number of max step when synthesize
|
||||
|
||||
Returns
|
||||
-------
|
||||
mel_output: Tensor[shape=(B, T_mel, C)]
|
||||
output sequence of features
|
||||
|
||||
stop_logits: Tensor[shape=(B, T_mel)]
|
||||
output sequence of stop logits
|
||||
|
||||
alignments: Tensor[shape=(B, T_mel, T_text)]
|
||||
attention weights
|
||||
|
||||
"""
|
||||
query = paddle.zeros(
|
||||
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
||||
dtype=key.dtype) #[B, C]
|
||||
|
@ -296,16 +495,79 @@ class Tacotron2Decoder(nn.Layer):
|
|||
|
||||
class Tacotron2(nn.Layer):
|
||||
"""
|
||||
Tacotron2 module for end-to-end text-to-speech (E2E-TTS).
|
||||
Tacotron2 model for end-to-end text-to-speech (E2E-TTS).
|
||||
|
||||
This is a module of Spectrogram prediction network in Tacotron2 described
|
||||
This is a model of Spectrogram prediction network in Tacotron2 described
|
||||
in `Natural TTS Synthesis
|
||||
by Conditioning WaveNet on Mel Spectrogram Predictions`_,
|
||||
by Conditioning WaveNet on Mel Spectrogram Predictions`,
|
||||
which converts the sequence of characters
|
||||
into the sequence of mel spectrogram.
|
||||
|
||||
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
|
||||
https://arxiv.org/abs/1712.05884
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frontend : parakeet.frontend.Phonetics
|
||||
frontend used to preprocess text
|
||||
|
||||
d_mels: int
|
||||
number of mel bands
|
||||
|
||||
d_encoder: int
|
||||
hidden size in encoder module
|
||||
|
||||
encoder_conv_layers: int
|
||||
number of conv layers in encoder
|
||||
|
||||
encoder_kernel_size: int
|
||||
kernel size of conv layers in encoder
|
||||
|
||||
d_prenet: int
|
||||
hidden size in decoder prenet
|
||||
|
||||
d_attention_rnn: int
|
||||
attention rnn layer hidden size in decoder
|
||||
|
||||
d_decoder_rnn: int
|
||||
decoder rnn layer hidden size in decoder
|
||||
|
||||
attention_filters: int
|
||||
filter size of the conv layer in location sensitive attention
|
||||
|
||||
attention_kernel_size: int
|
||||
kernel size of the conv layer in location sensitive attention
|
||||
|
||||
d_attention: int
|
||||
hidden size of the linear layer in location sensitive attention
|
||||
|
||||
d_postnet: int
|
||||
hidden size of postnet
|
||||
|
||||
postnet_kernel_size: int
|
||||
kernel size of the conv layer in postnet
|
||||
|
||||
postnet_conv_layers: int
|
||||
number of conv layers in postnet
|
||||
|
||||
reduction_factor: int
|
||||
reduction factor of tacotron
|
||||
|
||||
p_encoder_dropout: float
|
||||
droput probability in encoder
|
||||
|
||||
p_prenet_dropout: float
|
||||
droput probability in decoder prenet
|
||||
|
||||
p_attention_dropout: float
|
||||
droput probability in location sensitive attention
|
||||
|
||||
p_decoder_dropout: float
|
||||
droput probability in decoder
|
||||
|
||||
p_postnet_dropout: float
|
||||
droput probability in postnet
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -350,11 +612,38 @@ class Tacotron2(nn.Layer):
|
|||
d_mels=d_mels * reduction_factor,
|
||||
d_hidden=d_postnet,
|
||||
kernel_size=postnet_kernel_size,
|
||||
padding=int((postnet_kernel_size - 1) / 2),
|
||||
num_layers=postnet_conv_layers,
|
||||
dropout=p_postnet_dropout)
|
||||
|
||||
def forward(self, text_inputs, mels, text_lens, output_lens=None):
|
||||
"""Calculate forward propagation of tacotron2.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text_inputs: Tensor[shape=(B, T_text)]
|
||||
batch of the sequencees of padded character ids
|
||||
|
||||
mels: Tensor[shape(B, T_mel, C)]
|
||||
batch of the sequences of padded mel spectrogram
|
||||
|
||||
text_lens: Tensor[shape=(B,)]
|
||||
batch of lengths of each text input batch.
|
||||
|
||||
output_lens: Tensor[shape=(B,)]
|
||||
batch of lengths of each mels batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
outputs : Dict[str, Tensor]
|
||||
|
||||
mel_output: output sequence of features (B, T_mel, C)
|
||||
|
||||
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C)
|
||||
|
||||
stop_logits: output sequence of stop logits (B, T_mel)
|
||||
|
||||
alignments: attention weights (B, T_mel, T_text)
|
||||
"""
|
||||
embedded_inputs = self.embedding(text_inputs)
|
||||
encoder_outputs = self.encoder(embedded_inputs, text_lens)
|
||||
|
||||
|
@ -386,6 +675,31 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
@paddle.no_grad()
|
||||
def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000):
|
||||
"""Generate the mel sepctrogram of features given the sequences of character ids.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text_inputs: Tensor[shape=(B, T_text)]
|
||||
batch of the sequencees of padded character ids
|
||||
|
||||
stop_threshold: float
|
||||
stop synthesize when stop logit is greater than this stop threshold
|
||||
|
||||
max_decoder_steps: int
|
||||
number of max step when synthesize
|
||||
|
||||
Returns
|
||||
-------
|
||||
outputs : Dict[str, Tensor]
|
||||
|
||||
mel_output: output sequence of sepctrogram (B, T_mel, C)
|
||||
|
||||
mel_outputs_postnet: output sequence of sepctrogram after postnet (B, T_mel, C)
|
||||
|
||||
stop_logits: output sequence of stop logits (B, T_mel)
|
||||
|
||||
alignments: attention weights (B, T_mel, T_text)
|
||||
"""
|
||||
embedded_inputs = self.embedding(text_inputs)
|
||||
encoder_outputs = self.encoder(embedded_inputs)
|
||||
mel_outputs, stop_logits, alignments = self.decoder.infer(
|
||||
|
@ -407,7 +721,27 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
@paddle.no_grad()
|
||||
def predict(self, text, stop_threshold=0.5, max_decoder_steps=1000):
|
||||
# TODO(lifuchen): implement predict function to product mel from texts
|
||||
"""Generate the mel sepctrogram of features given the sequenc of characters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text: str
|
||||
sequence of characters
|
||||
|
||||
stop_threshold: float
|
||||
stop synthesize when stop logit is greater than this stop threshold
|
||||
|
||||
max_decoder_steps: int
|
||||
number of max step when synthesize
|
||||
|
||||
Returns
|
||||
-------
|
||||
outputs : Dict[str, Tensor]
|
||||
|
||||
mel_outputs_postnet: output sequence of sepctrogram after postnet (T_mel, C)
|
||||
|
||||
alignments: attention weights (T_mel, T_text)
|
||||
"""
|
||||
ids = np.asarray(self.frontend(text))
|
||||
ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0])
|
||||
outputs = self.infer(ids, stop_threshold, max_decoder_steps)
|
||||
|
@ -416,6 +750,27 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
@classmethod
|
||||
def from_pretrained(cls, frontend, config, checkpoint_path):
|
||||
"""Build a tacotron2 model from a pretrained model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
frontend: parakeet.frontend.Phonetics
|
||||
frontend used to preprocess text
|
||||
|
||||
config: yacs.config.CfgNode
|
||||
model configs
|
||||
|
||||
checkpoint_path: Path
|
||||
the path of pretrained model checkpoint
|
||||
|
||||
Returns
|
||||
-------
|
||||
mel_outputs_postnet: Tensor[shape=(T_mel, C)]
|
||||
output sequence of sepctrogram after postnet
|
||||
|
||||
alignments: Tensor[shape=(T_mel, T_text)]
|
||||
attention weights
|
||||
"""
|
||||
model = cls(frontend,
|
||||
d_mels=config.data.d_mels,
|
||||
d_encoder=config.model.d_encoder,
|
||||
|
@ -442,11 +797,45 @@ class Tacotron2(nn.Layer):
|
|||
|
||||
|
||||
class Tacotron2Loss(nn.Layer):
|
||||
""" Tacotron2 Loss module
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, mel_outputs, mel_outputs_postnet, stop_logits,
|
||||
mel_targets, stop_tokens):
|
||||
"""Calculate tacotron2 loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mel_outputs: Tensor[shape=(B, T_mel, C)]
|
||||
output mel spectrogram sequence
|
||||
|
||||
mel_outputs_postnet: Tensor[shape(B, T_mel, C)]
|
||||
output mel spectrogram sequence after postnet
|
||||
|
||||
stop_logits: Tensor[shape=(B, T_mel)]
|
||||
output sequence of stop logits befor sigmoid
|
||||
|
||||
mel_targets: Tensor[shape=(B,)]
|
||||
target mel spectrogram sequence
|
||||
|
||||
stop_tokens:
|
||||
target stop token
|
||||
|
||||
Returns
|
||||
-------
|
||||
losses : Dict[str, float]
|
||||
|
||||
loss: the sum of the other three losses
|
||||
|
||||
mel_loss: MSE loss compute by mel_targets and mel_outputs
|
||||
|
||||
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet
|
||||
|
||||
stop_loss: stop loss computed by stop_logits and stop token
|
||||
"""
|
||||
mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets)
|
||||
post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets)
|
||||
stop_loss = paddle.nn.BCEWithLogitsLoss()(stop_logits, stop_tokens)
|
||||
|
|
Loading…
Reference in New Issue