diff --git a/README.md b/README.md index e32219b..7bb380e 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ sudo apt-get install libsndfile1 ### Install PaddlePaddle -See [install](https://www.paddlepaddle.org.cn/install/quick) for more details. This repo requires PaddlePaddle **1.7.1** or above. +See [install](https://www.paddlepaddle.org.cn/install/quick) for more details. This repo requires PaddlePaddle **1.8.0** or above. ### Install Parakeet diff --git a/examples/clarinet/train.py b/examples/clarinet/train.py index 82d9aa1..ed55702 100644 --- a/examples/clarinet/train.py +++ b/examples/clarinet/train.py @@ -163,11 +163,11 @@ if __name__ == "__main__": anneal_interval = train_config["anneal_interval"] lr_scheduler = dg.ExponentialDecay( learning_rate, anneal_interval, anneal_rate, staircase=True) - optim = fluid.optimizer.Adam( - lr_scheduler, parameter_list=model.parameters()) gradiant_max_norm = train_config["gradient_max_norm"] - clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( - gradiant_max_norm) + optim = fluid.optimizer.Adam( + lr_scheduler, + parameter_list=model.parameters(), + grad_clip=fluid.clip.ClipByGlobalNorm(gradiant_max_norm)) # train max_iterations = train_config["max_iterations"] @@ -229,7 +229,7 @@ if __name__ == "__main__": step_loss)) l.backward() - optim.minimize(l, grad_clip=clipper) + optim.minimize(l) optim.clear_gradients() if global_step % eval_interval == 0: diff --git a/examples/deepvoice3/train.py b/examples/deepvoice3/train.py index d363e6f..0d5a54b 100644 --- a/examples/deepvoice3/train.py +++ b/examples/deepvoice3/train.py @@ -196,8 +196,8 @@ if __name__ == "__main__": beta1, beta2, epsilon=epsilon, - parameter_list=dv3.parameters()) - gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1) + parameter_list=dv3.parameters(), + grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1)) # generation synthesis_config = config["synthesis"] @@ -258,15 +258,19 @@ if __name__ == "__main__": text_lengths, frames) l = losses["loss"] l.backward() + # record learning rate before updating writer.add_scalar("learning_rate", optim._learning_rate.step().numpy(), global_step) - optim.minimize(l, grad_clip=gradient_clipper) + optim.minimize(l) optim.clear_gradients() # ==================all kinds of tedious things================= # record step loss into tensorboard - step_loss = {k: v.numpy()[0] for k, v in losses.items()} + step_loss = { + k: v.numpy()[0] + for k, v in losses.items() if v is not None + } tqdm.tqdm.write("global_step: {}\tloss: {}".format( global_step, step_loss["loss"])) for k, v in step_loss.items(): diff --git a/examples/waveflow/synthesis.py b/examples/waveflow/synthesis.py index 15c4d3b..b9569bf 100644 --- a/examples/waveflow/synthesis.py +++ b/examples/waveflow/synthesis.py @@ -93,16 +93,7 @@ def synthesize(config): # Build model. model = WaveFlow(config, checkpoint_dir) - model.build(training=False) - # Obtain the current iteration. - if config.checkpoint is None: - if config.iteration is None: - iteration = io.load_latest_checkpoint(checkpoint_dir) - else: - iteration = config.iteration - else: - iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) - + iteration = model.build(training=False) # Run model inference. model.infer(iteration) diff --git a/examples/waveflow/waveflow.py b/examples/waveflow/waveflow.py index 700116b..23c558e 100644 --- a/examples/waveflow/waveflow.py +++ b/examples/waveflow/waveflow.py @@ -81,12 +81,6 @@ class WaveFlow(): waveflow = WaveFlowModule(config) - # Dry run once to create and initalize all necessary parameters. - audio = dg.to_variable(np.random.randn(1, 16000).astype(self.dtype)) - mel = dg.to_variable( - np.random.randn(1, config.mel_bands, 63).astype(self.dtype)) - waveflow(audio, mel) - if training: optimizer = fluid.optimizer.AdamOptimizer( learning_rate=config.learning_rate, diff --git a/examples/wavenet/train.py b/examples/wavenet/train.py index 14b861b..95e5c0d 100644 --- a/examples/wavenet/train.py +++ b/examples/wavenet/train.py @@ -126,12 +126,11 @@ if __name__ == "__main__": anneal_interval = train_config["anneal_interval"] lr_scheduler = dg.ExponentialDecay( learning_rate, anneal_interval, anneal_rate, staircase=True) - optim = fluid.optimizer.Adam( - lr_scheduler, parameter_list=model.parameters()) - gradiant_max_norm = train_config["gradient_max_norm"] - clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( - gradiant_max_norm) + optim = fluid.optimizer.Adam( + lr_scheduler, + parameter_list=model.parameters(), + grad_clip=fluid.clip.ClipByGlobalNorm(gradiant_max_norm)) train_loader = fluid.io.DataLoader.from_generator( capacity=10, return_list=True) @@ -149,7 +148,7 @@ if __name__ == "__main__": log_dir = os.path.join(args.output, "log") writer = SummaryWriter(log_dir) - # load parameters and optimizer, and opdate iterations done sofar + # load parameters and optimizer, and update iterations done so far if args.checkpoint is not None: iteration = io.load_parameters( model, optim, checkpoint_path=args.checkpoint) @@ -181,7 +180,7 @@ if __name__ == "__main__": writer.add_scalar("learning_rate", optim._learning_rate.step().numpy()[0], global_step) - optim.minimize(loss_var, grad_clip=clipper) + optim.minimize(loss_var) optim.clear_gradients() print("global_step: {}\tloss: {:<8.6f}".format(global_step, loss_np[0])) diff --git a/parakeet/models/clarinet/utils.py b/parakeet/models/clarinet/utils.py index d5c2b44..6a92b26 100644 --- a/parakeet/models/clarinet/utils.py +++ b/parakeet/models/clarinet/utils.py @@ -29,22 +29,10 @@ def conv2d(input, data_format="NCHW"): padding = tuple(pad for pad_dim in padding for pad in pad_dim) - inputs = { - 'Input': [input], - 'Filter': [weight], - } - attrs = { - 'strides': stride, - 'paddings': padding, - 'dilations': dilation, - 'groups': groups, - 'use_cudnn': use_cudnn, - 'use_mkldnn': False, - 'fuse_relu_before_depthwise_conv': False, - "padding_algorithm": "EXPLICIT", - "data_format": data_format, - } + attrs = ('strides', stride, 'paddings', padding, 'dilations', dilation, + 'groups', groups, 'use_cudnn', use_cudnn, 'use_mkldnn', False, + 'fuse_relu_before_depthwise_conv', False, "padding_algorithm", + "EXPLICIT", "data_format", data_format) - outputs = ops.conv2d(inputs, attrs) - out = outputs["Output"][0] - return out \ No newline at end of file + out = ops.conv2d(input, weight, *attrs) + return out diff --git a/parakeet/models/deepvoice3/loss.py b/parakeet/models/deepvoice3/loss.py index abf6d73..8c7029d 100644 --- a/parakeet/models/deepvoice3/loss.py +++ b/parakeet/models/deepvoice3/loss.py @@ -262,7 +262,7 @@ class TTSLoss(object): if compute_lin_loss: lin_hyp = lin_hyp[:, :-self.time_shift, :] lin_ref = lin_ref[:, self.time_shift:, :] - lin_mask = lin_mask[:, self.time_shift:, :] + lin_mask = lin_mask[:, self.time_shift:] lin_l1_loss = self.l1_loss( lin_hyp, lin_ref, lin_mask, priority_bin=self.priority_bin) lin_bce_loss = self.binary_divergence(lin_hyp, lin_ref, lin_mask) @@ -273,7 +273,7 @@ class TTSLoss(object): if compute_mel_loss: mel_hyp = mel_hyp[:, :-self.time_shift, :] mel_ref = mel_ref[:, self.time_shift:, :] - mel_mask = mel_mask[:, self.time_shift:, :] + mel_mask = mel_mask[:, self.time_shift:] mel_l1_loss = self.l1_loss(mel_hyp, mel_ref, mel_mask) mel_bce_loss = self.binary_divergence(mel_hyp, mel_ref, mel_mask) # print("=====>", mel_l1_loss.numpy()[0], mel_bce_loss.numpy()[0]) diff --git a/parakeet/models/deepvoice3/position_embedding.py b/parakeet/models/deepvoice3/position_embedding.py index 032feff..e76d2c3 100644 --- a/parakeet/models/deepvoice3/position_embedding.py +++ b/parakeet/models/deepvoice3/position_embedding.py @@ -31,8 +31,10 @@ def compute_position_embedding(radians, speaker_position_rate): """ _, embed_dim = radians.shape batch_size = speaker_position_rate.shape[0] - speaker_position_rate = F.unsqueeze(speaker_position_rate, [1, 2]) - scaled_radians = speaker_position_rate * radians + scaled_radians = F.elementwise_mul( + F.expand(F.unsqueeze(radians, [0]), [batch_size, 1, 1]), + speaker_position_rate, + axis=0) odd_mask = (np.arange(embed_dim) % 2).astype(np.float32) odd_mask = dg.to_variable(odd_mask) diff --git a/parakeet/models/wavenet/wavenet.py b/parakeet/models/wavenet/wavenet.py index 49778a5..a0296e1 100644 --- a/parakeet/models/wavenet/wavenet.py +++ b/parakeet/models/wavenet/wavenet.py @@ -111,7 +111,7 @@ class ResidualBlock(dg.Layer): h = h[:, :, :time_steps] # condition - if condition: + if condition is not None: h += self.condition_proj(condition) # gated tanh @@ -398,7 +398,8 @@ class WaveNet(dg.Layer): x_std = inv_std * (t - mu) exponent = F.exp(-0.5 * x_std * x_std) - pdf_x = 1.0 / np.sqrt(2.0 * np.pi) * inv_std * exponent + pdf_x = 1.0 / math.sqrt(2.0 * math.pi) * inv_std * exponent + pdf_x = p_mixture * pdf_x # pdf_x: [bs, len] pdf_x = F.reduce_sum(pdf_x, dim=-1) diff --git a/parakeet/modules/weight_norm.py b/parakeet/modules/weight_norm.py index 20af6c0..b48a686 100644 --- a/parakeet/modules/weight_norm.py +++ b/parakeet/modules/weight_norm.py @@ -84,13 +84,15 @@ class WeightNormWrapper(dg.Layer): w_v, self.create_parameter( shape=original_weight.shape, dtype=original_weight.dtype)) - F.assign(original_weight, getattr(self, w_v)) + with dg.no_grad(): + F.assign(original_weight, getattr(self, w_v)) delattr(layer, param_name) temp = norm_except(getattr(self, w_v), self.dim, self.power) self.add_parameter( w_g, self.create_parameter( shape=temp.shape, dtype=temp.dtype)) - F.assign(temp, getattr(self, w_g)) + with dg.no_grad(): + F.assign(temp, getattr(self, w_g)) # also set this when setting up setattr(self.layer, self.param_name,