siwtch to keras style sample_weight in losses

This commit is contained in:
chenfeiyu 2020-12-03 15:37:43 +08:00
parent 6edc7d8474
commit 810f979dba
2 changed files with 7 additions and 8 deletions

View File

@ -68,7 +68,6 @@ class MultiheadAttention(nn.Layer):
broadcastable shape, dtype: float32 or float64, the mask.
Returns:
(out, attention_weights)
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
"""
@ -134,7 +133,6 @@ class TransformerEncoderLayer(nn.Layer):
mask (Tensor): shape(batch_size, 1, time_steps), the padding mask.
Returns:
(x, attn_weights)
x (Tensor): shape(batch_size, time_steps, d_model), the decoded.
attn_weights (Tensor), shape(batch_size, n_heads, time_steps, time_steps), self attention.
"""
@ -202,7 +200,6 @@ class TransformerDecoderLayer(nn.Layer):
decoder_mask (Tensor): shape(batch_size, time_steps_q, time_steps_q) or broadcastable shape, decoder padding mask.
Returns:
(q, self_attn_weights, cross_attn_weights)
q (Tensor): shape(batch_size, time_steps_q, d_model), the decoded.
self_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_q), decoder self attention.
cross_attn_weights (Tensor), shape(batch_size, n_heads, time_steps_q, time_steps_k), decoder-encoder cross attention.
@ -228,7 +225,7 @@ class TransformerEncoder(nn.LayerList):
Returns:
x (Tensor): shape(batch_size, time_steps, feature_size), the context vector.
attention_weights(list), list of tensors, each of shape
attention_weights(list[Tensor]), each of shape
(batch_size, n_heads, time_steps, time_steps), the attention weights.
"""
attention_weights = []
@ -256,7 +253,9 @@ class TransformerDecoder(nn.LayerList):
drop_n_heads (int, optional): [description]. Defaults to 0.
Returns:
[type]: [description]
q (Tensor): shape(batch_size, time_steps_q, d_model), the output.
self_attention_weights (List[Tensor]): shape (batch_size, num_heads, encoder_steps, encoder_steps)
cross_attention_weights (List[Tensor]): shape (batch_size, num_heads, decoder_steps, encoder_steps)
"""
self_attention_weights = []
cross_attention_weights = []
@ -268,6 +267,7 @@ class TransformerDecoder(nn.LayerList):
class MLPPreNet(nn.Layer):
"""Decoder's prenet."""
def __init__(self, d_input, d_hidden, d_output, dropout):
# (lin + relu + dropout) * n + last projection
super(MLPPreNet, self).__init__()
@ -492,7 +492,7 @@ class TransformerTTS(nn.Layer):
# extract last r steps and append it to decoder output
decoder_output = paddle.concat([decoder_output, mel_output[:, -self.r:, :]], 1)
# stop condition?
# stop condition: (if any ouput frame of the output multiframes hits the stop condition)
if paddle.any(paddle.argmax(stop_logits[0, :, :], axis=-1) == self.stop_prob_index):
if verbose:
print("Hits stop condition.")

View File

@ -13,8 +13,7 @@ def weighted_mean(input, weight):
Tensor: shape(1,), weighted mean tensor with the same dtype as input.
"""
weight = paddle.cast(weight, input.dtype)
broadcast_factor = input.numel() / weight.numel()
return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_factor)
return paddle.mean(input * weight)
def masked_l1_loss(prediction, target, mask):
abs_error = F.l1_loss(prediction, target, reduction='none')