import numpy as np from paddle import fluid import paddle.fluid.dygraph as dg import paddle.fluid.layers as F from parakeet.modules import customized as L # TODO: just use numpy to init weight norm wrappers def norm(param, dim, power): powered = F.pow(param, power) powered_norm = F.reduce_sum(powered, dim=dim, keep_dim=False) norm_ = F.pow(powered_norm, 1. / power) return norm_ def norm_except(param, dim, power): """Computes the norm over all dimensions except dim. It differs from pytorch implementation that it does not keep dim. This difference is related with the broadcast mechanism in paddle. Read elementeise_mul for more. """ shape = param.shape ndim = len(shape) if dim is None: return norm(param, dim, power) elif dim == 0: param_matrix = F.reshape(param, (shape[0], np.prod(shape[1:]))) return norm(param_matrix, dim=1, power=power) elif dim == -1 or dim == ndim - 1: param_matrix = F.reshape(param, (np.prod(shape[:-1]), shape[-1])) return norm(param_matrix, dim=0, power=power) else: perm = list(range(ndim)) perm[0] = dim perm[dim] = 0 transposed_param = F.transpose(param, perm) return norm_except(transposed_param, dim=0, power=power) def compute_weight(v, g, dim, power): assert len(g.shape) == 1, "magnitude should be a vector" v_normalized = F.elementwise_div(v, (norm_except(v, dim, power) + 1e-12), axis=dim) weight = F.elementwise_mul(v_normalized, g, axis=dim) return weight class WeightNormWrapper(dg.Layer): def __init__(self, layer, param_name="weight", dim=0, power=2): super(WeightNormWrapper, self).__init__() self.param_name = param_name self.dim = dim self.power = power self.layer = layer w_v = param_name + "_v" w_g = param_name + "_g" # we could also use numpy to compute this, after all, it is run only once # at initialization. original_weight = getattr(layer, param_name) self.add_parameter( w_v, self.create_parameter(shape=original_weight.shape, dtype=original_weight.dtype)) F.assign(original_weight, getattr(self, w_v)) delattr(layer, param_name) temp = norm_except(getattr(self, w_v), self.dim, self.power) self.add_parameter( w_g, self.create_parameter(shape=temp.shape, dtype=temp.dtype)) F.assign(temp, getattr(self, w_g)) # also set this setattr( self.layer, self.param_name, compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim, self.power)) self.weigth_norm_applied = True # hook to compute weight with v & g def hook(self): w_v = self.param_name + "_v" w_g = self.param_name + "_g" setattr( self.layer, self.param_name, compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim, self.power)) def remove_weight_norm(self): self.hook() self.weigth_norm_applied = False def forward(self, *args, **kwargs): if self.weigth_norm_applied == True: self.hook() return self.layer(*args, **kwargs) def __getattr__(self, key): """ this is used for attr forwarding. """ if key in self._parameters: return self._parameters[key] elif key in self._sub_layers: return self._sub_layers[key] elif key is "layer": return self._sub_layers["layer"] else: return getattr( object.__getattribute__(self, "_sub_layers")["layer"], key) def Linear(input_dim, output_dim, param_attr=None, bias_attr=None, act=None, dtype="float32"): # a weight norm applied linear layer. lin = dg.Linear(input_dim, output_dim, param_attr, bias_attr, act, dtype) lin = WeightNormWrapper(lin, dim=1) return lin def Conv1D(num_channels, num_filters, filter_size, stride=1, padding=0, dilation=1, groups=None, param_attr=None, bias_attr=None, use_cudnn=True, act=None, dtype='float32'): conv = L.Conv1D(num_channels, num_filters, filter_size, stride, padding, dilation, groups, param_attr, bias_attr, use_cudnn, act, dtype) conv = WeightNormWrapper(conv, dim=0) return conv def Conv1DTranspose(num_channels, num_filters, filter_size, padding=0, stride=1, dilation=1, groups=None, param_attr=None, bias_attr=None, use_cudnn=True, act=None, dtype='float32'): conv = L.Conv1DTranspose(num_channels, num_filters, filter_size, padding, stride, dilation, groups, param_attr, bias_attr, use_cudnn, act, dtype) conv = WeightNormWrapper(conv, dim=0) return conv def Conv1DCell(num_channels, num_filters, filter_size, dilation=1, causal=False, groups=None, param_attr=None, bias_attr=None, use_cudnn=True, act=None, dtype='float32'): conv = L.Conv1DCell(num_channels, num_filters, filter_size, dilation, causal, groups, param_attr, bias_attr, use_cudnn, act, dtype) conv = WeightNormWrapper(conv, dim=0) return conv def Conv2D(num_channels, num_filters, filter_size, stride=1, padding=0, dilation=1, groups=None, param_attr=None, bias_attr=None, use_cudnn=True, act=None, dtype='float32'): # a conv2d layer with weight norm wrapper conv = dg.Conv2D(num_channels, num_filters, filter_size, stride, padding, dilation, groups, param_attr, bias_attr, use_cudnn, act, dtype) conv = WeightNormWrapper(conv, dim=0) return conv def Conv2DTranspose(num_channels, num_filters, filter_size, output_size=None, padding=0, stride=1, dilation=1, groups=None, param_attr=None, bias_attr=None, use_cudnn=True, act=None, dtype='float32'): # a conv2d transpose layer with weight norm wrapper. conv = dg.Conv2DTranspose(num_channels, num_filters, filter_size, output_size, padding, stride, dilation, groups, param_attr, bias_attr, use_cudnn, act, dtype) conv = WeightNormWrapper(conv, dim=0) return conv