1. fix initializers;
2. use simple random sampler; 3. clean code for gradient clipper.
This commit is contained in:
parent
e58e927c5e
commit
ddf1c4f7a7
|
@ -13,109 +13,6 @@ from paddle.fluid.dygraph import base as imperative_base
|
|||
from paddle.fluid.clip import GradientClipBase, _correct_clip_op_role_var
|
||||
|
||||
class DoubleClip(GradientClipBase):
|
||||
"""
|
||||
:alias_main: paddle.nn.GradientClipByGlobalNorm
|
||||
:alias: paddle.nn.GradientClipByGlobalNorm,paddle.nn.clip.GradientClipByGlobalNorm
|
||||
:old_api: paddle.fluid.clip.GradientClipByGlobalNorm
|
||||
|
||||
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
|
||||
:math:`t\_list` , and limit it to ``clip_norm`` .
|
||||
|
||||
- If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio.
|
||||
|
||||
- If the global norm is less than or equal to ``clip_norm`` , nothing will be done.
|
||||
|
||||
The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip``
|
||||
is not None, then only part of gradients can be selected for gradient clipping.
|
||||
|
||||
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
|
||||
(for example: :ref:`api_fluid_optimizer_SGDOptimizer`).
|
||||
|
||||
The clipping formula is:
|
||||
|
||||
.. math::
|
||||
|
||||
t\_list[i] = t\_list[i] * \\frac{clip\_norm}{\max(global\_norm, clip\_norm)}
|
||||
|
||||
where:
|
||||
|
||||
.. math::
|
||||
|
||||
global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}
|
||||
|
||||
Args:
|
||||
clip_norm (float): The maximum norm value.
|
||||
group_name (str, optional): The group name for this clip. Default value is ``default_group``
|
||||
need_clip (function, optional): Type: function. This function accepts a ``Parameter`` and returns ``bool``
|
||||
(True: the gradient of this ``Parameter`` need to be clipped, False: not need). Default: None,
|
||||
and gradients of all parameters in the network will be clipped.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
# use for Static mode
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import numpy as np
|
||||
|
||||
main_prog = fluid.Program()
|
||||
startup_prog = fluid.Program()
|
||||
with fluid.program_guard(
|
||||
main_program=main_prog, startup_program=startup_prog):
|
||||
image = fluid.data(
|
||||
name='x', shape=[-1, 2], dtype='float32')
|
||||
predict = fluid.layers.fc(input=image, size=3, act='relu') # Trainable parameters: fc_0.w.0, fc_0.b.0
|
||||
loss = fluid.layers.mean(predict)
|
||||
|
||||
# Clip all parameters in network:
|
||||
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
|
||||
|
||||
# Clip a part of parameters in network: (e.g. fc_0.w_0)
|
||||
# pass a function(fileter_func) to need_clip, and fileter_func receive a ParamBase, and return bool
|
||||
# def fileter_func(Parameter):
|
||||
# # It can be easily filtered by Parameter.name (name can be set in fluid.ParamAttr, and the default name is fc_0.w_0, fc_0.b_0)
|
||||
# return Parameter.name=="fc_0.w_0"
|
||||
# clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func)
|
||||
|
||||
sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip)
|
||||
sgd_optimizer.minimize(loss)
|
||||
|
||||
place = fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
x = np.random.uniform(-100, 100, (10, 2)).astype('float32')
|
||||
exe.run(startup_prog)
|
||||
out = exe.run(main_prog, feed={'x': x}, fetch_list=loss)
|
||||
|
||||
|
||||
# use for Dygraph mode
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
with fluid.dygraph.guard():
|
||||
linear = fluid.dygraph.Linear(10, 10) # Trainable: linear_0.w.0, linear_0.b.0
|
||||
inputs = fluid.layers.uniform_random([32, 10]).astype('float32')
|
||||
out = linear(fluid.dygraph.to_variable(inputs))
|
||||
loss = fluid.layers.reduce_mean(out)
|
||||
loss.backward()
|
||||
|
||||
# Clip all parameters in network:
|
||||
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
|
||||
|
||||
# Clip a part of parameters in network: (e.g. linear_0.w_0)
|
||||
# pass a function(fileter_func) to need_clip, and fileter_func receive a ParamBase, and return bool
|
||||
# def fileter_func(ParamBase):
|
||||
# # It can be easily filtered by ParamBase.name(name can be set in fluid.ParamAttr, and the default name is linear_0.w_0, linear_0.b_0)
|
||||
# return ParamBase.name == "linear_0.w_0"
|
||||
# # Note: linear.weight and linear.bias can return the weight and bias of dygraph.Linear, respectively, and can be used to filter
|
||||
# return ParamBase.name == linear.weight.name
|
||||
# clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func)
|
||||
|
||||
sgd_optimizer = fluid.optimizer.SGD(
|
||||
learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip)
|
||||
sgd_optimizer.minimize(loss)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, clip_value, clip_norm, group_name="default_group", need_clip=None):
|
||||
super(DoubleClip, self).__init__(need_clip)
|
||||
self.clip_value = float(clip_value)
|
||||
|
@ -128,8 +25,13 @@ class DoubleClip(GradientClipBase):
|
|||
|
||||
@imperative_base.no_grad
|
||||
def _dygraph_clip(self, params_grads):
|
||||
params_grads = self._dygraph_clip_by_value(params_grads)
|
||||
params_grads = self._dygraph_clip_by_global_norm(params_grads)
|
||||
return params_grads
|
||||
|
||||
@imperative_base.no_grad
|
||||
def _dygraph_clip_by_value(self, params_grads):
|
||||
params_and_grads = []
|
||||
# clip by value first
|
||||
for p, g in params_grads:
|
||||
if g is None:
|
||||
continue
|
||||
|
@ -138,9 +40,10 @@ class DoubleClip(GradientClipBase):
|
|||
continue
|
||||
new_grad = layers.clip(x=g, min=-self.clip_value, max=self.clip_value)
|
||||
params_and_grads.append((p, new_grad))
|
||||
params_grads = params_and_grads
|
||||
|
||||
# clip by global norm
|
||||
return params_and_grads
|
||||
|
||||
@imperative_base.no_grad
|
||||
def _dygraph_clip_by_global_norm(self, params_grads):
|
||||
params_and_grads = []
|
||||
sum_square_list = []
|
||||
for p, g in params_grads:
|
||||
|
@ -178,4 +81,4 @@ class DoubleClip(GradientClipBase):
|
|||
new_grad = layers.elementwise_mul(x=g, y=clip_var)
|
||||
params_and_grads.append((p, new_grad))
|
||||
|
||||
return params_and_grads
|
||||
return params_and_grads
|
|
@ -7,12 +7,13 @@ import tqdm
|
|||
import paddle
|
||||
from paddle import fluid
|
||||
from paddle.fluid import layers as F
|
||||
from paddle.fluid import initializer as I
|
||||
from paddle.fluid import dygraph as dg
|
||||
from paddle.fluid.io import DataLoader
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from parakeet.models.deepvoice3 import Encoder, Decoder, PostNet, SpectraNet
|
||||
from parakeet.data import SliceDataset, DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
|
||||
from parakeet.data import SliceDataset, DataCargo, SequentialSampler, RandomSampler
|
||||
from parakeet.utils.io import save_parameters, load_parameters
|
||||
from parakeet.g2p import en
|
||||
|
||||
|
@ -22,9 +23,9 @@ from clip import DoubleClip
|
|||
|
||||
|
||||
def create_model(config):
|
||||
char_embedding = dg.Embedding((en.n_vocab, config["char_dim"]))
|
||||
char_embedding = dg.Embedding((en.n_vocab, config["char_dim"]), param_attr=I.Normal(scale=0.1))
|
||||
multi_speaker = config["n_speakers"] > 1
|
||||
speaker_embedding = dg.Embedding((config["n_speakers"], config["speaker_dim"])) \
|
||||
speaker_embedding = dg.Embedding((config["n_speakers"], config["speaker_dim"]), param_attr=I.Normal(scale=0.1)) \
|
||||
if multi_speaker else None
|
||||
encoder = Encoder(config["encoder_layers"], config["char_dim"],
|
||||
config["encoder_dim"], config["kernel_size"],
|
||||
|
@ -51,8 +52,7 @@ def create_data(config, data_path):
|
|||
|
||||
train_dataset = SliceDataset(dataset, config["valid_size"], len(dataset))
|
||||
train_collator = DataCollector(config["p_pronunciation"])
|
||||
train_sampler = PartialyRandomizedSimilarTimeLengthSampler(
|
||||
dataset.num_frames()[config["valid_size"]:])
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
train_cargo = DataCargo(train_dataset, train_collator,
|
||||
batch_size=config["batch_size"], sampler=train_sampler)
|
||||
train_loader = DataLoader\
|
||||
|
@ -81,7 +81,7 @@ def train(args, config):
|
|||
optim = create_optimizer(model, config)
|
||||
|
||||
global global_step
|
||||
max_iteration = 2000000
|
||||
max_iteration = 1000000
|
||||
|
||||
iterator = iter(tqdm.tqdm(train_loader))
|
||||
while global_step <= max_iteration:
|
||||
|
|
|
@ -39,15 +39,15 @@ class ConvBlock(dg.Layer):
|
|||
self.has_bias = has_bias
|
||||
|
||||
std = np.sqrt(4 * keep_prob / (kernel_size * in_channel))
|
||||
initializer = I.NormalInitializer(loc=0., scale=std)
|
||||
padding = "valid" if causal else "same"
|
||||
conv = Conv1D(in_channel, 2 * in_channel, (kernel_size, ),
|
||||
padding=padding,
|
||||
data_format="NTC",
|
||||
param_attr=initializer)
|
||||
param_attr=I.Normal(scale=std))
|
||||
self.conv = weight_norm(conv)
|
||||
if has_bias:
|
||||
self.bias_affine = dg.Linear(bias_dim, 2 * in_channel)
|
||||
std = np.sqrt(1 / bias_dim)
|
||||
self.bias_affine = dg.Linear(bias_dim, 2 * in_channel, param_attr=I.Normal(scale=std))
|
||||
|
||||
def forward(self, input, bias=None, padding=None):
|
||||
"""
|
||||
|
@ -82,11 +82,11 @@ class AffineBlock1(dg.Layer):
|
|||
def __init__(self, in_channel, out_channel, has_bias=False, bias_dim=0):
|
||||
super(AffineBlock1, self).__init__()
|
||||
std = np.sqrt(1.0 / in_channel)
|
||||
initializer = I.NormalInitializer(loc=0., scale=std)
|
||||
affine = dg.Linear(in_channel, out_channel, param_attr=initializer)
|
||||
affine = dg.Linear(in_channel, out_channel, param_attr=I.Normal(scale=std))
|
||||
self.affine = weight_norm(affine, dim=-1)
|
||||
if has_bias:
|
||||
self.bias_affine = dg.Linear(bias_dim, out_channel)
|
||||
std = np.sqrt(1 / bias_dim)
|
||||
self.bias_affine = dg.Linear(bias_dim, out_channel, param_attr=I.Normal(scale=std))
|
||||
|
||||
self.has_bias = has_bias
|
||||
self.bias_dim = bias_dim
|
||||
|
@ -110,10 +110,10 @@ class AffineBlock2(dg.Layer):
|
|||
has_bias=False, bias_dim=0, dropout=False, keep_prob=1.):
|
||||
super(AffineBlock2, self).__init__()
|
||||
if has_bias:
|
||||
self.bias_affine = dg.Linear(bias_dim, in_channel)
|
||||
std = np.sqrt(1 / bias_dim)
|
||||
self.bias_affine = dg.Linear(bias_dim, in_channel, param_attr=I.Normal(scale=std))
|
||||
std = np.sqrt(1.0 / in_channel)
|
||||
initializer = I.NormalInitializer(loc=0., scale=std)
|
||||
affine = dg.Linear(in_channel, out_channel, param_attr=initializer)
|
||||
affine = dg.Linear(in_channel, out_channel, param_attr=I.Normal(scale=std))
|
||||
self.affine = weight_norm(affine, dim=-1)
|
||||
|
||||
self.has_bias = has_bias
|
||||
|
@ -171,9 +171,8 @@ class AttentionBlock(dg.Layer):
|
|||
# multispeaker case
|
||||
if has_bias:
|
||||
std = np.sqrt(1.0 / bias_dim)
|
||||
initializer = I.NormalInitializer(loc=0., scale=std)
|
||||
self.q_pos_affine = dg.Linear(bias_dim, 1, param_attr=initializer)
|
||||
self.k_pos_affine = dg.Linear(bias_dim, 1, param_attr=initializer)
|
||||
self.q_pos_affine = dg.Linear(bias_dim, 1, param_attr=I.Normal(scale=std))
|
||||
self.k_pos_affine = dg.Linear(bias_dim, 1, param_attr=I.Normal(scale=std))
|
||||
self.omega_initial = self.create_parameter(shape=[1],
|
||||
attr=I.ConstantInitializer(value=omega_default))
|
||||
|
||||
|
@ -184,21 +183,17 @@ class AttentionBlock(dg.Layer):
|
|||
scale=np.sqrt(1. / input_dim))
|
||||
initializer = I.NumpyArrayInitializer(init_weight.astype(np.float32))
|
||||
# 3 affine transformation to project q, k, v into attention_dim
|
||||
q_affine = dg.Linear(input_dim, attention_dim,
|
||||
param_attr=initializer)
|
||||
q_affine = dg.Linear(input_dim, attention_dim, param_attr=initializer)
|
||||
self.q_affine = weight_norm(q_affine, dim=-1)
|
||||
k_affine = dg.Linear(input_dim, attention_dim,
|
||||
param_attr=initializer)
|
||||
k_affine = dg.Linear(input_dim, attention_dim, param_attr=initializer)
|
||||
self.k_affine = weight_norm(k_affine, dim=-1)
|
||||
|
||||
std = np.sqrt(1.0 / input_dim)
|
||||
initializer = I.NormalInitializer(loc=0., scale=std)
|
||||
v_affine = dg.Linear(input_dim, attention_dim, param_attr=initializer)
|
||||
v_affine = dg.Linear(input_dim, attention_dim, param_attr=I.Normal(scale=std))
|
||||
self.v_affine = weight_norm(v_affine, dim=-1)
|
||||
|
||||
std = np.sqrt(1.0 / attention_dim)
|
||||
initializer = I.NormalInitializer(loc=0., scale=std)
|
||||
out_affine = dg.Linear(attention_dim, input_dim, param_attr=initializer)
|
||||
out_affine = dg.Linear(attention_dim, input_dim, param_attr=I.Normal(scale=std))
|
||||
self.out_affine = weight_norm(out_affine, dim=-1)
|
||||
|
||||
self.keep_prob = keep_prob
|
||||
|
@ -289,11 +284,11 @@ class Decoder(dg.Layer):
|
|||
# output mel spectrogram
|
||||
output_dim = reduction_factor * in_channels # r * mel_dim
|
||||
std = np.sqrt(1.0 / decoder_dim)
|
||||
initializer = I.NormalInitializer(loc=0., scale=std)
|
||||
out_affine = dg.Linear(decoder_dim, output_dim, param_attr=initializer)
|
||||
out_affine = dg.Linear(decoder_dim, output_dim, param_attr=I.Normal(scale=std))
|
||||
self.out_affine = weight_norm(out_affine, dim=-1)
|
||||
if has_bias:
|
||||
self.out_sp_affine = dg.Linear(bias_dim, output_dim)
|
||||
std = np.sqrt(1 / bias_dim)
|
||||
self.out_sp_affine = dg.Linear(bias_dim, output_dim, param_attr=I.Normal(scale=std))
|
||||
|
||||
self.has_bias = has_bias
|
||||
self.kernel_size = kernel_size
|
||||
|
@ -351,8 +346,7 @@ class PostNet(dg.Layer):
|
|||
ConvBlock(postnet_dim, kernel_size, False, has_bias, bias_dim, keep_prob) for _ in range(layers)
|
||||
])
|
||||
std = np.sqrt(1.0 / postnet_dim)
|
||||
initializer = I.NormalInitializer(loc=0., scale=std)
|
||||
post_affine = dg.Linear(postnet_dim, out_channels, param_attr=initializer)
|
||||
post_affine = dg.Linear(postnet_dim, out_channels, param_attr=I.Normal(scale=std))
|
||||
self.post_affine = weight_norm(post_affine, dim=-1)
|
||||
self.upsample_factor = upsample_factor
|
||||
|
||||
|
|
Loading…
Reference in New Issue