ParakeetRebeccaRosario/parakeet/modules/weight_norm.py

252 lines
8.1 KiB
Python

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as F
from parakeet.modules import customized as L
def norm(param, dim, power):
powered = F.pow(param, power)
in_dtype = powered.dtype
if in_dtype == fluid.core.VarDesc.VarType.FP16:
powered = F.cast(powered, "float32")
powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False)
norm_ = F.pow(powered_norm, 1. / power)
if in_dtype == fluid.core.VarDesc.VarType.FP16:
norm_ = F.cast(norm_, "float16")
return norm_
def norm_except(param, dim, power):
"""Computes the norm over all dimensions except dim.
It differs from pytorch implementation that it does not keep dim.
This difference is related with the broadcast mechanism in paddle.
Read elementeise_mul for more.
"""
shape = param.shape
ndim = len(shape)
if dim is None:
return norm(param, dim, power)
elif dim == 0:
param_matrix = F.reshape(param, (shape[0], np.prod(shape[1:])))
return norm(param_matrix, dim=1, power=power)
elif dim == -1 or dim == ndim - 1:
param_matrix = F.reshape(param, (np.prod(shape[:-1]), shape[-1]))
return norm(param_matrix, dim=0, power=power)
else:
perm = list(range(ndim))
perm[0] = dim
perm[dim] = 0
transposed_param = F.transpose(param, perm)
return norm_except(transposed_param, dim=0, power=power)
def compute_weight(v, g, dim, power):
assert len(g.shape) == 1, "magnitude should be a vector"
v_normalized = F.elementwise_div(
v, (norm_except(v, dim, power) + 1e-12), axis=dim)
weight = F.elementwise_mul(v_normalized, g, axis=dim)
return weight
def assign_by_cast(i, o):
fluid.default_main_program().current_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={"in_dtype": i.dtype,
"out_dtype": o.dtype})
class WeightNormWrapper(dg.Layer):
def __init__(self, layer, param_name="weight", dim=0, power=2):
super(WeightNormWrapper, self).__init__()
self.param_name = param_name
self.dim = dim
self.power = power
self.layer = layer
w_v = param_name + "_v"
w_g = param_name + "_g"
# we could also use numpy to compute this, after all, it is run only once
# at initialization.
original_weight = getattr(layer, param_name)
self.add_parameter(
w_v,
self.create_parameter(
shape=original_weight.shape, dtype=original_weight.dtype))
assign_by_cast(original_weight, getattr(self, w_v))
delattr(layer, param_name)
temp = norm_except(getattr(self, w_v), self.dim, self.power)
self.add_parameter(
w_g, self.create_parameter(
shape=temp.shape, dtype=temp.dtype))
assign_by_cast(temp, getattr(self, w_g))
# also set this when setting up
setattr(self.layer, self.param_name,
compute_weight(
getattr(self, w_v),
getattr(self, w_g), self.dim, self.power))
self.weigth_norm_applied = True
# hook to compute weight with v & g
def hook(self):
w_v = self.param_name + "_v"
w_g = self.param_name + "_g"
setattr(self.layer, self.param_name,
compute_weight(
getattr(self, w_v),
getattr(self, w_g), self.dim, self.power))
def remove_weight_norm(self):
self.hook()
self.weigth_norm_applied = False
def forward(self, *args, **kwargs):
if self.weigth_norm_applied == True:
self.hook()
return self.layer(*args, **kwargs)
def __getattr__(self, key):
"""
this is used for attr forwarding.
"""
if key in self._parameters:
return self._parameters[key]
elif key in self._sub_layers:
return self._sub_layers[key]
elif key is "layer":
return self._sub_layers["layer"]
else:
return getattr(
object.__getattribute__(self, "_sub_layers")["layer"], key)
def Linear(input_dim,
output_dim,
param_attr=None,
bias_attr=None,
act=None,
dtype="float32"):
# a weight norm applied linear layer.
lin = dg.Linear(input_dim, output_dim, param_attr, bias_attr, act, dtype)
lin = WeightNormWrapper(lin, dim=1)
return lin
def Conv1D(num_channels,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype='float32'):
conv = L.Conv1D(num_channels, num_filters, filter_size, stride, padding,
dilation, groups, param_attr, bias_attr, use_cudnn, act,
dtype)
conv = WeightNormWrapper(conv, dim=0)
return conv
def Conv1DTranspose(num_channels,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype='float32'):
conv = L.Conv1DTranspose(num_channels, num_filters, filter_size, padding,
stride, dilation, groups, param_attr, bias_attr,
use_cudnn, act, dtype)
conv = WeightNormWrapper(conv, dim=0)
return conv
def Conv1DCell(num_channels,
num_filters,
filter_size,
dilation=1,
causal=False,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype='float32'):
conv = L.Conv1DCell(num_channels, num_filters, filter_size, dilation,
causal, groups, param_attr, bias_attr, use_cudnn, act,
dtype)
conv = WeightNormWrapper(conv, dim=0)
return conv
def Conv2D(num_channels,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype='float32'):
# a conv2d layer with weight norm wrapper
conv = dg.Conv2D(num_channels, num_filters, filter_size, stride, padding,
dilation, groups, param_attr, bias_attr, use_cudnn, act,
dtype)
conv = WeightNormWrapper(conv, dim=0)
return conv
def Conv2DTranspose(num_channels,
num_filters,
filter_size,
output_size=None,
padding=0,
stride=1,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype='float32'):
# a conv2d transpose layer with weight norm wrapper.
conv = dg.Conv2DTranspose(num_channels, num_filters, filter_size,
output_size, padding, stride, dilation, groups,
param_attr, bias_attr, use_cudnn, act, dtype)
conv = WeightNormWrapper(conv, dim=0)
return conv