diff --git a/examples/tacotron2/config.py b/examples/tacotron2/config.py index b14dbf9..9f4321e 100644 --- a/examples/tacotron2/config.py +++ b/examples/tacotron2/config.py @@ -37,9 +37,9 @@ _C.model = CN( encoder_kernel_size=5, # kernel size of conv layers in tacotron2 encoder d_prenet=256, # hidden size of decoder prenet d_attention_rnn=1024, # hidden size of the first rnn layer in tacotron2 decoder - d_decoder_rnn=1024, #hidden size of the second rnn layer in tacotron2 decoder + d_decoder_rnn=1024, # hidden size of the second rnn layer in tacotron2 decoder d_attention=128, # hidden size of decoder location linear layer - attention_filters=32, # number of filter in decoder location conv layer + attention_filters=32, # number of filter in decoder location conv layer attention_kernel_size=31, # kernel size of decoder location conv layer d_postnet=512, # hidden size of decoder postnet postnet_kernel_size=5, # kernel size of conv layers in postnet @@ -48,7 +48,8 @@ _C.model = CN( p_prenet_dropout=0.5, # droput probability in decoder prenet p_attention_dropout=0.1, # droput probability of first rnn layer in decoder p_decoder_dropout=0.1, # droput probability of second rnn layer in decoder - p_postnet_dropout=0.5, #droput probability in decoder postnet + p_postnet_dropout=0.5, # droput probability in decoder postnet + guided_attn_loss_sigma=0.2 # sigma in guided attention loss )) _C.training = CN( diff --git a/examples/tacotron2/train.py b/examples/tacotron2/train.py index bd635e6..b798e84 100644 --- a/examples/tacotron2/train.py +++ b/examples/tacotron2/train.py @@ -34,14 +34,14 @@ from ljspeech import LJSpeech, LJSpeechCollector class Experiment(ExperimentBase): def compute_losses(self, inputs, outputs): - _, mel_targets, _, _, stop_tokens = inputs + _, mel_targets, plens, slens, stop_tokens = inputs mel_outputs = outputs["mel_output"] mel_outputs_postnet = outputs["mel_outputs_postnet"] - stop_logits = outputs["stop_logits"] + attention_weight = outputs["alignments"] - losses = self.criterion(mel_outputs, mel_outputs_postnet, stop_logits, - mel_targets, stop_tokens) + losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets, + attention_weight, slens, plens) return losses def train_batch(self): @@ -145,7 +145,7 @@ class Experiment(ExperimentBase): weight_decay=paddle.regularizer.L2Decay( config.training.weight_decay), grad_clip=grad_clip) - criterion = Tacotron2Loss() + criterion = Tacotron2Loss(config.mode.guided_attn_loss_sigma) self.model = model self.optimizer = optimizer self.criterion = criterion diff --git a/parakeet/models/tacotron2.py b/parakeet/models/tacotron2.py index 1587108..abdfb7e 100644 --- a/parakeet/models/tacotron2.py +++ b/parakeet/models/tacotron2.py @@ -13,14 +13,16 @@ # limitations under the License. import math -import numpy as np + import paddle from paddle import nn from paddle.nn import functional as F -import parakeet +from paddle.nn import initializer as I +from paddle.fluid.layers import sequence_mask + from parakeet.modules.conv import Conv1dBatchNorm from parakeet.modules.attention import LocationSensitiveAttention -from parakeet.modules import masking +from parakeet.modules.losses import guided_attention_loss from parakeet.utils import checkpoint __all__ = ["Tacotron2", "Tacotron2Loss"] @@ -44,11 +46,7 @@ class DecoderPreNet(nn.Layer): The droput probability. """ - - def __init__(self, - d_input: int, - d_hidden: int, - d_output: int, + def __init__(self, d_input: int, d_hidden: int, d_output: int, dropout_rate: float): super().__init__() @@ -63,7 +61,7 @@ class DecoderPreNet(nn.Layer): ---------- x: Tensor [shape=(B, T_mel, C)] Batch of the sequences of padded mel spectrogram. - + Returns ------- output: Tensor [shape=(B, T_mel, C)] @@ -71,10 +69,12 @@ class DecoderPreNet(nn.Layer): """ - x = F.dropout( - F.relu(self.linear1(x)), self.dropout_rate, training=True) - output = F.dropout( - F.relu(self.linear2(x)), self.dropout_rate, training=True) + x = F.dropout(F.relu(self.linear1(x)), + self.dropout_rate, + training=True) + output = F.dropout(F.relu(self.linear2(x)), + self.dropout_rate, + training=True) return output @@ -99,13 +99,8 @@ class DecoderPostNet(nn.Layer): The droput probability. """ - - def __init__(self, - d_mels: int, - d_hidden: int, - kernel_size: int, - num_layers: int, - dropout: float): + def __init__(self, d_mels: int, d_hidden: int, kernel_size: int, + num_layers: int, dropout: float): super().__init__() self.dropout = dropout self.num_layers = num_layers @@ -115,45 +110,40 @@ class DecoderPostNet(nn.Layer): self.conv_batchnorms = nn.LayerList() k = math.sqrt(1.0 / (d_mels * kernel_size)) self.conv_batchnorms.append( - Conv1dBatchNorm( - d_mels, - d_hidden, - kernel_size=kernel_size, - padding=padding, - bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( - low=-k, high=k)), - data_format='NLC')) + Conv1dBatchNorm(d_mels, + d_hidden, + kernel_size=kernel_size, + padding=padding, + bias_attr=I.Uniform(-k, k), + data_format='NLC')) k = math.sqrt(1.0 / (d_hidden * kernel_size)) self.conv_batchnorms.extend([ - Conv1dBatchNorm( - d_hidden, - d_hidden, - kernel_size=kernel_size, - padding=padding, - bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( - low=-k, high=k)), - data_format='NLC') for i in range(1, num_layers - 1) + Conv1dBatchNorm(d_hidden, + d_hidden, + kernel_size=kernel_size, + padding=padding, + bias_attr=I.Uniform(-k, k), + data_format='NLC') + for i in range(1, num_layers - 1) ]) self.conv_batchnorms.append( - Conv1dBatchNorm( - d_hidden, - d_mels, - kernel_size=kernel_size, - padding=padding, - bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( - low=-k, high=k)), - data_format='NLC')) + Conv1dBatchNorm(d_hidden, + d_mels, + kernel_size=kernel_size, + padding=padding, + bias_attr=I.Uniform(-k, k), + data_format='NLC')) - def forward(self, input): + def forward(self, x): """Calculate forward propagation. Parameters ---------- - input: Tensor [shape=(B, T_mel, C)] + x: Tensor [shape=(B, T_mel, C)] Output sequence of features from decoder. - + Returns ------- output: Tensor [shape=(B, T_mel, C)] @@ -162,14 +152,12 @@ class DecoderPostNet(nn.Layer): """ for i in range(len(self.conv_batchnorms) - 1): - input = F.dropout( - F.tanh(self.conv_batchnorms[i](input)), - self.dropout, - training=self.training) - output = F.dropout( - self.conv_batchnorms[self.num_layers - 1](input), - self.dropout, - training=self.training) + x = F.dropout(F.tanh(self.conv_batchnorms[i](x)), + self.dropout, + training=self.training) + output = F.dropout(self.conv_batchnorms[self.num_layers - 1](x), + self.dropout, + training=self.training) return output @@ -180,41 +168,36 @@ class Tacotron2Encoder(nn.Layer): ---------- d_hidden: int The hidden size in encoder module. - + conv_layers: int The number of conv layers. kernel_size: int The kernel size of conv layers. - + p_dropout: float The droput probability. """ - - def __init__(self, - d_hidden: int, - conv_layers: int, - kernel_size: int, + def __init__(self, d_hidden: int, conv_layers: int, kernel_size: int, p_dropout: float): super().__init__() k = math.sqrt(1.0 / (d_hidden * kernel_size)) self.conv_batchnorms = paddle.nn.LayerList([ - Conv1dBatchNorm( - d_hidden, - d_hidden, - kernel_size, - stride=1, - padding=int((kernel_size - 1) / 2), - bias_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( - low=-k, high=k)), - data_format='NLC') for i in range(conv_layers) + Conv1dBatchNorm(d_hidden, + d_hidden, + kernel_size, + stride=1, + padding=int((kernel_size - 1) / 2), + bias_attr=I.Uniform(-k, k), + data_format='NLC') for i in range(conv_layers) ]) self.p_dropout = p_dropout self.hidden_size = int(d_hidden / 2) - self.lstm = nn.LSTM( - d_hidden, self.hidden_size, direction="bidirectional") + self.lstm = nn.LSTM(d_hidden, + self.hidden_size, + direction="bidirectional") def forward(self, x, input_lens=None): """Calculate forward propagation of tacotron2 encoder. @@ -223,10 +206,10 @@ class Tacotron2Encoder(nn.Layer): ---------- x: Tensor [shape=(B, T)] Batch of the sequencees of padded character ids. - + text_lens: Tensor [shape=(B,)], optional Batch of lengths of each text input batch. Defaults to None. - + Returns ------- output : Tensor [shape=(B, T, C)] @@ -234,10 +217,9 @@ class Tacotron2Encoder(nn.Layer): """ for conv_batchnorm in self.conv_batchnorms: - x = F.dropout( - F.relu(conv_batchnorm(x)), - self.p_dropout, - training=self.training) + x = F.dropout(F.relu(conv_batchnorm(x)), + self.p_dropout, + training=self.training) output, _ = self.lstm(inputs=x, sequence_length=input_lens) return output @@ -253,7 +235,7 @@ class Tacotron2Decoder(nn.Layer): reduction_factor: int The reduction factor of tacotron. - + d_encoder: int The hidden size of encoder. @@ -265,13 +247,13 @@ class Tacotron2Decoder(nn.Layer): d_decoder_rnn: int The decoder rnn layer hidden size. - + d_attention: int The hidden size of the linear layer in location sensitive attention. attention_filters: int The filter size of the conv layer in location sensitive attention. - + attention_kernel_size: int The kernel size of the conv layer in location sensitive attention. @@ -284,20 +266,11 @@ class Tacotron2Decoder(nn.Layer): p_decoder_dropout: float The droput probability in decoder. """ - - def __init__(self, - d_mels: int, - reduction_factor: int, - d_encoder: int, - d_prenet: int, - d_attention_rnn: int, - d_decoder_rnn: int, - d_attention: int, - attention_filters: int, - attention_kernel_size: int, - p_prenet_dropout: float, - p_attention_dropout: float, - p_decoder_dropout: float): + def __init__(self, d_mels: int, reduction_factor: int, d_encoder: int, + d_prenet: int, d_attention_rnn: int, d_decoder_rnn: int, + d_attention: int, attention_filters: int, + attention_kernel_size: int, p_prenet_dropout: float, + p_attention_dropout: float, p_decoder_dropout: float): super().__init__() self.d_mels = d_mels self.reduction_factor = reduction_factor @@ -307,28 +280,45 @@ class Tacotron2Decoder(nn.Layer): self.p_attention_dropout = p_attention_dropout self.p_decoder_dropout = p_decoder_dropout - self.prenet = DecoderPreNet( - d_mels * reduction_factor, - d_prenet, - d_prenet, - dropout_rate=p_prenet_dropout) + self.prenet = DecoderPreNet(d_mels * reduction_factor, + d_prenet, + d_prenet, + dropout_rate=p_prenet_dropout) + # attention_rnn takes attention's context vector has an + # auxiliary input self.attention_rnn = nn.LSTMCell(d_prenet + d_encoder, d_attention_rnn) self.attention_layer = LocationSensitiveAttention( d_attention_rnn, d_encoder, d_attention, attention_filters, attention_kernel_size) + + # decoder_rnn takes prenet's output and attention_rnn's input + # as input self.decoder_rnn = nn.LSTMCell(d_attention_rnn + d_encoder, d_decoder_rnn) self.linear_projection = nn.Linear(d_decoder_rnn + d_encoder, d_mels * reduction_factor) - self.stop_layer = nn.Linear(d_decoder_rnn + d_encoder, 1) + + # states - temporary attributes + self.attention_hidden = None + self.attention_cell = None + + self.decoder_hidden = None + self.decoder_cell = None + + self.attention_weights = None + self.attention_weights_cum = None + self.attention_context = None + + self.key = None + self.mask = None + self.processed_key = None def _initialize_decoder_states(self, key): """init states be used in decoder """ - batch_size = key.shape[0] - MAX_TIME = key.shape[1] + batch_size, encoder_steps, _ = key.shape self.attention_hidden = paddle.zeros( shape=[batch_size, self.d_attention_rnn], dtype=key.dtype) @@ -341,27 +331,27 @@ class Tacotron2Decoder(nn.Layer): shape=[batch_size, self.d_decoder_rnn], dtype=key.dtype) self.attention_weights = paddle.zeros( - shape=[batch_size, MAX_TIME], dtype=key.dtype) + shape=[batch_size, encoder_steps], dtype=key.dtype) self.attention_weights_cum = paddle.zeros( - shape=[batch_size, MAX_TIME], dtype=key.dtype) + shape=[batch_size, encoder_steps], dtype=key.dtype) self.attention_context = paddle.zeros( shape=[batch_size, self.d_encoder], dtype=key.dtype) - self.key = key #[B, T, C] - self.processed_key = self.attention_layer.key_layer(key) #[B, T, C] + self.key = key # [B, T, C] + # pre-compute projected keys to improve efficiency + self.processed_key = self.attention_layer.key_layer(key) # [B, T, C] def _decode(self, query): """decode one time step """ cell_input = paddle.concat([query, self.attention_context], axis=-1) - # The first lstm layer + # The first lstm layer (or spec encoder lstm) _, (self.attention_hidden, self.attention_cell) = self.attention_rnn( cell_input, (self.attention_hidden, self.attention_cell)) - self.attention_hidden = F.dropout( - self.attention_hidden, - self.p_attention_dropout, - training=self.training) + self.attention_hidden = F.dropout(self.attention_hidden, + self.p_attention_dropout, + training=self.training) # Loaction sensitive attention attention_weights_cat = paddle.stack( @@ -371,23 +361,21 @@ class Tacotron2Decoder(nn.Layer): attention_weights_cat, self.mask) self.attention_weights_cum += self.attention_weights - # The second lstm layer + # The second lstm layer (or spec decoder lstm) decoder_input = paddle.concat( [self.attention_hidden, self.attention_context], axis=-1) _, (self.decoder_hidden, self.decoder_cell) = self.decoder_rnn( decoder_input, (self.decoder_hidden, self.decoder_cell)) - self.decoder_hidden = F.dropout( - self.decoder_hidden, - p=self.p_decoder_dropout, - training=self.training) + self.decoder_hidden = F.dropout(self.decoder_hidden, + p=self.p_decoder_dropout, + training=self.training) # decode output one step decoder_hidden_attention_context = paddle.concat( [self.decoder_hidden, self.attention_context], axis=-1) decoder_output = self.linear_projection( decoder_hidden_attention_context) - stop_logit = self.stop_layer(decoder_hidden_attention_context) - return decoder_output, stop_logit, self.attention_weights + return decoder_output, self.attention_weights def forward(self, keys, querys, mask): """Calculate forward propagation of tacotron2 decoder. @@ -396,117 +384,105 @@ class Tacotron2Decoder(nn.Layer): ---------- keys: Tensor[shape=(B, T_key, C)] Batch of the sequences of padded output from encoder. - + querys: Tensor[shape(B, T_query, C)] Batch of the sequences of padded mel spectrogram. - + mask: Tensor - Mask generated with text length. Shape should be (B, T_key, T_query) or broadcastable shape. - + Mask generated with text length. Shape should be (B, T_key, 1). + Returns ------- mel_output: Tensor [shape=(B, T_query, C)] Output sequence of features. - stop_logits: Tensor [shape=(B, T_query)] - Output sequence of stop logits. - alignments: Tensor [shape=(B, T_query, T_key)] Attention weights. """ - querys = paddle.reshape( - querys, - [querys.shape[0], querys.shape[1] // self.reduction_factor, -1]) - querys = paddle.concat( - [ - paddle.zeros( - shape=[querys.shape[0], 1, querys.shape[-1]], - dtype=querys.dtype), querys - ], - axis=1) - querys = self.prenet(querys) - self._initialize_decoder_states(keys) self.mask = mask - mel_outputs, stop_logits, alignments = [], [], [] - while len(mel_outputs) < querys.shape[ - 1] - 1: # Ignore the last time step + querys = paddle.reshape( + querys, + [querys.shape[0], querys.shape[1] // self.reduction_factor, -1]) + start_step = paddle.zeros(shape=[querys.shape[0], 1, querys.shape[-1]], + dtype=querys.dtype) + querys = paddle.concat([start_step, querys], axis=1) + + querys = self.prenet(querys) + + mel_outputs, alignments = [], [] + # Ignore the last time step + while len(mel_outputs) < querys.shape[1] - 1: query = querys[:, len(mel_outputs), :] - mel_output, stop_logit, attention_weights = self._decode(query) - mel_outputs += [mel_output] - stop_logits += [stop_logit] - alignments += [attention_weights] + mel_output, attention_weights = self._decode(query) + mel_outputs.append(mel_output) + alignments.append(attention_weights) alignments = paddle.stack(alignments, axis=1) - stop_logits = paddle.concat(stop_logits, axis=1) mel_outputs = paddle.stack(mel_outputs, axis=1) - return mel_outputs, stop_logits, alignments + return mel_outputs, alignments - def infer(self, key, stop_threshold=0.5, max_decoder_steps=1000): + def infer(self, key, max_decoder_steps=1000): """Calculate forward propagation of tacotron2 decoder. Parameters ---------- keys: Tensor [shape=(B, T_key, C)] Batch of the sequences of padded output from encoder. - - stop_threshold: float, optional - Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5. - + max_decoder_steps: int, optional Number of max step when synthesize. Defaults to 1000. - + Returns ------- mel_output: Tensor [shape=(B, T_mel, C)] Output sequence of features. - stop_logits: Tensor [shape=(B, T_mel)] - Output sequence of stop logits. - alignments: Tensor [shape=(B, T_mel, T_key)] Attention weights. """ - query = paddle.zeros( - shape=[key.shape[0], self.d_mels * self.reduction_factor], - dtype=key.dtype) #[B, C] - + encoder_steps = key.shape[1] self._initialize_decoder_states(key) - self.mask = None + self.mask = None # mask is not needed for single instance inference - mel_outputs, stop_logits, alignments = [], [], [] + # [B, C] + start_step = paddle.zeros( + shape=[key.shape[0], self.d_mels * self.reduction_factor], + dtype=key.dtype) + query = start_step # [B, C] + + mel_outputs, alignments = [], [] while True: query = self.prenet(query) - mel_output, stop_logit, alignment = self._decode(query) + mel_output, alignment = self._decode(query) - mel_outputs += [mel_output] - stop_logits += [stop_logit] - alignments += [alignment] + mel_outputs.append(mel_output) + alignments.append(alignment) # (B=1, T) - if F.sigmoid(stop_logit) > stop_threshold: + if int(paddle.argmax(alignment[0])) == encoder_steps - 1: + print("Text content exhausted, synthesize stops.") break - elif len(mel_outputs) == max_decoder_steps: + if len(mel_outputs) == max_decoder_steps: print("Warning! Reached max decoder steps!!!") break query = mel_output alignments = paddle.stack(alignments, axis=1) - stop_logits = paddle.concat(stop_logits, axis=1) mel_outputs = paddle.stack(mel_outputs, axis=1) - return mel_outputs, stop_logits, alignments + return mel_outputs, alignments class Tacotron2(nn.Layer): """Tacotron2 model for end-to-end text-to-speech (E2E-TTS). This is a model of Spectrogram prediction network in Tacotron2 described - in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions - `_, + in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram + Predictions `_, which converts the sequence of characters into the sequence of mel spectrogram. @@ -517,10 +493,10 @@ class Tacotron2(nn.Layer): d_mels: int Number of mel bands. - + d_encoder: int Hidden size in encoder module. - + encoder_conv_layers: int Number of conv layers in encoder. @@ -538,7 +514,7 @@ class Tacotron2(nn.Layer): attention_filters: int Filter size of the conv layer in location sensitive attention. - + attention_kernel_size: int Kernel size of the conv layer in location sensitive attention. @@ -573,38 +549,34 @@ class Tacotron2(nn.Layer): Droput probability in postnet. """ - def __init__(self, - frontend: parakeet.frontend.Phonetics, - d_mels: int=80, - d_encoder: int=512, - encoder_conv_layers: int=3, - encoder_kernel_size: int=5, - d_prenet: int=256, - d_attention_rnn: int=1024, - d_decoder_rnn: int=1024, - attention_filters: int=32, - attention_kernel_size: int=31, - d_attention: int=128, - d_postnet: int=512, - postnet_kernel_size: int=5, - postnet_conv_layers: int=5, - reduction_factor: int=1, - p_encoder_dropout: float=0.5, - p_prenet_dropout: float=0.5, - p_attention_dropout: float=0.1, - p_decoder_dropout: float=0.1, - p_postnet_dropout: float=0.5): + vocab_size, + d_mels: int = 80, + d_encoder: int = 512, + encoder_conv_layers: int = 3, + encoder_kernel_size: int = 5, + d_prenet: int = 256, + d_attention_rnn: int = 1024, + d_decoder_rnn: int = 1024, + attention_filters: int = 32, + attention_kernel_size: int = 31, + d_attention: int = 128, + d_postnet: int = 512, + postnet_kernel_size: int = 5, + postnet_conv_layers: int = 5, + reduction_factor: int = 1, + p_encoder_dropout: float = 0.5, + p_prenet_dropout: float = 0.5, + p_attention_dropout: float = 0.1, + p_decoder_dropout: float = 0.1, + p_postnet_dropout: float = 0.5): super().__init__() - self.frontend = frontend - std = math.sqrt(2.0 / (self.frontend.vocab_size + d_encoder)) + std = math.sqrt(2.0 / (vocab_size + d_encoder)) val = math.sqrt(3.0) * std # uniform bounds for std - self.embedding = nn.Embedding( - self.frontend.vocab_size, - d_encoder, - weight_attr=paddle.ParamAttr(initializer=nn.initializer.Uniform( - low=-val, high=val))) + self.embedding = nn.Embedding(vocab_size, + d_encoder, + weight_attr=I.Uniform(-val, val)) self.encoder = Tacotron2Encoder(d_encoder, encoder_conv_layers, encoder_kernel_size, p_encoder_dropout) self.decoder = Tacotron2Decoder( @@ -612,12 +584,11 @@ class Tacotron2(nn.Layer): d_decoder_rnn, d_attention, attention_filters, attention_kernel_size, p_prenet_dropout, p_attention_dropout, p_decoder_dropout) - self.postnet = DecoderPostNet( - d_mels=d_mels * reduction_factor, - d_hidden=d_postnet, - kernel_size=postnet_kernel_size, - num_layers=postnet_conv_layers, - dropout=p_postnet_dropout) + self.postnet = DecoderPostNet(d_mels=d_mels * reduction_factor, + d_hidden=d_postnet, + kernel_size=postnet_kernel_size, + num_layers=postnet_conv_layers, + dropout=p_postnet_dropout) def forward(self, text_inputs, mels, text_lens, output_lens=None): """Calculate forward propagation of tacotron2. @@ -626,20 +597,20 @@ class Tacotron2(nn.Layer): ---------- text_inputs: Tensor [shape=(B, T_text)] Batch of the sequencees of padded character ids. - + mels: Tensor [shape(B, T_mel, C)] Batch of the sequences of padded mel spectrogram. - + text_lens: Tensor [shape=(B,)] Batch of lengths of each text input batch. - + output_lens: Tensor [shape=(B,)], optional Batch of lengths of each mels batch. Defaults to None. - + Returns ------- outputs : Dict[str, Tensor] - + mel_output: output sequence of features (B, T_mel, C); mel_outputs_postnet: output sequence of features after postnet (B, T_mel, C); @@ -651,47 +622,41 @@ class Tacotron2(nn.Layer): embedded_inputs = self.embedding(text_inputs) encoder_outputs = self.encoder(embedded_inputs, text_lens) - mask = paddle.tensor.unsqueeze( - paddle.fluid.layers.sequence_mask( - x=text_lens, dtype=encoder_outputs.dtype), [-1]) - mel_outputs, stop_logits, alignments = self.decoder( - encoder_outputs, mels, mask=mask) + # [B, T_enc, 1] + mask = paddle.unsqueeze( + sequence_mask(x=text_lens, dtype=encoder_outputs.dtype), [-1]) + mel_outputs, alignments = self.decoder(encoder_outputs, + mels, + mask=mask) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet if output_lens is not None: - mask = paddle.tensor.unsqueeze( - paddle.fluid.layers.sequence_mask(x=output_lens), - [-1]) #[B, T, 1] - mel_outputs = mel_outputs * mask #[B, T, C] - mel_outputs_postnet = mel_outputs_postnet * mask #[B, T, C] - stop_logits = stop_logits * mask[:, :, 0] + (1 - mask[:, :, 0] - ) * 1e3 #[B, T] + # [B, T_dec, 1] + mask = paddle.unsqueeze(sequence_mask(x=output_lens), [-1]) + mel_outputs = mel_outputs * mask # [B, T, C] + mel_outputs_postnet = mel_outputs_postnet * mask # [B, T, C] outputs = { "mel_output": mel_outputs, "mel_outputs_postnet": mel_outputs_postnet, - "stop_logits": stop_logits, "alignments": alignments } return outputs @paddle.no_grad() - def infer(self, text_inputs, stop_threshold=0.5, max_decoder_steps=1000): + def infer(self, text_inputs, max_decoder_steps=1000): """Generate the mel sepctrogram of features given the sequences of character ids. Parameters ---------- text_inputs: Tensor [shape=(B, T_text)] Batch of the sequencees of padded character ids. - - stop_threshold: float, optional - Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5. - + max_decoder_steps: int, optional Number of max step when synthesize. Defaults to 1000. - + Returns ------- outputs : Dict[str, Tensor] @@ -706,10 +671,8 @@ class Tacotron2(nn.Layer): """ embedded_inputs = self.embedding(text_inputs) encoder_outputs = self.encoder(embedded_inputs) - mel_outputs, stop_logits, alignments = self.decoder.infer( - encoder_outputs, - stop_threshold=stop_threshold, - max_decoder_steps=max_decoder_steps) + mel_outputs, alignments = self.decoder.infer( + encoder_outputs, max_decoder_steps=max_decoder_steps) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet @@ -717,62 +680,32 @@ class Tacotron2(nn.Layer): outputs = { "mel_output": mel_outputs, "mel_outputs_postnet": mel_outputs_postnet, - "stop_logits": stop_logits, "alignments": alignments } return outputs - @paddle.no_grad() - def predict(self, text, stop_threshold=0.5, max_decoder_steps=1000): - """Generate the mel sepctrogram of features given the sequenc of characters. - - Parameters - ---------- - text: str - Sequence of characters. - - stop_threshold: float, optional - Stop synthesize when stop logit is greater than this stop threshold. Defaults to 0.5. - - max_decoder_steps: int, optional - Number of max step when synthesize. Defaults to 1000. - - Returns - ------- - outputs : Dict[str, Tensor] - - mel_outputs_postnet: output sequence of sepctrogram after postnet (T_mel, C); - - alignments: attention weights (T_mel, T_text). - """ - ids = np.asarray(self.frontend(text)) - ids = paddle.unsqueeze(paddle.to_tensor(ids, dtype='int64'), [0]) - outputs = self.infer(ids, stop_threshold, max_decoder_steps) - return outputs['mel_outputs_postnet'][0].numpy(), outputs[ - 'alignments'][0].numpy() - @classmethod - def from_pretrained(cls, frontend, config, checkpoint_path): + def from_pretrained(cls, config, checkpoint_path): """Build a tacotron2 model from a pretrained model. Parameters ---------- frontend: parakeet.frontend.Phonetics Frontend used to preprocess text. - + config: yacs.config.CfgNode Model configs. - + checkpoint_path: Path or str The path of pretrained model checkpoint, without extension name. - + Returns ------- Tacotron2 The model build from pretrined result. """ - model = cls(frontend, + model = cls(vocab_size=config.model.vocab_size, d_mels=config.data.d_mels, d_encoder=config.model.d_encoder, encoder_conv_layers=config.model.encoder_conv_layers, @@ -800,50 +733,46 @@ class Tacotron2(nn.Layer): class Tacotron2Loss(nn.Layer): """ Tacotron2 Loss module """ - - def __init__(self): + def __init__(self, sigma=0.2): super().__init__() + self.spec_criterion = nn.MSELoss() + self.attn_criterion = guided_attention_loss + self.sigma = sigma - def forward(self, mel_outputs, mel_outputs_postnet, stop_logits, - mel_targets, stop_tokens): + def forward(self, mel_outputs, mel_outputs_postnet, mel_targets, + attention_weights, slens, plens): """Calculate tacotron2 loss. Parameters ---------- mel_outputs: Tensor [shape=(B, T_mel, C)] Output mel spectrogram sequence. - + mel_outputs_postnet: Tensor [shape(B, T_mel, C)] Output mel spectrogram sequence after postnet. - - stop_logits: Tensor [shape=(B, T_mel)] - Output sequence of stop logits befor sigmoid. - + mel_targets: Tensor [shape=(B, T_mel, C)] Target mel spectrogram sequence. - - stop_tokens: Tensor [shape=(B,)] - Target stop token. - + Returns ------- losses : Dict[str, Tensor] - + loss: the sum of the other three losses; mel_loss: MSE loss compute by mel_targets and mel_outputs; post_mel_loss: MSE loss compute by mel_targets and mel_outputs_postnet; - - stop_loss: stop loss computed by stop_logits and stop token. """ - mel_loss = paddle.nn.MSELoss()(mel_outputs, mel_targets) - post_mel_loss = paddle.nn.MSELoss()(mel_outputs_postnet, mel_targets) - stop_loss = paddle.nn.BCEWithLogitsLoss()(stop_logits, stop_tokens) - total_loss = mel_loss + post_mel_loss + stop_loss - losses = dict( - loss=total_loss, - mel_loss=mel_loss, - post_mel_loss=post_mel_loss, - stop_loss=stop_loss) + mel_loss = self.spec_criterion(mel_outputs, mel_targets) + post_mel_loss = self.spec_criterion(mel_outputs_postnet, mel_targets) + gal_loss = self.attn_criterion(attention_weights, slens, plens, + self.sigma) + total_loss = mel_loss + post_mel_loss + gal_loss + losses = { + "loss": total_loss, + "mel_loss": mel_loss, + "post_mel_loss": post_mel_loss, + "guided_attn_loss": gal_loss + } return losses diff --git a/parakeet/modules/attention.py b/parakeet/modules/attention.py index aaf0b55..bedc6be 100644 --- a/parakeet/modules/attention.py +++ b/parakeet/modules/attention.py @@ -143,9 +143,9 @@ class MonoheadAttention(nn.Layer): def __init__(self, model_dim: int, - dropout: float=0.0, - k_dim: int=None, - v_dim: int=None): + dropout: float = 0.0, + k_dim: int = None, + v_dim: int = None): super(MonoheadAttention, self).__init__() k_dim = k_dim or model_dim v_dim = v_dim or model_dim @@ -225,9 +225,9 @@ class MultiheadAttention(nn.Layer): def __init__(self, model_dim: int, num_heads: int, - dropout: float=0.0, - k_dim: int=None, - v_dim: int=None): + dropout: float = 0.0, + k_dim: int = None, + v_dim: int = None): super(MultiheadAttention, self).__init__() if model_dim % num_heads != 0: raise ValueError("model_dim must be divisible by num_heads") @@ -316,14 +316,11 @@ class LocationSensitiveAttention(nn.Layer): self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False) self.value = nn.Linear(d_attention, 1, bias_attr=False) - #Location Layer + # Location Layer self.location_conv = nn.Conv1D( - 2, - location_filters, - location_kernel_size, - 1, - int((location_kernel_size - 1) / 2), - 1, + 2, location_filters, + kernel_size=location_kernel_size, + padding=int((location_kernel_size - 1) / 2), bias_attr=False, data_format='NLC') self.location_layer = nn.Linear( @@ -352,21 +349,22 @@ class LocationSensitiveAttention(nn.Layer): Attention weights concat. mask : Tensor, optional - The mask. Shape should be (batch_size, times_steps_q, time_steps_k) or broadcastable shape. + The mask. Shape should be (batch_size, times_steps_k, 1). Defaults to None. Returns ---------- - attention_context : Tensor [shape=(batch_size, time_steps_q, d_attention)] + attention_context : Tensor [shape=(batch_size, d_attention)] The context vector. - attention_weights : Tensor [shape=(batch_size, times_steps_q, time_steps_k)] + attention_weights : Tensor [shape=(batch_size, time_steps_k)] The attention weights. """ processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1])) processed_attention_weights = self.location_layer( self.location_conv(attention_weights_cat)) + # (B, T_enc, 1) alignment = self.value( paddle.tanh(processed_attention_weights + processed_key + processed_query)) @@ -378,7 +376,7 @@ class LocationSensitiveAttention(nn.Layer): attention_context = paddle.matmul( attention_weights, value, transpose_x=True) - attention_weights = paddle.squeeze(attention_weights, axis=[-1]) - attention_context = paddle.squeeze(attention_context, axis=[1]) + attention_weights = paddle.squeeze(attention_weights, axis=-1) + attention_context = paddle.squeeze(attention_context, axis=1) return attention_context, attention_weights diff --git a/parakeet/modules/losses.py b/parakeet/modules/losses.py index ab188fd..eb68dae 100644 --- a/parakeet/modules/losses.py +++ b/parakeet/modules/losses.py @@ -17,15 +17,51 @@ import numpy as np import paddle from paddle import nn from paddle.nn import functional as F +from paddle.fluid.layers import sequence_mask __all__ = [ + "guided_attention_loss", "weighted_mean", "masked_l1_loss", "masked_softmax_with_cross_entropy", - "diagonal_loss", ] +def attention_guide(dec_lens, enc_lens, N, T, g, dtype=None): + """Build that W matrix. shape(B, T_dec, T_enc) + W[i, n, t] = 1 - exp(-(n/dec_lens[i] - t/enc_lens[i])**2 / (2g**2)) + + See also: + Tachibana, Hideyuki, Katsuya Uenoyama, and Shunsuke Aihara. 2017. “Efficiently Trainable Text-to-Speech System Based on Deep Convolutional Networks with Guided Attention.” ArXiv:1710.08969 [Cs, Eess], October. http://arxiv.org/abs/1710.08969. + """ + dtype = dtype or paddle.get_default_dtype() + dec_pos = paddle.arange(0, N).astype( + dtype) / dec_lens.unsqueeze(-1) # n/N # shape(B, T_dec) + enc_pos = paddle.arange(0, T).astype( + dtype) / enc_lens.unsqueeze(-1) # t/T # shape(B, T_enc) + W = 1 - paddle.exp(-(dec_pos.unsqueeze(-1) - + enc_pos.unsqueeze(1))**2 / (2 * g ** 2)) + + dec_mask = sequence_mask(dec_lens, maxlen=N) + enc_mask = sequence_mask(enc_lens, maxlen=T) + mask = dec_mask.unsqueeze(-1) * enc_mask.unsqueeze(1) + mask = paddle.cast(mask, W.dtype) + + W *= mask + return W + + +def guided_attention_loss(attention_weight, dec_lens, enc_lens, g): + """Guided attention loss, masked to excluded padding parts.""" + _, N, T = attention_weight.shape + W = attention_guide(dec_lens, enc_lens, N, T, g, attention_weight.dtype) + + total_tokens = (dec_lens * enc_lens).astype(W.dtype) + loss = paddle.mean(paddle.sum( + W * attention_weight, [1, 2]) / total_tokens) + return loss, W + + def weighted_mean(input, weight): """Weighted mean. It can also be used as masked mean. @@ -40,14 +76,10 @@ def weighted_mean(input, weight): ---------- Tensor [shape=(1,)] Weighted mean tensor with the same dtype as input. - - Warnings - --------- - This is not a mathematical weighted mean. It performs weighted sum and - simple average. """ weight = paddle.cast(weight, input.dtype) - return paddle.mean(input * weight) + broadcast_ratio = input.size / weight.size + return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_ratio) def masked_l1_loss(prediction, target, mask): @@ -101,70 +133,3 @@ def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1): ce = F.softmax_with_cross_entropy(logits, label, axis=axis) loss = weighted_mean(ce, mask) return loss - - -def diagonal_loss(attentions, - input_lengths, - target_lengths, - g=0.2, - multihead=False): - """A metric to evaluate how diagonal a attention distribution is. - - It is computed for batch attention distributions. For each attention - distribution, the valid decoder time steps and encoder time steps may - differ. - - Parameters - ---------- - attentions : Tensor [shape=(B, T_dec, T_enc) or (B, H, T_dec, T_dec)] - The attention weights from an encoder-decoder structure. - - input_lengths : Tensor [shape=(B,)] - The valid length for each encoder output. - - target_lengths : Tensor [shape=(B,)] - The valid length for each decoder output. - - g : float, optional - [description], by default 0.2. - - multihead : bool, optional - A flag indicating whether ``attentions`` is a multihead attention's - attention distribution. - - If ``True``, the shape of attention is ``(B, H, T_dec, T_dec)``, by - default False. - - Returns - ------- - Tensor [shape=(1,)] - The diagonal loss. - """ - W = guided_attentions(input_lengths, target_lengths, g) - W_tensor = paddle.to_tensor(W) - if not multihead: - return paddle.mean(attentions * W_tensor) - else: - return paddle.mean(attentions * paddle.unsqueeze(W_tensor, 1)) - - -@numba.jit(nopython=True) -def guided_attention(N, max_N, T, max_T, g): - W = np.zeros((max_T, max_N), dtype=np.float32) - for t in range(T): - for n in range(N): - W[t, n] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g)) - # (T_dec, T_enc) - return W - - -def guided_attentions(input_lengths, target_lengths, g=0.2): - B = len(input_lengths) - max_input_len = input_lengths.max() - max_target_len = target_lengths.max() - W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32) - for b in range(B): - W[b] = guided_attention(input_lengths[b], max_input_len, - target_lengths[b], max_target_len, g) - # (B, T_dec, T_enc) - return W