From ddf1c4f7a7356e61ebdad78ba8a5e0979c25366c Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Wed, 29 Jul 2020 11:54:47 +0800 Subject: [PATCH] 1. fix initializers; 2. use simple random sampler; 3. clean code for gradient clipper. --- examples/deepvoice3/clip.py | 119 +++------------------------- examples/deepvoice3/train.py | 12 +-- parakeet/models/deepvoice3/model.py | 44 +++++----- 3 files changed, 36 insertions(+), 139 deletions(-) diff --git a/examples/deepvoice3/clip.py b/examples/deepvoice3/clip.py index 0a31320..0a4f998 100644 --- a/examples/deepvoice3/clip.py +++ b/examples/deepvoice3/clip.py @@ -13,109 +13,6 @@ from paddle.fluid.dygraph import base as imperative_base from paddle.fluid.clip import GradientClipBase, _correct_clip_op_role_var class DoubleClip(GradientClipBase): - """ - :alias_main: paddle.nn.GradientClipByGlobalNorm - :alias: paddle.nn.GradientClipByGlobalNorm,paddle.nn.clip.GradientClipByGlobalNorm - :old_api: paddle.fluid.clip.GradientClipByGlobalNorm - - Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in - :math:`t\_list` , and limit it to ``clip_norm`` . - - - If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio. - - - If the global norm is less than or equal to ``clip_norm`` , nothing will be done. - - The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip`` - is not None, then only part of gradients can be selected for gradient clipping. - - Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` - (for example: :ref:`api_fluid_optimizer_SGDOptimizer`). - - The clipping formula is: - - .. math:: - - t\_list[i] = t\_list[i] * \\frac{clip\_norm}{\max(global\_norm, clip\_norm)} - - where: - - .. math:: - - global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2} - - Args: - clip_norm (float): The maximum norm value. - group_name (str, optional): The group name for this clip. Default value is ``default_group`` - need_clip (function, optional): Type: function. This function accepts a ``Parameter`` and returns ``bool`` - (True: the gradient of this ``Parameter`` need to be clipped, False: not need). Default: None, - and gradients of all parameters in the network will be clipped. - - Examples: - .. code-block:: python - - # use for Static mode - import paddle - import paddle.fluid as fluid - import numpy as np - - main_prog = fluid.Program() - startup_prog = fluid.Program() - with fluid.program_guard( - main_program=main_prog, startup_program=startup_prog): - image = fluid.data( - name='x', shape=[-1, 2], dtype='float32') - predict = fluid.layers.fc(input=image, size=3, act='relu') # Trainable parameters: fc_0.w.0, fc_0.b.0 - loss = fluid.layers.mean(predict) - - # Clip all parameters in network: - clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0) - - # Clip a part of parameters in network: (e.g. fc_0.w_0) - # pass a function(fileter_func) to need_clip, and fileter_func receive a ParamBase, and return bool - # def fileter_func(Parameter): - # # It can be easily filtered by Parameter.name (name can be set in fluid.ParamAttr, and the default name is fc_0.w_0, fc_0.b_0) - # return Parameter.name=="fc_0.w_0" - # clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func) - - sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip) - sgd_optimizer.minimize(loss) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - x = np.random.uniform(-100, 100, (10, 2)).astype('float32') - exe.run(startup_prog) - out = exe.run(main_prog, feed={'x': x}, fetch_list=loss) - - - # use for Dygraph mode - import paddle - import paddle.fluid as fluid - - with fluid.dygraph.guard(): - linear = fluid.dygraph.Linear(10, 10) # Trainable: linear_0.w.0, linear_0.b.0 - inputs = fluid.layers.uniform_random([32, 10]).astype('float32') - out = linear(fluid.dygraph.to_variable(inputs)) - loss = fluid.layers.reduce_mean(out) - loss.backward() - - # Clip all parameters in network: - clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0) - - # Clip a part of parameters in network: (e.g. linear_0.w_0) - # pass a function(fileter_func) to need_clip, and fileter_func receive a ParamBase, and return bool - # def fileter_func(ParamBase): - # # It can be easily filtered by ParamBase.name(name can be set in fluid.ParamAttr, and the default name is linear_0.w_0, linear_0.b_0) - # return ParamBase.name == "linear_0.w_0" - # # Note: linear.weight and linear.bias can return the weight and bias of dygraph.Linear, respectively, and can be used to filter - # return ParamBase.name == linear.weight.name - # clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func) - - sgd_optimizer = fluid.optimizer.SGD( - learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip) - sgd_optimizer.minimize(loss) - - """ - def __init__(self, clip_value, clip_norm, group_name="default_group", need_clip=None): super(DoubleClip, self).__init__(need_clip) self.clip_value = float(clip_value) @@ -128,8 +25,13 @@ class DoubleClip(GradientClipBase): @imperative_base.no_grad def _dygraph_clip(self, params_grads): + params_grads = self._dygraph_clip_by_value(params_grads) + params_grads = self._dygraph_clip_by_global_norm(params_grads) + return params_grads + + @imperative_base.no_grad + def _dygraph_clip_by_value(self, params_grads): params_and_grads = [] - # clip by value first for p, g in params_grads: if g is None: continue @@ -138,9 +40,10 @@ class DoubleClip(GradientClipBase): continue new_grad = layers.clip(x=g, min=-self.clip_value, max=self.clip_value) params_and_grads.append((p, new_grad)) - params_grads = params_and_grads - - # clip by global norm + return params_and_grads + + @imperative_base.no_grad + def _dygraph_clip_by_global_norm(self, params_grads): params_and_grads = [] sum_square_list = [] for p, g in params_grads: @@ -178,4 +81,4 @@ class DoubleClip(GradientClipBase): new_grad = layers.elementwise_mul(x=g, y=clip_var) params_and_grads.append((p, new_grad)) - return params_and_grads + return params_and_grads \ No newline at end of file diff --git a/examples/deepvoice3/train.py b/examples/deepvoice3/train.py index 76ced55..07f5c94 100644 --- a/examples/deepvoice3/train.py +++ b/examples/deepvoice3/train.py @@ -7,12 +7,13 @@ import tqdm import paddle from paddle import fluid from paddle.fluid import layers as F +from paddle.fluid import initializer as I from paddle.fluid import dygraph as dg from paddle.fluid.io import DataLoader from tensorboardX import SummaryWriter from parakeet.models.deepvoice3 import Encoder, Decoder, PostNet, SpectraNet -from parakeet.data import SliceDataset, DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler +from parakeet.data import SliceDataset, DataCargo, SequentialSampler, RandomSampler from parakeet.utils.io import save_parameters, load_parameters from parakeet.g2p import en @@ -22,9 +23,9 @@ from clip import DoubleClip def create_model(config): - char_embedding = dg.Embedding((en.n_vocab, config["char_dim"])) + char_embedding = dg.Embedding((en.n_vocab, config["char_dim"]), param_attr=I.Normal(scale=0.1)) multi_speaker = config["n_speakers"] > 1 - speaker_embedding = dg.Embedding((config["n_speakers"], config["speaker_dim"])) \ + speaker_embedding = dg.Embedding((config["n_speakers"], config["speaker_dim"]), param_attr=I.Normal(scale=0.1)) \ if multi_speaker else None encoder = Encoder(config["encoder_layers"], config["char_dim"], config["encoder_dim"], config["kernel_size"], @@ -51,8 +52,7 @@ def create_data(config, data_path): train_dataset = SliceDataset(dataset, config["valid_size"], len(dataset)) train_collator = DataCollector(config["p_pronunciation"]) - train_sampler = PartialyRandomizedSimilarTimeLengthSampler( - dataset.num_frames()[config["valid_size"]:]) + train_sampler = RandomSampler(train_dataset) train_cargo = DataCargo(train_dataset, train_collator, batch_size=config["batch_size"], sampler=train_sampler) train_loader = DataLoader\ @@ -81,7 +81,7 @@ def train(args, config): optim = create_optimizer(model, config) global global_step - max_iteration = 2000000 + max_iteration = 1000000 iterator = iter(tqdm.tqdm(train_loader)) while global_step <= max_iteration: diff --git a/parakeet/models/deepvoice3/model.py b/parakeet/models/deepvoice3/model.py index b6ac702..1da13a9 100644 --- a/parakeet/models/deepvoice3/model.py +++ b/parakeet/models/deepvoice3/model.py @@ -39,15 +39,15 @@ class ConvBlock(dg.Layer): self.has_bias = has_bias std = np.sqrt(4 * keep_prob / (kernel_size * in_channel)) - initializer = I.NormalInitializer(loc=0., scale=std) padding = "valid" if causal else "same" conv = Conv1D(in_channel, 2 * in_channel, (kernel_size, ), padding=padding, data_format="NTC", - param_attr=initializer) + param_attr=I.Normal(scale=std)) self.conv = weight_norm(conv) if has_bias: - self.bias_affine = dg.Linear(bias_dim, 2 * in_channel) + std = np.sqrt(1 / bias_dim) + self.bias_affine = dg.Linear(bias_dim, 2 * in_channel, param_attr=I.Normal(scale=std)) def forward(self, input, bias=None, padding=None): """ @@ -82,11 +82,11 @@ class AffineBlock1(dg.Layer): def __init__(self, in_channel, out_channel, has_bias=False, bias_dim=0): super(AffineBlock1, self).__init__() std = np.sqrt(1.0 / in_channel) - initializer = I.NormalInitializer(loc=0., scale=std) - affine = dg.Linear(in_channel, out_channel, param_attr=initializer) + affine = dg.Linear(in_channel, out_channel, param_attr=I.Normal(scale=std)) self.affine = weight_norm(affine, dim=-1) if has_bias: - self.bias_affine = dg.Linear(bias_dim, out_channel) + std = np.sqrt(1 / bias_dim) + self.bias_affine = dg.Linear(bias_dim, out_channel, param_attr=I.Normal(scale=std)) self.has_bias = has_bias self.bias_dim = bias_dim @@ -110,10 +110,10 @@ class AffineBlock2(dg.Layer): has_bias=False, bias_dim=0, dropout=False, keep_prob=1.): super(AffineBlock2, self).__init__() if has_bias: - self.bias_affine = dg.Linear(bias_dim, in_channel) + std = np.sqrt(1 / bias_dim) + self.bias_affine = dg.Linear(bias_dim, in_channel, param_attr=I.Normal(scale=std)) std = np.sqrt(1.0 / in_channel) - initializer = I.NormalInitializer(loc=0., scale=std) - affine = dg.Linear(in_channel, out_channel, param_attr=initializer) + affine = dg.Linear(in_channel, out_channel, param_attr=I.Normal(scale=std)) self.affine = weight_norm(affine, dim=-1) self.has_bias = has_bias @@ -171,9 +171,8 @@ class AttentionBlock(dg.Layer): # multispeaker case if has_bias: std = np.sqrt(1.0 / bias_dim) - initializer = I.NormalInitializer(loc=0., scale=std) - self.q_pos_affine = dg.Linear(bias_dim, 1, param_attr=initializer) - self.k_pos_affine = dg.Linear(bias_dim, 1, param_attr=initializer) + self.q_pos_affine = dg.Linear(bias_dim, 1, param_attr=I.Normal(scale=std)) + self.k_pos_affine = dg.Linear(bias_dim, 1, param_attr=I.Normal(scale=std)) self.omega_initial = self.create_parameter(shape=[1], attr=I.ConstantInitializer(value=omega_default)) @@ -184,21 +183,17 @@ class AttentionBlock(dg.Layer): scale=np.sqrt(1. / input_dim)) initializer = I.NumpyArrayInitializer(init_weight.astype(np.float32)) # 3 affine transformation to project q, k, v into attention_dim - q_affine = dg.Linear(input_dim, attention_dim, - param_attr=initializer) + q_affine = dg.Linear(input_dim, attention_dim, param_attr=initializer) self.q_affine = weight_norm(q_affine, dim=-1) - k_affine = dg.Linear(input_dim, attention_dim, - param_attr=initializer) + k_affine = dg.Linear(input_dim, attention_dim, param_attr=initializer) self.k_affine = weight_norm(k_affine, dim=-1) std = np.sqrt(1.0 / input_dim) - initializer = I.NormalInitializer(loc=0., scale=std) - v_affine = dg.Linear(input_dim, attention_dim, param_attr=initializer) + v_affine = dg.Linear(input_dim, attention_dim, param_attr=I.Normal(scale=std)) self.v_affine = weight_norm(v_affine, dim=-1) std = np.sqrt(1.0 / attention_dim) - initializer = I.NormalInitializer(loc=0., scale=std) - out_affine = dg.Linear(attention_dim, input_dim, param_attr=initializer) + out_affine = dg.Linear(attention_dim, input_dim, param_attr=I.Normal(scale=std)) self.out_affine = weight_norm(out_affine, dim=-1) self.keep_prob = keep_prob @@ -289,11 +284,11 @@ class Decoder(dg.Layer): # output mel spectrogram output_dim = reduction_factor * in_channels # r * mel_dim std = np.sqrt(1.0 / decoder_dim) - initializer = I.NormalInitializer(loc=0., scale=std) - out_affine = dg.Linear(decoder_dim, output_dim, param_attr=initializer) + out_affine = dg.Linear(decoder_dim, output_dim, param_attr=I.Normal(scale=std)) self.out_affine = weight_norm(out_affine, dim=-1) if has_bias: - self.out_sp_affine = dg.Linear(bias_dim, output_dim) + std = np.sqrt(1 / bias_dim) + self.out_sp_affine = dg.Linear(bias_dim, output_dim, param_attr=I.Normal(scale=std)) self.has_bias = has_bias self.kernel_size = kernel_size @@ -351,8 +346,7 @@ class PostNet(dg.Layer): ConvBlock(postnet_dim, kernel_size, False, has_bias, bias_dim, keep_prob) for _ in range(layers) ]) std = np.sqrt(1.0 / postnet_dim) - initializer = I.NormalInitializer(loc=0., scale=std) - post_affine = dg.Linear(postnet_dim, out_channels, param_attr=initializer) + post_affine = dg.Linear(postnet_dim, out_channels, param_attr=I.Normal(scale=std)) self.post_affine = weight_norm(post_affine, dim=-1) self.upsample_factor = upsample_factor