From 0287f465326176ed5ebf3918ea17e391c0e7e4f7 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Sat, 5 Dec 2020 21:08:10 +0800 Subject: [PATCH] switch back to keras style sample weight --- parakeet/modules/losses.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/parakeet/modules/losses.py b/parakeet/modules/losses.py index 84afed1..b8bc945 100644 --- a/parakeet/modules/losses.py +++ b/parakeet/modules/losses.py @@ -15,8 +15,7 @@ def weighted_mean(input, weight): Tensor: shape(1,), weighted mean tensor with the same dtype as input. """ weight = paddle.cast(weight, input.dtype) - broadcast_factor = input.numel() / weight.numel() - return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_factor) + return paddle.mean(input * weight) def masked_l1_loss(prediction, target, mask): abs_error = F.l1_loss(prediction, target, reduction='none')