diff --git a/.gitignore b/.gitignore index 9e0ff35..13dd63d 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,10 @@ venv.bak/ dmypy.json # Pyre type checker -.pyre/ \ No newline at end of file +.pyre/ + +# Shell, vim, and output folder +*.sh +*.swp +runs +syn_audios diff --git a/parakeet/data/datacargo.py b/parakeet/data/datacargo.py index da6bc9a..1d7d8d5 100644 --- a/parakeet/data/datacargo.py +++ b/parakeet/data/datacargo.py @@ -31,6 +31,9 @@ class DataCargo(object): def __iter__(self): return DataIterator(self) + + def __call__(self): + return DataIterator(self) @property def _auto_collation(self): diff --git a/parakeet/models/wavenet/README.md b/parakeet/models/wavenet/README.md new file mode 100644 index 0000000..412a3c8 --- /dev/null +++ b/parakeet/models/wavenet/README.md @@ -0,0 +1 @@ +# WaveNet-Paddle \ No newline at end of file diff --git a/parakeet/models/wavenet/configs/wavenet_ljspeech_single_gaussian.yaml b/parakeet/models/wavenet/configs/wavenet_ljspeech_single_gaussian.yaml new file mode 100644 index 0000000..fc7222b --- /dev/null +++ b/parakeet/models/wavenet/configs/wavenet_ljspeech_single_gaussian.yaml @@ -0,0 +1,32 @@ +valid_size: 16 +train_clip_second: 0.5 +sample_rate: 22050 +fft_window_shift: 256 +fft_window_size: 1024 +fft_size: 2048 +mel_bands: 80 + +seed: 1 +batch_size: 8 +test_every: 2000 +save_every: 10000 +max_iterations: 2000000 + +layers: 30 +kernel_width: 2 +dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +residual_channels: 128 +skip_channels: 128 +loss_type: mix-gaussian-pdf +num_mixtures: 1 +log_scale_min: -9.0 + +conditioner: + filter_sizes: [[32, 3], [32, 3]] + upsample_factors: [16, 16] + +learning_rate: 0.001 +gradient_max_norm: 100.0 +anneal: + every: 200000 + rate: 0.5 diff --git a/parakeet/models/wavenet/data.py b/parakeet/models/wavenet/data.py new file mode 100644 index 0000000..61cc4ab --- /dev/null +++ b/parakeet/models/wavenet/data.py @@ -0,0 +1,191 @@ +import math +import os +import random + +import librosa +import numpy as np +from paddle import fluid + +import utils +from parakeet.datasets import ljspeech +from parakeet.data import dataset +from parakeet.data.sampler import Sampler, BatchSampler, SequentialSampler +from parakeet.data.datacargo import DataCargo + + +class Dataset(ljspeech.LJSpeech): + def __init__(self, config): + super(Dataset, self).__init__(config.root) + self.config = config + self.fft_window_shift = config.fft_window_shift + # Calculate context frames. + frames_per_second = config.sample_rate // self.fft_window_shift + train_clip_frames = int(math.ceil( + config.train_clip_second * frames_per_second)) + context_frames = config.context_size // self.fft_window_shift + self.num_frames = train_clip_frames + context_frames + + def _get_example(self, metadatum): + fname, _, _ = metadatum + wav_path = self.root.joinpath("wavs", fname + ".wav") + + config = self.config + sr = config.sample_rate + fft_window_shift = config.fft_window_shift + fft_window_size = config.fft_window_size + fft_size = config.fft_size + + audio, loaded_sr = librosa.load(wav_path, sr=None) + assert loaded_sr == sr + + # Pad audio to the right size. + frames = math.ceil(float(audio.size) / fft_window_shift) + fft_padding = (fft_size - fft_window_shift) // 2 + desired_length = frames * fft_window_shift + fft_padding * 2 + pad_amount = (desired_length - audio.size) // 2 + + if audio.size % 2 == 0: + audio = np.pad(audio, (pad_amount, pad_amount), mode='reflect') + else: + audio = np.pad(audio, (pad_amount, pad_amount + 1), mode='reflect') + + # Normalize audio. + audio = audio / np.abs(audio).max() * 0.999 + + # Compute mel-spectrogram. + # Turn center to False to prevent internal padding. + spectrogram = librosa.core.stft( + audio, hop_length=fft_window_shift, + win_length=fft_window_size, n_fft=fft_size, center=False) + spectrogram_magnitude = np.abs(spectrogram) + + # Compute mel-spectrograms. + mel_filter_bank = librosa.filters.mel(sr=sr, n_fft=fft_size, + n_mels=config.mel_bands) + mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude) + mel_spectrogram = mel_spectrogram.T + + # Rescale mel_spectrogram. + min_level, ref_level = 1e-5, 20 + mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram)) + mel_spectrogram = mel_spectrogram - ref_level + mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1) + + # Extract the center of audio that corresponds to mel spectrograms. + audio = audio[fft_padding : -fft_padding] + assert mel_spectrogram.shape[0] * fft_window_shift == audio.size + + return audio, mel_spectrogram + + +class Subset(dataset.Dataset): + def __init__(self, dataset, indices, valid): + self.dataset = dataset + self.indices = indices + self.valid = valid + + def __getitem__(self, idx): + fft_window_shift = self.dataset.fft_window_shift + num_frames = self.dataset.num_frames + audio, mel = self.dataset[self.indices[idx]] + + if self.valid: + audio_start = 0 + else: + # Randomly crop context + train_clip_second of audio. + audio_frames = int(audio.size) // fft_window_shift + max_start_frame = audio_frames - num_frames + assert max_start_frame >= 0, "audio {} is too short".format(idx) + + frame_start = random.randint(0, max_start_frame) + frame_end = frame_start + num_frames + + audio_start = frame_start * fft_window_shift + audio_end = frame_end * fft_window_shift + + audio = audio[audio_start : audio_end] + + return audio, mel, audio_start + + def _batch_examples(self, batch): + audios = [sample[0] for sample in batch] + audio_starts = [sample[2] for sample in batch] + + # mels shape [num_frames, mel_bands] + max_frames = max(sample[1].shape[0] for sample in batch) + mels = [utils.pad_to_size(sample[1], max_frames) for sample in batch] + + audios = np.array(audios, dtype=np.float32) + mels = np.array(mels, dtype=np.float32) + audio_starts = np.array(audio_starts, dtype=np.int32) + + return audios, mels, audio_starts + + def __len__(self): + return len(self.indices) + + +class DistributedSampler(Sampler): + def __init__(self, dataset_size, num_trainers, rank, shuffle=True): + self.dataset_size = dataset_size + self.num_trainers = num_trainers + self.rank = rank + self.num_samples = int(math.ceil(dataset_size / num_trainers)) + self.total_size = self.num_samples * num_trainers + assert self.total_size >= self.dataset_size + self.shuffle = shuffle + + def __iter__(self): + indices = list(range(self.dataset_size)) + if self.shuffle: + random.shuffle(indices) + + # Append extra samples to make it evenly distributed on all trainers. + indices += indices[:(self.total_size - self.dataset_size)] + assert len(indices) == self.total_size + + # Subset samples for each trainer. + indices = indices[self.rank:self.total_size:self.num_trainers] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + +class LJSpeech: + def __init__(self, config, nranks, rank): + place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace() + + # Whole LJSpeech dataset. + ds = Dataset(config) + + # Split into train and valid dataset. + indices = list(range(len(ds))) + train_indices = indices[config.valid_size:] + valid_indices = indices[:config.valid_size] + random.shuffle(train_indices) + + # Train dataset. + trainset = Subset(ds, train_indices, valid=False) + sampler = DistributedSampler(len(trainset), nranks, rank) + total_bs = config.batch_size + assert total_bs % nranks == 0 + train_sampler = BatchSampler(sampler, total_bs // nranks, + drop_last=True) + trainloader = DataCargo(trainset, batch_sampler=train_sampler) + + trainreader = fluid.io.PyReader(capacity=50, return_list=True) + trainreader.decorate_batch_generator(trainloader, place) + self.trainloader = (data for _ in iter(int, 1) + for data in trainreader()) + + # Valid dataset. + validset = Subset(ds, valid_indices, valid=True) + # Currently only support batch_size = 1 for valid loader. + validloader = DataCargo(validset, batch_size=1, shuffle=False) + + validreader = fluid.io.PyReader(capacity=20, return_list=True) + validreader.decorate_batch_generator(validloader, place) + self.validloader = validreader diff --git a/parakeet/models/wavenet/ops.py b/parakeet/models/wavenet/ops.py new file mode 100644 index 0000000..6eda2a9 --- /dev/null +++ b/parakeet/models/wavenet/ops.py @@ -0,0 +1,249 @@ +import paddle +from paddle import fluid +import paddle.fluid.dygraph as dg +import numpy as np + +import weight_norm + + +def Embedding(name_scope, + num_embeddings, + embed_dim, + padding_idx=None, + std=0.1, + dtype="float32"): + # param attrs + weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal( + scale=std)) + layer = dg.Embedding( + name_scope, (num_embeddings, embed_dim), + padding_idx=padding_idx, + param_attr=weight_attr, + dtype=dtype) + return layer + + +def FC(name_scope, + in_features, + size, + num_flatten_dims=1, + relu=False, + dropout=0.0, + act=None, + dtype="float32"): + """ + A special Linear Layer, when it is used with dropout, the weight is + initialized as normal(0, std=np.sqrt((1-dropout) / in_features)) + """ + + # stds + if isinstance(in_features, int): + in_features = [in_features] + + stds = [np.sqrt((1.0 - dropout) / in_feature) for in_feature in in_features] + if relu: + stds = [std * np.sqrt(2.0) for std in stds] + + weight_inits = [ + fluid.initializer.NormalInitializer(scale=std) for std in stds + ] + bias_init = fluid.initializer.ConstantInitializer(0.0) + + # param attrs + weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits] + bias_attr = fluid.ParamAttr(initializer=bias_init) + + layer = weight_norm.FC(name_scope, + size, + num_flatten_dims=num_flatten_dims, + param_attr=weight_attrs, + bias_attr=bias_attr, + act=act, + dtype=dtype) + return layer + + +def Conv1D(name_scope, + in_channels, + num_filters, + filter_size=2, + dilation=1, + groups=None, + causal=False, + std_mul=1.0, + dropout=0.0, + use_cudnn=True, + act=None, + dtype="float32"): + """ + A special Conv1D Layer, when it is used with dropout, the weight is + initialized as + normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_channels))) + """ + # std + std = np.sqrt((std_mul * (1.0 - dropout)) / (filter_size * in_channels)) + weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std) + bias_init = fluid.initializer.ConstantInitializer(0.0) + + # param attrs + weight_attr = fluid.ParamAttr(initializer=weight_init) + bias_attr = fluid.ParamAttr(initializer=bias_init) + + layer = weight_norm.Conv1D( + name_scope, + num_filters, + filter_size, + dilation, + groups=groups, + causal=causal, + param_attr=weight_attr, + bias_attr=bias_attr, + use_cudnn=use_cudnn, + act=act, + dtype=dtype) + return layer + + +class Conv1D_GU(dg.Layer): + def __init__(self, + name_scope, + conditioner_dim, + in_channels, + num_filters, + filter_size, + dilation, + causal=False, + residual=True, + dtype="float32"): + super(Conv1D_GU, self).__init__(name_scope, dtype=dtype) + + self.conditioner_dim = conditioner_dim + self.in_channels = in_channels + self.num_filters = num_filters + self.filter_size = filter_size + self.dilation = dilation + self.causal = causal + self.residual = residual + + if residual: + assert ( + in_channels == num_filters + ), "this block uses residual connection"\ + "the input_channels should equals num_filters" + + self.conv = Conv1D( + self.full_name(), + in_channels, + 2 * num_filters, + filter_size, + dilation, + causal=causal, + dtype=dtype) + + self.fc = Conv1D( + self.full_name(), + conditioner_dim, + 2 * num_filters, + filter_size=1, + dilation=1, + causal=False, + dtype=dtype) + + def forward(self, x, skip=None, conditioner=None): + """ + Args: + x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU + layer, where B means batch_size, C_in means the input channels + T means input time steps. + conditioner (Variable): Shape(B, C_con, 1, T), expanded mel + conditioner, where C_con is conditioner hidden dim which + equals the num of mel bands. Note that when using residual + connection, the Conv1DGLU does not change the number of + channels, so out channels equals input channels. + Returns: + x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where + C_out means the output channels of Conv1DGLU. + """ + residual = x + x = self.conv(x) + + if conditioner is not None: + cond_bias = self.fc(conditioner) + x += cond_bias + + content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) + + # Gated Unit. + x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), + fluid.layers.tanh(content)) + + if skip is None: + skip = x + else: + skip = fluid.layers.scale(skip + x, np.sqrt(0.5)) + + if self.residual: + x = fluid.layers.scale(residual + x, np.sqrt(0.5)) + + return x, skip + + def add_input(self, x, skip=None, conditioner=None): + """ + Inputs: + x: shape(B, num_filters, 1, time_steps) + conditioner: shape(B, conditioner_dim, 1, time_steps) + Outputs: + out: shape(B, num_filters, 1, time_steps), where time_steps = 1 + """ + residual = x + + # add step input and produce step output + x = self.conv.add_input(x) + + if conditioner is not None: + cond_bias = self.fc(conditioner) + x += cond_bias + + content, gate = fluid.layers.split(x, num_or_sections=2, dim=1) + + # Gated Unit. + x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate), + fluid.layers.tanh(content)) + + if skip is None: + skip = x + else: + skip = fluid.layers.scale(skip + x, np.sqrt(0.5)) + + if self.residual: + x = fluid.layers.scale(residual + x, np.sqrt(0.5)) + + return x, skip + + +def Conv2DTranspose(name_scope, + num_filters, + filter_size, + padding=0, + stride=1, + dilation=1, + use_cudnn=True, + act=None, + dtype="float32"): + val = 1.0 / (filter_size[0] * filter_size[1]) + weight_init = fluid.initializer.ConstantInitializer(val) + weight_attr = fluid.ParamAttr(initializer=weight_init) + + layer = weight_norm.Conv2DTranspose( + name_scope, + num_filters, + filter_size=filter_size, + padding=padding, + stride=stride, + dilation=dilation, + param_attr=weight_attr, + use_cudnn=use_cudnn, + act=act, + dtype=dtype) + + return layer diff --git a/parakeet/models/wavenet/slurm.py b/parakeet/models/wavenet/slurm.py new file mode 100644 index 0000000..47af2dc --- /dev/null +++ b/parakeet/models/wavenet/slurm.py @@ -0,0 +1,112 @@ +""" +Utility module for restarting training when using SLURM. +""" +import subprocess +import os +import sys +import shlex +import re +import time + + +def job_info(): + """Get information about the current job using `scontrol show job`. + Returns a dict mapping parameter names (e.g. "UserId", "RunTime", etc) to + their values, both as strings. + """ + job_id = int(os.environ["SLURM_JOB_ID"]) + + command = ["scontrol", "show", "job", str(job_id)] + output = subprocess.check_output(command).decode("utf-8") + + # Use a regex to extract the parameter names and values + pattern = "([A-Za-z/]*)=([^ \t\n]*)" + return dict(re.findall(pattern, output)) + + +def parse_hours(text): + """Parse a time format HH or DD-HH into a number of hours.""" + hour_chunks = text.split("-") + if len(hour_chunks) == 1: + return int(hour_chunks[0]) + elif len(hour_chunks) == 2: + return 24 * int(hour_chunks[0]) + int(hour_chunks[1]) + else: + raise ValueError("Unexpected hour format (expected HH or " + "DD-HH, but got {}).".format(text)) + + +def parse_time(text): + """Convert slurm time to an integer. + Expects time to be of the form: + "hours:minutes:seconds" or "day-hours:minutes:seconds". + """ + hours, minutes, seconds = text.split(":") + try: + return parse_hours(hours) * 3600 + int(minutes) * 60 + int(seconds) + except ValueError as e: + raise ValueError("Error parsing time {}. Got error {}.".format( + text, str(e))) + + +def restart_command(): + """Using the environment and SLURM command, create a command that, when, + run, will enqueue a repeat of the current job using `sbatch`. + Return the command as a list of strings, suitable for passing to + `subprocess.check_call` or similar functions. + Returns: + resume_command: list, command to run to restart job. + end_time: int or None; the time the job will end or None + if the job has unlimited runtime. + """ + # Make sure `RunTime` could be parsed correctly. + while job_info()["RunTime"] == "INVALID": + time.sleep(1) + + # Get all the necessary information by querying SLURM with this job id + info = job_info() + + try: + num_cpus = int(info["CPUs/Task"]) + except KeyError: + num_cpus = int(os.environ["SLURM_CPUS_PER_TASK"]) + + num_tasks = int(os.environ["SLURM_NTASKS"]) + nodes = info["NumNodes"] + gres, partition = info.get("Gres"), info.get("Partition") + stderr, stdout = info.get("StdErr"), info.get("StdOut") + job_name = info.get("JobName") + command = ["sbatch", "--job-name={}".format(job_name), + "--ntasks={}".format(num_tasks)] + + if partition: + command.extend(["--partition", partition]) + + if gres and gres != "(null)": + command.extend(["--gres", gres]) + num_gpu = int(gres.split(':')[-1]) + print("number of gpu assigned by slurm is {}".format(num_gpu)) + + if stderr: + command.extend(["--error", stderr]) + + if stdout: + command.extend(["--output", stdout]) + + python = subprocess.check_output( + ["/usr/bin/which", "python3"]).decode("utf-8").strip() + dist_setting = ['-m', 'paddle.distributed.launch'] + wrap_cmd = ["srun", python, '-u'] + dist_setting + sys.argv + + command.append( + "--wrap={}".format(" ".join(shlex.quote(arg) for arg in wrap_cmd))) + time_limit_string = info["TimeLimit"] + if time_limit_string.lower() == "unlimited": + print("UNLIMITED detected: restart OFF, infinite learning ON.", + flush=True) + return command, None + time_limit = parse_time(time_limit_string) + runtime = parse_time(info["RunTime"]) + end_time = time.time() + time_limit - runtime + + return command, end_time diff --git a/parakeet/models/wavenet/synthesis.py b/parakeet/models/wavenet/synthesis.py new file mode 100644 index 0000000..d87a188 --- /dev/null +++ b/parakeet/models/wavenet/synthesis.py @@ -0,0 +1,85 @@ +import os +import random +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid + +import utils +from wavenet import WaveNet + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='wavenet', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + + parser.add_argument('--output', type=str, default="./syn_audios", + help="path to write synthesized audio files") + parser.add_argument('--sample', type=int, + help="which of the valid samples to synthesize audio") + + +def synthesize(config): + pprint(jsonargparse.namespace_to_dict(config)) + + # Get checkpoint directory path. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + + # Configurate device. + place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveNet(config, checkpoint_dir) + model.build(training=False) + + # Obtain the current iteration. + if config.checkpoint is None: + if config.iteration is None: + iteration = utils.load_latest_checkpoint(checkpoint_dir) + else: + iteration = config.iteration + else: + iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) + + # Run model inference. + model.infer(iteration) + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser( + description="Synthesize audio using WaveNet model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + synthesize(config) diff --git a/parakeet/models/wavenet/train.py b/parakeet/models/wavenet/train.py new file mode 100644 index 0000000..1a17bbd --- /dev/null +++ b/parakeet/models/wavenet/train.py @@ -0,0 +1,139 @@ +import os +import random +import subprocess +import time +from pprint import pprint + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid +from tensorboardX import SummaryWriter + +import slurm +import utils +from wavenet import WaveNet + +MAXIMUM_SAVE_TIME = 10 * 60 + + +def add_options_to_parser(parser): + parser.add_argument('--model', type=str, default='wavenet', + help="general name of the model") + parser.add_argument('--name', type=str, + help="specific name of the training model") + parser.add_argument('--root', type=str, + help="root path of the LJSpeech dataset") + + parser.add_argument('--parallel', type=bool, default=True, + help="option to use data parallel training") + parser.add_argument('--use_gpu', type=bool, default=True, + help="option to use gpu training") + + parser.add_argument('--iteration', type=int, default=None, + help=("which iteration of checkpoint to load, " + "default to load the latest checkpoint")) + parser.add_argument('--checkpoint', type=str, default=None, + help="path of the checkpoint to load") + parser.add_argument('--slurm', type=bool, default=False, + help="whether you are using slurm to submit training jobs") + + +def train(config): + use_gpu = config.use_gpu + parallel = config.parallel if use_gpu else False + + # Get the rank of the current training process. + rank = dg.parallel.Env().local_rank if parallel else 0 + nranks = dg.parallel.Env().nranks if parallel else 1 + + if rank == 0: + # Print the whole config setting. + pprint(jsonargparse.namespace_to_dict(config)) + + # Make checkpoint directory. + run_dir = os.path.join("runs", config.model, config.name) + checkpoint_dir = os.path.join(run_dir, "checkpoint") + os.makedirs(checkpoint_dir, exist_ok=True) + + # Create tensorboard logger. + tb = SummaryWriter(os.path.join(run_dir, "logs")) \ + if rank == 0 else None + + # Configurate device + place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace() + + with dg.guard(place): + # Fix random seed. + seed = config.seed + random.seed(seed) + np.random.seed(seed) + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + print("Random Seed: ", seed) + + # Build model. + model = WaveNet(config, checkpoint_dir, parallel, rank, nranks, tb) + model.build() + + # Obtain the current iteration. + if config.checkpoint is None: + if config.iteration is None: + iteration = utils.load_latest_checkpoint(checkpoint_dir, rank) + else: + iteration = config.iteration + else: + iteration = int(config.checkpoint.split('/')[-1].split('-')[-1]) + + # Get restart command if using slurm. + if config.slurm: + resume_command, death_time = slurm.restart_command() + if rank == 0: + print("Restart command:", " ".join(resume_command)) + done = False + + while iteration < config.max_iterations: + # Run one single training step. + model.train_step(iteration) + + iteration += 1 + + if iteration % config.test_every == 0: + # Run validation step. + model.valid_step(iteration) + + # Check whether reaching the time limit. + if config.slurm: + done = (death_time is not None and death_time - time.time() < + MAXIMUM_SAVE_TIME) + + if rank == 0 and done: + print("Saving progress before exiting.") + model.save(iteration) + + print("Running restart command:", " ".join(resume_command)) + # Submit restart command. + subprocess.check_call(resume_command) + break + + if rank == 0 and iteration % config.save_every == 0: + # Save parameters. + model.save(iteration) + + # Close TensorBoard. + if rank == 0: + tb.close() + + +if __name__ == "__main__": + # Create parser. + parser = jsonargparse.ArgumentParser(description="Train WaveNet model", + formatter_class='default_argparse') + add_options_to_parser(parser) + utils.add_config_options_to_parser(parser) + + # Parse argument from both command line and yaml config file. + # For conflicting updates to the same field, + # the preceding update will be overwritten by the following one. + config = parser.parse_args() + train(config) diff --git a/parakeet/models/wavenet/utils.py b/parakeet/models/wavenet/utils.py new file mode 100644 index 0000000..c2b6601 --- /dev/null +++ b/parakeet/models/wavenet/utils.py @@ -0,0 +1,143 @@ +import itertools +import os +import time + +import jsonargparse +import numpy as np +import paddle.fluid.dygraph as dg + + +def add_config_options_to_parser(parser): + parser.add_argument('--valid_size', type=int, + help="size of the valid dataset") + parser.add_argument('--train_clip_second', type=float, + help="the length of audio clip for training") + parser.add_argument('--sample_rate', type=int, + help="sampling rate of audio data file") + parser.add_argument('--fft_window_shift', type=int, + help="the shift of fft window for each frame") + parser.add_argument('--fft_window_size', type=int, + help="the size of fft window for each frame") + parser.add_argument('--fft_size', type=int, + help="the size of fft filter on each frame") + parser.add_argument('--mel_bands', type=int, + help="the number of mel bands when calculating mel spectrograms") + + parser.add_argument('--seed', type=int, + help="seed of random initialization for the model") + parser.add_argument('--batch_size', type=int, + help="batch size for training") + parser.add_argument('--test_every', type=int, + help="test interval during training") + parser.add_argument('--save_every', type=int, + help="checkpointing interval during training") + parser.add_argument('--max_iterations', type=int, + help="maximum training iterations") + + parser.add_argument('--layers', type=int, + help="number of dilated convolution layers") + parser.add_argument('--kernel_width', type=int, + help="dilated convolution kernel width") + parser.add_argument('--dilation_block', type=list, + help="dilated convolution kernel width") + parser.add_argument('--residual_channels', type=int) + parser.add_argument('--skip_channels', type=int) + parser.add_argument('--loss_type', type=str, + help="mix-gaussian-pdf or softmax") + parser.add_argument('--num_channels', type=int, default=None, + help="number of channels for softmax output") + parser.add_argument('--num_mixtures', type=int, default=None, + help="number of gaussian mixtures for gaussian output") + parser.add_argument('--log_scale_min', type=float, default=None, + help="minimum clip value of log variance of gaussian output") + + parser.add_argument('--conditioner.filter_sizes', type=list, + help="conv2d tranpose op filter sizes for building conditioner") + parser.add_argument('--conditioner.upsample_factors', type=list, + help="list of upsample factors for building conditioner") + + parser.add_argument('--learning_rate', type=float) + parser.add_argument('--gradient_max_norm', type=float) + parser.add_argument('--anneal.every', type=int, + help="step interval for annealing learning rate") + parser.add_argument('--anneal.rate', type=float) + + parser.add_argument('--config', action=jsonargparse.ActionConfigFile) + + +def pad_to_size(array, length, pad_with=0.0): + """ + Pad an array on the first (length) axis to a given length. + """ + padding = length - array.shape[0] + assert padding >= 0, "Padding required was less than zero" + + paddings = [(0, 0)] * len(array.shape) + paddings[0] = (0, padding) + + return np.pad(array, paddings, mode='constant', constant_values=pad_with) + + +def calculate_context_size(config): + dilations = list( + itertools.islice( + itertools.cycle(config.dilation_block), config.layers)) + config.context_size = sum(dilations) + 1 + print("Context size is", config.context_size) + + +def load_latest_checkpoint(checkpoint_dir, rank=0): + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + # Create checkpoint index file if not exist. + if (not os.path.isfile(checkpoint_path)) and rank == 0: + with open(checkpoint_path, "w") as handle: + handle.write("model_checkpoint_path: step-0") + + # Make sure that other process waits until checkpoint file is created + # by process 0. + while not os.path.isfile(checkpoint_path): + time.sleep(1) + + # Fetch the latest checkpoint index. + with open(checkpoint_path, "r") as handle: + latest_checkpoint = handle.readline().split()[-1] + iteration = int(latest_checkpoint.split("-")[-1]) + + return iteration + + +def save_latest_checkpoint(checkpoint_dir, iteration): + checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") + # Update the latest checkpoint index. + with open(checkpoint_path, "w") as handle: + handle.write("model_checkpoint_path: step-{}".format(iteration)) + + +def load_parameters(checkpoint_dir, rank, model, optimizer=None, + iteration=None, file_path=None): + if file_path is None: + if iteration is None: + iteration = load_latest_checkpoint(checkpoint_dir, rank) + if iteration == 0: + return + file_path = "{}/step-{}".format(checkpoint_dir, iteration) + + model_dict, optimizer_dict = dg.load_dygraph(file_path) + model.set_dict(model_dict) + print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path)) + if optimizer and optimizer_dict: + optimizer.set_dict(optimizer_dict) + print("[checkpoint] Rank {}: loaded optimizer state from {}".format( + rank, file_path)) + + +def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): + file_path = "{}/step-{}".format(checkpoint_dir, iteration) + model_dict = model.state_dict() + dg.save_dygraph(model_dict, file_path) + print("[checkpoint] Saved model to {}".format(file_path)) + + if optimizer: + opt_dict = optimizer.state_dict() + dg.save_dygraph(opt_dict, file_path) + print("[checkpoint] Saved optimzier state to {}".format(file_path)) diff --git a/parakeet/models/wavenet/wavenet.py b/parakeet/models/wavenet/wavenet.py new file mode 100644 index 0000000..acc6e76 --- /dev/null +++ b/parakeet/models/wavenet/wavenet.py @@ -0,0 +1,188 @@ +import itertools +import os +import time + +import librosa +import numpy as np +from paddle import fluid +import paddle.fluid.dygraph as dg + +import utils +from data import LJSpeech +from wavenet_modules import WaveNetModule, debug + + +class WaveNet(): + def __init__(self, config, checkpoint_dir, parallel=False, rank=0, + nranks=1, tb_logger=None): + # Process config to calculate the context size + dilations = list( + itertools.islice( + itertools.cycle(config.dilation_block), config.layers)) + config.context_size = sum(dilations) + 1 + self.config = config + self.checkpoint_dir = checkpoint_dir + self.parallel = parallel + self.rank = rank + self.nranks = nranks + self.tb_logger = tb_logger + + def build(self, training=True): + config = self.config + dataset = LJSpeech(config, self.nranks, self.rank) + self.trainloader = dataset.trainloader + self.validloader = dataset.validloader + +# if self.rank == 0: +# for i, (audios, mels, ids) in enumerate(self.validloader()): +# print("audios {}, mels {}, ids {}".format(audios.dtype, mels.dtype, ids.dtype)) +# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format( +# i, self.rank, audios.shape, mels.shape, ids.shape, +# ids.numpy())) +# +# for i, (audios, mels, ids) in enumerate(self.trainloader): +# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format( +# i, self.rank, audios.shape, mels.shape, ids.shape, +# ids.numpy())) + + wavenet = WaveNetModule("wavenet", config, self.rank) + + # Dry run once to create and initalize all necessary parameters. + audio = dg.to_variable(np.random.randn(1, 20000).astype(np.float32)) + mel = dg.to_variable( + np.random.randn(1, 100, self.config.mel_bands).astype(np.float32)) + audio_start = dg.to_variable(np.array([0], dtype=np.int32)) + wavenet(audio, mel, audio_start) + + if training: + # Create Learning rate scheduler. + lr_scheduler = dg.ExponentialDecay( + learning_rate = config.learning_rate, + decay_steps = config.anneal.every, + decay_rate = config.anneal.rate, + staircase=True) + + optimizer = fluid.optimizer.AdamOptimizer( + learning_rate=lr_scheduler) + + clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( + config.gradient_max_norm) + + # Load parameters. + utils.load_parameters(self.checkpoint_dir, self.rank, + wavenet, optimizer, + iteration=config.iteration, + file_path=config.checkpoint) + print("Rank {}: checkpoint loaded.".format(self.rank)) + + # Data parallelism. + if self.parallel: + strategy = dg.parallel.prepare_context() + wavenet = dg.parallel.DataParallel(wavenet, strategy) + + self.wavenet = wavenet + self.optimizer = optimizer + self.clipper = clipper + + else: + # Load parameters. + utils.load_parameters(self.checkpoint_dir, self.rank, wavenet, + iteration=config.iteration, + file_path=config.checkpoint) + print("Rank {}: checkpoint loaded.".format(self.rank)) + + self.wavenet = wavenet + + def train_step(self, iteration): + self.wavenet.train() + + start_time = time.time() + audios, mels, audio_starts = next(self.trainloader) + load_time = time.time() + + loss, _ = self.wavenet(audios, mels, audio_starts) + + if self.parallel: + # loss = loss / num_trainers + loss = self.wavenet.scale_loss(loss) + loss.backward() + self.wavenet.apply_collective_grads() + else: + loss.backward() + + if isinstance(self.optimizer._learning_rate, + fluid.optimizer.LearningRateDecay): + current_lr = self.optimizer._learning_rate.step().numpy() + else: + current_lr = self.optimizer._learning_rate + + self.optimizer.minimize(loss, grad_clip=self.clipper, + parameter_list=self.wavenet.parameters()) + self.wavenet.clear_gradients() + + graph_time = time.time() + + if self.rank == 0: + loss_val = float(loss.numpy()) * self.nranks + log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \ + "Time: {:.3f}/{:.3f}".format( + self.rank, iteration, loss_val, + load_time - start_time, graph_time - load_time) + print(log) + + tb = self.tb_logger + tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration) + tb.add_scalar("Learning-Rate", current_lr, iteration) + + @dg.no_grad + def valid_step(self, iteration): + self.wavenet.eval() + + total_loss = [] + start_time = time.time() + sample_audios = [] + for audios, mels, audio_starts in self.validloader(): + loss, sample_audio = self.wavenet(audios, mels, audio_starts, True) + total_loss.append(float(loss.numpy())) + sample_audios.append(sample_audio) + total_time = time.time() - start_time + + if self.rank == 0: + loss_val = np.mean(total_loss) + log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format( + self.rank, loss_val, total_time) + print(log) + + tb = self.tb_logger + tb.add_scalar("Valid-Avg-Loss", loss_val, iteration) + tb.add_audio("Teacher-Forced-Audio-0", sample_audios[0].numpy(), + iteration, sample_rate=self.config.sample_rate) + tb.add_audio("Teacher-Forced-Audio-1", sample_audios[1].numpy(), + iteration, sample_rate=self.config.sample_rate) + + def save(self, iteration): + utils.save_latest_parameters(self.checkpoint_dir, iteration, + self.wavenet, self.optimizer) + utils.save_latest_checkpoint(self.checkpoint_dir, iteration) + + @dg.no_grad + def infer(self, iteration): + self.wavenet.eval() + + config = self.config + sample = config.sample + + output = "{}/{}/iter-{}".format(config.output, config.name, iteration) + os.makedirs(output, exist_ok=True) + + filename = "{}/valid_{}.wav".format(output, sample) + print("Synthesize sample {}, save as {}".format(sample, filename)) + + mels_list = [mels for _, mels, _ in self.validloader()] + start_time = time.time() + syn_audio = self.wavenet.synthesize(mels_list[sample]) + syn_time = time.time() - start_time + print("audio shape {}, synthesis time {}".format( + syn_audio.shape, syn_time)) + librosa.output.write_wav(filename, syn_audio, + sr=config.sample_rate) diff --git a/parakeet/models/wavenet/wavenet_modules.py b/parakeet/models/wavenet/wavenet_modules.py new file mode 100644 index 0000000..c5c01e9 --- /dev/null +++ b/parakeet/models/wavenet/wavenet_modules.py @@ -0,0 +1,423 @@ +import itertools +import math + +import numpy as np +from paddle import fluid +import paddle.fluid.dygraph as dg +import ops +import weight_norm + + +def get_padding(filter_size, stride, padding_type='same'): + if padding_type == 'same': + padding = [(x - y) // 2 for x, y in zip(filter_size, stride)] + else: + raise ValueError("Only support same padding") + return padding + + +def debug(x, var_name, rank, verbose=False): + if not verbose and rank != 0: + return + dim = len(x.shape) + if not isinstance(x, np.ndarray): + x = x.numpy() + if dim == 1: + print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x)) + elif dim == 2: + print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5])) + elif dim == 3: + print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5, 0])) + else: + print("Rank", rank, var_name, "shape", x.shape) + + +def extract_slices(x, audio_starts, audio_length, rank): + slices = [] + for i in range(x.shape[0]): + start = audio_starts.numpy()[i] + end = start + audio_length + slice = fluid.layers.slice( + x, axes=[0, 1], starts=[i, start], ends=[i+1, end]) + slices.append(fluid.layers.squeeze(slice, [0])) + + x = fluid.layers.stack(slices, axis=0) + + return x + + +class Conditioner(dg.Layer): + def __init__(self, name_scope, config): + super(Conditioner, self).__init__(name_scope) + upsample_factors = config.conditioner.upsample_factors + filter_sizes = config.conditioner.filter_sizes + assert np.prod(upsample_factors) == config.fft_window_shift + + self.deconvs = [] + for i, up_scale in enumerate(upsample_factors): + stride = (up_scale, 1) + padding = get_padding(filter_sizes[i], stride) + self.deconvs.append( + ops.Conv2DTranspose( + self.full_name(), + num_filters=1, + filter_size=filter_sizes[i], + padding=padding, + stride=stride)) + + # Register python list as parameters. + for i, layer in enumerate(self.deconvs): + self.add_sublayer("conv_transpose_{}".format(i), layer) + + def forward(self, x): + x = fluid.layers.unsqueeze(x, 1) + for layer in self.deconvs: + x = fluid.layers.leaky_relu(layer(x), alpha=0.4) + + return fluid.layers.squeeze(x, [1]) + + +class WaveNetModule(dg.Layer): + def __init__(self, name_scope, config, rank): + super(WaveNetModule, self).__init__(name_scope) + + self.rank = rank + self.conditioner = Conditioner(self.full_name(), config) + self.dilations = list( + itertools.islice( + itertools.cycle(config.dilation_block), config.layers)) + self.context_size = sum(self.dilations) + 1 + self.log_scale_min = config.log_scale_min + self.config = config + + print("dilations", self.dilations) + print("context_size", self.context_size) + + if config.loss_type == "softmax": + self.embedding_fc = ops.Embedding( + self.full_name(), + num_embeddings=config.num_channels, + embed_dim=config.residual_channels) + elif config.loss_type == "mix-gaussian-pdf": + self.embedding_fc = ops.FC( + self.full_name(), + in_features=1, + size=config.residual_channels, + num_flatten_dims=2, + relu=False) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + + self.dilated_causal_convs = [] + for dilation in self.dilations: + self.dilated_causal_convs.append( + ops.Conv1D_GU( + self.full_name(), + conditioner_dim=config.mel_bands, + in_channels=config.residual_channels, + num_filters=config.residual_channels, + filter_size=config.kernel_width, + dilation=dilation, + causal=True + ) + ) + + for i, layer in enumerate(self.dilated_causal_convs): + self.add_sublayer("dilated_causal_conv_{}".format(i), layer) + + self.fc1 = ops.FC( + self.full_name(), + in_features=config.residual_channels, + size=config.skip_channels, + num_flatten_dims=2, + relu=True, + act="relu") + + self.fc2 = ops.FC( + self.full_name(), + in_features=config.skip_channels, + size=config.skip_channels, + num_flatten_dims=2, + relu=True, + act="relu") + + if config.loss_type == "softmax": + self.fc3 = ops.FC( + self.full_name(), + in_features=config.skip_channels, + size=config.num_channels, + num_flatten_dims=2, + relu=False) + elif config.loss_type == "mix-gaussian-pdf": + self.fc3 = ops.FC( + self.full_name(), + in_features=config.skip_channels, + size=3 * config.num_mixtures, + num_flatten_dims=2, + relu=False) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + + def sample_softmax(self, mix_parameters): + batch, length, hidden = mix_parameters.shape + mix_param_2d = fluid.layers.reshape(mix_parameters, + [batch * length, hidden]) + mix_param_2d = fluid.layers.softmax(mix_param_2d, axis=-1) + + # quantized: [batch * length] + quantized = fluid.layers.cast(fluid.layers.sampling_id(mix_param_2d), + dtype="float32") + samples = (quantized + 0.5) * (2.0 / self.config.num_channels) - 1.0 + + # samples: [batch * length] + return samples + + def sample_mix_gaussian(self, mix_parameters): + # mix_parameters reshape from [bs, 13799, 3 * num_mixtures] + # to [bs * 13799, 3 * num_mixtures]. + batch, length, hidden = mix_parameters.shape + mix_param_2d = fluid.layers.reshape(mix_parameters, + [batch * length, hidden]) + K = hidden // 3 + + # Unpack the parameters of the mixture of gaussian. + logits_pi = mix_param_2d[:, 0 : K] + mu = mix_param_2d[:, K : 2*K] + log_s = mix_param_2d[:, 2*K : 3*K] + s = fluid.layers.exp(log_s) + + pi = fluid.layers.softmax(logits_pi, axis=-1) + comp_samples = fluid.layers.sampling_id(pi) + + row_idx = dg.to_variable(np.arange(batch * length)) + comp_samples = fluid.layers.stack([row_idx, comp_samples], axis=-1) + + mu_comp = fluid.layers.gather_nd(mu, comp_samples) + s_comp = fluid.layers.gather_nd(s, comp_samples) + + # N(0, 1) Normal Sample. + u = fluid.layers.gaussian_random(shape=[batch * length]) + samples = mu_comp + u * s_comp + samples = fluid.layers.clip(samples, min=-1.0, max=1.0) + + return samples + + def softmax_loss(self, targets, mix_parameters): + # targets: [bs, 13799] -> [bs, 11752] + # mix_params: [bs, 13799, 3] -> [bs, 11752, 3] + targets = targets[:, self.context_size:] + mix_parameters = mix_parameters[:, self.context_size:, :] + + # Quantized audios to integral values with range [0, num_channels) + num_channels = self.config.num_channels + targets = fluid.layers.clip(targets, min=-1.0, max=0.99999) + quantized = fluid.layers.cast( + (targets + 1.0) / 2.0 * num_channels, dtype="int64") + + # per_sample_loss shape: [bs, 17952, 1] + per_sample_loss = fluid.layers.softmax_with_cross_entropy( + logits=mix_parameters, label=fluid.layers.unsqueeze(quantized, 2)) + loss = fluid.layers.reduce_mean(per_sample_loss) + #debug(loss, "softmax loss", self.rank) + + return loss + + def mixture_density_loss(self, targets, mix_parameters, log_scale_min): + # targets: [bs, 13799] -> [bs, 11752] + # mix_params: [bs, 13799, 3] -> [bs, 11752, 3] + targets = targets[:, self.context_size:] + mix_parameters = mix_parameters[:, self.context_size:, :] + + # log_s: [bs, 11752, num_mixture] + logits_pi, mu, log_s = fluid.layers.split(mix_parameters, num_or_sections=3, dim=-1) + + pi = fluid.layers.softmax(logits_pi, axis=-1) + log_s = fluid.layers.clip(log_s, min=log_scale_min, max=100.0) + inv_s = fluid.layers.exp(0.0 - log_s) + + # Calculate gaussian loss. + targets = fluid.layers.unsqueeze(targets, -1) + targets = fluid.layers.expand(targets, [1, 1, self.config.num_mixtures]) + x_std = inv_s * (targets - mu) + exponent = fluid.layers.exp(-0.5 * x_std * x_std) + # pdf_x: [bs, 11752, 1] + pdf_x = 1.0 / np.sqrt(2.0 * np.pi) * inv_s * exponent + pdf_x = pi * pdf_x + # pdf_x: [bs, 11752] + pdf_x = fluid.layers.reduce_sum(pdf_x, dim=-1) + per_sample_loss = 0.0 - fluid.layers.log(pdf_x + 1e-9) + + loss = fluid.layers.reduce_mean(per_sample_loss) + + return loss + + def forward(self, audios, mels, audio_starts, sample=False): + # audios: [bs, 13800], mels: [bs, full_frame_length, 80] + # audio_starts: [bs] + # Build conditioner based on mels. + full_conditioner = self.conditioner(mels) + + # Slice conditioners. + audio_length = audios.shape[1] + conditioner = extract_slices(full_conditioner, + audio_starts, audio_length, self.rank) + + # input_audio, target_audio: [bs, 13799] + input_audios = audios[:, :-1] + target_audios = audios[:, 1:] + # conditioner: [bs, 13799, 80] + conditioner = conditioner[:, 1:, :] + + loss_type = self.config.loss_type + + # layer_input: [bs, 13799, 128] + if loss_type == "softmax": + input_audios = fluid.layers.clip( + input_audios, min=-1.0, max=0.99999) + # quantized have values in [0, num_channels) + quantized = fluid.layers.cast( + (input_audios + 1.0) / 2.0 * self.config.num_channels, + dtype="int64") + layer_input = self.embedding_fc(fluid.layers.unsqueeze(quantized, 2)) + elif loss_type == "mix-gaussian-pdf": + layer_input = self.embedding_fc(fluid.layers.unsqueeze(input_audios, 2)) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + + # layer_input: [bs, res_channel, 1, 13799] + layer_input = fluid.layers.unsqueeze(fluid.layers.transpose(layer_input, perm=[0, 2, 1]), 2) + # conditioner: [bs, mel_bands, 1, 13799] + conditioner = fluid.layers.unsqueeze(fluid.layers.transpose(conditioner, perm=[0, 2, 1]), 2) + + # layer_input: [bs, res_channel, 1, 13799] + # skip: [bs, res_channel, 1, 13799] + skip = None + for i, layer in enumerate(self.dilated_causal_convs): + layer_input, skip = layer(layer_input, skip, conditioner) + #debug(layer_input, "layer_input_" + str(i), self.rank) + #debug(skip, "skip_" + str(i), self.rank) + + # Reshape skip to [bs, 13799, res_channel] + skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1]) + #debug(skip, "skip", self.rank) + + # mix_param: [bs, 13799, 3 * num_mixtures] + mix_parameters = self.fc3(self.fc2(self.fc1(skip))) + + # Sample teacher-forced audio. + sample_audios = None + if sample: + if loss_type == "softmax": + sample_audios = self.sample_softmax(mix_parameters) + elif loss_type == "mix-gaussian-pdf": + sample_audios = self.sample_mix_gaussian(mix_parameters) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + #debug(sample_audios, "sample_audios", self.rank) + + # Calculate mix-gaussian density loss. + # padding is all zero. + # target_audio: [bs, 13799]. + # mix_params: [bs, 13799, 3]. + if loss_type == "softmax": + loss = self.softmax_loss(target_audios, mix_parameters) + elif loss_type == "mix-gaussian-pdf": + loss = self.mixture_density_loss(target_audios, + mix_parameters, self.log_scale_min) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + + #print("Rank {}, loss {}".format(self.rank, loss.numpy())) + + return loss, sample_audios + + def synthesize(self, mels): + self.start_new_sequence() + print("input mels shape", mels.shape) + # mels: [bs=1, n_frames, 80] + # conditioner: [1, n_frames * samples_per_frame, 80] + # Should I move forward by one sample? No difference + # Append context frame to mels + bs, n_frames, mel_bands = mels.shape + #num_pad_frames = int(np.ceil(self.context_size / self.config.fft_window_shift)) + #silence = fluid.layers.zeros(shape=[bs, num_pad_frames, mel_bands], dtype="float32") + #inf_mels = fluid.layers.concat([silence, mels], axis=1) + #print("padded mels shape", inf_mels.shape) + + #conditioner = self.conditioner(inf_mels)[:, self.context_size:, :] + conditioner = self.conditioner(mels) + time_steps = conditioner.shape[1] + print("Total steps", time_steps) + + loss_type = self.config.loss_type + audio_samples = [] + current_sample = fluid.layers.zeros(shape=[bs, 1, 1], dtype="float32") + for i in range(time_steps): + if i % 100 == 0: + print("Step", i) + + # convert from real value sample to audio embedding. + # [bs, 1, 128] + if loss_type == "softmax": + current_sample = fluid.layers.clip( + current_sample, min=-1.0, max=0.99999) + # quantized have values in [0, num_channels) + quantized = fluid.layers.cast( + (current_sample + 1.0) / 2.0 * self.config.num_channels, + dtype="int64") + audio_input = self.embedding_fc(quantized) + elif loss_type == "mix-gaussian-pdf": + audio_input = self.embedding_fc(current_sample) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + + # [bs, 128, 1, 1] + audio_input = fluid.layers.unsqueeze(fluid.layers.transpose(audio_input, perm=[0, 2, 1]), 2) + # [bs, 80] + cond_input = conditioner[:, i, :] + # [bs, 80, 1, 1] + cond_input = fluid.layers.reshape( + cond_input, cond_input.shape + [1, 1]) + + skip = None + for layer in self.dilated_causal_convs: + audio_input, skip = layer.add_input(audio_input, skip, cond_input) + + # [bs, 1, 128] + skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1]) + # [bs, 1, 3] + mix_parameters = self.fc3(self.fc2(self.fc1(skip))) + if loss_type == "softmax": + sample = self.sample_softmax(mix_parameters) + elif loss_type == "mix-gaussian-pdf": + sample = self.sample_mix_gaussian(mix_parameters) + else: + raise ValueError( + "loss_type {} is unsupported!".format(loss_type)) + audio_samples.append(sample) + # [bs] + current_sample = audio_samples[-1] + # [bs, 1, 1] + current_sample = fluid.layers.reshape(current_sample, + current_sample.shape + [1, 1]) + + # syn_audio: (num_samples,) + syn_audio = fluid.layers.concat(audio_samples, axis=0).numpy() + + return syn_audio + + def start_new_sequence(self): + for layer in self.sublayers(): + if isinstance(layer, weight_norm.Conv1D): + layer.start_new_sequence() + + def save(self, iteration): + utils.save_latest_parameters(self.checkpoint_dir, iteration, + self.wavenet, self.optimizer) + utils.save_latest_checkpoint(self.checkpoint_dir, iteration) diff --git a/parakeet/models/wavenet/weight_norm.py b/parakeet/models/wavenet/weight_norm.py new file mode 100644 index 0000000..75fe413 --- /dev/null +++ b/parakeet/models/wavenet/weight_norm.py @@ -0,0 +1,920 @@ +import math +from copy import deepcopy + +import numpy as np +import paddle.fluid.dygraph as dg +from paddle import fluid +from paddle.fluid import core +from paddle.fluid.framework import Variable +from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer +from paddle.fluid.layers import utils +from six.moves import reduce + + +def _norm(p, dim): + """Computes the norm over all dimensions except dim. + It differs from pytorch implementation that it does not keep dim. + This difference is related with the broadcast mechanism in paddle. + Read elementeise_mul for more. + """ + if dim is None: + return np.linalg.norm(p, ord=2, axis=None) + elif dim == 0: + p = np.reshape(p, newshape=(p.shape[0], -1)) + return np.linalg.norm(p, ord=2, axis=1) + elif dim == p.ndim - 1: + p = np.reshape(p, newshape=(-1, p.shape[-1])) + return np.linalg.norm(p, ord=2, axis=0) + else: + perm = list(range(p.ndim)) + perm[0] = dim + perm[dim] = 0 + return _norm(np.transpose(p, axes=perm)) + + +class Conv1D(dg.Layer): + """ + A convolution 1D block implemented with Conv2D. Form simplicity and + ensuring the output has the same length as the input, it does not allow + stride > 1. + """ + def __init__(self, + name_scope, + num_filters, + filter_size=3, + dilation=1, + groups=None, + causal=False, + param_attr=None, + bias_attr=None, + use_cudnn=True, + act=None, + dtype="float32"): + super(Conv1D, self).__init__(name_scope, dtype=dtype) + + if causal: + padding = dilation * (filter_size - 1) + else: + padding = (dilation * (filter_size - 1)) // 2 + + self.num_filters = num_filters + self.filter_size = filter_size + self.dilation = dilation + self.causal = causal + self.padding = padding + self.act = act + + self.conv = Conv2D( + self.full_name(), + num_filters=num_filters, + filter_size=(1, filter_size), + stride=(1, 1), + dilation=(1, dilation), + padding=(0, padding), + groups=groups, + param_attr=param_attr, + bias_attr=bias_attr, + use_cudnn=use_cudnn, + act=act, + dtype=dtype) + + def forward(self, x): + """ + Args: + x (Variable): Shape(B, C_in, 1, T), the input, where C_in means + input channels. + Returns: + x (Variable): Shape(B, C_out, 1, T), the outputs, where C_out means + output channels (num_filters). + """ + x = self.conv(x) + if self.filter_size > 1: + if self.causal: + x = fluid.layers.slice( + x, axes=[3], starts=[0], ends=[-self.padding]) + elif self.filter_size % 2 == 0: + x = fluid.layers.slice(x, axes=[3], starts=[0], ends=[-1]) + return x + + def start_new_sequence(self): + self.temp_weight = None + self.input_buffer = None + + def add_input(self, x): + """ + Adding input for a time step and compute an output for a time step. + + Args: + x (Variable): Shape(B, C_in, 1, T), the input, where C_in means + input channels, and T = 1. + Returns: + out (Variable): Shape(B, C_out, 1, T), the outputs, where C_out + means output channels (num_filters), and T = 1. + + """ + if self.temp_weight is None: + self.temp_weight = self._reshaped_weight() + + window_size = 1 + (self.filter_size - 1) * self.dilation + batch_size = x.shape[0] + in_channels = x.shape[1] + + if self.filter_size > 1: + if self.input_buffer is None: + self.input_buffer = fluid.layers.fill_constant( + [batch_size, in_channels, 1, window_size - 1], + dtype=x.dtype, + value=0.0) + else: + self.input_buffer = self.input_buffer[:, :, :, 1:] + self.input_buffer = fluid.layers.concat( + [self.input_buffer, x], axis=3) + x = self.input_buffer + if self.dilation > 1: + if not hasattr(self, "indices"): + self.indices = dg.to_variable( + np.arange(0, window_size, self.dilation)) + tmp = fluid.layers.transpose( + self.input_buffer, perm=[3, 1, 2, 0]) + tmp = fluid.layers.gather(tmp, index=self.indices) + tmp = fluid.layers.transpose(tmp, perm=[3, 1, 2, 0]) + x = tmp + inputs = fluid.layers.reshape( + x, shape=[batch_size, in_channels * 1 * self.filter_size]) + out = fluid.layers.matmul(inputs, self.temp_weight, transpose_y=True) + out = fluid.layers.elementwise_add(out, self.conv._bias_param, axis=-1) + out = fluid.layers.reshape(out, out.shape + [1, 1]) + out = self._helper.append_activation(out, act=self.act) + return out + + def _reshaped_weight(self): + """ + Get the linearized weight of convolution filter, cause it is by nature + a matmul weight. And because the model uses weight norm, compute the + weight by weight_v * weight_g to make it faster. + Returns: + weight_matrix (Variable): Shape(C_out, C_in * 1 * kernel_size) + """ + shape = self.conv._filter_param_v.shape + matrix_shape = [shape[0], np.prod(shape[1:])] + weight_matrix = fluid.layers.reshape( + self.conv._filter_param_v, shape=matrix_shape) + weight_matrix = fluid.layers.elementwise_mul( + fluid.layers.l2_normalize( + weight_matrix, axis=1), + self.conv._filter_param_g, + axis=0) + return weight_matrix + + +class FC(dg.Layer): + """ + **Fully Connected Layer** + This function creates a fully connected layer in the network. It can take + one or multiple tensors as its inputs(input can be a list of Variable, see + Args in detail). It creates a pair of variables called (magnitude(g), + direction(V)) for each input tensor. Elementwise_mul(V, g) represents a fully connected + weight matrix from each input unit to each output unit. + The fully connected layer multiplies each input tensor + with its corresponding weight to produce an output Tensor with shape [M, `size`], + where M is batch size. If multiple input tensors are given, the results of + multiple output tensors with shape [M, `size`] will be summed up. If bias_attr + is not None, a bias variable will be created and added to the output. + Finally, if activation is not None, it will be applied to the output as well. + When the input is single tensor: + .. math:: + Out = Act({X(normalize(V)g) + b}) + When the input are multiple tensors: + .. math:: + Out = Act({\sum_{i=0}^{N-1}X_i(V_ig_i) + b}) + In the above equation: + * :math:`N`: Number of the input. N equals to len(input) if input is list of Variable. + * :math:`X_i`: The i-th input tensor. + * :math:`V_i`: The i-th direction matrix corresponding i-th input tensor. + * :math:`g_i`: The i-th magnitude vector corresponding i-th input tensor. + * :math:`b`: The bias parameter created by this layer (if needed). + * :math:`Act`: The activation function. + * :math:`Out`: The output tensor. + See below for an example. + .. code-block:: text + Given: + data_1.data = [[[0.1, 0.2], + [0.3, 0.4]]] + data_1.shape = (1, 2, 2) # 1 is batch_size + data_2 = [[[0.1, 0.2, 0.3]]] + data_2.shape = (1, 1, 3) + out = fluid.layers.fc(input=[data_1, data_2], size=2) + Then: + out.data = [[0.18669507, 0.1893476]] + out.shape = (1, 2) + Args: + name_scope(str): The name of this class. + size(int): The number of output units in this layer. + num_flatten_dims (int): The fc layer can accept an input tensor with more than + two dimensions. If this happens, the multidimensional tensor will first be flattened + into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input + tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1) + dimensions will be flatten to form the first dimension of the final matrix (height of + the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to + form the second dimension of the final matrix (width of the matrix). For example, suppose + `X` is a 5-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3. + Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. Default: 1 + param_attr (ParamAttr|list of ParamAttr|None): The parameter attribute for learnable + parameters/weights of this layer. + bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias + of this layer. If it is set to False, no bias will be added to the output units. + If it is set to None, the bias is initialized zero. Default: None. + act (str|None): Activation to be applied to the output of this layer. + is_test(bool): A flag indicating whether execution is in test phase. Default: False + dtype(str): Dtype used for weight + Raises: + ValueError: If rank of the input tensor is less than 2. + Examples: + .. code-block:: python + from paddle.fluid.dygraph.base import to_variable + import paddle.fluid as fluid + from paddle.fluid.dygraph import FC + import numpy as np + data = np.random.uniform( -1, 1, [30, 10, 32] ).astype('float32') + with fluid.dygraph.guard(): + fc = FC( "fc", 64, num_flatten_dims=2) + data = to_variable( data ) + conv = fc( data ) + """ + + def __init__(self, + name_scope, + size, + num_flatten_dims=1, + epsilon=1e-30, + param_attr=None, + bias_attr=None, + act=None, + is_test=False, + dtype="float32"): + super(FC, self).__init__(name_scope, dtype) + + self._size = size + self._num_flatten_dims = num_flatten_dims + self._epsilon = epsilon + self._dtype = dtype + self._param_attr = param_attr + self._bias_attr = bias_attr + self._act = act + self.__g = list() + self.__v = list() + + @property + def _v(self, i=0): + return self.__v[i] + + @property + def _g(self, i=0): + return self.__g[i] + + @_v.setter + def _v(self, value, i=0): + assert isinstance(value, Parameter) + self.__v[i] = value + + @_g.setter + def _g(self, value, i=0): + assert isinstance(value, Parameter) + self.__g[i] = value + + def _build_once(self, input): + i = 0 + for inp, param in self._helper.iter_inputs_and_params( + input, self._param_attr): + input_shape = inp.shape + + param_shape = [ + reduce(lambda a, b: a * b, + input_shape[self._num_flatten_dims:], 1) + ] + [self._size] + self.__v.append( + self.add_parameter( + "_v%d" % i, + self.create_parameter( + attr=param, + shape=param_shape, + dtype=self._dtype, + is_bias=False))) + + magnitude_shape = param_shape[1:] + magnitude_value = np.linalg.norm( + self.__v[i].numpy(), ord=2, axis=0) + + self.__g.append( + self.add_parameter( + "_g%d" % i, + self.create_parameter( + attr=fluid.ParamAttr(initializer=fluid.initializer. + NumpyArrayInitializer( + magnitude_value)), + shape=magnitude_shape, + dtype=self._dtype, + is_bias=False))) + i += 1 + + size = list([self._size]) + self._b = self.create_parameter( + attr=self._bias_attr, shape=size, dtype=self._dtype, is_bias=True) + + def forward(self, input): + mul_results = list() + i = 0 + for inp, param in self._helper.iter_inputs_and_params( + input, self._param_attr): + v_norm = self._helper.create_variable_for_type_inference( + self._dtype) + v_normalized = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="norm", + inputs={"X": self.__v[i]}, + outputs={"Out": v_normalized, + "Norm": v_norm}, + attrs={"axis": 0, + "epsilon": self._epsilon}) + weight = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="elementwise_mul", + inputs={"X": [v_normalized], + "Y": [self.__g[i]]}, + outputs={"Out": [weight]}, + attrs={"axis": 1}) + tmp = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type="mul", + inputs={"X": inp, + "Y": weight}, + outputs={"Out": tmp}, + attrs={ + "x_num_col_dims": self._num_flatten_dims, + "y_num_col_dims": 1 + }) + i += 1 + mul_results.append(tmp) + + if len(mul_results) == 1: + pre_bias = mul_results[0] + else: + pre_bias = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="sum", + inputs={"X": mul_results}, + outputs={"Out": pre_bias}, + attrs={"use_mkldnn": False}) + + if self._b: + pre_activation = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type="elementwise_add", + inputs={"X": [pre_bias], + "Y": [self._b]}, + outputs={"Out": [pre_activation]}, + attrs={"axis": self._num_flatten_dims}) + else: + pre_activation = pre_bias + # Currently, we don't support inplace in dygraph mode + return self._helper.append_activation(pre_activation, act=self._act) + + +class Conv2D(dg.Layer): + """ + The convolution2D layer calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input and + Output are in NCHW format, where N is batch size, C is the number of + channels, H is the height of the feature, and W is the width of the feature. + Filter is in MCHW format, where M is the number of output image channels, + C is the number of input image channels, H is the height of the filter, + and W is the width of the filter. If the groups is greater than 1, + C will equal the number of input image channels divided by the groups. + Please refer to UFLDL's `convolution + ` + for more detials. + If bias attribution and activation type are provided, bias is added to the + output of the convolution, and the corresponding activation function is + applied to the final result. + For each input :math:`X`, the equation is: + .. math:: + Out = \sigma ((Vg) \\ast X + b) + Where: + * :math:`X`: Input value, a tensor with NCHW format. + * :math:`V`: Filter direction value, a tensor with MCHW format. + * :math:`g`: Filter magnitude value, a tensor with M format. + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + Example: + - Input: + Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` + Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)` + - Output: + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` + Where + .. math:: + H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 + Args: + name_scope(str) : The name for this class. + num_filters(int): The number of filter. It is as same as the output + image channel. + filter_size (int|tuple|None): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + stride (int|tuple): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: stride = 1. + padding (int|tuple): The padding size. If padding is a tuple, it must + contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: padding = 0. + dilation (int|tuple): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: dilation = 1. + groups (int): The groups number of the Conv2d Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Default: groups=1. + param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights + of conv2d. If it is set to None or one attribute of ParamAttr, conv2d + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with :math:`Normal(0.0, std)`, + and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None. + bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv2d + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True + act (str): Activation type, if it is set to None, activation is not appended. + Default: None + Raises: + ValueError: If the shapes of input, filter_size, stride, padding and + groups mismatch. + Examples: + .. code-block:: python + from paddle.fluid.dygraph.base import to_variable + import paddle.fluid as fluid + from paddle.fluid.dygraph import Conv2D + import numpy as np + data = np.random.uniform( -1, 1, [10, 3, 32, 32] ).astype('float32') + with fluid.dygraph.guard(): + conv2d = Conv2D( "conv2d", 2, 3) + data = to_variable( data ) + conv = conv2d( data ) + """ + + def __init__(self, + name_scope, + num_filters, + filter_size, + stride=1, + padding=0, + dilation=1, + groups=None, + param_attr=None, + bias_attr=None, + use_cudnn=True, + act=None, + epsilon=1e-30, + dtype="float32"): + assert param_attr is not False, "param_attr should not be False here." + super(Conv2D, self).__init__(name_scope, dtype) + self._groups = groups + self._stride = utils.convert_to_list(stride, 2, "stride") + self._padding = utils.convert_to_list(padding, 2, "padding") + self._dilation = utils.convert_to_list(dilation, 2, "dilation") + self._act = act + if not isinstance(use_cudnn, bool): + raise ValueError("use_cudnn should be True or False") + self._use_cudnn = use_cudnn + self._filter_size = filter_size + self._num_filters = num_filters + self._param_attr = param_attr + self._bias_attr = bias_attr + self._epsilon = epsilon + self._dtype = dtype + # if (self._num_channels == self._groups and + # num_filters % self._num_channels == 0 and not self._use_cudnn): + # self._l_type = 'depthwise_conv2d' + # else: + # TODO(jiabin): recover the usage of depthwise_conv2d when it's + # kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17275 + self._l_type = "conv2d" + + def _build_once(self, input): + self._num_channels = input.shape[1] + if self._groups is None: + num_filter_channels = self._num_channels + else: + if self._num_channels % self._groups != 0: + raise ValueError("num_channels must be divisible by groups.") + num_filter_channels = self._num_channels // self._groups + filter_size = utils.convert_to_list(self._filter_size, 2, + "filter_size") + filter_shape = [self._num_filters, int(num_filter_channels) + ] + filter_size + + def _get_default_param_initializer(): + filter_elem_num = filter_size[0] * filter_size[ + 1] * self._num_channels + std = (2.0 / filter_elem_num)**0.5 + return Normal(0.0, std, 0) + + # weight_v + self._filter_param_v = self.create_parameter( + attr=self._param_attr, + shape=filter_shape, + dtype=self._dtype, + default_initializer=_get_default_param_initializer()) + + # weight_g + norm_value = _norm( + self._filter_param_v.numpy(), dim=0) # CAUTION: hard-code + self._filter_param_g = self.create_parameter( + attr=fluid.ParamAttr( + initializer=fluid.initializer.NumpyArrayInitializer( + norm_value)), + shape=norm_value.shape, + dtype=self._dtype, + default_initializer=_get_default_param_initializer()) + + if self._use_cudnn: + self.create_variable( + name="kCUDNNFwdAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + self.create_variable( + name="kCUDNNBwdDataAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + self.create_variable( + name="kCUDNNBwdFilterAlgoCache", + persistable=True, + type=core.VarDesc.VarType.RAW) + + self._bias_param = self.create_parameter( + attr=self._bias_attr, + shape=[self._num_filters], + dtype=self._dtype, + is_bias=True) + + def forward(self, input): + matrix = self._helper.create_variable_for_type_inference(self._dtype) + tmp = self._helper.create_variable_for_type_inference(self._dtype) + new_shape = [ + self._filter_param_v.shape[0], + reduce(lambda x, y: x * y, self._filter_param_v.shape[1:], 1), + ] + + self._helper.append_op( + type="reshape2", + inputs={"X": self._filter_param_v}, + attrs={"shape": new_shape}, + outputs={"Out": matrix, + "XShape": tmp}) + + m_norm = self._helper.create_variable_for_type_inference(self._dtype) + m_normalized = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="norm", + inputs={"X": matrix}, + outputs={"Out": m_normalized, + "Norm": m_norm}, + attrs={"axis": 1, + "epsilon": self._epsilon}) + + v_normalized = self._helper.create_variable_for_type_inference( + self._dtype) + tmp2 = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type="reshape2", + inputs={"X": m_normalized}, + attrs={"shape": self._filter_param_v.shape}, + outputs={"Out": v_normalized, + "XShape": tmp2}) + + filter_param = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="elementwise_mul", + inputs={"X": [v_normalized], + "Y": [self._filter_param_g]}, + outputs={"Out": [filter_param]}, + attrs={"axis": 0}, # CAUTION: hard-code + ) + + pre_bias = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + + self._helper.append_op( + type=self._l_type, + inputs={"Input": input, + "Filter": filter_param}, + outputs={"Output": pre_bias}, + attrs={ + "strides": self._stride, + "paddings": self._padding, + "dilations": self._dilation, + "groups": self._groups if self._groups else 1, + "use_cudnn": self._use_cudnn, + "use_mkldnn": False, + }) + + if self._bias_param is not None: + pre_act = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type="elementwise_add", + inputs={"X": [pre_bias], + "Y": [self._bias_param]}, + outputs={"Out": [pre_act]}, + attrs={"axis": 1}) + else: + pre_act = pre_bias + + # Currently, we don't support inplace in dygraph mode + return self._helper.append_activation(pre_act, act=self._act) + + +class Conv2DTranspose(dg.Layer): + """ + **Convlution2D transpose layer** + The convolution2D transpose layer calculates the output based on the input, + filter, and dilations, strides, paddings. Input(Input) and output(Output) + are in NCHW format. Where N is batch size, C is the number of channels, + H is the height of the feature, and W is the width of the feature. + Parameters(dilations, strides, paddings) are two elements. These two elements + represent height and width, respectively. The details of convolution transpose + layer, please refer to the following explanation and references + `therein `_. + If bias attribution and activation type are provided, bias is added to + the output of the convolution, and the corresponding activation function + is applied to the final result. + For each input :math:`X`, the equation is: + .. math:: + Out = \sigma ((Vg) \\ast X + b) + Where: + * :math:`X`: Input value, a tensor with NCHW format. + * :math:`V`: Filter value, a tensor with MCHW format. + * :math:`g`: Filter value, a tensor with M format. + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D tensor with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + Example: + - Input: + Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` + Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)` + - Output: + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` + Where + .. math:: + H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\ + W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\ + H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\ + W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] ) + Args: + name_scope(str): The name of this class. + num_filters(int): The number of the filter. It is as same as the output + image channel. + output_size(int|tuple|None): The output image size. If output size is a + tuple, it must contain two integers, (image_H, image_W). None if use + filter_size, padding, and stride to calculate output_size. + if output_size and filter_size are specified at the same time, They + should follow the formula above. Default: None. + filter_size(int|tuple|None): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. None if use output size to + calculate filter_size. Default: None. + padding(int|tuple): The padding size. If padding is a tuple, it must + contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: padding = 0. + stride(int|tuple): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: stride = 1. + dilation(int|tuple): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: dilation = 1. + groups(int): The groups number of the Conv2d transpose layer. Inspired by + grouped convolution in Alex Krizhevsky's Deep CNN paper, in which + when group=2, the first half of the filters is only connected to the + first half of the input channels, while the second half of the + filters is only connected to the second half of the input channels. + Default: groups = 1. + param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights + of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. Default: None. + bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d_transpose. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv2d_transpose + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True. + act (str): Activation type, if it is set to None, activation is not appended. + Default: None. + Returns: + Variable: The tensor variable storing the convolution transpose result. + Raises: + ValueError: If the shapes of input, filter_size, stride, padding and + groups mismatch. + Examples: + .. code-block:: python + import paddle.fluid as fluid + import numpy + with fluid.dygraph.guard(): + data = numpy.random.random((3, 32, 32)).astype('float32') + conv2DTranspose = fluid.dygraph.nn.Conv2DTranspose( + 'Conv2DTranspose', num_filters=2, filter_size=3) + ret = conv2DTranspose(fluid.dygraph.base.to_variable(data)) + """ + + def __init__(self, + name_scope, + num_filters, + output_size=None, + filter_size=None, + padding=0, + stride=1, + dilation=1, + groups=None, + param_attr=None, + bias_attr=None, + use_cudnn=True, + epsilon=1e-30, + act=None, + dtype="float32"): + super(Conv2DTranspose, self).__init__(name_scope, dtype) + assert (param_attr is not False + ), "param_attr should not be False in conv2d_transpose." + self._param_attr = param_attr + self._bias_attr = bias_attr + self._groups = groups + self._num_filters = num_filters + self._use_cudnn = use_cudnn + self._padding = padding + self._stride = stride + self._dilation = dilation + self._filter_size = filter_size + self._output_size = output_size + self._op_type = "conv2d_transpose" + self._epsilon = epsilon + + def _build_once(self, input): + input_channel = input.shape[1] + if (input_channel == self._groups and + self._num_filters == input_channel and not self._use_cudnn): + self._op_type = "depthwise_conv2d_transpose" + + if not isinstance(input, Variable): + raise TypeError("Input of conv2d_transpose must be Variable") + + self._padding = utils.convert_to_list(self._padding, 2, "padding") + self._stride = utils.convert_to_list(self._stride, 2, "stride") + self._dilation = utils.convert_to_list(self._dilation, 2, "dilation") + + if not isinstance(self._use_cudnn, bool): + raise ValueError("use_cudnn should be True or False") + + if self._filter_size is None: + if self._output_size is None: + raise ValueError( + "output_size must be set when filter_size is None") + if isinstance(self._output_size, int): + self._output_size = [self._output_size, self._output_size] + + h_in = input.shape[2] + w_in = input.shape[3] + + filter_size_h = (self._output_size[0] - + (h_in - 1) * self._stride[0] + 2 * + self._padding[0] - 1) // self._dilation[0] + 1 + filter_size_w = (self._output_size[1] - + (w_in - 1) * self._stride[1] + 2 * + self._padding[1] - 1) // self._dilation[1] + 1 + self._filter_size = [filter_size_h, filter_size_w] + else: + self._filter_size = utils.convert_to_list( + self._filter_size, 2, "conv2d_transpose.filter_size") + + if self._output_size is None: + self._output_size = [] + elif isinstance(self._output_size, list) or isinstance( + self._output_size, int): + self._output_size = utils.convert_to_list(self._output_size, 2, + "output_size") + else: + raise ValueError("output_size should be list or int") + self._padding = utils.convert_to_list(self._padding, 2, "padding") + self._groups = 1 if self._groups is None else self._groups + filter_shape = [ + input_channel, + self._num_filters // self._groups, + ] + self._filter_size + + # img filter v (direction) + self._img_filter_v = self.create_parameter( + dtype=input.dtype, shape=filter_shape, attr=self._param_attr) + + # img filter g (magnitude) + img_filter_magnitude = _norm( + self._img_filter_v.numpy(), dim=0) # CAUTION: hard-code + self._img_filter_g = self.create_parameter( + dtype=input.dtype, + shape=img_filter_magnitude.shape, + attr=fluid.ParamAttr( + initializer=NumpyArrayInitializer(img_filter_magnitude))) + + self._img_bias = self.create_parameter( + attr=self._bias_attr, + shape=[self._num_filters], + dtype=self._dtype, + is_bias=True) + + def forward(self, input): + matrix = self._helper.create_variable_for_type_inference(self._dtype) + tmp = self._helper.create_variable_for_type_inference(self._dtype) + new_shape = [ + self._img_filter_v.shape[0], + reduce(lambda x, y: x * y, self._img_filter_v.shape[1:], 1), + ] + + self._helper.append_op( + type="reshape2", + inputs={"X": self._img_filter_v}, + attrs={"shape": new_shape}, + outputs={"Out": matrix, + "XShape": tmp}) + + m_norm = self._helper.create_variable_for_type_inference(self._dtype) + m_normalized = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="norm", + inputs={"X": matrix}, + outputs={"Out": m_normalized, + "Norm": m_norm}, + attrs={"axis": 1, + "epsilon": self._epsilon}) + + v_normalized = self._helper.create_variable_for_type_inference( + self._dtype) + tmp2 = self._helper.create_variable_for_type_inference(self._dtype) + self._helper.append_op( + type="reshape2", + inputs={"X": m_normalized}, + attrs={"shape": self._img_filter_v.shape}, + outputs={"Out": v_normalized, + "XShape": tmp2}) + + img_filter = self._helper.create_variable_for_type_inference( + self._dtype) + self._helper.append_op( + type="elementwise_mul", + inputs={"X": [v_normalized], + "Y": [self._img_filter_g]}, + outputs={"Out": [img_filter]}, + attrs={"axis": 0}, # CAUTION: hard-code + ) + + pre_bias = self._helper.create_variable_for_type_inference( + dtype=input.dtype) + self._helper.append_op( + type=self._op_type, + inputs={"Input": [input], + "Filter": [img_filter]}, + outputs={"Output": pre_bias}, + attrs={ + "output_size": self._output_size, + "strides": self._stride, + "paddings": self._padding, + "dilations": self._dilation, + "groups": self._groups, + "use_cudnn": self._use_cudnn, + }) + + if self._img_bias is not None: + pre_act = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type="elementwise_add", + inputs={"X": [pre_bias], + "Y": [self._img_bias]}, + outputs={"Out": [pre_act]}, + attrs={"axis": 1}) + else: + pre_act = pre_bias + + out = self._helper.append_activation(pre_act) + return out