diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index 4cc3df3..fedd58e 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -1,3 +1,17 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math from tqdm import trange import paddle @@ -15,6 +29,7 @@ from parakeet.modules import losses as L __all__ = ["TransformerTTS", "TransformerTTSLoss"] + # Transformer TTS's own implementation of transformer class MultiheadAttention(nn.Layer): """ @@ -25,7 +40,14 @@ class MultiheadAttention(nn.Layer): Another deviation is that it concats the input query and context vector before applying the output projection. """ - def __init__(self, model_dim, num_heads, k_dim=None, v_dim=None, k_input_dim=None, v_input_dim=None): + + def __init__(self, + model_dim, + num_heads, + k_dim=None, + v_dim=None, + k_input_dim=None, + v_input_dim=None): """ Args: model_dim (int): the feature size of query. @@ -41,7 +63,7 @@ class MultiheadAttention(nn.Layer): ValueError: if model_dim is not divisible by num_heads """ super(MultiheadAttention, self).__init__() - if model_dim % num_heads !=0: + if model_dim % num_heads != 0: raise ValueError("model_dim must be divisible by num_heads") depth = model_dim // num_heads k_dim = k_dim or depth @@ -52,10 +74,10 @@ class MultiheadAttention(nn.Layer): self.affine_k = nn.Linear(k_input_dim, num_heads * k_dim) self.affine_v = nn.Linear(v_input_dim, num_heads * v_dim) self.affine_o = nn.Linear(model_dim + num_heads * v_dim, model_dim) - + self.num_heads = num_heads self.model_dim = model_dim - + def forward(self, q, k, v, mask, drop_n_heads=0): """ Compute context vector and attention weights. @@ -72,17 +94,18 @@ class MultiheadAttention(nn.Layer): attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights. """ q_in = q - q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C) + q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C) k = _split_heads(self.affine_k(k), self.num_heads) v = _split_heads(self.affine_v(v), self.num_heads) if mask is not None: - mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim - + mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim + context_vectors, attention_weights = scaled_dot_product_attention( q, k, v, mask, training=self.training) - context_vectors = drop_head(context_vectors, drop_n_heads, self.training) - context_vectors = _concat_heads(context_vectors) # (B, T, h*C) - + context_vectors = drop_head(context_vectors, drop_n_heads, + self.training) + context_vectors = _concat_heads(context_vectors) # (B, T, h*C) + concat_feature = paddle.concat([q_in, context_vectors], -1) out = self.affine_o(concat_feature) return out, attention_weights @@ -92,6 +115,7 @@ class TransformerEncoderLayer(nn.Layer): """ Transformer encoder layer. """ + def __init__(self, d_model, n_heads, d_ffn, dropout=0.): """ Args: @@ -114,8 +138,10 @@ class TransformerEncoderLayer(nn.Layer): # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual x_in = x x = self.layer_norm1(x) - context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads) - context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training) + context_vector, attn_weights = self.self_mha(x, x, x, mask, + drop_n_heads) + context_vector = x_in + F.dropout( + context_vector, self.dropout, training=self.training) return context_vector, attn_weights def _forward_ffn(self, x): @@ -123,9 +149,9 @@ class TransformerEncoderLayer(nn.Layer): x_in = x x = self.layer_norm2(x) x = self.ffn(x) - out= x_in + F.dropout(x, self.dropout, training=self.training) + out = x_in + F.dropout(x, self.dropout, training=self.training) return out - + def forward(self, x, mask, drop_n_heads=0): """ Args: @@ -145,6 +171,7 @@ class TransformerDecoderLayer(nn.Layer): """ Transformer decoder layer. """ + def __init__(self, d_model, n_heads, d_ffn, dropout=0., d_encoder=None): """ Args: @@ -157,37 +184,42 @@ class TransformerDecoderLayer(nn.Layer): super(TransformerDecoderLayer, self).__init__() self.self_mha = MultiheadAttention(d_model, n_heads) self.layer_norm1 = nn.LayerNorm([d_model], epsilon=1e-6) - - self.cross_mha = MultiheadAttention(d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder) + + self.cross_mha = MultiheadAttention( + d_model, n_heads, k_input_dim=d_encoder, v_input_dim=d_encoder) self.layer_norm2 = nn.LayerNorm([d_model], epsilon=1e-6) - + self.ffn = PositionwiseFFN(d_model, d_ffn, dropout) self.layer_norm3 = nn.LayerNorm([d_model], epsilon=1e-6) self.dropout = dropout - + def _forward_self_mha(self, x, mask, drop_n_heads): # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual x_in = x x = self.layer_norm1(x) - context_vector, attn_weights = self.self_mha(x, x, x, mask, drop_n_heads) - context_vector = x_in + F.dropout(context_vector, self.dropout, training=self.training) + context_vector, attn_weights = self.self_mha(x, x, x, mask, + drop_n_heads) + context_vector = x_in + F.dropout( + context_vector, self.dropout, training=self.training) return context_vector, attn_weights def _forward_cross_mha(self, q, k, v, mask, drop_n_heads): # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual q_in = q q = self.layer_norm2(q) - context_vector, attn_weights = self.cross_mha(q, k, v, mask, drop_n_heads) - context_vector = q_in + F.dropout(context_vector, self.dropout, training=self.training) + context_vector, attn_weights = self.cross_mha(q, k, v, mask, + drop_n_heads) + context_vector = q_in + F.dropout( + context_vector, self.dropout, training=self.training) return context_vector, attn_weights - + def _forward_ffn(self, x): # PreLN scheme: Norm -> SubLayer -> Dropout -> Residual x_in = x x = self.layer_norm3(x) x = self.ffn(x) - out= x_in + F.dropout(x, self.dropout, training=self.training) + out = x_in + F.dropout(x, self.dropout, training=self.training) return out def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0): @@ -204,8 +236,10 @@ class TransformerDecoderLayer(nn.Layer): self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention. cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention. """ - q, self_attn_weights = self._forward_self_mha(q, decoder_mask, drop_n_heads) - q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, drop_n_heads) + q, self_attn_weights = self._forward_self_mha(q, decoder_mask, + drop_n_heads) + q, cross_attn_weights = self._forward_cross_mha(q, k, v, encoder_mask, + drop_n_heads) q = self._forward_ffn(q) return q, self_attn_weights, cross_attn_weights @@ -214,7 +248,8 @@ class TransformerEncoder(nn.LayerList): def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0.): super(TransformerEncoder, self).__init__() for _ in range(n_layers): - self.append(TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout)) + self.append( + TransformerEncoderLayer(d_model, n_heads, d_ffn, dropout)) def forward(self, x, mask, drop_n_heads=0): """ @@ -236,10 +271,18 @@ class TransformerEncoder(nn.LayerList): class TransformerDecoder(nn.LayerList): - def __init__(self, d_model, n_heads, d_ffn, n_layers, dropout=0., d_encoder=None): + def __init__(self, + d_model, + n_heads, + d_ffn, + n_layers, + dropout=0., + d_encoder=None): super(TransformerDecoder, self).__init__() for _ in range(n_layers): - self.append(TransformerDecoderLayer(d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder)) + self.append( + TransformerDecoderLayer( + d_model, n_heads, d_ffn, dropout, d_encoder=d_encoder)) def forward(self, q, k, v, encoder_mask, decoder_mask, drop_n_heads=0): """[summary] @@ -260,7 +303,8 @@ class TransformerDecoder(nn.LayerList): self_attention_weights = [] cross_attention_weights = [] for layer in self: - q, self_attention_weights_i, cross_attention_weights_i = layer(q, k, v, encoder_mask, decoder_mask, drop_n_heads) + q, self_attention_weights_i, cross_attention_weights_i = layer( + q, k, v, encoder_mask, decoder_mask, drop_n_heads) self_attention_weights.append(self_attention_weights_i) cross_attention_weights.append(cross_attention_weights_i) return q, self_attention_weights, cross_attention_weights @@ -268,6 +312,7 @@ class TransformerDecoder(nn.LayerList): class MLPPreNet(nn.Layer): """Decoder's prenet.""" + def __init__(self, d_input, d_hidden, d_output, dropout): # (lin + relu + dropout) * n + last projection super(MLPPreNet, self).__init__() @@ -275,16 +320,24 @@ class MLPPreNet(nn.Layer): self.lin2 = nn.Linear(d_hidden, d_hidden) self.lin3 = nn.Linear(d_hidden, d_hidden) self.dropout = dropout - + def forward(self, x, dropout): - l1 = F.dropout(F.relu(self.lin1(x)), self.dropout, training=self.training) - l2 = F.dropout(F.relu(self.lin2(l1)), self.dropout, training=self.training) + l1 = F.dropout( + F.relu(self.lin1(x)), self.dropout, training=self.training) + l2 = F.dropout( + F.relu(self.lin2(l1)), self.dropout, training=self.training) l3 = self.lin3(l2) return l3 + # NOTE: not used in class CNNPreNet(nn.Layer): - def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers, + def __init__(self, + d_input, + d_hidden, + d_output, + kernel_size, + n_layers, dropout=0.): # (conv + bn + relu + dropout) * n + last projection super(CNNPreNet, self).__init__() @@ -292,16 +345,21 @@ class CNNPreNet(nn.Layer): c_in = d_input for _ in range(n_layers): self.convs.append( - Conv1dBatchNorm(c_in, d_hidden, kernel_size, - weight_attr=I.XavierUniform(), - padding="same", data_format="NLC")) + Conv1dBatchNorm( + c_in, + d_hidden, + kernel_size, + weight_attr=I.XavierUniform(), + padding="same", + data_format="NLC")) c_in = d_hidden self.affine_out = nn.Linear(d_hidden, d_output) self.dropout = dropout - + def forward(self, x): for layer in self.convs: - x = F.dropout(F.relu(layer(x)), self.dropout, training=self.training) + x = F.dropout( + F.relu(layer(x)), self.dropout, training=self.training) x = self.affine_out(x) return x @@ -310,21 +368,25 @@ class CNNPostNet(nn.Layer): def __init__(self, d_input, d_hidden, d_output, kernel_size, n_layers): super(CNNPostNet, self).__init__() self.convs = nn.LayerList() - kernel_size = kernel_size if isinstance(kernel_size, (tuple, list)) else (kernel_size, ) + kernel_size = kernel_size if isinstance(kernel_size, ( + tuple, list)) else (kernel_size, ) padding = (kernel_size[0] - 1, 0) for i in range(n_layers): c_in = d_input if i == 0 else d_hidden c_out = d_output if i == n_layers - 1 else d_hidden self.convs.append( - Conv1dBatchNorm(c_in, c_out, kernel_size, - weight_attr=I.XavierUniform(), - padding=padding)) + Conv1dBatchNorm( + c_in, + c_out, + kernel_size, + weight_attr=I.XavierUniform(), + padding=padding)) self.last_bn = nn.BatchNorm1D(d_output) # for a layer that ends with a normalization layer that is targeted to # output a non zero-central output, it may take a long time to # train the scale and bias # NOTE: it can also be a non-causal conv - + def forward(self, x): x_in = x for i, layer in enumerate(self.convs): @@ -336,19 +398,19 @@ class CNNPostNet(nn.Layer): class TransformerTTS(nn.Layer): - def __init__(self, - frontend: parakeet.frontend.Phonetics, - d_encoder: int, - d_decoder: int, - d_mel: int, + def __init__(self, + frontend: parakeet.frontend.Phonetics, + d_encoder: int, + d_decoder: int, + d_mel: int, n_heads: int, d_ffn: int, - encoder_layers: int, - decoder_layers: int, - d_prenet: int, - d_postnet: int, - postnet_layers: int, - postnet_kernel_size: int, + encoder_layers: int, + decoder_layers: int, + d_prenet: int, + d_postnet: int, + postnet_layers: int, + postnet_kernel_size: int, max_reduction_factor: int, decoder_prenet_dropout: float, dropout: float): @@ -359,29 +421,34 @@ class TransformerTTS(nn.Layer): # encoder self.encoder_prenet = nn.Embedding( - frontend.vocab_size, d_encoder, - padding_idx=frontend.vocab.padding_index, + frontend.vocab_size, + d_encoder, + padding_idx=frontend.vocab.padding_index, weight_attr=I.Uniform(-0.05, 0.05)) # position encoding matrix may be extended later - self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) + self.encoder_pe = pe.positional_encoding(0, 1000, d_encoder) self.encoder_pe_scalar = self.create_parameter( [1], attr=I.Constant(1.)) - self.encoder = TransformerEncoder( - d_encoder, n_heads, d_ffn, encoder_layers, dropout) - + self.encoder = TransformerEncoder(d_encoder, n_heads, d_ffn, + encoder_layers, dropout) + # decoder self.decoder_prenet = MLPPreNet(d_mel, d_prenet, d_decoder, dropout) self.decoder_pe = pe.positional_encoding(0, 1000, d_decoder) self.decoder_pe_scalar = self.create_parameter( [1], attr=I.Constant(1.)) self.decoder = TransformerDecoder( - d_decoder, n_heads, d_ffn, decoder_layers, dropout, + d_decoder, + n_heads, + d_ffn, + decoder_layers, + dropout, d_encoder=d_encoder) self.final_proj = nn.Linear(d_decoder, max_reduction_factor * d_mel) - self.decoder_postnet = CNNPostNet( - d_mel, d_postnet, d_mel, postnet_kernel_size, postnet_layers) + self.decoder_postnet = CNNPostNet(d_mel, d_postnet, d_mel, + postnet_kernel_size, postnet_layers) self.stop_conditioner = nn.Linear(d_mel, 3) - + # specs self.padding_idx = frontend.vocab.padding_index self.d_encoder = d_encoder @@ -390,21 +457,22 @@ class TransformerTTS(nn.Layer): self.max_r = max_reduction_factor self.dropout = dropout self.decoder_prenet_dropout = decoder_prenet_dropout - + # start and end: though it is only used in predict # it can also be used in training dtype = paddle.get_default_dtype() self.start_vec = paddle.full([1, d_mel], 0.5, dtype=dtype) self.end_vec = paddle.full([1, d_mel], -0.5, dtype=dtype) self.stop_prob_index = 2 - + # mutables - self.r = max_reduction_factor # set it every call + self.r = max_reduction_factor # set it every call self.drop_n_heads = 0 - + def forward(self, text, mel): encoded, encoder_attention_weights, encoder_mask = self.encode(text) - mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode(encoded, mel, encoder_mask) + mel_output, mel_intermediate, cross_attention_weights, stop_logits = self.decode( + encoded, mel, encoder_mask) outputs = { "mel_output": mel_output, "mel_intermediate": mel_intermediate, @@ -420,51 +488,54 @@ class TransformerTTS(nn.Layer): if embed.shape[1] > self.encoder_pe.shape[0]: new_T = max(embed.shape[1], self.encoder_pe.shape[0] * 2) self.encoder_pe = pe.positional_encoding(0, new_T, self.d_encoder) - pos_enc = self.encoder_pe[:T_enc, :] # (T, C) - x = embed.scale(math.sqrt(self.d_encoder)) + pos_enc * self.encoder_pe_scalar + pos_enc = self.encoder_pe[:T_enc, :] # (T, C) + x = embed.scale(math.sqrt( + self.d_encoder)) + pos_enc * self.encoder_pe_scalar x = F.dropout(x, self.dropout, training=self.training) # TODO(chenfeiyu): unsqueeze a decoder_time_steps=1 for the mask encoder_padding_mask = paddle.unsqueeze( - masking.id_mask(text, self.padding_idx, dtype=x.dtype), 1) - x, attention_weights = self.encoder(x, encoder_padding_mask, self.drop_n_heads) + masking.id_mask( + text, self.padding_idx, dtype=x.dtype), 1) + x, attention_weights = self.encoder(x, encoder_padding_mask, + self.drop_n_heads) return x, attention_weights, encoder_padding_mask - + def decode(self, encoder_output, input, encoder_padding_mask): batch_size, T_dec, mel_dim = input.shape - + x = self.decoder_prenet(input, self.decoder_prenet_dropout) # twice its length if needed if x.shape[1] * self.r > self.decoder_pe.shape[0]: new_T = max(x.shape[1] * self.r, self.decoder_pe.shape[0] * 2) self.decoder_pe = pe.positional_encoding(0, new_T, self.d_decoder) - pos_enc = self.decoder_pe[:T_dec*self.r:self.r, :] - x = x.scale(math.sqrt(self.d_decoder)) + pos_enc * self.decoder_pe_scalar + pos_enc = self.decoder_pe[:T_dec * self.r:self.r, :] + x = x.scale(math.sqrt( + self.d_decoder)) + pos_enc * self.decoder_pe_scalar x = F.dropout(x, self.dropout, training=self.training) no_future_mask = masking.future_mask(T_dec, dtype=input.dtype) - decoder_padding_mask = masking.feature_mask(input, axis=-1, dtype=input.dtype) - decoder_mask = masking.combine_mask(decoder_padding_mask.unsqueeze(-1), no_future_mask) + decoder_padding_mask = masking.feature_mask( + input, axis=-1, dtype=input.dtype) + decoder_mask = masking.combine_mask( + decoder_padding_mask.unsqueeze(-1), no_future_mask) decoder_output, _, cross_attention_weights = self.decoder( - x, - encoder_output, - encoder_output, - encoder_padding_mask, - decoder_mask, - self.drop_n_heads) + x, encoder_output, encoder_output, encoder_padding_mask, + decoder_mask, self.drop_n_heads) # use only parts of it - output_proj = self.final_proj(decoder_output)[:, :, : self.r * mel_dim] - mel_intermediate = paddle.reshape(output_proj, [batch_size, -1, mel_dim]) + output_proj = self.final_proj(decoder_output)[:, :, :self.r * mel_dim] + mel_intermediate = paddle.reshape(output_proj, + [batch_size, -1, mel_dim]) stop_logits = self.stop_conditioner(mel_intermediate) - + # cnn postnet mel_channel_first = paddle.transpose(mel_intermediate, [0, 2, 1]) mel_output = self.decoder_postnet(mel_channel_first) mel_output = paddle.transpose(mel_output, [0, 2, 1]) return mel_output, mel_intermediate, cross_attention_weights, stop_logits - + def predict(self, input, raw_input=True, max_length=1000, verbose=True): """Predict log scale magnitude mel spectrogram from text input. @@ -475,26 +546,32 @@ class TransformerTTS(nn.Layer): """ if raw_input: text_ids = paddle.to_tensor(self.frontend(input)) - text_input = paddle.unsqueeze(text_ids, 0) # (1, T) + text_input = paddle.unsqueeze(text_ids, 0) # (1, T) else: text_input = input - - decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) - decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) - + + decoder_input = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) + decoder_output = paddle.unsqueeze(self.start_vec, 0) # (B=1, T, C) + # encoder the text sequence - encoder_output, encoder_attentions, encoder_padding_mask = self.encode(text_input) - for _ in trange(int(max_length // self.r) + 1): + encoder_output, encoder_attentions, encoder_padding_mask = self.encode( + text_input) + for _ in range(int(max_length // self.r) + 1): mel_output, _, cross_attention_weights, stop_logits = self.decode( encoder_output, decoder_input, encoder_padding_mask) - + # extract last step and append it to decoder input - decoder_input = paddle.concat([decoder_input, mel_output[:, -1:, :]], 1) + decoder_input = paddle.concat( + [decoder_input, mel_output[:, -1:, :]], 1) # extract last r steps and append it to decoder output - decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1) - + decoder_output = paddle.concat( + [decoder_output, mel_output[:, -self.r:, :]], 1) + # stop condition: (if any ouput frame of the output multiframes hits the stop condition) - if paddle.any(paddle.argmax(stop_logits[0, -self.r:, :], axis=-1) == self.stop_prob_index): + if paddle.any( + paddle.argmax( + stop_logits[0, -self.r:, :], axis=-1) == + self.stop_prob_index): if verbose: print("Hits stop condition.") break @@ -516,24 +593,28 @@ class TransformerTTSLoss(nn.Layer): def __init__(self, stop_loss_scale): super(TransformerTTSLoss, self).__init__() self.stop_loss_scale = stop_loss_scale - - def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs): - mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype) + + def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, + stop_probs): + mask = masking.feature_mask( + mel_target, axis=-1, dtype=mel_target.dtype) mask1 = paddle.unsqueeze(mask, -1) mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1) mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1) - + mel_len = mask.shape[-1] - last_position = F.one_hot(mask.sum(-1).astype("int64") - 1, num_classes=mel_len) - mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype(mask.dtype) + last_position = F.one_hot( + mask.sum(-1).astype("int64") - 1, num_classes=mel_len) + mask2 = mask + last_position.scale(self.stop_loss_scale - 1).astype( + mask.dtype) stop_loss = L.masked_softmax_with_cross_entropy( stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) - - loss = mel_loss1 + mel_loss2 + stop_loss + + loss = mel_loss1 + mel_loss2 + stop_loss losses = dict( - loss=loss, # total loss - mel_loss1=mel_loss1, # ouput mel loss - mel_loss2=mel_loss2, # intermediate mel loss + loss=loss, # total loss + mel_loss1=mel_loss1, # ouput mel loss + mel_loss2=mel_loss2, # intermediate mel loss stop_loss=stop_loss # stop prob loss ) return losses @@ -542,26 +623,29 @@ class TransformerTTSLoss(nn.Layer): class AdaptiveTransformerTTSLoss(nn.Layer): def __init__(self): super(AdaptiveTransformerTTSLoss, self).__init__() - - def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, stop_probs): - mask = masking.feature_mask(mel_target, axis=-1, dtype=mel_target.dtype) + + def forward(self, mel_output, mel_intermediate, mel_target, stop_logits, + stop_probs): + mask = masking.feature_mask( + mel_target, axis=-1, dtype=mel_target.dtype) mask1 = paddle.unsqueeze(mask, -1) mel_loss1 = L.masked_l1_loss(mel_output, mel_target, mask1) mel_loss2 = L.masked_l1_loss(mel_intermediate, mel_target, mask1) - + batch_size, mel_len = mask.shape valid_lengths = mask.sum(-1).astype("int64") last_position = F.one_hot(valid_lengths - 1, num_classes=mel_len) stop_loss_scale = valid_lengths.sum() / batch_size - 1 - mask2 = mask + last_position.scale(stop_loss_scale - 1).astype(mask.dtype) + mask2 = mask + last_position.scale(stop_loss_scale - 1).astype( + mask.dtype) stop_loss = L.masked_softmax_with_cross_entropy( stop_logits, stop_probs.unsqueeze(-1), mask2.unsqueeze(-1)) - - loss = mel_loss1 + mel_loss2 + stop_loss + + loss = mel_loss1 + mel_loss2 + stop_loss losses = dict( - loss=loss, # total loss - mel_loss1=mel_loss1, # ouput mel loss - mel_loss2=mel_loss2, # intermediate mel loss + loss=loss, # total loss + mel_loss1=mel_loss1, # ouput mel loss + mel_loss2=mel_loss2, # intermediate mel loss stop_loss=stop_loss # stop prob loss ) - return losses \ No newline at end of file + return losses diff --git a/parakeet/modules/attention.py b/parakeet/modules/attention.py index d7053b4..838f886 100644 --- a/parakeet/modules/attention.py +++ b/parakeet/modules/attention.py @@ -1,10 +1,30 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import numpy as np import paddle from paddle import nn from paddle.nn import functional as F -def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True): + +def scaled_dot_product_attention(q, + k, + v, + mask=None, + dropout=0.0, + training=True): """ scaled dot product attention with mask. Assume q, k, v all have the same leader dimensions(denoted as * in descriptions below). Dropout is applied to @@ -22,18 +42,19 @@ def scaled_dot_product_attention(q, k, v, mask=None, dropout=0.0, training=True) out (Tensor): shape(*, T_q, d_v), the context vector. attn_weights (Tensor): shape(*, T_q, T_k), the attention weights. """ - d = q.shape[-1] # we only support imperative execution + d = q.shape[-1] # we only support imperative execution qk = paddle.matmul(q, k, transpose_y=True) scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d)) - + if mask is not None: - scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here - + scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here + attn_weights = F.softmax(scaled_logit, axis=-1) attn_weights = F.dropout(attn_weights, dropout, training=training) out = paddle.matmul(attn_weights, v) return out, attn_weights + def drop_head(x, drop_n_heads, training): """ Drop n heads from multiple context vectors. @@ -48,12 +69,12 @@ def drop_head(x, drop_n_heads, training): """ if not training or (drop_n_heads == 0): return x - + batch_size, num_heads, _, _ = x.shape # drop all heads if num_heads == drop_n_heads: return paddle.zeros_like(x) - + mask = np.ones([batch_size, num_heads]) mask[:, :drop_n_heads] = 0 for subarray in mask: @@ -63,18 +84,21 @@ def drop_head(x, drop_n_heads, training): out = x * paddle.to_tensor(mask) return out + def _split_heads(x, num_heads): batch_size, time_steps, _ = x.shape x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1]) x = paddle.transpose(x, [0, 2, 1, 3]) return x + def _concat_heads(x): batch_size, _, time_steps, _ = x.shape x = paddle.transpose(x, [0, 2, 1, 3]) x = paddle.reshape(x, [batch_size, time_steps, -1]) return x + # Standard implementations of Monohead Attention & Multihead Attention class MonoheadAttention(nn.Layer): def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None): @@ -99,10 +123,10 @@ class MonoheadAttention(nn.Layer): self.affine_k = nn.Linear(model_dim, k_dim) self.affine_v = nn.Linear(model_dim, v_dim) self.affine_o = nn.Linear(v_dim, model_dim) - + self.model_dim = model_dim self.dropout = dropout - + def forward(self, q, k, v, mask): """ Compute context vector and attention weights. @@ -119,22 +143,28 @@ class MonoheadAttention(nn.Layer): out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector. attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights. """ - q = self.affine_q(q) # (B, T, C) + q = self.affine_q(q) # (B, T, C) k = self.affine_k(k) v = self.affine_v(v) - + context_vectors, attention_weights = scaled_dot_product_attention( q, k, v, mask, self.dropout, self.training) - + out = self.affine_o(context_vectors) return out, attention_weights - + class MultiheadAttention(nn.Layer): """ Multihead scaled dot product attention. """ - def __init__(self, model_dim, num_heads, dropout=0.0, k_dim=None, v_dim=None): + + def __init__(self, + model_dim, + num_heads, + dropout=0.0, + k_dim=None, + v_dim=None): """ Multihead Attention module. @@ -154,7 +184,7 @@ class MultiheadAttention(nn.Layer): ValueError: if model_dim is not divisible by num_heads """ super(MultiheadAttention, self).__init__() - if model_dim % num_heads !=0: + if model_dim % num_heads != 0: raise ValueError("model_dim must be divisible by num_heads") depth = model_dim // num_heads k_dim = k_dim or depth @@ -163,11 +193,11 @@ class MultiheadAttention(nn.Layer): self.affine_k = nn.Linear(model_dim, num_heads * k_dim) self.affine_v = nn.Linear(model_dim, num_heads * v_dim) self.affine_o = nn.Linear(num_heads * v_dim, model_dim) - + self.num_heads = num_heads self.model_dim = model_dim self.dropout = dropout - + def forward(self, q, k, v, mask): """ Compute context vector and attention weights. @@ -184,14 +214,67 @@ class MultiheadAttention(nn.Layer): out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector. attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights. """ - q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C) + q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C) k = _split_heads(self.affine_k(k), self.num_heads) v = _split_heads(self.affine_v(v), self.num_heads) - mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim - + mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim + context_vectors, attention_weights = scaled_dot_product_attention( q, k, v, mask, self.dropout, self.training) # NOTE: there is more sophisticated implementation: Scheduled DropHead - context_vectors = _concat_heads(context_vectors) # (B, T, h*C) + context_vectors = _concat_heads(context_vectors) # (B, T, h*C) out = self.affine_o(context_vectors) return out, attention_weights + + +class LocationSensitiveAttention(nn.Layer): + def __init__(self, + d_query: int, + d_key: int, + d_attention: int, + location_filters: int, + location_kernel_size: int): + super().__init__() + + self.query_layer = nn.Linear(d_query, d_attention, bias_attr=False) + self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False) + self.value = nn.Linear(d_attention, 1, bias_attr=False) + + #Location Layer + self.location_conv = nn.Conv1D( + 2, + location_filters, + location_kernel_size, + 1, + int((location_kernel_size - 1) / 2), + 1, + bias_attr=False, + data_format='NLC') + self.location_layer = nn.Linear( + location_filters, d_attention, bias_attr=False) + + def forward(self, + query, + processed_key, + value, + attention_weights_cat, + mask=None): + + processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1])) + processed_attention_weights = self.location_layer( + self.location_conv(attention_weights_cat)) + alignment = self.value( + paddle.tanh(processed_attention_weights + processed_key + + processed_query)) + + if mask is not None: + alignment = alignment + (1.0 - mask) * -1e9 + + attention_weights = F.softmax(alignment, axis=1) + attention_context = paddle.matmul( + attention_weights, value, transpose_x=True) + + attention_weights = paddle.squeeze(attention_weights, axis=[-1]) + attention_context = paddle.squeeze(attention_context, axis=[1]) + + return attention_context, attention_weights diff --git a/parakeet/modules/conv.py b/parakeet/modules/conv.py index 698cda2..35f17b8 100644 --- a/parakeet/modules/conv.py +++ b/parakeet/modules/conv.py @@ -1,6 +1,21 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import paddle from paddle import nn + class Conv1dCell(nn.Conv1D): """ A subclass of Conv1d layer, which can be used like an RNN cell. It can take @@ -14,43 +29,46 @@ class Conv1dCell(nn.Conv1D): As a result, these arguments are removed form the initializer. """ - def __init__(self, + + def __init__(self, in_channels, out_channels, kernel_size, dilation=1, weight_attr=None, bias_attr=None): - _dilation = dilation[0] if isinstance(dilation, (tuple, list)) else dilation - _kernel_size = kernel_size[0] if isinstance(kernel_size, (tuple, list)) else kernel_size + _dilation = dilation[0] if isinstance(dilation, + (tuple, list)) else dilation + _kernel_size = kernel_size[0] if isinstance(kernel_size, ( + tuple, list)) else kernel_size self._r = 1 + (_kernel_size - 1) * _dilation super(Conv1dCell, self).__init__( - in_channels, - out_channels, - kernel_size, - padding=(self._r - 1, 0), - dilation=dilation, - weight_attr=weight_attr, - bias_attr=bias_attr, + in_channels, + out_channels, + kernel_size, + padding=(self._r - 1, 0), + dilation=dilation, + weight_attr=weight_attr, + bias_attr=bias_attr, data_format="NCL") @property def receptive_field(self): return self._r - + def start_sequence(self): if self.training: raise Exception("only use start_sequence in evaluation") self._buffer = None - self._reshaped_weight = paddle.reshape( - self.weight, (self._out_channels, -1)) - + self._reshaped_weight = paddle.reshape(self.weight, + (self._out_channels, -1)) + def initialize_buffer(self, x_t): batch_size, _ = x_t.shape self._buffer = paddle.zeros( - (batch_size, self._in_channels, self.receptive_field), + (batch_size, self._in_channels, self.receptive_field), dtype=x_t.dtype) - + def update_buffer(self, x_t): self._buffer = paddle.concat( [self._buffer[:, :, 1:], paddle.unsqueeze(x_t, -1)], -1) @@ -66,7 +84,7 @@ class Conv1dCell(nn.Conv1D): if self.receptive_field > 1: if self._buffer is None: self.initialize_buffer(x_t) - + # update buffer self.update_buffer(x_t) if self._dilation[0] > 1: @@ -82,20 +100,34 @@ class Conv1dCell(nn.Conv1D): class Conv1dBatchNorm(nn.Layer): - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, - weight_attr=None, bias_attr=None, data_format="NCL"): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + weight_attr=None, + bias_attr=None, + data_format="NCL", + momentum=0.9, + epsilon=1e-05): super(Conv1dBatchNorm, self).__init__() - # TODO(chenfeiyu): carefully initialize Conv1d's weight - self.conv = nn.Conv1D(in_channels, out_channels, kernel_size, stride, - padding=padding, - weight_attr=weight_attr, - bias_attr=bias_attr, - data_format=data_format) - # TODO: channel last, but BatchNorm1d does not support channel last layout - self.bn = nn.BatchNorm1D(out_channels, momentum=0.99, epsilon=1e-3, data_format=data_format) + self.conv = nn.Conv1D( + in_channels, + out_channels, + kernel_size, + stride, + padding=padding, + weight_attr=weight_attr, + bias_attr=bias_attr, + data_format=data_format) + self.bn = nn.BatchNorm1D( + out_channels, + momentum=momentum, + epsilon=epsilon, + data_format=data_format) def forward(self, x): x = self.conv(x) x = self.bn(x) return x -