From b15c313423ea90860c145e991784e0f7448c7fa6 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 2 Dec 2019 14:00:53 -0800 Subject: [PATCH 1/3] working integraton with parakeet --- .gitignore | 8 +- parakeet/data/datacargo.py | 3 + parakeet/models/wavenet/README.md | 1 + .../wavenet_ljspeech_single_gaussian.yaml | 32 + parakeet/models/wavenet/data.py | 191 ++++ parakeet/models/wavenet/ops.py | 249 +++++ parakeet/models/wavenet/slurm.py | 112 +++ parakeet/models/wavenet/synthesis.py | 85 ++ parakeet/models/wavenet/train.py | 139 +++ parakeet/models/wavenet/utils.py | 143 +++ parakeet/models/wavenet/wavenet.py | 188 ++++ parakeet/models/wavenet/wavenet_modules.py | 423 ++++++++ parakeet/models/wavenet/weight_norm.py | 920 ++++++++++++++++++ 13 files changed, 2493 insertions(+), 1 deletion(-) create mode 100644 parakeet/models/wavenet/README.md create mode 100644 parakeet/models/wavenet/configs/wavenet_ljspeech_single_gaussian.yaml create mode 100644 parakeet/models/wavenet/data.py create mode 100644 parakeet/models/wavenet/ops.py create mode 100644 parakeet/models/wavenet/slurm.py create mode 100644 parakeet/models/wavenet/synthesis.py create mode 100644 parakeet/models/wavenet/train.py create mode 100644 parakeet/models/wavenet/utils.py create mode 100644 parakeet/models/wavenet/wavenet.py create mode 100644 parakeet/models/wavenet/wavenet_modules.py create mode 100644 parakeet/models/wavenet/weight_norm.py diff --git a/.gitignore b/.gitignore index 9e0ff35..13dd63d 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,10 @@ venv.bak/ dmypy.json # Pyre type checker -.pyre/ \ No newline at end of file +.pyre/ + +# Shell, vim, and output folder +*.sh +*.swp +runs +syn_audios diff --git a/parakeet/data/datacargo.py b/parakeet/data/datacargo.py index da6bc9a..1d7d8d5 100644 --- a/parakeet/data/datacargo.py +++ b/parakeet/data/datacargo.py @@ -31,6 +31,9 @@ class DataCargo(object): def __iter__(self): return DataIterator(self) + + def __call__(self): + return DataIterator(self) @property def _auto_collation(self): diff --git a/parakeet/models/wavenet/README.md b/parakeet/models/wavenet/README.md new file mode 100644 index 0000000..412a3c8 --- /dev/null +++ b/parakeet/models/wavenet/README.md @@ -0,0 +1 @@ +# WaveNet-Paddle \ No newline at end of file diff --git a/parakeet/models/wavenet/configs/wavenet_ljspeech_single_gaussian.yaml b/parakeet/models/wavenet/configs/wavenet_ljspeech_single_gaussian.yaml new file mode 100644 index 0000000..fc7222b --- /dev/null +++ b/parakeet/models/wavenet/configs/wavenet_ljspeech_single_gaussian.yaml @@ -0,0 +1,32 @@ +valid_size: 16 +train_clip_second: 0.5 +sample_rate: 22050 +fft_window_shift: 256 +fft_window_size: 1024 +fft_size: 2048 +mel_bands: 80 + +seed: 1 +batch_size: 8 +test_every: 2000 +save_every: 10000 +max_iterations: 2000000 + +layers: 30 +kernel_width: 2 +dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +residual_channels: 128 +skip_channels: 128 +loss_type: mix-gaussian-pdf +num_mixtures: 1 +log_scale_min: -9.0 + +conditioner: + filter_sizes: [[32, 3], [32, 3]] + upsample_factors: [16, 16] + +learning_rate: 0.001 +gradient_max_norm: 100.0 +anneal: + every: 200000 + rate: 0.5 diff --git a/parakeet/models/wavenet/data.py b/parakeet/models/wavenet/data.py new file mode 100644 index 0000000..61cc4ab --- /dev/null +++ b/parakeet/models/wavenet/data.py @@ -0,0 +1,191 @@ +import math +import os +import random + +import librosa +import numpy as np +from paddle import fluid + +import utils +from parakeet.datasets import ljspeech +from parakeet.data import dataset +from parakeet.data.sampler import Sampler, BatchSampler, SequentialSampler +from parakeet.data.datacargo import DataCargo + + +class Dataset(ljspeech.LJSpeech): + def __init__(self, config): + super(Dataset, self).__init__(config.root) + self.config = config + self.fft_window_shift = config.fft_window_shift + # Calculate context frames. + frames_per_second = config.sample_rate // self.fft_window_shift + train_clip_frames = int(math.ceil( + config.train_clip_second * frames_per_second)) + context_frames = config.context_size // self.fft_window_shift + self.num_frames = train_clip_frames + context_frames + + def _get_example(self, metadatum): + fname, _, _ = metadatum + wav_path = self.root.joinpath("wavs", fname + ".wav") + + config = self.config + sr = config.sample_rate + fft_window_shift = config.fft_window_shift + fft_window_size = config.fft_window_size + fft_size = config.fft_size + + audio, loaded_sr = librosa.load(wav_path, sr=None) + assert loaded_sr == sr + + # Pad audio to the right size. + frames = math.ceil(float(audio.size) / fft_window_shift) + fft_padding = (fft_size - fft_window_shift) // 2 + desired_length = frames * fft_window_shift + fft_padding * 2 + pad_amount = (desired_length - audio.size) // 2 + + if audio.size % 2 == 0: + audio = np.pad(audio, (pad_amount, pad_amount), mode='reflect') + else: + audio = np.pad(audio, (pad_amount, pad_amount + 1), mode='reflect') + + # Normalize audio. + audio = audio / np.abs(audio).max() * 0.999 + + # Compute mel-spectrogram. + # Turn center to False to prevent internal padding. + spectrogram = librosa.core.stft( + audio, hop_length=fft_window_shift, + win_length=fft_window_size, n_fft=fft_size, center=False) + spectrogram_magnitude = np.abs(spectrogram) + + # Compute mel-spectrograms. + mel_filter_bank = librosa.filters.mel(sr=sr, n_fft=fft_size, + n_mels=config.mel_bands) + mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude) + mel_spectrogram = mel_spectrogram.T + + # Rescale mel_spectrogram. + min_level, ref_level = 1e-5, 20 + mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram)) + mel_spectrogram = mel_spectrogram - ref_level + mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1) + + # Extract the center of audio that corresponds to mel spectrograms. + audio = audio[fft_padding : -fft_padding] + assert mel_spectrogram.shape[0] * fft_window_shift == audio.size + + return audio, mel_spectrogram + + +class Subset(dataset.Dataset): + def __init__(self, dataset, indices, valid): + self.dataset = dataset + self.indices = indices + self.valid = valid + + def __getitem__(self, idx): + fft_window_shift = self.dataset.fft_window_shift + num_frames = self.dataset.num_frames + audio, mel = self.dataset[self.indices[idx]] + + if self.valid: + audio_start = 0 + else: + # Randomly crop context + train_clip_second of audio. + audio_frames = int(audio.size) // fft_window_shift + max_start_frame = audio_frames - num_frames + assert max_start_frame >= 0, "audio {} is too short".format(idx) + + frame_start = random.randint(0, max_start_frame) + frame_end = frame_start + num_frames + + audio_start = frame_start * fft_window_shift + audio_end = frame_end * fft_window_shift + + audio = audio[audio_start : audio_end] + + return audio, mel, audio_start + + def _batch_examples(self, batch): + audios = [sample[0] for sample in batch] + audio_starts = [sample[2] for sample in batch] + + # mels shape [num_frames, mel_bands] + max_frames = max(sample[1].shape[0] for sample in batch) + mels = [utils.pad_to_size(sample[1], max_frames) for sample in batch] + + audios = np.array(audios, dtype=np.float32) + mels = np.array(mels, dtype=np.float32) + audio_starts = np.array(audio_starts, dtype=np.int32) + + return audios, mels, audio_starts + + def __len__(self): + return len(self.indices) + + +class DistributedSampler(Sampler): + def __init__(self, dataset_size, num_trainers, rank, shuffle=True): + self.dataset_size = dataset_size + self.num_trainers = num_trainers + self.rank = rank + self.num_samples = int(math.ceil(dataset_size / num_trainers)) + self.total_size = self.num_samples * num_trainers + assert self.total_size >= self.dataset_size + self.shuffle = shuffle + + def __iter__(self): + indices = list(range(self.dataset_size)) + if self.shuffle: + random.shuffle(indices) + + # Append extra samples to make it evenly distributed on all trainers. + indices += indices[:(self.total_size - self.dataset_size)] + assert len(indices) == self.total_size + + # Subset samples for each trainer. + indices = indices[self.rank:self.total_size:self.num_trainers] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + +class LJSpeech: + def __init__(self, config, nranks, rank): + place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace() + + # Whole LJSpeech dataset. + ds = Dataset(config) + + # Split into train and valid dataset. + indices = list(range(len(ds))) + train_indices = indices[config.valid_size:] + valid_indices = indices[:config.valid_size] + random.shuffle(train_indices) + + # Train dataset. + trainset = Subset(ds, train_indices, valid=False) + sampler = DistributedSampler(len(trainset), nranks, rank) + total_bs = config.batch_size + assert total_bs % nranks == 0 + train_sampler = BatchSampler(sampler, total_bs // nranks, + drop_last=True) + trainloader = DataCargo(trainset, batch_sampler=train_sampler) + + trainreader = fluid.io.PyReader(capacity=50, return_list=True) + trainreader.decorate_batch_generator(trainloader, place) + self.trainloader = (data for _ in iter(int, 1) + for data in trainreader()) + + # Valid dataset. + validset = Subset(ds, valid_indices, valid=True) + # Currently only support batch_size = 1 for valid loader. + validloader = DataCargo(validset, batch_size=1, shuffle=False) + + validreader = fluid.io.PyReader(capacity=20, return_list=True) + validreader.decorate_batch_generator(validloader, place) + self.validloader = validreader diff --git a/parakeet/models/wavenet/ops.py b/parakeet/models/wavenet/ops.py new file mode 100644 index 0000000..6eda2a9 --- /dev/null +++ b/parakeet/models/wavenet/ops.py @@ -0,0 +1,249 @@ +import paddle +from paddle import fluid +import paddle.fluid.dygraph as dg +import numpy as np + +import weight_norm + + +def Embedding(name_scope, + num_embeddings, + embed_dim, + padding_idx=None, + std=0.1, + dtype="float32"): + # param attrs + weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=std)) + layer = dg.Embedding( + name_scope, (num_embeddings, embed_dim), + padding_idx=padding_idx, + param_attr=weight_attr, + dtype=dtype) + return layer + + +def FC(name_scope, + in_features, + size, + num_flatten_dims=1, + relu=False, + dropout=0.0, + act=None, + dtype="float32"): + """ + A special Linear Layer, when it is used with dropout, the weight is + initialized as normal(0, std=np.sqrt((1-dropout) / in_features)) + """ + + # stds + if isinstance(in_features, int): + in_features = [in_features] + + stds = [np.sqrt((1.0 - dropout) / in_feature) for in_feature in in_features] + if relu: + stds = [std * np.sqrt(2.0) for std in stds] + + weight_inits = [ + fluid.initializer.NormalInitializer(scale=std) for std in stds + ] + bias_init = fluid.initializer.ConstantInitializer(0.0) + + # param attrs + weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits] + bias_attr = fluid.ParamAttr(initializer=bias_init) + + layer = weight_norm.FC(name_scope, + size, + num_flatten_dims=num_flatten_dims, + param_attr=weight_attrs, + bias_attr=bias_attr, + act=act, + dtype=dtype) + return layer + + +def Conv1D(name_scope, + in_channels, + num_filters, + filter_size=2, + dilation=1, + groups=None, + causal=False, + std_mul=1.0, + dropout=0.0, + use_cudnn=True, + act=None, + dtype="float32"): + """ + A special Conv1D Layer, when it is used with dropout, the weight is + initialized as + normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_channels))) + """ + # std + std = np.sqrt((std_mul * (1.0 - dropout)) / (filter_size * in_channels)) + weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std) + bias_init = fluid.initializer.ConstantInitializer(0.0) + + # param attrs + weight_attr = fluid.ParamAttr(initializer=weight_init) + bias_attr = fluid.ParamAttr(initializer=bias_init) + + layer = weight_norm.Conv1D( + name_scope, + num_filters, + filter_size, + dilation, + groups=groups, + causal=causal, + param_attr=weight_attr, + bias_attr=bias_attr, + use_cudnn=use_cudnn, + act=act, + dtype=dtype) + return layer + + +class Conv1D_GU(dg.Layer): + def __init__(self, + name_scope, + conditioner_dim, + in_channels, + num_filters, + filter_size, + dilation, + causal=False, + residual=True, + dtype="float32"): + super(Conv1D_GU, self).__init__(name_scope, dtype=dtype) + + self.conditioner_dim = conditioner_dim + self.in_channels = in_channels + self.num_filters = num_filters + self.filter_size = filter_size + self.dilation = dilation + self.causal = causal + self.residual = residual + + if residual: + assert ( + in_channels == num_filters + ), "this block uses residual connection"\ + "the input_channels should equals num_filters" + + self.conv = Conv1D( + self.full_name(), + in_channels, + 2 * num_filters, + filter_size, + dilation, + causal=causal, + dtype=dtype) + + self.fc = Conv1D( + self.full_name(), + conditioner_dim, + 2 * num_filters, + filter_size=1, + dilation=1, + causal=False, + dtype=dtype) + + def forward(self, x, skip=None, conditioner=None): + """ + Args: + x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU + layer, where B means batch_size, C_in means the input channels + T means input time steps. + conditioner (Variable): Shape(B, C_con, 1, T), expanded mel + conditioner, where C_con is conditioner hidden dim which + equals the num of mel bands. Note that when using residual + connection, the Conv1DGLU does not change the number of + channels, so out channels equals input channels. + Returns: + x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where + C_out means the output channels of Conv1DGLU. + """ + residual = x + x = self.conv(x) + + if conditioner is not None: + cond_bias = self.fc(conditioner) + x += cond_bias + + content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) + + # Gated Unit. + x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), + fluid.layers.tanh(content)) + + if skip is None: + skip = x + else: + skip = fluid.layers.scale(skip + x, np.sqrt(0.5)) + + if self.residual: + x = fluid.layers.scale(residual + x, np.sqrt(0.5)) + + return x, skip + + def add_input(self, x, skip=None, conditioner=None): + """ + Inputs: + x: shape(B, num_filters, 1, time_steps) + conditioner: shape(B, conditioner_dim, 1, time_steps) + Outputs: + out: shape(B, num_filters, 1, time_steps), where time_steps = 1 + """ + residual = x + + # add step input and produce step output + x = self.conv.add_input(x) + + if conditioner is not None: + cond_bias = self.fc(conditioner) + x += cond_bias + + content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) + + # Gated Unit. + x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), + fluid.layers.tanh(content)) + + if skip is None: + skip = x + else: + skip = fluid.layers.scale(skip + x, np.sqrt(0.5)) + + if self.residual: + x = fluid.layers.scale(residual + x, np.sqrt(0.5)) + + return x, skip + + +def Conv2DTranspose(name_scope, + num_filters, + filter_size, + padding=0, + stride=1, + dilation=1, + use_cudnn=True, + act=None, + dtype="float32"): + val = 1.0 / (filter_size[0] * filter_size[1]) + weight_init = fluid.initializer.ConstantInitializer(val) + weight_attr = fluid.ParamAttr(initializer=weight_init) + + layer = weight_norm.Conv2DTranspose( + name_scope, + num_filters, + filter_size=filter_size, + padding=padding, + stride=stride, + dilation=dilation, + param_attr=weight_attr, + use_cudnn=use_cudnn, + act=act, + dtype=dtype) + + return layer diff --git a/parakeet/models/wavenet/slurm.py b/parakeet/models/wavenet/slurm.py new file mode 100644 index 0000000..47af2dc --- /dev/null +++ b/parakeet/models/wavenet/slurm.py @@ -0,0 +1,112 @@ +""" +Utility module for restarting training when using SLURM. +""" +import subprocess +import os +import sys +import shlex +import re +import time + + +def job_info(): + """Get information about the current job using `scontrol show job`. + Returns a dict mapping parameter names (e.g. "UserId", "RunTime", etc) to + their values, both as strings. + """ + job_id = int(os.environ["SLURM_JOB_ID"]) + + command = ["scontrol", "show", "job", str(job_id)] + output = subprocess.check_output(command).decode("utf-8") + + # Use a regex to extract the parameter names and values + pattern = "([A-Za-z/]*)=([^ \t\n]*)" + return dict(re.findall(pattern, output)) + + +def parse_hours(text): + """Parse a time format HH or DD-HH into a number of hours.""" + hour_chunks = text.split("-") + if len(hour_chunks) == 1: + return int(hour_chunks[0]) + elif len(hour_chunks) == 2: + return 24 * int(hour_chunks[0]) + int(hour_chunks[1]) + else: + raise ValueError("Unexpected hour format (expected HH or " + "DD-HH, but got {}).".format(text)) + + +def parse_time(text): + """Convert slurm time to an integer. + Expects time to be of the form: + "hours:minutes:seconds" or "day-hours:minutes:seconds". + """ + hours, minutes, seconds = text.split(":") + try: + return parse_hours(hours) * 3600 + int(minutes) * 60 + int(seconds) + except ValueError as e: + raise ValueError("Error parsing time {}. Got error {}.".format( + text, str(e))) + + +def restart_command(): + """Using the environment and SLURM command, create a command that, when, + run, will enqueue a repeat of the current job using `sbatch`. + Return the command as a list of strings, suitable for passing to + `subprocess.check_call` or similar functions. + Returns: + resume_command: list, command to run to restart job. + end_time: int or None; the time the job will end or None + if the job has unlimited runtime. + """ + # Make sure `RunTime` could be parsed correctly. + while job_info()["RunTime"] == "INVALID": + time.sleep(1) + + # Get all the necessary information by querying SLURM with this job id + info = job_info() + + try: + num_cpus = int(info["CPUs/Task"]) + except KeyError: + num_cpus = int(os.environ["SLURM_CPUS_PER_TASK"]) + + num_tasks = int(os.environ["SLURM_NTASKS"]) + nodes = info["NumNodes"] + gres, partition = info.get("Gres"), info.get("Partition") + stderr, stdout = info.get("StdErr"), info.get("StdOut") + job_name = info.get("JobName") + command = ["sbatch", "--job-name={}".format(job_name), + "--ntasks={}".format(num_tasks)] + + if partition: + command.extend(["--partition", partition]) + + if gres and gres != "(null)": + command.extend(["--gres", gres]) + num_gpu = int(gres.split(':')[-1]) + print("number of gpu assigned by slurm is {}".format(num_gpu)) + + if stderr: + command.extend(["--error", stderr]) + + if stdout: + command.extend(["--output", stdout]) + + python = subprocess.check_output( + ["/usr/bin/which", "python3"]).decode("utf-8").strip() + dist_setting = ['-m', 'paddle.distributed.launch'] + wrap_cmd = ["srun", python, '-u'] + dist_setting + sys.argv + + command.append( + "--wrap={}".format(" ".join(shlex.quote(arg) for arg in wrap_cmd))) + time_limit_string = info["TimeLimit"] + if time_limit_string.lower() == "unlimited": + print("UNLIMITED detected: restart OFF, infinite learning ON.", + flush=True) + return command, None + time_limit = parse_time(time_limit_string) + runtime = parse_time(info["RunTime"]) + end_time = time.time() + time_limit - runtime + + return command, end_time diff --git a/parakeet/models/wavenet/synthesis.py b/parakeet/models/wavenet/synthesis.py new file mode 100644 index 0000000..d87a188 --- /dev/null +++ b/parakeet/models/wavenet/synthesis.py @@ -0,0 +1,85 @@ +import os +import random +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid + +import utils +from wavenet import WaveNet + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='wavenet', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + + parser.add_argument('--output', type=str, default="./syn_audios", + help="path to write synthesized audio files") + parser.add_argument('--sample', type=int, + help="which of the valid samples to synthesize audio") + + +def synthesize(config): + pprint(jsonargparse.namespace_to_dict(config)) + + # Get checkpoint directory path. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + + # Configurate device. + place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveNet(config, checkpoint_dir) + model.build(training=False) + + # Obtain the current iteration. + if config.checkpoint is None: + if config.iteration is None: + iteration = utils.load_latest_checkpoint(checkpoint_dir) + else: + iteration = config.iteration + else: + iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) + + # Run model inference. + model.infer(iteration) + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser( + description="Synthesize audio using WaveNet model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + synthesize(config) diff --git a/parakeet/models/wavenet/train.py b/parakeet/models/wavenet/train.py new file mode 100644 index 0000000..1a17bbd --- /dev/null +++ b/parakeet/models/wavenet/train.py @@ -0,0 +1,139 @@ +import os +import random +import subprocess +import time +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid +from tensorboardX import SummaryWriter + +import slurm +import utils +from wavenet import WaveNet + +MAXIMUM_SAVE_TIME = 10 * 60 + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='wavenet', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--parallel', type=bool, default=True, + help="option to use data parallel training") + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + parser.add_argument('--slurm', type=bool, default=False, + help="whether you are using slurm to submit training jobs") + + +def train(config): + use_gpu = config.use_gpu + parallel = config.parallel if use_gpu else False + + # Get the rank of the current training process. + rank = dg.parallel.Env().local_rank if parallel else 0 + nranks = dg.parallel.Env().nranks if parallel else 1 + + if rank == 0: + # Print the whole config setting. + pprint(jsonargparse.namespace_to_dict(config)) + + # Make checkpoint directory. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + os.makedirs(checkpoint_dir, exist_ok=True) + + # Create tensorboard logger. + tb = SummaryWriter(os.path.join(run_dir, "logs")) \ + if rank == 0 else None + + # Configurate device + place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveNet(config, checkpoint_dir, parallel, rank, nranks, tb) + model.build() + + # Obtain the current iteration. + if config.checkpoint is None: + if config.iteration is None: + iteration = utils.load_latest_checkpoint(checkpoint_dir, rank) + else: + iteration = config.iteration + else: + iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) + + # Get restart command if using slurm. + if config.slurm: + resume_command, death_time = slurm.restart_command() + if rank == 0: + print("Restart command:", " ".join(resume_command)) + done = False + + while iteration < config.max_iterations: + # Run one single training step. + model.train_step(iteration) + + iteration += 1 + + if iteration % config.test_every == 0: + # Run validation step. + model.valid_step(iteration) + + # Check whether reaching the time limit. + if config.slurm: + done = (death_time is not None and death_time - time.time() < + MAXIMUM_SAVE_TIME) + + if rank == 0 and done: + print("Saving progress before exiting.") + model.save(iteration) + + print("Running restart command:", " ".join(resume_command)) + # Submit restart command. + subprocess.check_call(resume_command) + break + + if rank == 0 and iteration % config.save_every == 0: + # Save parameters. + model.save(iteration) + + # Close TensorBoard. + if rank == 0: + tb.close() + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser(description="Train WaveNet model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + train(config) diff --git a/parakeet/models/wavenet/utils.py b/parakeet/models/wavenet/utils.py new file mode 100644 index 0000000..c2b6601 --- /dev/null +++ b/parakeet/models/wavenet/utils.py @@ -0,0 +1,143 @@ +import itertools +import os +import time + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg + + +def add_config_options_to_parser(parser): + parser.add_argument('--valid_size', type=int, + help="size of the valid dataset") + parser.add_argument('--train_clip_second', type=float, + help="the length of audio clip for training") + parser.add_argument('--sample_rate', type=int, + help="sampling rate of audio data file") + parser.add_argument('--fft_window_shift', type=int, + help="the shift of fft window for each frame") + parser.add_argument('--fft_window_size', type=int, + help="the size of fft window for each frame") + parser.add_argument('--fft_size', type=int, + help="the size of fft filter on each frame") + parser.add_argument('--mel_bands', type=int, + help="the number of mel bands when calculating mel spectrograms") + + parser.add_argument('--seed', type=int, + help="seed of random initialization for the model") + parser.add_argument('--batch_size', type=int, + help="batch size for training") + parser.add_argument('--test_every', type=int, + help="test interval during training") + parser.add_argument('--save_every', type=int, + help="checkpointing interval during training") + parser.add_argument('--max_iterations', type=int, + help="maximum training iterations") + + parser.add_argument('--layers', type=int, + help="number of dilated convolution layers") + parser.add_argument('--kernel_width', type=int, + help="dilated convolution kernel width") + parser.add_argument('--dilation_block', type=list, + help="dilated convolution kernel width") + parser.add_argument('--residual_channels', type=int) + parser.add_argument('--skip_channels', type=int) + parser.add_argument('--loss_type', type=str, + help="mix-gaussian-pdf or softmax") + parser.add_argument('--num_channels', type=int, default=None, + help="number of channels for softmax output") + parser.add_argument('--num_mixtures', type=int, default=None, + help="number of gaussian mixtures for gaussian output") + parser.add_argument('--log_scale_min', type=float, default=None, + help="minimum clip value of log variance of gaussian output") + + parser.add_argument('--conditioner.filter_sizes', type=list, + help="conv2d tranpose op filter sizes for building conditioner") + parser.add_argument('--conditioner.upsample_factors', type=list, + help="list of upsample factors for building conditioner") + + parser.add_argument('--learning_rate', type=float) + parser.add_argument('--gradient_max_norm', type=float) + parser.add_argument('--anneal.every', type=int, + help="step interval for annealing learning rate") + parser.add_argument('--anneal.rate', type=float) + + parser.add_argument('--config', action=jsonargparse.ActionConfigFile) + + +def pad_to_size(array, length, pad_with=0.0): + """ + Pad an array on the first (length) axis to a given length. + """ + padding = length - array.shape[0] + assert padding >= 0, "Padding required was less than zero" + + paddings = [(0, 0)] * len(array.shape) + paddings[0] = (0, padding) + + return np.pad(array, paddings, mode='constant', constant_values=pad_with) + + +def calculate_context_size(config): + dilations = list( + itertools.islice( + itertools.cycle(config.dilation_block), config.layers)) + config.context_size = sum(dilations) + 1 + print("Context size is", config.context_size) + + +def load_latest_checkpoint(checkpoint_dir, rank=0): + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + # Create checkpoint index file if not exist. + if (not os.path.isfile(checkpoint_path)) and rank == 0: + with open(checkpoint_path, "w") as handle: + handle.write("model_checkpoint_path: step-0") + + # Make sure that other process waits until checkpoint file is created + # by process 0. + while not os.path.isfile(checkpoint_path): + time.sleep(1) + + # Fetch the latest checkpoint index. + with open(checkpoint_path, "r") as handle: + latest_checkpoint = handle.readline().split()[-1] + iteration = int(latest_checkpoint.split("-")[-1]) + + return iteration + + +def save_latest_checkpoint(checkpoint_dir, iteration): + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + # Update the latest checkpoint index. + with open(checkpoint_path, "w") as handle: + handle.write("model_checkpoint_path: step-{}".format(iteration)) + + +def load_parameters(checkpoint_dir, rank, model, optimizer=None, + iteration=None, file_path=None): + if file_path is None: + if iteration is None: + iteration = load_latest_checkpoint(checkpoint_dir, rank) + if iteration == 0: + return + file_path = "{}/step-{}".format(checkpoint_dir, iteration) + + model_dict, optimizer_dict = dg.load_dygraph(file_path) + model.set_dict(model_dict) + print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path)) + if optimizer and optimizer_dict: + optimizer.set_dict(optimizer_dict) + print("[checkpoint] Rank {}: loaded optimizer state from {}".format( + rank, file_path)) + + +def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): + file_path = "{}/step-{}".format(checkpoint_dir, iteration) + model_dict = model.state_dict() + dg.save_dygraph(model_dict, file_path) + print("[checkpoint] Saved model to {}".format(file_path)) + + if optimizer: + opt_dict = optimizer.state_dict() + dg.save_dygraph(opt_dict, file_path) + print("[checkpoint] Saved optimzier state to {}".format(file_path)) diff --git a/parakeet/models/wavenet/wavenet.py b/parakeet/models/wavenet/wavenet.py new file mode 100644 index 0000000..acc6e76 --- /dev/null +++ b/parakeet/models/wavenet/wavenet.py @@ -0,0 +1,188 @@ +import itertools +import os +import time + +import librosa +import numpy as np +from paddle import fluid +import paddle.fluid.dygraph as dg + +import utils +from data import LJSpeech +from wavenet_modules import WaveNetModule, debug + + +class WaveNet(): + def __init__(self, config, checkpoint_dir, parallel=False, rank=0, + nranks=1, tb_logger=None): + # Process config to calculate the context size + dilations = list( + itertools.islice( + itertools.cycle(config.dilation_block), config.layers)) + config.context_size = sum(dilations) + 1 + self.config = config + self.checkpoint_dir = checkpoint_dir + self.parallel = parallel + self.rank = rank + self.nranks = nranks + self.tb_logger = tb_logger + + def build(self, training=True): + config = self.config + dataset = LJSpeech(config, self.nranks, self.rank) + self.trainloader = dataset.trainloader + self.validloader = dataset.validloader + +# if self.rank == 0: +# for i, (audios, mels, ids) in enumerate(self.validloader()): +# print("audios {}, mels {}, ids {}".format(audios.dtype, mels.dtype, ids.dtype)) +# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format( +# i, self.rank, audios.shape, mels.shape, ids.shape, +# ids.numpy())) +# +# for i, (audios, mels, ids) in enumerate(self.trainloader): +# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format( +# i, self.rank, audios.shape, mels.shape, ids.shape, +# ids.numpy())) + + wavenet = WaveNetModule("wavenet", config, self.rank) + + # Dry run once to create and initalize all necessary parameters. + audio = dg.to_variable(np.random.randn(1, 20000).astype(np.float32)) + mel = dg.to_variable( + np.random.randn(1, 100, self.config.mel_bands).astype(np.float32)) + audio_start = dg.to_variable(np.array([0], dtype=np.int32)) + wavenet(audio, mel, audio_start) + + if training: + # Create Learning rate scheduler. + lr_scheduler = dg.ExponentialDecay( + learning_rate = config.learning_rate, + decay_steps = config.anneal.every, + decay_rate = config.anneal.rate, + staircase=True) + + optimizer = fluid.optimizer.AdamOptimizer( + learning_rate=lr_scheduler) + + clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( + config.gradient_max_norm) + + # Load parameters. + utils.load_parameters(self.checkpoint_dir, self.rank, + wavenet, optimizer, + iteration=config.iteration, + file_path=config.checkpoint) + print("Rank {}: checkpoint loaded.".format(self.rank)) + + # Data parallelism. + if self.parallel: + strategy = dg.parallel.prepare_context() + wavenet = dg.parallel.DataParallel(wavenet, strategy) + + self.wavenet = wavenet + self.optimizer = optimizer + self.clipper = clipper + + else: + # Load parameters. + utils.load_parameters(self.checkpoint_dir, self.rank, wavenet, + iteration=config.iteration, + file_path=config.checkpoint) + print("Rank {}: checkpoint loaded.".format(self.rank)) + + self.wavenet = wavenet + + def train_step(self, iteration): + self.wavenet.train() + + start_time = time.time() + audios, mels, audio_starts = next(self.trainloader) + load_time = time.time() + + loss, _ = self.wavenet(audios, mels, audio_starts) + + if self.parallel: + # loss = loss / num_trainers + loss = self.wavenet.scale_loss(loss) + loss.backward() + self.wavenet.apply_collective_grads() + else: + loss.backward() + + if isinstance(self.optimizer._learning_rate, + fluid.optimizer.LearningRateDecay): + current_lr = self.optimizer._learning_rate.step().numpy() + else: + current_lr = self.optimizer._learning_rate + + self.optimizer.minimize(loss, grad_clip=self.clipper, + parameter_list=self.wavenet.parameters()) + self.wavenet.clear_gradients() + + graph_time = time.time() + + if self.rank == 0: + loss_val = float(loss.numpy()) * self.nranks + log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \ + "Time: {:.3f}/{:.3f}".format( + self.rank, iteration, loss_val, + load_time - start_time, graph_time - load_time) + print(log) + + tb = self.tb_logger + tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration) + tb.add_scalar("Learning-Rate", current_lr, iteration) + + @dg.no_grad + def valid_step(self, iteration): + self.wavenet.eval() + + total_loss = [] + start_time = time.time() + sample_audios = [] + for audios, mels, audio_starts in self.validloader(): + loss, sample_audio = self.wavenet(audios, mels, audio_starts, True) + total_loss.append(float(loss.numpy())) + sample_audios.append(sample_audio) + total_time = time.time() - start_time + + if self.rank == 0: + loss_val = np.mean(total_loss) + log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format( + self.rank, loss_val, total_time) + print(log) + + tb = self.tb_logger + tb.add_scalar("Valid-Avg-Loss", loss_val, iteration) + tb.add_audio("Teacher-Forced-Audio-0", sample_audios[0].numpy(), + iteration, sample_rate=self.config.sample_rate) + tb.add_audio("Teacher-Forced-Audio-1", sample_audios[1].numpy(), + iteration, sample_rate=self.config.sample_rate) + + def save(self, iteration): + utils.save_latest_parameters(self.checkpoint_dir, iteration, + self.wavenet, self.optimizer) + utils.save_latest_checkpoint(self.checkpoint_dir, iteration) + + @dg.no_grad + def infer(self, iteration): + self.wavenet.eval() + + config = self.config + sample = config.sample + + output = "{}/{}/iter-{}".format(config.output, config.name, iteration) + os.makedirs(output, exist_ok=True) + + filename = "{}/valid_{}.wav".format(output, sample) + print("Synthesize sample {}, save as {}".format(sample, filename)) + + mels_list = [mels for _, mels, _ in self.validloader()] + start_time = time.time() + syn_audio = self.wavenet.synthesize(mels_list[sample]) + syn_time = time.time() - start_time + print("audio shape {}, synthesis time {}".format( + syn_audio.shape, syn_time)) + librosa.output.write_wav(filename, syn_audio, + sr=config.sample_rate) diff --git a/parakeet/models/wavenet/wavenet_modules.py b/parakeet/models/wavenet/wavenet_modules.py new file mode 100644 index 0000000..c5c01e9 --- /dev/null +++ b/parakeet/models/wavenet/wavenet_modules.py @@ -0,0 +1,423 @@ +import itertools +import math + +import numpy as np +from paddle import fluid +import paddle.fluid.dygraph as dg +import ops +import weight_norm + + +def get_padding(filter_size, stride, padding_type='same'): + if padding_type == 'same': + padding = [(x - y) // 2 for x, y in zip(filter_size, stride)] + else: + raise ValueError("Only support same padding") + return padding + + +def debug(x, var_name, rank, verbose=False): + if not verbose and rank != 0: + return + dim = len(x.shape) + if not isinstance(x, np.ndarray): + x = x.numpy() + if dim == 1: + print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x)) + elif dim == 2: + print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5])) + elif dim == 3: + print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5, 0])) + else: + print("Rank", rank, var_name, "shape", x.shape) + + +def extract_slices(x, audio_starts, audio_length, rank): + slices = [] + for i in range(x.shape[0]): + start = audio_starts.numpy()[i] + end = start + audio_length + slice = fluid.layers.slice( + x, axes=[0, 1], starts=[i, start], ends=[i+1, end]) + slices.append(fluid.layers.squeeze(slice, [0])) + + x = fluid.layers.stack(slices, axis=0) + + return x + + +class Conditioner(dg.Layer): + def __init__(self, name_scope, config): + super(Conditioner, self).__init__(name_scope) + upsample_factors = config.conditioner.upsample_factors + filter_sizes = config.conditioner.filter_sizes + assert np.prod(upsample_factors) == config.fft_window_shift + + self.deconvs = [] + for i, up_scale in enumerate(upsample_factors): + stride = (up_scale, 1) + padding = get_padding(filter_sizes[i], stride) + self.deconvs.append( + ops.Conv2DTranspose( + self.full_name(), + num_filters=1, + filter_size=filter_sizes[i], + padding=padding, + stride=stride)) + + # Register python list as parameters. + for i, layer in enumerate(self.deconvs): + self.add_sublayer("conv_transpose_{}".format(i), layer) + + def forward(self, x): + x = fluid.layers.unsqueeze(x, 1) + for layer in self.deconvs: + x = fluid.layers.leaky_relu(layer(x), alpha=0.4) + + return fluid.layers.squeeze(x, [1]) + + +class WaveNetModule(dg.Layer): + def __init__(self, name_scope, config, rank): + super(WaveNetModule, self).__init__(name_scope) + + self.rank = rank + self.conditioner = Conditioner(self.full_name(), config) + self.dilations = list( + itertools.islice( + itertools.cycle(config.dilation_block), config.layers)) + self.context_size = sum(self.dilations) + 1 + self.log_scale_min = config.log_scale_min + self.config = config + + print("dilations", self.dilations) + print("context_size", self.context_size) + + if config.loss_type == "softmax": + self.embedding_fc = ops.Embedding( + self.full_name(), + num_embeddings=config.num_channels, + embed_dim=config.residual_channels) + elif config.loss_type == "mix-gaussian-pdf": + self.embedding_fc = ops.FC( + self.full_name(), + in_features=1, + size=config.residual_channels, + num_flatten_dims=2, + relu=False) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + + self.dilated_causal_convs = [] + for dilation in self.dilations: + self.dilated_causal_convs.append( + ops.Conv1D_GU( + self.full_name(), + conditioner_dim=config.mel_bands, + in_channels=config.residual_channels, + num_filters=config.residual_channels, + filter_size=config.kernel_width, + dilation=dilation, + causal=True + ) + ) + + for i, layer in enumerate(self.dilated_causal_convs): + self.add_sublayer("dilated_causal_conv_{}".format(i), layer) + + self.fc1 = ops.FC( + self.full_name(), + in_features=config.residual_channels, + size=config.skip_channels, + num_flatten_dims=2, + relu=True, + act="relu") + + self.fc2 = ops.FC( + self.full_name(), + in_features=config.skip_channels, + size=config.skip_channels, + num_flatten_dims=2, + relu=True, + act="relu") + + if config.loss_type == "softmax": + self.fc3 = ops.FC( + self.full_name(), + in_features=config.skip_channels, + size=config.num_channels, + num_flatten_dims=2, + relu=False) + elif config.loss_type == "mix-gaussian-pdf": + self.fc3 = ops.FC( + self.full_name(), + in_features=config.skip_channels, + size=3 * config.num_mixtures, + num_flatten_dims=2, + relu=False) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + + def sample_softmax(self, mix_parameters): + batch, length, hidden = mix_parameters.shape + mix_param_2d = fluid.layers.reshape(mix_parameters, + [batch * length, hidden]) + mix_param_2d = fluid.layers.softmax(mix_param_2d, axis=-1) + + # quantized: [batch * length] + quantized = fluid.layers.cast(fluid.layers.sampling_id(mix_param_2d), + dtype="float32") + samples = (quantized + 0.5) * (2.0 / self.config.num_channels) - 1.0 + + # samples: [batch * length] + return samples + + def sample_mix_gaussian(self, mix_parameters): + # mix_parameters reshape from [bs, 13799, 3 * num_mixtures] + # to [bs * 13799, 3 * num_mixtures]. + batch, length, hidden = mix_parameters.shape + mix_param_2d = fluid.layers.reshape(mix_parameters, + [batch * length, hidden]) + K = hidden // 3 + + # Unpack the parameters of the mixture of gaussian. + logits_pi = mix_param_2d[:, 0 : K] + mu = mix_param_2d[:, K : 2*K] + log_s = mix_param_2d[:, 2*K : 3*K] + s = fluid.layers.exp(log_s) + + pi = fluid.layers.softmax(logits_pi, axis=-1) + comp_samples = fluid.layers.sampling_id(pi) + + row_idx = dg.to_variable(np.arange(batch * length)) + comp_samples = fluid.layers.stack([row_idx, comp_samples], axis=-1) + + mu_comp = fluid.layers.gather_nd(mu, comp_samples) + s_comp = fluid.layers.gather_nd(s, comp_samples) + + # N(0, 1) Normal Sample. + u = fluid.layers.gaussian_random(shape=[batch * length]) + samples = mu_comp + u * s_comp + samples = fluid.layers.clip(samples, min=-1.0, max=1.0) + + return samples + + def softmax_loss(self, targets, mix_parameters): + # targets: [bs, 13799] -> [bs, 11752] + # mix_params: [bs, 13799, 3] -> [bs, 11752, 3] + targets = targets[:, self.context_size:] + mix_parameters = mix_parameters[:, self.context_size:, :] + + # Quantized audios to integral values with range [0, num_channels) + num_channels = self.config.num_channels + targets = fluid.layers.clip(targets, min=-1.0, max=0.99999) + quantized = fluid.layers.cast( + (targets + 1.0) / 2.0 * num_channels, dtype="int64") + + # per_sample_loss shape: [bs, 17952, 1] + per_sample_loss = fluid.layers.softmax_with_cross_entropy( + logits=mix_parameters, label=fluid.layers.unsqueeze(quantized, 2)) + loss = fluid.layers.reduce_mean(per_sample_loss) + #debug(loss, "softmax loss", self.rank) + + return loss + + def mixture_density_loss(self, targets, mix_parameters, log_scale_min): + # targets: [bs, 13799] -> [bs, 11752] + # mix_params: [bs, 13799, 3] -> [bs, 11752, 3] + targets = targets[:, self.context_size:] + mix_parameters = mix_parameters[:, self.context_size:, :] + + # log_s: [bs, 11752, num_mixture] + logits_pi, mu, log_s = fluid.layers.split(mix_parameters, num_or_sections=3, dim=-1) + + pi = fluid.layers.softmax(logits_pi, axis=-1) + log_s = fluid.layers.clip(log_s, min=log_scale_min, max=100.0) + inv_s = fluid.layers.exp(0.0 - log_s) + + # Calculate gaussian loss. + targets = fluid.layers.unsqueeze(targets, -1) + targets = fluid.layers.expand(targets, [1, 1, self.config.num_mixtures]) + x_std = inv_s * (targets - mu) + exponent = fluid.layers.exp(-0.5 * x_std * x_std) + # pdf_x: [bs, 11752, 1] + pdf_x = 1.0 / np.sqrt(2.0 * np.pi) * inv_s * exponent + pdf_x = pi * pdf_x + # pdf_x: [bs, 11752] + pdf_x = fluid.layers.reduce_sum(pdf_x, dim=-1) + per_sample_loss = 0.0 - fluid.layers.log(pdf_x + 1e-9) + + loss = fluid.layers.reduce_mean(per_sample_loss) + + return loss + + def forward(self, audios, mels, audio_starts, sample=False): + # audios: [bs, 13800], mels: [bs, full_frame_length, 80] + # audio_starts: [bs] + # Build conditioner based on mels. + full_conditioner = self.conditioner(mels) + + # Slice conditioners. + audio_length = audios.shape[1] + conditioner = extract_slices(full_conditioner, + audio_starts, audio_length, self.rank) + + # input_audio, target_audio: [bs, 13799] + input_audios = audios[:, :-1] + target_audios = audios[:, 1:] + # conditioner: [bs, 13799, 80] + conditioner = conditioner[:, 1:, :] + + loss_type = self.config.loss_type + + # layer_input: [bs, 13799, 128] + if loss_type == "softmax": + input_audios = fluid.layers.clip( + input_audios, min=-1.0, max=0.99999) + # quantized have values in [0, num_channels) + quantized = fluid.layers.cast( + (input_audios + 1.0) / 2.0 * self.config.num_channels, + dtype="int64") + layer_input = self.embedding_fc(fluid.layers.unsqueeze(quantized, 2)) + elif loss_type == "mix-gaussian-pdf": + layer_input = self.embedding_fc(fluid.layers.unsqueeze(input_audios, 2)) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + + # layer_input: [bs, res_channel, 1, 13799] + layer_input = fluid.layers.unsqueeze(fluid.layers.transpose(layer_input, perm=[0, 2, 1]), 2) + # conditioner: [bs, mel_bands, 1, 13799] + conditioner = fluid.layers.unsqueeze(fluid.layers.transpose(conditioner, perm=[0, 2, 1]), 2) + + # layer_input: [bs, res_channel, 1, 13799] + # skip: [bs, res_channel, 1, 13799] + skip = None + for i, layer in enumerate(self.dilated_causal_convs): + layer_input, skip = layer(layer_input, skip, conditioner) + #debug(layer_input, "layer_input_" + str(i), self.rank) + #debug(skip, "skip_" + str(i), self.rank) + + # Reshape skip to [bs, 13799, res_channel] + skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1]) + #debug(skip, "skip", self.rank) + + # mix_param: [bs, 13799, 3 * num_mixtures] + mix_parameters = self.fc3(self.fc2(self.fc1(skip))) + + # Sample teacher-forced audio. + sample_audios = None + if sample: + if loss_type == "softmax": + sample_audios = self.sample_softmax(mix_parameters) + elif loss_type == "mix-gaussian-pdf": + sample_audios = self.sample_mix_gaussian(mix_parameters) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + #debug(sample_audios, "sample_audios", self.rank) + + # Calculate mix-gaussian density loss. + # padding is all zero. + # target_audio: [bs, 13799]. + # mix_params: [bs, 13799, 3]. + if loss_type == "softmax": + loss = self.softmax_loss(target_audios, mix_parameters) + elif loss_type == "mix-gaussian-pdf": + loss = self.mixture_density_loss(target_audios, + mix_parameters, self.log_scale_min) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + + #print("Rank {}, loss {}".format(self.rank, loss.numpy())) + + return loss, sample_audios + + def synthesize(self, mels): + self.start_new_sequence() + print("input mels shape", mels.shape) + # mels: [bs=1, n_frames, 80] + # conditioner: [1, n_frames * samples_per_frame, 80] + # Should I move forward by one sample? No difference + # Append context frame to mels + bs, n_frames, mel_bands = mels.shape + #num_pad_frames = int(np.ceil(self.context_size / self.config.fft_window_shift)) + #silence = fluid.layers.zeros(shape=[bs, num_pad_frames, mel_bands], dtype="float32") + #inf_mels = fluid.layers.concat([silence, mels], axis=1) + #print("padded mels shape", inf_mels.shape) + + #conditioner = self.conditioner(inf_mels)[:, self.context_size:, :] + conditioner = self.conditioner(mels) + time_steps = conditioner.shape[1] + print("Total steps", time_steps) + + loss_type = self.config.loss_type + audio_samples = [] + current_sample = fluid.layers.zeros(shape=[bs, 1, 1], dtype="float32") + for i in range(time_steps): + if i % 100 == 0: + print("Step", i) + + # convert from real value sample to audio embedding. + # [bs, 1, 128] + if loss_type == "softmax": + current_sample = fluid.layers.clip( + current_sample, min=-1.0, max=0.99999) + # quantized have values in [0, num_channels) + quantized = fluid.layers.cast( + (current_sample + 1.0) / 2.0 * self.config.num_channels, + dtype="int64") + audio_input = self.embedding_fc(quantized) + elif loss_type == "mix-gaussian-pdf": + audio_input = self.embedding_fc(current_sample) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + + # [bs, 128, 1, 1] + audio_input = fluid.layers.unsqueeze(fluid.layers.transpose(audio_input, perm=[0, 2, 1]), 2) + # [bs, 80] + cond_input = conditioner[:, i, :] + # [bs, 80, 1, 1] + cond_input = fluid.layers.reshape( + cond_input, cond_input.shape + [1, 1]) + + skip = None + for layer in self.dilated_causal_convs: + audio_input, skip = layer.add_input(audio_input, skip, cond_input) + + # [bs, 1, 128] + skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1]) + # [bs, 1, 3] + mix_parameters = self.fc3(self.fc2(self.fc1(skip))) + if loss_type == "softmax": + sample = self.sample_softmax(mix_parameters) + elif loss_type == "mix-gaussian-pdf": + sample = self.sample_mix_gaussian(mix_parameters) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + audio_samples.append(sample) + # [bs] + current_sample = audio_samples[-1] + # [bs, 1, 1] + current_sample = fluid.layers.reshape(current_sample, + current_sample.shape + [1, 1]) + + # syn_audio: (num_samples,) + syn_audio = fluid.layers.concat(audio_samples, axis=0).numpy() + + return syn_audio + + def start_new_sequence(self): + for layer in self.sublayers(): + if isinstance(layer, weight_norm.Conv1D): + layer.start_new_sequence() + + def save(self, iteration): + utils.save_latest_parameters(self.checkpoint_dir, iteration, + self.wavenet, self.optimizer) + utils.save_latest_checkpoint(self.checkpoint_dir, iteration) diff --git a/parakeet/models/wavenet/weight_norm.py b/parakeet/models/wavenet/weight_norm.py new file mode 100644 index 0000000..75fe413 --- /dev/null +++ b/parakeet/models/wavenet/weight_norm.py @@ -0,0 +1,920 @@ +import math +from copy import deepcopy + +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid +from paddle.fluid import core +from paddle.fluid.framework import Variable +from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer +from paddle.fluid.layers import utils +from six.moves import reduce + + +def _norm(p, dim): + """Computes the norm over all dimensions except dim. + It differs from pytorch implementation that it does not keep dim. + This difference is related with the broadcast mechanism in paddle. + Read elementeise_mul for more. + """ + if dim is None: + return np.linalg.norm(p, ord=2, axis=None) + elif dim == 0: + p = np.reshape(p, newshape=(p.shape[0], -1)) + return np.linalg.norm(p, ord=2, axis=1) + elif dim == p.ndim - 1: + p = np.reshape(p, newshape=(-1, p.shape[-1])) + return np.linalg.norm(p, ord=2, axis=0) + else: + perm = list(range(p.ndim)) + perm[0] = dim + perm[dim] = 0 + return _norm(np.transpose(p, axes=perm)) + + +class Conv1D(dg.Layer): + """ + A convolution 1D block implemented with Conv2D. Form simplicity and + ensuring the output has the same length as the input, it does not allow + stride > 1. + """ + def __init__(self, + name_scope, + num_filters, + filter_size=3, + dilation=1, + groups=None, + causal=False, + param_attr=None, + bias_attr=None, + use_cudnn=True, + act=None, + dtype="float32"): + super(Conv1D, self).__init__(name_scope, dtype=dtype) + + if causal: + padding = dilation * (filter_size - 1) + else: + padding = (dilation * (filter_size - 1)) // 2 + + self.num_filters = num_filters + self.filter_size = filter_size + self.dilation = dilation + self.causal = causal + self.padding = padding + self.act = act + + self.conv = Conv2D( + self.full_name(), + num_filters=num_filters, + filter_size=(1, filter_size), + stride=(1, 1), + dilation=(1, dilation), + padding=(0, padding), + groups=groups, + param_attr=param_attr, + bias_attr=bias_attr, + use_cudnn=use_cudnn, + act=act, + dtype=dtype) + + def forward(self, x): + """ + Args: + x (Variable): Shape(B, C_in, 1, T), the input, where C_in means + input channels. + Returns: + x (Variable): Shape(B, C_out, 1, T), the outputs, where C_out means + output channels (num_filters). + """ + x = self.conv(x) + if self.filter_size > 1: + if self.causal: + x = fluid.layers.slice( + x, axes=[3], starts=[0], ends=[-self.padding]) + elif self.filter_size % 2 == 0: + x = fluid.layers.slice(x, axes=[3], starts=[0], ends=[-1]) + return x + + def start_new_sequence(self): + self.temp_weight = None + self.input_buffer = None + + def add_input(self, x): + """ + Adding input for a time step and compute an output for a time step. + + Args: + x (Variable): Shape(B, C_in, 1, T), the input, where C_in means + input channels, and T = 1. + Returns: + out (Variable): Shape(B, C_out, 1, T), the outputs, where C_out + means output channels (num_filters), and T = 1. + + """ + if self.temp_weight is None: + self.temp_weight = self._reshaped_weight() + + window_size = 1 + (self.filter_size - 1) * self.dilation + batch_size = x.shape[0] + in_channels = x.shape[1] + + if self.filter_size > 1: + if self.input_buffer is None: + self.input_buffer = fluid.layers.fill_constant( + [batch_size, in_channels, 1, window_size - 1], + dtype=x.dtype, + value=0.0) + else: + self.input_buffer = self.input_buffer[:, :, :, 1:] + self.input_buffer = fluid.layers.concat( + [self.input_buffer, x], axis=3) + x = self.input_buffer + if self.dilation > 1: + if not hasattr(self, "indices"): + self.indices = dg.to_variable( + np.arange(0, window_size, self.dilation)) + tmp = fluid.layers.transpose( + self.input_buffer, perm=[3, 1, 2, 0]) + tmp = fluid.layers.gather(tmp, index=self.indices) + tmp = fluid.layers.transpose(tmp, perm=[3, 1, 2, 0]) + x = tmp + inputs = fluid.layers.reshape( + x, shape=[batch_size, in_channels * 1 * self.filter_size]) + out = fluid.layers.matmul(inputs, self.temp_weight, transpose_y=True) + out = fluid.layers.elementwise_add(out, self.conv._bias_param, axis=-1) + out = fluid.layers.reshape(out, out.shape + [1, 1]) + out = self._helper.append_activation(out, act=self.act) + return out + + def _reshaped_weight(self): + """ + Get the linearized weight of convolution filter, cause it is by nature + a matmul weight. And because the model uses weight norm, compute the + weight by weight_v * weight_g to make it faster. + Returns: + weight_matrix (Variable): Shape(C_out, C_in * 1 * kernel_size) + """ + shape = self.conv._filter_param_v.shape + matrix_shape = [shape[0], np.prod(shape[1:])] + weight_matrix = fluid.layers.reshape( + self.conv._filter_param_v, shape=matrix_shape) + weight_matrix = fluid.layers.elementwise_mul( + fluid.layers.l2_normalize( + weight_matrix, axis=1), + self.conv._filter_param_g, + axis=0) + return weight_matrix + + +class FC(dg.Layer): + """ + **Fully Connected Layer** + This function creates a fully connected layer in the network. It can take + one or multiple tensors as its inputs(input can be a list of Variable, see + Args in detail). It creates a pair of variables called (magnitude(g), + direction(V)) for each input tensor. Elementwise_mul(V, g) represents a fully connected + weight matrix from each input unit to each output unit. + The fully connected layer multiplies each input tensor + with its corresponding weight to produce an output Tensor with shape [M, `size`], + where M is batch size. If multiple input tensors are given, the results of + multiple output tensors with shape [M, `size`] will be summed up. If bias_attr + is not None, a bias variable will be created and added to the output. + Finally, if activation is not None, it will be applied to the output as well. + When the input is single tensor: + .. math:: + Out = Act({X(normalize(V)g) + b}) + When the input are multiple tensors: + .. math:: + Out = Act({\sum_{i=0}^{N-1}X_i(V_ig_i) + b}) + In the above equation: + * :math:`N`: Number of the input. N equals to len(input) if input is list of Variable. + * :math:`X_i`: The i-th input tensor. + * :math:`V_i`: The i-th direction matrix corresponding i-th input tensor. + * :math:`g_i`: The i-th magnitude vector corresponding i-th input tensor. + * :math:`b`: The bias parameter created by this layer (if needed). + * :math:`Act`: The activation function. + * :math:`Out`: The output tensor. + See below for an example. + .. code-block:: text + Given: + data_1.data = [[[0.1, 0.2], + [0.3, 0.4]]] + data_1.shape = (1, 2, 2) # 1 is batch_size + data_2 = [[[0.1, 0.2, 0.3]]] + data_2.shape = (1, 1, 3) + out = fluid.layers.fc(input=[data_1, data_2], size=2) + Then: + out.data = [[0.18669507, 0.1893476]] + out.shape = (1, 2) + Args: + name_scope(str): The name of this class. + size(int): The number of output units in this layer. + num_flatten_dims (int): The fc layer can accept an input tensor with more than + two dimensions. If this happens, the multidimensional tensor will first be flattened + into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input + tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1) + dimensions will be flatten to form the first dimension of the final matrix (height of + the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to + form the second dimension of the final matrix (width of the matrix). For example, suppose + `X` is a 5-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. + Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. Default: 1 + param_attr (ParamAttr|list of ParamAttr|None): The parameter attribute for learnable + parameters/weights of this layer. + bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias + of this layer. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. Default: None. + act (str|None): Activation to be applied to the output of this layer. + is_test(bool): A flag indicating whether execution is in test phase. Default: False + dtype(str): Dtype used for weight + Raises: + ValueError: If rank of the input tensor is less than 2. + Examples: + .. code-block:: python + from paddle.fluid.dygraph.base import to_variable + import paddle.fluid as fluid + from paddle.fluid.dygraph import FC + import numpy as np + data = np.random.uniform( -1, 1, [30, 10, 32] ).astype('float32') + with fluid.dygraph.guard(): + fc = FC( "fc", 64, num_flatten_dims=2) + data = to_variable( data ) + conv = fc( data ) + """ + + def __init__(self, + name_scope, + size, + num_flatten_dims=1, + epsilon=1e-30, + param_attr=None, + bias_attr=None, + act=None, + is_test=False, + dtype="float32"): + super(FC, self).__init__(name_scope, dtype) + + self._size = size + self._num_flatten_dims = num_flatten_dims + self._epsilon = epsilon + self._dtype = dtype + self._param_attr = param_attr + self._bias_attr = bias_attr + self._act = act + self.__g = list() + self.__v = list() + + @property + def _v(self, i=0): + return self.__v[i] + + @property + def _g(self, i=0): + return self.__g[i] + + @_v.setter + def _v(self, value, i=0): + assert isinstance(value, Parameter) + self.__v[i] = value + + @_g.setter + def _g(self, value, i=0): + assert isinstance(value, Parameter) + self.__g[i] = value + + def _build_once(self, input): + i = 0 + for inp, param in self._helper.iter_inputs_and_params( + input, self._param_attr): + input_shape = inp.shape + + param_shape = [ + reduce(lambda a, b: a * b, + input_shape[self._num_flatten_dims:], 1) + ] + [self._size] + self.__v.append( + self.add_parameter( + "_v%d" % i, + self.create_parameter( + attr=param, + shape=param_shape, + dtype=self._dtype, + is_bias=False))) + + magnitude_shape = param_shape[1:] + magnitude_value = np.linalg.norm( + self.__v[i].numpy(), ord=2, axis=0) + + self.__g.append( + self.add_parameter( + "_g%d" % i, + self.create_parameter( + attr=fluid.ParamAttr(initializer=fluid.initializer. + NumpyArrayInitializer( + magnitude_value)), + shape=magnitude_shape, + dtype=self._dtype, + is_bias=False))) + i += 1 + + size = list([self._size]) + self._b = self.create_parameter( + attr=self._bias_attr, shape=size, dtype=self._dtype, is_bias=True) + + def forward(self, input): + mul_results = list() + i = 0 + for inp, param in self._helper.iter_inputs_and_params( + input, self._param_attr): + v_norm = self._helper.create_variable_for_type_inference( + self._dtype) + v_normalized = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="norm", + inputs={"X": self.__v[i]}, + outputs={"Out": v_normalized, + "Norm": v_norm}, + attrs={"axis": 0, + "epsilon": self._epsilon}) + weight = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="elementwise_mul", + inputs={"X": [v_normalized], + "Y": [self.__g[i]]}, + outputs={"Out": [weight]}, + attrs={"axis": 1}) + tmp = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type="mul", + inputs={"X": inp, + "Y": weight}, + outputs={"Out": tmp}, + attrs={ + "x_num_col_dims": self._num_flatten_dims, + "y_num_col_dims": 1 + }) + i += 1 + mul_results.append(tmp) + + if len(mul_results) == 1: + pre_bias = mul_results[0] + else: + pre_bias = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="sum", + inputs={"X": mul_results}, + outputs={"Out": pre_bias}, + attrs={"use_mkldnn": False}) + + if self._b: + pre_activation = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type="elementwise_add", + inputs={"X": [pre_bias], + "Y": [self._b]}, + outputs={"Out": [pre_activation]}, + attrs={"axis": self._num_flatten_dims}) + else: + pre_activation = pre_bias + # Currently, we don't support inplace in dygraph mode + return self._helper.append_activation(pre_activation, act=self._act) + + +class Conv2D(dg.Layer): + """ + The convolution2D layer calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input and + Output are in NCHW format, where N is batch size, C is the number of + channels, H is the height of the feature, and W is the width of the feature. + Filter is in MCHW format, where M is the number of output image channels, + C is the number of input image channels, H is the height of the filter, + and W is the width of the filter. If the groups is greater than 1, + C will equal the number of input image channels divided by the groups. + Please refer to UFLDL's `convolution + ` + for more detials. + If bias attribution and activation type are provided, bias is added to the + output of the convolution, and the corresponding activation function is + applied to the final result. + For each input :math:`X`, the equation is: + .. math:: + Out = \sigma ((Vg) \\ast X + b) + Where: + * :math:`X`: Input value, a tensor with NCHW format. + * :math:`V`: Filter direction value, a tensor with MCHW format. + * :math:`g`: Filter magnitude value, a tensor with M format. + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + Example: + - Input: + Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` + Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)` + - Output: + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` + Where + .. math:: + H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 + Args: + name_scope(str) : The name for this class. + num_filters(int): The number of filter. It is as same as the output + image channel. + filter_size (int|tuple|None): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + stride (int|tuple): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: stride = 1. + padding (int|tuple): The padding size. If padding is a tuple, it must + contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: padding = 0. + dilation (int|tuple): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: dilation = 1. + groups (int): The groups number of the Conv2d Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Default: groups=1. + param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights + of conv2d. If it is set to None or one attribute of ParamAttr, conv2d + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with :math:`Normal(0.0, std)`, + and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None. + bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv2d + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True + act (str): Activation type, if it is set to None, activation is not appended. + Default: None + Raises: + ValueError: If the shapes of input, filter_size, stride, padding and + groups mismatch. + Examples: + .. code-block:: python + from paddle.fluid.dygraph.base import to_variable + import paddle.fluid as fluid + from paddle.fluid.dygraph import Conv2D + import numpy as np + data = np.random.uniform( -1, 1, [10, 3, 32, 32] ).astype('float32') + with fluid.dygraph.guard(): + conv2d = Conv2D( "conv2d", 2, 3) + data = to_variable( data ) + conv = conv2d( data ) + """ + + def __init__(self, + name_scope, + num_filters, + filter_size, + stride=1, + padding=0, + dilation=1, + groups=None, + param_attr=None, + bias_attr=None, + use_cudnn=True, + act=None, + epsilon=1e-30, + dtype="float32"): + assert param_attr is not False, "param_attr should not be False here." + super(Conv2D, self).__init__(name_scope, dtype) + self._groups = groups + self._stride = utils.convert_to_list(stride, 2, "stride") + self._padding = utils.convert_to_list(padding, 2, "padding") + self._dilation = utils.convert_to_list(dilation, 2, "dilation") + self._act = act + if not isinstance(use_cudnn, bool): + raise ValueError("use_cudnn should be True or False") + self._use_cudnn = use_cudnn + self._filter_size = filter_size + self._num_filters = num_filters + self._param_attr = param_attr + self._bias_attr = bias_attr + self._epsilon = epsilon + self._dtype = dtype + # if (self._num_channels == self._groups and + # num_filters % self._num_channels == 0 and not self._use_cudnn): + # self._l_type = 'depthwise_conv2d' + # else: + # TODO(jiabin): recover the usage of depthwise_conv2d when it's + # kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17275 + self._l_type = "conv2d" + + def _build_once(self, input): + self._num_channels = input.shape[1] + if self._groups is None: + num_filter_channels = self._num_channels + else: + if self._num_channels % self._groups != 0: + raise ValueError("num_channels must be divisible by groups.") + num_filter_channels = self._num_channels // self._groups + filter_size = utils.convert_to_list(self._filter_size, 2, + "filter_size") + filter_shape = [self._num_filters, int(num_filter_channels) + ] + filter_size + + def _get_default_param_initializer(): + filter_elem_num = filter_size[0] * filter_size[ + 1] * self._num_channels + std = (2.0 / filter_elem_num)**0.5 + return Normal(0.0, std, 0) + + # weight_v + self._filter_param_v = self.create_parameter( + attr=self._param_attr, + shape=filter_shape, + dtype=self._dtype, + default_initializer=_get_default_param_initializer()) + + # weight_g + norm_value = _norm( + self._filter_param_v.numpy(), dim=0) # CAUTION: hard-code + self._filter_param_g = self.create_parameter( + attr=fluid.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer( + norm_value)), + shape=norm_value.shape, + dtype=self._dtype, + default_initializer=_get_default_param_initializer()) + + if self._use_cudnn: + self.create_variable( + name="kCUDNNFwdAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + self.create_variable( + name="kCUDNNBwdDataAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + self.create_variable( + name="kCUDNNBwdFilterAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + + self._bias_param = self.create_parameter( + attr=self._bias_attr, + shape=[self._num_filters], + dtype=self._dtype, + is_bias=True) + + def forward(self, input): + matrix = self._helper.create_variable_for_type_inference(self._dtype) + tmp = self._helper.create_variable_for_type_inference(self._dtype) + new_shape = [ + self._filter_param_v.shape[0], + reduce(lambda x, y: x * y, self._filter_param_v.shape[1:], 1), + ] + + self._helper.append_op( + type="reshape2", + inputs={"X": self._filter_param_v}, + attrs={"shape": new_shape}, + outputs={"Out": matrix, + "XShape": tmp}) + + m_norm = self._helper.create_variable_for_type_inference(self._dtype) + m_normalized = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="norm", + inputs={"X": matrix}, + outputs={"Out": m_normalized, + "Norm": m_norm}, + attrs={"axis": 1, + "epsilon": self._epsilon}) + + v_normalized = self._helper.create_variable_for_type_inference( + self._dtype) + tmp2 = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type="reshape2", + inputs={"X": m_normalized}, + attrs={"shape": self._filter_param_v.shape}, + outputs={"Out": v_normalized, + "XShape": tmp2}) + + filter_param = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="elementwise_mul", + inputs={"X": [v_normalized], + "Y": [self._filter_param_g]}, + outputs={"Out": [filter_param]}, + attrs={"axis": 0}, # CAUTION: hard-code + ) + + pre_bias = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + + self._helper.append_op( + type=self._l_type, + inputs={"Input": input, + "Filter": filter_param}, + outputs={"Output": pre_bias}, + attrs={ + "strides": self._stride, + "paddings": self._padding, + "dilations": self._dilation, + "groups": self._groups if self._groups else 1, + "use_cudnn": self._use_cudnn, + "use_mkldnn": False, + }) + + if self._bias_param is not None: + pre_act = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type="elementwise_add", + inputs={"X": [pre_bias], + "Y": [self._bias_param]}, + outputs={"Out": [pre_act]}, + attrs={"axis": 1}) + else: + pre_act = pre_bias + + # Currently, we don't support inplace in dygraph mode + return self._helper.append_activation(pre_act, act=self._act) + + +class Conv2DTranspose(dg.Layer): + """ + **Convlution2D transpose layer** + The convolution2D transpose layer calculates the output based on the input, + filter, and dilations, strides, paddings. Input(Input) and output(Output) + are in NCHW format. Where N is batch size, C is the number of channels, + H is the height of the feature, and W is the width of the feature. + Parameters(dilations, strides, paddings) are two elements. These two elements + represent height and width, respectively. The details of convolution transpose + layer, please refer to the following explanation and references + `therein `_. + If bias attribution and activation type are provided, bias is added to + the output of the convolution, and the corresponding activation function + is applied to the final result. + For each input :math:`X`, the equation is: + .. math:: + Out = \sigma ((Vg) \\ast X + b) + Where: + * :math:`X`: Input value, a tensor with NCHW format. + * :math:`V`: Filter value, a tensor with MCHW format. + * :math:`g`: Filter value, a tensor with M format. + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + Example: + - Input: + Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` + Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)` + - Output: + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` + Where + .. math:: + H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\ + W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\ + H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\ + W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] ) + Args: + name_scope(str): The name of this class. + num_filters(int): The number of the filter. It is as same as the output + image channel. + output_size(int|tuple|None): The output image size. If output size is a + tuple, it must contain two integers, (image_H, image_W). None if use + filter_size, padding, and stride to calculate output_size. + if output_size and filter_size are specified at the same time, They + should follow the formula above. Default: None. + filter_size(int|tuple|None): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. None if use output size to + calculate filter_size. Default: None. + padding(int|tuple): The padding size. If padding is a tuple, it must + contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: padding = 0. + stride(int|tuple): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: stride = 1. + dilation(int|tuple): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: dilation = 1. + groups(int): The groups number of the Conv2d transpose layer. Inspired by + grouped convolution in Alex Krizhevsky's Deep CNN paper, in which + when group=2, the first half of the filters is only connected to the + first half of the input channels, while the second half of the + filters is only connected to the second half of the input channels. + Default: groups = 1. + param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights + of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. Default: None. + bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d_transpose. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv2d_transpose + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True. + act (str): Activation type, if it is set to None, activation is not appended. + Default: None. + Returns: + Variable: The tensor variable storing the convolution transpose result. + Raises: + ValueError: If the shapes of input, filter_size, stride, padding and + groups mismatch. + Examples: + .. code-block:: python + import paddle.fluid as fluid + import numpy + with fluid.dygraph.guard(): + data = numpy.random.random((3, 32, 32)).astype('float32') + conv2DTranspose = fluid.dygraph.nn.Conv2DTranspose( + 'Conv2DTranspose', num_filters=2, filter_size=3) + ret = conv2DTranspose(fluid.dygraph.base.to_variable(data)) + """ + + def __init__(self, + name_scope, + num_filters, + output_size=None, + filter_size=None, + padding=0, + stride=1, + dilation=1, + groups=None, + param_attr=None, + bias_attr=None, + use_cudnn=True, + epsilon=1e-30, + act=None, + dtype="float32"): + super(Conv2DTranspose, self).__init__(name_scope, dtype) + assert (param_attr is not False + ), "param_attr should not be False in conv2d_transpose." + self._param_attr = param_attr + self._bias_attr = bias_attr + self._groups = groups + self._num_filters = num_filters + self._use_cudnn = use_cudnn + self._padding = padding + self._stride = stride + self._dilation = dilation + self._filter_size = filter_size + self._output_size = output_size + self._op_type = "conv2d_transpose" + self._epsilon = epsilon + + def _build_once(self, input): + input_channel = input.shape[1] + if (input_channel == self._groups and + self._num_filters == input_channel and not self._use_cudnn): + self._op_type = "depthwise_conv2d_transpose" + + if not isinstance(input, Variable): + raise TypeError("Input of conv2d_transpose must be Variable") + + self._padding = utils.convert_to_list(self._padding, 2, "padding") + self._stride = utils.convert_to_list(self._stride, 2, "stride") + self._dilation = utils.convert_to_list(self._dilation, 2, "dilation") + + if not isinstance(self._use_cudnn, bool): + raise ValueError("use_cudnn should be True or False") + + if self._filter_size is None: + if self._output_size is None: + raise ValueError( + "output_size must be set when filter_size is None") + if isinstance(self._output_size, int): + self._output_size = [self._output_size, self._output_size] + + h_in = input.shape[2] + w_in = input.shape[3] + + filter_size_h = (self._output_size[0] - + (h_in - 1) * self._stride[0] + 2 * + self._padding[0] - 1) // self._dilation[0] + 1 + filter_size_w = (self._output_size[1] - + (w_in - 1) * self._stride[1] + 2 * + self._padding[1] - 1) // self._dilation[1] + 1 + self._filter_size = [filter_size_h, filter_size_w] + else: + self._filter_size = utils.convert_to_list( + self._filter_size, 2, "conv2d_transpose.filter_size") + + if self._output_size is None: + self._output_size = [] + elif isinstance(self._output_size, list) or isinstance( + self._output_size, int): + self._output_size = utils.convert_to_list(self._output_size, 2, + "output_size") + else: + raise ValueError("output_size should be list or int") + self._padding = utils.convert_to_list(self._padding, 2, "padding") + self._groups = 1 if self._groups is None else self._groups + filter_shape = [ + input_channel, + self._num_filters // self._groups, + ] + self._filter_size + + # img filter v (direction) + self._img_filter_v = self.create_parameter( + dtype=input.dtype, shape=filter_shape, attr=self._param_attr) + + # img filter g (magnitude) + img_filter_magnitude = _norm( + self._img_filter_v.numpy(), dim=0) # CAUTION: hard-code + self._img_filter_g = self.create_parameter( + dtype=input.dtype, + shape=img_filter_magnitude.shape, + attr=fluid.ParamAttr( + initializer=NumpyArrayInitializer(img_filter_magnitude))) + + self._img_bias = self.create_parameter( + attr=self._bias_attr, + shape=[self._num_filters], + dtype=self._dtype, + is_bias=True) + + def forward(self, input): + matrix = self._helper.create_variable_for_type_inference(self._dtype) + tmp = self._helper.create_variable_for_type_inference(self._dtype) + new_shape = [ + self._img_filter_v.shape[0], + reduce(lambda x, y: x * y, self._img_filter_v.shape[1:], 1), + ] + + self._helper.append_op( + type="reshape2", + inputs={"X": self._img_filter_v}, + attrs={"shape": new_shape}, + outputs={"Out": matrix, + "XShape": tmp}) + + m_norm = self._helper.create_variable_for_type_inference(self._dtype) + m_normalized = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="norm", + inputs={"X": matrix}, + outputs={"Out": m_normalized, + "Norm": m_norm}, + attrs={"axis": 1, + "epsilon": self._epsilon}) + + v_normalized = self._helper.create_variable_for_type_inference( + self._dtype) + tmp2 = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type="reshape2", + inputs={"X": m_normalized}, + attrs={"shape": self._img_filter_v.shape}, + outputs={"Out": v_normalized, + "XShape": tmp2}) + + img_filter = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="elementwise_mul", + inputs={"X": [v_normalized], + "Y": [self._img_filter_g]}, + outputs={"Out": [img_filter]}, + attrs={"axis": 0}, # CAUTION: hard-code + ) + + pre_bias = self._helper.create_variable_for_type_inference( + dtype=input.dtype) + self._helper.append_op( + type=self._op_type, + inputs={"Input": [input], + "Filter": [img_filter]}, + outputs={"Output": pre_bias}, + attrs={ + "output_size": self._output_size, + "strides": self._stride, + "paddings": self._padding, + "dilations": self._dilation, + "groups": self._groups, + "use_cudnn": self._use_cudnn, + }) + + if self._img_bias is not None: + pre_act = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type="elementwise_add", + inputs={"X": [pre_bias], + "Y": [self._img_bias]}, + outputs={"Out": [pre_act]}, + attrs={"axis": 1}) + else: + pre_act = pre_bias + + out = self._helper.append_activation(pre_act) + return out From 98841ee48ad8bd819b04930a809a88a71f6d1966 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Mon, 2 Dec 2019 22:58:17 -0800 Subject: [PATCH 2/3] clean code --- parakeet/data/sampler.py | 31 +- .../wavenet_ljspeech_mix_gaussian.yaml | 32 + .../configs/wavenet_ljspeech_softmax.yaml | 31 + parakeet/models/wavenet/data.py | 37 +- parakeet/models/wavenet/ops.py | 249 ----- parakeet/models/wavenet/wavenet.py | 28 +- parakeet/models/wavenet/wavenet_modules.py | 154 ++- parakeet/models/wavenet/weight_norm.py | 920 ------------------ parakeet/modules/modules.py | 154 +++ 9 files changed, 314 insertions(+), 1322 deletions(-) create mode 100644 parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml create mode 100644 parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml delete mode 100644 parakeet/models/wavenet/ops.py delete mode 100644 parakeet/models/wavenet/weight_norm.py diff --git a/parakeet/data/sampler.py b/parakeet/data/sampler.py index 097cc03..60aa5db 100644 --- a/parakeet/data/sampler.py +++ b/parakeet/data/sampler.py @@ -163,6 +163,35 @@ class WeightedRandomSampler(Sampler): return self.num_samples +class DistributedSampler(Sampler): + def __init__(self, dataset_size, num_trainers, rank, shuffle=True): + self.dataset_size = dataset_size + self.num_trainers = num_trainers + self.rank = rank + self.num_samples = int(np.ceil(dataset_size / num_trainers)) + self.total_size = self.num_samples * num_trainers + assert self.total_size >= self.dataset_size + self.shuffle = shuffle + + def __iter__(self): + indices = list(range(self.dataset_size)) + if self.shuffle: + random.shuffle(indices) + + # Append extra samples to make it evenly distributed on all trainers. + indices += indices[:(self.total_size - self.dataset_size)] + assert len(indices) == self.total_size + + # Subset samples for each trainer. + indices = indices[self.rank:self.total_size:self.num_trainers] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + class BatchSampler(Sampler): r"""Wraps another sampler to yield a mini-batch of indices. Args: @@ -206,4 +235,4 @@ class BatchSampler(Sampler): if self.drop_last: return len(self.sampler) // self.batch_size else: - return (len(self.sampler) + self.batch_size - 1) // self.batch_size \ No newline at end of file + return (len(self.sampler) + self.batch_size - 1) // self.batch_size diff --git a/parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml b/parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml new file mode 100644 index 0000000..bf19577 --- /dev/null +++ b/parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml @@ -0,0 +1,32 @@ +valid_size: 16 +train_clip_second: 0.5 +sample_rate: 22050 +fft_window_shift: 256 +fft_window_size: 1024 +fft_size: 2048 +mel_bands: 80 + +seed: 1 +batch_size: 8 +test_every: 2000 +save_every: 10000 +max_iterations: 2000000 + +layers: 30 +kernel_width: 2 +dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +residual_channels: 128 +skip_channels: 128 +loss_type: mix-gaussian-pdf +num_mixtures: 10 +log_scale_min: -9.0 + +conditioner: + filter_sizes: [[32, 3], [32, 3]] + upsample_factors: [16, 16] + +learning_rate: 0.001 +gradient_max_norm: 100.0 +anneal: + every: 200000 + rate: 0.5 diff --git a/parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml b/parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml new file mode 100644 index 0000000..f39de5d --- /dev/null +++ b/parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml @@ -0,0 +1,31 @@ +valid_size: 16 +train_clip_second: 0.5 +sample_rate: 22050 +fft_window_shift: 256 +fft_window_size: 1024 +fft_size: 2048 +mel_bands: 80 + +seed: 1 +batch_size: 8 +test_every: 2000 +save_every: 10000 +max_iterations: 2000000 + +layers: 30 +kernel_width: 2 +dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +residual_channels: 128 +skip_channels: 128 +loss_type: softmax +num_channels: 2048 + +conditioner: + filter_sizes: [[32, 3], [32, 3]] + upsample_factors: [16, 16] + +learning_rate: 0.001 +gradient_max_norm: 100.0 +anneal: + every: 200000 + rate: 0.5 diff --git a/parakeet/models/wavenet/data.py b/parakeet/models/wavenet/data.py index 61cc4ab..a4f1b70 100644 --- a/parakeet/models/wavenet/data.py +++ b/parakeet/models/wavenet/data.py @@ -1,5 +1,3 @@ -import math -import os import random import librosa @@ -9,7 +7,7 @@ from paddle import fluid import utils from parakeet.datasets import ljspeech from parakeet.data import dataset -from parakeet.data.sampler import Sampler, BatchSampler, SequentialSampler +from parakeet.data.sampler import DistributedSampler, BatchSampler from parakeet.data.datacargo import DataCargo @@ -20,7 +18,7 @@ class Dataset(ljspeech.LJSpeech): self.fft_window_shift = config.fft_window_shift # Calculate context frames. frames_per_second = config.sample_rate // self.fft_window_shift - train_clip_frames = int(math.ceil( + train_clip_frames = int(np.ceil( config.train_clip_second * frames_per_second)) context_frames = config.context_size // self.fft_window_shift self.num_frames = train_clip_frames + context_frames @@ -39,7 +37,7 @@ class Dataset(ljspeech.LJSpeech): assert loaded_sr == sr # Pad audio to the right size. - frames = math.ceil(float(audio.size) / fft_window_shift) + frames = int(np.ceil(float(audio.size) / fft_window_shift)) fft_padding = (fft_size - fft_window_shift) // 2 desired_length = frames * fft_window_shift + fft_padding * 2 pad_amount = (desired_length - audio.size) // 2 @@ -125,35 +123,6 @@ class Subset(dataset.Dataset): return len(self.indices) -class DistributedSampler(Sampler): - def __init__(self, dataset_size, num_trainers, rank, shuffle=True): - self.dataset_size = dataset_size - self.num_trainers = num_trainers - self.rank = rank - self.num_samples = int(math.ceil(dataset_size / num_trainers)) - self.total_size = self.num_samples * num_trainers - assert self.total_size >= self.dataset_size - self.shuffle = shuffle - - def __iter__(self): - indices = list(range(self.dataset_size)) - if self.shuffle: - random.shuffle(indices) - - # Append extra samples to make it evenly distributed on all trainers. - indices += indices[:(self.total_size - self.dataset_size)] - assert len(indices) == self.total_size - - # Subset samples for each trainer. - indices = indices[self.rank:self.total_size:self.num_trainers] - assert len(indices) == self.num_samples - - return iter(indices) - - def __len__(self): - return self.num_samples - - class LJSpeech: def __init__(self, config, nranks, rank): place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace() diff --git a/parakeet/models/wavenet/ops.py b/parakeet/models/wavenet/ops.py deleted file mode 100644 index 6eda2a9..0000000 --- a/parakeet/models/wavenet/ops.py +++ /dev/null @@ -1,249 +0,0 @@ -import paddle -from paddle import fluid -import paddle.fluid.dygraph as dg -import numpy as np - -import weight_norm - - -def Embedding(name_scope, - num_embeddings, - embed_dim, - padding_idx=None, - std=0.1, - dtype="float32"): - # param attrs - weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal( - scale=std)) - layer = dg.Embedding( - name_scope, (num_embeddings, embed_dim), - padding_idx=padding_idx, - param_attr=weight_attr, - dtype=dtype) - return layer - - -def FC(name_scope, - in_features, - size, - num_flatten_dims=1, - relu=False, - dropout=0.0, - act=None, - dtype="float32"): - """ - A special Linear Layer, when it is used with dropout, the weight is - initialized as normal(0, std=np.sqrt((1-dropout) / in_features)) - """ - - # stds - if isinstance(in_features, int): - in_features = [in_features] - - stds = [np.sqrt((1.0 - dropout) / in_feature) for in_feature in in_features] - if relu: - stds = [std * np.sqrt(2.0) for std in stds] - - weight_inits = [ - fluid.initializer.NormalInitializer(scale=std) for std in stds - ] - bias_init = fluid.initializer.ConstantInitializer(0.0) - - # param attrs - weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits] - bias_attr = fluid.ParamAttr(initializer=bias_init) - - layer = weight_norm.FC(name_scope, - size, - num_flatten_dims=num_flatten_dims, - param_attr=weight_attrs, - bias_attr=bias_attr, - act=act, - dtype=dtype) - return layer - - -def Conv1D(name_scope, - in_channels, - num_filters, - filter_size=2, - dilation=1, - groups=None, - causal=False, - std_mul=1.0, - dropout=0.0, - use_cudnn=True, - act=None, - dtype="float32"): - """ - A special Conv1D Layer, when it is used with dropout, the weight is - initialized as - normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_channels))) - """ - # std - std = np.sqrt((std_mul * (1.0 - dropout)) / (filter_size * in_channels)) - weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std) - bias_init = fluid.initializer.ConstantInitializer(0.0) - - # param attrs - weight_attr = fluid.ParamAttr(initializer=weight_init) - bias_attr = fluid.ParamAttr(initializer=bias_init) - - layer = weight_norm.Conv1D( - name_scope, - num_filters, - filter_size, - dilation, - groups=groups, - causal=causal, - param_attr=weight_attr, - bias_attr=bias_attr, - use_cudnn=use_cudnn, - act=act, - dtype=dtype) - return layer - - -class Conv1D_GU(dg.Layer): - def __init__(self, - name_scope, - conditioner_dim, - in_channels, - num_filters, - filter_size, - dilation, - causal=False, - residual=True, - dtype="float32"): - super(Conv1D_GU, self).__init__(name_scope, dtype=dtype) - - self.conditioner_dim = conditioner_dim - self.in_channels = in_channels - self.num_filters = num_filters - self.filter_size = filter_size - self.dilation = dilation - self.causal = causal - self.residual = residual - - if residual: - assert ( - in_channels == num_filters - ), "this block uses residual connection"\ - "the input_channels should equals num_filters" - - self.conv = Conv1D( - self.full_name(), - in_channels, - 2 * num_filters, - filter_size, - dilation, - causal=causal, - dtype=dtype) - - self.fc = Conv1D( - self.full_name(), - conditioner_dim, - 2 * num_filters, - filter_size=1, - dilation=1, - causal=False, - dtype=dtype) - - def forward(self, x, skip=None, conditioner=None): - """ - Args: - x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU - layer, where B means batch_size, C_in means the input channels - T means input time steps. - conditioner (Variable): Shape(B, C_con, 1, T), expanded mel - conditioner, where C_con is conditioner hidden dim which - equals the num of mel bands. Note that when using residual - connection, the Conv1DGLU does not change the number of - channels, so out channels equals input channels. - Returns: - x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where - C_out means the output channels of Conv1DGLU. - """ - residual = x - x = self.conv(x) - - if conditioner is not None: - cond_bias = self.fc(conditioner) - x += cond_bias - - content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) - - # Gated Unit. - x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), - fluid.layers.tanh(content)) - - if skip is None: - skip = x - else: - skip = fluid.layers.scale(skip + x, np.sqrt(0.5)) - - if self.residual: - x = fluid.layers.scale(residual + x, np.sqrt(0.5)) - - return x, skip - - def add_input(self, x, skip=None, conditioner=None): - """ - Inputs: - x: shape(B, num_filters, 1, time_steps) - conditioner: shape(B, conditioner_dim, 1, time_steps) - Outputs: - out: shape(B, num_filters, 1, time_steps), where time_steps = 1 - """ - residual = x - - # add step input and produce step output - x = self.conv.add_input(x) - - if conditioner is not None: - cond_bias = self.fc(conditioner) - x += cond_bias - - content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) - - # Gated Unit. - x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), - fluid.layers.tanh(content)) - - if skip is None: - skip = x - else: - skip = fluid.layers.scale(skip + x, np.sqrt(0.5)) - - if self.residual: - x = fluid.layers.scale(residual + x, np.sqrt(0.5)) - - return x, skip - - -def Conv2DTranspose(name_scope, - num_filters, - filter_size, - padding=0, - stride=1, - dilation=1, - use_cudnn=True, - act=None, - dtype="float32"): - val = 1.0 / (filter_size[0] * filter_size[1]) - weight_init = fluid.initializer.ConstantInitializer(val) - weight_attr = fluid.ParamAttr(initializer=weight_init) - - layer = weight_norm.Conv2DTranspose( - name_scope, - num_filters, - filter_size=filter_size, - padding=padding, - stride=stride, - dilation=dilation, - param_attr=weight_attr, - use_cudnn=use_cudnn, - act=act, - dtype=dtype) - - return layer diff --git a/parakeet/models/wavenet/wavenet.py b/parakeet/models/wavenet/wavenet.py index acc6e76..c636c4b 100644 --- a/parakeet/models/wavenet/wavenet.py +++ b/parakeet/models/wavenet/wavenet.py @@ -4,12 +4,12 @@ import time import librosa import numpy as np -from paddle import fluid import paddle.fluid.dygraph as dg +from paddle import fluid import utils from data import LJSpeech -from wavenet_modules import WaveNetModule, debug +from wavenet_modules import WaveNetModule class WaveNet(): @@ -33,18 +33,6 @@ class WaveNet(): self.trainloader = dataset.trainloader self.validloader = dataset.validloader -# if self.rank == 0: -# for i, (audios, mels, ids) in enumerate(self.validloader()): -# print("audios {}, mels {}, ids {}".format(audios.dtype, mels.dtype, ids.dtype)) -# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format( -# i, self.rank, audios.shape, mels.shape, ids.shape, -# ids.numpy())) -# -# for i, (audios, mels, ids) in enumerate(self.trainloader): -# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format( -# i, self.rank, audios.shape, mels.shape, ids.shape, -# ids.numpy())) - wavenet = WaveNetModule("wavenet", config, self.rank) # Dry run once to create and initalize all necessary parameters. @@ -139,8 +127,8 @@ class WaveNet(): self.wavenet.eval() total_loss = [] - start_time = time.time() sample_audios = [] + start_time = time.time() for audios, mels, audio_starts in self.validloader(): loss, sample_audio = self.wavenet(audios, mels, audio_starts, True) total_loss.append(float(loss.numpy())) @@ -160,11 +148,6 @@ class WaveNet(): tb.add_audio("Teacher-Forced-Audio-1", sample_audios[1].numpy(), iteration, sample_rate=self.config.sample_rate) - def save(self, iteration): - utils.save_latest_parameters(self.checkpoint_dir, iteration, - self.wavenet, self.optimizer) - utils.save_latest_checkpoint(self.checkpoint_dir, iteration) - @dg.no_grad def infer(self, iteration): self.wavenet.eval() @@ -186,3 +169,8 @@ class WaveNet(): syn_audio.shape, syn_time)) librosa.output.write_wav(filename, syn_audio, sr=config.sample_rate) + + def save(self, iteration): + utils.save_latest_parameters(self.checkpoint_dir, iteration, + self.wavenet, self.optimizer) + utils.save_latest_checkpoint(self.checkpoint_dir, iteration) diff --git a/parakeet/models/wavenet/wavenet_modules.py b/parakeet/models/wavenet/wavenet_modules.py index c5c01e9..fbab741 100644 --- a/parakeet/models/wavenet/wavenet_modules.py +++ b/parakeet/models/wavenet/wavenet_modules.py @@ -1,11 +1,9 @@ import itertools -import math import numpy as np -from paddle import fluid import paddle.fluid.dygraph as dg -import ops -import weight_norm +from paddle import fluid +from parakeet.modules import conv, modules def get_padding(filter_size, stride, padding_type='same'): @@ -16,22 +14,6 @@ def get_padding(filter_size, stride, padding_type='same'): return padding -def debug(x, var_name, rank, verbose=False): - if not verbose and rank != 0: - return - dim = len(x.shape) - if not isinstance(x, np.ndarray): - x = x.numpy() - if dim == 1: - print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x)) - elif dim == 2: - print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5])) - elif dim == 3: - print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5, 0])) - else: - print("Rank", rank, var_name, "shape", x.shape) - - def extract_slices(x, audio_starts, audio_length, rank): slices = [] for i in range(x.shape[0]): @@ -58,7 +40,7 @@ class Conditioner(dg.Layer): stride = (up_scale, 1) padding = get_padding(filter_sizes[i], stride) self.deconvs.append( - ops.Conv2DTranspose( + modules.Conv2DTranspose( self.full_name(), num_filters=1, filter_size=filter_sizes[i], @@ -94,12 +76,13 @@ class WaveNetModule(dg.Layer): print("context_size", self.context_size) if config.loss_type == "softmax": - self.embedding_fc = ops.Embedding( + self.embedding_fc = modules.Embedding( self.full_name(), num_embeddings=config.num_channels, - embed_dim=config.residual_channels) + embed_dim=config.residual_channels, + std=0.1) elif config.loss_type == "mix-gaussian-pdf": - self.embedding_fc = ops.FC( + self.embedding_fc = modules.FC( self.full_name(), in_features=1, size=config.residual_channels, @@ -112,7 +95,7 @@ class WaveNetModule(dg.Layer): self.dilated_causal_convs = [] for dilation in self.dilations: self.dilated_causal_convs.append( - ops.Conv1D_GU( + modules.Conv1D_GU( self.full_name(), conditioner_dim=config.mel_bands, in_channels=config.residual_channels, @@ -126,7 +109,7 @@ class WaveNetModule(dg.Layer): for i, layer in enumerate(self.dilated_causal_convs): self.add_sublayer("dilated_causal_conv_{}".format(i), layer) - self.fc1 = ops.FC( + self.fc1 = modules.FC( self.full_name(), in_features=config.residual_channels, size=config.skip_channels, @@ -134,7 +117,7 @@ class WaveNetModule(dg.Layer): relu=True, act="relu") - self.fc2 = ops.FC( + self.fc2 = modules.FC( self.full_name(), in_features=config.skip_channels, size=config.skip_channels, @@ -143,14 +126,14 @@ class WaveNetModule(dg.Layer): act="relu") if config.loss_type == "softmax": - self.fc3 = ops.FC( + self.fc3 = modules.FC( self.full_name(), in_features=config.skip_channels, size=config.num_channels, num_flatten_dims=2, relu=False) elif config.loss_type == "mix-gaussian-pdf": - self.fc3 = ops.FC( + self.fc3 = modules.FC( self.full_name(), in_features=config.skip_channels, size=3 * config.num_mixtures, @@ -175,8 +158,8 @@ class WaveNetModule(dg.Layer): return samples def sample_mix_gaussian(self, mix_parameters): - # mix_parameters reshape from [bs, 13799, 3 * num_mixtures] - # to [bs * 13799, 3 * num_mixtures]. + # mix_parameters reshape from [bs, len, 3 * num_mixtures] + # to [bs * len, 3 * num_mixtures]. batch, length, hidden = mix_parameters.shape mix_param_2d = fluid.layers.reshape(mix_parameters, [batch * length, hidden]) @@ -197,7 +180,7 @@ class WaveNetModule(dg.Layer): mu_comp = fluid.layers.gather_nd(mu, comp_samples) s_comp = fluid.layers.gather_nd(s, comp_samples) - # N(0, 1) Normal Sample. + # N(0, 1) normal sample. u = fluid.layers.gaussian_random(shape=[batch * length]) samples = mu_comp + u * s_comp samples = fluid.layers.clip(samples, min=-1.0, max=1.0) @@ -205,8 +188,6 @@ class WaveNetModule(dg.Layer): return samples def softmax_loss(self, targets, mix_parameters): - # targets: [bs, 13799] -> [bs, 11752] - # mix_params: [bs, 13799, 3] -> [bs, 11752, 3] targets = targets[:, self.context_size:] mix_parameters = mix_parameters[:, self.context_size:, :] @@ -216,22 +197,22 @@ class WaveNetModule(dg.Layer): quantized = fluid.layers.cast( (targets + 1.0) / 2.0 * num_channels, dtype="int64") - # per_sample_loss shape: [bs, 17952, 1] + # per_sample_loss shape: [bs, len, 1] per_sample_loss = fluid.layers.softmax_with_cross_entropy( logits=mix_parameters, label=fluid.layers.unsqueeze(quantized, 2)) loss = fluid.layers.reduce_mean(per_sample_loss) - #debug(loss, "softmax loss", self.rank) return loss def mixture_density_loss(self, targets, mix_parameters, log_scale_min): - # targets: [bs, 13799] -> [bs, 11752] - # mix_params: [bs, 13799, 3] -> [bs, 11752, 3] + # targets: [bs, len] + # mix_params: [bs, len, 3 * num_mixture] targets = targets[:, self.context_size:] mix_parameters = mix_parameters[:, self.context_size:, :] - # log_s: [bs, 11752, num_mixture] - logits_pi, mu, log_s = fluid.layers.split(mix_parameters, num_or_sections=3, dim=-1) + # log_s: [bs, len, num_mixture] + logits_pi, mu, log_s = fluid.layers.split( + mix_parameters, num_or_sections=3, dim=-1) pi = fluid.layers.softmax(logits_pi, axis=-1) log_s = fluid.layers.clip(log_s, min=log_scale_min, max=100.0) @@ -242,10 +223,9 @@ class WaveNetModule(dg.Layer): targets = fluid.layers.expand(targets, [1, 1, self.config.num_mixtures]) x_std = inv_s * (targets - mu) exponent = fluid.layers.exp(-0.5 * x_std * x_std) - # pdf_x: [bs, 11752, 1] pdf_x = 1.0 / np.sqrt(2.0 * np.pi) * inv_s * exponent pdf_x = pi * pdf_x - # pdf_x: [bs, 11752] + # pdf_x: [bs, len] pdf_x = fluid.layers.reduce_sum(pdf_x, dim=-1) per_sample_loss = 0.0 - fluid.layers.log(pdf_x + 1e-9) @@ -254,8 +234,6 @@ class WaveNetModule(dg.Layer): return loss def forward(self, audios, mels, audio_starts, sample=False): - # audios: [bs, 13800], mels: [bs, full_frame_length, 80] - # audio_starts: [bs] # Build conditioner based on mels. full_conditioner = self.conditioner(mels) @@ -264,15 +242,14 @@ class WaveNetModule(dg.Layer): conditioner = extract_slices(full_conditioner, audio_starts, audio_length, self.rank) - # input_audio, target_audio: [bs, 13799] + # input_audio, target_audio: [bs, len] input_audios = audios[:, :-1] target_audios = audios[:, 1:] - # conditioner: [bs, 13799, 80] + # conditioner: [bs, len, mel_bands] conditioner = conditioner[:, 1:, :] loss_type = self.config.loss_type - # layer_input: [bs, 13799, 128] if loss_type == "softmax": input_audios = fluid.layers.clip( input_audios, min=-1.0, max=0.99999) @@ -280,31 +257,31 @@ class WaveNetModule(dg.Layer): quantized = fluid.layers.cast( (input_audios + 1.0) / 2.0 * self.config.num_channels, dtype="int64") - layer_input = self.embedding_fc(fluid.layers.unsqueeze(quantized, 2)) + layer_input = self.embedding_fc( + fluid.layers.unsqueeze(quantized, 2)) elif loss_type == "mix-gaussian-pdf": - layer_input = self.embedding_fc(fluid.layers.unsqueeze(input_audios, 2)) + layer_input = self.embedding_fc( + fluid.layers.unsqueeze(input_audios, 2)) else: raise ValueError( "loss_type {} is unsupported!".format(loss_type)) - # layer_input: [bs, res_channel, 1, 13799] - layer_input = fluid.layers.unsqueeze(fluid.layers.transpose(layer_input, perm=[0, 2, 1]), 2) - # conditioner: [bs, mel_bands, 1, 13799] - conditioner = fluid.layers.unsqueeze(fluid.layers.transpose(conditioner, perm=[0, 2, 1]), 2) + # layer_input: [bs, res_channel, 1, len] + layer_input = fluid.layers.unsqueeze( + fluid.layers.transpose(layer_input, perm=[0, 2, 1]), 2) + # conditioner: [bs, mel_bands, 1, len] + conditioner = fluid.layers.unsqueeze( + fluid.layers.transpose(conditioner, perm=[0, 2, 1]), 2) - # layer_input: [bs, res_channel, 1, 13799] - # skip: [bs, res_channel, 1, 13799] skip = None for i, layer in enumerate(self.dilated_causal_convs): + # layer_input: [bs, res_channel, 1, len] + # skip: [bs, res_channel, 1, len] layer_input, skip = layer(layer_input, skip, conditioner) - #debug(layer_input, "layer_input_" + str(i), self.rank) - #debug(skip, "skip_" + str(i), self.rank) - # Reshape skip to [bs, 13799, res_channel] - skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1]) - #debug(skip, "skip", self.rank) - - # mix_param: [bs, 13799, 3 * num_mixtures] + # Reshape skip to [bs, len, res_channel] + skip = fluid.layers.transpose( + fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1]) mix_parameters = self.fc3(self.fc2(self.fc1(skip))) # Sample teacher-forced audio. @@ -317,12 +294,7 @@ class WaveNetModule(dg.Layer): else: raise ValueError( "loss_type {} is unsupported!".format(loss_type)) - #debug(sample_audios, "sample_audios", self.rank) - # Calculate mix-gaussian density loss. - # padding is all zero. - # target_audio: [bs, 13799]. - # mix_params: [bs, 13799, 3]. if loss_type == "softmax": loss = self.softmax_loss(target_audios, mix_parameters) elif loss_type == "mix-gaussian-pdf": @@ -332,27 +304,16 @@ class WaveNetModule(dg.Layer): raise ValueError( "loss_type {} is unsupported!".format(loss_type)) - #print("Rank {}, loss {}".format(self.rank, loss.numpy())) - return loss, sample_audios def synthesize(self, mels): self.start_new_sequence() - print("input mels shape", mels.shape) - # mels: [bs=1, n_frames, 80] - # conditioner: [1, n_frames * samples_per_frame, 80] - # Should I move forward by one sample? No difference - # Append context frame to mels bs, n_frames, mel_bands = mels.shape - #num_pad_frames = int(np.ceil(self.context_size / self.config.fft_window_shift)) - #silence = fluid.layers.zeros(shape=[bs, num_pad_frames, mel_bands], dtype="float32") - #inf_mels = fluid.layers.concat([silence, mels], axis=1) - #print("padded mels shape", inf_mels.shape) - - #conditioner = self.conditioner(inf_mels)[:, self.context_size:, :] conditioner = self.conditioner(mels) time_steps = conditioner.shape[1] - print("Total steps", time_steps) + + print("input mels shape", mels.shape) + print("Total synthesis steps", time_steps) loss_type = self.config.loss_type audio_samples = [] @@ -361,8 +322,8 @@ class WaveNetModule(dg.Layer): if i % 100 == 0: print("Step", i) - # convert from real value sample to audio embedding. - # [bs, 1, 128] + # Convert from real value sample to audio embedding. + # audio_input: [bs, 1, channel] if loss_type == "softmax": current_sample = fluid.layers.clip( current_sample, min=-1.0, max=0.99999) @@ -377,21 +338,23 @@ class WaveNetModule(dg.Layer): raise ValueError( "loss_type {} is unsupported!".format(loss_type)) - # [bs, 128, 1, 1] - audio_input = fluid.layers.unsqueeze(fluid.layers.transpose(audio_input, perm=[0, 2, 1]), 2) - # [bs, 80] + # [bs, channel, 1, 1] + audio_input = fluid.layers.unsqueeze( + fluid.layers.transpose(audio_input, perm=[0, 2, 1]), 2) + # [bs, mel_bands] cond_input = conditioner[:, i, :] - # [bs, 80, 1, 1] + # [bs, mel_bands, 1, 1] cond_input = fluid.layers.reshape( cond_input, cond_input.shape + [1, 1]) skip = None for layer in self.dilated_causal_convs: - audio_input, skip = layer.add_input(audio_input, skip, cond_input) + audio_input, skip = layer.add_input( + audio_input, skip, cond_input) - # [bs, 1, 128] - skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1]) - # [bs, 1, 3] + # [bs, 1, channel] + skip = fluid.layers.transpose( + fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1]) mix_parameters = self.fc3(self.fc2(self.fc1(skip))) if loss_type == "softmax": sample = self.sample_softmax(mix_parameters) @@ -407,17 +370,12 @@ class WaveNetModule(dg.Layer): current_sample = fluid.layers.reshape(current_sample, current_sample.shape + [1, 1]) - # syn_audio: (num_samples,) + # syn_audio: [num_samples] syn_audio = fluid.layers.concat(audio_samples, axis=0).numpy() return syn_audio def start_new_sequence(self): for layer in self.sublayers(): - if isinstance(layer, weight_norm.Conv1D): + if isinstance(layer, conv.Conv1D): layer.start_new_sequence() - - def save(self, iteration): - utils.save_latest_parameters(self.checkpoint_dir, iteration, - self.wavenet, self.optimizer) - utils.save_latest_checkpoint(self.checkpoint_dir, iteration) diff --git a/parakeet/models/wavenet/weight_norm.py b/parakeet/models/wavenet/weight_norm.py deleted file mode 100644 index 75fe413..0000000 --- a/parakeet/models/wavenet/weight_norm.py +++ /dev/null @@ -1,920 +0,0 @@ -import math -from copy import deepcopy - -import numpy as np -import paddle.fluid.dygraph as dg -from paddle import fluid -from paddle.fluid import core -from paddle.fluid.framework import Variable -from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer -from paddle.fluid.layers import utils -from six.moves import reduce - - -def _norm(p, dim): - """Computes the norm over all dimensions except dim. - It differs from pytorch implementation that it does not keep dim. - This difference is related with the broadcast mechanism in paddle. - Read elementeise_mul for more. - """ - if dim is None: - return np.linalg.norm(p, ord=2, axis=None) - elif dim == 0: - p = np.reshape(p, newshape=(p.shape[0], -1)) - return np.linalg.norm(p, ord=2, axis=1) - elif dim == p.ndim - 1: - p = np.reshape(p, newshape=(-1, p.shape[-1])) - return np.linalg.norm(p, ord=2, axis=0) - else: - perm = list(range(p.ndim)) - perm[0] = dim - perm[dim] = 0 - return _norm(np.transpose(p, axes=perm)) - - -class Conv1D(dg.Layer): - """ - A convolution 1D block implemented with Conv2D. Form simplicity and - ensuring the output has the same length as the input, it does not allow - stride > 1. - """ - def __init__(self, - name_scope, - num_filters, - filter_size=3, - dilation=1, - groups=None, - causal=False, - param_attr=None, - bias_attr=None, - use_cudnn=True, - act=None, - dtype="float32"): - super(Conv1D, self).__init__(name_scope, dtype=dtype) - - if causal: - padding = dilation * (filter_size - 1) - else: - padding = (dilation * (filter_size - 1)) // 2 - - self.num_filters = num_filters - self.filter_size = filter_size - self.dilation = dilation - self.causal = causal - self.padding = padding - self.act = act - - self.conv = Conv2D( - self.full_name(), - num_filters=num_filters, - filter_size=(1, filter_size), - stride=(1, 1), - dilation=(1, dilation), - padding=(0, padding), - groups=groups, - param_attr=param_attr, - bias_attr=bias_attr, - use_cudnn=use_cudnn, - act=act, - dtype=dtype) - - def forward(self, x): - """ - Args: - x (Variable): Shape(B, C_in, 1, T), the input, where C_in means - input channels. - Returns: - x (Variable): Shape(B, C_out, 1, T), the outputs, where C_out means - output channels (num_filters). - """ - x = self.conv(x) - if self.filter_size > 1: - if self.causal: - x = fluid.layers.slice( - x, axes=[3], starts=[0], ends=[-self.padding]) - elif self.filter_size % 2 == 0: - x = fluid.layers.slice(x, axes=[3], starts=[0], ends=[-1]) - return x - - def start_new_sequence(self): - self.temp_weight = None - self.input_buffer = None - - def add_input(self, x): - """ - Adding input for a time step and compute an output for a time step. - - Args: - x (Variable): Shape(B, C_in, 1, T), the input, where C_in means - input channels, and T = 1. - Returns: - out (Variable): Shape(B, C_out, 1, T), the outputs, where C_out - means output channels (num_filters), and T = 1. - - """ - if self.temp_weight is None: - self.temp_weight = self._reshaped_weight() - - window_size = 1 + (self.filter_size - 1) * self.dilation - batch_size = x.shape[0] - in_channels = x.shape[1] - - if self.filter_size > 1: - if self.input_buffer is None: - self.input_buffer = fluid.layers.fill_constant( - [batch_size, in_channels, 1, window_size - 1], - dtype=x.dtype, - value=0.0) - else: - self.input_buffer = self.input_buffer[:, :, :, 1:] - self.input_buffer = fluid.layers.concat( - [self.input_buffer, x], axis=3) - x = self.input_buffer - if self.dilation > 1: - if not hasattr(self, "indices"): - self.indices = dg.to_variable( - np.arange(0, window_size, self.dilation)) - tmp = fluid.layers.transpose( - self.input_buffer, perm=[3, 1, 2, 0]) - tmp = fluid.layers.gather(tmp, index=self.indices) - tmp = fluid.layers.transpose(tmp, perm=[3, 1, 2, 0]) - x = tmp - inputs = fluid.layers.reshape( - x, shape=[batch_size, in_channels * 1 * self.filter_size]) - out = fluid.layers.matmul(inputs, self.temp_weight, transpose_y=True) - out = fluid.layers.elementwise_add(out, self.conv._bias_param, axis=-1) - out = fluid.layers.reshape(out, out.shape + [1, 1]) - out = self._helper.append_activation(out, act=self.act) - return out - - def _reshaped_weight(self): - """ - Get the linearized weight of convolution filter, cause it is by nature - a matmul weight. And because the model uses weight norm, compute the - weight by weight_v * weight_g to make it faster. - Returns: - weight_matrix (Variable): Shape(C_out, C_in * 1 * kernel_size) - """ - shape = self.conv._filter_param_v.shape - matrix_shape = [shape[0], np.prod(shape[1:])] - weight_matrix = fluid.layers.reshape( - self.conv._filter_param_v, shape=matrix_shape) - weight_matrix = fluid.layers.elementwise_mul( - fluid.layers.l2_normalize( - weight_matrix, axis=1), - self.conv._filter_param_g, - axis=0) - return weight_matrix - - -class FC(dg.Layer): - """ - **Fully Connected Layer** - This function creates a fully connected layer in the network. It can take - one or multiple tensors as its inputs(input can be a list of Variable, see - Args in detail). It creates a pair of variables called (magnitude(g), - direction(V)) for each input tensor. Elementwise_mul(V, g) represents a fully connected - weight matrix from each input unit to each output unit. - The fully connected layer multiplies each input tensor - with its corresponding weight to produce an output Tensor with shape [M, `size`], - where M is batch size. If multiple input tensors are given, the results of - multiple output tensors with shape [M, `size`] will be summed up. If bias_attr - is not None, a bias variable will be created and added to the output. - Finally, if activation is not None, it will be applied to the output as well. - When the input is single tensor: - .. math:: - Out = Act({X(normalize(V)g) + b}) - When the input are multiple tensors: - .. math:: - Out = Act({\sum_{i=0}^{N-1}X_i(V_ig_i) + b}) - In the above equation: - * :math:`N`: Number of the input. N equals to len(input) if input is list of Variable. - * :math:`X_i`: The i-th input tensor. - * :math:`V_i`: The i-th direction matrix corresponding i-th input tensor. - * :math:`g_i`: The i-th magnitude vector corresponding i-th input tensor. - * :math:`b`: The bias parameter created by this layer (if needed). - * :math:`Act`: The activation function. - * :math:`Out`: The output tensor. - See below for an example. - .. code-block:: text - Given: - data_1.data = [[[0.1, 0.2], - [0.3, 0.4]]] - data_1.shape = (1, 2, 2) # 1 is batch_size - data_2 = [[[0.1, 0.2, 0.3]]] - data_2.shape = (1, 1, 3) - out = fluid.layers.fc(input=[data_1, data_2], size=2) - Then: - out.data = [[0.18669507, 0.1893476]] - out.shape = (1, 2) - Args: - name_scope(str): The name of this class. - size(int): The number of output units in this layer. - num_flatten_dims (int): The fc layer can accept an input tensor with more than - two dimensions. If this happens, the multidimensional tensor will first be flattened - into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input - tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1) - dimensions will be flatten to form the first dimension of the final matrix (height of - the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to - form the second dimension of the final matrix (width of the matrix). For example, suppose - `X` is a 5-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. - Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. Default: 1 - param_attr (ParamAttr|list of ParamAttr|None): The parameter attribute for learnable - parameters/weights of this layer. - bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias - of this layer. If it is set to False, no bias will be added to the output units. - If it is set to None, the bias is initialized zero. Default: None. - act (str|None): Activation to be applied to the output of this layer. - is_test(bool): A flag indicating whether execution is in test phase. Default: False - dtype(str): Dtype used for weight - Raises: - ValueError: If rank of the input tensor is less than 2. - Examples: - .. code-block:: python - from paddle.fluid.dygraph.base import to_variable - import paddle.fluid as fluid - from paddle.fluid.dygraph import FC - import numpy as np - data = np.random.uniform( -1, 1, [30, 10, 32] ).astype('float32') - with fluid.dygraph.guard(): - fc = FC( "fc", 64, num_flatten_dims=2) - data = to_variable( data ) - conv = fc( data ) - """ - - def __init__(self, - name_scope, - size, - num_flatten_dims=1, - epsilon=1e-30, - param_attr=None, - bias_attr=None, - act=None, - is_test=False, - dtype="float32"): - super(FC, self).__init__(name_scope, dtype) - - self._size = size - self._num_flatten_dims = num_flatten_dims - self._epsilon = epsilon - self._dtype = dtype - self._param_attr = param_attr - self._bias_attr = bias_attr - self._act = act - self.__g = list() - self.__v = list() - - @property - def _v(self, i=0): - return self.__v[i] - - @property - def _g(self, i=0): - return self.__g[i] - - @_v.setter - def _v(self, value, i=0): - assert isinstance(value, Parameter) - self.__v[i] = value - - @_g.setter - def _g(self, value, i=0): - assert isinstance(value, Parameter) - self.__g[i] = value - - def _build_once(self, input): - i = 0 - for inp, param in self._helper.iter_inputs_and_params( - input, self._param_attr): - input_shape = inp.shape - - param_shape = [ - reduce(lambda a, b: a * b, - input_shape[self._num_flatten_dims:], 1) - ] + [self._size] - self.__v.append( - self.add_parameter( - "_v%d" % i, - self.create_parameter( - attr=param, - shape=param_shape, - dtype=self._dtype, - is_bias=False))) - - magnitude_shape = param_shape[1:] - magnitude_value = np.linalg.norm( - self.__v[i].numpy(), ord=2, axis=0) - - self.__g.append( - self.add_parameter( - "_g%d" % i, - self.create_parameter( - attr=fluid.ParamAttr(initializer=fluid.initializer. - NumpyArrayInitializer( - magnitude_value)), - shape=magnitude_shape, - dtype=self._dtype, - is_bias=False))) - i += 1 - - size = list([self._size]) - self._b = self.create_parameter( - attr=self._bias_attr, shape=size, dtype=self._dtype, is_bias=True) - - def forward(self, input): - mul_results = list() - i = 0 - for inp, param in self._helper.iter_inputs_and_params( - input, self._param_attr): - v_norm = self._helper.create_variable_for_type_inference( - self._dtype) - v_normalized = self._helper.create_variable_for_type_inference( - self._dtype) - self._helper.append_op( - type="norm", - inputs={"X": self.__v[i]}, - outputs={"Out": v_normalized, - "Norm": v_norm}, - attrs={"axis": 0, - "epsilon": self._epsilon}) - weight = self._helper.create_variable_for_type_inference( - self._dtype) - self._helper.append_op( - type="elementwise_mul", - inputs={"X": [v_normalized], - "Y": [self.__g[i]]}, - outputs={"Out": [weight]}, - attrs={"axis": 1}) - tmp = self._helper.create_variable_for_type_inference(self._dtype) - self._helper.append_op( - type="mul", - inputs={"X": inp, - "Y": weight}, - outputs={"Out": tmp}, - attrs={ - "x_num_col_dims": self._num_flatten_dims, - "y_num_col_dims": 1 - }) - i += 1 - mul_results.append(tmp) - - if len(mul_results) == 1: - pre_bias = mul_results[0] - else: - pre_bias = self._helper.create_variable_for_type_inference( - self._dtype) - self._helper.append_op( - type="sum", - inputs={"X": mul_results}, - outputs={"Out": pre_bias}, - attrs={"use_mkldnn": False}) - - if self._b: - pre_activation = self._helper.create_variable_for_type_inference( - dtype=self._dtype) - self._helper.append_op( - type="elementwise_add", - inputs={"X": [pre_bias], - "Y": [self._b]}, - outputs={"Out": [pre_activation]}, - attrs={"axis": self._num_flatten_dims}) - else: - pre_activation = pre_bias - # Currently, we don't support inplace in dygraph mode - return self._helper.append_activation(pre_activation, act=self._act) - - -class Conv2D(dg.Layer): - """ - The convolution2D layer calculates the output based on the input, filter - and strides, paddings, dilations, groups parameters. Input and - Output are in NCHW format, where N is batch size, C is the number of - channels, H is the height of the feature, and W is the width of the feature. - Filter is in MCHW format, where M is the number of output image channels, - C is the number of input image channels, H is the height of the filter, - and W is the width of the filter. If the groups is greater than 1, - C will equal the number of input image channels divided by the groups. - Please refer to UFLDL's `convolution - ` - for more detials. - If bias attribution and activation type are provided, bias is added to the - output of the convolution, and the corresponding activation function is - applied to the final result. - For each input :math:`X`, the equation is: - .. math:: - Out = \sigma ((Vg) \\ast X + b) - Where: - * :math:`X`: Input value, a tensor with NCHW format. - * :math:`V`: Filter direction value, a tensor with MCHW format. - * :math:`g`: Filter magnitude value, a tensor with M format. - * :math:`\\ast`: Convolution operation. - * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. - * :math:`\\sigma`: Activation function. - * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. - Example: - - Input: - Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` - Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)` - - Output: - Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` - Where - .. math:: - H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\ - W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 - Args: - name_scope(str) : The name for this class. - num_filters(int): The number of filter. It is as same as the output - image channel. - filter_size (int|tuple|None): The filter size. If filter_size is a tuple, - it must contain two integers, (filter_size_H, filter_size_W). - Otherwise, the filter will be a square. - stride (int|tuple): The stride size. If stride is a tuple, it must - contain two integers, (stride_H, stride_W). Otherwise, the - stride_H = stride_W = stride. Default: stride = 1. - padding (int|tuple): The padding size. If padding is a tuple, it must - contain two integers, (padding_H, padding_W). Otherwise, the - padding_H = padding_W = padding. Default: padding = 0. - dilation (int|tuple): The dilation size. If dilation is a tuple, it must - contain two integers, (dilation_H, dilation_W). Otherwise, the - dilation_H = dilation_W = dilation. Default: dilation = 1. - groups (int): The groups number of the Conv2d Layer. According to grouped - convolution in Alex Krizhevsky's Deep CNN paper: when group=2, - the first half of the filters is only connected to the first half - of the input channels, while the second half of the filters is only - connected to the second half of the input channels. Default: groups=1. - param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights - of conv2d. If it is set to None or one attribute of ParamAttr, conv2d - will create ParamAttr as param_attr. If the Initializer of the param_attr - is not set, the parameter is initialized with :math:`Normal(0.0, std)`, - and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None. - bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d. - If it is set to False, no bias will be added to the output units. - If it is set to None or one attribute of ParamAttr, conv2d - will create ParamAttr as bias_attr. If the Initializer of the bias_attr - is not set, the bias is initialized zero. Default: None. - use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn - library is installed. Default: True - act (str): Activation type, if it is set to None, activation is not appended. - Default: None - Raises: - ValueError: If the shapes of input, filter_size, stride, padding and - groups mismatch. - Examples: - .. code-block:: python - from paddle.fluid.dygraph.base import to_variable - import paddle.fluid as fluid - from paddle.fluid.dygraph import Conv2D - import numpy as np - data = np.random.uniform( -1, 1, [10, 3, 32, 32] ).astype('float32') - with fluid.dygraph.guard(): - conv2d = Conv2D( "conv2d", 2, 3) - data = to_variable( data ) - conv = conv2d( data ) - """ - - def __init__(self, - name_scope, - num_filters, - filter_size, - stride=1, - padding=0, - dilation=1, - groups=None, - param_attr=None, - bias_attr=None, - use_cudnn=True, - act=None, - epsilon=1e-30, - dtype="float32"): - assert param_attr is not False, "param_attr should not be False here." - super(Conv2D, self).__init__(name_scope, dtype) - self._groups = groups - self._stride = utils.convert_to_list(stride, 2, "stride") - self._padding = utils.convert_to_list(padding, 2, "padding") - self._dilation = utils.convert_to_list(dilation, 2, "dilation") - self._act = act - if not isinstance(use_cudnn, bool): - raise ValueError("use_cudnn should be True or False") - self._use_cudnn = use_cudnn - self._filter_size = filter_size - self._num_filters = num_filters - self._param_attr = param_attr - self._bias_attr = bias_attr - self._epsilon = epsilon - self._dtype = dtype - # if (self._num_channels == self._groups and - # num_filters % self._num_channels == 0 and not self._use_cudnn): - # self._l_type = 'depthwise_conv2d' - # else: - # TODO(jiabin): recover the usage of depthwise_conv2d when it's - # kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17275 - self._l_type = "conv2d" - - def _build_once(self, input): - self._num_channels = input.shape[1] - if self._groups is None: - num_filter_channels = self._num_channels - else: - if self._num_channels % self._groups != 0: - raise ValueError("num_channels must be divisible by groups.") - num_filter_channels = self._num_channels // self._groups - filter_size = utils.convert_to_list(self._filter_size, 2, - "filter_size") - filter_shape = [self._num_filters, int(num_filter_channels) - ] + filter_size - - def _get_default_param_initializer(): - filter_elem_num = filter_size[0] * filter_size[ - 1] * self._num_channels - std = (2.0 / filter_elem_num)**0.5 - return Normal(0.0, std, 0) - - # weight_v - self._filter_param_v = self.create_parameter( - attr=self._param_attr, - shape=filter_shape, - dtype=self._dtype, - default_initializer=_get_default_param_initializer()) - - # weight_g - norm_value = _norm( - self._filter_param_v.numpy(), dim=0) # CAUTION: hard-code - self._filter_param_g = self.create_parameter( - attr=fluid.ParamAttr( - initializer=fluid.initializer.NumpyArrayInitializer( - norm_value)), - shape=norm_value.shape, - dtype=self._dtype, - default_initializer=_get_default_param_initializer()) - - if self._use_cudnn: - self.create_variable( - name="kCUDNNFwdAlgoCache", - persistable=True, - type=core.VarDesc.VarType.RAW) - self.create_variable( - name="kCUDNNBwdDataAlgoCache", - persistable=True, - type=core.VarDesc.VarType.RAW) - self.create_variable( - name="kCUDNNBwdFilterAlgoCache", - persistable=True, - type=core.VarDesc.VarType.RAW) - - self._bias_param = self.create_parameter( - attr=self._bias_attr, - shape=[self._num_filters], - dtype=self._dtype, - is_bias=True) - - def forward(self, input): - matrix = self._helper.create_variable_for_type_inference(self._dtype) - tmp = self._helper.create_variable_for_type_inference(self._dtype) - new_shape = [ - self._filter_param_v.shape[0], - reduce(lambda x, y: x * y, self._filter_param_v.shape[1:], 1), - ] - - self._helper.append_op( - type="reshape2", - inputs={"X": self._filter_param_v}, - attrs={"shape": new_shape}, - outputs={"Out": matrix, - "XShape": tmp}) - - m_norm = self._helper.create_variable_for_type_inference(self._dtype) - m_normalized = self._helper.create_variable_for_type_inference( - self._dtype) - self._helper.append_op( - type="norm", - inputs={"X": matrix}, - outputs={"Out": m_normalized, - "Norm": m_norm}, - attrs={"axis": 1, - "epsilon": self._epsilon}) - - v_normalized = self._helper.create_variable_for_type_inference( - self._dtype) - tmp2 = self._helper.create_variable_for_type_inference(self._dtype) - self._helper.append_op( - type="reshape2", - inputs={"X": m_normalized}, - attrs={"shape": self._filter_param_v.shape}, - outputs={"Out": v_normalized, - "XShape": tmp2}) - - filter_param = self._helper.create_variable_for_type_inference( - self._dtype) - self._helper.append_op( - type="elementwise_mul", - inputs={"X": [v_normalized], - "Y": [self._filter_param_g]}, - outputs={"Out": [filter_param]}, - attrs={"axis": 0}, # CAUTION: hard-code - ) - - pre_bias = self._helper.create_variable_for_type_inference( - dtype=self._dtype) - - self._helper.append_op( - type=self._l_type, - inputs={"Input": input, - "Filter": filter_param}, - outputs={"Output": pre_bias}, - attrs={ - "strides": self._stride, - "paddings": self._padding, - "dilations": self._dilation, - "groups": self._groups if self._groups else 1, - "use_cudnn": self._use_cudnn, - "use_mkldnn": False, - }) - - if self._bias_param is not None: - pre_act = self._helper.create_variable_for_type_inference( - dtype=self._dtype) - self._helper.append_op( - type="elementwise_add", - inputs={"X": [pre_bias], - "Y": [self._bias_param]}, - outputs={"Out": [pre_act]}, - attrs={"axis": 1}) - else: - pre_act = pre_bias - - # Currently, we don't support inplace in dygraph mode - return self._helper.append_activation(pre_act, act=self._act) - - -class Conv2DTranspose(dg.Layer): - """ - **Convlution2D transpose layer** - The convolution2D transpose layer calculates the output based on the input, - filter, and dilations, strides, paddings. Input(Input) and output(Output) - are in NCHW format. Where N is batch size, C is the number of channels, - H is the height of the feature, and W is the width of the feature. - Parameters(dilations, strides, paddings) are two elements. These two elements - represent height and width, respectively. The details of convolution transpose - layer, please refer to the following explanation and references - `therein `_. - If bias attribution and activation type are provided, bias is added to - the output of the convolution, and the corresponding activation function - is applied to the final result. - For each input :math:`X`, the equation is: - .. math:: - Out = \sigma ((Vg) \\ast X + b) - Where: - * :math:`X`: Input value, a tensor with NCHW format. - * :math:`V`: Filter value, a tensor with MCHW format. - * :math:`g`: Filter value, a tensor with M format. - * :math:`\\ast`: Convolution operation. - * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. - * :math:`\\sigma`: Activation function. - * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. - Example: - - Input: - Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` - Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)` - - Output: - Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` - Where - .. math:: - H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\ - W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\ - H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\ - W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] ) - Args: - name_scope(str): The name of this class. - num_filters(int): The number of the filter. It is as same as the output - image channel. - output_size(int|tuple|None): The output image size. If output size is a - tuple, it must contain two integers, (image_H, image_W). None if use - filter_size, padding, and stride to calculate output_size. - if output_size and filter_size are specified at the same time, They - should follow the formula above. Default: None. - filter_size(int|tuple|None): The filter size. If filter_size is a tuple, - it must contain two integers, (filter_size_H, filter_size_W). - Otherwise, the filter will be a square. None if use output size to - calculate filter_size. Default: None. - padding(int|tuple): The padding size. If padding is a tuple, it must - contain two integers, (padding_H, padding_W). Otherwise, the - padding_H = padding_W = padding. Default: padding = 0. - stride(int|tuple): The stride size. If stride is a tuple, it must - contain two integers, (stride_H, stride_W). Otherwise, the - stride_H = stride_W = stride. Default: stride = 1. - dilation(int|tuple): The dilation size. If dilation is a tuple, it must - contain two integers, (dilation_H, dilation_W). Otherwise, the - dilation_H = dilation_W = dilation. Default: dilation = 1. - groups(int): The groups number of the Conv2d transpose layer. Inspired by - grouped convolution in Alex Krizhevsky's Deep CNN paper, in which - when group=2, the first half of the filters is only connected to the - first half of the input channels, while the second half of the - filters is only connected to the second half of the input channels. - Default: groups = 1. - param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights - of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose - will create ParamAttr as param_attr. If the Initializer of the param_attr - is not set, the parameter is initialized with Xavier. Default: None. - bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d_transpose. - If it is set to False, no bias will be added to the output units. - If it is set to None or one attribute of ParamAttr, conv2d_transpose - will create ParamAttr as bias_attr. If the Initializer of the bias_attr - is not set, the bias is initialized zero. Default: None. - use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn - library is installed. Default: True. - act (str): Activation type, if it is set to None, activation is not appended. - Default: None. - Returns: - Variable: The tensor variable storing the convolution transpose result. - Raises: - ValueError: If the shapes of input, filter_size, stride, padding and - groups mismatch. - Examples: - .. code-block:: python - import paddle.fluid as fluid - import numpy - with fluid.dygraph.guard(): - data = numpy.random.random((3, 32, 32)).astype('float32') - conv2DTranspose = fluid.dygraph.nn.Conv2DTranspose( - 'Conv2DTranspose', num_filters=2, filter_size=3) - ret = conv2DTranspose(fluid.dygraph.base.to_variable(data)) - """ - - def __init__(self, - name_scope, - num_filters, - output_size=None, - filter_size=None, - padding=0, - stride=1, - dilation=1, - groups=None, - param_attr=None, - bias_attr=None, - use_cudnn=True, - epsilon=1e-30, - act=None, - dtype="float32"): - super(Conv2DTranspose, self).__init__(name_scope, dtype) - assert (param_attr is not False - ), "param_attr should not be False in conv2d_transpose." - self._param_attr = param_attr - self._bias_attr = bias_attr - self._groups = groups - self._num_filters = num_filters - self._use_cudnn = use_cudnn - self._padding = padding - self._stride = stride - self._dilation = dilation - self._filter_size = filter_size - self._output_size = output_size - self._op_type = "conv2d_transpose" - self._epsilon = epsilon - - def _build_once(self, input): - input_channel = input.shape[1] - if (input_channel == self._groups and - self._num_filters == input_channel and not self._use_cudnn): - self._op_type = "depthwise_conv2d_transpose" - - if not isinstance(input, Variable): - raise TypeError("Input of conv2d_transpose must be Variable") - - self._padding = utils.convert_to_list(self._padding, 2, "padding") - self._stride = utils.convert_to_list(self._stride, 2, "stride") - self._dilation = utils.convert_to_list(self._dilation, 2, "dilation") - - if not isinstance(self._use_cudnn, bool): - raise ValueError("use_cudnn should be True or False") - - if self._filter_size is None: - if self._output_size is None: - raise ValueError( - "output_size must be set when filter_size is None") - if isinstance(self._output_size, int): - self._output_size = [self._output_size, self._output_size] - - h_in = input.shape[2] - w_in = input.shape[3] - - filter_size_h = (self._output_size[0] - - (h_in - 1) * self._stride[0] + 2 * - self._padding[0] - 1) // self._dilation[0] + 1 - filter_size_w = (self._output_size[1] - - (w_in - 1) * self._stride[1] + 2 * - self._padding[1] - 1) // self._dilation[1] + 1 - self._filter_size = [filter_size_h, filter_size_w] - else: - self._filter_size = utils.convert_to_list( - self._filter_size, 2, "conv2d_transpose.filter_size") - - if self._output_size is None: - self._output_size = [] - elif isinstance(self._output_size, list) or isinstance( - self._output_size, int): - self._output_size = utils.convert_to_list(self._output_size, 2, - "output_size") - else: - raise ValueError("output_size should be list or int") - self._padding = utils.convert_to_list(self._padding, 2, "padding") - self._groups = 1 if self._groups is None else self._groups - filter_shape = [ - input_channel, - self._num_filters // self._groups, - ] + self._filter_size - - # img filter v (direction) - self._img_filter_v = self.create_parameter( - dtype=input.dtype, shape=filter_shape, attr=self._param_attr) - - # img filter g (magnitude) - img_filter_magnitude = _norm( - self._img_filter_v.numpy(), dim=0) # CAUTION: hard-code - self._img_filter_g = self.create_parameter( - dtype=input.dtype, - shape=img_filter_magnitude.shape, - attr=fluid.ParamAttr( - initializer=NumpyArrayInitializer(img_filter_magnitude))) - - self._img_bias = self.create_parameter( - attr=self._bias_attr, - shape=[self._num_filters], - dtype=self._dtype, - is_bias=True) - - def forward(self, input): - matrix = self._helper.create_variable_for_type_inference(self._dtype) - tmp = self._helper.create_variable_for_type_inference(self._dtype) - new_shape = [ - self._img_filter_v.shape[0], - reduce(lambda x, y: x * y, self._img_filter_v.shape[1:], 1), - ] - - self._helper.append_op( - type="reshape2", - inputs={"X": self._img_filter_v}, - attrs={"shape": new_shape}, - outputs={"Out": matrix, - "XShape": tmp}) - - m_norm = self._helper.create_variable_for_type_inference(self._dtype) - m_normalized = self._helper.create_variable_for_type_inference( - self._dtype) - self._helper.append_op( - type="norm", - inputs={"X": matrix}, - outputs={"Out": m_normalized, - "Norm": m_norm}, - attrs={"axis": 1, - "epsilon": self._epsilon}) - - v_normalized = self._helper.create_variable_for_type_inference( - self._dtype) - tmp2 = self._helper.create_variable_for_type_inference(self._dtype) - self._helper.append_op( - type="reshape2", - inputs={"X": m_normalized}, - attrs={"shape": self._img_filter_v.shape}, - outputs={"Out": v_normalized, - "XShape": tmp2}) - - img_filter = self._helper.create_variable_for_type_inference( - self._dtype) - self._helper.append_op( - type="elementwise_mul", - inputs={"X": [v_normalized], - "Y": [self._img_filter_g]}, - outputs={"Out": [img_filter]}, - attrs={"axis": 0}, # CAUTION: hard-code - ) - - pre_bias = self._helper.create_variable_for_type_inference( - dtype=input.dtype) - self._helper.append_op( - type=self._op_type, - inputs={"Input": [input], - "Filter": [img_filter]}, - outputs={"Output": pre_bias}, - attrs={ - "output_size": self._output_size, - "strides": self._stride, - "paddings": self._padding, - "dilations": self._dilation, - "groups": self._groups, - "use_cudnn": self._use_cudnn, - }) - - if self._img_bias is not None: - pre_act = self._helper.create_variable_for_type_inference( - dtype=self._dtype) - self._helper.append_op( - type="elementwise_add", - inputs={"X": [pre_bias], - "Y": [self._img_bias]}, - outputs={"Out": [pre_act]}, - attrs={"axis": 1}) - else: - pre_act = pre_bias - - out = self._helper.append_activation(pre_act) - return out diff --git a/parakeet/modules/modules.py b/parakeet/modules/modules.py index 4fb92ed..7aef463 100644 --- a/parakeet/modules/modules.py +++ b/parakeet/modules/modules.py @@ -26,6 +26,7 @@ def FC(name_scope, in_features, size, num_flatten_dims=1, + relu=False, dropout=0.0, epsilon=1e-30, act=None, @@ -39,7 +40,11 @@ def FC(name_scope, # stds if isinstance(in_features, int): in_features = [in_features] + stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features] + if relu: + stds = [std * np.sqrt(2.0) for std in stds] + weight_inits = [ fluid.initializer.NormalInitializer(scale=std) for std in stds ] @@ -456,3 +461,152 @@ class PositionEmbedding(dg.Layer): return out else: raise Exception("Then you can just use position rate at init") + + +class Conv1D_GU(dg.Layer): + def __init__(self, + name_scope, + conditioner_dim, + in_channels, + num_filters, + filter_size, + dilation, + causal=False, + residual=True, + dtype="float32"): + super(Conv1D_GU, self).__init__(name_scope, dtype=dtype) + + self.conditioner_dim = conditioner_dim + self.in_channels = in_channels + self.num_filters = num_filters + self.filter_size = filter_size + self.dilation = dilation + self.causal = causal + self.residual = residual + + if residual: + assert ( + in_channels == num_filters + ), "this block uses residual connection"\ + "the input_channels should equals num_filters" + + self.conv = Conv1D( + self.full_name(), + in_channels, + 2 * num_filters, + filter_size, + dilation, + causal=causal, + dtype=dtype) + + self.fc = Conv1D( + self.full_name(), + conditioner_dim, + 2 * num_filters, + filter_size=1, + dilation=1, + causal=False, + dtype=dtype) + + def forward(self, x, skip=None, conditioner=None): + """ + Args: + x (Variable): Shape(B, C_in, 1, T), the input of Conv1D_GU + layer, where B means batch_size, C_in means the input channels + T means input time steps. + skip (Variable): Shape(B, C_in, 1, T), skip connection. + conditioner (Variable): Shape(B, C_con, 1, T), expanded mel + conditioner, where C_con is conditioner hidden dim which + equals the num of mel bands. Note that when using residual + connection, the Conv1D_GU does not change the number of + channels, so out channels equals input channels. + Returns: + x (Variable): Shape(B, C_out, 1, T), the output of Conv1D_GU, where + C_out means the output channels of Conv1D_GU. + skip (Variable): Shape(B, C_out, 1, T), skip connection. + """ + residual = x + x = self.conv(x) + + if conditioner is not None: + cond_bias = self.fc(conditioner) + x += cond_bias + + content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) + + # Gated Unit. + x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), + fluid.layers.tanh(content)) + + if skip is None: + skip = x + else: + skip = fluid.layers.scale(skip + x, np.sqrt(0.5)) + + if self.residual: + x = fluid.layers.scale(residual + x, np.sqrt(0.5)) + + return x, skip + + def add_input(self, x, skip=None, conditioner=None): + """ + Inputs: + x: shape(B, num_filters, 1, time_steps) + skip: shape(B, num_filters, 1, time_steps), skip connection + conditioner: shape(B, conditioner_dim, 1, time_steps) + Outputs: + x: shape(B, num_filters, 1, time_steps), where time_steps = 1 + skip: skip connection, same shape as x + """ + residual = x + + # add step input and produce step output + x = self.conv.add_input(x) + + if conditioner is not None: + cond_bias = self.fc(conditioner) + x += cond_bias + + content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) + + # Gated Unit. + x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), + fluid.layers.tanh(content)) + + if skip is None: + skip = x + else: + skip = fluid.layers.scale(skip + x, np.sqrt(0.5)) + + if self.residual: + x = fluid.layers.scale(residual + x, np.sqrt(0.5)) + + return x, skip + + +def Conv2DTranspose(name_scope, + num_filters, + filter_size, + padding=0, + stride=1, + dilation=1, + use_cudnn=True, + act=None, + dtype="float32"): + val = 1.0 / (filter_size[0] * filter_size[1]) + weight_init = fluid.initializer.ConstantInitializer(val) + weight_attr = fluid.ParamAttr(initializer=weight_init) + + layer = weight_norm.Conv2DTranspose( + name_scope, + num_filters, + filter_size=filter_size, + padding=padding, + stride=stride, + dilation=dilation, + param_attr=weight_attr, + use_cudnn=use_cudnn, + act=act, + dtype=dtype) + + return layer From 862e23164d59e2ae09814e7c0f092ec2b9c9e0ca Mon Sep 17 00:00:00 2001 From: zhaokexin01 Date: Wed, 4 Dec 2019 06:42:27 +0800 Subject: [PATCH 3/3] Update README.md --- parakeet/models/wavenet/README.md | 98 ++++++++++++++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/parakeet/models/wavenet/README.md b/parakeet/models/wavenet/README.md index 412a3c8..18efd0b 100644 --- a/parakeet/models/wavenet/README.md +++ b/parakeet/models/wavenet/README.md @@ -1 +1,97 @@ -# WaveNet-Paddle \ No newline at end of file +# WaveNet with Paddle Fluid + +Paddle fluid implementation of WaveNet, a deep generative model of raw audio waveforms. +WaveNet model is originally proposed in [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499). +Our implementation is based on the WaveNet architecture described in [ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech](https://arxiv.org/abs/1807.07281) and can provide various output distributions, including single Gaussian, mixture of Gaussian, and softmax with linearly quantized channels. + +We implement WaveNet model in paddle fluid with dynamic graph, which is convenient for flexible network architectures. + +## Project Structure +```text +├── configs # yaml configuration files of preset model hyperparameters +├── data.py # dataset and dataloader settings for LJSpeech +├── slurm.py # optional slurm helper functions if you use slurm to train model +├── synthesis.py # script for speech synthesis +├── train.py # script for model training +├── utils.py # helper functions for e.g., model checkpointing +├── wavenet.py # WaveNet model high level APIs +└── wavenet_modules.py # WaveNet model implementation +``` + +## Usage + +There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. Hyperparameters that are known to work good for the LJSpeech dataset are provided as yaml files in `./configs/` folder. Specifically, we provide `wavenet_ljspeech_single_gaussian.yaml`, `wavenet_ljspeech_mix_gaussian.yaml`, and `wavenet_ljspeech_softmax.yaml` config files for WaveNet with single Gaussian, 10-component mixture of Gaussians, and softmax (with 2048 linearly quantized channels) output distributions, respectively. + +Note that `train.py` and `synthesis.py` all accept a `--config` parameter. To ensure consistency, you should use the same config yaml file for both training and synthesizing. You can also overwrite these preset hyperparameters with command line by updating parameters after `--config`. For example `--config=${yaml} --batch_size=8 --layers=20` can overwrite the corresponding hyperparameters in the `${yaml}` config file. For more details about these hyperparameters, check `utils.add_config_options_to_parser`. + +Note that you also need to specify some additional parameters for `train.py` and `synthesis.py`, and the details can be found in `train.add_options_to_parser` and `synthesis.add_options_to_parser`, respectively. + +### Dataset + +Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). + +```bash +wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 +tar xjvf LJSpeech-1.1.tar.bz2 +``` + +In this example, assume that the path of unzipped LJSpeech dataset is `./data/LJSpeech-1.1`. + +### Train on single GPU + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0 +python -u train.py --config=${yaml} \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --batch_size=4 \ + --parallel=false --use_gpu=true +``` + +#### Save and Load checkpoints + +Our model will save model parameters as checkpoints in `./runs/wavenet/${ModelName}/checkpoint/` every 10000 iterations by default. +The saved checkpoint will have the format of `step-${iteration_number}.pdparams` for model parameters and `step-${iteration_number}.pdopt` for optimizer parameters. + +There are three ways to load a checkpoint and resume training (take an example that you want to load a 500000-iteration checkpoint): +1. Use `--checkpoint=./runs/wavenet/${ModelName}/checkpoint/step-500000` to provide a specific path to load. Note that you only need to provide the base name of the parameter file, which is `step-500000`, no extension name `.pdparams` or `.pdopt` is needed. +2. Use `--iteration=500000`. +3. If you don't specify either `--checkpoint` or `--iteration`, the model will automatically load the latest checkpoint in `./runs/wavenet/${ModelName}/checkpoint`. + +### Train on multiple GPUs + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0,1,2,3 +python -u -m paddle.distributed.launch train.py \ + --config=${yaml} \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --parallel=true --use_gpu=true +``` + +Use `export CUDA_VISIBLE_DEVICES=0,1,2,3` to set the GPUs that you want to use to be visible. Then the `paddle.distributed.launch` module will use these visible GPUs to do data parallel training in multiprocessing mode. + +### Monitor with Tensorboard + +By default, the logs are saved in `./runs/wavenet/${ModelName}/logs/`. You can monitor logs by tensorboard. + +```bash +tensorboard --logdir=${log_dir} --port=8888 +``` + +### Synthesize from a checkpoint + +Check the [Save and load checkpoint](#save-and-load-checkpoints) section on how to load a specific checkpoint. +The following example will automatically load the latest checkpoint: + +```bash +export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.." +export CUDA_VISIBLE_DEVICES=0 +python -u synthesis.py --config=${yaml} \ + --root=./data/LJSpeech-1.1 \ + --name=${ModelName} --use_gpu=true \ + --output=./syn_audios \ + --sample=${SAMPLE} +``` + +In this example, `--output` specifies where to save the synthesized audios and `--sample` specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset.