import paddle from paddle import nn from paddle.nn import functional as F def weighted_mean(input, weight): """weighted mean.(It can also be used as masked mean.) Args: input (Tensor): input tensor, floating point dtype. weight (Tensor): weight tensor with broadcastable shape. Returns: Tensor: shape(1,), weighted mean tensor with the same dtype as input. """ weight = paddle.cast(weight, input.dtype) return paddle.mean(input * weight) def masked_l1_loss(prediction, target, mask): abs_error = F.l1_loss(prediction, target, reduction='none') return weighted_mean(abs_error, mask) def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1): ce = F.softmax_with_cross_entropy(logits, label, axis=axis) return weighted_mean(ce, mask)