working integraton with parakeet

This commit is contained in:
Kexin Zhao 2019-12-02 14:00:53 -08:00
parent 8c36f4539c
commit b15c313423
13 changed files with 2493 additions and 1 deletions

8
.gitignore vendored
View File

@ -129,4 +129,10 @@ venv.bak/
dmypy.json
# Pyre type checker
.pyre/
.pyre/
# Shell, vim, and output folder
*.sh
*.swp
runs
syn_audios

View File

@ -31,6 +31,9 @@ class DataCargo(object):
def __iter__(self):
return DataIterator(self)
def __call__(self):
return DataIterator(self)
@property
def _auto_collation(self):

View File

@ -0,0 +1 @@
# WaveNet-Paddle

View File

@ -0,0 +1,32 @@
valid_size: 16
train_clip_second: 0.5
sample_rate: 22050
fft_window_shift: 256
fft_window_size: 1024
fft_size: 2048
mel_bands: 80
seed: 1
batch_size: 8
test_every: 2000
save_every: 10000
max_iterations: 2000000
layers: 30
kernel_width: 2
dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
residual_channels: 128
skip_channels: 128
loss_type: mix-gaussian-pdf
num_mixtures: 1
log_scale_min: -9.0
conditioner:
filter_sizes: [[32, 3], [32, 3]]
upsample_factors: [16, 16]
learning_rate: 0.001
gradient_max_norm: 100.0
anneal:
every: 200000
rate: 0.5

View File

@ -0,0 +1,191 @@
import math
import os
import random
import librosa
import numpy as np
from paddle import fluid
import utils
from parakeet.datasets import ljspeech
from parakeet.data import dataset
from parakeet.data.sampler import Sampler, BatchSampler, SequentialSampler
from parakeet.data.datacargo import DataCargo
class Dataset(ljspeech.LJSpeech):
def __init__(self, config):
super(Dataset, self).__init__(config.root)
self.config = config
self.fft_window_shift = config.fft_window_shift
# Calculate context frames.
frames_per_second = config.sample_rate // self.fft_window_shift
train_clip_frames = int(math.ceil(
config.train_clip_second * frames_per_second))
context_frames = config.context_size // self.fft_window_shift
self.num_frames = train_clip_frames + context_frames
def _get_example(self, metadatum):
fname, _, _ = metadatum
wav_path = self.root.joinpath("wavs", fname + ".wav")
config = self.config
sr = config.sample_rate
fft_window_shift = config.fft_window_shift
fft_window_size = config.fft_window_size
fft_size = config.fft_size
audio, loaded_sr = librosa.load(wav_path, sr=None)
assert loaded_sr == sr
# Pad audio to the right size.
frames = math.ceil(float(audio.size) / fft_window_shift)
fft_padding = (fft_size - fft_window_shift) // 2
desired_length = frames * fft_window_shift + fft_padding * 2
pad_amount = (desired_length - audio.size) // 2
if audio.size % 2 == 0:
audio = np.pad(audio, (pad_amount, pad_amount), mode='reflect')
else:
audio = np.pad(audio, (pad_amount, pad_amount + 1), mode='reflect')
# Normalize audio.
audio = audio / np.abs(audio).max() * 0.999
# Compute mel-spectrogram.
# Turn center to False to prevent internal padding.
spectrogram = librosa.core.stft(
audio, hop_length=fft_window_shift,
win_length=fft_window_size, n_fft=fft_size, center=False)
spectrogram_magnitude = np.abs(spectrogram)
# Compute mel-spectrograms.
mel_filter_bank = librosa.filters.mel(sr=sr, n_fft=fft_size,
n_mels=config.mel_bands)
mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude)
mel_spectrogram = mel_spectrogram.T
# Rescale mel_spectrogram.
min_level, ref_level = 1e-5, 20
mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram))
mel_spectrogram = mel_spectrogram - ref_level
mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1)
# Extract the center of audio that corresponds to mel spectrograms.
audio = audio[fft_padding : -fft_padding]
assert mel_spectrogram.shape[0] * fft_window_shift == audio.size
return audio, mel_spectrogram
class Subset(dataset.Dataset):
def __init__(self, dataset, indices, valid):
self.dataset = dataset
self.indices = indices
self.valid = valid
def __getitem__(self, idx):
fft_window_shift = self.dataset.fft_window_shift
num_frames = self.dataset.num_frames
audio, mel = self.dataset[self.indices[idx]]
if self.valid:
audio_start = 0
else:
# Randomly crop context + train_clip_second of audio.
audio_frames = int(audio.size) // fft_window_shift
max_start_frame = audio_frames - num_frames
assert max_start_frame >= 0, "audio {} is too short".format(idx)
frame_start = random.randint(0, max_start_frame)
frame_end = frame_start + num_frames
audio_start = frame_start * fft_window_shift
audio_end = frame_end * fft_window_shift
audio = audio[audio_start : audio_end]
return audio, mel, audio_start
def _batch_examples(self, batch):
audios = [sample[0] for sample in batch]
audio_starts = [sample[2] for sample in batch]
# mels shape [num_frames, mel_bands]
max_frames = max(sample[1].shape[0] for sample in batch)
mels = [utils.pad_to_size(sample[1], max_frames) for sample in batch]
audios = np.array(audios, dtype=np.float32)
mels = np.array(mels, dtype=np.float32)
audio_starts = np.array(audio_starts, dtype=np.int32)
return audios, mels, audio_starts
def __len__(self):
return len(self.indices)
class DistributedSampler(Sampler):
def __init__(self, dataset_size, num_trainers, rank, shuffle=True):
self.dataset_size = dataset_size
self.num_trainers = num_trainers
self.rank = rank
self.num_samples = int(math.ceil(dataset_size / num_trainers))
self.total_size = self.num_samples * num_trainers
assert self.total_size >= self.dataset_size
self.shuffle = shuffle
def __iter__(self):
indices = list(range(self.dataset_size))
if self.shuffle:
random.shuffle(indices)
# Append extra samples to make it evenly distributed on all trainers.
indices += indices[:(self.total_size - self.dataset_size)]
assert len(indices) == self.total_size
# Subset samples for each trainer.
indices = indices[self.rank:self.total_size:self.num_trainers]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
class LJSpeech:
def __init__(self, config, nranks, rank):
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
# Whole LJSpeech dataset.
ds = Dataset(config)
# Split into train and valid dataset.
indices = list(range(len(ds)))
train_indices = indices[config.valid_size:]
valid_indices = indices[:config.valid_size]
random.shuffle(train_indices)
# Train dataset.
trainset = Subset(ds, train_indices, valid=False)
sampler = DistributedSampler(len(trainset), nranks, rank)
total_bs = config.batch_size
assert total_bs % nranks == 0
train_sampler = BatchSampler(sampler, total_bs // nranks,
drop_last=True)
trainloader = DataCargo(trainset, batch_sampler=train_sampler)
trainreader = fluid.io.PyReader(capacity=50, return_list=True)
trainreader.decorate_batch_generator(trainloader, place)
self.trainloader = (data for _ in iter(int, 1)
for data in trainreader())
# Valid dataset.
validset = Subset(ds, valid_indices, valid=True)
# Currently only support batch_size = 1 for valid loader.
validloader = DataCargo(validset, batch_size=1, shuffle=False)
validreader = fluid.io.PyReader(capacity=20, return_list=True)
validreader.decorate_batch_generator(validloader, place)
self.validloader = validreader

View File

@ -0,0 +1,249 @@
import paddle
from paddle import fluid
import paddle.fluid.dygraph as dg
import numpy as np
import weight_norm
def Embedding(name_scope,
num_embeddings,
embed_dim,
padding_idx=None,
std=0.1,
dtype="float32"):
# param attrs
weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=std))
layer = dg.Embedding(
name_scope, (num_embeddings, embed_dim),
padding_idx=padding_idx,
param_attr=weight_attr,
dtype=dtype)
return layer
def FC(name_scope,
in_features,
size,
num_flatten_dims=1,
relu=False,
dropout=0.0,
act=None,
dtype="float32"):
"""
A special Linear Layer, when it is used with dropout, the weight is
initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
"""
# stds
if isinstance(in_features, int):
in_features = [in_features]
stds = [np.sqrt((1.0 - dropout) / in_feature) for in_feature in in_features]
if relu:
stds = [std * np.sqrt(2.0) for std in stds]
weight_inits = [
fluid.initializer.NormalInitializer(scale=std) for std in stds
]
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits]
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = weight_norm.FC(name_scope,
size,
num_flatten_dims=num_flatten_dims,
param_attr=weight_attrs,
bias_attr=bias_attr,
act=act,
dtype=dtype)
return layer
def Conv1D(name_scope,
in_channels,
num_filters,
filter_size=2,
dilation=1,
groups=None,
causal=False,
std_mul=1.0,
dropout=0.0,
use_cudnn=True,
act=None,
dtype="float32"):
"""
A special Conv1D Layer, when it is used with dropout, the weight is
initialized as
normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_channels)))
"""
# std
std = np.sqrt((std_mul * (1.0 - dropout)) / (filter_size * in_channels))
weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std)
bias_init = fluid.initializer.ConstantInitializer(0.0)
# param attrs
weight_attr = fluid.ParamAttr(initializer=weight_init)
bias_attr = fluid.ParamAttr(initializer=bias_init)
layer = weight_norm.Conv1D(
name_scope,
num_filters,
filter_size,
dilation,
groups=groups,
causal=causal,
param_attr=weight_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer
class Conv1D_GU(dg.Layer):
def __init__(self,
name_scope,
conditioner_dim,
in_channels,
num_filters,
filter_size,
dilation,
causal=False,
residual=True,
dtype="float32"):
super(Conv1D_GU, self).__init__(name_scope, dtype=dtype)
self.conditioner_dim = conditioner_dim
self.in_channels = in_channels
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.residual = residual
if residual:
assert (
in_channels == num_filters
), "this block uses residual connection"\
"the input_channels should equals num_filters"
self.conv = Conv1D(
self.full_name(),
in_channels,
2 * num_filters,
filter_size,
dilation,
causal=causal,
dtype=dtype)
self.fc = Conv1D(
self.full_name(),
conditioner_dim,
2 * num_filters,
filter_size=1,
dilation=1,
causal=False,
dtype=dtype)
def forward(self, x, skip=None, conditioner=None):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
layer, where B means batch_size, C_in means the input channels
T means input time steps.
conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
conditioner, where C_con is conditioner hidden dim which
equals the num of mel bands. Note that when using residual
connection, the Conv1DGLU does not change the number of
channels, so out channels equals input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
C_out means the output channels of Conv1DGLU.
"""
residual = x
x = self.conv(x)
if conditioner is not None:
cond_bias = self.fc(conditioner)
x += cond_bias
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
# Gated Unit.
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate),
fluid.layers.tanh(content))
if skip is None:
skip = x
else:
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
if self.residual:
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
return x, skip
def add_input(self, x, skip=None, conditioner=None):
"""
Inputs:
x: shape(B, num_filters, 1, time_steps)
conditioner: shape(B, conditioner_dim, 1, time_steps)
Outputs:
out: shape(B, num_filters, 1, time_steps), where time_steps = 1
"""
residual = x
# add step input and produce step output
x = self.conv.add_input(x)
if conditioner is not None:
cond_bias = self.fc(conditioner)
x += cond_bias
content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
# Gated Unit.
x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate),
fluid.layers.tanh(content))
if skip is None:
skip = x
else:
skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
if self.residual:
x = fluid.layers.scale(residual + x, np.sqrt(0.5))
return x, skip
def Conv2DTranspose(name_scope,
num_filters,
filter_size,
padding=0,
stride=1,
dilation=1,
use_cudnn=True,
act=None,
dtype="float32"):
val = 1.0 / (filter_size[0] * filter_size[1])
weight_init = fluid.initializer.ConstantInitializer(val)
weight_attr = fluid.ParamAttr(initializer=weight_init)
layer = weight_norm.Conv2DTranspose(
name_scope,
num_filters,
filter_size=filter_size,
padding=padding,
stride=stride,
dilation=dilation,
param_attr=weight_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
return layer

View File

@ -0,0 +1,112 @@
"""
Utility module for restarting training when using SLURM.
"""
import subprocess
import os
import sys
import shlex
import re
import time
def job_info():
"""Get information about the current job using `scontrol show job`.
Returns a dict mapping parameter names (e.g. "UserId", "RunTime", etc) to
their values, both as strings.
"""
job_id = int(os.environ["SLURM_JOB_ID"])
command = ["scontrol", "show", "job", str(job_id)]
output = subprocess.check_output(command).decode("utf-8")
# Use a regex to extract the parameter names and values
pattern = "([A-Za-z/]*)=([^ \t\n]*)"
return dict(re.findall(pattern, output))
def parse_hours(text):
"""Parse a time format HH or DD-HH into a number of hours."""
hour_chunks = text.split("-")
if len(hour_chunks) == 1:
return int(hour_chunks[0])
elif len(hour_chunks) == 2:
return 24 * int(hour_chunks[0]) + int(hour_chunks[1])
else:
raise ValueError("Unexpected hour format (expected HH or "
"DD-HH, but got {}).".format(text))
def parse_time(text):
"""Convert slurm time to an integer.
Expects time to be of the form:
"hours:minutes:seconds" or "day-hours:minutes:seconds".
"""
hours, minutes, seconds = text.split(":")
try:
return parse_hours(hours) * 3600 + int(minutes) * 60 + int(seconds)
except ValueError as e:
raise ValueError("Error parsing time {}. Got error {}.".format(
text, str(e)))
def restart_command():
"""Using the environment and SLURM command, create a command that, when,
run, will enqueue a repeat of the current job using `sbatch`.
Return the command as a list of strings, suitable for passing to
`subprocess.check_call` or similar functions.
Returns:
resume_command: list<str>, command to run to restart job.
end_time: int or None; the time the job will end or None
if the job has unlimited runtime.
"""
# Make sure `RunTime` could be parsed correctly.
while job_info()["RunTime"] == "INVALID":
time.sleep(1)
# Get all the necessary information by querying SLURM with this job id
info = job_info()
try:
num_cpus = int(info["CPUs/Task"])
except KeyError:
num_cpus = int(os.environ["SLURM_CPUS_PER_TASK"])
num_tasks = int(os.environ["SLURM_NTASKS"])
nodes = info["NumNodes"]
gres, partition = info.get("Gres"), info.get("Partition")
stderr, stdout = info.get("StdErr"), info.get("StdOut")
job_name = info.get("JobName")
command = ["sbatch", "--job-name={}".format(job_name),
"--ntasks={}".format(num_tasks)]
if partition:
command.extend(["--partition", partition])
if gres and gres != "(null)":
command.extend(["--gres", gres])
num_gpu = int(gres.split(':')[-1])
print("number of gpu assigned by slurm is {}".format(num_gpu))
if stderr:
command.extend(["--error", stderr])
if stdout:
command.extend(["--output", stdout])
python = subprocess.check_output(
["/usr/bin/which", "python3"]).decode("utf-8").strip()
dist_setting = ['-m', 'paddle.distributed.launch']
wrap_cmd = ["srun", python, '-u'] + dist_setting + sys.argv
command.append(
"--wrap={}".format(" ".join(shlex.quote(arg) for arg in wrap_cmd)))
time_limit_string = info["TimeLimit"]
if time_limit_string.lower() == "unlimited":
print("UNLIMITED detected: restart OFF, infinite learning ON.",
flush=True)
return command, None
time_limit = parse_time(time_limit_string)
runtime = parse_time(info["RunTime"])
end_time = time.time() + time_limit - runtime
return command, end_time

View File

@ -0,0 +1,85 @@
import os
import random
from pprint import pprint
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
import utils
from wavenet import WaveNet
def add_options_to_parser(parser):
parser.add_argument('--model', type=str, default='wavenet',
help="general name of the model")
parser.add_argument('--name', type=str,
help="specific name of the training model")
parser.add_argument('--root', type=str,
help="root path of the LJSpeech dataset")
parser.add_argument('--use_gpu', type=bool, default=True,
help="option to use gpu training")
parser.add_argument('--iteration', type=int, default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument('--checkpoint', type=str, default=None,
help="path of the checkpoint to load")
parser.add_argument('--output', type=str, default="./syn_audios",
help="path to write synthesized audio files")
parser.add_argument('--sample', type=int,
help="which of the valid samples to synthesize audio")
def synthesize(config):
pprint(jsonargparse.namespace_to_dict(config))
# Get checkpoint directory path.
run_dir = os.path.join("runs", config.model, config.name)
checkpoint_dir = os.path.join(run_dir, "checkpoint")
# Configurate device.
place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace()
with dg.guard(place):
# Fix random seed.
seed = config.seed
random.seed(seed)
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
print("Random Seed: ", seed)
# Build model.
model = WaveNet(config, checkpoint_dir)
model.build(training=False)
# Obtain the current iteration.
if config.checkpoint is None:
if config.iteration is None:
iteration = utils.load_latest_checkpoint(checkpoint_dir)
else:
iteration = config.iteration
else:
iteration = int(config.checkpoint.split('/')[-1].split('-')[-1])
# Run model inference.
model.infer(iteration)
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(
description="Synthesize audio using WaveNet model",
formatter_class='default_argparse')
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
# Parse argument from both command line and yaml config file.
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
synthesize(config)

View File

@ -0,0 +1,139 @@
import os
import random
import subprocess
import time
from pprint import pprint
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from tensorboardX import SummaryWriter
import slurm
import utils
from wavenet import WaveNet
MAXIMUM_SAVE_TIME = 10 * 60
def add_options_to_parser(parser):
parser.add_argument('--model', type=str, default='wavenet',
help="general name of the model")
parser.add_argument('--name', type=str,
help="specific name of the training model")
parser.add_argument('--root', type=str,
help="root path of the LJSpeech dataset")
parser.add_argument('--parallel', type=bool, default=True,
help="option to use data parallel training")
parser.add_argument('--use_gpu', type=bool, default=True,
help="option to use gpu training")
parser.add_argument('--iteration', type=int, default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument('--checkpoint', type=str, default=None,
help="path of the checkpoint to load")
parser.add_argument('--slurm', type=bool, default=False,
help="whether you are using slurm to submit training jobs")
def train(config):
use_gpu = config.use_gpu
parallel = config.parallel if use_gpu else False
# Get the rank of the current training process.
rank = dg.parallel.Env().local_rank if parallel else 0
nranks = dg.parallel.Env().nranks if parallel else 1
if rank == 0:
# Print the whole config setting.
pprint(jsonargparse.namespace_to_dict(config))
# Make checkpoint directory.
run_dir = os.path.join("runs", config.model, config.name)
checkpoint_dir = os.path.join(run_dir, "checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)
# Create tensorboard logger.
tb = SummaryWriter(os.path.join(run_dir, "logs")) \
if rank == 0 else None
# Configurate device
place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace()
with dg.guard(place):
# Fix random seed.
seed = config.seed
random.seed(seed)
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
print("Random Seed: ", seed)
# Build model.
model = WaveNet(config, checkpoint_dir, parallel, rank, nranks, tb)
model.build()
# Obtain the current iteration.
if config.checkpoint is None:
if config.iteration is None:
iteration = utils.load_latest_checkpoint(checkpoint_dir, rank)
else:
iteration = config.iteration
else:
iteration = int(config.checkpoint.split('/')[-1].split('-')[-1])
# Get restart command if using slurm.
if config.slurm:
resume_command, death_time = slurm.restart_command()
if rank == 0:
print("Restart command:", " ".join(resume_command))
done = False
while iteration < config.max_iterations:
# Run one single training step.
model.train_step(iteration)
iteration += 1
if iteration % config.test_every == 0:
# Run validation step.
model.valid_step(iteration)
# Check whether reaching the time limit.
if config.slurm:
done = (death_time is not None and death_time - time.time() <
MAXIMUM_SAVE_TIME)
if rank == 0 and done:
print("Saving progress before exiting.")
model.save(iteration)
print("Running restart command:", " ".join(resume_command))
# Submit restart command.
subprocess.check_call(resume_command)
break
if rank == 0 and iteration % config.save_every == 0:
# Save parameters.
model.save(iteration)
# Close TensorBoard.
if rank == 0:
tb.close()
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(description="Train WaveNet model",
formatter_class='default_argparse')
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
# Parse argument from both command line and yaml config file.
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
train(config)

View File

@ -0,0 +1,143 @@
import itertools
import os
import time
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
def add_config_options_to_parser(parser):
parser.add_argument('--valid_size', type=int,
help="size of the valid dataset")
parser.add_argument('--train_clip_second', type=float,
help="the length of audio clip for training")
parser.add_argument('--sample_rate', type=int,
help="sampling rate of audio data file")
parser.add_argument('--fft_window_shift', type=int,
help="the shift of fft window for each frame")
parser.add_argument('--fft_window_size', type=int,
help="the size of fft window for each frame")
parser.add_argument('--fft_size', type=int,
help="the size of fft filter on each frame")
parser.add_argument('--mel_bands', type=int,
help="the number of mel bands when calculating mel spectrograms")
parser.add_argument('--seed', type=int,
help="seed of random initialization for the model")
parser.add_argument('--batch_size', type=int,
help="batch size for training")
parser.add_argument('--test_every', type=int,
help="test interval during training")
parser.add_argument('--save_every', type=int,
help="checkpointing interval during training")
parser.add_argument('--max_iterations', type=int,
help="maximum training iterations")
parser.add_argument('--layers', type=int,
help="number of dilated convolution layers")
parser.add_argument('--kernel_width', type=int,
help="dilated convolution kernel width")
parser.add_argument('--dilation_block', type=list,
help="dilated convolution kernel width")
parser.add_argument('--residual_channels', type=int)
parser.add_argument('--skip_channels', type=int)
parser.add_argument('--loss_type', type=str,
help="mix-gaussian-pdf or softmax")
parser.add_argument('--num_channels', type=int, default=None,
help="number of channels for softmax output")
parser.add_argument('--num_mixtures', type=int, default=None,
help="number of gaussian mixtures for gaussian output")
parser.add_argument('--log_scale_min', type=float, default=None,
help="minimum clip value of log variance of gaussian output")
parser.add_argument('--conditioner.filter_sizes', type=list,
help="conv2d tranpose op filter sizes for building conditioner")
parser.add_argument('--conditioner.upsample_factors', type=list,
help="list of upsample factors for building conditioner")
parser.add_argument('--learning_rate', type=float)
parser.add_argument('--gradient_max_norm', type=float)
parser.add_argument('--anneal.every', type=int,
help="step interval for annealing learning rate")
parser.add_argument('--anneal.rate', type=float)
parser.add_argument('--config', action=jsonargparse.ActionConfigFile)
def pad_to_size(array, length, pad_with=0.0):
"""
Pad an array on the first (length) axis to a given length.
"""
padding = length - array.shape[0]
assert padding >= 0, "Padding required was less than zero"
paddings = [(0, 0)] * len(array.shape)
paddings[0] = (0, padding)
return np.pad(array, paddings, mode='constant', constant_values=pad_with)
def calculate_context_size(config):
dilations = list(
itertools.islice(
itertools.cycle(config.dilation_block), config.layers))
config.context_size = sum(dilations) + 1
print("Context size is", config.context_size)
def load_latest_checkpoint(checkpoint_dir, rank=0):
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Create checkpoint index file if not exist.
if (not os.path.isfile(checkpoint_path)) and rank == 0:
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-0")
# Make sure that other process waits until checkpoint file is created
# by process 0.
while not os.path.isfile(checkpoint_path):
time.sleep(1)
# Fetch the latest checkpoint index.
with open(checkpoint_path, "r") as handle:
latest_checkpoint = handle.readline().split()[-1]
iteration = int(latest_checkpoint.split("-")[-1])
return iteration
def save_latest_checkpoint(checkpoint_dir, iteration):
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-{}".format(iteration))
def load_parameters(checkpoint_dir, rank, model, optimizer=None,
iteration=None, file_path=None):
if file_path is None:
if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank)
if iteration == 0:
return
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict, optimizer_dict = dg.load_dygraph(file_path)
model.set_dict(model_dict)
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
if optimizer and optimizer_dict:
optimizer.set_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
rank, file_path))
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict = model.state_dict()
dg.save_dygraph(model_dict, file_path)
print("[checkpoint] Saved model to {}".format(file_path))
if optimizer:
opt_dict = optimizer.state_dict()
dg.save_dygraph(opt_dict, file_path)
print("[checkpoint] Saved optimzier state to {}".format(file_path))

View File

@ -0,0 +1,188 @@
import itertools
import os
import time
import librosa
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
import utils
from data import LJSpeech
from wavenet_modules import WaveNetModule, debug
class WaveNet():
def __init__(self, config, checkpoint_dir, parallel=False, rank=0,
nranks=1, tb_logger=None):
# Process config to calculate the context size
dilations = list(
itertools.islice(
itertools.cycle(config.dilation_block), config.layers))
config.context_size = sum(dilations) + 1
self.config = config
self.checkpoint_dir = checkpoint_dir
self.parallel = parallel
self.rank = rank
self.nranks = nranks
self.tb_logger = tb_logger
def build(self, training=True):
config = self.config
dataset = LJSpeech(config, self.nranks, self.rank)
self.trainloader = dataset.trainloader
self.validloader = dataset.validloader
# if self.rank == 0:
# for i, (audios, mels, ids) in enumerate(self.validloader()):
# print("audios {}, mels {}, ids {}".format(audios.dtype, mels.dtype, ids.dtype))
# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format(
# i, self.rank, audios.shape, mels.shape, ids.shape,
# ids.numpy()))
#
# for i, (audios, mels, ids) in enumerate(self.trainloader):
# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format(
# i, self.rank, audios.shape, mels.shape, ids.shape,
# ids.numpy()))
wavenet = WaveNetModule("wavenet", config, self.rank)
# Dry run once to create and initalize all necessary parameters.
audio = dg.to_variable(np.random.randn(1, 20000).astype(np.float32))
mel = dg.to_variable(
np.random.randn(1, 100, self.config.mel_bands).astype(np.float32))
audio_start = dg.to_variable(np.array([0], dtype=np.int32))
wavenet(audio, mel, audio_start)
if training:
# Create Learning rate scheduler.
lr_scheduler = dg.ExponentialDecay(
learning_rate = config.learning_rate,
decay_steps = config.anneal.every,
decay_rate = config.anneal.rate,
staircase=True)
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=lr_scheduler)
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
config.gradient_max_norm)
# Load parameters.
utils.load_parameters(self.checkpoint_dir, self.rank,
wavenet, optimizer,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
# Data parallelism.
if self.parallel:
strategy = dg.parallel.prepare_context()
wavenet = dg.parallel.DataParallel(wavenet, strategy)
self.wavenet = wavenet
self.optimizer = optimizer
self.clipper = clipper
else:
# Load parameters.
utils.load_parameters(self.checkpoint_dir, self.rank, wavenet,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
self.wavenet = wavenet
def train_step(self, iteration):
self.wavenet.train()
start_time = time.time()
audios, mels, audio_starts = next(self.trainloader)
load_time = time.time()
loss, _ = self.wavenet(audios, mels, audio_starts)
if self.parallel:
# loss = loss / num_trainers
loss = self.wavenet.scale_loss(loss)
loss.backward()
self.wavenet.apply_collective_grads()
else:
loss.backward()
if isinstance(self.optimizer._learning_rate,
fluid.optimizer.LearningRateDecay):
current_lr = self.optimizer._learning_rate.step().numpy()
else:
current_lr = self.optimizer._learning_rate
self.optimizer.minimize(loss, grad_clip=self.clipper,
parameter_list=self.wavenet.parameters())
self.wavenet.clear_gradients()
graph_time = time.time()
if self.rank == 0:
loss_val = float(loss.numpy()) * self.nranks
log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \
"Time: {:.3f}/{:.3f}".format(
self.rank, iteration, loss_val,
load_time - start_time, graph_time - load_time)
print(log)
tb = self.tb_logger
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
tb.add_scalar("Learning-Rate", current_lr, iteration)
@dg.no_grad
def valid_step(self, iteration):
self.wavenet.eval()
total_loss = []
start_time = time.time()
sample_audios = []
for audios, mels, audio_starts in self.validloader():
loss, sample_audio = self.wavenet(audios, mels, audio_starts, True)
total_loss.append(float(loss.numpy()))
sample_audios.append(sample_audio)
total_time = time.time() - start_time
if self.rank == 0:
loss_val = np.mean(total_loss)
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
self.rank, loss_val, total_time)
print(log)
tb = self.tb_logger
tb.add_scalar("Valid-Avg-Loss", loss_val, iteration)
tb.add_audio("Teacher-Forced-Audio-0", sample_audios[0].numpy(),
iteration, sample_rate=self.config.sample_rate)
tb.add_audio("Teacher-Forced-Audio-1", sample_audios[1].numpy(),
iteration, sample_rate=self.config.sample_rate)
def save(self, iteration):
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.wavenet, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
@dg.no_grad
def infer(self, iteration):
self.wavenet.eval()
config = self.config
sample = config.sample
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
os.makedirs(output, exist_ok=True)
filename = "{}/valid_{}.wav".format(output, sample)
print("Synthesize sample {}, save as {}".format(sample, filename))
mels_list = [mels for _, mels, _ in self.validloader()]
start_time = time.time()
syn_audio = self.wavenet.synthesize(mels_list[sample])
syn_time = time.time() - start_time
print("audio shape {}, synthesis time {}".format(
syn_audio.shape, syn_time))
librosa.output.write_wav(filename, syn_audio,
sr=config.sample_rate)

View File

@ -0,0 +1,423 @@
import itertools
import math
import numpy as np
from paddle import fluid
import paddle.fluid.dygraph as dg
import ops
import weight_norm
def get_padding(filter_size, stride, padding_type='same'):
if padding_type == 'same':
padding = [(x - y) // 2 for x, y in zip(filter_size, stride)]
else:
raise ValueError("Only support same padding")
return padding
def debug(x, var_name, rank, verbose=False):
if not verbose and rank != 0:
return
dim = len(x.shape)
if not isinstance(x, np.ndarray):
x = x.numpy()
if dim == 1:
print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x))
elif dim == 2:
print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5]))
elif dim == 3:
print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5, 0]))
else:
print("Rank", rank, var_name, "shape", x.shape)
def extract_slices(x, audio_starts, audio_length, rank):
slices = []
for i in range(x.shape[0]):
start = audio_starts.numpy()[i]
end = start + audio_length
slice = fluid.layers.slice(
x, axes=[0, 1], starts=[i, start], ends=[i+1, end])
slices.append(fluid.layers.squeeze(slice, [0]))
x = fluid.layers.stack(slices, axis=0)
return x
class Conditioner(dg.Layer):
def __init__(self, name_scope, config):
super(Conditioner, self).__init__(name_scope)
upsample_factors = config.conditioner.upsample_factors
filter_sizes = config.conditioner.filter_sizes
assert np.prod(upsample_factors) == config.fft_window_shift
self.deconvs = []
for i, up_scale in enumerate(upsample_factors):
stride = (up_scale, 1)
padding = get_padding(filter_sizes[i], stride)
self.deconvs.append(
ops.Conv2DTranspose(
self.full_name(),
num_filters=1,
filter_size=filter_sizes[i],
padding=padding,
stride=stride))
# Register python list as parameters.
for i, layer in enumerate(self.deconvs):
self.add_sublayer("conv_transpose_{}".format(i), layer)
def forward(self, x):
x = fluid.layers.unsqueeze(x, 1)
for layer in self.deconvs:
x = fluid.layers.leaky_relu(layer(x), alpha=0.4)
return fluid.layers.squeeze(x, [1])
class WaveNetModule(dg.Layer):
def __init__(self, name_scope, config, rank):
super(WaveNetModule, self).__init__(name_scope)
self.rank = rank
self.conditioner = Conditioner(self.full_name(), config)
self.dilations = list(
itertools.islice(
itertools.cycle(config.dilation_block), config.layers))
self.context_size = sum(self.dilations) + 1
self.log_scale_min = config.log_scale_min
self.config = config
print("dilations", self.dilations)
print("context_size", self.context_size)
if config.loss_type == "softmax":
self.embedding_fc = ops.Embedding(
self.full_name(),
num_embeddings=config.num_channels,
embed_dim=config.residual_channels)
elif config.loss_type == "mix-gaussian-pdf":
self.embedding_fc = ops.FC(
self.full_name(),
in_features=1,
size=config.residual_channels,
num_flatten_dims=2,
relu=False)
else:
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
self.dilated_causal_convs = []
for dilation in self.dilations:
self.dilated_causal_convs.append(
ops.Conv1D_GU(
self.full_name(),
conditioner_dim=config.mel_bands,
in_channels=config.residual_channels,
num_filters=config.residual_channels,
filter_size=config.kernel_width,
dilation=dilation,
causal=True
)
)
for i, layer in enumerate(self.dilated_causal_convs):
self.add_sublayer("dilated_causal_conv_{}".format(i), layer)
self.fc1 = ops.FC(
self.full_name(),
in_features=config.residual_channels,
size=config.skip_channels,
num_flatten_dims=2,
relu=True,
act="relu")
self.fc2 = ops.FC(
self.full_name(),
in_features=config.skip_channels,
size=config.skip_channels,
num_flatten_dims=2,
relu=True,
act="relu")
if config.loss_type == "softmax":
self.fc3 = ops.FC(
self.full_name(),
in_features=config.skip_channels,
size=config.num_channels,
num_flatten_dims=2,
relu=False)
elif config.loss_type == "mix-gaussian-pdf":
self.fc3 = ops.FC(
self.full_name(),
in_features=config.skip_channels,
size=3 * config.num_mixtures,
num_flatten_dims=2,
relu=False)
else:
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
def sample_softmax(self, mix_parameters):
batch, length, hidden = mix_parameters.shape
mix_param_2d = fluid.layers.reshape(mix_parameters,
[batch * length, hidden])
mix_param_2d = fluid.layers.softmax(mix_param_2d, axis=-1)
# quantized: [batch * length]
quantized = fluid.layers.cast(fluid.layers.sampling_id(mix_param_2d),
dtype="float32")
samples = (quantized + 0.5) * (2.0 / self.config.num_channels) - 1.0
# samples: [batch * length]
return samples
def sample_mix_gaussian(self, mix_parameters):
# mix_parameters reshape from [bs, 13799, 3 * num_mixtures]
# to [bs * 13799, 3 * num_mixtures].
batch, length, hidden = mix_parameters.shape
mix_param_2d = fluid.layers.reshape(mix_parameters,
[batch * length, hidden])
K = hidden // 3
# Unpack the parameters of the mixture of gaussian.
logits_pi = mix_param_2d[:, 0 : K]
mu = mix_param_2d[:, K : 2*K]
log_s = mix_param_2d[:, 2*K : 3*K]
s = fluid.layers.exp(log_s)
pi = fluid.layers.softmax(logits_pi, axis=-1)
comp_samples = fluid.layers.sampling_id(pi)
row_idx = dg.to_variable(np.arange(batch * length))
comp_samples = fluid.layers.stack([row_idx, comp_samples], axis=-1)
mu_comp = fluid.layers.gather_nd(mu, comp_samples)
s_comp = fluid.layers.gather_nd(s, comp_samples)
# N(0, 1) Normal Sample.
u = fluid.layers.gaussian_random(shape=[batch * length])
samples = mu_comp + u * s_comp
samples = fluid.layers.clip(samples, min=-1.0, max=1.0)
return samples
def softmax_loss(self, targets, mix_parameters):
# targets: [bs, 13799] -> [bs, 11752]
# mix_params: [bs, 13799, 3] -> [bs, 11752, 3]
targets = targets[:, self.context_size:]
mix_parameters = mix_parameters[:, self.context_size:, :]
# Quantized audios to integral values with range [0, num_channels)
num_channels = self.config.num_channels
targets = fluid.layers.clip(targets, min=-1.0, max=0.99999)
quantized = fluid.layers.cast(
(targets + 1.0) / 2.0 * num_channels, dtype="int64")
# per_sample_loss shape: [bs, 17952, 1]
per_sample_loss = fluid.layers.softmax_with_cross_entropy(
logits=mix_parameters, label=fluid.layers.unsqueeze(quantized, 2))
loss = fluid.layers.reduce_mean(per_sample_loss)
#debug(loss, "softmax loss", self.rank)
return loss
def mixture_density_loss(self, targets, mix_parameters, log_scale_min):
# targets: [bs, 13799] -> [bs, 11752]
# mix_params: [bs, 13799, 3] -> [bs, 11752, 3]
targets = targets[:, self.context_size:]
mix_parameters = mix_parameters[:, self.context_size:, :]
# log_s: [bs, 11752, num_mixture]
logits_pi, mu, log_s = fluid.layers.split(mix_parameters, num_or_sections=3, dim=-1)
pi = fluid.layers.softmax(logits_pi, axis=-1)
log_s = fluid.layers.clip(log_s, min=log_scale_min, max=100.0)
inv_s = fluid.layers.exp(0.0 - log_s)
# Calculate gaussian loss.
targets = fluid.layers.unsqueeze(targets, -1)
targets = fluid.layers.expand(targets, [1, 1, self.config.num_mixtures])
x_std = inv_s * (targets - mu)
exponent = fluid.layers.exp(-0.5 * x_std * x_std)
# pdf_x: [bs, 11752, 1]
pdf_x = 1.0 / np.sqrt(2.0 * np.pi) * inv_s * exponent
pdf_x = pi * pdf_x
# pdf_x: [bs, 11752]
pdf_x = fluid.layers.reduce_sum(pdf_x, dim=-1)
per_sample_loss = 0.0 - fluid.layers.log(pdf_x + 1e-9)
loss = fluid.layers.reduce_mean(per_sample_loss)
return loss
def forward(self, audios, mels, audio_starts, sample=False):
# audios: [bs, 13800], mels: [bs, full_frame_length, 80]
# audio_starts: [bs]
# Build conditioner based on mels.
full_conditioner = self.conditioner(mels)
# Slice conditioners.
audio_length = audios.shape[1]
conditioner = extract_slices(full_conditioner,
audio_starts, audio_length, self.rank)
# input_audio, target_audio: [bs, 13799]
input_audios = audios[:, :-1]
target_audios = audios[:, 1:]
# conditioner: [bs, 13799, 80]
conditioner = conditioner[:, 1:, :]
loss_type = self.config.loss_type
# layer_input: [bs, 13799, 128]
if loss_type == "softmax":
input_audios = fluid.layers.clip(
input_audios, min=-1.0, max=0.99999)
# quantized have values in [0, num_channels)
quantized = fluid.layers.cast(
(input_audios + 1.0) / 2.0 * self.config.num_channels,
dtype="int64")
layer_input = self.embedding_fc(fluid.layers.unsqueeze(quantized, 2))
elif loss_type == "mix-gaussian-pdf":
layer_input = self.embedding_fc(fluid.layers.unsqueeze(input_audios, 2))
else:
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
# layer_input: [bs, res_channel, 1, 13799]
layer_input = fluid.layers.unsqueeze(fluid.layers.transpose(layer_input, perm=[0, 2, 1]), 2)
# conditioner: [bs, mel_bands, 1, 13799]
conditioner = fluid.layers.unsqueeze(fluid.layers.transpose(conditioner, perm=[0, 2, 1]), 2)
# layer_input: [bs, res_channel, 1, 13799]
# skip: [bs, res_channel, 1, 13799]
skip = None
for i, layer in enumerate(self.dilated_causal_convs):
layer_input, skip = layer(layer_input, skip, conditioner)
#debug(layer_input, "layer_input_" + str(i), self.rank)
#debug(skip, "skip_" + str(i), self.rank)
# Reshape skip to [bs, 13799, res_channel]
skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
#debug(skip, "skip", self.rank)
# mix_param: [bs, 13799, 3 * num_mixtures]
mix_parameters = self.fc3(self.fc2(self.fc1(skip)))
# Sample teacher-forced audio.
sample_audios = None
if sample:
if loss_type == "softmax":
sample_audios = self.sample_softmax(mix_parameters)
elif loss_type == "mix-gaussian-pdf":
sample_audios = self.sample_mix_gaussian(mix_parameters)
else:
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
#debug(sample_audios, "sample_audios", self.rank)
# Calculate mix-gaussian density loss.
# padding is all zero.
# target_audio: [bs, 13799].
# mix_params: [bs, 13799, 3].
if loss_type == "softmax":
loss = self.softmax_loss(target_audios, mix_parameters)
elif loss_type == "mix-gaussian-pdf":
loss = self.mixture_density_loss(target_audios,
mix_parameters, self.log_scale_min)
else:
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
#print("Rank {}, loss {}".format(self.rank, loss.numpy()))
return loss, sample_audios
def synthesize(self, mels):
self.start_new_sequence()
print("input mels shape", mels.shape)
# mels: [bs=1, n_frames, 80]
# conditioner: [1, n_frames * samples_per_frame, 80]
# Should I move forward by one sample? No difference
# Append context frame to mels
bs, n_frames, mel_bands = mels.shape
#num_pad_frames = int(np.ceil(self.context_size / self.config.fft_window_shift))
#silence = fluid.layers.zeros(shape=[bs, num_pad_frames, mel_bands], dtype="float32")
#inf_mels = fluid.layers.concat([silence, mels], axis=1)
#print("padded mels shape", inf_mels.shape)
#conditioner = self.conditioner(inf_mels)[:, self.context_size:, :]
conditioner = self.conditioner(mels)
time_steps = conditioner.shape[1]
print("Total steps", time_steps)
loss_type = self.config.loss_type
audio_samples = []
current_sample = fluid.layers.zeros(shape=[bs, 1, 1], dtype="float32")
for i in range(time_steps):
if i % 100 == 0:
print("Step", i)
# convert from real value sample to audio embedding.
# [bs, 1, 128]
if loss_type == "softmax":
current_sample = fluid.layers.clip(
current_sample, min=-1.0, max=0.99999)
# quantized have values in [0, num_channels)
quantized = fluid.layers.cast(
(current_sample + 1.0) / 2.0 * self.config.num_channels,
dtype="int64")
audio_input = self.embedding_fc(quantized)
elif loss_type == "mix-gaussian-pdf":
audio_input = self.embedding_fc(current_sample)
else:
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
# [bs, 128, 1, 1]
audio_input = fluid.layers.unsqueeze(fluid.layers.transpose(audio_input, perm=[0, 2, 1]), 2)
# [bs, 80]
cond_input = conditioner[:, i, :]
# [bs, 80, 1, 1]
cond_input = fluid.layers.reshape(
cond_input, cond_input.shape + [1, 1])
skip = None
for layer in self.dilated_causal_convs:
audio_input, skip = layer.add_input(audio_input, skip, cond_input)
# [bs, 1, 128]
skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
# [bs, 1, 3]
mix_parameters = self.fc3(self.fc2(self.fc1(skip)))
if loss_type == "softmax":
sample = self.sample_softmax(mix_parameters)
elif loss_type == "mix-gaussian-pdf":
sample = self.sample_mix_gaussian(mix_parameters)
else:
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
audio_samples.append(sample)
# [bs]
current_sample = audio_samples[-1]
# [bs, 1, 1]
current_sample = fluid.layers.reshape(current_sample,
current_sample.shape + [1, 1])
# syn_audio: (num_samples,)
syn_audio = fluid.layers.concat(audio_samples, axis=0).numpy()
return syn_audio
def start_new_sequence(self):
for layer in self.sublayers():
if isinstance(layer, weight_norm.Conv1D):
layer.start_new_sequence()
def save(self, iteration):
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.wavenet, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)

View File

@ -0,0 +1,920 @@
import math
from copy import deepcopy
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from paddle.fluid import core
from paddle.fluid.framework import Variable
from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from paddle.fluid.layers import utils
from six.moves import reduce
def _norm(p, dim):
"""Computes the norm over all dimensions except dim.
It differs from pytorch implementation that it does not keep dim.
This difference is related with the broadcast mechanism in paddle.
Read elementeise_mul for more.
"""
if dim is None:
return np.linalg.norm(p, ord=2, axis=None)
elif dim == 0:
p = np.reshape(p, newshape=(p.shape[0], -1))
return np.linalg.norm(p, ord=2, axis=1)
elif dim == p.ndim - 1:
p = np.reshape(p, newshape=(-1, p.shape[-1]))
return np.linalg.norm(p, ord=2, axis=0)
else:
perm = list(range(p.ndim))
perm[0] = dim
perm[dim] = 0
return _norm(np.transpose(p, axes=perm))
class Conv1D(dg.Layer):
"""
A convolution 1D block implemented with Conv2D. Form simplicity and
ensuring the output has the same length as the input, it does not allow
stride > 1.
"""
def __init__(self,
name_scope,
num_filters,
filter_size=3,
dilation=1,
groups=None,
causal=False,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype="float32"):
super(Conv1D, self).__init__(name_scope, dtype=dtype)
if causal:
padding = dilation * (filter_size - 1)
else:
padding = (dilation * (filter_size - 1)) // 2
self.num_filters = num_filters
self.filter_size = filter_size
self.dilation = dilation
self.causal = causal
self.padding = padding
self.act = act
self.conv = Conv2D(
self.full_name(),
num_filters=num_filters,
filter_size=(1, filter_size),
stride=(1, 1),
dilation=(1, dilation),
padding=(0, padding),
groups=groups,
param_attr=param_attr,
bias_attr=bias_attr,
use_cudnn=use_cudnn,
act=act,
dtype=dtype)
def forward(self, x):
"""
Args:
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
input channels.
Returns:
x (Variable): Shape(B, C_out, 1, T), the outputs, where C_out means
output channels (num_filters).
"""
x = self.conv(x)
if self.filter_size > 1:
if self.causal:
x = fluid.layers.slice(
x, axes=[3], starts=[0], ends=[-self.padding])
elif self.filter_size % 2 == 0:
x = fluid.layers.slice(x, axes=[3], starts=[0], ends=[-1])
return x
def start_new_sequence(self):
self.temp_weight = None
self.input_buffer = None
def add_input(self, x):
"""
Adding input for a time step and compute an output for a time step.
Args:
x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
input channels, and T = 1.
Returns:
out (Variable): Shape(B, C_out, 1, T), the outputs, where C_out
means output channels (num_filters), and T = 1.
"""
if self.temp_weight is None:
self.temp_weight = self._reshaped_weight()
window_size = 1 + (self.filter_size - 1) * self.dilation
batch_size = x.shape[0]
in_channels = x.shape[1]
if self.filter_size > 1:
if self.input_buffer is None:
self.input_buffer = fluid.layers.fill_constant(
[batch_size, in_channels, 1, window_size - 1],
dtype=x.dtype,
value=0.0)
else:
self.input_buffer = self.input_buffer[:, :, :, 1:]
self.input_buffer = fluid.layers.concat(
[self.input_buffer, x], axis=3)
x = self.input_buffer
if self.dilation > 1:
if not hasattr(self, "indices"):
self.indices = dg.to_variable(
np.arange(0, window_size, self.dilation))
tmp = fluid.layers.transpose(
self.input_buffer, perm=[3, 1, 2, 0])
tmp = fluid.layers.gather(tmp, index=self.indices)
tmp = fluid.layers.transpose(tmp, perm=[3, 1, 2, 0])
x = tmp
inputs = fluid.layers.reshape(
x, shape=[batch_size, in_channels * 1 * self.filter_size])
out = fluid.layers.matmul(inputs, self.temp_weight, transpose_y=True)
out = fluid.layers.elementwise_add(out, self.conv._bias_param, axis=-1)
out = fluid.layers.reshape(out, out.shape + [1, 1])
out = self._helper.append_activation(out, act=self.act)
return out
def _reshaped_weight(self):
"""
Get the linearized weight of convolution filter, cause it is by nature
a matmul weight. And because the model uses weight norm, compute the
weight by weight_v * weight_g to make it faster.
Returns:
weight_matrix (Variable): Shape(C_out, C_in * 1 * kernel_size)
"""
shape = self.conv._filter_param_v.shape
matrix_shape = [shape[0], np.prod(shape[1:])]
weight_matrix = fluid.layers.reshape(
self.conv._filter_param_v, shape=matrix_shape)
weight_matrix = fluid.layers.elementwise_mul(
fluid.layers.l2_normalize(
weight_matrix, axis=1),
self.conv._filter_param_g,
axis=0)
return weight_matrix
class FC(dg.Layer):
"""
**Fully Connected Layer**
This function creates a fully connected layer in the network. It can take
one or multiple tensors as its inputs(input can be a list of Variable, see
Args in detail). It creates a pair of variables called (magnitude(g),
direction(V)) for each input tensor. Elementwise_mul(V, g) represents a fully connected
weight matrix from each input unit to each output unit.
The fully connected layer multiplies each input tensor
with its corresponding weight to produce an output Tensor with shape [M, `size`],
where M is batch size. If multiple input tensors are given, the results of
multiple output tensors with shape [M, `size`] will be summed up. If bias_attr
is not None, a bias variable will be created and added to the output.
Finally, if activation is not None, it will be applied to the output as well.
When the input is single tensor:
.. math::
Out = Act({X(normalize(V)g) + b})
When the input are multiple tensors:
.. math::
Out = Act({\sum_{i=0}^{N-1}X_i(V_ig_i) + b})
In the above equation:
* :math:`N`: Number of the input. N equals to len(input) if input is list of Variable.
* :math:`X_i`: The i-th input tensor.
* :math:`V_i`: The i-th direction matrix corresponding i-th input tensor.
* :math:`g_i`: The i-th magnitude vector corresponding i-th input tensor.
* :math:`b`: The bias parameter created by this layer (if needed).
* :math:`Act`: The activation function.
* :math:`Out`: The output tensor.
See below for an example.
.. code-block:: text
Given:
data_1.data = [[[0.1, 0.2],
[0.3, 0.4]]]
data_1.shape = (1, 2, 2) # 1 is batch_size
data_2 = [[[0.1, 0.2, 0.3]]]
data_2.shape = (1, 1, 3)
out = fluid.layers.fc(input=[data_1, data_2], size=2)
Then:
out.data = [[0.18669507, 0.1893476]]
out.shape = (1, 2)
Args:
name_scope(str): The name of this class.
size(int): The number of output units in this layer.
num_flatten_dims (int): The fc layer can accept an input tensor with more than
two dimensions. If this happens, the multidimensional tensor will first be flattened
into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input
tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1)
dimensions will be flatten to form the first dimension of the final matrix (height of
the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to
form the second dimension of the final matrix (width of the matrix). For example, suppose
`X` is a 5-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3.
Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. Default: 1
param_attr (ParamAttr|list of ParamAttr|None): The parameter attribute for learnable
parameters/weights of this layer.
bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
of this layer. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. Default: None.
act (str|None): Activation to be applied to the output of this layer.
is_test(bool): A flag indicating whether execution is in test phase. Default: False
dtype(str): Dtype used for weight
Raises:
ValueError: If rank of the input tensor is less than 2.
Examples:
.. code-block:: python
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import FC
import numpy as np
data = np.random.uniform( -1, 1, [30, 10, 32] ).astype('float32')
with fluid.dygraph.guard():
fc = FC( "fc", 64, num_flatten_dims=2)
data = to_variable( data )
conv = fc( data )
"""
def __init__(self,
name_scope,
size,
num_flatten_dims=1,
epsilon=1e-30,
param_attr=None,
bias_attr=None,
act=None,
is_test=False,
dtype="float32"):
super(FC, self).__init__(name_scope, dtype)
self._size = size
self._num_flatten_dims = num_flatten_dims
self._epsilon = epsilon
self._dtype = dtype
self._param_attr = param_attr
self._bias_attr = bias_attr
self._act = act
self.__g = list()
self.__v = list()
@property
def _v(self, i=0):
return self.__v[i]
@property
def _g(self, i=0):
return self.__g[i]
@_v.setter
def _v(self, value, i=0):
assert isinstance(value, Parameter)
self.__v[i] = value
@_g.setter
def _g(self, value, i=0):
assert isinstance(value, Parameter)
self.__g[i] = value
def _build_once(self, input):
i = 0
for inp, param in self._helper.iter_inputs_and_params(
input, self._param_attr):
input_shape = inp.shape
param_shape = [
reduce(lambda a, b: a * b,
input_shape[self._num_flatten_dims:], 1)
] + [self._size]
self.__v.append(
self.add_parameter(
"_v%d" % i,
self.create_parameter(
attr=param,
shape=param_shape,
dtype=self._dtype,
is_bias=False)))
magnitude_shape = param_shape[1:]
magnitude_value = np.linalg.norm(
self.__v[i].numpy(), ord=2, axis=0)
self.__g.append(
self.add_parameter(
"_g%d" % i,
self.create_parameter(
attr=fluid.ParamAttr(initializer=fluid.initializer.
NumpyArrayInitializer(
magnitude_value)),
shape=magnitude_shape,
dtype=self._dtype,
is_bias=False)))
i += 1
size = list([self._size])
self._b = self.create_parameter(
attr=self._bias_attr, shape=size, dtype=self._dtype, is_bias=True)
def forward(self, input):
mul_results = list()
i = 0
for inp, param in self._helper.iter_inputs_and_params(
input, self._param_attr):
v_norm = self._helper.create_variable_for_type_inference(
self._dtype)
v_normalized = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="norm",
inputs={"X": self.__v[i]},
outputs={"Out": v_normalized,
"Norm": v_norm},
attrs={"axis": 0,
"epsilon": self._epsilon})
weight = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="elementwise_mul",
inputs={"X": [v_normalized],
"Y": [self.__g[i]]},
outputs={"Out": [weight]},
attrs={"axis": 1})
tmp = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="mul",
inputs={"X": inp,
"Y": weight},
outputs={"Out": tmp},
attrs={
"x_num_col_dims": self._num_flatten_dims,
"y_num_col_dims": 1
})
i += 1
mul_results.append(tmp)
if len(mul_results) == 1:
pre_bias = mul_results[0]
else:
pre_bias = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="sum",
inputs={"X": mul_results},
outputs={"Out": pre_bias},
attrs={"use_mkldnn": False})
if self._b:
pre_activation = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type="elementwise_add",
inputs={"X": [pre_bias],
"Y": [self._b]},
outputs={"Out": [pre_activation]},
attrs={"axis": self._num_flatten_dims})
else:
pre_activation = pre_bias
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(pre_activation, act=self._act)
class Conv2D(dg.Layer):
"""
The convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input and
Output are in NCHW format, where N is batch size, C is the number of
channels, H is the height of the feature, and W is the width of the feature.
Filter is in MCHW format, where M is the number of output image channels,
C is the number of input image channels, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input image channels divided by the groups.
Please refer to UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`
for more detials.
If bias attribution and activation type are provided, bias is added to the
output of the convolution, and the corresponding activation function is
applied to the final result.
For each input :math:`X`, the equation is:
.. math::
Out = \sigma ((Vg) \\ast X + b)
Where:
* :math:`X`: Input value, a tensor with NCHW format.
* :math:`V`: Filter direction value, a tensor with MCHW format.
* :math:`g`: Filter magnitude value, a tensor with M format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Args:
name_scope(str) : The name for this class.
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
stride (int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
padding (int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
dilation (int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups (int): The groups number of the Conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: groups=1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
act (str): Activation type, if it is set to None, activation is not appended.
Default: None
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D
import numpy as np
data = np.random.uniform( -1, 1, [10, 3, 32, 32] ).astype('float32')
with fluid.dygraph.guard():
conv2d = Conv2D( "conv2d", 2, 3)
data = to_variable( data )
conv = conv2d( data )
"""
def __init__(self,
name_scope,
num_filters,
filter_size,
stride=1,
padding=0,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
epsilon=1e-30,
dtype="float32"):
assert param_attr is not False, "param_attr should not be False here."
super(Conv2D, self).__init__(name_scope, dtype)
self._groups = groups
self._stride = utils.convert_to_list(stride, 2, "stride")
self._padding = utils.convert_to_list(padding, 2, "padding")
self._dilation = utils.convert_to_list(dilation, 2, "dilation")
self._act = act
if not isinstance(use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
self._use_cudnn = use_cudnn
self._filter_size = filter_size
self._num_filters = num_filters
self._param_attr = param_attr
self._bias_attr = bias_attr
self._epsilon = epsilon
self._dtype = dtype
# if (self._num_channels == self._groups and
# num_filters % self._num_channels == 0 and not self._use_cudnn):
# self._l_type = 'depthwise_conv2d'
# else:
# TODO(jiabin): recover the usage of depthwise_conv2d when it's
# kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17275
self._l_type = "conv2d"
def _build_once(self, input):
self._num_channels = input.shape[1]
if self._groups is None:
num_filter_channels = self._num_channels
else:
if self._num_channels % self._groups != 0:
raise ValueError("num_channels must be divisible by groups.")
num_filter_channels = self._num_channels // self._groups
filter_size = utils.convert_to_list(self._filter_size, 2,
"filter_size")
filter_shape = [self._num_filters, int(num_filter_channels)
] + filter_size
def _get_default_param_initializer():
filter_elem_num = filter_size[0] * filter_size[
1] * self._num_channels
std = (2.0 / filter_elem_num)**0.5
return Normal(0.0, std, 0)
# weight_v
self._filter_param_v = self.create_parameter(
attr=self._param_attr,
shape=filter_shape,
dtype=self._dtype,
default_initializer=_get_default_param_initializer())
# weight_g
norm_value = _norm(
self._filter_param_v.numpy(), dim=0) # CAUTION: hard-code
self._filter_param_g = self.create_parameter(
attr=fluid.ParamAttr(
initializer=fluid.initializer.NumpyArrayInitializer(
norm_value)),
shape=norm_value.shape,
dtype=self._dtype,
default_initializer=_get_default_param_initializer())
if self._use_cudnn:
self.create_variable(
name="kCUDNNFwdAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
self.create_variable(
name="kCUDNNBwdDataAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
self.create_variable(
name="kCUDNNBwdFilterAlgoCache",
persistable=True,
type=core.VarDesc.VarType.RAW)
self._bias_param = self.create_parameter(
attr=self._bias_attr,
shape=[self._num_filters],
dtype=self._dtype,
is_bias=True)
def forward(self, input):
matrix = self._helper.create_variable_for_type_inference(self._dtype)
tmp = self._helper.create_variable_for_type_inference(self._dtype)
new_shape = [
self._filter_param_v.shape[0],
reduce(lambda x, y: x * y, self._filter_param_v.shape[1:], 1),
]
self._helper.append_op(
type="reshape2",
inputs={"X": self._filter_param_v},
attrs={"shape": new_shape},
outputs={"Out": matrix,
"XShape": tmp})
m_norm = self._helper.create_variable_for_type_inference(self._dtype)
m_normalized = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="norm",
inputs={"X": matrix},
outputs={"Out": m_normalized,
"Norm": m_norm},
attrs={"axis": 1,
"epsilon": self._epsilon})
v_normalized = self._helper.create_variable_for_type_inference(
self._dtype)
tmp2 = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="reshape2",
inputs={"X": m_normalized},
attrs={"shape": self._filter_param_v.shape},
outputs={"Out": v_normalized,
"XShape": tmp2})
filter_param = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="elementwise_mul",
inputs={"X": [v_normalized],
"Y": [self._filter_param_g]},
outputs={"Out": [filter_param]},
attrs={"axis": 0}, # CAUTION: hard-code
)
pre_bias = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type=self._l_type,
inputs={"Input": input,
"Filter": filter_param},
outputs={"Output": pre_bias},
attrs={
"strides": self._stride,
"paddings": self._padding,
"dilations": self._dilation,
"groups": self._groups if self._groups else 1,
"use_cudnn": self._use_cudnn,
"use_mkldnn": False,
})
if self._bias_param is not None:
pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type="elementwise_add",
inputs={"X": [pre_bias],
"Y": [self._bias_param]},
outputs={"Out": [pre_act]},
attrs={"axis": 1})
else:
pre_act = pre_bias
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(pre_act, act=self._act)
class Conv2DTranspose(dg.Layer):
"""
**Convlution2D transpose layer**
The convolution2D transpose layer calculates the output based on the input,
filter, and dilations, strides, paddings. Input(Input) and output(Output)
are in NCHW format. Where N is batch size, C is the number of channels,
H is the height of the feature, and W is the width of the feature.
Parameters(dilations, strides, paddings) are two elements. These two elements
represent height and width, respectively. The details of convolution transpose
layer, please refer to the following explanation and references
`therein <http://www.matthewzeiler.com/wp-content/uploads/2017/07/cvpr2010.pdf>`_.
If bias attribution and activation type are provided, bias is added to
the output of the convolution, and the corresponding activation function
is applied to the final result.
For each input :math:`X`, the equation is:
.. math::
Out = \sigma ((Vg) \\ast X + b)
Where:
* :math:`X`: Input value, a tensor with NCHW format.
* :math:`V`: Filter value, a tensor with MCHW format.
* :math:`g`: Filter value, a tensor with M format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\
W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\
H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\
W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] )
Args:
name_scope(str): The name of this class.
num_filters(int): The number of the filter. It is as same as the output
image channel.
output_size(int|tuple|None): The output image size. If output size is a
tuple, it must contain two integers, (image_H, image_W). None if use
filter_size, padding, and stride to calculate output_size.
if output_size and filter_size are specified at the same time, They
should follow the formula above. Default: None.
filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square. None if use output size to
calculate filter_size. Default: None.
padding(int|tuple): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: padding = 0.
stride(int|tuple): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: stride = 1.
dilation(int|tuple): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: dilation = 1.
groups(int): The groups number of the Conv2d transpose layer. Inspired by
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels.
Default: groups = 1.
param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d_transpose.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str): Activation type, if it is set to None, activation is not appended.
Default: None.
Returns:
Variable: The tensor variable storing the convolution transpose result.
Raises:
ValueError: If the shapes of input, filter_size, stride, padding and
groups mismatch.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy
with fluid.dygraph.guard():
data = numpy.random.random((3, 32, 32)).astype('float32')
conv2DTranspose = fluid.dygraph.nn.Conv2DTranspose(
'Conv2DTranspose', num_filters=2, filter_size=3)
ret = conv2DTranspose(fluid.dygraph.base.to_variable(data))
"""
def __init__(self,
name_scope,
num_filters,
output_size=None,
filter_size=None,
padding=0,
stride=1,
dilation=1,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
epsilon=1e-30,
act=None,
dtype="float32"):
super(Conv2DTranspose, self).__init__(name_scope, dtype)
assert (param_attr is not False
), "param_attr should not be False in conv2d_transpose."
self._param_attr = param_attr
self._bias_attr = bias_attr
self._groups = groups
self._num_filters = num_filters
self._use_cudnn = use_cudnn
self._padding = padding
self._stride = stride
self._dilation = dilation
self._filter_size = filter_size
self._output_size = output_size
self._op_type = "conv2d_transpose"
self._epsilon = epsilon
def _build_once(self, input):
input_channel = input.shape[1]
if (input_channel == self._groups and
self._num_filters == input_channel and not self._use_cudnn):
self._op_type = "depthwise_conv2d_transpose"
if not isinstance(input, Variable):
raise TypeError("Input of conv2d_transpose must be Variable")
self._padding = utils.convert_to_list(self._padding, 2, "padding")
self._stride = utils.convert_to_list(self._stride, 2, "stride")
self._dilation = utils.convert_to_list(self._dilation, 2, "dilation")
if not isinstance(self._use_cudnn, bool):
raise ValueError("use_cudnn should be True or False")
if self._filter_size is None:
if self._output_size is None:
raise ValueError(
"output_size must be set when filter_size is None")
if isinstance(self._output_size, int):
self._output_size = [self._output_size, self._output_size]
h_in = input.shape[2]
w_in = input.shape[3]
filter_size_h = (self._output_size[0] -
(h_in - 1) * self._stride[0] + 2 *
self._padding[0] - 1) // self._dilation[0] + 1
filter_size_w = (self._output_size[1] -
(w_in - 1) * self._stride[1] + 2 *
self._padding[1] - 1) // self._dilation[1] + 1
self._filter_size = [filter_size_h, filter_size_w]
else:
self._filter_size = utils.convert_to_list(
self._filter_size, 2, "conv2d_transpose.filter_size")
if self._output_size is None:
self._output_size = []
elif isinstance(self._output_size, list) or isinstance(
self._output_size, int):
self._output_size = utils.convert_to_list(self._output_size, 2,
"output_size")
else:
raise ValueError("output_size should be list or int")
self._padding = utils.convert_to_list(self._padding, 2, "padding")
self._groups = 1 if self._groups is None else self._groups
filter_shape = [
input_channel,
self._num_filters // self._groups,
] + self._filter_size
# img filter v (direction)
self._img_filter_v = self.create_parameter(
dtype=input.dtype, shape=filter_shape, attr=self._param_attr)
# img filter g (magnitude)
img_filter_magnitude = _norm(
self._img_filter_v.numpy(), dim=0) # CAUTION: hard-code
self._img_filter_g = self.create_parameter(
dtype=input.dtype,
shape=img_filter_magnitude.shape,
attr=fluid.ParamAttr(
initializer=NumpyArrayInitializer(img_filter_magnitude)))
self._img_bias = self.create_parameter(
attr=self._bias_attr,
shape=[self._num_filters],
dtype=self._dtype,
is_bias=True)
def forward(self, input):
matrix = self._helper.create_variable_for_type_inference(self._dtype)
tmp = self._helper.create_variable_for_type_inference(self._dtype)
new_shape = [
self._img_filter_v.shape[0],
reduce(lambda x, y: x * y, self._img_filter_v.shape[1:], 1),
]
self._helper.append_op(
type="reshape2",
inputs={"X": self._img_filter_v},
attrs={"shape": new_shape},
outputs={"Out": matrix,
"XShape": tmp})
m_norm = self._helper.create_variable_for_type_inference(self._dtype)
m_normalized = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="norm",
inputs={"X": matrix},
outputs={"Out": m_normalized,
"Norm": m_norm},
attrs={"axis": 1,
"epsilon": self._epsilon})
v_normalized = self._helper.create_variable_for_type_inference(
self._dtype)
tmp2 = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="reshape2",
inputs={"X": m_normalized},
attrs={"shape": self._img_filter_v.shape},
outputs={"Out": v_normalized,
"XShape": tmp2})
img_filter = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="elementwise_mul",
inputs={"X": [v_normalized],
"Y": [self._img_filter_g]},
outputs={"Out": [img_filter]},
attrs={"axis": 0}, # CAUTION: hard-code
)
pre_bias = self._helper.create_variable_for_type_inference(
dtype=input.dtype)
self._helper.append_op(
type=self._op_type,
inputs={"Input": [input],
"Filter": [img_filter]},
outputs={"Output": pre_bias},
attrs={
"output_size": self._output_size,
"strides": self._stride,
"paddings": self._padding,
"dilations": self._dilation,
"groups": self._groups,
"use_cudnn": self._use_cudnn,
})
if self._img_bias is not None:
pre_act = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type="elementwise_add",
inputs={"X": [pre_bias],
"Y": [self._img_bias]},
outputs={"Out": [pre_act]},
attrs={"axis": 1})
else:
pre_act = pre_bias
out = self._helper.append_activation(pre_act)
return out