diff --git a/parakeet/modules/losses.py b/parakeet/modules/losses.py index 84afed1..b8bc945 100644 --- a/parakeet/modules/losses.py +++ b/parakeet/modules/losses.py @@ -15,8 +15,7 @@ def weighted_mean(input, weight): Tensor: shape(1,), weighted mean tensor with the same dtype as input. """ weight = paddle.cast(weight, input.dtype) - broadcast_factor = input.numel() / weight.numel() - return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_factor) + return paddle.mean(input * weight) def masked_l1_loss(prediction, target, mask): abs_error = F.l1_loss(prediction, target, reduction='none')