From 810f979dbafb11505a8131a6995e53044feccae1 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 3 Dec 2020 15:37:43 +0800 Subject: [PATCH] siwtch to keras style sample_weight in losses --- parakeet/models/transformer_tts.py | 12 ++++++------ parakeet/modules/losses.py | 3 +-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/parakeet/models/transformer_tts.py b/parakeet/models/transformer_tts.py index 6f1b62f..f9b9046 100644 --- a/parakeet/models/transformer_tts.py +++ b/parakeet/models/transformer_tts.py @@ -68,7 +68,6 @@ class MultiheadAttention(nn.Layer): broadcastable shape, dtype: float32 or float64, the mask. Returns: - (out, attention_weights) out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector. attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights. """ @@ -134,7 +133,6 @@ class TransformerEncoderLayer(nn.Layer): mask (Tensor): shape(batch_size, 1, time_steps), the padding mask. Returns: - (x, attn_weights) x (Tensor): shape(batch_size, time_steps, d_model), the decoded. attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention. """ @@ -202,7 +200,6 @@ class TransformerDecoderLayer(nn.Layer): decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask. Returns: - (q, self_attn_weights, cross_attn_weights) q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded. self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention. cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention. @@ -228,7 +225,7 @@ class TransformerEncoder(nn.LayerList): Returns: x (Tensor): shape(batch_size, time_steps, feature_size), the context vector. - attention_weights(list), list of tensors, each of shape + attention_weights(list[Tensor]), each of shape (batch_size, n_heads, time_steps, time_steps), the attention weights. """ attention_weights = [] @@ -256,7 +253,9 @@ class TransformerDecoder(nn.LayerList): drop_n_heads (int, optional): [description]. Defaults to 0. Returns: - [type]: [description] + q (Tensor): shape(batch_size, time_steps_q, d_model), the output. + self_attention_weights (List[Tensor]): shape (batch_size, num_heads, encoder_steps, encoder_steps) + cross_attention_weights (List[Tensor]): shape (batch_size, num_heads, decoder_steps, encoder_steps) """ self_attention_weights = [] cross_attention_weights = [] @@ -268,6 +267,7 @@ class TransformerDecoder(nn.LayerList): class MLPPreNet(nn.Layer): + """Decoder's prenet.""" def __init__(self, d_input, d_hidden, d_output, dropout): # (lin + relu + dropout) * n + last projection super(MLPPreNet, self).__init__() @@ -492,7 +492,7 @@ class TransformerTTS(nn.Layer): # extract last r steps and append it to decoder output decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1) - # stop condition? + # stop condition: (if any ouput frame of the output multiframes hits the stop condition) if paddle.any(paddle.argmax(stop_logits[0, :, :], axis=-1) == self.stop_prob_index): if verbose: print("Hits stop condition.") diff --git a/parakeet/modules/losses.py b/parakeet/modules/losses.py index 9dd40f0..e7187a8 100644 --- a/parakeet/modules/losses.py +++ b/parakeet/modules/losses.py @@ -13,8 +13,7 @@ def weighted_mean(input, weight): Tensor: shape(1,), weighted mean tensor with the same dtype as input. """ weight = paddle.cast(weight, input.dtype) - broadcast_factor = input.numel() / weight.numel() - return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_factor) + return paddle.mean(input * weight) def masked_l1_loss(prediction, target, mask): abs_error = F.l1_loss(prediction, target, reduction='none')