diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..4102b69 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,27 @@ +- repo: https://github.com/PaddlePaddle/mirrors-yapf.git + sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 + hooks: + - id: yapf + files: \.py$ +- repo: https://github.com/pre-commit/pre-commit-hooks + sha: a11d9314b22d8f8c7556443875b731ef05965464 + hooks: + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + files: (?!.*paddle)^.*$ + - id: end-of-file-fixer + files: \.md$ + - id: trailing-whitespace + files: \.md$ +- repo: https://github.com/Lucas-C/pre-commit-hooks + sha: v1.0.1 + hooks: + - id: forbid-crlf + files: \.md$ + - id: remove-crlf + files: \.md$ + - id: forbid-tabs + files: \.md$ + - id: remove-tabs + files: \.md$ diff --git a/examples/deepvoice3/train.py b/examples/deepvoice3/train.py index 636032d..6d9aef6 100644 --- a/examples/deepvoice3/train.py +++ b/examples/deepvoice3/train.py @@ -28,22 +28,21 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( description="Train a deepvoice 3 model with LJSpeech dataset.") parser.add_argument("-c", "--config", type=str, help="experimrnt config") - parser.add_argument("-s", - "--data", - type=str, - default="/workspace/datasets/LJSpeech-1.1/", - help="The path of the LJSpeech dataset.") + parser.add_argument( + "-s", + "--data", + type=str, + default="/workspace/datasets/LJSpeech-1.1/", + help="The path of the LJSpeech dataset.") parser.add_argument("-r", "--resume", type=str, help="checkpoint to load") - parser.add_argument("-o", - "--output", - type=str, - default="result", - help="The directory to save result.") - parser.add_argument("-g", - "--device", - type=int, - default=-1, - help="device to use") + parser.add_argument( + "-o", + "--output", + type=str, + default="result", + help="The directory to save result.") + parser.add_argument( + "-g", "--device", type=int, default=-1, help="device to use") args, _ = parser.parse_known_args() with open(args.config, 'rt') as f: config = ruamel.yaml.safe_load(f) @@ -84,18 +83,16 @@ if __name__ == "__main__": train_config = config["train"] batch_size = train_config["batch_size"] text_lengths = [len(example[2]) for example in meta] - sampler = PartialyRandomizedSimilarTimeLengthSampler( - text_lengths, batch_size) + sampler = PartialyRandomizedSimilarTimeLengthSampler(text_lengths, + batch_size) # some hyperparameters affect how we process data, so create a data collector! model_config = config["model"] downsample_factor = model_config["downsample_factor"] r = model_config["outputs_per_step"] collector = DataCollector(downsample_factor=downsample_factor, r=r) - ljspeech_loader = DataCargo(ljspeech, - batch_fn=collector, - batch_size=batch_size, - sampler=sampler) + ljspeech_loader = DataCargo( + ljspeech, batch_fn=collector, batch_size=batch_size, sampler=sampler) # =========================model========================= if args.device == -1: @@ -131,15 +128,14 @@ if __name__ == "__main__": window_ahead = model_config["window_ahead"] key_projection = model_config["key_projection"] value_projection = model_config["value_projection"] - dv3 = make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim, - padding_idx, embedding_std, max_positions, n_vocab, - freeze_embedding, filter_size, encoder_channels, - n_mels, decoder_channels, r, - trainable_positional_encodings, use_memory_mask, - query_position_rate, key_position_rate, - window_backward, window_ahead, key_projection, - value_projection, downsample_factor, linear_dim, - use_decoder_states, converter_channels, dropout) + dv3 = make_model( + n_speakers, speaker_dim, speaker_embed_std, embed_dim, padding_idx, + embedding_std, max_positions, n_vocab, freeze_embedding, + filter_size, encoder_channels, n_mels, decoder_channels, r, + trainable_positional_encodings, use_memory_mask, + query_position_rate, key_position_rate, window_backward, + window_ahead, key_projection, value_projection, downsample_factor, + linear_dim, use_decoder_states, converter_channels, dropout) # =========================loss========================= loss_config = config["loss"] @@ -149,13 +145,14 @@ if __name__ == "__main__": priority_freq_weight = loss_config["priority_freq_weight"] binary_divergence_weight = loss_config["binary_divergence_weight"] guided_attention_sigma = loss_config["guided_attention_sigma"] - criterion = TTSLoss(masked_weight=masked_weight, - priority_bin=priority_bin, - priority_weight=priority_freq_weight, - binary_divergence_weight=binary_divergence_weight, - guided_attention_sigma=guided_attention_sigma, - downsample_factor=downsample_factor, - r=r) + criterion = TTSLoss( + masked_weight=masked_weight, + priority_bin=priority_bin, + priority_weight=priority_freq_weight, + binary_divergence_weight=binary_divergence_weight, + guided_attention_sigma=guided_attention_sigma, + downsample_factor=downsample_factor, + r=r) # =========================lr_scheduler========================= lr_config = config["lr_scheduler"] @@ -169,11 +166,12 @@ if __name__ == "__main__": beta1 = optim_config["beta1"] beta2 = optim_config["beta2"] epsilon = optim_config["epsilon"] - optim = fluid.optimizer.Adam(lr_scheduler, - beta1, - beta2, - epsilon=epsilon, - parameter_list=dv3.parameters()) + optim = fluid.optimizer.Adam( + lr_scheduler, + beta1, + beta2, + epsilon=epsilon, + parameter_list=dv3.parameters()) gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1) # generation @@ -183,8 +181,8 @@ if __name__ == "__main__": # =========================link(dataloader, paddle)========================= # CAUTION: it does not return a DataLoader - loader = fluid.io.DataLoader.from_generator(capacity=10, - return_list=True) + loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) loader.set_batch_generator(ljspeech_loader, places=place) # tensorboard & checkpoint preparation @@ -247,22 +245,23 @@ if __name__ == "__main__": # TODO: clean code # train state saving, the first sentence in the batch if global_step % snap_interval == 0: - save_state(state_dir, - writer, - global_step, - mel_input=downsampled_mel_specs, - mel_output=mel_outputs, - lin_input=lin_specs, - lin_output=linear_outputs, - alignments=alignments, - win_length=win_length, - hop_length=hop_length, - min_level_db=min_level_db, - ref_level_db=ref_level_db, - power=power, - n_iter=n_iter, - preemphasis=preemphasis, - sample_rate=sample_rate) + save_state( + state_dir, + writer, + global_step, + mel_input=downsampled_mel_specs, + mel_output=mel_outputs, + lin_input=lin_specs, + lin_output=linear_outputs, + alignments=alignments, + win_length=win_length, + hop_length=hop_length, + min_level_db=min_level_db, + ref_level_db=ref_level_db, + power=power, + n_iter=n_iter, + preemphasis=preemphasis, + sample_rate=sample_rate) # evaluation if global_step % eval_interval == 0: @@ -275,27 +274,28 @@ if __name__ == "__main__": "Some have accepted this as a miracle without any physical explanation.", ] for idx, sent in enumerate(sentences): - wav, attn = eval_model(dv3, sent, - replace_pronounciation_prob, - min_level_db, ref_level_db, - power, n_iter, win_length, - hop_length, preemphasis) + wav, attn = eval_model( + dv3, sent, replace_pronounciation_prob, + min_level_db, ref_level_db, power, n_iter, + win_length, hop_length, preemphasis) wav_path = os.path.join( state_dir, "waveform", "eval_sample_{:09d}.wav".format(global_step)) sf.write(wav_path, wav, sample_rate) - writer.add_audio("eval_sample_{}".format(idx), - wav, - global_step, - sample_rate=sample_rate) + writer.add_audio( + "eval_sample_{}".format(idx), + wav, + global_step, + sample_rate=sample_rate) attn_path = os.path.join( state_dir, "alignments", "eval_sample_attn_{:09d}.png".format(global_step)) plot_alignment(attn, attn_path) - writer.add_image("eval_sample_attn{}".format(idx), - cm.viridis(attn), - global_step, - dataformats="HWC") + writer.add_image( + "eval_sample_attn{}".format(idx), + cm.viridis(attn), + global_step, + dataformats="HWC") # save checkpoint if global_step % save_interval == 0: @@ -311,4 +311,4 @@ if __name__ == "__main__": global_step += 1 # epoch report writer.add_scalar("epoch_average_loss", epoch_loss / i, j) - epoch_loss = 0. \ No newline at end of file + epoch_loss = 0. diff --git a/parakeet/models/waveflow/README.md b/examples/waveflow/README.md similarity index 97% rename from parakeet/models/waveflow/README.md rename to examples/waveflow/README.md index d8072b1..184396f 100644 --- a/parakeet/models/waveflow/README.md +++ b/examples/waveflow/README.md @@ -16,10 +16,10 @@ Paddle fluid implementation of [WaveFlow: A Compact Flow-based Model for Raw Aud ## Usage -There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. +There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. We provide `wavenet_ljspeech.yaml` as a hyperparameter set that works well on the LJSpeech dataset. -Note that `train.py`, `synthesis.py`, and `benchmark.py` all accept a `--config` parameter. To ensure consistency, you should use the same config yaml file for both training, synthesizing and benchmarking. You can also overwrite these preset hyperparameters with command line by updating parameters after `--config`. +Note that `train.py`, `synthesis.py`, and `benchmark.py` all accept a `--config` parameter. To ensure consistency, you should use the same config yaml file for both training, synthesizing and benchmarking. You can also overwrite these preset hyperparameters with command line by updating parameters after `--config`. For example `--config=${yaml} --batch_size=8` can overwrite the corresponding hyperparameters in the `${yaml}` config file. For more details about these hyperparameters, check `utils.add_config_options_to_parser`. Note that you also need to specify some additional parameters for `train.py`, `synthesis.py`, and `benchmark.py`, and the details can be found in `train.add_options_to_parser`, `synthesis.add_options_to_parser`, and `benchmark.add_options_to_parser`, respectively. @@ -50,10 +50,10 @@ python -u train.py \ #### Save and Load checkpoints Our model will save model parameters as checkpoints in `./runs/waveflow/${ModelName}/checkpoint/` every 10000 iterations by default. -The saved checkpoint will have the format of `step-${iteration_number}.pdparams` for model parameters and `step-${iteration_number}.pdopt` for optimizer parameters. +The saved checkpoint will have the format of `step-${iteration_number}.pdparams` for model parameters and `step-${iteration_number}.pdopt` for optimizer parameters. There are three ways to load a checkpoint and resume training (take an example that you want to load a 500000-iteration checkpoint): -1. Use `--checkpoint=./runs/waveflow/${ModelName}/checkpoint/step-500000` to provide a specific path to load. Note that you only need to provide the base name of the parameter file, which is `step-500000`, no extension name `.pdparams` or `.pdopt` is needed. +1. Use `--checkpoint=./runs/waveflow/${ModelName}/checkpoint/step-500000` to provide a specific path to load. Note that you only need to provide the base name of the parameter file, which is `step-500000`, no extension name `.pdparams` or `.pdopt` is needed. 2. Use `--iteration=500000`. 3. If you don't specify either `--checkpoint` or `--iteration`, the model will automatically load the latest checkpoint in `./runs/waveflow/${ModelName}/checkpoint`. @@ -108,4 +108,4 @@ python -u benchmark.py \ --config=./configs/waveflow_ljspeech.yaml \ --root=./data/LJSpeech-1.1 \ --name=${ModelName} --use_gpu=true -``` \ No newline at end of file +``` diff --git a/parakeet/models/waveflow/benchmark.py b/examples/waveflow/benchmark.py similarity index 66% rename from parakeet/models/waveflow/benchmark.py rename to examples/waveflow/benchmark.py index b2949d2..eb6b6fc 100644 --- a/parakeet/models/waveflow/benchmark.py +++ b/examples/waveflow/benchmark.py @@ -2,35 +2,47 @@ import os import random from pprint import pprint -import jsonargparse +import argparse import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid import utils -from waveflow import WaveFlow +from parakeet.models.waveflow import WaveFlow def add_options_to_parser(parser): - parser.add_argument('--model', type=str, default='waveflow', + parser.add_argument( + '--model', + type=str, + default='waveflow', help="general name of the model") - parser.add_argument('--name', type=str, - help="specific name of the training model") - parser.add_argument('--root', type=str, - help="root path of the LJSpeech dataset") + parser.add_argument( + '--name', type=str, help="specific name of the training model") + parser.add_argument( + '--root', type=str, help="root path of the LJSpeech dataset") - parser.add_argument('--use_gpu', type=bool, default=True, + parser.add_argument( + '--use_gpu', + type=bool, + default=True, help="option to use gpu training") - parser.add_argument('--iteration', type=int, default=None, + parser.add_argument( + '--iteration', + type=int, + default=None, help=("which iteration of checkpoint to load, " "default to load the latest checkpoint")) - parser.add_argument('--checkpoint', type=str, default=None, + parser.add_argument( + '--checkpoint', + type=str, + default=None, help="path of the checkpoint to load") def benchmark(config): - pprint(jsonargparse.namespace_to_dict(config)) + pprint(vars(config)) # Get checkpoint directory path. run_dir = os.path.join("runs", config.model, config.name) @@ -47,7 +59,7 @@ def benchmark(config): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed print("Random Seed: ", seed) - + # Build model. model = WaveFlow(config, checkpoint_dir) model.build(training=False) @@ -58,9 +70,8 @@ def benchmark(config): if __name__ == "__main__": # Create parser. - parser = jsonargparse.ArgumentParser( - description="Synthesize audio using WaveNet model", - formatter_class='default_argparse') + parser = argparse.ArgumentParser( + description="Synthesize audio using WaveNet model") add_options_to_parser(parser) utils.add_config_options_to_parser(parser) @@ -68,4 +79,5 @@ if __name__ == "__main__": # For conflicting updates to the same field, # the preceding update will be overwritten by the following one. config = parser.parse_args() + config = utils.add_yaml_config(config) benchmark(config) diff --git a/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml b/examples/waveflow/configs/waveflow_ljspeech.yaml similarity index 100% rename from parakeet/models/waveflow/configs/waveflow_ljspeech.yaml rename to examples/waveflow/configs/waveflow_ljspeech.yaml diff --git a/parakeet/models/waveflow/synthesis.py b/examples/waveflow/synthesis.py similarity index 67% rename from parakeet/models/waveflow/synthesis.py rename to examples/waveflow/synthesis.py index e42e170..1e3fb9e 100644 --- a/parakeet/models/waveflow/synthesis.py +++ b/examples/waveflow/synthesis.py @@ -2,40 +2,58 @@ import os import random from pprint import pprint -import jsonargparse +import argparse import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid import utils -from waveflow import WaveFlow +from parakeet.models.waveflow import WaveFlow def add_options_to_parser(parser): - parser.add_argument('--model', type=str, default='waveflow', + parser.add_argument( + '--model', + type=str, + default='waveflow', help="general name of the model") - parser.add_argument('--name', type=str, - help="specific name of the training model") - parser.add_argument('--root', type=str, - help="root path of the LJSpeech dataset") + parser.add_argument( + '--name', type=str, help="specific name of the training model") + parser.add_argument( + '--root', type=str, help="root path of the LJSpeech dataset") - parser.add_argument('--use_gpu', type=bool, default=True, + parser.add_argument( + '--use_gpu', + type=bool, + default=True, help="option to use gpu training") - parser.add_argument('--iteration', type=int, default=None, + parser.add_argument( + '--iteration', + type=int, + default=None, help=("which iteration of checkpoint to load, " "default to load the latest checkpoint")) - parser.add_argument('--checkpoint', type=str, default=None, + parser.add_argument( + '--checkpoint', + type=str, + default=None, help="path of the checkpoint to load") - parser.add_argument('--output', type=str, default="./syn_audios", + parser.add_argument( + '--output', + type=str, + default="./syn_audios", help="path to write synthesized audio files") - parser.add_argument('--sample', type=int, default=None, + parser.add_argument( + '--sample', + type=int, + default=None, help="which of the valid samples to synthesize audio") def synthesize(config): - pprint(jsonargparse.namespace_to_dict(config)) + pprint(vars(config)) # Get checkpoint directory path. run_dir = os.path.join("runs", config.model, config.name) @@ -52,7 +70,7 @@ def synthesize(config): fluid.default_startup_program().random_seed = seed fluid.default_main_program().random_seed = seed print("Random Seed: ", seed) - + # Build model. model = WaveFlow(config, checkpoint_dir) model.build(training=False) @@ -72,9 +90,8 @@ def synthesize(config): if __name__ == "__main__": # Create parser. - parser = jsonargparse.ArgumentParser( - description="Synthesize audio using WaveNet model", - formatter_class='default_argparse') + parser = argparse.ArgumentParser( + description="Synthesize audio using WaveNet model") add_options_to_parser(parser) utils.add_config_options_to_parser(parser) @@ -82,4 +99,5 @@ if __name__ == "__main__": # For conflicting updates to the same field, # the preceding update will be overwritten by the following one. config = parser.parse_args() + config = utils.add_yaml_config(config) synthesize(config) diff --git a/parakeet/models/waveflow/train.py b/examples/waveflow/train.py similarity index 74% rename from parakeet/models/waveflow/train.py rename to examples/waveflow/train.py index 89b787a..e41597e 100644 --- a/parakeet/models/waveflow/train.py +++ b/examples/waveflow/train.py @@ -4,34 +4,48 @@ import subprocess import time from pprint import pprint -import jsonargparse +import argparse import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid from tensorboardX import SummaryWriter -import slurm import utils -from waveflow import WaveFlow +from parakeet.models.waveflow import WaveFlow def add_options_to_parser(parser): - parser.add_argument('--model', type=str, default='waveflow', + parser.add_argument( + '--model', + type=str, + default='waveflow', help="general name of the model") - parser.add_argument('--name', type=str, - help="specific name of the training model") - parser.add_argument('--root', type=str, - help="root path of the LJSpeech dataset") + parser.add_argument( + '--name', type=str, help="specific name of the training model") + parser.add_argument( + '--root', type=str, help="root path of the LJSpeech dataset") - parser.add_argument('--parallel', type=bool, default=True, + parser.add_argument( + '--parallel', + type=utils.str2bool, + default=True, help="option to use data parallel training") - parser.add_argument('--use_gpu', type=bool, default=True, + parser.add_argument( + '--use_gpu', + type=utils.str2bool, + default=True, help="option to use gpu training") - parser.add_argument('--iteration', type=int, default=None, + parser.add_argument( + '--iteration', + type=int, + default=None, help=("which iteration of checkpoint to load, " "default to load the latest checkpoint")) - parser.add_argument('--checkpoint', type=str, default=None, + parser.add_argument( + '--checkpoint', + type=str, + default=None, help="path of the checkpoint to load") @@ -45,12 +59,13 @@ def train(config): if rank == 0: # Print the whole config setting. - pprint(jsonargparse.namespace_to_dict(config)) + pprint(vars(config)) # Make checkpoint directory. run_dir = os.path.join("runs", config.model, config.name) checkpoint_dir = os.path.join(run_dir, "checkpoint") - os.makedirs(checkpoint_dir, exist_ok=True) + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) # Create tensorboard logger. tb = SummaryWriter(os.path.join(run_dir, "logs")) \ @@ -102,8 +117,8 @@ def train(config): if __name__ == "__main__": # Create parser. - parser = jsonargparse.ArgumentParser(description="Train WaveFlow model", - formatter_class='default_argparse') + parser = argparse.ArgumentParser(description="Train WaveFlow model") + #formatter_class='default_argparse') add_options_to_parser(parser) utils.add_config_options_to_parser(parser) @@ -111,4 +126,5 @@ if __name__ == "__main__": # For conflicting updates to the same field, # the preceding update will be overwritten by the following one. config = parser.parse_args() - train(config) + config = utils.add_yaml_config(config) + train(config) diff --git a/parakeet/models/waveflow/utils.py b/examples/waveflow/utils.py similarity index 55% rename from parakeet/models/waveflow/utils.py rename to examples/waveflow/utils.py index 3baeb60..c088b1d 100644 --- a/parakeet/models/waveflow/utils.py +++ b/examples/waveflow/utils.py @@ -2,59 +2,96 @@ import itertools import os import time -import jsonargparse +import argparse +import ruamel.yaml import numpy as np import paddle.fluid.dygraph as dg +def str2bool(v): + return v.lower() in ("true", "t", "1") + + def add_config_options_to_parser(parser): - parser.add_argument('--valid_size', type=int, - help="size of the valid dataset") - parser.add_argument('--segment_length', type=int, + parser.add_argument( + '--valid_size', type=int, help="size of the valid dataset") + parser.add_argument( + '--segment_length', + type=int, help="the length of audio clip for training") - parser.add_argument('--sample_rate', type=int, - help="sampling rate of audio data file") - parser.add_argument('--fft_window_shift', type=int, + parser.add_argument( + '--sample_rate', type=int, help="sampling rate of audio data file") + parser.add_argument( + '--fft_window_shift', + type=int, help="the shift of fft window for each frame") - parser.add_argument('--fft_window_size', type=int, + parser.add_argument( + '--fft_window_size', + type=int, help="the size of fft window for each frame") - parser.add_argument('--fft_size', type=int, - help="the size of fft filter on each frame") - parser.add_argument('--mel_bands', type=int, + parser.add_argument( + '--fft_size', type=int, help="the size of fft filter on each frame") + parser.add_argument( + '--mel_bands', + type=int, help="the number of mel bands when calculating mel spectrograms") - parser.add_argument('--mel_fmin', type=float, + parser.add_argument( + '--mel_fmin', + type=float, help="lowest frequency in calculating mel spectrograms") - parser.add_argument('--mel_fmax', type=float, + parser.add_argument( + '--mel_fmax', + type=float, help="highest frequency in calculating mel spectrograms") - parser.add_argument('--seed', type=int, - help="seed of random initialization for the model") + parser.add_argument( + '--seed', type=int, help="seed of random initialization for the model") parser.add_argument('--learning_rate', type=float) - parser.add_argument('--batch_size', type=int, - help="batch size for training") - parser.add_argument('--test_every', type=int, - help="test interval during training") - parser.add_argument('--save_every', type=int, + parser.add_argument( + '--batch_size', type=int, help="batch size for training") + parser.add_argument( + '--test_every', type=int, help="test interval during training") + parser.add_argument( + '--save_every', + type=int, help="checkpointing interval during training") - parser.add_argument('--max_iterations', type=int, - help="maximum training iterations") + parser.add_argument( + '--max_iterations', type=int, help="maximum training iterations") - parser.add_argument('--sigma', type=float, + parser.add_argument( + '--sigma', + type=float, help="standard deviation of the latent Gaussian variable") - parser.add_argument('--n_flows', type=int, - help="number of flows") - parser.add_argument('--n_group', type=int, + parser.add_argument('--n_flows', type=int, help="number of flows") + parser.add_argument( + '--n_group', + type=int, help="number of adjacent audio samples to squeeze into one column") - parser.add_argument('--n_layers', type=int, + parser.add_argument( + '--n_layers', + type=int, help="number of conv2d layer in one wavenet-like flow architecture") - parser.add_argument('--n_channels', type=int, - help="number of residual channels in flow") - parser.add_argument('--kernel_h', type=int, + parser.add_argument( + '--n_channels', type=int, help="number of residual channels in flow") + parser.add_argument( + '--kernel_h', + type=int, help="height of the kernel in the conv2d layer") - parser.add_argument('--kernel_w', type=int, - help="width of the kernel in the conv2d layer") + parser.add_argument( + '--kernel_w', type=int, help="width of the kernel in the conv2d layer") - parser.add_argument('--config', action=jsonargparse.ActionConfigFile) + parser.add_argument('--config', type=str, help="Path to the config file.") + + +def add_yaml_config(config): + with open(config.config, 'rt') as f: + yaml_cfg = ruamel.yaml.safe_load(f) + cfg_vars = vars(config) + for k, v in yaml_cfg.items(): + if k in cfg_vars and cfg_vars[k] is not None: + continue + cfg_vars[k] = v + return config def load_latest_checkpoint(checkpoint_dir, rank=0): @@ -84,8 +121,12 @@ def save_latest_checkpoint(checkpoint_dir, iteration): handle.write("model_checkpoint_path: step-{}".format(iteration)) -def load_parameters(checkpoint_dir, rank, model, optimizer=None, - iteration=None, file_path=None): +def load_parameters(checkpoint_dir, + rank, + model, + optimizer=None, + iteration=None, + file_path=None): if file_path is None: if iteration is None: iteration = load_latest_checkpoint(checkpoint_dir, rank) @@ -99,7 +140,7 @@ def load_parameters(checkpoint_dir, rank, model, optimizer=None, if optimizer and optimizer_dict: optimizer.set_dict(optimizer_dict) print("[checkpoint] Rank {}: loaded optimizer state from {}".format( - rank, file_path)) + rank, file_path)) def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): diff --git a/parakeet/datasets/ljspeech.py b/parakeet/datasets/ljspeech.py index 3ab73c7..7d4dffe 100644 --- a/parakeet/datasets/ljspeech.py +++ b/parakeet/datasets/ljspeech.py @@ -5,24 +5,29 @@ import librosa from .. import g2p from ..data.sampler import SequentialSampler, RandomSampler, BatchSampler -from ..data.dataset import Dataset +from ..data.dataset import DatasetMixin from ..data.datacargo import DataCargo from ..data.batch import TextIDBatcher, SpecBatcher -class LJSpeech(Dataset): +class LJSpeech(DatasetMixin): def __init__(self, root): super(LJSpeech, self).__init__() - assert isinstance(root, (str, Path)), "root should be a string or Path object" + assert isinstance(root, ( + str, Path)), "root should be a string or Path object" self.root = root if isinstance(root, Path) else Path(root) - self.metadata = self._prepare_metadata() - + self.metadata = self._prepare_metadata() + def _prepare_metadata(self): csv_path = self.root.joinpath("metadata.csv") - metadata = pd.read_csv(csv_path, sep="|", header=None, quoting=3, - names=["fname", "raw_text", "normalized_text"]) + metadata = pd.read_csv( + csv_path, + sep="|", + header=None, + quoting=3, + names=["fname", "raw_text", "normalized_text"]) return metadata - + def _get_example(self, metadatum): """All the code for generating an Example from a metadatum. If you want a different preprocessing pipeline, you can override this method. @@ -30,28 +35,32 @@ class LJSpeech(Dataset): In this case, you'd better pass a composed transform and pass it to the init method. """ - + fname, raw_text, normalized_text = metadatum wav_path = self.root.joinpath("wavs", fname + ".wav") - + # load -> trim -> preemphasis -> stft -> magnitude -> mel_scale -> logscale -> normalize - wav, sample_rate = librosa.load(wav_path, sr=None) # we would rather use functor to hold its parameters + wav, sample_rate = librosa.load( + wav_path, + sr=None) # we would rather use functor to hold its parameters trimed, _ = librosa.effects.trim(wav) preemphasized = librosa.effects.preemphasis(trimed) D = librosa.stft(preemphasized) mag, phase = librosa.magphase(D) mel = librosa.feature.melspectrogram(S=mag) - + mag = librosa.amplitude_to_db(S=mag) mel = librosa.amplitude_to_db(S=mel) - + ref_db = 20 max_db = 100 mel = np.clip((mel - ref_db + max_db) / max_db, 1e-8, 1) mel = np.clip((mag - ref_db + max_db) / max_db, 1e-8, 1) - phonemes = np.array(g2p.en.text_to_sequence(normalized_text), dtype=np.int64) - return (mag, mel, phonemes) # maybe we need to implement it as a map in the future + phonemes = np.array( + g2p.en.text_to_sequence(normalized_text), dtype=np.int64) + return (mag, mel, phonemes + ) # maybe we need to implement it as a map in the future def _batch_examples(self, minibatch): mag_batch = [] @@ -71,12 +80,10 @@ class LJSpeech(Dataset): metadatum = self.metadata.iloc[index] example = self._get_example(metadatum) return example - + def __iter__(self): for i in range(len(self)): yield self[i] - + def __len__(self): return len(self.metadata) - - diff --git a/parakeet/models/waveflow/__init__.py b/parakeet/models/waveflow/__init__.py new file mode 100644 index 0000000..20475cd --- /dev/null +++ b/parakeet/models/waveflow/__init__.py @@ -0,0 +1 @@ +from parakeet.models.waveflow.waveflow import WaveFlow diff --git a/parakeet/models/waveflow/data.py b/parakeet/models/waveflow/data.py index d89fb7b..b5ad2c9 100644 --- a/parakeet/models/waveflow/data.py +++ b/parakeet/models/waveflow/data.py @@ -5,10 +5,9 @@ import numpy as np from paddle import fluid from parakeet.datasets import ljspeech -from parakeet.data import dataset -from parakeet.data.batch import SpecBatcher, WavBatcher -from parakeet.data.datacargo import DataCargo -from parakeet.data.sampler import DistributedSampler, BatchSampler +from parakeet.data import SpecBatcher, WavBatcher +from parakeet.data import DataCargo, DatasetMixin +from parakeet.data import DistributedSampler, BatchSampler from scipy.io.wavfile import read @@ -27,7 +26,7 @@ class Dataset(ljspeech.LJSpeech): return audio -class Subset(dataset.Dataset): +class Subset(DatasetMixin): def __init__(self, dataset, indices, valid): self.dataset = dataset self.indices = indices @@ -36,18 +35,18 @@ class Subset(dataset.Dataset): def get_mel(self, audio): spectrogram = librosa.core.stft( - audio, n_fft=self.config.fft_size, + audio, + n_fft=self.config.fft_size, hop_length=self.config.fft_window_shift, win_length=self.config.fft_window_size) - spectrogram_magnitude = np.abs(spectrogram) + spectrogram_magnitude = np.abs(spectrogram) # mel_filter_bank shape: [n_mels, 1 + n_fft/2] - mel_filter_bank = librosa.filters.mel( - sr=self.config.sample_rate, - n_fft=self.config.fft_size, - n_mels=self.config.mel_bands, - fmin=self.config.mel_fmin, - fmax=self.config.mel_fmax) + mel_filter_bank = librosa.filters.mel(sr=self.config.sample_rate, + n_fft=self.config.fft_size, + n_mels=self.config.mel_bands, + fmin=self.config.mel_fmin, + fmax=self.config.mel_fmax) # mel shape: [n_mels, num_frames] mel = np.dot(mel_filter_bank, spectrogram_magnitude) @@ -67,13 +66,14 @@ class Subset(dataset.Dataset): pass else: # audio shape: [len] - if audio.shape[0] >= segment_length: + if audio.shape[0] >= segment_length: max_audio_start = audio.shape[0] - segment_length audio_start = random.randint(0, max_audio_start) - audio = audio[audio_start : (audio_start + segment_length)] + audio = audio[audio_start:(audio_start + segment_length)] else: audio = np.pad(audio, (0, segment_length - audio.shape[0]), - mode='constant', constant_values=0) + mode='constant', + constant_values=0) # Normalize audio to the [-1, 1] range. audio = audio.astype(np.float32) / 32768.0 @@ -109,17 +109,17 @@ class LJSpeech: # Train dataset. trainset = Subset(ds, train_indices, valid=False) - sampler = DistributedSampler(len(trainset), nranks, rank) + sampler = DistributedSampler(len(trainset), nranks, rank) total_bs = config.batch_size assert total_bs % nranks == 0 - train_sampler = BatchSampler(sampler, total_bs // nranks, - drop_last=True) + train_sampler = BatchSampler( + sampler, total_bs // nranks, drop_last=True) trainloader = DataCargo(trainset, batch_sampler=train_sampler) trainreader = fluid.io.PyReader(capacity=50, return_list=True) trainreader.decorate_batch_generator(trainloader, place) self.trainloader = (data for _ in iter(int, 1) - for data in trainreader()) + for data in trainreader()) # Valid dataset. validset = Subset(ds, valid_indices, valid=True) @@ -127,5 +127,5 @@ class LJSpeech: validloader = DataCargo(validset, batch_size=1, shuffle=False) validreader = fluid.io.PyReader(capacity=20, return_list=True) - validreader.decorate_batch_generator(validloader, place) + validreader.decorate_batch_generator(validloader, place) self.validloader = validreader diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py index 4935d42..569086e 100644 --- a/parakeet/models/waveflow/waveflow.py +++ b/parakeet/models/waveflow/waveflow.py @@ -8,13 +8,18 @@ from paddle import fluid from scipy.io.wavfile import write import utils -from data import LJSpeech -from waveflow_modules import WaveFlowLoss, WaveFlowModule +from .data import LJSpeech +from .waveflow_modules import WaveFlowLoss, WaveFlowModule class WaveFlow(): - def __init__(self, config, checkpoint_dir, parallel=False, rank=0, - nranks=1, tb_logger=None): + def __init__(self, + config, + checkpoint_dir, + parallel=False, + rank=0, + nranks=1, + tb_logger=None): self.config = config self.checkpoint_dir = checkpoint_dir self.parallel = parallel @@ -24,12 +29,12 @@ class WaveFlow(): def build(self, training=True): config = self.config - dataset = LJSpeech(config, self.nranks, self.rank) + dataset = LJSpeech(config, self.nranks, self.rank) self.trainloader = dataset.trainloader self.validloader = dataset.validloader - waveflow = WaveFlowModule("waveflow", config) - + waveflow = WaveFlowModule(config) + # Dry run once to create and initalize all necessary parameters. audio = dg.to_variable(np.random.randn(1, 16000).astype(np.float32)) mel = dg.to_variable( @@ -38,29 +43,36 @@ class WaveFlow(): if training: optimizer = fluid.optimizer.AdamOptimizer( - learning_rate=config.learning_rate) - + learning_rate=config.learning_rate, + parameter_list=waveflow.parameters()) + # Load parameters. - utils.load_parameters(self.checkpoint_dir, self.rank, - waveflow, optimizer, - iteration=config.iteration, - file_path=config.checkpoint) + utils.load_parameters( + self.checkpoint_dir, + self.rank, + waveflow, + optimizer, + iteration=config.iteration, + file_path=config.checkpoint) print("Rank {}: checkpoint loaded.".format(self.rank)) - + # Data parallelism. if self.parallel: strategy = dg.parallel.prepare_context() waveflow = dg.parallel.DataParallel(waveflow, strategy) - + self.waveflow = waveflow self.optimizer = optimizer self.criterion = WaveFlowLoss(config.sigma) else: # Load parameters. - utils.load_parameters(self.checkpoint_dir, self.rank, waveflow, - iteration=config.iteration, - file_path=config.checkpoint) + utils.load_parameters( + self.checkpoint_dir, + self.rank, + waveflow, + iteration=config.iteration, + file_path=config.checkpoint) print("Rank {}: checkpoint loaded.".format(self.rank)) self.waveflow = waveflow @@ -83,7 +95,8 @@ class WaveFlow(): else: loss.backward() - self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters()) + self.optimizer.minimize( + loss, parameter_list=self.waveflow.parameters()) self.waveflow.clear_gradients() graph_time = time.time() @@ -139,7 +152,8 @@ class WaveFlow(): sample = config.sample output = "{}/{}/iter-{}".format(config.output, config.name, iteration) - os.makedirs(output, exist_ok=True) + if not os.path.exists(output): + os.makedirs(output) mels_list = [mels for _, mels in self.validloader()] if sample is not None: @@ -148,16 +162,16 @@ class WaveFlow(): for sample, mel in enumerate(mels_list): filename = "{}/valid_{}.wav".format(output, sample) print("Synthesize sample {}, save as {}".format(sample, filename)) - + start_time = time.time() audio = self.waveflow.synthesize(mel, sigma=self.config.sigma) syn_time = time.time() - start_time - + audio = audio[0] audio_time = audio.shape[0] / self.config.sample_rate - print("audio time {:.4f}, synthesis time {:.4f}".format( - audio_time, syn_time)) - + print("audio time {:.4f}, synthesis time {:.4f}".format(audio_time, + syn_time)) + # Denormalize audio from [-1, 1] to [-32768, 32768] int16 range. audio = audio.numpy() * 32768.0 audio = audio.astype('int16') @@ -180,8 +194,8 @@ class WaveFlow(): syn_time = time.time() - start_time audio_time = audio.shape[1] * batch_size / self.config.sample_rate - print("audio time {:.4f}, synthesis time {:.4f}".format( - audio_time, syn_time)) + print("audio time {:.4f}, synthesis time {:.4f}".format(audio_time, + syn_time)) print("{} X real-time".format(audio_time / syn_time)) def save(self, iteration): diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py index 39cb598..c981fe7 100644 --- a/parakeet/models/waveflow/waveflow_modules.py +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -3,26 +3,27 @@ import itertools import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid -from parakeet.modules import conv, modules, weight_norm +from parakeet.modules import weight_norm -def set_param_attr(layer, c_in=1): - if isinstance(layer, (weight_norm.Conv2DTranspose, weight_norm.Conv2D)): - k = np.sqrt(1.0 / (c_in * np.prod(layer._filter_size))) +def get_param_attr(layer_type, filter_size, c_in=1): + if layer_type == "weight_norm": + k = np.sqrt(1.0 / (c_in * np.prod(filter_size))) weight_init = fluid.initializer.UniformInitializer(low=-k, high=k) bias_init = fluid.initializer.UniformInitializer(low=-k, high=k) - elif isinstance(layer, dg.Conv2D): + elif layer_type == "common": weight_init = fluid.initializer.ConstantInitializer(0.0) bias_init = fluid.initializer.ConstantInitializer(0.0) else: raise TypeError("Unsupported layer type.") - layer._param_attr = fluid.ParamAttr(initializer=weight_init) - layer._bias_attr = fluid.ParamAttr(initializer=bias_init) + param_attr = fluid.ParamAttr(initializer=weight_init) + bias_attr = fluid.ParamAttr(initializer=bias_init) + return param_attr, bias_attr def unfold(x, n_group): - length = x.shape[-1] + length = x.shape[-1] new_shape = x.shape[:-1] + [length // n_group, n_group] return fluid.layers.reshape(x, new_shape) @@ -48,20 +49,23 @@ class WaveFlowLoss: class Conditioner(dg.Layer): - def __init__(self, name_scope): - super(Conditioner, self).__init__(name_scope) + def __init__(self): + super(Conditioner, self).__init__() upsample_factors = [16, 16] - + self.upsample_conv2d = [] for s in upsample_factors: in_channel = 1 - conv_trans2d = modules.Conv2DTranspose( - self.full_name(), + param_attr, bias_attr = get_param_attr( + "weight_norm", (3, 2 * s), c_in=in_channel) + conv_trans2d = weight_norm.Conv2DTranspose( + num_channels=in_channel, num_filters=1, filter_size=(3, 2 * s), padding=(1, s // 2), - stride=(1, s)) - set_param_attr(conv_trans2d, c_in=in_channel) + stride=(1, s), + param_attr=param_attr, + bias_attr=bias_attr) self.upsample_conv2d.append(conv_trans2d) for i, layer in enumerate(self.upsample_conv2d): @@ -86,8 +90,8 @@ class Conditioner(dg.Layer): class Flow(dg.Layer): - def __init__(self, name_scope, config): - super(Flow, self).__init__(name_scope) + def __init__(self, config): + super(Flow, self).__init__() self.n_layers = config.n_layers self.n_channels = config.n_channels self.kernel_h = config.kernel_h @@ -95,27 +99,34 @@ class Flow(dg.Layer): # Transform audio: [batch, 1, n_group, time/n_group] # => [batch, n_channels, n_group, time/n_group] + param_attr, bias_attr = get_param_attr("weight_norm", (1, 1), c_in=1) self.start = weight_norm.Conv2D( - self.full_name(), + num_channels=1, num_filters=self.n_channels, - filter_size=(1, 1)) - set_param_attr(self.start, c_in=1) + filter_size=(1, 1), + param_attr=param_attr, + bias_attr=bias_attr) # Initializing last layer to 0 makes the affine coupling layers # do nothing at first. This helps with training stability # output shape: [batch, 2, n_group, time/n_group] + param_attr, bias_attr = get_param_attr( + "common", (1, 1), c_in=self.n_channels) self.end = dg.Conv2D( - self.full_name(), + num_channels=self.n_channels, num_filters=2, - filter_size=(1, 1)) - set_param_attr(self.end) + filter_size=(1, 1), + param_attr=param_attr, + bias_attr=bias_attr) # receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze - dilation_dict = {8: [1, 1, 1, 1, 1, 1, 1, 1], - 16: [1, 1, 1, 1, 1, 1, 1, 1], - 32: [1, 2, 4, 1, 2, 4, 1, 2], - 64: [1, 2, 4, 8, 16, 1, 2, 4], - 128: [1, 2, 4, 8, 16, 32, 64, 1]} + dilation_dict = { + 8: [1, 1, 1, 1, 1, 1, 1, 1], + 16: [1, 1, 1, 1, 1, 1, 1, 1], + 32: [1, 2, 4, 1, 2, 4, 1, 2], + 64: [1, 2, 4, 8, 16, 1, 2, 4], + 128: [1, 2, 4, 8, 16, 32, 64, 1] + } self.dilation_h_list = dilation_dict[config.n_group] self.in_layers = [] @@ -123,32 +134,42 @@ class Flow(dg.Layer): self.res_skip_layers = [] for i in range(self.n_layers): dilation_h = self.dilation_h_list[i] - dilation_w = 2 ** i + dilation_w = 2**i + param_attr, bias_attr = get_param_attr( + "weight_norm", (self.kernel_h, self.kernel_w), + c_in=self.n_channels) in_layer = weight_norm.Conv2D( - self.full_name(), + num_channels=self.n_channels, num_filters=2 * self.n_channels, filter_size=(self.kernel_h, self.kernel_w), - dilation=(dilation_h, dilation_w)) - set_param_attr(in_layer, c_in=self.n_channels) + dilation=(dilation_h, dilation_w), + param_attr=param_attr, + bias_attr=bias_attr) self.in_layers.append(in_layer) + param_attr, bias_attr = get_param_attr( + "weight_norm", (1, 1), c_in=config.mel_bands) cond_layer = weight_norm.Conv2D( - self.full_name(), + num_channels=config.mel_bands, num_filters=2 * self.n_channels, - filter_size=(1, 1)) - set_param_attr(cond_layer, c_in=config.mel_bands) + filter_size=(1, 1), + param_attr=param_attr, + bias_attr=bias_attr) self.cond_layers.append(cond_layer) if i < self.n_layers - 1: res_skip_channels = 2 * self.n_channels else: res_skip_channels = self.n_channels + param_attr, bias_attr = get_param_attr( + "weight_norm", (1, 1), c_in=self.n_channels) res_skip_layer = weight_norm.Conv2D( - self.full_name(), + num_channels=self.n_channels, num_filters=res_skip_channels, - filter_size=(1, 1)) - set_param_attr(res_skip_layer, c_in=self.n_channels) + filter_size=(1, 1), + param_attr=param_attr, + bias_attr=bias_attr) self.res_skip_layers.append(res_skip_layer) self.add_sublayer("in_layer_{}".format(i), in_layer) @@ -162,14 +183,14 @@ class Flow(dg.Layer): for i in range(self.n_layers): dilation_h = self.dilation_h_list[i] - dilation_w = 2 ** i + dilation_w = 2**i # Pad height dim (n_group): causal convolution # Pad width dim (time): dialated non-causal convolution pad_top, pad_bottom = (self.kernel_h - 1) * dilation_h, 0 - pad_left = pad_right = int((self.kernel_w-1) * dilation_w / 2) - audio_pad = fluid.layers.pad2d(audio, - paddings=[pad_top, pad_bottom, pad_left, pad_right]) + pad_left = pad_right = int((self.kernel_w - 1) * dilation_w / 2) + audio_pad = fluid.layers.pad2d( + audio, paddings=[pad_top, pad_bottom, pad_left, pad_right]) hidden = self.in_layers[i](audio_pad) cond_hidden = self.cond_layers[i](mel) @@ -196,7 +217,7 @@ class Flow(dg.Layer): for i in range(self.n_layers): dilation_h = self.dilation_h_list[i] - dilation_w = 2 ** i + dilation_w = 2**i state_size = dilation_h * (self.kernel_h - 1) queue = queues[i] @@ -206,7 +227,7 @@ class Flow(dg.Layer): queue.append(fluid.layers.zeros_like(audio)) state = queue[0:state_size] - state = fluid.layers.concat([*state, audio], axis=2) + state = fluid.layers.concat(state + [audio], axis=2) queue.pop(0) queue.append(audio) @@ -214,10 +235,10 @@ class Flow(dg.Layer): # Pad height dim (n_group): causal convolution # Pad width dim (time): dialated non-causal convolution pad_top, pad_bottom = 0, 0 - pad_left = int((self.kernel_w-1) * dilation_w / 2) - pad_right = int((self.kernel_w-1) * dilation_w / 2) - state = fluid.layers.pad2d(state, - paddings=[pad_top, pad_bottom, pad_left, pad_right]) + pad_left = int((self.kernel_w - 1) * dilation_w / 2) + pad_right = int((self.kernel_w - 1) * dilation_w / 2) + state = fluid.layers.pad2d( + state, paddings=[pad_top, pad_bottom, pad_left, pad_right]) hidden = self.in_layers[i](state) cond_hidden = self.cond_layers[i](mel) @@ -241,20 +262,20 @@ class Flow(dg.Layer): class WaveFlowModule(dg.Layer): - def __init__(self, name_scope, config): - super(WaveFlowModule, self).__init__(name_scope) + def __init__(self, config): + super(WaveFlowModule, self).__init__() self.n_flows = config.n_flows self.n_group = config.n_group self.n_layers = config.n_layers assert self.n_group % 2 == 0 assert self.n_flows % 2 == 0 - self.conditioner = Conditioner(self.full_name()) + self.conditioner = Conditioner() self.flows = [] for i in range(self.n_flows): - flow = Flow(self.full_name(), config) + flow = Flow(config) self.flows.append(flow) - self.add_sublayer("flow_{}".format(i), flow) + self.add_sublayer("flow_{}".format(i), flow) self.perms = [] half = self.n_group // 2 @@ -266,7 +287,7 @@ class WaveFlowModule(dg.Layer): perm[:half] = reversed(perm[:half]) perm[half:] = reversed(perm[half:]) self.perms.append(perm) - + def forward(self, audio, mel): mel = self.conditioner(mel) assert mel.shape[2] >= audio.shape[1] @@ -277,14 +298,13 @@ class WaveFlowModule(dg.Layer): audio = audio[:, :pruned_len] if mel.shape[2] > pruned_len: mel = mel[:, :, :pruned_len] - + # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) # From [bs, time] to [bs, n_group, time/n_group] audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1]) # [bs, 1, n_group, time/n_group] audio = fluid.layers.unsqueeze(audio, 1) - log_s_list = [] for i in range(self.n_flows): inputs = audio[:, :, :-1, :] @@ -305,7 +325,6 @@ class WaveFlowModule(dg.Layer): mel = fluid.layers.stack(mel_slices, axis=2) z = fluid.layers.squeeze(audio, [1]) - return z, log_s_list def synthesize(self, mel, sigma=1.0): @@ -331,7 +350,7 @@ class WaveFlowModule(dg.Layer): for h in range(1, self.n_group): inputs = audio_h - conds = mel[:, :, h:(h+1), :] + conds = mel[:, :, h:(h + 1), :] outputs = self.flows[i].infer(inputs, conds, queues) log_s = outputs[:, 0:1, :, :] diff --git a/parakeet/modules/weight_norm.py b/parakeet/modules/weight_norm.py index 8db21c0..992f099 100644 --- a/parakeet/modules/weight_norm.py +++ b/parakeet/modules/weight_norm.py @@ -40,8 +40,8 @@ def norm_except(param, dim, power): def compute_weight(v, g, dim, power): assert len(g.shape) == 1, "magnitude should be a vector" - v_normalized = F.elementwise_div(v, (norm_except(v, dim, power) + 1e-12), - axis=dim) + v_normalized = F.elementwise_div( + v, (norm_except(v, dim, power) + 1e-12), axis=dim) weight = F.elementwise_mul(v_normalized, g, axis=dim) return weight @@ -63,20 +63,21 @@ class WeightNormWrapper(dg.Layer): original_weight = getattr(layer, param_name) self.add_parameter( w_v, - self.create_parameter(shape=original_weight.shape, - dtype=original_weight.dtype)) + self.create_parameter( + shape=original_weight.shape, dtype=original_weight.dtype)) F.assign(original_weight, getattr(self, w_v)) delattr(layer, param_name) temp = norm_except(getattr(self, w_v), self.dim, self.power) self.add_parameter( - w_g, self.create_parameter(shape=temp.shape, dtype=temp.dtype)) + w_g, self.create_parameter( + shape=temp.shape, dtype=temp.dtype)) F.assign(temp, getattr(self, w_g)) # also set this when setting up - setattr( - self.layer, self.param_name, - compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim, - self.power)) + setattr(self.layer, self.param_name, + compute_weight( + getattr(self, w_v), + getattr(self, w_g), self.dim, self.power)) self.weigth_norm_applied = True @@ -84,10 +85,10 @@ class WeightNormWrapper(dg.Layer): def hook(self): w_v = self.param_name + "_v" w_g = self.param_name + "_g" - setattr( - self.layer, self.param_name, - compute_weight(getattr(self, w_v), getattr(self, w_g), self.dim, - self.power)) + setattr(self.layer, self.param_name, + compute_weight( + getattr(self, w_v), + getattr(self, w_g), self.dim, self.power)) def remove_weight_norm(self): self.hook()