switch back to keras style sample weight

This commit is contained in:
chenfeiyu 2020-12-05 21:08:10 +08:00
parent d3761683e1
commit 0287f46532
1 changed files with 1 additions and 2 deletions

View File

@ -15,8 +15,7 @@ def weighted_mean(input, weight):
Tensor: shape(1,), weighted mean tensor with the same dtype as input.
"""
weight = paddle.cast(weight, input.dtype)
broadcast_factor = input.numel() / weight.numel()
return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_factor)
return paddle.mean(input * weight)
def masked_l1_loss(prediction, target, mask):
abs_error = F.l1_loss(prediction, target, reduction='none')