171 lines
4.9 KiB
Python
171 lines
4.9 KiB
Python
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import numba
|
|
import numpy as np
|
|
import paddle
|
|
from paddle import nn
|
|
from paddle.nn import functional as F
|
|
|
|
__all__ = [
|
|
"weighted_mean",
|
|
"masked_l1_loss",
|
|
"masked_softmax_with_cross_entropy",
|
|
"diagonal_loss",
|
|
]
|
|
|
|
|
|
def weighted_mean(input, weight):
|
|
"""Weighted mean. It can also be used as masked mean.
|
|
|
|
Parameters
|
|
-----------
|
|
input : Tensor
|
|
The input tensor.
|
|
weight : Tensor
|
|
The weight tensor with broadcastable shape with the input.
|
|
|
|
Returns
|
|
----------
|
|
Tensor [shape=(1,)]
|
|
Weighted mean tensor with the same dtype as input.
|
|
|
|
Warnings
|
|
---------
|
|
This is not a mathematical weighted mean. It performs weighted sum and
|
|
simple average.
|
|
"""
|
|
weight = paddle.cast(weight, input.dtype)
|
|
return paddle.mean(input * weight)
|
|
|
|
|
|
def masked_l1_loss(prediction, target, mask):
|
|
"""Compute maksed L1 loss.
|
|
|
|
Parameters
|
|
----------
|
|
prediction : Tensor
|
|
The prediction.
|
|
|
|
target : Tensor
|
|
The target. The shape should be broadcastable to ``prediction``.
|
|
|
|
mask : Tensor
|
|
The mask. The shape should be broadcatable to the broadcasted shape of
|
|
``prediction`` and ``target``.
|
|
|
|
Returns
|
|
-------
|
|
Tensor [shape=(1,)]
|
|
The masked L1 loss.
|
|
"""
|
|
abs_error = F.l1_loss(prediction, target, reduction='none')
|
|
loss = weighted_mean(abs_error, mask)
|
|
return loss
|
|
|
|
|
|
def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
|
|
"""Compute masked softmax with cross entropy loss.
|
|
|
|
Parameters
|
|
----------
|
|
logits : Tensor
|
|
The logits. The ``axis``-th axis is the class dimension.
|
|
|
|
label : Tensor [dtype: int]
|
|
The label. The size of the ``axis``-th axis should be 1.
|
|
|
|
mask : Tensor
|
|
The mask. The shape should be broadcastable to ``label``.
|
|
|
|
axis : int, optional
|
|
The index of the class dimension in the shape of ``logits``, by default
|
|
-1.
|
|
|
|
Returns
|
|
-------
|
|
Tensor [shape=(1,)]
|
|
The masked softmax with cross entropy loss.
|
|
"""
|
|
ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
|
|
loss = weighted_mean(ce, mask)
|
|
return loss
|
|
|
|
|
|
def diagonal_loss(attentions,
|
|
input_lengths,
|
|
target_lengths,
|
|
g=0.2,
|
|
multihead=False):
|
|
"""A metric to evaluate how diagonal a attention distribution is.
|
|
|
|
It is computed for batch attention distributions. For each attention
|
|
distribution, the valid decoder time steps and encoder time steps may
|
|
differ.
|
|
|
|
Parameters
|
|
----------
|
|
attentions : Tensor [shape=(B, T_dec, T_enc) or (B, H, T_dec, T_dec)]
|
|
The attention weights from an encoder-decoder structure.
|
|
|
|
input_lengths : Tensor [shape=(B,)]
|
|
The valid length for each encoder output.
|
|
|
|
target_lengths : Tensor [shape=(B,)]
|
|
The valid length for each decoder output.
|
|
|
|
g : float, optional
|
|
[description], by default 0.2.
|
|
|
|
multihead : bool, optional
|
|
A flag indicating whether ``attentions`` is a multihead attention's
|
|
attention distribution.
|
|
|
|
If ``True``, the shape of attention is ``(B, H, T_dec, T_dec)``, by
|
|
default False.
|
|
|
|
Returns
|
|
-------
|
|
Tensor [shape=(1,)]
|
|
The diagonal loss.
|
|
"""
|
|
W = guided_attentions(input_lengths, target_lengths, g)
|
|
W_tensor = paddle.to_tensor(W)
|
|
if not multihead:
|
|
return paddle.mean(attentions * W_tensor)
|
|
else:
|
|
return paddle.mean(attentions * paddle.unsqueeze(W_tensor, 1))
|
|
|
|
|
|
@numba.jit(nopython=True)
|
|
def guided_attention(N, max_N, T, max_T, g):
|
|
W = np.zeros((max_T, max_N), dtype=np.float32)
|
|
for t in range(T):
|
|
for n in range(N):
|
|
W[t, n] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
|
|
# (T_dec, T_enc)
|
|
return W
|
|
|
|
|
|
def guided_attentions(input_lengths, target_lengths, g=0.2):
|
|
B = len(input_lengths)
|
|
max_input_len = input_lengths.max()
|
|
max_target_len = target_lengths.max()
|
|
W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
|
|
for b in range(B):
|
|
W[b] = guided_attention(input_lengths[b], max_input_len,
|
|
target_lengths[b], max_target_len, g)
|
|
# (B, T_dec, T_enc)
|
|
return W
|