From 0e18d600572ca1e9461cced0f0a470b503c5c900 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Thu, 19 Dec 2019 00:03:06 -0800 Subject: [PATCH] refine code --- parakeet/models/waveflow/benchmark.py | 71 +++++++++++ ...4_layer8x8.yaml => waveflow_ljspeech.yaml} | 0 ...flow_ljspeech_sqz16_r64_layer8x8_s123.yaml | 24 ---- parakeet/models/waveflow/data.py | 8 +- parakeet/models/waveflow/requirements.txt | 3 - parakeet/models/waveflow/train.py | 25 ---- parakeet/models/waveflow/utils.py | 21 ---- parakeet/models/waveflow/waveflow.py | 56 ++++----- parakeet/models/waveflow/waveflow_modules.py | 113 +++++++++++------- 9 files changed, 170 insertions(+), 151 deletions(-) create mode 100644 parakeet/models/waveflow/benchmark.py rename parakeet/models/waveflow/configs/{waveflow_ljspeech_sqz16_r64_layer8x8.yaml => waveflow_ljspeech.yaml} (100%) delete mode 100644 parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml delete mode 100644 parakeet/models/waveflow/requirements.txt diff --git a/parakeet/models/waveflow/benchmark.py b/parakeet/models/waveflow/benchmark.py new file mode 100644 index 0000000..b2949d2 --- /dev/null +++ b/parakeet/models/waveflow/benchmark.py @@ -0,0 +1,71 @@ +import os +import random +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid + +import utils +from waveflow import WaveFlow + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='waveflow', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + + +def benchmark(config): + pprint(jsonargparse.namespace_to_dict(config)) + + # Get checkpoint directory path. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + + # Configurate device. + place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveFlow(config, checkpoint_dir) + model.build(training=False) + + # Run model inference. + model.benchmark() + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser( + description="Synthesize audio using WaveNet model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + benchmark(config) diff --git a/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8.yaml b/parakeet/models/waveflow/configs/waveflow_ljspeech.yaml similarity index 100% rename from parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8.yaml rename to parakeet/models/waveflow/configs/waveflow_ljspeech.yaml diff --git a/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml b/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml deleted file mode 100644 index 7d45212..0000000 --- a/parakeet/models/waveflow/configs/waveflow_ljspeech_sqz16_r64_layer8x8_s123.yaml +++ /dev/null @@ -1,24 +0,0 @@ -valid_size: 16 -segment_length: 16000 -sample_rate: 22050 -fft_window_shift: 256 -fft_window_size: 1024 -fft_size: 1024 -mel_bands: 80 -mel_fmin: 0.0 -mel_fmax: 8000.0 - -seed: 123 -learning_rate: 0.0002 -batch_size: 8 -test_every: 2000 -save_every: 5000 -max_iterations: 2000000 - -sigma: 1.0 -n_flows: 8 -n_group: 16 -n_layers: 8 -n_channels: 64 -kernel_h: 3 -kernel_w: 3 diff --git a/parakeet/models/waveflow/data.py b/parakeet/models/waveflow/data.py index ddaf104..d89fb7b 100644 --- a/parakeet/models/waveflow/data.py +++ b/parakeet/models/waveflow/data.py @@ -4,7 +4,6 @@ import librosa import numpy as np from paddle import fluid -import utils from parakeet.datasets import ljspeech from parakeet.data import dataset from parakeet.data.batch import SpecBatcher, WavBatcher @@ -12,8 +11,6 @@ from parakeet.data.datacargo import DataCargo from parakeet.data.sampler import DistributedSampler, BatchSampler from scipy.io.wavfile import read -MAX_WAV_VALUE = 32768.0 - class Dataset(ljspeech.LJSpeech): def __init__(self, config): @@ -78,10 +75,9 @@ class Subset(dataset.Dataset): audio = np.pad(audio, (0, segment_length - audio.shape[0]), mode='constant', constant_values=0) - # Normalize audio. - audio = audio.astype(np.float32) / MAX_WAV_VALUE + # Normalize audio to the [-1, 1] range. + audio = audio.astype(np.float32) / 32768.0 mel = self.get_mel(audio) - #print("mel = {}, dtype {}, shape {}".format(mel, mel.dtype, mel.shape)) return audio, mel diff --git a/parakeet/models/waveflow/requirements.txt b/parakeet/models/waveflow/requirements.txt deleted file mode 100644 index f575339..0000000 --- a/parakeet/models/waveflow/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -paddlepaddle-gpu==1.6.1.post97 -tensorboardX==1.9 -librosa==0.7.1 diff --git a/parakeet/models/waveflow/train.py b/parakeet/models/waveflow/train.py index a125d97..89b787a 100644 --- a/parakeet/models/waveflow/train.py +++ b/parakeet/models/waveflow/train.py @@ -14,8 +14,6 @@ import slurm import utils from waveflow import WaveFlow -MAXIMUM_SAVE_TIME = 10 * 60 - def add_options_to_parser(parser): parser.add_argument('--model', type=str, default='waveflow', @@ -35,8 +33,6 @@ def add_options_to_parser(parser): "default to load the latest checkpoint")) parser.add_argument('--checkpoint', type=str, default=None, help="path of the checkpoint to load") - parser.add_argument('--slurm', type=bool, default=False, - help="whether you are using slurm to submit training jobs") def train(config): @@ -85,13 +81,6 @@ def train(config): else: iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) - # Get restart command if using slurm. - if config.slurm: - resume_command, death_time = slurm.restart_command() - if rank == 0: - print("Restart command:", " ".join(resume_command)) - done = False - while iteration < config.max_iterations: # Run one single training step. model.train_step(iteration) @@ -102,20 +91,6 @@ def train(config): # Run validation step. model.valid_step(iteration) - # Check whether reaching the time limit. - if config.slurm: - done = (death_time is not None and death_time - time.time() < - MAXIMUM_SAVE_TIME) - - if rank == 0 and done: - print("Saving progress before exiting.") - model.save(iteration) - - print("Running restart command:", " ".join(resume_command)) - # Submit restart command. - subprocess.check_call(resume_command) - break - if rank == 0 and iteration % config.save_every == 0: # Save parameters. model.save(iteration) diff --git a/parakeet/models/waveflow/utils.py b/parakeet/models/waveflow/utils.py index 494a409..3baeb60 100644 --- a/parakeet/models/waveflow/utils.py +++ b/parakeet/models/waveflow/utils.py @@ -57,27 +57,6 @@ def add_config_options_to_parser(parser): parser.add_argument('--config', action=jsonargparse.ActionConfigFile) -def pad_to_size(array, length, pad_with=0.0): - """ - Pad an array on the first (length) axis to a given length. - """ - padding = length - array.shape[0] - assert padding >= 0, "Padding required was less than zero" - - paddings = [(0, 0)] * len(array.shape) - paddings[0] = (0, padding) - - return np.pad(array, paddings, mode='constant', constant_values=pad_with) - - -def calculate_context_size(config): - dilations = list( - itertools.islice( - itertools.cycle(config.dilation_block), config.layers)) - config.context_size = sum(dilations) + 1 - print("Context size is", config.context_size) - - def load_latest_checkpoint(checkpoint_dir, rank=0): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") # Create checkpoint index file if not exist. diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py index b362c2d..4935d42 100644 --- a/parakeet/models/waveflow/waveflow.py +++ b/parakeet/models/waveflow/waveflow.py @@ -2,11 +2,10 @@ import itertools import os import time -#import librosa -from scipy.io.wavfile import write import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid +from scipy.io.wavfile import write import utils from data import LJSpeech @@ -29,18 +28,6 @@ class WaveFlow(): self.trainloader = dataset.trainloader self.validloader = dataset.validloader -# if self.rank == 0: -# for i, (audios, mels) in enumerate(self.validloader()): -# print("audios {}, mels {}".format(audios.dtype, mels.dtype)) -# print("{}: rank {}, audios {}, mels {}".format( -# i, self.rank, audios.shape, mels.shape)) -# -# for i, (audios, mels) in enumerate(self.trainloader): -# print("{}: rank {}, audios {}, mels {}".format( -# i, self.rank, audios.shape, mels.shape)) -# -# exit() - waveflow = WaveFlowModule("waveflow", config) # Dry run once to create and initalize all necessary parameters. @@ -96,8 +83,6 @@ class WaveFlow(): else: loss.backward() - current_lr = self.optimizer._learning_rate - self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters()) self.waveflow.clear_gradients() @@ -113,7 +98,6 @@ class WaveFlow(): tb = self.tb_logger tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration) - tb.add_scalar("Learning-Rate", current_lr, iteration) @dg.no_grad def valid_step(self, iteration): @@ -161,34 +145,44 @@ class WaveFlow(): if sample is not None: mels_list = [mels_list[sample]] - audio_times = [] - inf_times = [] for sample, mel in enumerate(mels_list): filename = "{}/valid_{}.wav".format(output, sample) print("Synthesize sample {}, save as {}".format(sample, filename)) start_time = time.time() - audio = self.waveflow.synthesize(mel) + audio = self.waveflow.synthesize(mel, sigma=self.config.sigma) syn_time = time.time() - start_time - audio_time = audio.shape[0] / 22050 - print("audio time {}, synthesis time {}, speedup: {}".format( - audio_time, syn_time, audio_time / syn_time)) + audio = audio[0] + audio_time = audio.shape[0] / self.config.sample_rate + print("audio time {:.4f}, synthesis time {:.4f}".format( + audio_time, syn_time)) - #librosa.output.write_wav(filename, syn_audio, - # sr=config.sample_rate) + # Denormalize audio from [-1, 1] to [-32768, 32768] int16 range. audio = audio.numpy() * 32768.0 audio = audio.astype('int16') write(filename, config.sample_rate, audio) - audio_times.append(audio_time) - inf_times.append(syn_time) + @dg.no_grad + def benchmark(self): + self.waveflow.eval() - total_audio = sum(audio_times) - total_inf = sum(inf_times) + mels_list = [mels for _, mels in self.validloader()] + mel = fluid.layers.concat(mels_list, axis=2) + mel = mel[:, :, :864] + batch_size = 8 + mel = fluid.layers.expand(mel, [batch_size, 1, 1]) - print("Total audio: {}, total inf time {}, speedup: {}".format( - total_audio, total_inf, total_audio / total_inf)) + for i in range(10): + start_time = time.time() + audio = self.waveflow.synthesize(mel, sigma=self.config.sigma) + print("audio.shape = ", audio.shape) + syn_time = time.time() - start_time + + audio_time = audio.shape[1] * batch_size / self.config.sample_rate + print("audio time {:.4f}, synthesis time {:.4f}".format( + audio_time, syn_time)) + print("{} X real-time".format(audio_time / syn_time)) def save(self, iteration): utils.save_latest_parameters(self.checkpoint_dir, iteration, diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py index 45b46a6..39cb598 100644 --- a/parakeet/models/waveflow/waveflow_modules.py +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -23,7 +23,6 @@ def set_param_attr(layer, c_in=1): def unfold(x, n_group): length = x.shape[-1] - #assert length % n_group == 0 new_shape = x.shape[:-1] + [length // n_group, n_group] return fluid.layers.reshape(x, new_shape) @@ -192,13 +191,53 @@ class Flow(dg.Layer): return self.end(output) + def infer(self, audio, mel, queues): + audio = self.start(audio) -def debug(x, msg): - y = x.numpy() - print(msg + " :\n", y) - print("shape: ", y.shape) - print("dtype: ", y.dtype) - print("") + for i in range(self.n_layers): + dilation_h = self.dilation_h_list[i] + dilation_w = 2 ** i + + state_size = dilation_h * (self.kernel_h - 1) + queue = queues[i] + + if len(queue) == 0: + for j in range(state_size): + queue.append(fluid.layers.zeros_like(audio)) + + state = queue[0:state_size] + state = fluid.layers.concat([*state, audio], axis=2) + + queue.pop(0) + queue.append(audio) + + # Pad height dim (n_group): causal convolution + # Pad width dim (time): dialated non-causal convolution + pad_top, pad_bottom = 0, 0 + pad_left = int((self.kernel_w-1) * dilation_w / 2) + pad_right = int((self.kernel_w-1) * dilation_w / 2) + state = fluid.layers.pad2d(state, + paddings=[pad_top, pad_bottom, pad_left, pad_right]) + + hidden = self.in_layers[i](state) + cond_hidden = self.cond_layers[i](mel) + in_acts = hidden + cond_hidden + out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \ + fluid.layers.sigmoid(in_acts[:, self.n_channels:, :]) + res_skip_acts = self.res_skip_layers[i](out_acts) + + if i < self.n_layers - 1: + audio += res_skip_acts[:, :self.n_channels, :, :] + skip_acts = res_skip_acts[:, self.n_channels:, :, :] + else: + skip_acts = res_skip_acts + + if i == 0: + output = skip_acts + else: + output += skip_acts + + return self.end(output) class WaveFlowModule(dg.Layer): @@ -206,7 +245,9 @@ class WaveFlowModule(dg.Layer): super(WaveFlowModule, self).__init__(name_scope) self.n_flows = config.n_flows self.n_group = config.n_group + self.n_layers = config.n_layers assert self.n_group % 2 == 0 + assert self.n_flows % 2 == 0 self.conditioner = Conditioner(self.full_name()) self.flows = [] @@ -215,14 +256,16 @@ class WaveFlowModule(dg.Layer): self.flows.append(flow) self.add_sublayer("flow_{}".format(i), flow) - self.perms = [[15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], - [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], - [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], - [15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0], - [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8], - [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8], - [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8], - [7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8]] + self.perms = [] + half = self.n_group // 2 + for i in range(self.n_flows): + perm = list(range(self.n_group)) + if i < self.n_flows // 2: + perm = perm[::-1] + else: + perm[:half] = reversed(perm[:half]) + perm[half:] = reversed(perm[half:]) + self.perms.append(perm) def forward(self, audio, mel): mel = self.conditioner(mel) @@ -266,19 +309,13 @@ class WaveFlowModule(dg.Layer): return z, log_s_list def synthesize(self, mel, sigma=1.0): - #debug(mel, "mel") mel = self.conditioner.infer(mel) - #debug(mel, "mel after conditioner") - # From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group] mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2]) - #debug(mel, "after group") audio = fluid.layers.gaussian_random( shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma) - #debug(audio, "audio") - for i in reversed(range(self.n_flows)): # Permute over the height dimension. audio_slices = [audio[:, :, j, :] for j in self.perms[i]] @@ -287,34 +324,28 @@ class WaveFlowModule(dg.Layer): mel = fluid.layers.stack(mel_slices, axis=2) audio_list = [] - audio_0 = audio[:, :, :1, :] + audio_0 = audio[:, :, 0:1, :] audio_list.append(audio_0) + audio_h = audio_0 + queues = [[] for _ in range(self.n_layers)] for h in range(1, self.n_group): - # inputs: [bs, 1, h, time/n_group] - inputs = fluid.layers.concat(audio_list, axis=2) - conds = mel[:, :, 1:(h+1), :] - outputs = self.flows[i](inputs, conds) + inputs = audio_h + conds = mel[:, :, h:(h+1), :] + outputs = self.flows[i].infer(inputs, conds, queues) - log_s = outputs[:, :1, (h-1):h, :] - b = outputs[:, 1:, (h-1):h, :] - audio_h = (audio[:, :, h:(h+1), :] - b) / fluid.layers.exp(log_s) + log_s = outputs[:, 0:1, :, :] + b = outputs[:, 1:, :, :] + audio_h = (audio[:, :, h:(h+1), :] - b) / \ + fluid.layers.exp(log_s) audio_list.append(audio_h) audio = fluid.layers.concat(audio_list, axis=2) - #print("audio.shape =", audio.shape) - # Assume batch size = 1 - # audio: [n_group, time/n_group] - audio = fluid.layers.squeeze(audio, [0, 1]) - # audio: [time] + # audio: [bs, n_group, time/n_group] + audio = fluid.layers.squeeze(audio, [1]) + # audio: [bs, time] audio = fluid.layers.reshape( - fluid.layers.transpose(audio, [1, 0]), [-1]) - #print("audio.shape =", audio.shape) + fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1]) return audio - - def start_new_sequence(self): - for layer in self.sublayers(): - if isinstance(layer, conv.Conv1D): - layer.start_new_sequence()