From 83d6a85b8663a0cd700e88ed03f7bde27f8ceacc Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Thu, 12 Dec 2019 17:58:10 -0800 Subject: [PATCH 1/6] add waveflow model valid for training only --- parakeet/models/waveflow/README.md | 97 +++++++ .../waveflow_ljspeech_sqz16_r64_layer8x8.yaml | 24 ++ ...flow_ljspeech_sqz16_r64_layer8x8_s123.yaml | 24 ++ parakeet/models/waveflow/data.py | 139 ++++++++++ parakeet/models/waveflow/slurm.py | 113 ++++++++ parakeet/models/waveflow/synthesis.py | 85 ++++++ parakeet/models/waveflow/train.py | 139 ++++++++++ parakeet/models/waveflow/utils.py | 135 +++++++++ parakeet/models/waveflow/waveflow.py | 174 ++++++++++++ parakeet/models/waveflow/waveflow_modules.py | 256 ++++++++++++++++++ 10 files changed, 1186 insertions(+) create mode 100644 parakeet/models/waveflow/README.md create mode 100644 parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8.yaml create mode 100644 parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml create mode 100644 parakeet/models/waveflow/data.py create mode 100644 parakeet/models/waveflow/slurm.py create mode 100644 parakeet/models/waveflow/synthesis.py create mode 100644 parakeet/models/waveflow/train.py create mode 100644 parakeet/models/waveflow/utils.py create mode 100644 parakeet/models/waveflow/waveflow.py create mode 100644 parakeet/models/waveflow/waveflow_modules.py diff --git a/parakeet/models/waveflow/README.md b/parakeet/models/waveflow/README.md new file mode 100644 index 0000000..18efd0b --- /dev/null +++ b/parakeet/models/waveflow/README.md @@ -0,0 +1,97 @@ +# WaveNet with Paddle Fluid + +Paddle fluid implementation of WaveNet, a deep generative model of raw audio waveforms. +WaveNet model is originally proposed in [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499). +Our implementation is based on the WaveNet architecture described in [ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech](https://arxiv.org/abs/1807.07281) and can provide various output distributions, including single Gaussian, mixture of Gaussian, and softmax with linearly quantized channels. + +We implement WaveNet model in paddle fluid with dynamic graph, which is convenient for flexible network architectures. + +## Project Structure +```text +├── configs # yaml configuration files of preset model hyperparameters +├── data.py # dataset and dataloader settings for LJSpeech +├── slurm.py # optional slurm helper functions if you use slurm to train model +├── synthesis.py # script for speech synthesis +├── train.py # script for model training +├── utils.py # helper functions for e.g., model checkpointing +├── wavenet.py # WaveNet model high level APIs +└── wavenet_modules.py # WaveNet model implementation +``` + +## Usage + +There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. Hyperparameters that are known to work good for the LJSpeech dataset are provided as yaml files in `./configs/` folder. Specifically, we provide `wavenet_ljspeech_single_gaussian.yaml`, `wavenet_ljspeech_mix_gaussian.yaml`, and `wavenet_ljspeech_softmax.yaml` config files for WaveNet with single Gaussian, 10-component mixture of Gaussians, and softmax (with 2048 linearly quantized channels) output distributions, respectively. + +Note that `train.py` and `synthesis.py` all accept a `--config` parameter. To ensure consistency, you should use the same config yaml file for both training and synthesizing. You can also overwrite these preset hyperparameters with command line by updating parameters after `--config`. For example `--config=${yaml} --batch_size=8 --layers=20` can overwrite the corresponding hyperparameters in the `${yaml}` config file. For more details about these hyperparameters, check `utils.add_config_options_to_parser`. + +Note that you also need to specify some additional parameters for `train.py` and `synthesis.py`, and the details can be found in `train.add_options_to_parser` and `synthesis.add_options_to_parser`, respectively. + +### Dataset + +Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). + +```bash +wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 +tar xjvf LJSpeech-1.1.tar.bz2 +``` + +In this example, assume that the path of unzipped LJSpeech dataset is `./data/LJSpeech-1.1`. + +### Train on single GPU + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0 +python -u train.py --config=${yaml} \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --batch_size=4 \ + --parallel=false --use_gpu=true +``` + +#### Save and Load checkpoints + +Our model will save model parameters as checkpoints in `./runs/wavenet/${ModelName}/checkpoint/` every 10000 iterations by default. +The saved checkpoint will have the format of `step-${iteration_number}.pdparams` for model parameters and `step-${iteration_number}.pdopt` for optimizer parameters. + +There are three ways to load a checkpoint and resume training (take an example that you want to load a 500000-iteration checkpoint): +1. Use `--checkpoint=./runs/wavenet/${ModelName}/checkpoint/step-500000` to provide a specific path to load. Note that you only need to provide the base name of the parameter file, which is `step-500000`, no extension name `.pdparams` or `.pdopt` is needed. +2. Use `--iteration=500000`. +3. If you don't specify either `--checkpoint` or `--iteration`, the model will automatically load the latest checkpoint in `./runs/wavenet/${ModelName}/checkpoint`. + +### Train on multiple GPUs + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -u -m paddle.distributed.launch train.py \ + --config=${yaml} \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --parallel=true --use_gpu=true +``` + +Use `export CUDA_VISIBLE_DEVICES=0,1,2,3` to set the GPUs that you want to use to be visible. Then the `paddle.distributed.launch` module will use these visible GPUs to do data parallel training in multiprocessing mode. + +### Monitor with Tensorboard + +By default, the logs are saved in `./runs/wavenet/${ModelName}/logs/`. You can monitor logs by tensorboard. + +```bash +tensorboard --logdir=${log_dir} --port=8888 +``` + +### Synthesize from a checkpoint + +Check the [Save and load checkpoint](#save-and-load-checkpoints) section on how to load a specific checkpoint. +The following example will automatically load the latest checkpoint: + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0 +python -u synthesis.py --config=${yaml} \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --use_gpu=true \ + --output=./syn_audios \ + --sample=${SAMPLE} +``` + +In this example, `--output` specifies where to save the synthesized audios and `--sample` specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset. diff --git a/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8.yaml b/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8.yaml new file mode 100644 index 0000000..f9bbc83 --- /dev/null +++ b/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8.yaml @@ -0,0 +1,24 @@ +valid_size: 16 +segment_length: 16000 +sample_rate: 22050 +fft_window_shift: 256 +fft_window_size: 1024 +fft_size: 1024 +mel_bands: 80 +mel_fmin: 0.0 +mel_fmax: 8000.0 + +seed: 1234 +learning_rate: 0.0002 +batch_size: 8 +test_every: 2000 +save_every: 5000 +max_iterations: 2000000 + +sigma: 1.0 +n_flows: 8 +n_group: 16 +n_layers: 8 +n_channels: 64 +kernel_h: 3 +kernel_w: 3 diff --git a/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml b/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml new file mode 100644 index 0000000..7d45212 --- /dev/null +++ b/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml @@ -0,0 +1,24 @@ +valid_size: 16 +segment_length: 16000 +sample_rate: 22050 +fft_window_shift: 256 +fft_window_size: 1024 +fft_size: 1024 +mel_bands: 80 +mel_fmin: 0.0 +mel_fmax: 8000.0 + +seed: 123 +learning_rate: 0.0002 +batch_size: 8 +test_every: 2000 +save_every: 5000 +max_iterations: 2000000 + +sigma: 1.0 +n_flows: 8 +n_group: 16 +n_layers: 8 +n_channels: 64 +kernel_h: 3 +kernel_w: 3 diff --git a/parakeet/models/waveflow/data.py b/parakeet/models/waveflow/data.py new file mode 100644 index 0000000..3c70ce0 --- /dev/null +++ b/parakeet/models/waveflow/data.py @@ -0,0 +1,139 @@ +import random + +import librosa +import numpy as np +from paddle import fluid + +import utils +from parakeet.datasets import ljspeech +from parakeet.data import dataset +from parakeet.data.batch import SpecBatcher, WavBatcher +from parakeet.data.datacargo import DataCargo +from parakeet.data.sampler import DistributedSampler, BatchSampler +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +class Dataset(ljspeech.LJSpeech): + def __init__(self, config): + super(Dataset, self).__init__(config.root) + self.config = config + + def _get_example(self, metadatum): + fname, _, _ = metadatum + wav_path = self.root.joinpath("wavs", fname + ".wav") + + loaded_sr, audio = read(wav_path) + assert loaded_sr == self.config.sample_rate + + return audio + + +class Subset(dataset.Dataset): + def __init__(self, dataset, indices, valid): + self.dataset = dataset + self.indices = indices + self.valid = valid + self.config = dataset.config + + def get_mel(self, audio): + spectrogram = librosa.core.stft( + audio, n_fft=self.config.fft_size, + hop_length=self.config.fft_window_shift, + win_length=self.config.fft_window_size) + spectrogram_magnitude = np.abs(spectrogram) + + # mel_filter_bank shape: [n_mels, 1 + n_fft/2] + mel_filter_bank = librosa.filters.mel( + sr=self.config.sample_rate, + n_fft=self.config.fft_size, + n_mels=self.config.mel_bands, + fmin=self.config.mel_fmin, + fmax=self.config.mel_fmax) + # mel shape: [n_mels, num_frames] + mel = np.dot(mel_filter_bank, spectrogram_magnitude) + + # Normalize mel. + clip_val = 1e-5 + ref_constant = 1 + mel = np.log(np.clip(mel, a_min=clip_val, a_max=None) * ref_constant) + + return mel + + def __getitem__(self, idx): + audio = self.dataset[self.indices[idx]] + segment_length = self.config.segment_length + + if self.valid: + # whole audio for valid set + pass + else: + # audio shape: [len] + if audio.shape[0] >= segment_length: + max_audio_start = audio.shape[0] - segment_length + audio_start = random.randint(0, max_audio_start) + audio = audio[audio_start : (audio_start + segment_length)] + else: + audio = np.pad(audio, (0, segment_length - audio.shape[0]), + mode='constant', constant_values=0) + + # Normalize audio. + audio = audio / MAX_WAV_VALUE + mel = self.get_mel(audio) + + return audio, mel + + def _batch_examples(self, batch): + audio_batch = [] + mel_batch = [] + for audio, mel in batch: + audio_batch + + audios = [sample[0] for sample in batch] + mels = [sample[1] for sample in batch] + + audios = WavBatcher(pad_value=0.0)(audios) + mels = SpecBatcher(pad_value=0.0)(mels) + + return audios, mels + + def __len__(self): + return len(self.indices) + + +class LJSpeech: + def __init__(self, config, nranks, rank): + place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace() + + # Whole LJSpeech dataset. + ds = Dataset(config) + + # Split into train and valid dataset. + indices = list(range(len(ds))) + train_indices = indices[config.valid_size:] + valid_indices = indices[:config.valid_size] + random.shuffle(train_indices) + + # Train dataset. + trainset = Subset(ds, train_indices, valid=False) + sampler = DistributedSampler(len(trainset), nranks, rank) + total_bs = config.batch_size + assert total_bs % nranks == 0 + train_sampler = BatchSampler(sampler, total_bs // nranks, + drop_last=True) + trainloader = DataCargo(trainset, batch_sampler=train_sampler) + + trainreader = fluid.io.PyReader(capacity=50, return_list=True) + trainreader.decorate_batch_generator(trainloader, place) + self.trainloader = (data for _ in iter(int, 1) + for data in trainreader()) + + # Valid dataset. + validset = Subset(ds, valid_indices, valid=True) + # Currently only support batch_size = 1 for valid loader. + validloader = DataCargo(validset, batch_size=1, shuffle=False) + + validreader = fluid.io.PyReader(capacity=20, return_list=True) + validreader.decorate_batch_generator(validloader, place) + self.validloader = validreader diff --git a/parakeet/models/waveflow/slurm.py b/parakeet/models/waveflow/slurm.py new file mode 100644 index 0000000..de1818c --- /dev/null +++ b/parakeet/models/waveflow/slurm.py @@ -0,0 +1,113 @@ +""" +Utility module for restarting training when using SLURM. +""" +import subprocess +import os +import sys +import shlex +import re +import time + + +def job_info(): + """Get information about the current job using `scontrol show job`. + Returns a dict mapping parameter names (e.g. "UserId", "RunTime", etc) to + their values, both as strings. + """ + job_id = int(os.environ["SLURM_JOB_ID"]) + + command = ["scontrol", "show", "job", str(job_id)] + output = subprocess.check_output(command).decode("utf-8") + + # Use a regex to extract the parameter names and values + pattern = "([A-Za-z/]*)=([^ \t\n]*)" + return dict(re.findall(pattern, output)) + + +def parse_hours(text): + """Parse a time format HH or DD-HH into a number of hours.""" + hour_chunks = text.split("-") + if len(hour_chunks) == 1: + return int(hour_chunks[0]) + elif len(hour_chunks) == 2: + return 24 * int(hour_chunks[0]) + int(hour_chunks[1]) + else: + raise ValueError("Unexpected hour format (expected HH or " + "DD-HH, but got {}).".format(text)) + + +def parse_time(text): + """Convert slurm time to an integer. + Expects time to be of the form: + "hours:minutes:seconds" or "day-hours:minutes:seconds". + """ + hours, minutes, seconds = text.split(":") + try: + return parse_hours(hours) * 3600 + int(minutes) * 60 + int(seconds) + except ValueError as e: + raise ValueError("Error parsing time {}. Got error {}.".format( + text, str(e))) + + +def restart_command(): + """Using the environment and SLURM command, create a command that, when, + run, will enqueue a repeat of the current job using `sbatch`. + Return the command as a list of strings, suitable for passing to + `subprocess.check_call` or similar functions. + Returns: + resume_command: list, command to run to restart job. + end_time: int or None; the time the job will end or None + if the job has unlimited runtime. + """ + # Make sure `RunTime` could be parsed correctly. + while job_info()["RunTime"] == "INVALID": + time.sleep(1) + + # Get all the necessary information by querying SLURM with this job id + info = job_info() + + try: + num_cpus = int(info["CPUs/Task"]) + except KeyError: + num_cpus = int(os.environ["SLURM_CPUS_PER_TASK"]) + + num_tasks = int(os.environ["SLURM_NTASKS"]) + nodes = info["NumNodes"] + gres, partition = info.get("Gres"), info.get("Partition") + stderr, stdout = info.get("StdErr"), info.get("StdOut") + job_name = info.get("JobName") + command = ["sbatch", "--job-name={}".format(job_name), + "--ntasks={}".format(num_tasks), + "--exclude=asimov-186"] + + if partition: + command.extend(["--partition", partition]) + + if gres and gres != "(null)": + command.extend(["--gres", gres]) + num_gpu = int(gres.split(':')[-1]) + print("number of gpu assigned by slurm is {}".format(num_gpu)) + + if stderr: + command.extend(["--error", stderr]) + + if stdout: + command.extend(["--output", stdout]) + + python = subprocess.check_output( + ["/usr/bin/which", "python3"]).decode("utf-8").strip() + dist_setting = ['-m', 'paddle.distributed.launch'] + wrap_cmd = ["srun", python, '-u'] + dist_setting + sys.argv + + command.append( + "--wrap={}".format(" ".join(shlex.quote(arg) for arg in wrap_cmd))) + time_limit_string = info["TimeLimit"] + if time_limit_string.lower() == "unlimited": + print("UNLIMITED detected: restart OFF, infinite learning ON.", + flush=True) + return command, None + time_limit = parse_time(time_limit_string) + runtime = parse_time(info["RunTime"]) + end_time = time.time() + time_limit - runtime + + return command, end_time diff --git a/parakeet/models/waveflow/synthesis.py b/parakeet/models/waveflow/synthesis.py new file mode 100644 index 0000000..d87a188 --- /dev/null +++ b/parakeet/models/waveflow/synthesis.py @@ -0,0 +1,85 @@ +import os +import random +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid + +import utils +from wavenet import WaveNet + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='wavenet', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + + parser.add_argument('--output', type=str, default="./syn_audios", + help="path to write synthesized audio files") + parser.add_argument('--sample', type=int, + help="which of the valid samples to synthesize audio") + + +def synthesize(config): + pprint(jsonargparse.namespace_to_dict(config)) + + # Get checkpoint directory path. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + + # Configurate device. + place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveNet(config, checkpoint_dir) + model.build(training=False) + + # Obtain the current iteration. + if config.checkpoint is None: + if config.iteration is None: + iteration = utils.load_latest_checkpoint(checkpoint_dir) + else: + iteration = config.iteration + else: + iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) + + # Run model inference. + model.infer(iteration) + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser( + description="Synthesize audio using WaveNet model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + synthesize(config) diff --git a/parakeet/models/waveflow/train.py b/parakeet/models/waveflow/train.py new file mode 100644 index 0000000..a125d97 --- /dev/null +++ b/parakeet/models/waveflow/train.py @@ -0,0 +1,139 @@ +import os +import random +import subprocess +import time +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid +from tensorboardX import SummaryWriter + +import slurm +import utils +from waveflow import WaveFlow + +MAXIMUM_SAVE_TIME = 10 * 60 + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='waveflow', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--parallel', type=bool, default=True, + help="option to use data parallel training") + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + parser.add_argument('--slurm', type=bool, default=False, + help="whether you are using slurm to submit training jobs") + + +def train(config): + use_gpu = config.use_gpu + parallel = config.parallel if use_gpu else False + + # Get the rank of the current training process. + rank = dg.parallel.Env().local_rank if parallel else 0 + nranks = dg.parallel.Env().nranks if parallel else 1 + + if rank == 0: + # Print the whole config setting. + pprint(jsonargparse.namespace_to_dict(config)) + + # Make checkpoint directory. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + os.makedirs(checkpoint_dir, exist_ok=True) + + # Create tensorboard logger. + tb = SummaryWriter(os.path.join(run_dir, "logs")) \ + if rank == 0 else None + + # Configurate device + place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveFlow(config, checkpoint_dir, parallel, rank, nranks, tb) + model.build() + + # Obtain the current iteration. + if config.checkpoint is None: + if config.iteration is None: + iteration = utils.load_latest_checkpoint(checkpoint_dir, rank) + else: + iteration = config.iteration + else: + iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) + + # Get restart command if using slurm. + if config.slurm: + resume_command, death_time = slurm.restart_command() + if rank == 0: + print("Restart command:", " ".join(resume_command)) + done = False + + while iteration < config.max_iterations: + # Run one single training step. + model.train_step(iteration) + + iteration += 1 + + if iteration % config.test_every == 0: + # Run validation step. + model.valid_step(iteration) + + # Check whether reaching the time limit. + if config.slurm: + done = (death_time is not None and death_time - time.time() < + MAXIMUM_SAVE_TIME) + + if rank == 0 and done: + print("Saving progress before exiting.") + model.save(iteration) + + print("Running restart command:", " ".join(resume_command)) + # Submit restart command. + subprocess.check_call(resume_command) + break + + if rank == 0 and iteration % config.save_every == 0: + # Save parameters. + model.save(iteration) + + # Close TensorBoard. + if rank == 0: + tb.close() + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser(description="Train WaveFlow model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + train(config) diff --git a/parakeet/models/waveflow/utils.py b/parakeet/models/waveflow/utils.py new file mode 100644 index 0000000..494a409 --- /dev/null +++ b/parakeet/models/waveflow/utils.py @@ -0,0 +1,135 @@ +import itertools +import os +import time + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg + + +def add_config_options_to_parser(parser): + parser.add_argument('--valid_size', type=int, + help="size of the valid dataset") + parser.add_argument('--segment_length', type=int, + help="the length of audio clip for training") + parser.add_argument('--sample_rate', type=int, + help="sampling rate of audio data file") + parser.add_argument('--fft_window_shift', type=int, + help="the shift of fft window for each frame") + parser.add_argument('--fft_window_size', type=int, + help="the size of fft window for each frame") + parser.add_argument('--fft_size', type=int, + help="the size of fft filter on each frame") + parser.add_argument('--mel_bands', type=int, + help="the number of mel bands when calculating mel spectrograms") + parser.add_argument('--mel_fmin', type=float, + help="lowest frequency in calculating mel spectrograms") + parser.add_argument('--mel_fmax', type=float, + help="highest frequency in calculating mel spectrograms") + + parser.add_argument('--seed', type=int, + help="seed of random initialization for the model") + parser.add_argument('--learning_rate', type=float) + parser.add_argument('--batch_size', type=int, + help="batch size for training") + parser.add_argument('--test_every', type=int, + help="test interval during training") + parser.add_argument('--save_every', type=int, + help="checkpointing interval during training") + parser.add_argument('--max_iterations', type=int, + help="maximum training iterations") + + parser.add_argument('--sigma', type=float, + help="standard deviation of the latent Gaussian variable") + parser.add_argument('--n_flows', type=int, + help="number of flows") + parser.add_argument('--n_group', type=int, + help="number of adjacent audio samples to squeeze into one column") + parser.add_argument('--n_layers', type=int, + help="number of conv2d layer in one wavenet-like flow architecture") + parser.add_argument('--n_channels', type=int, + help="number of residual channels in flow") + parser.add_argument('--kernel_h', type=int, + help="height of the kernel in the conv2d layer") + parser.add_argument('--kernel_w', type=int, + help="width of the kernel in the conv2d layer") + + parser.add_argument('--config', action=jsonargparse.ActionConfigFile) + + +def pad_to_size(array, length, pad_with=0.0): + """ + Pad an array on the first (length) axis to a given length. + """ + padding = length - array.shape[0] + assert padding >= 0, "Padding required was less than zero" + + paddings = [(0, 0)] * len(array.shape) + paddings[0] = (0, padding) + + return np.pad(array, paddings, mode='constant', constant_values=pad_with) + + +def calculate_context_size(config): + dilations = list( + itertools.islice( + itertools.cycle(config.dilation_block), config.layers)) + config.context_size = sum(dilations) + 1 + print("Context size is", config.context_size) + + +def load_latest_checkpoint(checkpoint_dir, rank=0): + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + # Create checkpoint index file if not exist. + if (not os.path.isfile(checkpoint_path)) and rank == 0: + with open(checkpoint_path, "w") as handle: + handle.write("model_checkpoint_path: step-0") + + # Make sure that other process waits until checkpoint file is created + # by process 0. + while not os.path.isfile(checkpoint_path): + time.sleep(1) + + # Fetch the latest checkpoint index. + with open(checkpoint_path, "r") as handle: + latest_checkpoint = handle.readline().split()[-1] + iteration = int(latest_checkpoint.split("-")[-1]) + + return iteration + + +def save_latest_checkpoint(checkpoint_dir, iteration): + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + # Update the latest checkpoint index. + with open(checkpoint_path, "w") as handle: + handle.write("model_checkpoint_path: step-{}".format(iteration)) + + +def load_parameters(checkpoint_dir, rank, model, optimizer=None, + iteration=None, file_path=None): + if file_path is None: + if iteration is None: + iteration = load_latest_checkpoint(checkpoint_dir, rank) + if iteration == 0: + return + file_path = "{}/step-{}".format(checkpoint_dir, iteration) + + model_dict, optimizer_dict = dg.load_dygraph(file_path) + model.set_dict(model_dict) + print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path)) + if optimizer and optimizer_dict: + optimizer.set_dict(optimizer_dict) + print("[checkpoint] Rank {}: loaded optimizer state from {}".format( + rank, file_path)) + + +def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): + file_path = "{}/step-{}".format(checkpoint_dir, iteration) + model_dict = model.state_dict() + dg.save_dygraph(model_dict, file_path) + print("[checkpoint] Saved model to {}".format(file_path)) + + if optimizer: + opt_dict = optimizer.state_dict() + dg.save_dygraph(opt_dict, file_path) + print("[checkpoint] Saved optimzier state to {}".format(file_path)) diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py new file mode 100644 index 0000000..b778497 --- /dev/null +++ b/parakeet/models/waveflow/waveflow.py @@ -0,0 +1,174 @@ +import itertools +import os +import time + +import librosa +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid + +import utils +from data import LJSpeech +from waveflow_modules import WaveFlowLoss, WaveFlowModule + + +class WaveFlow(): + def __init__(self, config, checkpoint_dir, parallel=False, rank=0, + nranks=1, tb_logger=None): + self.config = config + self.checkpoint_dir = checkpoint_dir + self.parallel = parallel + self.rank = rank + self.nranks = nranks + self.tb_logger = tb_logger + + def build(self, training=True): + config = self.config + dataset = LJSpeech(config, self.nranks, self.rank) + self.trainloader = dataset.trainloader + self.validloader = dataset.validloader + +# if self.rank == 0: +# for i, (audios, mels) in enumerate(self.validloader()): +# print("audios {}, mels {}".format(audios.dtype, mels.dtype)) +# print("{}: rank {}, audios {}, mels {}".format( +# i, self.rank, audios.shape, mels.shape)) +# +# for i, (audios, mels) in enumerate(self.trainloader): +# print("{}: rank {}, audios {}, mels {}".format( +# i, self.rank, audios.shape, mels.shape)) +# +# exit() + + waveflow = WaveFlowModule("waveflow", config) + + # Dry run once to create and initalize all necessary parameters. + audio = dg.to_variable(np.random.randn(1, 16000).astype(np.float32)) + mel = dg.to_variable( + np.random.randn(1, config.mel_bands, 63).astype(np.float32)) + waveflow(audio, mel) + + if training: + optimizer = fluid.optimizer.AdamOptimizer( + learning_rate=config.learning_rate) + + # Load parameters. + utils.load_parameters(self.checkpoint_dir, self.rank, + waveflow, optimizer, + iteration=config.iteration, + file_path=config.checkpoint) + print("Rank {}: checkpoint loaded.".format(self.rank)) + + # Data parallelism. + if self.parallel: + strategy = dg.parallel.prepare_context() + waveflow = dg.parallel.DataParallel(waveflow, strategy) + + self.waveflow = waveflow + self.optimizer = optimizer + self.criterion = WaveFlowLoss(config.sigma) + + else: + # Load parameters. + utils.load_parameters(self.checkpoint_dir, self.rank, waveflow, + iteration=config.iteration, + file_path=config.checkpoint) + print("Rank {}: checkpoint loaded.".format(self.rank)) + + self.waveflow = waveflow + + def train_step(self, iteration): + self.waveflow.train() + + start_time = time.time() + audios, mels = next(self.trainloader) + load_time = time.time() + + outputs = self.waveflow(audios, mels) + loss = self.criterion(outputs) + + if self.parallel: + # loss = loss / num_trainers + loss = self.waveflow.scale_loss(loss) + loss.backward() + self.waveflow.apply_collective_grads() + else: + loss.backward() + + current_lr = self.optimizer._learning_rate + + self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters()) + self.waveflow.clear_gradients() + + graph_time = time.time() + + if self.rank == 0: + loss_val = float(loss.numpy()) * self.nranks + log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \ + "Time: {:.3f}/{:.3f}".format( + self.rank, iteration, loss_val, + load_time - start_time, graph_time - load_time) + print(log) + + tb = self.tb_logger + tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration) + tb.add_scalar("Learning-Rate", current_lr, iteration) + + @dg.no_grad + def valid_step(self, iteration): + self.waveflow.eval() + tb = self.tb_logger + + total_loss = [] + sample_audios = [] + start_time = time.time() + + for i, batch in enumerate(self.validloader()): + audios, mels = batch + valid_outputs = self.waveflow(audios, mels) + valid_z, valid_log_s_list = valid_outputs + + # Visualize latent z and scale log_s. + if self.rank == 0 and i == 0: + tb.add_histogram("Valid-Latent_z", valid_z.numpy(), iteration) + for j, valid_log_s in enumerate(valid_log_s_list): + hist_name = "Valid-{}th-Flow-Log_s".format(j) + tb.add_histogram(hist_name, valid_log_s.numpy(), iteration) + + valid_loss = self.criterion(valid_outputs) + total_loss.append(float(valid_loss.numpy())) + + total_time = time.time() - start_time + if self.rank == 0: + loss_val = np.mean(total_loss) + log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format( + self.rank, loss_val, total_time) + print(log) + tb.add_scalar("Valid-Avg-Loss", loss_val, iteration) + + @dg.no_grad + def infer(self, iteration): + self.waveflow.eval() + + config = self.config + sample = config.sample + + output = "{}/{}/iter-{}".format(config.output, config.name, iteration) + os.makedirs(output, exist_ok=True) + + filename = "{}/valid_{}.wav".format(output, sample) + print("Synthesize sample {}, save as {}".format(sample, filename)) + + mels_list = [mels for _, mels, _ in self.validloader()] + start_time = time.time() + syn_audio = self.waveflow.synthesize(mels_list[sample]) + syn_time = time.time() - start_time + print("audio shape {}, synthesis time {}".format( + syn_audio.shape, syn_time)) + librosa.output.write_wav(filename, syn_audio, + sr=config.sample_rate) + + def save(self, iteration): + utils.save_latest_parameters(self.checkpoint_dir, iteration, + self.waveflow, self.optimizer) + utils.save_latest_checkpoint(self.checkpoint_dir, iteration) diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py new file mode 100644 index 0000000..a4b9c4f --- /dev/null +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -0,0 +1,256 @@ +import itertools + +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid +from parakeet.modules import conv, modules, weight_norm + + +def set_param_attr(layer, c_in=1): + if isinstance(layer, (weight_norm.Conv2DTranspose, weight_norm.Conv2D)): + k = np.sqrt(1.0 / (c_in * np.prod(layer._filter_size))) + weight_init = fluid.initializer.UniformInitializer(low=-k, high=k) + bias_init = fluid.initializer.UniformInitializer(low=-k, high=k) + elif isinstance(layer, dg.Conv2D): + weight_init = fluid.initializer.ConstantInitializer(0.0) + bias_init = fluid.initializer.ConstantInitializer(0.0) + else: + raise TypeError("Unsupported layer type.") + + layer._param_attr = fluid.ParamAttr(initializer=weight_init) + layer._bias_attr = fluid.ParamAttr(initializer=bias_init) + + +def unfold(x, n_group): + length = x.shape[-1] + #assert length % n_group == 0 + new_shape = x.shape[:-1] + [length // n_group, n_group] + return fluid.layers.reshape(x, new_shape) + + +class WaveFlowLoss: + def __init__(self, sigma=1.0): + self.sigma = sigma + + def __call__(self, model_output): + z, log_s_list = model_output + for i, log_s in enumerate(log_s_list): + if i == 0: + log_s_total = fluid.layers.reduce_sum(log_s) + else: + log_s_total = log_s_total + fluid.layers.reduce_sum(log_s) + + loss = fluid.layers.reduce_sum(z * z) / (2 * self.sigma * self.sigma) \ + - log_s_total + loss = loss / np.prod(z.shape) + const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma) + + return loss + const + + +class Conditioner(dg.Layer): + def __init__(self, name_scope): + super(Conditioner, self).__init__(name_scope) + upsample_factors = [16, 16] + + self.upsample_conv2d = [] + for s in upsample_factors: + in_channel = 1 + conv_trans2d = modules.Conv2DTranspose( + self.full_name(), + num_filters=1, + filter_size=(3, 2 * s), + padding=(1, s // 2), + stride=(1, s)) + set_param_attr(conv_trans2d, c_in=in_channel) + self.upsample_conv2d.append(conv_trans2d) + + for i, layer in enumerate(self.upsample_conv2d): + self.add_sublayer("conv2d_transpose_{}".format(i), layer) + + def forward(self, x): + x = fluid.layers.unsqueeze(x, 1) + for layer in self.upsample_conv2d: + x = fluid.layers.leaky_relu(layer(x), alpha=0.4) + + return fluid.layers.squeeze(x, [1]) + + +class Flow(dg.Layer): + def __init__(self, name_scope, config): + super(Flow, self).__init__(name_scope) + self.n_layers = config.n_layers + self.n_channels = config.n_channels + self.kernel_h = config.kernel_h + self.kernel_w = config.kernel_w + + # Transform audio: [batch, 1, n_group, time/n_group] + # => [batch, n_channels, n_group, time/n_group] + self.start = weight_norm.Conv2D( + self.full_name(), + num_filters=self.n_channels, + filter_size=(1, 1)) + set_param_attr(self.start, c_in=1) + + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + # output shape: [batch, 2, n_group, time/n_group] + self.end = dg.Conv2D( + self.full_name(), + num_filters=2, + filter_size=(1, 1)) + set_param_attr(self.end) + + # receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze + dilation_dict = {8: [1, 1, 1, 1, 1, 1, 1, 1], + 16: [1, 1, 1, 1, 1, 1, 1, 1], + 32: [1, 2, 4, 1, 2, 4, 1, 2], + 64: [1, 2, 4, 8, 16, 1, 2, 4], + 128: [1, 2, 4, 8, 16, 32, 64, 1]} + self.dilation_h_list = dilation_dict[config.n_group] + + self.in_layers = [] + self.cond_layers = [] + self.res_skip_layers = [] + for i in range(self.n_layers): + dilation_h = self.dilation_h_list[i] + dilation_w = 2 ** i + + in_layer = weight_norm.Conv2D( + self.full_name(), + num_filters=2 * self.n_channels, + filter_size=(self.kernel_h, self.kernel_w), + dilation=(dilation_h, dilation_w)) + set_param_attr(in_layer, c_in=self.n_channels) + self.in_layers.append(in_layer) + + cond_layer = weight_norm.Conv2D( + self.full_name(), + num_filters=2 * self.n_channels, + filter_size=(1, 1)) + set_param_attr(cond_layer, c_in=config.mel_bands) + self.cond_layers.append(cond_layer) + + if i < self.n_layers - 1: + res_skip_channels = 2 * self.n_channels + else: + res_skip_channels = self.n_channels + res_skip_layer = weight_norm.Conv2D( + self.full_name(), + num_filters=res_skip_channels, + filter_size=(1, 1)) + set_param_attr(res_skip_layer, c_in=self.n_channels) + self.res_skip_layers.append(res_skip_layer) + + self.add_sublayer("in_layer_{}".format(i), in_layer) + self.add_sublayer("cond_layer_{}".format(i), cond_layer) + self.add_sublayer("res_skip_layer_{}".format(i), res_skip_layer) + + def forward(self, audio, mel): + # audio: [bs, 1, n_group, time/group] + # mel: [bs, mel_bands, n_group, time/n_group] + audio = self.start(audio) + + for i in range(self.n_layers): + dilation_h = self.dilation_h_list[i] + dilation_w = 2 ** i + + # Pad height dim (n_group): causal convolution + # Pad width dim (time): dialated non-causal convolution + pad_top, pad_bottom = (self.kernel_h - 1) * dilation_h, 0 + pad_left = pad_right = int((self.kernel_w-1) * dilation_w / 2) + audio_pad = fluid.layers.pad2d(audio, + paddings=[pad_top, pad_bottom, pad_left, pad_right]) + + hidden = self.in_layers[i](audio_pad) + cond_hidden = self.cond_layers[i](mel) + in_acts = hidden + cond_hidden + out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \ + fluid.layers.sigmoid(in_acts[:, self.n_channels:, :]) + res_skip_acts = self.res_skip_layers[i](out_acts) + + if i < self.n_layers - 1: + audio += res_skip_acts[:, :self.n_channels, :, :] + skip_acts = res_skip_acts[:, self.n_channels:, :, :] + else: + skip_acts = res_skip_acts + + if i == 0: + output = skip_acts + else: + output += skip_acts + + return self.end(output) + + +class WaveFlowModule(dg.Layer): + def __init__(self, name_scope, config): + super(WaveFlowModule, self).__init__(name_scope) + self.n_flows = config.n_flows + self.n_group = config.n_group + assert self.n_group % 2 == 0 + + self.conditioner = Conditioner(self.full_name()) + self.flows = [] + for i in range(self.n_flows): + flow = Flow(self.full_name(), config) + self.flows.append(flow) + self.add_sublayer("flow_{}".format(i), flow) + + self.perms = [[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8], + [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8], + [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8], + [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8]] + + def forward(self, audio, mel): + mel = self.conditioner(mel) + assert mel.shape[2] >= audio.shape[1] + # Prune out the tail of audio/mel so that time/n_group == 0. + pruned_len = audio.shape[1] // self.n_group * self.n_group + + if audio.shape[1] > pruned_len: + audio = audio[:, :pruned_len] + if mel.shape[2] > pruned_len: + mel = mel[:, :, :pruned_len] + + # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] + mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) + # From [bs, time] to [bs, n_group, time/n_group] + audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1]) + # [bs, 1, n_group, time/n_group] + audio = fluid.layers.unsqueeze(audio, 1) + + log_s_list = [] + for i in range(self.n_flows): + inputs = audio[:, :, :-1, :] + conds = mel[:, :, 1:, :] + outputs = self.flows[i](inputs, conds) + log_s = outputs[:, :1, :, :] + b = outputs[:, 1:, :, :] + log_s_list.append(log_s) + + audio_0 = audio[:, :, :1, :] + audio_out = audio[:, :, 1:, :] * fluid.layers.exp(log_s) + b + audio = fluid.layers.concat([audio_0, audio_out], axis=2) + + # Permute over the height dim. + audio_slices = [audio[:, :, j, :] for j in self.perms[i]] + audio = fluid.layers.stack(audio_slices, axis=2) + mel_slices = [mel[:, :, j, :] for j in self.perms[i]] + mel = fluid.layers.stack(mel_slices, axis=2) + + z = fluid.layers.squeeze(audio, [1]) + + return z, log_s_list + + def synthesize(self, mels): + pass + + def start_new_sequence(self): + for layer in self.sublayers(): + if isinstance(layer, conv.Conv1D): + layer.start_new_sequence() From f6f0a2ca2129b69b64844d62a046e62e3e3bb677 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Thu, 12 Dec 2019 18:11:32 -0800 Subject: [PATCH 2/6] add documentation --- parakeet/models/waveflow/README.md | 67 ++--------------------- parakeet/models/waveflow/requirements.txt | 3 + 2 files changed, 7 insertions(+), 63 deletions(-) create mode 100644 parakeet/models/waveflow/requirements.txt diff --git a/parakeet/models/waveflow/README.md b/parakeet/models/waveflow/README.md index 18efd0b..355ca31 100644 --- a/parakeet/models/waveflow/README.md +++ b/parakeet/models/waveflow/README.md @@ -1,30 +1,6 @@ -# WaveNet with Paddle Fluid +### Install -Paddle fluid implementation of WaveNet, a deep generative model of raw audio waveforms. -WaveNet model is originally proposed in [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499). -Our implementation is based on the WaveNet architecture described in [ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech](https://arxiv.org/abs/1807.07281) and can provide various output distributions, including single Gaussian, mixture of Gaussian, and softmax with linearly quantized channels. - -We implement WaveNet model in paddle fluid with dynamic graph, which is convenient for flexible network architectures. - -## Project Structure -```text -├── configs # yaml configuration files of preset model hyperparameters -├── data.py # dataset and dataloader settings for LJSpeech -├── slurm.py # optional slurm helper functions if you use slurm to train model -├── synthesis.py # script for speech synthesis -├── train.py # script for model training -├── utils.py # helper functions for e.g., model checkpointing -├── wavenet.py # WaveNet model high level APIs -└── wavenet_modules.py # WaveNet model implementation -``` - -## Usage - -There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. Hyperparameters that are known to work good for the LJSpeech dataset are provided as yaml files in `./configs/` folder. Specifically, we provide `wavenet_ljspeech_single_gaussian.yaml`, `wavenet_ljspeech_mix_gaussian.yaml`, and `wavenet_ljspeech_softmax.yaml` config files for WaveNet with single Gaussian, 10-component mixture of Gaussians, and softmax (with 2048 linearly quantized channels) output distributions, respectively. - -Note that `train.py` and `synthesis.py` all accept a `--config` parameter. To ensure consistency, you should use the same config yaml file for both training and synthesizing. You can also overwrite these preset hyperparameters with command line by updating parameters after `--config`. For example `--config=${yaml} --batch_size=8 --layers=20` can overwrite the corresponding hyperparameters in the `${yaml}` config file. For more details about these hyperparameters, check `utils.add_config_options_to_parser`. - -Note that you also need to specify some additional parameters for `train.py` and `synthesis.py`, and the details can be found in `train.add_options_to_parser` and `synthesis.add_options_to_parser`, respectively. +pip install -r requirements.txt ### Dataset @@ -48,50 +24,15 @@ python -u train.py --config=${yaml} \ --parallel=false --use_gpu=true ``` -#### Save and Load checkpoints - -Our model will save model parameters as checkpoints in `./runs/wavenet/${ModelName}/checkpoint/` every 10000 iterations by default. -The saved checkpoint will have the format of `step-${iteration_number}.pdparams` for model parameters and `step-${iteration_number}.pdopt` for optimizer parameters. - -There are three ways to load a checkpoint and resume training (take an example that you want to load a 500000-iteration checkpoint): -1. Use `--checkpoint=./runs/wavenet/${ModelName}/checkpoint/step-500000` to provide a specific path to load. Note that you only need to provide the base name of the parameter file, which is `step-500000`, no extension name `.pdparams` or `.pdopt` is needed. -2. Use `--iteration=500000`. -3. If you don't specify either `--checkpoint` or `--iteration`, the model will automatically load the latest checkpoint in `./runs/wavenet/${ModelName}/checkpoint`. - ### Train on multiple GPUs ```bash export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." export CUDA_VISIBLE_DEVICES=0,1,2,3 python -u -m paddle.distributed.launch train.py \ - --config=${yaml} \ + --config=./configs/waveflow_ljspeech_sqz16_r64_layer8x8.yaml \ --root=./data/LJSpeech-1.1 \ - --name=${ModelName} --parallel=true --use_gpu=true + --name=test_speed --parallel=true --use_gpu=true ``` Use `export CUDA_VISIBLE_DEVICES=0,1,2,3` to set the GPUs that you want to use to be visible. Then the `paddle.distributed.launch` module will use these visible GPUs to do data parallel training in multiprocessing mode. - -### Monitor with Tensorboard - -By default, the logs are saved in `./runs/wavenet/${ModelName}/logs/`. You can monitor logs by tensorboard. - -```bash -tensorboard --logdir=${log_dir} --port=8888 -``` - -### Synthesize from a checkpoint - -Check the [Save and load checkpoint](#save-and-load-checkpoints) section on how to load a specific checkpoint. -The following example will automatically load the latest checkpoint: - -```bash -export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." -export CUDA_VISIBLE_DEVICES=0 -python -u synthesis.py --config=${yaml} \ - --root=./data/LJSpeech-1.1 \ - --name=${ModelName} --use_gpu=true \ - --output=./syn_audios \ - --sample=${SAMPLE} -``` - -In this example, `--output` specifies where to save the synthesized audios and `--sample` specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset. diff --git a/parakeet/models/waveflow/requirements.txt b/parakeet/models/waveflow/requirements.txt new file mode 100644 index 0000000..f575339 --- /dev/null +++ b/parakeet/models/waveflow/requirements.txt @@ -0,0 +1,3 @@ +paddlepaddle-gpu==1.6.1.post97 +tensorboardX==1.9 +librosa==0.7.1 From 8c22397b5504345f92068c923d732035aff739d5 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 16 Dec 2019 16:42:39 -0800 Subject: [PATCH 3/6] add working synthesis code --- parakeet/models/waveflow/data.py | 8 +-- parakeet/models/waveflow/synthesis.py | 8 +-- parakeet/models/waveflow/waveflow.py | 44 +++++++++--- parakeet/models/waveflow/waveflow_modules.py | 70 +++++++++++++++++++- 4 files changed, 106 insertions(+), 24 deletions(-) diff --git a/parakeet/models/waveflow/data.py b/parakeet/models/waveflow/data.py index 3c70ce0..ddaf104 100644 --- a/parakeet/models/waveflow/data.py +++ b/parakeet/models/waveflow/data.py @@ -79,17 +79,13 @@ class Subset(dataset.Dataset): mode='constant', constant_values=0) # Normalize audio. - audio = audio / MAX_WAV_VALUE + audio = audio.astype(np.float32) / MAX_WAV_VALUE mel = self.get_mel(audio) + #print("mel = {}, dtype {}, shape {}".format(mel, mel.dtype, mel.shape)) return audio, mel def _batch_examples(self, batch): - audio_batch = [] - mel_batch = [] - for audio, mel in batch: - audio_batch - audios = [sample[0] for sample in batch] mels = [sample[1] for sample in batch] diff --git a/parakeet/models/waveflow/synthesis.py b/parakeet/models/waveflow/synthesis.py index d87a188..e42e170 100644 --- a/parakeet/models/waveflow/synthesis.py +++ b/parakeet/models/waveflow/synthesis.py @@ -8,11 +8,11 @@ import paddle.fluid.dygraph as dg from paddle import fluid import utils -from wavenet import WaveNet +from waveflow import WaveFlow def add_options_to_parser(parser): - parser.add_argument('--model', type=str, default='wavenet', + parser.add_argument('--model', type=str, default='waveflow', help="general name of the model") parser.add_argument('--name', type=str, help="specific name of the training model") @@ -30,7 +30,7 @@ def add_options_to_parser(parser): parser.add_argument('--output', type=str, default="./syn_audios", help="path to write synthesized audio files") - parser.add_argument('--sample', type=int, + parser.add_argument('--sample', type=int, default=None, help="which of the valid samples to synthesize audio") @@ -54,7 +54,7 @@ def synthesize(config): print("Random Seed: ", seed) # Build model. - model = WaveNet(config, checkpoint_dir) + model = WaveFlow(config, checkpoint_dir) model.build(training=False) # Obtain the current iteration. diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py index b778497..b362c2d 100644 --- a/parakeet/models/waveflow/waveflow.py +++ b/parakeet/models/waveflow/waveflow.py @@ -2,7 +2,8 @@ import itertools import os import time -import librosa +#import librosa +from scipy.io.wavfile import write import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid @@ -156,17 +157,38 @@ class WaveFlow(): output = "{}/{}/iter-{}".format(config.output, config.name, iteration) os.makedirs(output, exist_ok=True) - filename = "{}/valid_{}.wav".format(output, sample) - print("Synthesize sample {}, save as {}".format(sample, filename)) + mels_list = [mels for _, mels in self.validloader()] + if sample is not None: + mels_list = [mels_list[sample]] - mels_list = [mels for _, mels, _ in self.validloader()] - start_time = time.time() - syn_audio = self.waveflow.synthesize(mels_list[sample]) - syn_time = time.time() - start_time - print("audio shape {}, synthesis time {}".format( - syn_audio.shape, syn_time)) - librosa.output.write_wav(filename, syn_audio, - sr=config.sample_rate) + audio_times = [] + inf_times = [] + for sample, mel in enumerate(mels_list): + filename = "{}/valid_{}.wav".format(output, sample) + print("Synthesize sample {}, save as {}".format(sample, filename)) + + start_time = time.time() + audio = self.waveflow.synthesize(mel) + syn_time = time.time() - start_time + + audio_time = audio.shape[0] / 22050 + print("audio time {}, synthesis time {}, speedup: {}".format( + audio_time, syn_time, audio_time / syn_time)) + + #librosa.output.write_wav(filename, syn_audio, + # sr=config.sample_rate) + audio = audio.numpy() * 32768.0 + audio = audio.astype('int16') + write(filename, config.sample_rate, audio) + + audio_times.append(audio_time) + inf_times.append(syn_time) + + total_audio = sum(audio_times) + total_inf = sum(inf_times) + + print("Total audio: {}, total inf time {}, speedup: {}".format( + total_audio, total_inf, total_audio / total_inf)) def save(self, iteration): utils.save_latest_parameters(self.checkpoint_dir, iteration, diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py index a4b9c4f..45b46a6 100644 --- a/parakeet/models/waveflow/waveflow_modules.py +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -75,6 +75,16 @@ class Conditioner(dg.Layer): return fluid.layers.squeeze(x, [1]) + def infer(self, x): + x = fluid.layers.unsqueeze(x, 1) + for layer in self.upsample_conv2d: + x = layer(x) + # Trim conv artifacts. + time_cutoff = layer._filter_size[1] - layer._stride[1] + x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4) + + return fluid.layers.squeeze(x, [1]) + class Flow(dg.Layer): def __init__(self, name_scope, config): @@ -183,6 +193,14 @@ class Flow(dg.Layer): return self.end(output) +def debug(x, msg): + y = x.numpy() + print(msg + " :\n", y) + print("shape: ", y.shape) + print("dtype: ", y.dtype) + print("") + + class WaveFlowModule(dg.Layer): def __init__(self, name_scope, config): super(WaveFlowModule, self).__init__(name_scope) @@ -217,7 +235,7 @@ class WaveFlowModule(dg.Layer): if mel.shape[2] > pruned_len: mel = mel[:, :, :pruned_len] - # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] + # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) # From [bs, time] to [bs, n_group, time/n_group] audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1]) @@ -247,8 +265,54 @@ class WaveFlowModule(dg.Layer): return z, log_s_list - def synthesize(self, mels): - pass + def synthesize(self, mel, sigma=1.0): + #debug(mel, "mel") + mel = self.conditioner.infer(mel) + #debug(mel, "mel after conditioner") + + # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] + mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) + #debug(mel, "after group") + + audio = fluid.layers.gaussian_random( + shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma) + + #debug(audio, "audio") + + for i in reversed(range(self.n_flows)): + # Permute over the height dimension. + audio_slices = [audio[:, :, j, :] for j in self.perms[i]] + audio = fluid.layers.stack(audio_slices, axis=2) + mel_slices = [mel[:, :, j, :] for j in self.perms[i]] + mel = fluid.layers.stack(mel_slices, axis=2) + + audio_list = [] + audio_0 = audio[:, :, :1, :] + audio_list.append(audio_0) + + for h in range(1, self.n_group): + # inputs: [bs, 1, h, time/n_group] + inputs = fluid.layers.concat(audio_list, axis=2) + conds = mel[:, :, 1:(h+1), :] + outputs = self.flows[i](inputs, conds) + + log_s = outputs[:, :1, (h-1):h, :] + b = outputs[:, 1:, (h-1):h, :] + audio_h = (audio[:, :, h:(h+1), :] - b) / fluid.layers.exp(log_s) + audio_list.append(audio_h) + + audio = fluid.layers.concat(audio_list, axis=2) + #print("audio.shape =", audio.shape) + + # Assume batch size = 1 + # audio: [n_group, time/n_group] + audio = fluid.layers.squeeze(audio, [0, 1]) + # audio: [time] + audio = fluid.layers.reshape( + fluid.layers.transpose(audio, [1, 0]), [-1]) + #print("audio.shape =", audio.shape) + + return audio def start_new_sequence(self): for layer in self.sublayers(): From 0e18d600572ca1e9461cced0f0a470b503c5c900 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Thu, 19 Dec 2019 00:03:06 -0800 Subject: [PATCH 4/6] refine code --- parakeet/models/waveflow/benchmark.py | 71 +++++++++++ ...4_layer8x8.yaml => waveflow_ljspeech.yaml} | 0 ...flow_ljspeech_sqz16_r64_layer8x8_s123.yaml | 24 ---- parakeet/models/waveflow/data.py | 8 +- parakeet/models/waveflow/requirements.txt | 3 - parakeet/models/waveflow/train.py | 25 ---- parakeet/models/waveflow/utils.py | 21 ---- parakeet/models/waveflow/waveflow.py | 56 ++++----- parakeet/models/waveflow/waveflow_modules.py | 113 +++++++++++------- 9 files changed, 170 insertions(+), 151 deletions(-) create mode 100644 parakeet/models/waveflow/benchmark.py rename parakeet/models/waveflow/configs/{waveflow_ljspeech_sqz16_r64_layer8x8.yaml => waveflow_ljspeech.yaml} (100%) delete mode 100644 parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml delete mode 100644 parakeet/models/waveflow/requirements.txt diff --git a/parakeet/models/waveflow/benchmark.py b/parakeet/models/waveflow/benchmark.py new file mode 100644 index 0000000..b2949d2 --- /dev/null +++ b/parakeet/models/waveflow/benchmark.py @@ -0,0 +1,71 @@ +import os +import random +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid + +import utils +from waveflow import WaveFlow + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='waveflow', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + + +def benchmark(config): + pprint(jsonargparse.namespace_to_dict(config)) + + # Get checkpoint directory path. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + + # Configurate device. + place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveFlow(config, checkpoint_dir) + model.build(training=False) + + # Run model inference. + model.benchmark() + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser( + description="Synthesize audio using WaveNet model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + benchmark(config) diff --git a/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8.yaml b/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml similarity index 100% rename from parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8.yaml rename to parakeet/models/waveflow/configs/waveflow_ljspeech.yaml diff --git a/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml b/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml deleted file mode 100644 index 7d45212..0000000 --- a/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml +++ /dev/null @@ -1,24 +0,0 @@ -valid_size: 16 -segment_length: 16000 -sample_rate: 22050 -fft_window_shift: 256 -fft_window_size: 1024 -fft_size: 1024 -mel_bands: 80 -mel_fmin: 0.0 -mel_fmax: 8000.0 - -seed: 123 -learning_rate: 0.0002 -batch_size: 8 -test_every: 2000 -save_every: 5000 -max_iterations: 2000000 - -sigma: 1.0 -n_flows: 8 -n_group: 16 -n_layers: 8 -n_channels: 64 -kernel_h: 3 -kernel_w: 3 diff --git a/parakeet/models/waveflow/data.py b/parakeet/models/waveflow/data.py index ddaf104..d89fb7b 100644 --- a/parakeet/models/waveflow/data.py +++ b/parakeet/models/waveflow/data.py @@ -4,7 +4,6 @@ import librosa import numpy as np from paddle import fluid -import utils from parakeet.datasets import ljspeech from parakeet.data import dataset from parakeet.data.batch import SpecBatcher, WavBatcher @@ -12,8 +11,6 @@ from parakeet.data.datacargo import DataCargo from parakeet.data.sampler import DistributedSampler, BatchSampler from scipy.io.wavfile import read -MAX_WAV_VALUE = 32768.0 - class Dataset(ljspeech.LJSpeech): def __init__(self, config): @@ -78,10 +75,9 @@ class Subset(dataset.Dataset): audio = np.pad(audio, (0, segment_length - audio.shape[0]), mode='constant', constant_values=0) - # Normalize audio. - audio = audio.astype(np.float32) / MAX_WAV_VALUE + # Normalize audio to the [-1, 1] range. + audio = audio.astype(np.float32) / 32768.0 mel = self.get_mel(audio) - #print("mel = {}, dtype {}, shape {}".format(mel, mel.dtype, mel.shape)) return audio, mel diff --git a/parakeet/models/waveflow/requirements.txt b/parakeet/models/waveflow/requirements.txt deleted file mode 100644 index f575339..0000000 --- a/parakeet/models/waveflow/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -paddlepaddle-gpu==1.6.1.post97 -tensorboardX==1.9 -librosa==0.7.1 diff --git a/parakeet/models/waveflow/train.py b/parakeet/models/waveflow/train.py index a125d97..89b787a 100644 --- a/parakeet/models/waveflow/train.py +++ b/parakeet/models/waveflow/train.py @@ -14,8 +14,6 @@ import slurm import utils from waveflow import WaveFlow -MAXIMUM_SAVE_TIME = 10 * 60 - def add_options_to_parser(parser): parser.add_argument('--model', type=str, default='waveflow', @@ -35,8 +33,6 @@ def add_options_to_parser(parser): "default to load the latest checkpoint")) parser.add_argument('--checkpoint', type=str, default=None, help="path of the checkpoint to load") - parser.add_argument('--slurm', type=bool, default=False, - help="whether you are using slurm to submit training jobs") def train(config): @@ -85,13 +81,6 @@ def train(config): else: iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) - # Get restart command if using slurm. - if config.slurm: - resume_command, death_time = slurm.restart_command() - if rank == 0: - print("Restart command:", " ".join(resume_command)) - done = False - while iteration < config.max_iterations: # Run one single training step. model.train_step(iteration) @@ -102,20 +91,6 @@ def train(config): # Run validation step. model.valid_step(iteration) - # Check whether reaching the time limit. - if config.slurm: - done = (death_time is not None and death_time - time.time() < - MAXIMUM_SAVE_TIME) - - if rank == 0 and done: - print("Saving progress before exiting.") - model.save(iteration) - - print("Running restart command:", " ".join(resume_command)) - # Submit restart command. - subprocess.check_call(resume_command) - break - if rank == 0 and iteration % config.save_every == 0: # Save parameters. model.save(iteration) diff --git a/parakeet/models/waveflow/utils.py b/parakeet/models/waveflow/utils.py index 494a409..3baeb60 100644 --- a/parakeet/models/waveflow/utils.py +++ b/parakeet/models/waveflow/utils.py @@ -57,27 +57,6 @@ def add_config_options_to_parser(parser): parser.add_argument('--config', action=jsonargparse.ActionConfigFile) -def pad_to_size(array, length, pad_with=0.0): - """ - Pad an array on the first (length) axis to a given length. - """ - padding = length - array.shape[0] - assert padding >= 0, "Padding required was less than zero" - - paddings = [(0, 0)] * len(array.shape) - paddings[0] = (0, padding) - - return np.pad(array, paddings, mode='constant', constant_values=pad_with) - - -def calculate_context_size(config): - dilations = list( - itertools.islice( - itertools.cycle(config.dilation_block), config.layers)) - config.context_size = sum(dilations) + 1 - print("Context size is", config.context_size) - - def load_latest_checkpoint(checkpoint_dir, rank=0): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") # Create checkpoint index file if not exist. diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py index b362c2d..4935d42 100644 --- a/parakeet/models/waveflow/waveflow.py +++ b/parakeet/models/waveflow/waveflow.py @@ -2,11 +2,10 @@ import itertools import os import time -#import librosa -from scipy.io.wavfile import write import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid +from scipy.io.wavfile import write import utils from data import LJSpeech @@ -29,18 +28,6 @@ class WaveFlow(): self.trainloader = dataset.trainloader self.validloader = dataset.validloader -# if self.rank == 0: -# for i, (audios, mels) in enumerate(self.validloader()): -# print("audios {}, mels {}".format(audios.dtype, mels.dtype)) -# print("{}: rank {}, audios {}, mels {}".format( -# i, self.rank, audios.shape, mels.shape)) -# -# for i, (audios, mels) in enumerate(self.trainloader): -# print("{}: rank {}, audios {}, mels {}".format( -# i, self.rank, audios.shape, mels.shape)) -# -# exit() - waveflow = WaveFlowModule("waveflow", config) # Dry run once to create and initalize all necessary parameters. @@ -96,8 +83,6 @@ class WaveFlow(): else: loss.backward() - current_lr = self.optimizer._learning_rate - self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters()) self.waveflow.clear_gradients() @@ -113,7 +98,6 @@ class WaveFlow(): tb = self.tb_logger tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration) - tb.add_scalar("Learning-Rate", current_lr, iteration) @dg.no_grad def valid_step(self, iteration): @@ -161,34 +145,44 @@ class WaveFlow(): if sample is not None: mels_list = [mels_list[sample]] - audio_times = [] - inf_times = [] for sample, mel in enumerate(mels_list): filename = "{}/valid_{}.wav".format(output, sample) print("Synthesize sample {}, save as {}".format(sample, filename)) start_time = time.time() - audio = self.waveflow.synthesize(mel) + audio = self.waveflow.synthesize(mel, sigma=self.config.sigma) syn_time = time.time() - start_time - audio_time = audio.shape[0] / 22050 - print("audio time {}, synthesis time {}, speedup: {}".format( - audio_time, syn_time, audio_time / syn_time)) + audio = audio[0] + audio_time = audio.shape[0] / self.config.sample_rate + print("audio time {:.4f}, synthesis time {:.4f}".format( + audio_time, syn_time)) - #librosa.output.write_wav(filename, syn_audio, - # sr=config.sample_rate) + # Denormalize audio from [-1, 1] to [-32768, 32768] int16 range. audio = audio.numpy() * 32768.0 audio = audio.astype('int16') write(filename, config.sample_rate, audio) - audio_times.append(audio_time) - inf_times.append(syn_time) + @dg.no_grad + def benchmark(self): + self.waveflow.eval() - total_audio = sum(audio_times) - total_inf = sum(inf_times) + mels_list = [mels for _, mels in self.validloader()] + mel = fluid.layers.concat(mels_list, axis=2) + mel = mel[:, :, :864] + batch_size = 8 + mel = fluid.layers.expand(mel, [batch_size, 1, 1]) - print("Total audio: {}, total inf time {}, speedup: {}".format( - total_audio, total_inf, total_audio / total_inf)) + for i in range(10): + start_time = time.time() + audio = self.waveflow.synthesize(mel, sigma=self.config.sigma) + print("audio.shape = ", audio.shape) + syn_time = time.time() - start_time + + audio_time = audio.shape[1] * batch_size / self.config.sample_rate + print("audio time {:.4f}, synthesis time {:.4f}".format( + audio_time, syn_time)) + print("{} X real-time".format(audio_time / syn_time)) def save(self, iteration): utils.save_latest_parameters(self.checkpoint_dir, iteration, diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py index 45b46a6..39cb598 100644 --- a/parakeet/models/waveflow/waveflow_modules.py +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -23,7 +23,6 @@ def set_param_attr(layer, c_in=1): def unfold(x, n_group): length = x.shape[-1] - #assert length % n_group == 0 new_shape = x.shape[:-1] + [length // n_group, n_group] return fluid.layers.reshape(x, new_shape) @@ -192,13 +191,53 @@ class Flow(dg.Layer): return self.end(output) + def infer(self, audio, mel, queues): + audio = self.start(audio) -def debug(x, msg): - y = x.numpy() - print(msg + " :\n", y) - print("shape: ", y.shape) - print("dtype: ", y.dtype) - print("") + for i in range(self.n_layers): + dilation_h = self.dilation_h_list[i] + dilation_w = 2 ** i + + state_size = dilation_h * (self.kernel_h - 1) + queue = queues[i] + + if len(queue) == 0: + for j in range(state_size): + queue.append(fluid.layers.zeros_like(audio)) + + state = queue[0:state_size] + state = fluid.layers.concat([*state, audio], axis=2) + + queue.pop(0) + queue.append(audio) + + # Pad height dim (n_group): causal convolution + # Pad width dim (time): dialated non-causal convolution + pad_top, pad_bottom = 0, 0 + pad_left = int((self.kernel_w-1) * dilation_w / 2) + pad_right = int((self.kernel_w-1) * dilation_w / 2) + state = fluid.layers.pad2d(state, + paddings=[pad_top, pad_bottom, pad_left, pad_right]) + + hidden = self.in_layers[i](state) + cond_hidden = self.cond_layers[i](mel) + in_acts = hidden + cond_hidden + out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \ + fluid.layers.sigmoid(in_acts[:, self.n_channels:, :]) + res_skip_acts = self.res_skip_layers[i](out_acts) + + if i < self.n_layers - 1: + audio += res_skip_acts[:, :self.n_channels, :, :] + skip_acts = res_skip_acts[:, self.n_channels:, :, :] + else: + skip_acts = res_skip_acts + + if i == 0: + output = skip_acts + else: + output += skip_acts + + return self.end(output) class WaveFlowModule(dg.Layer): @@ -206,7 +245,9 @@ class WaveFlowModule(dg.Layer): super(WaveFlowModule, self).__init__(name_scope) self.n_flows = config.n_flows self.n_group = config.n_group + self.n_layers = config.n_layers assert self.n_group % 2 == 0 + assert self.n_flows % 2 == 0 self.conditioner = Conditioner(self.full_name()) self.flows = [] @@ -215,14 +256,16 @@ class WaveFlowModule(dg.Layer): self.flows.append(flow) self.add_sublayer("flow_{}".format(i), flow) - self.perms = [[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], - [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], - [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], - [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], - [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8], - [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8], - [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8], - [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8]] + self.perms = [] + half = self.n_group // 2 + for i in range(self.n_flows): + perm = list(range(self.n_group)) + if i < self.n_flows // 2: + perm = perm[::-1] + else: + perm[:half] = reversed(perm[:half]) + perm[half:] = reversed(perm[half:]) + self.perms.append(perm) def forward(self, audio, mel): mel = self.conditioner(mel) @@ -266,19 +309,13 @@ class WaveFlowModule(dg.Layer): return z, log_s_list def synthesize(self, mel, sigma=1.0): - #debug(mel, "mel") mel = self.conditioner.infer(mel) - #debug(mel, "mel after conditioner") - # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) - #debug(mel, "after group") audio = fluid.layers.gaussian_random( shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma) - #debug(audio, "audio") - for i in reversed(range(self.n_flows)): # Permute over the height dimension. audio_slices = [audio[:, :, j, :] for j in self.perms[i]] @@ -287,34 +324,28 @@ class WaveFlowModule(dg.Layer): mel = fluid.layers.stack(mel_slices, axis=2) audio_list = [] - audio_0 = audio[:, :, :1, :] + audio_0 = audio[:, :, 0:1, :] audio_list.append(audio_0) + audio_h = audio_0 + queues = [[] for _ in range(self.n_layers)] for h in range(1, self.n_group): - # inputs: [bs, 1, h, time/n_group] - inputs = fluid.layers.concat(audio_list, axis=2) - conds = mel[:, :, 1:(h+1), :] - outputs = self.flows[i](inputs, conds) + inputs = audio_h + conds = mel[:, :, h:(h+1), :] + outputs = self.flows[i].infer(inputs, conds, queues) - log_s = outputs[:, :1, (h-1):h, :] - b = outputs[:, 1:, (h-1):h, :] - audio_h = (audio[:, :, h:(h+1), :] - b) / fluid.layers.exp(log_s) + log_s = outputs[:, 0:1, :, :] + b = outputs[:, 1:, :, :] + audio_h = (audio[:, :, h:(h+1), :] - b) / \ + fluid.layers.exp(log_s) audio_list.append(audio_h) audio = fluid.layers.concat(audio_list, axis=2) - #print("audio.shape =", audio.shape) - # Assume batch size = 1 - # audio: [n_group, time/n_group] - audio = fluid.layers.squeeze(audio, [0, 1]) - # audio: [time] + # audio: [bs, n_group, time/n_group] + audio = fluid.layers.squeeze(audio, [1]) + # audio: [bs, time] audio = fluid.layers.reshape( - fluid.layers.transpose(audio, [1, 0]), [-1]) - #print("audio.shape =", audio.shape) + fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1]) return audio - - def start_new_sequence(self): - for layer in self.sublayers(): - if isinstance(layer, conv.Conv1D): - layer.start_new_sequence() From 4af577ad723788a0dd6c10bba1f357a263a92149 Mon Sep 17 00:00:00 2001 From: zhaokexin01 Date: Thu, 19 Dec 2019 16:34:22 +0800 Subject: [PATCH 5/6] Update README.md --- parakeet/models/waveflow/README.md | 83 ++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 5 deletions(-) diff --git a/parakeet/models/waveflow/README.md b/parakeet/models/waveflow/README.md index 355ca31..d8072b1 100644 --- a/parakeet/models/waveflow/README.md +++ b/parakeet/models/waveflow/README.md @@ -1,6 +1,28 @@ -### Install +# WaveFlow with Paddle Fluid -pip install -r requirements.txt +Paddle fluid implementation of [WaveFlow: A Compact Flow-based Model for Raw Audio](https://arxiv.org/abs/1912.01219). + +## Project Structure +```text +├── configs # yaml configuration files of preset model hyperparameters +├── benchmark.py # benchmark code to test the speed of batched speech synthesis +├── data.py # dataset and dataloader settings for LJSpeech +├── synthesis.py # script for speech synthesis +├── train.py # script for model training +├── utils.py # helper functions for e.g., model checkpointing +├── waveflow.py # WaveFlow model high level APIs +└── waveflow_modules.py # WaveFlow model implementation +``` + +## Usage + +There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. +We provide `wavenet_ljspeech.yaml` as a hyperparameter set that works well on the LJSpeech dataset. + +Note that `train.py`, `synthesis.py`, and `benchmark.py` all accept a `--config` parameter. To ensure consistency, you should use the same config yaml file for both training, synthesizing and benchmarking. You can also overwrite these preset hyperparameters with command line by updating parameters after `--config`. +For example `--config=${yaml} --batch_size=8` can overwrite the corresponding hyperparameters in the `${yaml}` config file. For more details about these hyperparameters, check `utils.add_config_options_to_parser`. + +Note that you also need to specify some additional parameters for `train.py`, `synthesis.py`, and `benchmark.py`, and the details can be found in `train.add_options_to_parser`, `synthesis.add_options_to_parser`, and `benchmark.add_options_to_parser`, respectively. ### Dataset @@ -18,21 +40,72 @@ In this example, assume that the path of unzipped LJSpeech dataset is `./data/LJ ```bash export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." export CUDA_VISIBLE_DEVICES=0 -python -u train.py --config=${yaml} \ +python -u train.py \ + --config=./configs/waveflow_ljspeech.yaml \ --root=./data/LJSpeech-1.1 \ --name=${ModelName} --batch_size=4 \ --parallel=false --use_gpu=true ``` +#### Save and Load checkpoints + +Our model will save model parameters as checkpoints in `./runs/waveflow/${ModelName}/checkpoint/` every 10000 iterations by default. +The saved checkpoint will have the format of `step-${iteration_number}.pdparams` for model parameters and `step-${iteration_number}.pdopt` for optimizer parameters. + +There are three ways to load a checkpoint and resume training (take an example that you want to load a 500000-iteration checkpoint): +1. Use `--checkpoint=./runs/waveflow/${ModelName}/checkpoint/step-500000` to provide a specific path to load. Note that you only need to provide the base name of the parameter file, which is `step-500000`, no extension name `.pdparams` or `.pdopt` is needed. +2. Use `--iteration=500000`. +3. If you don't specify either `--checkpoint` or `--iteration`, the model will automatically load the latest checkpoint in `./runs/waveflow/${ModelName}/checkpoint`. + ### Train on multiple GPUs ```bash export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." export CUDA_VISIBLE_DEVICES=0,1,2,3 python -u -m paddle.distributed.launch train.py \ - --config=./configs/waveflow_ljspeech_sqz16_r64_layer8x8.yaml \ + --config=./configs/waveflow_ljspeech.yaml \ --root=./data/LJSpeech-1.1 \ - --name=test_speed --parallel=true --use_gpu=true + --name=${ModelName} --parallel=true --use_gpu=true ``` Use `export CUDA_VISIBLE_DEVICES=0,1,2,3` to set the GPUs that you want to use to be visible. Then the `paddle.distributed.launch` module will use these visible GPUs to do data parallel training in multiprocessing mode. + +### Monitor with Tensorboard + +By default, the logs are saved in `./runs/waveflow/${ModelName}/logs/`. You can monitor logs by tensorboard. + +```bash +tensorboard --logdir=${log_dir} --port=8888 +``` + +### Synthesize from a checkpoint + +Check the [Save and load checkpoint](#save-and-load-checkpoints) section on how to load a specific checkpoint. +The following example will automatically load the latest checkpoint: + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0 +python -u synthesis.py \ + --config=./configs/waveflow_ljspeech.yaml \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --use_gpu=true \ + --output=./syn_audios \ + --sample=${SAMPLE} \ + --sigma=1.0 +``` + +In this example, `--output` specifies where to save the synthesized audios and `--sample` specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset. + +### Benchmarking + +Use the following example to benchmark the speed of batched speech synthesis, which reports how many times faster than real-time: + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0 +python -u benchmark.py \ + --config=./configs/waveflow_ljspeech.yaml \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --use_gpu=true +``` \ No newline at end of file From 91ab2b34c4ca01dc731f8a97cdd3e0e8914c19e5 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Thu, 19 Dec 2019 00:37:43 -0800 Subject: [PATCH 6/6] small change --- .../waveflow/configs/waveflow_ljspeech.yaml | 4 +- parakeet/models/waveflow/slurm.py | 113 ------------------ 2 files changed, 2 insertions(+), 115 deletions(-) delete mode 100644 parakeet/models/waveflow/slurm.py diff --git a/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml b/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml index f9bbc83..d3548c4 100644 --- a/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml +++ b/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml @@ -12,8 +12,8 @@ seed: 1234 learning_rate: 0.0002 batch_size: 8 test_every: 2000 -save_every: 5000 -max_iterations: 2000000 +save_every: 10000 +max_iterations: 3000000 sigma: 1.0 n_flows: 8 diff --git a/parakeet/models/waveflow/slurm.py b/parakeet/models/waveflow/slurm.py deleted file mode 100644 index de1818c..0000000 --- a/parakeet/models/waveflow/slurm.py +++ /dev/null @@ -1,113 +0,0 @@ -""" -Utility module for restarting training when using SLURM. -""" -import subprocess -import os -import sys -import shlex -import re -import time - - -def job_info(): - """Get information about the current job using `scontrol show job`. - Returns a dict mapping parameter names (e.g. "UserId", "RunTime", etc) to - their values, both as strings. - """ - job_id = int(os.environ["SLURM_JOB_ID"]) - - command = ["scontrol", "show", "job", str(job_id)] - output = subprocess.check_output(command).decode("utf-8") - - # Use a regex to extract the parameter names and values - pattern = "([A-Za-z/]*)=([^ \t\n]*)" - return dict(re.findall(pattern, output)) - - -def parse_hours(text): - """Parse a time format HH or DD-HH into a number of hours.""" - hour_chunks = text.split("-") - if len(hour_chunks) == 1: - return int(hour_chunks[0]) - elif len(hour_chunks) == 2: - return 24 * int(hour_chunks[0]) + int(hour_chunks[1]) - else: - raise ValueError("Unexpected hour format (expected HH or " - "DD-HH, but got {}).".format(text)) - - -def parse_time(text): - """Convert slurm time to an integer. - Expects time to be of the form: - "hours:minutes:seconds" or "day-hours:minutes:seconds". - """ - hours, minutes, seconds = text.split(":") - try: - return parse_hours(hours) * 3600 + int(minutes) * 60 + int(seconds) - except ValueError as e: - raise ValueError("Error parsing time {}. Got error {}.".format( - text, str(e))) - - -def restart_command(): - """Using the environment and SLURM command, create a command that, when, - run, will enqueue a repeat of the current job using `sbatch`. - Return the command as a list of strings, suitable for passing to - `subprocess.check_call` or similar functions. - Returns: - resume_command: list, command to run to restart job. - end_time: int or None; the time the job will end or None - if the job has unlimited runtime. - """ - # Make sure `RunTime` could be parsed correctly. - while job_info()["RunTime"] == "INVALID": - time.sleep(1) - - # Get all the necessary information by querying SLURM with this job id - info = job_info() - - try: - num_cpus = int(info["CPUs/Task"]) - except KeyError: - num_cpus = int(os.environ["SLURM_CPUS_PER_TASK"]) - - num_tasks = int(os.environ["SLURM_NTASKS"]) - nodes = info["NumNodes"] - gres, partition = info.get("Gres"), info.get("Partition") - stderr, stdout = info.get("StdErr"), info.get("StdOut") - job_name = info.get("JobName") - command = ["sbatch", "--job-name={}".format(job_name), - "--ntasks={}".format(num_tasks), - "--exclude=asimov-186"] - - if partition: - command.extend(["--partition", partition]) - - if gres and gres != "(null)": - command.extend(["--gres", gres]) - num_gpu = int(gres.split(':')[-1]) - print("number of gpu assigned by slurm is {}".format(num_gpu)) - - if stderr: - command.extend(["--error", stderr]) - - if stdout: - command.extend(["--output", stdout]) - - python = subprocess.check_output( - ["/usr/bin/which", "python3"]).decode("utf-8").strip() - dist_setting = ['-m', 'paddle.distributed.launch'] - wrap_cmd = ["srun", python, '-u'] + dist_setting + sys.argv - - command.append( - "--wrap={}".format(" ".join(shlex.quote(arg) for arg in wrap_cmd))) - time_limit_string = info["TimeLimit"] - if time_limit_string.lower() == "unlimited": - print("UNLIMITED detected: restart OFF, infinite learning ON.", - flush=True) - return command, None - time_limit = parse_time(time_limit_string) - runtime = parse_time(info["RunTime"]) - end_time = time.time() + time_limit - runtime - - return command, end_time