add docstring for tacotron2

This commit is contained in:
lfchener 2020-12-18 15:31:40 +08:00
parent d81df88173
commit ecdeb14a40
2 changed files with 399 additions and 9 deletions

View File

@ -19,3 +19,4 @@ from parakeet.models.waveflow import *
from parakeet.models.transformer_tts import *
#from parakeet.models.deepvoice3 import *
# from parakeet.models.fastspeech import *
from parakeet.models.tacotron2 import *

View File

@ -27,6 +27,25 @@ __all__ = ["Tacotron2", "Tacotron2Loss"]
class DecoderPreNet(nn.Layer):
"""
Decoder prenet module for Tacotron2.
Parameters
----------
d_input: int
input dimension
d_hidden: int
hidden size
d_output: int
output Dimension
dropout_rate: float
droput probability
"""
def __init__(self,
d_input: int,
d_hidden: int,
@ -39,23 +58,60 @@ class DecoderPreNet(nn.Layer):
self.linear2 = nn.Linear(d_hidden, d_output, bias_attr=False)
def forward(self, x):
"""Calculate forward propagation.
Parameters
----------
x: Tensor[shape=(B, T_mel, C)]
batch of the sequences of padded mel spectrogram
Returns
-------
output: Tensor[shape=(B, T_mel, C)]
batch of the sequences of padded hidden state
"""
x = F.dropout(F.relu(self.linear1(x)), self.dropout_rate)
output = F.dropout(F.relu(self.linear2(x)), self.dropout_rate)
return output
class DecoderPostNet(nn.Layer):
"""
Decoder postnet module for Tacotron2.
Parameters
----------
d_mels: int
number of mel bands
d_hidden: int
hidden size of postnet
kernel_size: int
kernel size of the conv layer in postnet
num_layers: int
number of conv layers in postnet
dropout: float
droput probability
"""
def __init__(self,
d_mels: int=80,
d_hidden: int=512,
kernel_size: int=5,
padding: int=0,
num_layers: int=5,
dropout: float=0.1):
super().__init__()
self.dropout = dropout
self.num_layers = num_layers
padding = int((kernel_size - 1) / 2),
self.conv_batchnorms = nn.LayerList()
k = math.sqrt(1.0 / (d_mels * kernel_size))
self.conv_batchnorms.append(
@ -91,15 +147,47 @@ class DecoderPostNet(nn.Layer):
data_format='NLC'))
def forward(self, input):
"""Calculate forward propagation.
Parameters
----------
input: Tensor[shape=(B, T_mel, C)]
output sequence of features from decoder
Returns
-------
output: Tensor[shape=(B, T_mel, C)]
output sequence of features after postnet
"""
for i in range(len(self.conv_batchnorms) - 1):
input = F.dropout(
F.tanh(self.conv_batchnorms[i](input), self.dropout))
input = F.dropout(self.conv_batchnorms[self.num_layers - 1](input),
self.dropout)
return input
output = F.dropout(self.conv_batchnorms[self.num_layers - 1](input),
self.dropout)
return output
class Tacotron2Encoder(nn.Layer):
"""
Tacotron2 encoder module for Tacotron2.
Parameters
----------
d_hidden: int
hidden size in encoder module
conv_layers: int
number of conv layers
kernel_size: int
kernel size of conv layers
p_dropout: float
droput probability
"""
def __init__(self,
d_hidden: int,
conv_layers: int,
@ -126,6 +214,22 @@ class Tacotron2Encoder(nn.Layer):
d_hidden, self.hidden_size, direction="bidirectional")
def forward(self, x, input_lens=None):
"""Calculate forward propagation of tacotron2 encoder.
Parameters
----------
x: Tensor[shape=(B, T)]
batch of the sequencees of padded character ids
text_lens: Tensor[shape=(B,)]
batch of lengths of each text input batch.
Returns
-------
output : Tensor[shape=(B, T, C)]
batch of the sequences of padded hidden states
"""
for conv_batchnorm in self.conv_batchnorms:
x = F.dropout(F.relu(conv_batchnorm(x)),
self.p_dropout) #(B, T, C)
@ -135,6 +239,47 @@ class Tacotron2Encoder(nn.Layer):
class Tacotron2Decoder(nn.Layer):
"""
Tacotron2 decoder module for Tacotron2.
Parameters
----------
d_mels: int
number of mel bands
reduction_factor: int
reduction factor of tacotron
d_encoder: int
hidden size of encoder
d_prenet: int
hidden size in decoder prenet
d_attention_rnn: int
attention rnn layer hidden size
d_decoder_rnn: int
decoder rnn layer hidden size
d_attention: int
hidden size of the linear layer in location sensitive attention
attention_filters: int
filter size of the conv layer in location sensitive attention
attention_kernel_size: int
kernel size of the conv layer in location sensitive attention
p_prenet_dropout: float
droput probability in decoder prenet
p_attention_dropout: float
droput probability in location sensitive attention
p_decoder_dropout: float
droput probability in decoder"""
def __init__(self,
d_mels: int,
reduction_factor: int,
@ -175,6 +320,8 @@ class Tacotron2Decoder(nn.Layer):
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
def _initialize_decoder_states(self, key):
"""init states be used in decoder
"""
batch_size = key.shape[0]
MAX_TIME = key.shape[1]
@ -199,6 +346,8 @@ class Tacotron2Decoder(nn.Layer):
self.processed_key = self.attention_layer.key_layer(key) #[B, T, C]
def _decode(self, query):
"""decode one time step
"""
cell_input = paddle.concat([query, self.attention_context], axis=-1)
# The first lstm layer
@ -232,6 +381,31 @@ class Tacotron2Decoder(nn.Layer):
return decoder_output, stop_logit, self.attention_weights
def forward(self, keys, querys, mask):
"""Calculate forward propagation of tacotron2 decoder.
Parameters
----------
keys: Tensor[shape=(B, T_text, C)]
batch of the sequences of padded output from encoder
querys: Tensor[shape(B, T_mel, C)]
batch of the sequences of padded mel spectrogram
mask: Tensor[shape=(B, T_text, 1)]
mask generated with text length
Returns
-------
mel_output: Tensor[shape=(B, T_mel, C)]
output sequence of features
stop_logits: Tensor[shape=(B, T_mel)]
output sequence of stop logits
alignments: Tensor[shape=(B, T_mel, T_text)]
attention weights
"""
querys = paddle.reshape(
querys,
[querys.shape[0], querys.shape[1] // self.reduction_factor, -1])
@ -263,6 +437,31 @@ class Tacotron2Decoder(nn.Layer):
return mel_outputs, stop_logits, alignments
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
"""Calculate forward propagation of tacotron2 decoder.
Parameters
----------
keys: Tensor[shape=(B, T_text, C)]
batch of the sequences of padded output from encoder
stop_threshold: float
stop synthesize when stop logit is greater than this stop threshold
max_decoder_steps: int
number of max step when synthesize
Returns
-------
mel_output: Tensor[shape=(B, T_mel, C)]
output sequence of features
stop_logits: Tensor[shape=(B, T_mel)]
output sequence of stop logits
alignments: Tensor[shape=(B, T_mel, T_text)]
attention weights
"""
query = paddle.zeros(
shape=[key.shape[0], self.d_mels * self.reduction_factor],
dtype=key.dtype) #[B, C]
@ -296,16 +495,79 @@ class Tacotron2Decoder(nn.Layer):
class Tacotron2(nn.Layer):
"""
Tacotron2 module for end-to-end text-to-speech (E2E-TTS).
Tacotron2 model for end-to-end text-to-speech (E2E-TTS).
This is a module of Spectrogram prediction network in Tacotron2 described
This is a model of Spectrogram prediction network in Tacotron2 described
in `Natural TTS Synthesis
by Conditioning WaveNet on Mel Spectrogram Predictions`_,
by Conditioning WaveNet on Mel Spectrogram Predictions`,
which converts the sequence of characters
into the sequence of mel spectrogram.
.. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
https://arxiv.org/abs/1712.05884
Parameters
----------
frontend : parakeet.frontend.Phonetics
frontend used to preprocess text
d_mels: int
number of mel bands
d_encoder: int
hidden size in encoder module
encoder_conv_layers: int
number of conv layers in encoder
encoder_kernel_size: int
kernel size of conv layers in encoder
d_prenet: int
hidden size in decoder prenet
d_attention_rnn: int
attention rnn layer hidden size in decoder
d_decoder_rnn: int
decoder rnn layer hidden size in decoder
attention_filters: int
filter size of the conv layer in location sensitive attention
attention_kernel_size: int
kernel size of the conv layer in location sensitive attention
d_attention: int
hidden size of the linear layer in location sensitive attention
d_postnet: int
hidden size of postnet
postnet_kernel_size: int
kernel size of the conv layer in postnet
postnet_conv_layers: int
number of conv layers in postnet
reduction_factor: int
reduction factor of tacotron
p_encoder_dropout: float
droput probability in encoder
p_prenet_dropout: float
droput probability in decoder prenet
p_attention_dropout: float
droput probability in location sensitive attention
p_decoder_dropout: float
droput probability in decoder
p_postnet_dropout: float
droput probability in postnet
"""
def __init__(self,
@ -350,11 +612,38 @@ class Tacotron2(nn.Layer):
d_mels=d_mels * reduction_factor,
d_hidden=d_postnet,
kernel_size=postnet_kernel_size,
padding=int((postnet_kernel_size - 1) / 2),
num_layers=postnet_conv_layers,
dropout=p_postnet_dropout)
def forward(self, text_inputs, mels, text_lens, output_lens=None):
"""Calculate forward propagation of tacotron2.
Parameters
----------
text_inputs: Tensor[shape=(B, T_text)]
batch of the sequencees of padded character ids
mels: Tensor[shape(B, T_mel, C)]
batch of the sequences of padded mel spectrogram
text_lens: Tensor[shape=(B,)]
batch of lengths of each text input batch.
output_lens: Tensor[shape=(B,)]
batch of lengths of each mels batch.
Returns
-------
outputs : Dict[str, Tensor]
mel_output: output sequence of features (B, T_mel, C)
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C)
stop_logits: output sequence of stop logits (B, T_mel)
alignments: attention weights (B, T_mel, T_text)
"""
embedded_inputs = self.embedding(text_inputs)
encoder_outputs = self.encoder(embedded_inputs, text_lens)
@ -386,6 +675,31 @@ class Tacotron2(nn.Layer):
@paddle.no_grad()
def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000):
"""Generate the mel sepctrogram of features given the sequences of character ids.
Parameters
----------
text_inputs: Tensor[shape=(B, T_text)]
batch of the sequencees of padded character ids
stop_threshold: float
stop synthesize when stop logit is greater than this stop threshold
max_decoder_steps: int
number of max step when synthesize
Returns
-------
outputs : Dict[str, Tensor]
mel_output: output sequence of sepctrogram (B, T_mel, C)
mel_outputs_postnet: output sequence of sepctrogram after postnet (B, T_mel, C)
stop_logits: output sequence of stop logits (B, T_mel)
alignments: attention weights (B, T_mel, T_text)
"""
embedded_inputs = self.embedding(text_inputs)
encoder_outputs = self.encoder(embedded_inputs)
mel_outputs, stop_logits, alignments = self.decoder.infer(
@ -407,7 +721,27 @@ class Tacotron2(nn.Layer):
@paddle.no_grad()
def predict(self, text, stop_threshold=0.5, max_decoder_steps=1000):
# TODO(lifuchen): implement predict function to product mel from texts
"""Generate the mel sepctrogram of features given the sequenc of characters.
Parameters
----------
text: str
sequence of characters
stop_threshold: float
stop synthesize when stop logit is greater than this stop threshold
max_decoder_steps: int
number of max step when synthesize
Returns
-------
outputs : Dict[str, Tensor]
mel_outputs_postnet: output sequence of sepctrogram after postnet (T_mel, C)
alignments: attention weights (T_mel, T_text)
"""
ids = np.asarray(self.frontend(text))
ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0])
outputs = self.infer(ids, stop_threshold, max_decoder_steps)
@ -416,6 +750,27 @@ class Tacotron2(nn.Layer):
@classmethod
def from_pretrained(cls, frontend, config, checkpoint_path):
"""Build a tacotron2 model from a pretrained model.
Parameters
----------
frontend: parakeet.frontend.Phonetics
frontend used to preprocess text
config: yacs.config.CfgNode
model configs
checkpoint_path: Path
the path of pretrained model checkpoint
Returns
-------
mel_outputs_postnet: Tensor[shape=(T_mel, C)]
output sequence of sepctrogram after postnet
alignments: Tensor[shape=(T_mel, T_text)]
attention weights
"""
model = cls(frontend,
d_mels=config.data.d_mels,
d_encoder=config.model.d_encoder,
@ -442,11 +797,45 @@ class Tacotron2(nn.Layer):
class Tacotron2Loss(nn.Layer):
""" Tacotron2 Loss module
"""
def __init__(self):
super().__init__()
def forward(self, mel_outputs, mel_outputs_postnet, stop_logits,
mel_targets, stop_tokens):
"""Calculate tacotron2 loss.
Parameters
----------
mel_outputs: Tensor[shape=(B, T_mel, C)]
output mel spectrogram sequence
mel_outputs_postnet: Tensor[shape(B, T_mel, C)]
output mel spectrogram sequence after postnet
stop_logits: Tensor[shape=(B, T_mel)]
output sequence of stop logits befor sigmoid
mel_targets: Tensor[shape=(B,)]
target mel spectrogram sequence
stop_tokens:
target stop token
Returns
-------
losses : Dict[str, float]
loss: the sum of the other three losses
mel_loss: MSE loss compute by mel_targets and mel_outputs
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet
stop_loss: stop loss computed by stop_logits and stop token
"""
mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets)
post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets)
stop_loss = paddle.nn.BCEWithLogitsLoss()(stop_logits, stop_tokens)