# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numba
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F

__all__ = [
    "weighted_mean",
    "masked_l1_loss",
    "masked_softmax_with_cross_entropy",
    "diagonal_loss",
]


def weighted_mean(input, weight):
    """Weighted mean. It can also be used as masked mean.

    Parameters
    -----------
    input : Tensor 
        The input tensor.
    weight : Tensor
        The weight tensor with broadcastable shape with the input.

    Returns
    ----------
    Tensor [shape=(1,)]
        Weighted mean tensor with the same dtype as input.
        
    Warnings
    ---------
    This is not a mathematical weighted mean. It performs weighted sum and 
    simple average.
    """
    weight = paddle.cast(weight, input.dtype)
    return paddle.mean(input * weight)


def masked_l1_loss(prediction, target, mask):
    """Compute maksed L1 loss.

    Parameters
    ----------
    prediction : Tensor
        The prediction.
        
    target : Tensor
        The target. The shape should be broadcastable to ``prediction``.
        
    mask : Tensor
        The mask. The shape should be broadcatable to the broadcasted shape of 
        ``prediction`` and ``target``.

    Returns
    -------
    Tensor [shape=(1,)]
        The masked L1 loss.
    """
    abs_error = F.l1_loss(prediction, target, reduction='none')
    loss = weighted_mean(abs_error, mask)
    return loss


def masked_softmax_with_cross_entropy(logits, label, mask, axis=-1):
    """Compute masked softmax with cross entropy loss.

    Parameters
    ----------
    logits : Tensor
        The logits. The ``axis``-th axis is the class dimension.
        
    label : Tensor [dtype: int]
        The label. The size of the ``axis``-th axis should be 1.
        
    mask : Tensor 
        The mask. The shape should be broadcastable to ``label``.
        
    axis : int, optional
        The index of the class dimension in the shape of ``logits``, by default
        -1.

    Returns
    -------
    Tensor [shape=(1,)]
        The masked softmax with cross entropy loss.
    """
    ce = F.softmax_with_cross_entropy(logits, label, axis=axis)
    loss = weighted_mean(ce, mask)
    return loss


def diagonal_loss(attentions,
                  input_lengths,
                  target_lengths,
                  g=0.2,
                  multihead=False):
    """A metric to evaluate how diagonal a attention distribution is.
    
    It is computed for batch attention distributions. For each attention 
    distribution, the valid decoder time steps and encoder time steps may
    differ.

    Parameters
    ----------
    attentions : Tensor [shape=(B, T_dec, T_enc) or (B, H, T_dec, T_dec)]
        The attention weights from an encoder-decoder structure.
        
    input_lengths : Tensor [shape=(B,)]
        The valid length for each encoder output.
        
    target_lengths : Tensor [shape=(B,)]
        The valid length for each decoder output.
        
    g : float, optional
        [description], by default 0.2.
        
    multihead : bool, optional
        A flag indicating whether ``attentions`` is a multihead attention's
        attention distribution. 
        
        If ``True``, the shape of attention is ``(B, H, T_dec, T_dec)``, by 
        default False.

    Returns
    -------
    Tensor [shape=(1,)]
        The diagonal loss.
    """
    W = guided_attentions(input_lengths, target_lengths, g)
    W_tensor = paddle.to_tensor(W)
    if not multihead:
        return paddle.mean(attentions * W_tensor)
    else:
        return paddle.mean(attentions * paddle.unsqueeze(W_tensor, 1))


@numba.jit(nopython=True)
def guided_attention(N, max_N, T, max_T, g):
    W = np.zeros((max_T, max_N), dtype=np.float32)
    for t in range(T):
        for n in range(N):
            W[t, n] = 1 - np.exp(-(n / N - t / T)**2 / (2 * g * g))
    # (T_dec, T_enc)
    return W


def guided_attentions(input_lengths, target_lengths, g=0.2):
    B = len(input_lengths)
    max_input_len = input_lengths.max()
    max_target_len = target_lengths.max()
    W = np.zeros((B, max_target_len, max_input_len), dtype=np.float32)
    for b in range(B):
        W[b] = guided_attention(input_lengths[b], max_input_len,
                                target_lengths[b], max_target_len, g)
    # (B, T_dec, T_enc)
    return W