# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle import fluid
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as F

from parakeet.modules import customized as L


def norm(param, dim, power):
    powered = F.pow(param, power)
    in_dtype = powered.dtype
    if in_dtype == fluid.core.VarDesc.VarType.FP16:
        powered = F.cast(powered, "float32")
    powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False)
    norm_ = F.pow(powered_norm, 1. / power)
    if in_dtype == fluid.core.VarDesc.VarType.FP16:
        norm_ = F.cast(norm_, "float16")
    return norm_


def norm_except(param, dim, power):
    """Computes the norm over all dimensions except dim.
    It differs from pytorch implementation that it does not keep dim.
    This difference is related with the broadcast mechanism in paddle.
    Read elementeise_mul for more.
    """
    shape = param.shape
    ndim = len(shape)

    if dim is None:
        return norm(param, dim, power)
    elif dim == 0:
        param_matrix = F.reshape(param, (shape[0], -1))
        return norm(param_matrix, dim=1, power=power)
    elif dim == -1 or dim == ndim - 1:
        param_matrix = F.reshape(param, (-1, shape[-1]))
        return norm(param_matrix, dim=0, power=power)
    else:
        perm = list(range(ndim))
        perm[0] = dim
        perm[dim] = 0
        transposed_param = F.transpose(param, perm)
        return norm_except(transposed_param, dim=0, power=power)


def compute_l2_normalized_weight(v, g, dim):
    shape = v.shape
    ndim = len(shape)

    if dim is None:
        v_normalized = v / (F.sqrt(F.reduce_sum(F.square(v))) + 1e-12)
    elif dim == 0:
        param_matrix = F.reshape(v, (shape[0], -1))
        v_normalized = F.l2_normalize(param_matrix, axis=1)
        v_normalized = F.reshape(v_normalized, shape)
    elif dim == -1 or dim == ndim - 1:
        param_matrix = F.reshape(v, (-1, shape[-1]))
        v_normalized = F.l2_normalize(param_matrix, axis=0)
        v_normalized = F.reshape(v_normalized, shape)
    else:
        perm = list(range(ndim))
        perm[0] = dim
        perm[dim] = 0
        transposed_param = F.transpose(v, perm)
        transposed_shape = transposed_param.shape
        param_matrix = F.reshape(transposed_param,
                                 (transposed_param.shape[0], -1))
        v_normalized = F.l2_normalize(param_matrix, axis=1)
        v_normalized = F.reshape(v_normalized, transposed_shape)
        v_normalized = F.transpose(v_normalized, perm)
    weight = F.elementwise_mul(v_normalized, g, axis=dim)
    return weight


def compute_weight(v, g, dim, power):
    assert len(g.shape) == 1, "magnitude should be a vector"
    if power == 2:
        in_dtype = v.dtype
        if in_dtype == fluid.core.VarDesc.VarType.FP16:
            v = F.cast(v, "float32")
            g = F.cast(g, "float32")
        weight = compute_l2_normalized_weight(v, g, dim)
        if in_dtype == fluid.core.VarDesc.VarType.FP16:
            weight = F.cast(weight, "float16")
        return weight
    else:
        v_normalized = F.elementwise_div(
            v, (norm_except(v, dim, power) + 1e-12), axis=dim)
        weight = F.elementwise_mul(v_normalized, g, axis=dim)
        return weight


class WeightNormWrapper(dg.Layer):
    def __init__(self, layer, param_name="weight", dim=0, power=2):
        super(WeightNormWrapper, self).__init__()

        self.param_name = param_name
        self.dim = dim
        self.power = power
        self.layer = layer

        w_v = param_name + "_v"
        w_g = param_name + "_g"

        # we could also use numpy to compute this, after all, it is run only once
        # at initialization.
        original_weight = getattr(layer, param_name)
        self.add_parameter(
            w_v,
            self.create_parameter(
                shape=original_weight.shape, dtype=original_weight.dtype))
        with dg.no_grad():
            F.assign(original_weight, getattr(self, w_v))
        delattr(layer, param_name)
        temp = norm_except(getattr(self, w_v), self.dim, self.power)
        self.add_parameter(
            w_g, self.create_parameter(
                shape=temp.shape, dtype=temp.dtype))
        with dg.no_grad():
            F.assign(temp, getattr(self, w_g))

        # also set this when setting up
        setattr(self.layer, self.param_name,
                compute_weight(
                    getattr(self, w_v),
                    getattr(self, w_g), self.dim, self.power))

        self.weigth_norm_applied = True

    # hook to compute weight with v & g
    def hook(self):
        w_v = self.param_name + "_v"
        w_g = self.param_name + "_g"
        setattr(self.layer, self.param_name,
                compute_weight(
                    getattr(self, w_v),
                    getattr(self, w_g), self.dim, self.power))

    def remove_weight_norm(self):
        self.hook()
        self.weigth_norm_applied = False

    def forward(self, *args, **kwargs):
        if self.weigth_norm_applied == True:
            self.hook()
        return self.layer(*args, **kwargs)

    def __getattr__(self, key):
        """
        this is used for attr forwarding.
        """
        if key in self._parameters:
            return self._parameters[key]
        elif key in self._sub_layers:
            return self._sub_layers[key]
        elif key is "layer":
            return self._sub_layers["layer"]
        else:
            return getattr(
                object.__getattribute__(self, "_sub_layers")["layer"], key)


def Linear(input_dim,
           output_dim,
           param_attr=None,
           bias_attr=None,
           act=None,
           dtype="float32"):
    # a weight norm applied linear layer.
    lin = dg.Linear(input_dim, output_dim, param_attr, bias_attr, act, dtype)
    lin = WeightNormWrapper(lin, dim=1)
    return lin


def Conv1D(num_channels,
           num_filters,
           filter_size,
           stride=1,
           padding=0,
           dilation=1,
           groups=1,
           param_attr=None,
           bias_attr=None,
           use_cudnn=True,
           act=None,
           dtype='float32'):
    conv = L.Conv1D(num_channels, num_filters, filter_size, stride, padding,
                    dilation, groups, param_attr, bias_attr, use_cudnn, act,
                    dtype)
    conv = WeightNormWrapper(conv, dim=0)
    return conv


def Conv1DTranspose(num_channels,
                    num_filters,
                    filter_size,
                    padding=0,
                    stride=1,
                    dilation=1,
                    groups=1,
                    param_attr=None,
                    bias_attr=None,
                    use_cudnn=True,
                    act=None,
                    dtype='float32'):
    conv = L.Conv1DTranspose(num_channels, num_filters, filter_size, padding,
                             stride, dilation, groups, param_attr, bias_attr,
                             use_cudnn, act, dtype)
    conv = WeightNormWrapper(conv, dim=0)
    return conv


def Conv1DCell(num_channels,
               num_filters,
               filter_size,
               dilation=1,
               causal=False,
               groups=1,
               param_attr=None,
               bias_attr=None,
               use_cudnn=True,
               act=None,
               dtype='float32'):
    conv = L.Conv1DCell(num_channels, num_filters, filter_size, dilation,
                        causal, groups, param_attr, bias_attr, use_cudnn, act,
                        dtype)
    conv = WeightNormWrapper(conv, dim=0)
    return conv


def Conv2D(num_channels,
           num_filters,
           filter_size,
           stride=1,
           padding=0,
           dilation=1,
           groups=1,
           param_attr=None,
           bias_attr=None,
           use_cudnn=True,
           act=None,
           dtype='float32'):
    # a conv2d layer with weight norm wrapper
    conv = dg.Conv2D(num_channels, num_filters, filter_size, stride, padding,
                     dilation, groups, param_attr, bias_attr, use_cudnn, act,
                     dtype)
    conv = WeightNormWrapper(conv, dim=0)
    return conv


def Conv2DTranspose(num_channels,
                    num_filters,
                    filter_size,
                    output_size=None,
                    padding=0,
                    stride=1,
                    dilation=1,
                    groups=1,
                    param_attr=None,
                    bias_attr=None,
                    use_cudnn=True,
                    act=None,
                    dtype='float32'):
    # a conv2d transpose layer with weight norm wrapper.
    conv = dg.Conv2DTranspose(num_channels, num_filters, filter_size,
                              output_size, padding, stride, dilation, groups,
                              param_attr, bias_attr, use_cudnn, act, dtype)
    conv = WeightNormWrapper(conv, dim=0)
    return conv