switch back to keras style sample weight
This commit is contained in:
parent
d3761683e1
commit
0287f46532
|
@ -15,8 +15,7 @@ def weighted_mean(input, weight):
|
|||
Tensor: shape(1,), weighted mean tensor with the same dtype as input.
|
||||
"""
|
||||
weight = paddle.cast(weight, input.dtype)
|
||||
broadcast_factor = input.numel() / weight.numel()
|
||||
return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_factor)
|
||||
return paddle.mean(input * weight)
|
||||
|
||||
def masked_l1_loss(prediction, target, mask):
|
||||
abs_error = F.l1_loss(prediction, target, reduction='none')
|
||||
|
|
Loading…
Reference in New Issue