diff --git a/parakeet/models/waveflow/README.md b/parakeet/models/waveflow/README.md new file mode 100644 index 0000000..d8072b1 --- /dev/null +++ b/parakeet/models/waveflow/README.md @@ -0,0 +1,111 @@ +# WaveFlow with Paddle Fluid + +Paddle fluid implementation of [WaveFlow: A Compact Flow-based Model for Raw Audio](https://arxiv.org/abs/1912.01219). + +## Project Structure +```text +├── configs # yaml configuration files of preset model hyperparameters +├── benchmark.py # benchmark code to test the speed of batched speech synthesis +├── data.py # dataset and dataloader settings for LJSpeech +├── synthesis.py # script for speech synthesis +├── train.py # script for model training +├── utils.py # helper functions for e.g., model checkpointing +├── waveflow.py # WaveFlow model high level APIs +└── waveflow_modules.py # WaveFlow model implementation +``` + +## Usage + +There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. +We provide `wavenet_ljspeech.yaml` as a hyperparameter set that works well on the LJSpeech dataset. + +Note that `train.py`, `synthesis.py`, and `benchmark.py` all accept a `--config` parameter. To ensure consistency, you should use the same config yaml file for both training, synthesizing and benchmarking. You can also overwrite these preset hyperparameters with command line by updating parameters after `--config`. +For example `--config=${yaml} --batch_size=8` can overwrite the corresponding hyperparameters in the `${yaml}` config file. For more details about these hyperparameters, check `utils.add_config_options_to_parser`. + +Note that you also need to specify some additional parameters for `train.py`, `synthesis.py`, and `benchmark.py`, and the details can be found in `train.add_options_to_parser`, `synthesis.add_options_to_parser`, and `benchmark.add_options_to_parser`, respectively. + +### Dataset + +Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). + +```bash +wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 +tar xjvf LJSpeech-1.1.tar.bz2 +``` + +In this example, assume that the path of unzipped LJSpeech dataset is `./data/LJSpeech-1.1`. + +### Train on single GPU + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0 +python -u train.py \ + --config=./configs/waveflow_ljspeech.yaml \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --batch_size=4 \ + --parallel=false --use_gpu=true +``` + +#### Save and Load checkpoints + +Our model will save model parameters as checkpoints in `./runs/waveflow/${ModelName}/checkpoint/` every 10000 iterations by default. +The saved checkpoint will have the format of `step-${iteration_number}.pdparams` for model parameters and `step-${iteration_number}.pdopt` for optimizer parameters. + +There are three ways to load a checkpoint and resume training (take an example that you want to load a 500000-iteration checkpoint): +1. Use `--checkpoint=./runs/waveflow/${ModelName}/checkpoint/step-500000` to provide a specific path to load. Note that you only need to provide the base name of the parameter file, which is `step-500000`, no extension name `.pdparams` or `.pdopt` is needed. +2. Use `--iteration=500000`. +3. If you don't specify either `--checkpoint` or `--iteration`, the model will automatically load the latest checkpoint in `./runs/waveflow/${ModelName}/checkpoint`. + +### Train on multiple GPUs + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -u -m paddle.distributed.launch train.py \ + --config=./configs/waveflow_ljspeech.yaml \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --parallel=true --use_gpu=true +``` + +Use `export CUDA_VISIBLE_DEVICES=0,1,2,3` to set the GPUs that you want to use to be visible. Then the `paddle.distributed.launch` module will use these visible GPUs to do data parallel training in multiprocessing mode. + +### Monitor with Tensorboard + +By default, the logs are saved in `./runs/waveflow/${ModelName}/logs/`. You can monitor logs by tensorboard. + +```bash +tensorboard --logdir=${log_dir} --port=8888 +``` + +### Synthesize from a checkpoint + +Check the [Save and load checkpoint](#save-and-load-checkpoints) section on how to load a specific checkpoint. +The following example will automatically load the latest checkpoint: + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0 +python -u synthesis.py \ + --config=./configs/waveflow_ljspeech.yaml \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --use_gpu=true \ + --output=./syn_audios \ + --sample=${SAMPLE} \ + --sigma=1.0 +``` + +In this example, `--output` specifies where to save the synthesized audios and `--sample` specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset. + +### Benchmarking + +Use the following example to benchmark the speed of batched speech synthesis, which reports how many times faster than real-time: + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0 +python -u benchmark.py \ + --config=./configs/waveflow_ljspeech.yaml \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --use_gpu=true +``` \ No newline at end of file diff --git a/parakeet/models/waveflow/benchmark.py b/parakeet/models/waveflow/benchmark.py new file mode 100644 index 0000000..b2949d2 --- /dev/null +++ b/parakeet/models/waveflow/benchmark.py @@ -0,0 +1,71 @@ +import os +import random +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid + +import utils +from waveflow import WaveFlow + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='waveflow', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + + +def benchmark(config): + pprint(jsonargparse.namespace_to_dict(config)) + + # Get checkpoint directory path. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + + # Configurate device. + place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveFlow(config, checkpoint_dir) + model.build(training=False) + + # Run model inference. + model.benchmark() + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser( + description="Synthesize audio using WaveNet model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + benchmark(config) diff --git a/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml b/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml new file mode 100644 index 0000000..d3548c4 --- /dev/null +++ b/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml @@ -0,0 +1,24 @@ +valid_size: 16 +segment_length: 16000 +sample_rate: 22050 +fft_window_shift: 256 +fft_window_size: 1024 +fft_size: 1024 +mel_bands: 80 +mel_fmin: 0.0 +mel_fmax: 8000.0 + +seed: 1234 +learning_rate: 0.0002 +batch_size: 8 +test_every: 2000 +save_every: 10000 +max_iterations: 3000000 + +sigma: 1.0 +n_flows: 8 +n_group: 16 +n_layers: 8 +n_channels: 64 +kernel_h: 3 +kernel_w: 3 diff --git a/parakeet/models/waveflow/data.py b/parakeet/models/waveflow/data.py new file mode 100644 index 0000000..d89fb7b --- /dev/null +++ b/parakeet/models/waveflow/data.py @@ -0,0 +1,131 @@ +import random + +import librosa +import numpy as np +from paddle import fluid + +from parakeet.datasets import ljspeech +from parakeet.data import dataset +from parakeet.data.batch import SpecBatcher, WavBatcher +from parakeet.data.datacargo import DataCargo +from parakeet.data.sampler import DistributedSampler, BatchSampler +from scipy.io.wavfile import read + + +class Dataset(ljspeech.LJSpeech): + def __init__(self, config): + super(Dataset, self).__init__(config.root) + self.config = config + + def _get_example(self, metadatum): + fname, _, _ = metadatum + wav_path = self.root.joinpath("wavs", fname + ".wav") + + loaded_sr, audio = read(wav_path) + assert loaded_sr == self.config.sample_rate + + return audio + + +class Subset(dataset.Dataset): + def __init__(self, dataset, indices, valid): + self.dataset = dataset + self.indices = indices + self.valid = valid + self.config = dataset.config + + def get_mel(self, audio): + spectrogram = librosa.core.stft( + audio, n_fft=self.config.fft_size, + hop_length=self.config.fft_window_shift, + win_length=self.config.fft_window_size) + spectrogram_magnitude = np.abs(spectrogram) + + # mel_filter_bank shape: [n_mels, 1 + n_fft/2] + mel_filter_bank = librosa.filters.mel( + sr=self.config.sample_rate, + n_fft=self.config.fft_size, + n_mels=self.config.mel_bands, + fmin=self.config.mel_fmin, + fmax=self.config.mel_fmax) + # mel shape: [n_mels, num_frames] + mel = np.dot(mel_filter_bank, spectrogram_magnitude) + + # Normalize mel. + clip_val = 1e-5 + ref_constant = 1 + mel = np.log(np.clip(mel, a_min=clip_val, a_max=None) * ref_constant) + + return mel + + def __getitem__(self, idx): + audio = self.dataset[self.indices[idx]] + segment_length = self.config.segment_length + + if self.valid: + # whole audio for valid set + pass + else: + # audio shape: [len] + if audio.shape[0] >= segment_length: + max_audio_start = audio.shape[0] - segment_length + audio_start = random.randint(0, max_audio_start) + audio = audio[audio_start : (audio_start + segment_length)] + else: + audio = np.pad(audio, (0, segment_length - audio.shape[0]), + mode='constant', constant_values=0) + + # Normalize audio to the [-1, 1] range. + audio = audio.astype(np.float32) / 32768.0 + mel = self.get_mel(audio) + + return audio, mel + + def _batch_examples(self, batch): + audios = [sample[0] for sample in batch] + mels = [sample[1] for sample in batch] + + audios = WavBatcher(pad_value=0.0)(audios) + mels = SpecBatcher(pad_value=0.0)(mels) + + return audios, mels + + def __len__(self): + return len(self.indices) + + +class LJSpeech: + def __init__(self, config, nranks, rank): + place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace() + + # Whole LJSpeech dataset. + ds = Dataset(config) + + # Split into train and valid dataset. + indices = list(range(len(ds))) + train_indices = indices[config.valid_size:] + valid_indices = indices[:config.valid_size] + random.shuffle(train_indices) + + # Train dataset. + trainset = Subset(ds, train_indices, valid=False) + sampler = DistributedSampler(len(trainset), nranks, rank) + total_bs = config.batch_size + assert total_bs % nranks == 0 + train_sampler = BatchSampler(sampler, total_bs // nranks, + drop_last=True) + trainloader = DataCargo(trainset, batch_sampler=train_sampler) + + trainreader = fluid.io.PyReader(capacity=50, return_list=True) + trainreader.decorate_batch_generator(trainloader, place) + self.trainloader = (data for _ in iter(int, 1) + for data in trainreader()) + + # Valid dataset. + validset = Subset(ds, valid_indices, valid=True) + # Currently only support batch_size = 1 for valid loader. + validloader = DataCargo(validset, batch_size=1, shuffle=False) + + validreader = fluid.io.PyReader(capacity=20, return_list=True) + validreader.decorate_batch_generator(validloader, place) + self.validloader = validreader diff --git a/parakeet/models/waveflow/synthesis.py b/parakeet/models/waveflow/synthesis.py new file mode 100644 index 0000000..e42e170 --- /dev/null +++ b/parakeet/models/waveflow/synthesis.py @@ -0,0 +1,85 @@ +import os +import random +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid + +import utils +from waveflow import WaveFlow + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='waveflow', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + + parser.add_argument('--output', type=str, default="./syn_audios", + help="path to write synthesized audio files") + parser.add_argument('--sample', type=int, default=None, + help="which of the valid samples to synthesize audio") + + +def synthesize(config): + pprint(jsonargparse.namespace_to_dict(config)) + + # Get checkpoint directory path. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + + # Configurate device. + place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveFlow(config, checkpoint_dir) + model.build(training=False) + + # Obtain the current iteration. + if config.checkpoint is None: + if config.iteration is None: + iteration = utils.load_latest_checkpoint(checkpoint_dir) + else: + iteration = config.iteration + else: + iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) + + # Run model inference. + model.infer(iteration) + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser( + description="Synthesize audio using WaveNet model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + synthesize(config) diff --git a/parakeet/models/waveflow/train.py b/parakeet/models/waveflow/train.py new file mode 100644 index 0000000..89b787a --- /dev/null +++ b/parakeet/models/waveflow/train.py @@ -0,0 +1,114 @@ +import os +import random +import subprocess +import time +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid +from tensorboardX import SummaryWriter + +import slurm +import utils +from waveflow import WaveFlow + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='waveflow', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--parallel', type=bool, default=True, + help="option to use data parallel training") + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + + +def train(config): + use_gpu = config.use_gpu + parallel = config.parallel if use_gpu else False + + # Get the rank of the current training process. + rank = dg.parallel.Env().local_rank if parallel else 0 + nranks = dg.parallel.Env().nranks if parallel else 1 + + if rank == 0: + # Print the whole config setting. + pprint(jsonargparse.namespace_to_dict(config)) + + # Make checkpoint directory. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + os.makedirs(checkpoint_dir, exist_ok=True) + + # Create tensorboard logger. + tb = SummaryWriter(os.path.join(run_dir, "logs")) \ + if rank == 0 else None + + # Configurate device + place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveFlow(config, checkpoint_dir, parallel, rank, nranks, tb) + model.build() + + # Obtain the current iteration. + if config.checkpoint is None: + if config.iteration is None: + iteration = utils.load_latest_checkpoint(checkpoint_dir, rank) + else: + iteration = config.iteration + else: + iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) + + while iteration < config.max_iterations: + # Run one single training step. + model.train_step(iteration) + + iteration += 1 + + if iteration % config.test_every == 0: + # Run validation step. + model.valid_step(iteration) + + if rank == 0 and iteration % config.save_every == 0: + # Save parameters. + model.save(iteration) + + # Close TensorBoard. + if rank == 0: + tb.close() + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser(description="Train WaveFlow model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + train(config) diff --git a/parakeet/models/waveflow/utils.py b/parakeet/models/waveflow/utils.py new file mode 100644 index 0000000..3baeb60 --- /dev/null +++ b/parakeet/models/waveflow/utils.py @@ -0,0 +1,114 @@ +import itertools +import os +import time + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg + + +def add_config_options_to_parser(parser): + parser.add_argument('--valid_size', type=int, + help="size of the valid dataset") + parser.add_argument('--segment_length', type=int, + help="the length of audio clip for training") + parser.add_argument('--sample_rate', type=int, + help="sampling rate of audio data file") + parser.add_argument('--fft_window_shift', type=int, + help="the shift of fft window for each frame") + parser.add_argument('--fft_window_size', type=int, + help="the size of fft window for each frame") + parser.add_argument('--fft_size', type=int, + help="the size of fft filter on each frame") + parser.add_argument('--mel_bands', type=int, + help="the number of mel bands when calculating mel spectrograms") + parser.add_argument('--mel_fmin', type=float, + help="lowest frequency in calculating mel spectrograms") + parser.add_argument('--mel_fmax', type=float, + help="highest frequency in calculating mel spectrograms") + + parser.add_argument('--seed', type=int, + help="seed of random initialization for the model") + parser.add_argument('--learning_rate', type=float) + parser.add_argument('--batch_size', type=int, + help="batch size for training") + parser.add_argument('--test_every', type=int, + help="test interval during training") + parser.add_argument('--save_every', type=int, + help="checkpointing interval during training") + parser.add_argument('--max_iterations', type=int, + help="maximum training iterations") + + parser.add_argument('--sigma', type=float, + help="standard deviation of the latent Gaussian variable") + parser.add_argument('--n_flows', type=int, + help="number of flows") + parser.add_argument('--n_group', type=int, + help="number of adjacent audio samples to squeeze into one column") + parser.add_argument('--n_layers', type=int, + help="number of conv2d layer in one wavenet-like flow architecture") + parser.add_argument('--n_channels', type=int, + help="number of residual channels in flow") + parser.add_argument('--kernel_h', type=int, + help="height of the kernel in the conv2d layer") + parser.add_argument('--kernel_w', type=int, + help="width of the kernel in the conv2d layer") + + parser.add_argument('--config', action=jsonargparse.ActionConfigFile) + + +def load_latest_checkpoint(checkpoint_dir, rank=0): + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + # Create checkpoint index file if not exist. + if (not os.path.isfile(checkpoint_path)) and rank == 0: + with open(checkpoint_path, "w") as handle: + handle.write("model_checkpoint_path: step-0") + + # Make sure that other process waits until checkpoint file is created + # by process 0. + while not os.path.isfile(checkpoint_path): + time.sleep(1) + + # Fetch the latest checkpoint index. + with open(checkpoint_path, "r") as handle: + latest_checkpoint = handle.readline().split()[-1] + iteration = int(latest_checkpoint.split("-")[-1]) + + return iteration + + +def save_latest_checkpoint(checkpoint_dir, iteration): + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + # Update the latest checkpoint index. + with open(checkpoint_path, "w") as handle: + handle.write("model_checkpoint_path: step-{}".format(iteration)) + + +def load_parameters(checkpoint_dir, rank, model, optimizer=None, + iteration=None, file_path=None): + if file_path is None: + if iteration is None: + iteration = load_latest_checkpoint(checkpoint_dir, rank) + if iteration == 0: + return + file_path = "{}/step-{}".format(checkpoint_dir, iteration) + + model_dict, optimizer_dict = dg.load_dygraph(file_path) + model.set_dict(model_dict) + print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path)) + if optimizer and optimizer_dict: + optimizer.set_dict(optimizer_dict) + print("[checkpoint] Rank {}: loaded optimizer state from {}".format( + rank, file_path)) + + +def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): + file_path = "{}/step-{}".format(checkpoint_dir, iteration) + model_dict = model.state_dict() + dg.save_dygraph(model_dict, file_path) + print("[checkpoint] Saved model to {}".format(file_path)) + + if optimizer: + opt_dict = optimizer.state_dict() + dg.save_dygraph(opt_dict, file_path) + print("[checkpoint] Saved optimzier state to {}".format(file_path)) diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py new file mode 100644 index 0000000..4935d42 --- /dev/null +++ b/parakeet/models/waveflow/waveflow.py @@ -0,0 +1,190 @@ +import itertools +import os +import time + +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid +from scipy.io.wavfile import write + +import utils +from data import LJSpeech +from waveflow_modules import WaveFlowLoss, WaveFlowModule + + +class WaveFlow(): + def __init__(self, config, checkpoint_dir, parallel=False, rank=0, + nranks=1, tb_logger=None): + self.config = config + self.checkpoint_dir = checkpoint_dir + self.parallel = parallel + self.rank = rank + self.nranks = nranks + self.tb_logger = tb_logger + + def build(self, training=True): + config = self.config + dataset = LJSpeech(config, self.nranks, self.rank) + self.trainloader = dataset.trainloader + self.validloader = dataset.validloader + + waveflow = WaveFlowModule("waveflow", config) + + # Dry run once to create and initalize all necessary parameters. + audio = dg.to_variable(np.random.randn(1, 16000).astype(np.float32)) + mel = dg.to_variable( + np.random.randn(1, config.mel_bands, 63).astype(np.float32)) + waveflow(audio, mel) + + if training: + optimizer = fluid.optimizer.AdamOptimizer( + learning_rate=config.learning_rate) + + # Load parameters. + utils.load_parameters(self.checkpoint_dir, self.rank, + waveflow, optimizer, + iteration=config.iteration, + file_path=config.checkpoint) + print("Rank {}: checkpoint loaded.".format(self.rank)) + + # Data parallelism. + if self.parallel: + strategy = dg.parallel.prepare_context() + waveflow = dg.parallel.DataParallel(waveflow, strategy) + + self.waveflow = waveflow + self.optimizer = optimizer + self.criterion = WaveFlowLoss(config.sigma) + + else: + # Load parameters. + utils.load_parameters(self.checkpoint_dir, self.rank, waveflow, + iteration=config.iteration, + file_path=config.checkpoint) + print("Rank {}: checkpoint loaded.".format(self.rank)) + + self.waveflow = waveflow + + def train_step(self, iteration): + self.waveflow.train() + + start_time = time.time() + audios, mels = next(self.trainloader) + load_time = time.time() + + outputs = self.waveflow(audios, mels) + loss = self.criterion(outputs) + + if self.parallel: + # loss = loss / num_trainers + loss = self.waveflow.scale_loss(loss) + loss.backward() + self.waveflow.apply_collective_grads() + else: + loss.backward() + + self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters()) + self.waveflow.clear_gradients() + + graph_time = time.time() + + if self.rank == 0: + loss_val = float(loss.numpy()) * self.nranks + log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \ + "Time: {:.3f}/{:.3f}".format( + self.rank, iteration, loss_val, + load_time - start_time, graph_time - load_time) + print(log) + + tb = self.tb_logger + tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration) + + @dg.no_grad + def valid_step(self, iteration): + self.waveflow.eval() + tb = self.tb_logger + + total_loss = [] + sample_audios = [] + start_time = time.time() + + for i, batch in enumerate(self.validloader()): + audios, mels = batch + valid_outputs = self.waveflow(audios, mels) + valid_z, valid_log_s_list = valid_outputs + + # Visualize latent z and scale log_s. + if self.rank == 0 and i == 0: + tb.add_histogram("Valid-Latent_z", valid_z.numpy(), iteration) + for j, valid_log_s in enumerate(valid_log_s_list): + hist_name = "Valid-{}th-Flow-Log_s".format(j) + tb.add_histogram(hist_name, valid_log_s.numpy(), iteration) + + valid_loss = self.criterion(valid_outputs) + total_loss.append(float(valid_loss.numpy())) + + total_time = time.time() - start_time + if self.rank == 0: + loss_val = np.mean(total_loss) + log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format( + self.rank, loss_val, total_time) + print(log) + tb.add_scalar("Valid-Avg-Loss", loss_val, iteration) + + @dg.no_grad + def infer(self, iteration): + self.waveflow.eval() + + config = self.config + sample = config.sample + + output = "{}/{}/iter-{}".format(config.output, config.name, iteration) + os.makedirs(output, exist_ok=True) + + mels_list = [mels for _, mels in self.validloader()] + if sample is not None: + mels_list = [mels_list[sample]] + + for sample, mel in enumerate(mels_list): + filename = "{}/valid_{}.wav".format(output, sample) + print("Synthesize sample {}, save as {}".format(sample, filename)) + + start_time = time.time() + audio = self.waveflow.synthesize(mel, sigma=self.config.sigma) + syn_time = time.time() - start_time + + audio = audio[0] + audio_time = audio.shape[0] / self.config.sample_rate + print("audio time {:.4f}, synthesis time {:.4f}".format( + audio_time, syn_time)) + + # Denormalize audio from [-1, 1] to [-32768, 32768] int16 range. + audio = audio.numpy() * 32768.0 + audio = audio.astype('int16') + write(filename, config.sample_rate, audio) + + @dg.no_grad + def benchmark(self): + self.waveflow.eval() + + mels_list = [mels for _, mels in self.validloader()] + mel = fluid.layers.concat(mels_list, axis=2) + mel = mel[:, :, :864] + batch_size = 8 + mel = fluid.layers.expand(mel, [batch_size, 1, 1]) + + for i in range(10): + start_time = time.time() + audio = self.waveflow.synthesize(mel, sigma=self.config.sigma) + print("audio.shape = ", audio.shape) + syn_time = time.time() - start_time + + audio_time = audio.shape[1] * batch_size / self.config.sample_rate + print("audio time {:.4f}, synthesis time {:.4f}".format( + audio_time, syn_time)) + print("{} X real-time".format(audio_time / syn_time)) + + def save(self, iteration): + utils.save_latest_parameters(self.checkpoint_dir, iteration, + self.waveflow, self.optimizer) + utils.save_latest_checkpoint(self.checkpoint_dir, iteration) diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py new file mode 100644 index 0000000..39cb598 --- /dev/null +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -0,0 +1,351 @@ +import itertools + +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid +from parakeet.modules import conv, modules, weight_norm + + +def set_param_attr(layer, c_in=1): + if isinstance(layer, (weight_norm.Conv2DTranspose, weight_norm.Conv2D)): + k = np.sqrt(1.0 / (c_in * np.prod(layer._filter_size))) + weight_init = fluid.initializer.UniformInitializer(low=-k, high=k) + bias_init = fluid.initializer.UniformInitializer(low=-k, high=k) + elif isinstance(layer, dg.Conv2D): + weight_init = fluid.initializer.ConstantInitializer(0.0) + bias_init = fluid.initializer.ConstantInitializer(0.0) + else: + raise TypeError("Unsupported layer type.") + + layer._param_attr = fluid.ParamAttr(initializer=weight_init) + layer._bias_attr = fluid.ParamAttr(initializer=bias_init) + + +def unfold(x, n_group): + length = x.shape[-1] + new_shape = x.shape[:-1] + [length // n_group, n_group] + return fluid.layers.reshape(x, new_shape) + + +class WaveFlowLoss: + def __init__(self, sigma=1.0): + self.sigma = sigma + + def __call__(self, model_output): + z, log_s_list = model_output + for i, log_s in enumerate(log_s_list): + if i == 0: + log_s_total = fluid.layers.reduce_sum(log_s) + else: + log_s_total = log_s_total + fluid.layers.reduce_sum(log_s) + + loss = fluid.layers.reduce_sum(z * z) / (2 * self.sigma * self.sigma) \ + - log_s_total + loss = loss / np.prod(z.shape) + const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma) + + return loss + const + + +class Conditioner(dg.Layer): + def __init__(self, name_scope): + super(Conditioner, self).__init__(name_scope) + upsample_factors = [16, 16] + + self.upsample_conv2d = [] + for s in upsample_factors: + in_channel = 1 + conv_trans2d = modules.Conv2DTranspose( + self.full_name(), + num_filters=1, + filter_size=(3, 2 * s), + padding=(1, s // 2), + stride=(1, s)) + set_param_attr(conv_trans2d, c_in=in_channel) + self.upsample_conv2d.append(conv_trans2d) + + for i, layer in enumerate(self.upsample_conv2d): + self.add_sublayer("conv2d_transpose_{}".format(i), layer) + + def forward(self, x): + x = fluid.layers.unsqueeze(x, 1) + for layer in self.upsample_conv2d: + x = fluid.layers.leaky_relu(layer(x), alpha=0.4) + + return fluid.layers.squeeze(x, [1]) + + def infer(self, x): + x = fluid.layers.unsqueeze(x, 1) + for layer in self.upsample_conv2d: + x = layer(x) + # Trim conv artifacts. + time_cutoff = layer._filter_size[1] - layer._stride[1] + x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4) + + return fluid.layers.squeeze(x, [1]) + + +class Flow(dg.Layer): + def __init__(self, name_scope, config): + super(Flow, self).__init__(name_scope) + self.n_layers = config.n_layers + self.n_channels = config.n_channels + self.kernel_h = config.kernel_h + self.kernel_w = config.kernel_w + + # Transform audio: [batch, 1, n_group, time/n_group] + # => [batch, n_channels, n_group, time/n_group] + self.start = weight_norm.Conv2D( + self.full_name(), + num_filters=self.n_channels, + filter_size=(1, 1)) + set_param_attr(self.start, c_in=1) + + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + # output shape: [batch, 2, n_group, time/n_group] + self.end = dg.Conv2D( + self.full_name(), + num_filters=2, + filter_size=(1, 1)) + set_param_attr(self.end) + + # receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze + dilation_dict = {8: [1, 1, 1, 1, 1, 1, 1, 1], + 16: [1, 1, 1, 1, 1, 1, 1, 1], + 32: [1, 2, 4, 1, 2, 4, 1, 2], + 64: [1, 2, 4, 8, 16, 1, 2, 4], + 128: [1, 2, 4, 8, 16, 32, 64, 1]} + self.dilation_h_list = dilation_dict[config.n_group] + + self.in_layers = [] + self.cond_layers = [] + self.res_skip_layers = [] + for i in range(self.n_layers): + dilation_h = self.dilation_h_list[i] + dilation_w = 2 ** i + + in_layer = weight_norm.Conv2D( + self.full_name(), + num_filters=2 * self.n_channels, + filter_size=(self.kernel_h, self.kernel_w), + dilation=(dilation_h, dilation_w)) + set_param_attr(in_layer, c_in=self.n_channels) + self.in_layers.append(in_layer) + + cond_layer = weight_norm.Conv2D( + self.full_name(), + num_filters=2 * self.n_channels, + filter_size=(1, 1)) + set_param_attr(cond_layer, c_in=config.mel_bands) + self.cond_layers.append(cond_layer) + + if i < self.n_layers - 1: + res_skip_channels = 2 * self.n_channels + else: + res_skip_channels = self.n_channels + res_skip_layer = weight_norm.Conv2D( + self.full_name(), + num_filters=res_skip_channels, + filter_size=(1, 1)) + set_param_attr(res_skip_layer, c_in=self.n_channels) + self.res_skip_layers.append(res_skip_layer) + + self.add_sublayer("in_layer_{}".format(i), in_layer) + self.add_sublayer("cond_layer_{}".format(i), cond_layer) + self.add_sublayer("res_skip_layer_{}".format(i), res_skip_layer) + + def forward(self, audio, mel): + # audio: [bs, 1, n_group, time/group] + # mel: [bs, mel_bands, n_group, time/n_group] + audio = self.start(audio) + + for i in range(self.n_layers): + dilation_h = self.dilation_h_list[i] + dilation_w = 2 ** i + + # Pad height dim (n_group): causal convolution + # Pad width dim (time): dialated non-causal convolution + pad_top, pad_bottom = (self.kernel_h - 1) * dilation_h, 0 + pad_left = pad_right = int((self.kernel_w-1) * dilation_w / 2) + audio_pad = fluid.layers.pad2d(audio, + paddings=[pad_top, pad_bottom, pad_left, pad_right]) + + hidden = self.in_layers[i](audio_pad) + cond_hidden = self.cond_layers[i](mel) + in_acts = hidden + cond_hidden + out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \ + fluid.layers.sigmoid(in_acts[:, self.n_channels:, :]) + res_skip_acts = self.res_skip_layers[i](out_acts) + + if i < self.n_layers - 1: + audio += res_skip_acts[:, :self.n_channels, :, :] + skip_acts = res_skip_acts[:, self.n_channels:, :, :] + else: + skip_acts = res_skip_acts + + if i == 0: + output = skip_acts + else: + output += skip_acts + + return self.end(output) + + def infer(self, audio, mel, queues): + audio = self.start(audio) + + for i in range(self.n_layers): + dilation_h = self.dilation_h_list[i] + dilation_w = 2 ** i + + state_size = dilation_h * (self.kernel_h - 1) + queue = queues[i] + + if len(queue) == 0: + for j in range(state_size): + queue.append(fluid.layers.zeros_like(audio)) + + state = queue[0:state_size] + state = fluid.layers.concat([*state, audio], axis=2) + + queue.pop(0) + queue.append(audio) + + # Pad height dim (n_group): causal convolution + # Pad width dim (time): dialated non-causal convolution + pad_top, pad_bottom = 0, 0 + pad_left = int((self.kernel_w-1) * dilation_w / 2) + pad_right = int((self.kernel_w-1) * dilation_w / 2) + state = fluid.layers.pad2d(state, + paddings=[pad_top, pad_bottom, pad_left, pad_right]) + + hidden = self.in_layers[i](state) + cond_hidden = self.cond_layers[i](mel) + in_acts = hidden + cond_hidden + out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \ + fluid.layers.sigmoid(in_acts[:, self.n_channels:, :]) + res_skip_acts = self.res_skip_layers[i](out_acts) + + if i < self.n_layers - 1: + audio += res_skip_acts[:, :self.n_channels, :, :] + skip_acts = res_skip_acts[:, self.n_channels:, :, :] + else: + skip_acts = res_skip_acts + + if i == 0: + output = skip_acts + else: + output += skip_acts + + return self.end(output) + + +class WaveFlowModule(dg.Layer): + def __init__(self, name_scope, config): + super(WaveFlowModule, self).__init__(name_scope) + self.n_flows = config.n_flows + self.n_group = config.n_group + self.n_layers = config.n_layers + assert self.n_group % 2 == 0 + assert self.n_flows % 2 == 0 + + self.conditioner = Conditioner(self.full_name()) + self.flows = [] + for i in range(self.n_flows): + flow = Flow(self.full_name(), config) + self.flows.append(flow) + self.add_sublayer("flow_{}".format(i), flow) + + self.perms = [] + half = self.n_group // 2 + for i in range(self.n_flows): + perm = list(range(self.n_group)) + if i < self.n_flows // 2: + perm = perm[::-1] + else: + perm[:half] = reversed(perm[:half]) + perm[half:] = reversed(perm[half:]) + self.perms.append(perm) + + def forward(self, audio, mel): + mel = self.conditioner(mel) + assert mel.shape[2] >= audio.shape[1] + # Prune out the tail of audio/mel so that time/n_group == 0. + pruned_len = audio.shape[1] // self.n_group * self.n_group + + if audio.shape[1] > pruned_len: + audio = audio[:, :pruned_len] + if mel.shape[2] > pruned_len: + mel = mel[:, :, :pruned_len] + + # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] + mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) + # From [bs, time] to [bs, n_group, time/n_group] + audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1]) + # [bs, 1, n_group, time/n_group] + audio = fluid.layers.unsqueeze(audio, 1) + + log_s_list = [] + for i in range(self.n_flows): + inputs = audio[:, :, :-1, :] + conds = mel[:, :, 1:, :] + outputs = self.flows[i](inputs, conds) + log_s = outputs[:, :1, :, :] + b = outputs[:, 1:, :, :] + log_s_list.append(log_s) + + audio_0 = audio[:, :, :1, :] + audio_out = audio[:, :, 1:, :] * fluid.layers.exp(log_s) + b + audio = fluid.layers.concat([audio_0, audio_out], axis=2) + + # Permute over the height dim. + audio_slices = [audio[:, :, j, :] for j in self.perms[i]] + audio = fluid.layers.stack(audio_slices, axis=2) + mel_slices = [mel[:, :, j, :] for j in self.perms[i]] + mel = fluid.layers.stack(mel_slices, axis=2) + + z = fluid.layers.squeeze(audio, [1]) + + return z, log_s_list + + def synthesize(self, mel, sigma=1.0): + mel = self.conditioner.infer(mel) + # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] + mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) + + audio = fluid.layers.gaussian_random( + shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma) + + for i in reversed(range(self.n_flows)): + # Permute over the height dimension. + audio_slices = [audio[:, :, j, :] for j in self.perms[i]] + audio = fluid.layers.stack(audio_slices, axis=2) + mel_slices = [mel[:, :, j, :] for j in self.perms[i]] + mel = fluid.layers.stack(mel_slices, axis=2) + + audio_list = [] + audio_0 = audio[:, :, 0:1, :] + audio_list.append(audio_0) + audio_h = audio_0 + queues = [[] for _ in range(self.n_layers)] + + for h in range(1, self.n_group): + inputs = audio_h + conds = mel[:, :, h:(h+1), :] + outputs = self.flows[i].infer(inputs, conds, queues) + + log_s = outputs[:, 0:1, :, :] + b = outputs[:, 1:, :, :] + audio_h = (audio[:, :, h:(h+1), :] - b) / \ + fluid.layers.exp(log_s) + audio_list.append(audio_h) + + audio = fluid.layers.concat(audio_list, axis=2) + + # audio: [bs, n_group, time/n_group] + audio = fluid.layers.squeeze(audio, [1]) + # audio: [bs, time] + audio = fluid.layers.reshape( + fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1]) + + return audio