32 lines
917 B
Python
32 lines
917 B
Python
import paddle
|
|
from paddle import nn
|
|
from paddle.nn import functional as F
|
|
|
|
def weighted_mean(input, weight):
|
|
"""weighted mean.(It can also be used as masked mean.)
|
|
|
|
Args:
|
|
input (Tensor): input tensor, floating point dtype.
|
|
weight (Tensor): weight tensor with broadcastable shape.
|
|
|
|
Returns:
|
|
Tensor: shape(1,), weighted mean tensor with the same dtype as input.
|
|
"""
|
|
weight = paddle.cast(weight, input.dtype)
|
|
broadcast_factor = input.numel() / weight.numel()
|
|
return paddle.sum(input * weight) / (paddle.sum(weight) * broadcast_factor)
|
|
|
|
def masked_l1_loss(prediction, target, mask):
|
|
abs_error = F.l1_loss(prediction, target, reduction='none')
|
|
return weighted_mean(abs_error, mask)
|
|
|
|
def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
|
|
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
|
return weighted_mean(ce, mask)
|
|
|
|
|
|
|
|
|
|
|
|
|