refactoring code
This commit is contained in:
parent
0aa7088d36
commit
086fbf8e35
|
@ -37,9 +37,9 @@ _C.model = CN(
|
||||||
encoder_kernel_size=5, # kernel size of conv layers in tacotron2 encoder
|
encoder_kernel_size=5, # kernel size of conv layers in tacotron2 encoder
|
||||||
d_prenet=256, # hidden size of decoder prenet
|
d_prenet=256, # hidden size of decoder prenet
|
||||||
d_attention_rnn=1024, # hidden size of the first rnn layer in tacotron2 decoder
|
d_attention_rnn=1024, # hidden size of the first rnn layer in tacotron2 decoder
|
||||||
d_decoder_rnn=1024, #hidden size of the second rnn layer in tacotron2 decoder
|
d_decoder_rnn=1024, # hidden size of the second rnn layer in tacotron2 decoder
|
||||||
d_attention=128, # hidden size of decoder location linear layer
|
d_attention=128, # hidden size of decoder location linear layer
|
||||||
attention_filters=32, # number of filter in decoder location conv layer
|
attention_filters=32, # number of filter in decoder location conv layer
|
||||||
attention_kernel_size=31, # kernel size of decoder location conv layer
|
attention_kernel_size=31, # kernel size of decoder location conv layer
|
||||||
d_postnet=512, # hidden size of decoder postnet
|
d_postnet=512, # hidden size of decoder postnet
|
||||||
postnet_kernel_size=5, # kernel size of conv layers in postnet
|
postnet_kernel_size=5, # kernel size of conv layers in postnet
|
||||||
|
@ -48,7 +48,8 @@ _C.model = CN(
|
||||||
p_prenet_dropout=0.5, # droput probability in decoder prenet
|
p_prenet_dropout=0.5, # droput probability in decoder prenet
|
||||||
p_attention_dropout=0.1, # droput probability of first rnn layer in decoder
|
p_attention_dropout=0.1, # droput probability of first rnn layer in decoder
|
||||||
p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder
|
p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder
|
||||||
p_postnet_dropout=0.5, #droput probability in decoder postnet
|
p_postnet_dropout=0.5, # droput probability in decoder postnet
|
||||||
|
guided_attn_loss_sigma=0.2 # sigma in guided attention loss
|
||||||
))
|
))
|
||||||
|
|
||||||
_C.training = CN(
|
_C.training = CN(
|
||||||
|
|
|
@ -34,14 +34,14 @@ from ljspeech import LJSpeech, LJSpeechCollector
|
||||||
|
|
||||||
class Experiment(ExperimentBase):
|
class Experiment(ExperimentBase):
|
||||||
def compute_losses(self, inputs, outputs):
|
def compute_losses(self, inputs, outputs):
|
||||||
_, mel_targets, _, _, stop_tokens = inputs
|
_, mel_targets, plens, slens, stop_tokens = inputs
|
||||||
|
|
||||||
mel_outputs = outputs["mel_output"]
|
mel_outputs = outputs["mel_output"]
|
||||||
mel_outputs_postnet = outputs["mel_outputs_postnet"]
|
mel_outputs_postnet = outputs["mel_outputs_postnet"]
|
||||||
stop_logits = outputs["stop_logits"]
|
attention_weight = outputs["alignments"]
|
||||||
|
|
||||||
losses = self.criterion(mel_outputs, mel_outputs_postnet, stop_logits,
|
losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets,
|
||||||
mel_targets, stop_tokens)
|
attention_weight, slens, plens)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def train_batch(self):
|
def train_batch(self):
|
||||||
|
@ -145,7 +145,7 @@ class Experiment(ExperimentBase):
|
||||||
weight_decay=paddle.regularizer.L2Decay(
|
weight_decay=paddle.regularizer.L2Decay(
|
||||||
config.training.weight_decay),
|
config.training.weight_decay),
|
||||||
grad_clip=grad_clip)
|
grad_clip=grad_clip)
|
||||||
criterion = Tacotron2Loss()
|
criterion = Tacotron2Loss(config.mode.guided_attn_loss_sigma)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
|
|
|
@ -13,14 +13,16 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import numpy as np
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
import parakeet
|
from paddle.nn import initializer as I
|
||||||
|
from paddle.fluid.layers import sequence_mask
|
||||||
|
|
||||||
from parakeet.modules.conv import Conv1dBatchNorm
|
from parakeet.modules.conv import Conv1dBatchNorm
|
||||||
from parakeet.modules.attention import LocationSensitiveAttention
|
from parakeet.modules.attention import LocationSensitiveAttention
|
||||||
from parakeet.modules import masking
|
from parakeet.modules.losses import guided_attention_loss
|
||||||
from parakeet.utils import checkpoint
|
from parakeet.utils import checkpoint
|
||||||
|
|
||||||
__all__ = ["Tacotron2", "Tacotron2Loss"]
|
__all__ = ["Tacotron2", "Tacotron2Loss"]
|
||||||
|
@ -44,11 +46,7 @@ class DecoderPreNet(nn.Layer):
|
||||||
The droput probability.
|
The droput probability.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, d_input: int, d_hidden: int, d_output: int,
|
||||||
def __init__(self,
|
|
||||||
d_input: int,
|
|
||||||
d_hidden: int,
|
|
||||||
d_output: int,
|
|
||||||
dropout_rate: float):
|
dropout_rate: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -63,7 +61,7 @@ class DecoderPreNet(nn.Layer):
|
||||||
----------
|
----------
|
||||||
x: Tensor [shape=(B, T_mel, C)]
|
x: Tensor [shape=(B, T_mel, C)]
|
||||||
Batch of the sequences of padded mel spectrogram.
|
Batch of the sequences of padded mel spectrogram.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
output: Tensor [shape=(B, T_mel, C)]
|
output: Tensor [shape=(B, T_mel, C)]
|
||||||
|
@ -71,10 +69,12 @@ class DecoderPreNet(nn.Layer):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x = F.dropout(
|
x = F.dropout(F.relu(self.linear1(x)),
|
||||||
F.relu(self.linear1(x)), self.dropout_rate, training=True)
|
self.dropout_rate,
|
||||||
output = F.dropout(
|
training=True)
|
||||||
F.relu(self.linear2(x)), self.dropout_rate, training=True)
|
output = F.dropout(F.relu(self.linear2(x)),
|
||||||
|
self.dropout_rate,
|
||||||
|
training=True)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,13 +99,8 @@ class DecoderPostNet(nn.Layer):
|
||||||
The droput probability.
|
The droput probability.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, d_mels: int, d_hidden: int, kernel_size: int,
|
||||||
def __init__(self,
|
num_layers: int, dropout: float):
|
||||||
d_mels: int,
|
|
||||||
d_hidden: int,
|
|
||||||
kernel_size: int,
|
|
||||||
num_layers: int,
|
|
||||||
dropout: float):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
@ -115,45 +110,40 @@ class DecoderPostNet(nn.Layer):
|
||||||
self.conv_batchnorms = nn.LayerList()
|
self.conv_batchnorms = nn.LayerList()
|
||||||
k = math.sqrt(1.0 / (d_mels * kernel_size))
|
k = math.sqrt(1.0 / (d_mels * kernel_size))
|
||||||
self.conv_batchnorms.append(
|
self.conv_batchnorms.append(
|
||||||
Conv1dBatchNorm(
|
Conv1dBatchNorm(d_mels,
|
||||||
d_mels,
|
d_hidden,
|
||||||
d_hidden,
|
kernel_size=kernel_size,
|
||||||
kernel_size=kernel_size,
|
padding=padding,
|
||||||
padding=padding,
|
bias_attr=I.Uniform(-k, k),
|
||||||
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
data_format='NLC'))
|
||||||
low=-k, high=k)),
|
|
||||||
data_format='NLC'))
|
|
||||||
|
|
||||||
k = math.sqrt(1.0 / (d_hidden * kernel_size))
|
k = math.sqrt(1.0 / (d_hidden * kernel_size))
|
||||||
self.conv_batchnorms.extend([
|
self.conv_batchnorms.extend([
|
||||||
Conv1dBatchNorm(
|
Conv1dBatchNorm(d_hidden,
|
||||||
d_hidden,
|
d_hidden,
|
||||||
d_hidden,
|
kernel_size=kernel_size,
|
||||||
kernel_size=kernel_size,
|
padding=padding,
|
||||||
padding=padding,
|
bias_attr=I.Uniform(-k, k),
|
||||||
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
data_format='NLC')
|
||||||
low=-k, high=k)),
|
for i in range(1, num_layers - 1)
|
||||||
data_format='NLC') for i in range(1, num_layers - 1)
|
|
||||||
])
|
])
|
||||||
|
|
||||||
self.conv_batchnorms.append(
|
self.conv_batchnorms.append(
|
||||||
Conv1dBatchNorm(
|
Conv1dBatchNorm(d_hidden,
|
||||||
d_hidden,
|
d_mels,
|
||||||
d_mels,
|
kernel_size=kernel_size,
|
||||||
kernel_size=kernel_size,
|
padding=padding,
|
||||||
padding=padding,
|
bias_attr=I.Uniform(-k, k),
|
||||||
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
data_format='NLC'))
|
||||||
low=-k, high=k)),
|
|
||||||
data_format='NLC'))
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, x):
|
||||||
"""Calculate forward propagation.
|
"""Calculate forward propagation.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
input: Tensor [shape=(B, T_mel, C)]
|
x: Tensor [shape=(B, T_mel, C)]
|
||||||
Output sequence of features from decoder.
|
Output sequence of features from decoder.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
output: Tensor [shape=(B, T_mel, C)]
|
output: Tensor [shape=(B, T_mel, C)]
|
||||||
|
@ -162,14 +152,12 @@ class DecoderPostNet(nn.Layer):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
for i in range(len(self.conv_batchnorms) - 1):
|
for i in range(len(self.conv_batchnorms) - 1):
|
||||||
input = F.dropout(
|
x = F.dropout(F.tanh(self.conv_batchnorms[i](x)),
|
||||||
F.tanh(self.conv_batchnorms[i](input)),
|
self.dropout,
|
||||||
self.dropout,
|
training=self.training)
|
||||||
training=self.training)
|
output = F.dropout(self.conv_batchnorms[self.num_layers - 1](x),
|
||||||
output = F.dropout(
|
self.dropout,
|
||||||
self.conv_batchnorms[self.num_layers - 1](input),
|
training=self.training)
|
||||||
self.dropout,
|
|
||||||
training=self.training)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
@ -180,41 +168,36 @@ class Tacotron2Encoder(nn.Layer):
|
||||||
----------
|
----------
|
||||||
d_hidden: int
|
d_hidden: int
|
||||||
The hidden size in encoder module.
|
The hidden size in encoder module.
|
||||||
|
|
||||||
conv_layers: int
|
conv_layers: int
|
||||||
The number of conv layers.
|
The number of conv layers.
|
||||||
|
|
||||||
kernel_size: int
|
kernel_size: int
|
||||||
The kernel size of conv layers.
|
The kernel size of conv layers.
|
||||||
|
|
||||||
p_dropout: float
|
p_dropout: float
|
||||||
The droput probability.
|
The droput probability.
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, d_hidden: int, conv_layers: int, kernel_size: int,
|
||||||
def __init__(self,
|
|
||||||
d_hidden: int,
|
|
||||||
conv_layers: int,
|
|
||||||
kernel_size: int,
|
|
||||||
p_dropout: float):
|
p_dropout: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
k = math.sqrt(1.0 / (d_hidden * kernel_size))
|
k = math.sqrt(1.0 / (d_hidden * kernel_size))
|
||||||
self.conv_batchnorms = paddle.nn.LayerList([
|
self.conv_batchnorms = paddle.nn.LayerList([
|
||||||
Conv1dBatchNorm(
|
Conv1dBatchNorm(d_hidden,
|
||||||
d_hidden,
|
d_hidden,
|
||||||
d_hidden,
|
kernel_size,
|
||||||
kernel_size,
|
stride=1,
|
||||||
stride=1,
|
padding=int((kernel_size - 1) / 2),
|
||||||
padding=int((kernel_size - 1) / 2),
|
bias_attr=I.Uniform(-k, k),
|
||||||
bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
data_format='NLC') for i in range(conv_layers)
|
||||||
low=-k, high=k)),
|
|
||||||
data_format='NLC') for i in range(conv_layers)
|
|
||||||
])
|
])
|
||||||
self.p_dropout = p_dropout
|
self.p_dropout = p_dropout
|
||||||
|
|
||||||
self.hidden_size = int(d_hidden / 2)
|
self.hidden_size = int(d_hidden / 2)
|
||||||
self.lstm = nn.LSTM(
|
self.lstm = nn.LSTM(d_hidden,
|
||||||
d_hidden, self.hidden_size, direction="bidirectional")
|
self.hidden_size,
|
||||||
|
direction="bidirectional")
|
||||||
|
|
||||||
def forward(self, x, input_lens=None):
|
def forward(self, x, input_lens=None):
|
||||||
"""Calculate forward propagation of tacotron2 encoder.
|
"""Calculate forward propagation of tacotron2 encoder.
|
||||||
|
@ -223,10 +206,10 @@ class Tacotron2Encoder(nn.Layer):
|
||||||
----------
|
----------
|
||||||
x: Tensor [shape=(B, T)]
|
x: Tensor [shape=(B, T)]
|
||||||
Batch of the sequencees of padded character ids.
|
Batch of the sequencees of padded character ids.
|
||||||
|
|
||||||
text_lens: Tensor [shape=(B,)], optional
|
text_lens: Tensor [shape=(B,)], optional
|
||||||
Batch of lengths of each text input batch. Defaults to None.
|
Batch of lengths of each text input batch. Defaults to None.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
output : Tensor [shape=(B, T, C)]
|
output : Tensor [shape=(B, T, C)]
|
||||||
|
@ -234,10 +217,9 @@ class Tacotron2Encoder(nn.Layer):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
for conv_batchnorm in self.conv_batchnorms:
|
for conv_batchnorm in self.conv_batchnorms:
|
||||||
x = F.dropout(
|
x = F.dropout(F.relu(conv_batchnorm(x)),
|
||||||
F.relu(conv_batchnorm(x)),
|
self.p_dropout,
|
||||||
self.p_dropout,
|
training=self.training)
|
||||||
training=self.training)
|
|
||||||
|
|
||||||
output, _ = self.lstm(inputs=x, sequence_length=input_lens)
|
output, _ = self.lstm(inputs=x, sequence_length=input_lens)
|
||||||
return output
|
return output
|
||||||
|
@ -253,7 +235,7 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
|
|
||||||
reduction_factor: int
|
reduction_factor: int
|
||||||
The reduction factor of tacotron.
|
The reduction factor of tacotron.
|
||||||
|
|
||||||
d_encoder: int
|
d_encoder: int
|
||||||
The hidden size of encoder.
|
The hidden size of encoder.
|
||||||
|
|
||||||
|
@ -265,13 +247,13 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
|
|
||||||
d_decoder_rnn: int
|
d_decoder_rnn: int
|
||||||
The decoder rnn layer hidden size.
|
The decoder rnn layer hidden size.
|
||||||
|
|
||||||
d_attention: int
|
d_attention: int
|
||||||
The hidden size of the linear layer in location sensitive attention.
|
The hidden size of the linear layer in location sensitive attention.
|
||||||
|
|
||||||
attention_filters: int
|
attention_filters: int
|
||||||
The filter size of the conv layer in location sensitive attention.
|
The filter size of the conv layer in location sensitive attention.
|
||||||
|
|
||||||
attention_kernel_size: int
|
attention_kernel_size: int
|
||||||
The kernel size of the conv layer in location sensitive attention.
|
The kernel size of the conv layer in location sensitive attention.
|
||||||
|
|
||||||
|
@ -284,20 +266,11 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
p_decoder_dropout: float
|
p_decoder_dropout: float
|
||||||
The droput probability in decoder.
|
The droput probability in decoder.
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, d_mels: int, reduction_factor: int, d_encoder: int,
|
||||||
def __init__(self,
|
d_prenet: int, d_attention_rnn: int, d_decoder_rnn: int,
|
||||||
d_mels: int,
|
d_attention: int, attention_filters: int,
|
||||||
reduction_factor: int,
|
attention_kernel_size: int, p_prenet_dropout: float,
|
||||||
d_encoder: int,
|
p_attention_dropout: float, p_decoder_dropout: float):
|
||||||
d_prenet: int,
|
|
||||||
d_attention_rnn: int,
|
|
||||||
d_decoder_rnn: int,
|
|
||||||
d_attention: int,
|
|
||||||
attention_filters: int,
|
|
||||||
attention_kernel_size: int,
|
|
||||||
p_prenet_dropout: float,
|
|
||||||
p_attention_dropout: float,
|
|
||||||
p_decoder_dropout: float):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.d_mels = d_mels
|
self.d_mels = d_mels
|
||||||
self.reduction_factor = reduction_factor
|
self.reduction_factor = reduction_factor
|
||||||
|
@ -307,28 +280,45 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
self.p_attention_dropout = p_attention_dropout
|
self.p_attention_dropout = p_attention_dropout
|
||||||
self.p_decoder_dropout = p_decoder_dropout
|
self.p_decoder_dropout = p_decoder_dropout
|
||||||
|
|
||||||
self.prenet = DecoderPreNet(
|
self.prenet = DecoderPreNet(d_mels * reduction_factor,
|
||||||
d_mels * reduction_factor,
|
d_prenet,
|
||||||
d_prenet,
|
d_prenet,
|
||||||
d_prenet,
|
dropout_rate=p_prenet_dropout)
|
||||||
dropout_rate=p_prenet_dropout)
|
|
||||||
|
|
||||||
|
# attention_rnn takes attention's context vector has an
|
||||||
|
# auxiliary input
|
||||||
self.attention_rnn = nn.LSTMCell(d_prenet + d_encoder, d_attention_rnn)
|
self.attention_rnn = nn.LSTMCell(d_prenet + d_encoder, d_attention_rnn)
|
||||||
|
|
||||||
self.attention_layer = LocationSensitiveAttention(
|
self.attention_layer = LocationSensitiveAttention(
|
||||||
d_attention_rnn, d_encoder, d_attention, attention_filters,
|
d_attention_rnn, d_encoder, d_attention, attention_filters,
|
||||||
attention_kernel_size)
|
attention_kernel_size)
|
||||||
|
|
||||||
|
# decoder_rnn takes prenet's output and attention_rnn's input
|
||||||
|
# as input
|
||||||
self.decoder_rnn = nn.LSTMCell(d_attention_rnn + d_encoder,
|
self.decoder_rnn = nn.LSTMCell(d_attention_rnn + d_encoder,
|
||||||
d_decoder_rnn)
|
d_decoder_rnn)
|
||||||
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
|
self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder,
|
||||||
d_mels * reduction_factor)
|
d_mels * reduction_factor)
|
||||||
self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1)
|
|
||||||
|
# states - temporary attributes
|
||||||
|
self.attention_hidden = None
|
||||||
|
self.attention_cell = None
|
||||||
|
|
||||||
|
self.decoder_hidden = None
|
||||||
|
self.decoder_cell = None
|
||||||
|
|
||||||
|
self.attention_weights = None
|
||||||
|
self.attention_weights_cum = None
|
||||||
|
self.attention_context = None
|
||||||
|
|
||||||
|
self.key = None
|
||||||
|
self.mask = None
|
||||||
|
self.processed_key = None
|
||||||
|
|
||||||
def _initialize_decoder_states(self, key):
|
def _initialize_decoder_states(self, key):
|
||||||
"""init states be used in decoder
|
"""init states be used in decoder
|
||||||
"""
|
"""
|
||||||
batch_size = key.shape[0]
|
batch_size, encoder_steps, _ = key.shape
|
||||||
MAX_TIME = key.shape[1]
|
|
||||||
|
|
||||||
self.attention_hidden = paddle.zeros(
|
self.attention_hidden = paddle.zeros(
|
||||||
shape=[batch_size, self.d_attention_rnn], dtype=key.dtype)
|
shape=[batch_size, self.d_attention_rnn], dtype=key.dtype)
|
||||||
|
@ -341,27 +331,27 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype)
|
shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype)
|
||||||
|
|
||||||
self.attention_weights = paddle.zeros(
|
self.attention_weights = paddle.zeros(
|
||||||
shape=[batch_size, MAX_TIME], dtype=key.dtype)
|
shape=[batch_size, encoder_steps], dtype=key.dtype)
|
||||||
self.attention_weights_cum = paddle.zeros(
|
self.attention_weights_cum = paddle.zeros(
|
||||||
shape=[batch_size, MAX_TIME], dtype=key.dtype)
|
shape=[batch_size, encoder_steps], dtype=key.dtype)
|
||||||
self.attention_context = paddle.zeros(
|
self.attention_context = paddle.zeros(
|
||||||
shape=[batch_size, self.d_encoder], dtype=key.dtype)
|
shape=[batch_size, self.d_encoder], dtype=key.dtype)
|
||||||
|
|
||||||
self.key = key #[B, T, C]
|
self.key = key # [B, T, C]
|
||||||
self.processed_key = self.attention_layer.key_layer(key) #[B, T, C]
|
# pre-compute projected keys to improve efficiency
|
||||||
|
self.processed_key = self.attention_layer.key_layer(key) # [B, T, C]
|
||||||
|
|
||||||
def _decode(self, query):
|
def _decode(self, query):
|
||||||
"""decode one time step
|
"""decode one time step
|
||||||
"""
|
"""
|
||||||
cell_input = paddle.concat([query, self.attention_context], axis=-1)
|
cell_input = paddle.concat([query, self.attention_context], axis=-1)
|
||||||
|
|
||||||
# The first lstm layer
|
# The first lstm layer (or spec encoder lstm)
|
||||||
_, (self.attention_hidden, self.attention_cell) = self.attention_rnn(
|
_, (self.attention_hidden, self.attention_cell) = self.attention_rnn(
|
||||||
cell_input, (self.attention_hidden, self.attention_cell))
|
cell_input, (self.attention_hidden, self.attention_cell))
|
||||||
self.attention_hidden = F.dropout(
|
self.attention_hidden = F.dropout(self.attention_hidden,
|
||||||
self.attention_hidden,
|
self.p_attention_dropout,
|
||||||
self.p_attention_dropout,
|
training=self.training)
|
||||||
training=self.training)
|
|
||||||
|
|
||||||
# Loaction sensitive attention
|
# Loaction sensitive attention
|
||||||
attention_weights_cat = paddle.stack(
|
attention_weights_cat = paddle.stack(
|
||||||
|
@ -371,23 +361,21 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
attention_weights_cat, self.mask)
|
attention_weights_cat, self.mask)
|
||||||
self.attention_weights_cum += self.attention_weights
|
self.attention_weights_cum += self.attention_weights
|
||||||
|
|
||||||
# The second lstm layer
|
# The second lstm layer (or spec decoder lstm)
|
||||||
decoder_input = paddle.concat(
|
decoder_input = paddle.concat(
|
||||||
[self.attention_hidden, self.attention_context], axis=-1)
|
[self.attention_hidden, self.attention_context], axis=-1)
|
||||||
_, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn(
|
_, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn(
|
||||||
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
||||||
self.decoder_hidden = F.dropout(
|
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
||||||
self.decoder_hidden,
|
p=self.p_decoder_dropout,
|
||||||
p=self.p_decoder_dropout,
|
training=self.training)
|
||||||
training=self.training)
|
|
||||||
|
|
||||||
# decode output one step
|
# decode output one step
|
||||||
decoder_hidden_attention_context = paddle.concat(
|
decoder_hidden_attention_context = paddle.concat(
|
||||||
[self.decoder_hidden, self.attention_context], axis=-1)
|
[self.decoder_hidden, self.attention_context], axis=-1)
|
||||||
decoder_output = self.linear_projection(
|
decoder_output = self.linear_projection(
|
||||||
decoder_hidden_attention_context)
|
decoder_hidden_attention_context)
|
||||||
stop_logit = self.stop_layer(decoder_hidden_attention_context)
|
return decoder_output, self.attention_weights
|
||||||
return decoder_output, stop_logit, self.attention_weights
|
|
||||||
|
|
||||||
def forward(self, keys, querys, mask):
|
def forward(self, keys, querys, mask):
|
||||||
"""Calculate forward propagation of tacotron2 decoder.
|
"""Calculate forward propagation of tacotron2 decoder.
|
||||||
|
@ -396,117 +384,105 @@ class Tacotron2Decoder(nn.Layer):
|
||||||
----------
|
----------
|
||||||
keys: Tensor[shape=(B, T_key, C)]
|
keys: Tensor[shape=(B, T_key, C)]
|
||||||
Batch of the sequences of padded output from encoder.
|
Batch of the sequences of padded output from encoder.
|
||||||
|
|
||||||
querys: Tensor[shape(B, T_query, C)]
|
querys: Tensor[shape(B, T_query, C)]
|
||||||
Batch of the sequences of padded mel spectrogram.
|
Batch of the sequences of padded mel spectrogram.
|
||||||
|
|
||||||
mask: Tensor
|
mask: Tensor
|
||||||
Mask generated with text length. Shape should be (B, T_key, T_query) or broadcastable shape.
|
Mask generated with text length. Shape should be (B, T_key, 1).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
mel_output: Tensor [shape=(B, T_query, C)]
|
mel_output: Tensor [shape=(B, T_query, C)]
|
||||||
Output sequence of features.
|
Output sequence of features.
|
||||||
|
|
||||||
stop_logits: Tensor [shape=(B, T_query)]
|
|
||||||
Output sequence of stop logits.
|
|
||||||
|
|
||||||
alignments: Tensor [shape=(B, T_query, T_key)]
|
alignments: Tensor [shape=(B, T_query, T_key)]
|
||||||
Attention weights.
|
Attention weights.
|
||||||
"""
|
"""
|
||||||
querys = paddle.reshape(
|
|
||||||
querys,
|
|
||||||
[querys.shape[0], querys.shape[1] // self.reduction_factor, -1])
|
|
||||||
querys = paddle.concat(
|
|
||||||
[
|
|
||||||
paddle.zeros(
|
|
||||||
shape=[querys.shape[0], 1, querys.shape[-1]],
|
|
||||||
dtype=querys.dtype), querys
|
|
||||||
],
|
|
||||||
axis=1)
|
|
||||||
querys = self.prenet(querys)
|
|
||||||
|
|
||||||
self._initialize_decoder_states(keys)
|
self._initialize_decoder_states(keys)
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
|
|
||||||
mel_outputs, stop_logits, alignments = [], [], []
|
querys = paddle.reshape(
|
||||||
while len(mel_outputs) < querys.shape[
|
querys,
|
||||||
1] - 1: # Ignore the last time step
|
[querys.shape[0], querys.shape[1] // self.reduction_factor, -1])
|
||||||
|
start_step = paddle.zeros(shape=[querys.shape[0], 1, querys.shape[-1]],
|
||||||
|
dtype=querys.dtype)
|
||||||
|
querys = paddle.concat([start_step, querys], axis=1)
|
||||||
|
|
||||||
|
querys = self.prenet(querys)
|
||||||
|
|
||||||
|
mel_outputs, alignments = [], []
|
||||||
|
# Ignore the last time step
|
||||||
|
while len(mel_outputs) < querys.shape[1] - 1:
|
||||||
query = querys[:, len(mel_outputs), :]
|
query = querys[:, len(mel_outputs), :]
|
||||||
mel_output, stop_logit, attention_weights = self._decode(query)
|
mel_output, attention_weights = self._decode(query)
|
||||||
mel_outputs += [mel_output]
|
mel_outputs.append(mel_output)
|
||||||
stop_logits += [stop_logit]
|
alignments.append(attention_weights)
|
||||||
alignments += [attention_weights]
|
|
||||||
|
|
||||||
alignments = paddle.stack(alignments, axis=1)
|
alignments = paddle.stack(alignments, axis=1)
|
||||||
stop_logits = paddle.concat(stop_logits, axis=1)
|
|
||||||
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||||
|
|
||||||
return mel_outputs, stop_logits, alignments
|
return mel_outputs, alignments
|
||||||
|
|
||||||
def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000):
|
def infer(self, key, max_decoder_steps=1000):
|
||||||
"""Calculate forward propagation of tacotron2 decoder.
|
"""Calculate forward propagation of tacotron2 decoder.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
keys: Tensor [shape=(B, T_key, C)]
|
keys: Tensor [shape=(B, T_key, C)]
|
||||||
Batch of the sequences of padded output from encoder.
|
Batch of the sequences of padded output from encoder.
|
||||||
|
|
||||||
stop_threshold: float, optional
|
|
||||||
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
|
||||||
|
|
||||||
max_decoder_steps: int, optional
|
max_decoder_steps: int, optional
|
||||||
Number of max step when synthesize. Defaults to 1000.
|
Number of max step when synthesize. Defaults to 1000.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
mel_output: Tensor [shape=(B, T_mel, C)]
|
mel_output: Tensor [shape=(B, T_mel, C)]
|
||||||
Output sequence of features.
|
Output sequence of features.
|
||||||
|
|
||||||
stop_logits: Tensor [shape=(B, T_mel)]
|
|
||||||
Output sequence of stop logits.
|
|
||||||
|
|
||||||
alignments: Tensor [shape=(B, T_mel, T_key)]
|
alignments: Tensor [shape=(B, T_mel, T_key)]
|
||||||
Attention weights.
|
Attention weights.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
query = paddle.zeros(
|
encoder_steps = key.shape[1]
|
||||||
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
|
||||||
dtype=key.dtype) #[B, C]
|
|
||||||
|
|
||||||
self._initialize_decoder_states(key)
|
self._initialize_decoder_states(key)
|
||||||
self.mask = None
|
self.mask = None # mask is not needed for single instance inference
|
||||||
|
|
||||||
mel_outputs, stop_logits, alignments = [], [], []
|
# [B, C]
|
||||||
|
start_step = paddle.zeros(
|
||||||
|
shape=[key.shape[0], self.d_mels * self.reduction_factor],
|
||||||
|
dtype=key.dtype)
|
||||||
|
query = start_step # [B, C]
|
||||||
|
|
||||||
|
mel_outputs, alignments = [], []
|
||||||
while True:
|
while True:
|
||||||
query = self.prenet(query)
|
query = self.prenet(query)
|
||||||
mel_output, stop_logit, alignment = self._decode(query)
|
mel_output, alignment = self._decode(query)
|
||||||
|
|
||||||
mel_outputs += [mel_output]
|
mel_outputs.append(mel_output)
|
||||||
stop_logits += [stop_logit]
|
alignments.append(alignment) # (B=1, T)
|
||||||
alignments += [alignment]
|
|
||||||
|
|
||||||
if F.sigmoid(stop_logit) > stop_threshold:
|
if int(paddle.argmax(alignment[0])) == encoder_steps - 1:
|
||||||
|
print("Text content exhausted, synthesize stops.")
|
||||||
break
|
break
|
||||||
elif len(mel_outputs) == max_decoder_steps:
|
if len(mel_outputs) == max_decoder_steps:
|
||||||
print("Warning! Reached max decoder steps!!!")
|
print("Warning! Reached max decoder steps!!!")
|
||||||
break
|
break
|
||||||
|
|
||||||
query = mel_output
|
query = mel_output
|
||||||
|
|
||||||
alignments = paddle.stack(alignments, axis=1)
|
alignments = paddle.stack(alignments, axis=1)
|
||||||
stop_logits = paddle.concat(stop_logits, axis=1)
|
|
||||||
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
mel_outputs = paddle.stack(mel_outputs, axis=1)
|
||||||
|
|
||||||
return mel_outputs, stop_logits, alignments
|
return mel_outputs, alignments
|
||||||
|
|
||||||
|
|
||||||
class Tacotron2(nn.Layer):
|
class Tacotron2(nn.Layer):
|
||||||
"""Tacotron2 model for end-to-end text-to-speech (E2E-TTS).
|
"""Tacotron2 model for end-to-end text-to-speech (E2E-TTS).
|
||||||
|
|
||||||
This is a model of Spectrogram prediction network in Tacotron2 described
|
This is a model of Spectrogram prediction network in Tacotron2 described
|
||||||
in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
|
in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram
|
||||||
<https://arxiv.org/abs/1712.05884>`_,
|
Predictions <https://arxiv.org/abs/1712.05884>`_,
|
||||||
which converts the sequence of characters
|
which converts the sequence of characters
|
||||||
into the sequence of mel spectrogram.
|
into the sequence of mel spectrogram.
|
||||||
|
|
||||||
|
@ -517,10 +493,10 @@ class Tacotron2(nn.Layer):
|
||||||
|
|
||||||
d_mels: int
|
d_mels: int
|
||||||
Number of mel bands.
|
Number of mel bands.
|
||||||
|
|
||||||
d_encoder: int
|
d_encoder: int
|
||||||
Hidden size in encoder module.
|
Hidden size in encoder module.
|
||||||
|
|
||||||
encoder_conv_layers: int
|
encoder_conv_layers: int
|
||||||
Number of conv layers in encoder.
|
Number of conv layers in encoder.
|
||||||
|
|
||||||
|
@ -538,7 +514,7 @@ class Tacotron2(nn.Layer):
|
||||||
|
|
||||||
attention_filters: int
|
attention_filters: int
|
||||||
Filter size of the conv layer in location sensitive attention.
|
Filter size of the conv layer in location sensitive attention.
|
||||||
|
|
||||||
attention_kernel_size: int
|
attention_kernel_size: int
|
||||||
Kernel size of the conv layer in location sensitive attention.
|
Kernel size of the conv layer in location sensitive attention.
|
||||||
|
|
||||||
|
@ -573,38 +549,34 @@ class Tacotron2(nn.Layer):
|
||||||
Droput probability in postnet.
|
Droput probability in postnet.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
frontend: parakeet.frontend.Phonetics,
|
vocab_size,
|
||||||
d_mels: int=80,
|
d_mels: int = 80,
|
||||||
d_encoder: int=512,
|
d_encoder: int = 512,
|
||||||
encoder_conv_layers: int=3,
|
encoder_conv_layers: int = 3,
|
||||||
encoder_kernel_size: int=5,
|
encoder_kernel_size: int = 5,
|
||||||
d_prenet: int=256,
|
d_prenet: int = 256,
|
||||||
d_attention_rnn: int=1024,
|
d_attention_rnn: int = 1024,
|
||||||
d_decoder_rnn: int=1024,
|
d_decoder_rnn: int = 1024,
|
||||||
attention_filters: int=32,
|
attention_filters: int = 32,
|
||||||
attention_kernel_size: int=31,
|
attention_kernel_size: int = 31,
|
||||||
d_attention: int=128,
|
d_attention: int = 128,
|
||||||
d_postnet: int=512,
|
d_postnet: int = 512,
|
||||||
postnet_kernel_size: int=5,
|
postnet_kernel_size: int = 5,
|
||||||
postnet_conv_layers: int=5,
|
postnet_conv_layers: int = 5,
|
||||||
reduction_factor: int=1,
|
reduction_factor: int = 1,
|
||||||
p_encoder_dropout: float=0.5,
|
p_encoder_dropout: float = 0.5,
|
||||||
p_prenet_dropout: float=0.5,
|
p_prenet_dropout: float = 0.5,
|
||||||
p_attention_dropout: float=0.1,
|
p_attention_dropout: float = 0.1,
|
||||||
p_decoder_dropout: float=0.1,
|
p_decoder_dropout: float = 0.1,
|
||||||
p_postnet_dropout: float=0.5):
|
p_postnet_dropout: float = 0.5):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.frontend = frontend
|
std = math.sqrt(2.0 / (vocab_size + d_encoder))
|
||||||
std = math.sqrt(2.0 / (self.frontend.vocab_size + d_encoder))
|
|
||||||
val = math.sqrt(3.0) * std # uniform bounds for std
|
val = math.sqrt(3.0) * std # uniform bounds for std
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(vocab_size,
|
||||||
self.frontend.vocab_size,
|
d_encoder,
|
||||||
d_encoder,
|
weight_attr=I.Uniform(-val, val))
|
||||||
weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform(
|
|
||||||
low=-val, high=val)))
|
|
||||||
self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers,
|
self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers,
|
||||||
encoder_kernel_size, p_encoder_dropout)
|
encoder_kernel_size, p_encoder_dropout)
|
||||||
self.decoder = Tacotron2Decoder(
|
self.decoder = Tacotron2Decoder(
|
||||||
|
@ -612,12 +584,11 @@ class Tacotron2(nn.Layer):
|
||||||
d_decoder_rnn, d_attention, attention_filters,
|
d_decoder_rnn, d_attention, attention_filters,
|
||||||
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
attention_kernel_size, p_prenet_dropout, p_attention_dropout,
|
||||||
p_decoder_dropout)
|
p_decoder_dropout)
|
||||||
self.postnet = DecoderPostNet(
|
self.postnet = DecoderPostNet(d_mels=d_mels * reduction_factor,
|
||||||
d_mels=d_mels * reduction_factor,
|
d_hidden=d_postnet,
|
||||||
d_hidden=d_postnet,
|
kernel_size=postnet_kernel_size,
|
||||||
kernel_size=postnet_kernel_size,
|
num_layers=postnet_conv_layers,
|
||||||
num_layers=postnet_conv_layers,
|
dropout=p_postnet_dropout)
|
||||||
dropout=p_postnet_dropout)
|
|
||||||
|
|
||||||
def forward(self, text_inputs, mels, text_lens, output_lens=None):
|
def forward(self, text_inputs, mels, text_lens, output_lens=None):
|
||||||
"""Calculate forward propagation of tacotron2.
|
"""Calculate forward propagation of tacotron2.
|
||||||
|
@ -626,20 +597,20 @@ class Tacotron2(nn.Layer):
|
||||||
----------
|
----------
|
||||||
text_inputs: Tensor [shape=(B, T_text)]
|
text_inputs: Tensor [shape=(B, T_text)]
|
||||||
Batch of the sequencees of padded character ids.
|
Batch of the sequencees of padded character ids.
|
||||||
|
|
||||||
mels: Tensor [shape(B, T_mel, C)]
|
mels: Tensor [shape(B, T_mel, C)]
|
||||||
Batch of the sequences of padded mel spectrogram.
|
Batch of the sequences of padded mel spectrogram.
|
||||||
|
|
||||||
text_lens: Tensor [shape=(B,)]
|
text_lens: Tensor [shape=(B,)]
|
||||||
Batch of lengths of each text input batch.
|
Batch of lengths of each text input batch.
|
||||||
|
|
||||||
output_lens: Tensor [shape=(B,)], optional
|
output_lens: Tensor [shape=(B,)], optional
|
||||||
Batch of lengths of each mels batch. Defaults to None.
|
Batch of lengths of each mels batch. Defaults to None.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
outputs : Dict[str, Tensor]
|
outputs : Dict[str, Tensor]
|
||||||
|
|
||||||
mel_output: output sequence of features (B, T_mel, C);
|
mel_output: output sequence of features (B, T_mel, C);
|
||||||
|
|
||||||
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C);
|
mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C);
|
||||||
|
@ -651,47 +622,41 @@ class Tacotron2(nn.Layer):
|
||||||
embedded_inputs = self.embedding(text_inputs)
|
embedded_inputs = self.embedding(text_inputs)
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lens)
|
encoder_outputs = self.encoder(embedded_inputs, text_lens)
|
||||||
|
|
||||||
mask = paddle.tensor.unsqueeze(
|
# [B, T_enc, 1]
|
||||||
paddle.fluid.layers.sequence_mask(
|
mask = paddle.unsqueeze(
|
||||||
x=text_lens, dtype=encoder_outputs.dtype), [-1])
|
sequence_mask(x=text_lens, dtype=encoder_outputs.dtype), [-1])
|
||||||
mel_outputs, stop_logits, alignments = self.decoder(
|
mel_outputs, alignments = self.decoder(encoder_outputs,
|
||||||
encoder_outputs, mels, mask=mask)
|
mels,
|
||||||
|
mask=mask)
|
||||||
|
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
|
||||||
if output_lens is not None:
|
if output_lens is not None:
|
||||||
mask = paddle.tensor.unsqueeze(
|
# [B, T_dec, 1]
|
||||||
paddle.fluid.layers.sequence_mask(x=output_lens),
|
mask = paddle.unsqueeze(sequence_mask(x=output_lens), [-1])
|
||||||
[-1]) #[B, T, 1]
|
mel_outputs = mel_outputs * mask # [B, T, C]
|
||||||
mel_outputs = mel_outputs * mask #[B, T, C]
|
mel_outputs_postnet = mel_outputs_postnet * mask # [B, T, C]
|
||||||
mel_outputs_postnet = mel_outputs_postnet * mask #[B, T, C]
|
|
||||||
stop_logits = stop_logits * mask[:, :, 0] + (1 - mask[:, :, 0]
|
|
||||||
) * 1e3 #[B, T]
|
|
||||||
outputs = {
|
outputs = {
|
||||||
"mel_output": mel_outputs,
|
"mel_output": mel_outputs,
|
||||||
"mel_outputs_postnet": mel_outputs_postnet,
|
"mel_outputs_postnet": mel_outputs_postnet,
|
||||||
"stop_logits": stop_logits,
|
|
||||||
"alignments": alignments
|
"alignments": alignments
|
||||||
}
|
}
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@paddle.no_grad()
|
@paddle.no_grad()
|
||||||
def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000):
|
def infer(self, text_inputs, max_decoder_steps=1000):
|
||||||
"""Generate the mel sepctrogram of features given the sequences of character ids.
|
"""Generate the mel sepctrogram of features given the sequences of character ids.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
text_inputs: Tensor [shape=(B, T_text)]
|
text_inputs: Tensor [shape=(B, T_text)]
|
||||||
Batch of the sequencees of padded character ids.
|
Batch of the sequencees of padded character ids.
|
||||||
|
|
||||||
stop_threshold: float, optional
|
|
||||||
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
|
||||||
|
|
||||||
max_decoder_steps: int, optional
|
max_decoder_steps: int, optional
|
||||||
Number of max step when synthesize. Defaults to 1000.
|
Number of max step when synthesize. Defaults to 1000.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
outputs : Dict[str, Tensor]
|
outputs : Dict[str, Tensor]
|
||||||
|
@ -706,10 +671,8 @@ class Tacotron2(nn.Layer):
|
||||||
"""
|
"""
|
||||||
embedded_inputs = self.embedding(text_inputs)
|
embedded_inputs = self.embedding(text_inputs)
|
||||||
encoder_outputs = self.encoder(embedded_inputs)
|
encoder_outputs = self.encoder(embedded_inputs)
|
||||||
mel_outputs, stop_logits, alignments = self.decoder.infer(
|
mel_outputs, alignments = self.decoder.infer(
|
||||||
encoder_outputs,
|
encoder_outputs, max_decoder_steps=max_decoder_steps)
|
||||||
stop_threshold=stop_threshold,
|
|
||||||
max_decoder_steps=max_decoder_steps)
|
|
||||||
|
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||||
|
@ -717,62 +680,32 @@ class Tacotron2(nn.Layer):
|
||||||
outputs = {
|
outputs = {
|
||||||
"mel_output": mel_outputs,
|
"mel_output": mel_outputs,
|
||||||
"mel_outputs_postnet": mel_outputs_postnet,
|
"mel_outputs_postnet": mel_outputs_postnet,
|
||||||
"stop_logits": stop_logits,
|
|
||||||
"alignments": alignments
|
"alignments": alignments
|
||||||
}
|
}
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@paddle.no_grad()
|
|
||||||
def predict(self, text, stop_threshold=0.5, max_decoder_steps=1000):
|
|
||||||
"""Generate the mel sepctrogram of features given the sequenc of characters.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
text: str
|
|
||||||
Sequence of characters.
|
|
||||||
|
|
||||||
stop_threshold: float, optional
|
|
||||||
Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5.
|
|
||||||
|
|
||||||
max_decoder_steps: int, optional
|
|
||||||
Number of max step when synthesize. Defaults to 1000.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
outputs : Dict[str, Tensor]
|
|
||||||
|
|
||||||
mel_outputs_postnet: output sequence of sepctrogram after postnet (T_mel, C);
|
|
||||||
|
|
||||||
alignments: attention weights (T_mel, T_text).
|
|
||||||
"""
|
|
||||||
ids = np.asarray(self.frontend(text))
|
|
||||||
ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0])
|
|
||||||
outputs = self.infer(ids, stop_threshold, max_decoder_steps)
|
|
||||||
return outputs['mel_outputs_postnet'][0].numpy(), outputs[
|
|
||||||
'alignments'][0].numpy()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, frontend, config, checkpoint_path):
|
def from_pretrained(cls, config, checkpoint_path):
|
||||||
"""Build a tacotron2 model from a pretrained model.
|
"""Build a tacotron2 model from a pretrained model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
frontend: parakeet.frontend.Phonetics
|
frontend: parakeet.frontend.Phonetics
|
||||||
Frontend used to preprocess text.
|
Frontend used to preprocess text.
|
||||||
|
|
||||||
config: yacs.config.CfgNode
|
config: yacs.config.CfgNode
|
||||||
Model configs.
|
Model configs.
|
||||||
|
|
||||||
checkpoint_path: Path or str
|
checkpoint_path: Path or str
|
||||||
The path of pretrained model checkpoint, without extension name.
|
The path of pretrained model checkpoint, without extension name.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Tacotron2
|
Tacotron2
|
||||||
The model build from pretrined result.
|
The model build from pretrined result.
|
||||||
"""
|
"""
|
||||||
model = cls(frontend,
|
model = cls(vocab_size=config.model.vocab_size,
|
||||||
d_mels=config.data.d_mels,
|
d_mels=config.data.d_mels,
|
||||||
d_encoder=config.model.d_encoder,
|
d_encoder=config.model.d_encoder,
|
||||||
encoder_conv_layers=config.model.encoder_conv_layers,
|
encoder_conv_layers=config.model.encoder_conv_layers,
|
||||||
|
@ -800,50 +733,46 @@ class Tacotron2(nn.Layer):
|
||||||
class Tacotron2Loss(nn.Layer):
|
class Tacotron2Loss(nn.Layer):
|
||||||
""" Tacotron2 Loss module
|
""" Tacotron2 Loss module
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, sigma=0.2):
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.spec_criterion = nn.MSELoss()
|
||||||
|
self.attn_criterion = guided_attention_loss
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
def forward(self, mel_outputs, mel_outputs_postnet, stop_logits,
|
def forward(self, mel_outputs, mel_outputs_postnet, mel_targets,
|
||||||
mel_targets, stop_tokens):
|
attention_weights, slens, plens):
|
||||||
"""Calculate tacotron2 loss.
|
"""Calculate tacotron2 loss.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
mel_outputs: Tensor [shape=(B, T_mel, C)]
|
mel_outputs: Tensor [shape=(B, T_mel, C)]
|
||||||
Output mel spectrogram sequence.
|
Output mel spectrogram sequence.
|
||||||
|
|
||||||
mel_outputs_postnet: Tensor [shape(B, T_mel, C)]
|
mel_outputs_postnet: Tensor [shape(B, T_mel, C)]
|
||||||
Output mel spectrogram sequence after postnet.
|
Output mel spectrogram sequence after postnet.
|
||||||
|
|
||||||
stop_logits: Tensor [shape=(B, T_mel)]
|
|
||||||
Output sequence of stop logits befor sigmoid.
|
|
||||||
|
|
||||||
mel_targets: Tensor [shape=(B, T_mel, C)]
|
mel_targets: Tensor [shape=(B, T_mel, C)]
|
||||||
Target mel spectrogram sequence.
|
Target mel spectrogram sequence.
|
||||||
|
|
||||||
stop_tokens: Tensor [shape=(B,)]
|
|
||||||
Target stop token.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
losses : Dict[str, Tensor]
|
losses : Dict[str, Tensor]
|
||||||
|
|
||||||
loss: the sum of the other three losses;
|
loss: the sum of the other three losses;
|
||||||
|
|
||||||
mel_loss: MSE loss compute by mel_targets and mel_outputs;
|
mel_loss: MSE loss compute by mel_targets and mel_outputs;
|
||||||
|
|
||||||
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet;
|
post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet;
|
||||||
|
|
||||||
stop_loss: stop loss computed by stop_logits and stop token.
|
|
||||||
"""
|
"""
|
||||||
mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets)
|
mel_loss = self.spec_criterion(mel_outputs, mel_targets)
|
||||||
post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets)
|
post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets)
|
||||||
stop_loss = paddle.nn.BCEWithLogitsLoss()(stop_logits, stop_tokens)
|
gal_loss = self.attn_criterion(attention_weights, slens, plens,
|
||||||
total_loss = mel_loss + post_mel_loss + stop_loss
|
self.sigma)
|
||||||
losses = dict(
|
total_loss = mel_loss + post_mel_loss + gal_loss
|
||||||
loss=total_loss,
|
losses = {
|
||||||
mel_loss=mel_loss,
|
"loss": total_loss,
|
||||||
post_mel_loss=post_mel_loss,
|
"mel_loss": mel_loss,
|
||||||
stop_loss=stop_loss)
|
"post_mel_loss": post_mel_loss,
|
||||||
|
"guided_attn_loss": gal_loss
|
||||||
|
}
|
||||||
return losses
|
return losses
|
||||||
|
|
|
@ -143,9 +143,9 @@ class MonoheadAttention(nn.Layer):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_dim: int,
|
model_dim: int,
|
||||||
dropout: float=0.0,
|
dropout: float = 0.0,
|
||||||
k_dim: int=None,
|
k_dim: int = None,
|
||||||
v_dim: int=None):
|
v_dim: int = None):
|
||||||
super(MonoheadAttention, self).__init__()
|
super(MonoheadAttention, self).__init__()
|
||||||
k_dim = k_dim or model_dim
|
k_dim = k_dim or model_dim
|
||||||
v_dim = v_dim or model_dim
|
v_dim = v_dim or model_dim
|
||||||
|
@ -225,9 +225,9 @@ class MultiheadAttention(nn.Layer):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_dim: int,
|
model_dim: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
dropout: float=0.0,
|
dropout: float = 0.0,
|
||||||
k_dim: int=None,
|
k_dim: int = None,
|
||||||
v_dim: int=None):
|
v_dim: int = None):
|
||||||
super(MultiheadAttention, self).__init__()
|
super(MultiheadAttention, self).__init__()
|
||||||
if model_dim % num_heads != 0:
|
if model_dim % num_heads != 0:
|
||||||
raise ValueError("model_dim must be divisible by num_heads")
|
raise ValueError("model_dim must be divisible by num_heads")
|
||||||
|
@ -316,14 +316,11 @@ class LocationSensitiveAttention(nn.Layer):
|
||||||
self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False)
|
self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False)
|
||||||
self.value = nn.Linear(d_attention, 1, bias_attr=False)
|
self.value = nn.Linear(d_attention, 1, bias_attr=False)
|
||||||
|
|
||||||
#Location Layer
|
# Location Layer
|
||||||
self.location_conv = nn.Conv1D(
|
self.location_conv = nn.Conv1D(
|
||||||
2,
|
2, location_filters,
|
||||||
location_filters,
|
kernel_size=location_kernel_size,
|
||||||
location_kernel_size,
|
padding=int((location_kernel_size - 1) / 2),
|
||||||
1,
|
|
||||||
int((location_kernel_size - 1) / 2),
|
|
||||||
1,
|
|
||||||
bias_attr=False,
|
bias_attr=False,
|
||||||
data_format='NLC')
|
data_format='NLC')
|
||||||
self.location_layer = nn.Linear(
|
self.location_layer = nn.Linear(
|
||||||
|
@ -352,21 +349,22 @@ class LocationSensitiveAttention(nn.Layer):
|
||||||
Attention weights concat.
|
Attention weights concat.
|
||||||
|
|
||||||
mask : Tensor, optional
|
mask : Tensor, optional
|
||||||
The mask. Shape should be (batch_size, times_steps_q, time_steps_k) or broadcastable shape.
|
The mask. Shape should be (batch_size, times_steps_k, 1).
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
attention_context : Tensor [shape=(batch_size, time_steps_q, d_attention)]
|
attention_context : Tensor [shape=(batch_size, d_attention)]
|
||||||
The context vector.
|
The context vector.
|
||||||
|
|
||||||
attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)]
|
attention_weights : Tensor [shape=(batch_size, time_steps_k)]
|
||||||
The attention weights.
|
The attention weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
|
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
|
||||||
processed_attention_weights = self.location_layer(
|
processed_attention_weights = self.location_layer(
|
||||||
self.location_conv(attention_weights_cat))
|
self.location_conv(attention_weights_cat))
|
||||||
|
# (B, T_enc, 1)
|
||||||
alignment = self.value(
|
alignment = self.value(
|
||||||
paddle.tanh(processed_attention_weights + processed_key +
|
paddle.tanh(processed_attention_weights + processed_key +
|
||||||
processed_query))
|
processed_query))
|
||||||
|
@ -378,7 +376,7 @@ class LocationSensitiveAttention(nn.Layer):
|
||||||
attention_context = paddle.matmul(
|
attention_context = paddle.matmul(
|
||||||
attention_weights, value, transpose_x=True)
|
attention_weights, value, transpose_x=True)
|
||||||
|
|
||||||
attention_weights = paddle.squeeze(attention_weights, axis=[-1])
|
attention_weights = paddle.squeeze(attention_weights, axis=-1)
|
||||||
attention_context = paddle.squeeze(attention_context, axis=[1])
|
attention_context = paddle.squeeze(attention_context, axis=1)
|
||||||
|
|
||||||
return attention_context, attention_weights
|
return attention_context, attention_weights
|
||||||
|
|
|
@ -17,15 +17,51 @@ import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
from paddle.fluid.layers import sequence_mask
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"guided_attention_loss",
|
||||||
"weighted_mean",
|
"weighted_mean",
|
||||||
"masked_l1_loss",
|
"masked_l1_loss",
|
||||||
"masked_softmax_with_cross_entropy",
|
"masked_softmax_with_cross_entropy",
|
||||||
"diagonal_loss",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def attention_guide(dec_lens, enc_lens, N, T, g, dtype=None):
|
||||||
|
"""Build that W matrix. shape(B, T_dec, T_enc)
|
||||||
|
W[i, n, t] = 1 - exp(-(n/dec_lens[i] - t/enc_lens[i])**2 / (2g**2))
|
||||||
|
|
||||||
|
See also:
|
||||||
|
Tachibana, Hideyuki, Katsuya Uenoyama, and Shunsuke Aihara. 2017. “Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention.” ArXiv:1710.08969 [Cs, Eess], October. http://arxiv.org/abs/1710.08969.
|
||||||
|
"""
|
||||||
|
dtype = dtype or paddle.get_default_dtype()
|
||||||
|
dec_pos = paddle.arange(0, N).astype(
|
||||||
|
dtype) / dec_lens.unsqueeze(-1) # n/N # shape(B, T_dec)
|
||||||
|
enc_pos = paddle.arange(0, T).astype(
|
||||||
|
dtype) / enc_lens.unsqueeze(-1) # t/T # shape(B, T_enc)
|
||||||
|
W = 1 - paddle.exp(-(dec_pos.unsqueeze(-1) -
|
||||||
|
enc_pos.unsqueeze(1))**2 / (2 * g ** 2))
|
||||||
|
|
||||||
|
dec_mask = sequence_mask(dec_lens, maxlen=N)
|
||||||
|
enc_mask = sequence_mask(enc_lens, maxlen=T)
|
||||||
|
mask = dec_mask.unsqueeze(-1) * enc_mask.unsqueeze(1)
|
||||||
|
mask = paddle.cast(mask, W.dtype)
|
||||||
|
|
||||||
|
W *= mask
|
||||||
|
return W
|
||||||
|
|
||||||
|
|
||||||
|
def guided_attention_loss(attention_weight, dec_lens, enc_lens, g):
|
||||||
|
"""Guided attention loss, masked to excluded padding parts."""
|
||||||
|
_, N, T = attention_weight.shape
|
||||||
|
W = attention_guide(dec_lens, enc_lens, N, T, g, attention_weight.dtype)
|
||||||
|
|
||||||
|
total_tokens = (dec_lens * enc_lens).astype(W.dtype)
|
||||||
|
loss = paddle.mean(paddle.sum(
|
||||||
|
W * attention_weight, [1, 2]) / total_tokens)
|
||||||
|
return loss, W
|
||||||
|
|
||||||
|
|
||||||
def weighted_mean(input, weight):
|
def weighted_mean(input, weight):
|
||||||
"""Weighted mean. It can also be used as masked mean.
|
"""Weighted mean. It can also be used as masked mean.
|
||||||
|
|
||||||
|
@ -40,14 +76,10 @@ def weighted_mean(input, weight):
|
||||||
----------
|
----------
|
||||||
Tensor [shape=(1,)]
|
Tensor [shape=(1,)]
|
||||||
Weighted mean tensor with the same dtype as input.
|
Weighted mean tensor with the same dtype as input.
|
||||||
|
|
||||||
Warnings
|
|
||||||
---------
|
|
||||||
This is not a mathematical weighted mean. It performs weighted sum and
|
|
||||||
simple average.
|
|
||||||
"""
|
"""
|
||||||
weight = paddle.cast(weight, input.dtype)
|
weight = paddle.cast(weight, input.dtype)
|
||||||
return paddle.mean(input * weight)
|
broadcast_ratio = input.size / weight.size
|
||||||
|
return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_ratio)
|
||||||
|
|
||||||
|
|
||||||
def masked_l1_loss(prediction, target, mask):
|
def masked_l1_loss(prediction, target, mask):
|
||||||
|
@ -101,70 +133,3 @@ def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
|
||||||
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
||||||
loss = weighted_mean(ce, mask)
|
loss = weighted_mean(ce, mask)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def diagonal_loss(attentions,
|
|
||||||
input_lengths,
|
|
||||||
target_lengths,
|
|
||||||
g=0.2,
|
|
||||||
multihead=False):
|
|
||||||
"""A metric to evaluate how diagonal a attention distribution is.
|
|
||||||
|
|
||||||
It is computed for batch attention distributions. For each attention
|
|
||||||
distribution, the valid decoder time steps and encoder time steps may
|
|
||||||
differ.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
attentions : Tensor [shape=(B, T_dec, T_enc) or (B, H, T_dec, T_dec)]
|
|
||||||
The attention weights from an encoder-decoder structure.
|
|
||||||
|
|
||||||
input_lengths : Tensor [shape=(B,)]
|
|
||||||
The valid length for each encoder output.
|
|
||||||
|
|
||||||
target_lengths : Tensor [shape=(B,)]
|
|
||||||
The valid length for each decoder output.
|
|
||||||
|
|
||||||
g : float, optional
|
|
||||||
[description], by default 0.2.
|
|
||||||
|
|
||||||
multihead : bool, optional
|
|
||||||
A flag indicating whether ``attentions`` is a multihead attention's
|
|
||||||
attention distribution.
|
|
||||||
|
|
||||||
If ``True``, the shape of attention is ``(B, H, T_dec, T_dec)``, by
|
|
||||||
default False.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Tensor [shape=(1,)]
|
|
||||||
The diagonal loss.
|
|
||||||
"""
|
|
||||||
W = guided_attentions(input_lengths, target_lengths, g)
|
|
||||||
W_tensor = paddle.to_tensor(W)
|
|
||||||
if not multihead:
|
|
||||||
return paddle.mean(attentions * W_tensor)
|
|
||||||
else:
|
|
||||||
return paddle.mean(attentions * paddle.unsqueeze(W_tensor, 1))
|
|
||||||
|
|
||||||
|
|
||||||
@numba.jit(nopython=True)
|
|
||||||
def guided_attention(N, max_N, T, max_T, g):
|
|
||||||
W = np.zeros((max_T, max_N), dtype=np.float32)
|
|
||||||
for t in range(T):
|
|
||||||
for n in range(N):
|
|
||||||
W[t, n] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
|
|
||||||
# (T_dec, T_enc)
|
|
||||||
return W
|
|
||||||
|
|
||||||
|
|
||||||
def guided_attentions(input_lengths, target_lengths, g=0.2):
|
|
||||||
B = len(input_lengths)
|
|
||||||
max_input_len = input_lengths.max()
|
|
||||||
max_target_len = target_lengths.max()
|
|
||||||
W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
|
|
||||||
for b in range(B):
|
|
||||||
W[b] = guided_attention(input_lengths[b], max_input_len,
|
|
||||||
target_lengths[b], max_target_len, g)
|
|
||||||
# (B, T_dec, T_enc)
|
|
||||||
return W
|
|
||||||
|
|
Loading…
Reference in New Issue