diff --git a/parakeet/data/sampler.py b/parakeet/data/sampler.py
index 097cc03..60aa5db 100644
--- a/parakeet/data/sampler.py
+++ b/parakeet/data/sampler.py
@@ -163,6 +163,35 @@ class WeightedRandomSampler(Sampler):
return self.num_samples
+class DistributedSampler(Sampler):
+ def __init__(self, dataset_size, num_trainers, rank, shuffle=True):
+ self.dataset_size = dataset_size
+ self.num_trainers = num_trainers
+ self.rank = rank
+ self.num_samples = int(np.ceil(dataset_size / num_trainers))
+ self.total_size = self.num_samples * num_trainers
+ assert self.total_size >= self.dataset_size
+ self.shuffle = shuffle
+
+ def __iter__(self):
+ indices = list(range(self.dataset_size))
+ if self.shuffle:
+ random.shuffle(indices)
+
+ # Append extra samples to make it evenly distributed on all trainers.
+ indices += indices[:(self.total_size - self.dataset_size)]
+ assert len(indices) == self.total_size
+
+ # Subset samples for each trainer.
+ indices = indices[self.rank:self.total_size:self.num_trainers]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+
class BatchSampler(Sampler):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
@@ -206,4 +235,4 @@ class BatchSampler(Sampler):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
- return (len(self.sampler) + self.batch_size - 1) // self.batch_size
\ No newline at end of file
+ return (len(self.sampler) + self.batch_size - 1) // self.batch_size
diff --git a/parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml b/parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml
new file mode 100644
index 0000000..bf19577
--- /dev/null
+++ b/parakeet/models/wavenet/configs/wavenet_ljspeech_mix_gaussian.yaml
@@ -0,0 +1,32 @@
+valid_size: 16
+train_clip_second: 0.5
+sample_rate: 22050
+fft_window_shift: 256
+fft_window_size: 1024
+fft_size: 2048
+mel_bands: 80
+
+seed: 1
+batch_size: 8
+test_every: 2000
+save_every: 10000
+max_iterations: 2000000
+
+layers: 30
+kernel_width: 2
+dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
+residual_channels: 128
+skip_channels: 128
+loss_type: mix-gaussian-pdf
+num_mixtures: 10
+log_scale_min: -9.0
+
+conditioner:
+ filter_sizes: [[32, 3], [32, 3]]
+ upsample_factors: [16, 16]
+
+learning_rate: 0.001
+gradient_max_norm: 100.0
+anneal:
+ every: 200000
+ rate: 0.5
diff --git a/parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml b/parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml
new file mode 100644
index 0000000..f39de5d
--- /dev/null
+++ b/parakeet/models/wavenet/configs/wavenet_ljspeech_softmax.yaml
@@ -0,0 +1,31 @@
+valid_size: 16
+train_clip_second: 0.5
+sample_rate: 22050
+fft_window_shift: 256
+fft_window_size: 1024
+fft_size: 2048
+mel_bands: 80
+
+seed: 1
+batch_size: 8
+test_every: 2000
+save_every: 10000
+max_iterations: 2000000
+
+layers: 30
+kernel_width: 2
+dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
+residual_channels: 128
+skip_channels: 128
+loss_type: softmax
+num_channels: 2048
+
+conditioner:
+ filter_sizes: [[32, 3], [32, 3]]
+ upsample_factors: [16, 16]
+
+learning_rate: 0.001
+gradient_max_norm: 100.0
+anneal:
+ every: 200000
+ rate: 0.5
diff --git a/parakeet/models/wavenet/data.py b/parakeet/models/wavenet/data.py
index 61cc4ab..a4f1b70 100644
--- a/parakeet/models/wavenet/data.py
+++ b/parakeet/models/wavenet/data.py
@@ -1,5 +1,3 @@
-import math
-import os
import random
import librosa
@@ -9,7 +7,7 @@ from paddle import fluid
import utils
from parakeet.datasets import ljspeech
from parakeet.data import dataset
-from parakeet.data.sampler import Sampler, BatchSampler, SequentialSampler
+from parakeet.data.sampler import DistributedSampler, BatchSampler
from parakeet.data.datacargo import DataCargo
@@ -20,7 +18,7 @@ class Dataset(ljspeech.LJSpeech):
self.fft_window_shift = config.fft_window_shift
# Calculate context frames.
frames_per_second = config.sample_rate // self.fft_window_shift
- train_clip_frames = int(math.ceil(
+ train_clip_frames = int(np.ceil(
config.train_clip_second * frames_per_second))
context_frames = config.context_size // self.fft_window_shift
self.num_frames = train_clip_frames + context_frames
@@ -39,7 +37,7 @@ class Dataset(ljspeech.LJSpeech):
assert loaded_sr == sr
# Pad audio to the right size.
- frames = math.ceil(float(audio.size) / fft_window_shift)
+ frames = int(np.ceil(float(audio.size) / fft_window_shift))
fft_padding = (fft_size - fft_window_shift) // 2
desired_length = frames * fft_window_shift + fft_padding * 2
pad_amount = (desired_length - audio.size) // 2
@@ -125,35 +123,6 @@ class Subset(dataset.Dataset):
return len(self.indices)
-class DistributedSampler(Sampler):
- def __init__(self, dataset_size, num_trainers, rank, shuffle=True):
- self.dataset_size = dataset_size
- self.num_trainers = num_trainers
- self.rank = rank
- self.num_samples = int(math.ceil(dataset_size / num_trainers))
- self.total_size = self.num_samples * num_trainers
- assert self.total_size >= self.dataset_size
- self.shuffle = shuffle
-
- def __iter__(self):
- indices = list(range(self.dataset_size))
- if self.shuffle:
- random.shuffle(indices)
-
- # Append extra samples to make it evenly distributed on all trainers.
- indices += indices[:(self.total_size - self.dataset_size)]
- assert len(indices) == self.total_size
-
- # Subset samples for each trainer.
- indices = indices[self.rank:self.total_size:self.num_trainers]
- assert len(indices) == self.num_samples
-
- return iter(indices)
-
- def __len__(self):
- return self.num_samples
-
-
class LJSpeech:
def __init__(self, config, nranks, rank):
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
diff --git a/parakeet/models/wavenet/ops.py b/parakeet/models/wavenet/ops.py
deleted file mode 100644
index 6eda2a9..0000000
--- a/parakeet/models/wavenet/ops.py
+++ /dev/null
@@ -1,249 +0,0 @@
-import paddle
-from paddle import fluid
-import paddle.fluid.dygraph as dg
-import numpy as np
-
-import weight_norm
-
-
-def Embedding(name_scope,
- num_embeddings,
- embed_dim,
- padding_idx=None,
- std=0.1,
- dtype="float32"):
- # param attrs
- weight_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(
- scale=std))
- layer = dg.Embedding(
- name_scope, (num_embeddings, embed_dim),
- padding_idx=padding_idx,
- param_attr=weight_attr,
- dtype=dtype)
- return layer
-
-
-def FC(name_scope,
- in_features,
- size,
- num_flatten_dims=1,
- relu=False,
- dropout=0.0,
- act=None,
- dtype="float32"):
- """
- A special Linear Layer, when it is used with dropout, the weight is
- initialized as normal(0, std=np.sqrt((1-dropout) / in_features))
- """
-
- # stds
- if isinstance(in_features, int):
- in_features = [in_features]
-
- stds = [np.sqrt((1.0 - dropout) / in_feature) for in_feature in in_features]
- if relu:
- stds = [std * np.sqrt(2.0) for std in stds]
-
- weight_inits = [
- fluid.initializer.NormalInitializer(scale=std) for std in stds
- ]
- bias_init = fluid.initializer.ConstantInitializer(0.0)
-
- # param attrs
- weight_attrs = [fluid.ParamAttr(initializer=init) for init in weight_inits]
- bias_attr = fluid.ParamAttr(initializer=bias_init)
-
- layer = weight_norm.FC(name_scope,
- size,
- num_flatten_dims=num_flatten_dims,
- param_attr=weight_attrs,
- bias_attr=bias_attr,
- act=act,
- dtype=dtype)
- return layer
-
-
-def Conv1D(name_scope,
- in_channels,
- num_filters,
- filter_size=2,
- dilation=1,
- groups=None,
- causal=False,
- std_mul=1.0,
- dropout=0.0,
- use_cudnn=True,
- act=None,
- dtype="float32"):
- """
- A special Conv1D Layer, when it is used with dropout, the weight is
- initialized as
- normal(0, std=np.sqrt(std_mul * (1-dropout) / (filter_size * in_channels)))
- """
- # std
- std = np.sqrt((std_mul * (1.0 - dropout)) / (filter_size * in_channels))
- weight_init = fluid.initializer.NormalInitializer(loc=0.0, scale=std)
- bias_init = fluid.initializer.ConstantInitializer(0.0)
-
- # param attrs
- weight_attr = fluid.ParamAttr(initializer=weight_init)
- bias_attr = fluid.ParamAttr(initializer=bias_init)
-
- layer = weight_norm.Conv1D(
- name_scope,
- num_filters,
- filter_size,
- dilation,
- groups=groups,
- causal=causal,
- param_attr=weight_attr,
- bias_attr=bias_attr,
- use_cudnn=use_cudnn,
- act=act,
- dtype=dtype)
- return layer
-
-
-class Conv1D_GU(dg.Layer):
- def __init__(self,
- name_scope,
- conditioner_dim,
- in_channels,
- num_filters,
- filter_size,
- dilation,
- causal=False,
- residual=True,
- dtype="float32"):
- super(Conv1D_GU, self).__init__(name_scope, dtype=dtype)
-
- self.conditioner_dim = conditioner_dim
- self.in_channels = in_channels
- self.num_filters = num_filters
- self.filter_size = filter_size
- self.dilation = dilation
- self.causal = causal
- self.residual = residual
-
- if residual:
- assert (
- in_channels == num_filters
- ), "this block uses residual connection"\
- "the input_channels should equals num_filters"
-
- self.conv = Conv1D(
- self.full_name(),
- in_channels,
- 2 * num_filters,
- filter_size,
- dilation,
- causal=causal,
- dtype=dtype)
-
- self.fc = Conv1D(
- self.full_name(),
- conditioner_dim,
- 2 * num_filters,
- filter_size=1,
- dilation=1,
- causal=False,
- dtype=dtype)
-
- def forward(self, x, skip=None, conditioner=None):
- """
- Args:
- x (Variable): Shape(B, C_in, 1, T), the input of Conv1DGLU
- layer, where B means batch_size, C_in means the input channels
- T means input time steps.
- conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
- conditioner, where C_con is conditioner hidden dim which
- equals the num of mel bands. Note that when using residual
- connection, the Conv1DGLU does not change the number of
- channels, so out channels equals input channels.
- Returns:
- x (Variable): Shape(B, C_out, 1, T), the output of Conv1DGLU, where
- C_out means the output channels of Conv1DGLU.
- """
- residual = x
- x = self.conv(x)
-
- if conditioner is not None:
- cond_bias = self.fc(conditioner)
- x += cond_bias
-
- content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
-
- # Gated Unit.
- x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate),
- fluid.layers.tanh(content))
-
- if skip is None:
- skip = x
- else:
- skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
-
- if self.residual:
- x = fluid.layers.scale(residual + x, np.sqrt(0.5))
-
- return x, skip
-
- def add_input(self, x, skip=None, conditioner=None):
- """
- Inputs:
- x: shape(B, num_filters, 1, time_steps)
- conditioner: shape(B, conditioner_dim, 1, time_steps)
- Outputs:
- out: shape(B, num_filters, 1, time_steps), where time_steps = 1
- """
- residual = x
-
- # add step input and produce step output
- x = self.conv.add_input(x)
-
- if conditioner is not None:
- cond_bias = self.fc(conditioner)
- x += cond_bias
-
- content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
-
- # Gated Unit.
- x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate),
- fluid.layers.tanh(content))
-
- if skip is None:
- skip = x
- else:
- skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
-
- if self.residual:
- x = fluid.layers.scale(residual + x, np.sqrt(0.5))
-
- return x, skip
-
-
-def Conv2DTranspose(name_scope,
- num_filters,
- filter_size,
- padding=0,
- stride=1,
- dilation=1,
- use_cudnn=True,
- act=None,
- dtype="float32"):
- val = 1.0 / (filter_size[0] * filter_size[1])
- weight_init = fluid.initializer.ConstantInitializer(val)
- weight_attr = fluid.ParamAttr(initializer=weight_init)
-
- layer = weight_norm.Conv2DTranspose(
- name_scope,
- num_filters,
- filter_size=filter_size,
- padding=padding,
- stride=stride,
- dilation=dilation,
- param_attr=weight_attr,
- use_cudnn=use_cudnn,
- act=act,
- dtype=dtype)
-
- return layer
diff --git a/parakeet/models/wavenet/wavenet.py b/parakeet/models/wavenet/wavenet.py
index acc6e76..c636c4b 100644
--- a/parakeet/models/wavenet/wavenet.py
+++ b/parakeet/models/wavenet/wavenet.py
@@ -4,12 +4,12 @@ import time
import librosa
import numpy as np
-from paddle import fluid
import paddle.fluid.dygraph as dg
+from paddle import fluid
import utils
from data import LJSpeech
-from wavenet_modules import WaveNetModule, debug
+from wavenet_modules import WaveNetModule
class WaveNet():
@@ -33,18 +33,6 @@ class WaveNet():
self.trainloader = dataset.trainloader
self.validloader = dataset.validloader
-# if self.rank == 0:
-# for i, (audios, mels, ids) in enumerate(self.validloader()):
-# print("audios {}, mels {}, ids {}".format(audios.dtype, mels.dtype, ids.dtype))
-# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format(
-# i, self.rank, audios.shape, mels.shape, ids.shape,
-# ids.numpy()))
-#
-# for i, (audios, mels, ids) in enumerate(self.trainloader):
-# print("{}: rank {}, audios {}, mels {}, indices {} / {}".format(
-# i, self.rank, audios.shape, mels.shape, ids.shape,
-# ids.numpy()))
-
wavenet = WaveNetModule("wavenet", config, self.rank)
# Dry run once to create and initalize all necessary parameters.
@@ -139,8 +127,8 @@ class WaveNet():
self.wavenet.eval()
total_loss = []
- start_time = time.time()
sample_audios = []
+ start_time = time.time()
for audios, mels, audio_starts in self.validloader():
loss, sample_audio = self.wavenet(audios, mels, audio_starts, True)
total_loss.append(float(loss.numpy()))
@@ -160,11 +148,6 @@ class WaveNet():
tb.add_audio("Teacher-Forced-Audio-1", sample_audios[1].numpy(),
iteration, sample_rate=self.config.sample_rate)
- def save(self, iteration):
- utils.save_latest_parameters(self.checkpoint_dir, iteration,
- self.wavenet, self.optimizer)
- utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
-
@dg.no_grad
def infer(self, iteration):
self.wavenet.eval()
@@ -186,3 +169,8 @@ class WaveNet():
syn_audio.shape, syn_time))
librosa.output.write_wav(filename, syn_audio,
sr=config.sample_rate)
+
+ def save(self, iteration):
+ utils.save_latest_parameters(self.checkpoint_dir, iteration,
+ self.wavenet, self.optimizer)
+ utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
diff --git a/parakeet/models/wavenet/wavenet_modules.py b/parakeet/models/wavenet/wavenet_modules.py
index c5c01e9..fbab741 100644
--- a/parakeet/models/wavenet/wavenet_modules.py
+++ b/parakeet/models/wavenet/wavenet_modules.py
@@ -1,11 +1,9 @@
import itertools
-import math
import numpy as np
-from paddle import fluid
import paddle.fluid.dygraph as dg
-import ops
-import weight_norm
+from paddle import fluid
+from parakeet.modules import conv, modules
def get_padding(filter_size, stride, padding_type='same'):
@@ -16,22 +14,6 @@ def get_padding(filter_size, stride, padding_type='same'):
return padding
-def debug(x, var_name, rank, verbose=False):
- if not verbose and rank != 0:
- return
- dim = len(x.shape)
- if not isinstance(x, np.ndarray):
- x = x.numpy()
- if dim == 1:
- print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x))
- elif dim == 2:
- print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5]))
- elif dim == 3:
- print("Rank {}".format(rank), var_name, "shape {}, value {}".format(x.shape, x[:, :5, 0]))
- else:
- print("Rank", rank, var_name, "shape", x.shape)
-
-
def extract_slices(x, audio_starts, audio_length, rank):
slices = []
for i in range(x.shape[0]):
@@ -58,7 +40,7 @@ class Conditioner(dg.Layer):
stride = (up_scale, 1)
padding = get_padding(filter_sizes[i], stride)
self.deconvs.append(
- ops.Conv2DTranspose(
+ modules.Conv2DTranspose(
self.full_name(),
num_filters=1,
filter_size=filter_sizes[i],
@@ -94,12 +76,13 @@ class WaveNetModule(dg.Layer):
print("context_size", self.context_size)
if config.loss_type == "softmax":
- self.embedding_fc = ops.Embedding(
+ self.embedding_fc = modules.Embedding(
self.full_name(),
num_embeddings=config.num_channels,
- embed_dim=config.residual_channels)
+ embed_dim=config.residual_channels,
+ std=0.1)
elif config.loss_type == "mix-gaussian-pdf":
- self.embedding_fc = ops.FC(
+ self.embedding_fc = modules.FC(
self.full_name(),
in_features=1,
size=config.residual_channels,
@@ -112,7 +95,7 @@ class WaveNetModule(dg.Layer):
self.dilated_causal_convs = []
for dilation in self.dilations:
self.dilated_causal_convs.append(
- ops.Conv1D_GU(
+ modules.Conv1D_GU(
self.full_name(),
conditioner_dim=config.mel_bands,
in_channels=config.residual_channels,
@@ -126,7 +109,7 @@ class WaveNetModule(dg.Layer):
for i, layer in enumerate(self.dilated_causal_convs):
self.add_sublayer("dilated_causal_conv_{}".format(i), layer)
- self.fc1 = ops.FC(
+ self.fc1 = modules.FC(
self.full_name(),
in_features=config.residual_channels,
size=config.skip_channels,
@@ -134,7 +117,7 @@ class WaveNetModule(dg.Layer):
relu=True,
act="relu")
- self.fc2 = ops.FC(
+ self.fc2 = modules.FC(
self.full_name(),
in_features=config.skip_channels,
size=config.skip_channels,
@@ -143,14 +126,14 @@ class WaveNetModule(dg.Layer):
act="relu")
if config.loss_type == "softmax":
- self.fc3 = ops.FC(
+ self.fc3 = modules.FC(
self.full_name(),
in_features=config.skip_channels,
size=config.num_channels,
num_flatten_dims=2,
relu=False)
elif config.loss_type == "mix-gaussian-pdf":
- self.fc3 = ops.FC(
+ self.fc3 = modules.FC(
self.full_name(),
in_features=config.skip_channels,
size=3 * config.num_mixtures,
@@ -175,8 +158,8 @@ class WaveNetModule(dg.Layer):
return samples
def sample_mix_gaussian(self, mix_parameters):
- # mix_parameters reshape from [bs, 13799, 3 * num_mixtures]
- # to [bs * 13799, 3 * num_mixtures].
+ # mix_parameters reshape from [bs, len, 3 * num_mixtures]
+ # to [bs * len, 3 * num_mixtures].
batch, length, hidden = mix_parameters.shape
mix_param_2d = fluid.layers.reshape(mix_parameters,
[batch * length, hidden])
@@ -197,7 +180,7 @@ class WaveNetModule(dg.Layer):
mu_comp = fluid.layers.gather_nd(mu, comp_samples)
s_comp = fluid.layers.gather_nd(s, comp_samples)
- # N(0, 1) Normal Sample.
+ # N(0, 1) normal sample.
u = fluid.layers.gaussian_random(shape=[batch * length])
samples = mu_comp + u * s_comp
samples = fluid.layers.clip(samples, min=-1.0, max=1.0)
@@ -205,8 +188,6 @@ class WaveNetModule(dg.Layer):
return samples
def softmax_loss(self, targets, mix_parameters):
- # targets: [bs, 13799] -> [bs, 11752]
- # mix_params: [bs, 13799, 3] -> [bs, 11752, 3]
targets = targets[:, self.context_size:]
mix_parameters = mix_parameters[:, self.context_size:, :]
@@ -216,22 +197,22 @@ class WaveNetModule(dg.Layer):
quantized = fluid.layers.cast(
(targets + 1.0) / 2.0 * num_channels, dtype="int64")
- # per_sample_loss shape: [bs, 17952, 1]
+ # per_sample_loss shape: [bs, len, 1]
per_sample_loss = fluid.layers.softmax_with_cross_entropy(
logits=mix_parameters, label=fluid.layers.unsqueeze(quantized, 2))
loss = fluid.layers.reduce_mean(per_sample_loss)
- #debug(loss, "softmax loss", self.rank)
return loss
def mixture_density_loss(self, targets, mix_parameters, log_scale_min):
- # targets: [bs, 13799] -> [bs, 11752]
- # mix_params: [bs, 13799, 3] -> [bs, 11752, 3]
+ # targets: [bs, len]
+ # mix_params: [bs, len, 3 * num_mixture]
targets = targets[:, self.context_size:]
mix_parameters = mix_parameters[:, self.context_size:, :]
- # log_s: [bs, 11752, num_mixture]
- logits_pi, mu, log_s = fluid.layers.split(mix_parameters, num_or_sections=3, dim=-1)
+ # log_s: [bs, len, num_mixture]
+ logits_pi, mu, log_s = fluid.layers.split(
+ mix_parameters, num_or_sections=3, dim=-1)
pi = fluid.layers.softmax(logits_pi, axis=-1)
log_s = fluid.layers.clip(log_s, min=log_scale_min, max=100.0)
@@ -242,10 +223,9 @@ class WaveNetModule(dg.Layer):
targets = fluid.layers.expand(targets, [1, 1, self.config.num_mixtures])
x_std = inv_s * (targets - mu)
exponent = fluid.layers.exp(-0.5 * x_std * x_std)
- # pdf_x: [bs, 11752, 1]
pdf_x = 1.0 / np.sqrt(2.0 * np.pi) * inv_s * exponent
pdf_x = pi * pdf_x
- # pdf_x: [bs, 11752]
+ # pdf_x: [bs, len]
pdf_x = fluid.layers.reduce_sum(pdf_x, dim=-1)
per_sample_loss = 0.0 - fluid.layers.log(pdf_x + 1e-9)
@@ -254,8 +234,6 @@ class WaveNetModule(dg.Layer):
return loss
def forward(self, audios, mels, audio_starts, sample=False):
- # audios: [bs, 13800], mels: [bs, full_frame_length, 80]
- # audio_starts: [bs]
# Build conditioner based on mels.
full_conditioner = self.conditioner(mels)
@@ -264,15 +242,14 @@ class WaveNetModule(dg.Layer):
conditioner = extract_slices(full_conditioner,
audio_starts, audio_length, self.rank)
- # input_audio, target_audio: [bs, 13799]
+ # input_audio, target_audio: [bs, len]
input_audios = audios[:, :-1]
target_audios = audios[:, 1:]
- # conditioner: [bs, 13799, 80]
+ # conditioner: [bs, len, mel_bands]
conditioner = conditioner[:, 1:, :]
loss_type = self.config.loss_type
- # layer_input: [bs, 13799, 128]
if loss_type == "softmax":
input_audios = fluid.layers.clip(
input_audios, min=-1.0, max=0.99999)
@@ -280,31 +257,31 @@ class WaveNetModule(dg.Layer):
quantized = fluid.layers.cast(
(input_audios + 1.0) / 2.0 * self.config.num_channels,
dtype="int64")
- layer_input = self.embedding_fc(fluid.layers.unsqueeze(quantized, 2))
+ layer_input = self.embedding_fc(
+ fluid.layers.unsqueeze(quantized, 2))
elif loss_type == "mix-gaussian-pdf":
- layer_input = self.embedding_fc(fluid.layers.unsqueeze(input_audios, 2))
+ layer_input = self.embedding_fc(
+ fluid.layers.unsqueeze(input_audios, 2))
else:
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
- # layer_input: [bs, res_channel, 1, 13799]
- layer_input = fluid.layers.unsqueeze(fluid.layers.transpose(layer_input, perm=[0, 2, 1]), 2)
- # conditioner: [bs, mel_bands, 1, 13799]
- conditioner = fluid.layers.unsqueeze(fluid.layers.transpose(conditioner, perm=[0, 2, 1]), 2)
+ # layer_input: [bs, res_channel, 1, len]
+ layer_input = fluid.layers.unsqueeze(
+ fluid.layers.transpose(layer_input, perm=[0, 2, 1]), 2)
+ # conditioner: [bs, mel_bands, 1, len]
+ conditioner = fluid.layers.unsqueeze(
+ fluid.layers.transpose(conditioner, perm=[0, 2, 1]), 2)
- # layer_input: [bs, res_channel, 1, 13799]
- # skip: [bs, res_channel, 1, 13799]
skip = None
for i, layer in enumerate(self.dilated_causal_convs):
+ # layer_input: [bs, res_channel, 1, len]
+ # skip: [bs, res_channel, 1, len]
layer_input, skip = layer(layer_input, skip, conditioner)
- #debug(layer_input, "layer_input_" + str(i), self.rank)
- #debug(skip, "skip_" + str(i), self.rank)
- # Reshape skip to [bs, 13799, res_channel]
- skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
- #debug(skip, "skip", self.rank)
-
- # mix_param: [bs, 13799, 3 * num_mixtures]
+ # Reshape skip to [bs, len, res_channel]
+ skip = fluid.layers.transpose(
+ fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
mix_parameters = self.fc3(self.fc2(self.fc1(skip)))
# Sample teacher-forced audio.
@@ -317,12 +294,7 @@ class WaveNetModule(dg.Layer):
else:
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
- #debug(sample_audios, "sample_audios", self.rank)
- # Calculate mix-gaussian density loss.
- # padding is all zero.
- # target_audio: [bs, 13799].
- # mix_params: [bs, 13799, 3].
if loss_type == "softmax":
loss = self.softmax_loss(target_audios, mix_parameters)
elif loss_type == "mix-gaussian-pdf":
@@ -332,27 +304,16 @@ class WaveNetModule(dg.Layer):
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
- #print("Rank {}, loss {}".format(self.rank, loss.numpy()))
-
return loss, sample_audios
def synthesize(self, mels):
self.start_new_sequence()
- print("input mels shape", mels.shape)
- # mels: [bs=1, n_frames, 80]
- # conditioner: [1, n_frames * samples_per_frame, 80]
- # Should I move forward by one sample? No difference
- # Append context frame to mels
bs, n_frames, mel_bands = mels.shape
- #num_pad_frames = int(np.ceil(self.context_size / self.config.fft_window_shift))
- #silence = fluid.layers.zeros(shape=[bs, num_pad_frames, mel_bands], dtype="float32")
- #inf_mels = fluid.layers.concat([silence, mels], axis=1)
- #print("padded mels shape", inf_mels.shape)
-
- #conditioner = self.conditioner(inf_mels)[:, self.context_size:, :]
conditioner = self.conditioner(mels)
time_steps = conditioner.shape[1]
- print("Total steps", time_steps)
+
+ print("input mels shape", mels.shape)
+ print("Total synthesis steps", time_steps)
loss_type = self.config.loss_type
audio_samples = []
@@ -361,8 +322,8 @@ class WaveNetModule(dg.Layer):
if i % 100 == 0:
print("Step", i)
- # convert from real value sample to audio embedding.
- # [bs, 1, 128]
+ # Convert from real value sample to audio embedding.
+ # audio_input: [bs, 1, channel]
if loss_type == "softmax":
current_sample = fluid.layers.clip(
current_sample, min=-1.0, max=0.99999)
@@ -377,21 +338,23 @@ class WaveNetModule(dg.Layer):
raise ValueError(
"loss_type {} is unsupported!".format(loss_type))
- # [bs, 128, 1, 1]
- audio_input = fluid.layers.unsqueeze(fluid.layers.transpose(audio_input, perm=[0, 2, 1]), 2)
- # [bs, 80]
+ # [bs, channel, 1, 1]
+ audio_input = fluid.layers.unsqueeze(
+ fluid.layers.transpose(audio_input, perm=[0, 2, 1]), 2)
+ # [bs, mel_bands]
cond_input = conditioner[:, i, :]
- # [bs, 80, 1, 1]
+ # [bs, mel_bands, 1, 1]
cond_input = fluid.layers.reshape(
cond_input, cond_input.shape + [1, 1])
skip = None
for layer in self.dilated_causal_convs:
- audio_input, skip = layer.add_input(audio_input, skip, cond_input)
+ audio_input, skip = layer.add_input(
+ audio_input, skip, cond_input)
- # [bs, 1, 128]
- skip = fluid.layers.transpose(fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
- # [bs, 1, 3]
+ # [bs, 1, channel]
+ skip = fluid.layers.transpose(
+ fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
mix_parameters = self.fc3(self.fc2(self.fc1(skip)))
if loss_type == "softmax":
sample = self.sample_softmax(mix_parameters)
@@ -407,17 +370,12 @@ class WaveNetModule(dg.Layer):
current_sample = fluid.layers.reshape(current_sample,
current_sample.shape + [1, 1])
- # syn_audio: (num_samples,)
+ # syn_audio: [num_samples]
syn_audio = fluid.layers.concat(audio_samples, axis=0).numpy()
return syn_audio
def start_new_sequence(self):
for layer in self.sublayers():
- if isinstance(layer, weight_norm.Conv1D):
+ if isinstance(layer, conv.Conv1D):
layer.start_new_sequence()
-
- def save(self, iteration):
- utils.save_latest_parameters(self.checkpoint_dir, iteration,
- self.wavenet, self.optimizer)
- utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
diff --git a/parakeet/models/wavenet/weight_norm.py b/parakeet/models/wavenet/weight_norm.py
deleted file mode 100644
index 75fe413..0000000
--- a/parakeet/models/wavenet/weight_norm.py
+++ /dev/null
@@ -1,920 +0,0 @@
-import math
-from copy import deepcopy
-
-import numpy as np
-import paddle.fluid.dygraph as dg
-from paddle import fluid
-from paddle.fluid import core
-from paddle.fluid.framework import Variable
-from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
-from paddle.fluid.layers import utils
-from six.moves import reduce
-
-
-def _norm(p, dim):
- """Computes the norm over all dimensions except dim.
- It differs from pytorch implementation that it does not keep dim.
- This difference is related with the broadcast mechanism in paddle.
- Read elementeise_mul for more.
- """
- if dim is None:
- return np.linalg.norm(p, ord=2, axis=None)
- elif dim == 0:
- p = np.reshape(p, newshape=(p.shape[0], -1))
- return np.linalg.norm(p, ord=2, axis=1)
- elif dim == p.ndim - 1:
- p = np.reshape(p, newshape=(-1, p.shape[-1]))
- return np.linalg.norm(p, ord=2, axis=0)
- else:
- perm = list(range(p.ndim))
- perm[0] = dim
- perm[dim] = 0
- return _norm(np.transpose(p, axes=perm))
-
-
-class Conv1D(dg.Layer):
- """
- A convolution 1D block implemented with Conv2D. Form simplicity and
- ensuring the output has the same length as the input, it does not allow
- stride > 1.
- """
- def __init__(self,
- name_scope,
- num_filters,
- filter_size=3,
- dilation=1,
- groups=None,
- causal=False,
- param_attr=None,
- bias_attr=None,
- use_cudnn=True,
- act=None,
- dtype="float32"):
- super(Conv1D, self).__init__(name_scope, dtype=dtype)
-
- if causal:
- padding = dilation * (filter_size - 1)
- else:
- padding = (dilation * (filter_size - 1)) // 2
-
- self.num_filters = num_filters
- self.filter_size = filter_size
- self.dilation = dilation
- self.causal = causal
- self.padding = padding
- self.act = act
-
- self.conv = Conv2D(
- self.full_name(),
- num_filters=num_filters,
- filter_size=(1, filter_size),
- stride=(1, 1),
- dilation=(1, dilation),
- padding=(0, padding),
- groups=groups,
- param_attr=param_attr,
- bias_attr=bias_attr,
- use_cudnn=use_cudnn,
- act=act,
- dtype=dtype)
-
- def forward(self, x):
- """
- Args:
- x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
- input channels.
- Returns:
- x (Variable): Shape(B, C_out, 1, T), the outputs, where C_out means
- output channels (num_filters).
- """
- x = self.conv(x)
- if self.filter_size > 1:
- if self.causal:
- x = fluid.layers.slice(
- x, axes=[3], starts=[0], ends=[-self.padding])
- elif self.filter_size % 2 == 0:
- x = fluid.layers.slice(x, axes=[3], starts=[0], ends=[-1])
- return x
-
- def start_new_sequence(self):
- self.temp_weight = None
- self.input_buffer = None
-
- def add_input(self, x):
- """
- Adding input for a time step and compute an output for a time step.
-
- Args:
- x (Variable): Shape(B, C_in, 1, T), the input, where C_in means
- input channels, and T = 1.
- Returns:
- out (Variable): Shape(B, C_out, 1, T), the outputs, where C_out
- means output channels (num_filters), and T = 1.
-
- """
- if self.temp_weight is None:
- self.temp_weight = self._reshaped_weight()
-
- window_size = 1 + (self.filter_size - 1) * self.dilation
- batch_size = x.shape[0]
- in_channels = x.shape[1]
-
- if self.filter_size > 1:
- if self.input_buffer is None:
- self.input_buffer = fluid.layers.fill_constant(
- [batch_size, in_channels, 1, window_size - 1],
- dtype=x.dtype,
- value=0.0)
- else:
- self.input_buffer = self.input_buffer[:, :, :, 1:]
- self.input_buffer = fluid.layers.concat(
- [self.input_buffer, x], axis=3)
- x = self.input_buffer
- if self.dilation > 1:
- if not hasattr(self, "indices"):
- self.indices = dg.to_variable(
- np.arange(0, window_size, self.dilation))
- tmp = fluid.layers.transpose(
- self.input_buffer, perm=[3, 1, 2, 0])
- tmp = fluid.layers.gather(tmp, index=self.indices)
- tmp = fluid.layers.transpose(tmp, perm=[3, 1, 2, 0])
- x = tmp
- inputs = fluid.layers.reshape(
- x, shape=[batch_size, in_channels * 1 * self.filter_size])
- out = fluid.layers.matmul(inputs, self.temp_weight, transpose_y=True)
- out = fluid.layers.elementwise_add(out, self.conv._bias_param, axis=-1)
- out = fluid.layers.reshape(out, out.shape + [1, 1])
- out = self._helper.append_activation(out, act=self.act)
- return out
-
- def _reshaped_weight(self):
- """
- Get the linearized weight of convolution filter, cause it is by nature
- a matmul weight. And because the model uses weight norm, compute the
- weight by weight_v * weight_g to make it faster.
- Returns:
- weight_matrix (Variable): Shape(C_out, C_in * 1 * kernel_size)
- """
- shape = self.conv._filter_param_v.shape
- matrix_shape = [shape[0], np.prod(shape[1:])]
- weight_matrix = fluid.layers.reshape(
- self.conv._filter_param_v, shape=matrix_shape)
- weight_matrix = fluid.layers.elementwise_mul(
- fluid.layers.l2_normalize(
- weight_matrix, axis=1),
- self.conv._filter_param_g,
- axis=0)
- return weight_matrix
-
-
-class FC(dg.Layer):
- """
- **Fully Connected Layer**
- This function creates a fully connected layer in the network. It can take
- one or multiple tensors as its inputs(input can be a list of Variable, see
- Args in detail). It creates a pair of variables called (magnitude(g),
- direction(V)) for each input tensor. Elementwise_mul(V, g) represents a fully connected
- weight matrix from each input unit to each output unit.
- The fully connected layer multiplies each input tensor
- with its corresponding weight to produce an output Tensor with shape [M, `size`],
- where M is batch size. If multiple input tensors are given, the results of
- multiple output tensors with shape [M, `size`] will be summed up. If bias_attr
- is not None, a bias variable will be created and added to the output.
- Finally, if activation is not None, it will be applied to the output as well.
- When the input is single tensor:
- .. math::
- Out = Act({X(normalize(V)g) + b})
- When the input are multiple tensors:
- .. math::
- Out = Act({\sum_{i=0}^{N-1}X_i(V_ig_i) + b})
- In the above equation:
- * :math:`N`: Number of the input. N equals to len(input) if input is list of Variable.
- * :math:`X_i`: The i-th input tensor.
- * :math:`V_i`: The i-th direction matrix corresponding i-th input tensor.
- * :math:`g_i`: The i-th magnitude vector corresponding i-th input tensor.
- * :math:`b`: The bias parameter created by this layer (if needed).
- * :math:`Act`: The activation function.
- * :math:`Out`: The output tensor.
- See below for an example.
- .. code-block:: text
- Given:
- data_1.data = [[[0.1, 0.2],
- [0.3, 0.4]]]
- data_1.shape = (1, 2, 2) # 1 is batch_size
- data_2 = [[[0.1, 0.2, 0.3]]]
- data_2.shape = (1, 1, 3)
- out = fluid.layers.fc(input=[data_1, data_2], size=2)
- Then:
- out.data = [[0.18669507, 0.1893476]]
- out.shape = (1, 2)
- Args:
- name_scope(str): The name of this class.
- size(int): The number of output units in this layer.
- num_flatten_dims (int): The fc layer can accept an input tensor with more than
- two dimensions. If this happens, the multidimensional tensor will first be flattened
- into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input
- tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1)
- dimensions will be flatten to form the first dimension of the final matrix (height of
- the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to
- form the second dimension of the final matrix (width of the matrix). For example, suppose
- `X` is a 5-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3.
- Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. Default: 1
- param_attr (ParamAttr|list of ParamAttr|None): The parameter attribute for learnable
- parameters/weights of this layer.
- bias_attr (ParamAttr|list of ParamAttr, default None): The parameter attribute for the bias
- of this layer. If it is set to False, no bias will be added to the output units.
- If it is set to None, the bias is initialized zero. Default: None.
- act (str|None): Activation to be applied to the output of this layer.
- is_test(bool): A flag indicating whether execution is in test phase. Default: False
- dtype(str): Dtype used for weight
- Raises:
- ValueError: If rank of the input tensor is less than 2.
- Examples:
- .. code-block:: python
- from paddle.fluid.dygraph.base import to_variable
- import paddle.fluid as fluid
- from paddle.fluid.dygraph import FC
- import numpy as np
- data = np.random.uniform( -1, 1, [30, 10, 32] ).astype('float32')
- with fluid.dygraph.guard():
- fc = FC( "fc", 64, num_flatten_dims=2)
- data = to_variable( data )
- conv = fc( data )
- """
-
- def __init__(self,
- name_scope,
- size,
- num_flatten_dims=1,
- epsilon=1e-30,
- param_attr=None,
- bias_attr=None,
- act=None,
- is_test=False,
- dtype="float32"):
- super(FC, self).__init__(name_scope, dtype)
-
- self._size = size
- self._num_flatten_dims = num_flatten_dims
- self._epsilon = epsilon
- self._dtype = dtype
- self._param_attr = param_attr
- self._bias_attr = bias_attr
- self._act = act
- self.__g = list()
- self.__v = list()
-
- @property
- def _v(self, i=0):
- return self.__v[i]
-
- @property
- def _g(self, i=0):
- return self.__g[i]
-
- @_v.setter
- def _v(self, value, i=0):
- assert isinstance(value, Parameter)
- self.__v[i] = value
-
- @_g.setter
- def _g(self, value, i=0):
- assert isinstance(value, Parameter)
- self.__g[i] = value
-
- def _build_once(self, input):
- i = 0
- for inp, param in self._helper.iter_inputs_and_params(
- input, self._param_attr):
- input_shape = inp.shape
-
- param_shape = [
- reduce(lambda a, b: a * b,
- input_shape[self._num_flatten_dims:], 1)
- ] + [self._size]
- self.__v.append(
- self.add_parameter(
- "_v%d" % i,
- self.create_parameter(
- attr=param,
- shape=param_shape,
- dtype=self._dtype,
- is_bias=False)))
-
- magnitude_shape = param_shape[1:]
- magnitude_value = np.linalg.norm(
- self.__v[i].numpy(), ord=2, axis=0)
-
- self.__g.append(
- self.add_parameter(
- "_g%d" % i,
- self.create_parameter(
- attr=fluid.ParamAttr(initializer=fluid.initializer.
- NumpyArrayInitializer(
- magnitude_value)),
- shape=magnitude_shape,
- dtype=self._dtype,
- is_bias=False)))
- i += 1
-
- size = list([self._size])
- self._b = self.create_parameter(
- attr=self._bias_attr, shape=size, dtype=self._dtype, is_bias=True)
-
- def forward(self, input):
- mul_results = list()
- i = 0
- for inp, param in self._helper.iter_inputs_and_params(
- input, self._param_attr):
- v_norm = self._helper.create_variable_for_type_inference(
- self._dtype)
- v_normalized = self._helper.create_variable_for_type_inference(
- self._dtype)
- self._helper.append_op(
- type="norm",
- inputs={"X": self.__v[i]},
- outputs={"Out": v_normalized,
- "Norm": v_norm},
- attrs={"axis": 0,
- "epsilon": self._epsilon})
- weight = self._helper.create_variable_for_type_inference(
- self._dtype)
- self._helper.append_op(
- type="elementwise_mul",
- inputs={"X": [v_normalized],
- "Y": [self.__g[i]]},
- outputs={"Out": [weight]},
- attrs={"axis": 1})
- tmp = self._helper.create_variable_for_type_inference(self._dtype)
- self._helper.append_op(
- type="mul",
- inputs={"X": inp,
- "Y": weight},
- outputs={"Out": tmp},
- attrs={
- "x_num_col_dims": self._num_flatten_dims,
- "y_num_col_dims": 1
- })
- i += 1
- mul_results.append(tmp)
-
- if len(mul_results) == 1:
- pre_bias = mul_results[0]
- else:
- pre_bias = self._helper.create_variable_for_type_inference(
- self._dtype)
- self._helper.append_op(
- type="sum",
- inputs={"X": mul_results},
- outputs={"Out": pre_bias},
- attrs={"use_mkldnn": False})
-
- if self._b:
- pre_activation = self._helper.create_variable_for_type_inference(
- dtype=self._dtype)
- self._helper.append_op(
- type="elementwise_add",
- inputs={"X": [pre_bias],
- "Y": [self._b]},
- outputs={"Out": [pre_activation]},
- attrs={"axis": self._num_flatten_dims})
- else:
- pre_activation = pre_bias
- # Currently, we don't support inplace in dygraph mode
- return self._helper.append_activation(pre_activation, act=self._act)
-
-
-class Conv2D(dg.Layer):
- """
- The convolution2D layer calculates the output based on the input, filter
- and strides, paddings, dilations, groups parameters. Input and
- Output are in NCHW format, where N is batch size, C is the number of
- channels, H is the height of the feature, and W is the width of the feature.
- Filter is in MCHW format, where M is the number of output image channels,
- C is the number of input image channels, H is the height of the filter,
- and W is the width of the filter. If the groups is greater than 1,
- C will equal the number of input image channels divided by the groups.
- Please refer to UFLDL's `convolution
- `
- for more detials.
- If bias attribution and activation type are provided, bias is added to the
- output of the convolution, and the corresponding activation function is
- applied to the final result.
- For each input :math:`X`, the equation is:
- .. math::
- Out = \sigma ((Vg) \\ast X + b)
- Where:
- * :math:`X`: Input value, a tensor with NCHW format.
- * :math:`V`: Filter direction value, a tensor with MCHW format.
- * :math:`g`: Filter magnitude value, a tensor with M format.
- * :math:`\\ast`: Convolution operation.
- * :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
- * :math:`\\sigma`: Activation function.
- * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
- Example:
- - Input:
- Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
- Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
- - Output:
- Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
- Where
- .. math::
- H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
- W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
- Args:
- name_scope(str) : The name for this class.
- num_filters(int): The number of filter. It is as same as the output
- image channel.
- filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
- it must contain two integers, (filter_size_H, filter_size_W).
- Otherwise, the filter will be a square.
- stride (int|tuple): The stride size. If stride is a tuple, it must
- contain two integers, (stride_H, stride_W). Otherwise, the
- stride_H = stride_W = stride. Default: stride = 1.
- padding (int|tuple): The padding size. If padding is a tuple, it must
- contain two integers, (padding_H, padding_W). Otherwise, the
- padding_H = padding_W = padding. Default: padding = 0.
- dilation (int|tuple): The dilation size. If dilation is a tuple, it must
- contain two integers, (dilation_H, dilation_W). Otherwise, the
- dilation_H = dilation_W = dilation. Default: dilation = 1.
- groups (int): The groups number of the Conv2d Layer. According to grouped
- convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
- the first half of the filters is only connected to the first half
- of the input channels, while the second half of the filters is only
- connected to the second half of the input channels. Default: groups=1.
- param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
- of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
- will create ParamAttr as param_attr. If the Initializer of the param_attr
- is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
- and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
- bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d.
- If it is set to False, no bias will be added to the output units.
- If it is set to None or one attribute of ParamAttr, conv2d
- will create ParamAttr as bias_attr. If the Initializer of the bias_attr
- is not set, the bias is initialized zero. Default: None.
- use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
- library is installed. Default: True
- act (str): Activation type, if it is set to None, activation is not appended.
- Default: None
- Raises:
- ValueError: If the shapes of input, filter_size, stride, padding and
- groups mismatch.
- Examples:
- .. code-block:: python
- from paddle.fluid.dygraph.base import to_variable
- import paddle.fluid as fluid
- from paddle.fluid.dygraph import Conv2D
- import numpy as np
- data = np.random.uniform( -1, 1, [10, 3, 32, 32] ).astype('float32')
- with fluid.dygraph.guard():
- conv2d = Conv2D( "conv2d", 2, 3)
- data = to_variable( data )
- conv = conv2d( data )
- """
-
- def __init__(self,
- name_scope,
- num_filters,
- filter_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=None,
- param_attr=None,
- bias_attr=None,
- use_cudnn=True,
- act=None,
- epsilon=1e-30,
- dtype="float32"):
- assert param_attr is not False, "param_attr should not be False here."
- super(Conv2D, self).__init__(name_scope, dtype)
- self._groups = groups
- self._stride = utils.convert_to_list(stride, 2, "stride")
- self._padding = utils.convert_to_list(padding, 2, "padding")
- self._dilation = utils.convert_to_list(dilation, 2, "dilation")
- self._act = act
- if not isinstance(use_cudnn, bool):
- raise ValueError("use_cudnn should be True or False")
- self._use_cudnn = use_cudnn
- self._filter_size = filter_size
- self._num_filters = num_filters
- self._param_attr = param_attr
- self._bias_attr = bias_attr
- self._epsilon = epsilon
- self._dtype = dtype
- # if (self._num_channels == self._groups and
- # num_filters % self._num_channels == 0 and not self._use_cudnn):
- # self._l_type = 'depthwise_conv2d'
- # else:
- # TODO(jiabin): recover the usage of depthwise_conv2d when it's
- # kernel fixed https://github.com/PaddlePaddle/Paddle/issues/17275
- self._l_type = "conv2d"
-
- def _build_once(self, input):
- self._num_channels = input.shape[1]
- if self._groups is None:
- num_filter_channels = self._num_channels
- else:
- if self._num_channels % self._groups != 0:
- raise ValueError("num_channels must be divisible by groups.")
- num_filter_channels = self._num_channels // self._groups
- filter_size = utils.convert_to_list(self._filter_size, 2,
- "filter_size")
- filter_shape = [self._num_filters, int(num_filter_channels)
- ] + filter_size
-
- def _get_default_param_initializer():
- filter_elem_num = filter_size[0] * filter_size[
- 1] * self._num_channels
- std = (2.0 / filter_elem_num)**0.5
- return Normal(0.0, std, 0)
-
- # weight_v
- self._filter_param_v = self.create_parameter(
- attr=self._param_attr,
- shape=filter_shape,
- dtype=self._dtype,
- default_initializer=_get_default_param_initializer())
-
- # weight_g
- norm_value = _norm(
- self._filter_param_v.numpy(), dim=0) # CAUTION: hard-code
- self._filter_param_g = self.create_parameter(
- attr=fluid.ParamAttr(
- initializer=fluid.initializer.NumpyArrayInitializer(
- norm_value)),
- shape=norm_value.shape,
- dtype=self._dtype,
- default_initializer=_get_default_param_initializer())
-
- if self._use_cudnn:
- self.create_variable(
- name="kCUDNNFwdAlgoCache",
- persistable=True,
- type=core.VarDesc.VarType.RAW)
- self.create_variable(
- name="kCUDNNBwdDataAlgoCache",
- persistable=True,
- type=core.VarDesc.VarType.RAW)
- self.create_variable(
- name="kCUDNNBwdFilterAlgoCache",
- persistable=True,
- type=core.VarDesc.VarType.RAW)
-
- self._bias_param = self.create_parameter(
- attr=self._bias_attr,
- shape=[self._num_filters],
- dtype=self._dtype,
- is_bias=True)
-
- def forward(self, input):
- matrix = self._helper.create_variable_for_type_inference(self._dtype)
- tmp = self._helper.create_variable_for_type_inference(self._dtype)
- new_shape = [
- self._filter_param_v.shape[0],
- reduce(lambda x, y: x * y, self._filter_param_v.shape[1:], 1),
- ]
-
- self._helper.append_op(
- type="reshape2",
- inputs={"X": self._filter_param_v},
- attrs={"shape": new_shape},
- outputs={"Out": matrix,
- "XShape": tmp})
-
- m_norm = self._helper.create_variable_for_type_inference(self._dtype)
- m_normalized = self._helper.create_variable_for_type_inference(
- self._dtype)
- self._helper.append_op(
- type="norm",
- inputs={"X": matrix},
- outputs={"Out": m_normalized,
- "Norm": m_norm},
- attrs={"axis": 1,
- "epsilon": self._epsilon})
-
- v_normalized = self._helper.create_variable_for_type_inference(
- self._dtype)
- tmp2 = self._helper.create_variable_for_type_inference(self._dtype)
- self._helper.append_op(
- type="reshape2",
- inputs={"X": m_normalized},
- attrs={"shape": self._filter_param_v.shape},
- outputs={"Out": v_normalized,
- "XShape": tmp2})
-
- filter_param = self._helper.create_variable_for_type_inference(
- self._dtype)
- self._helper.append_op(
- type="elementwise_mul",
- inputs={"X": [v_normalized],
- "Y": [self._filter_param_g]},
- outputs={"Out": [filter_param]},
- attrs={"axis": 0}, # CAUTION: hard-code
- )
-
- pre_bias = self._helper.create_variable_for_type_inference(
- dtype=self._dtype)
-
- self._helper.append_op(
- type=self._l_type,
- inputs={"Input": input,
- "Filter": filter_param},
- outputs={"Output": pre_bias},
- attrs={
- "strides": self._stride,
- "paddings": self._padding,
- "dilations": self._dilation,
- "groups": self._groups if self._groups else 1,
- "use_cudnn": self._use_cudnn,
- "use_mkldnn": False,
- })
-
- if self._bias_param is not None:
- pre_act = self._helper.create_variable_for_type_inference(
- dtype=self._dtype)
- self._helper.append_op(
- type="elementwise_add",
- inputs={"X": [pre_bias],
- "Y": [self._bias_param]},
- outputs={"Out": [pre_act]},
- attrs={"axis": 1})
- else:
- pre_act = pre_bias
-
- # Currently, we don't support inplace in dygraph mode
- return self._helper.append_activation(pre_act, act=self._act)
-
-
-class Conv2DTranspose(dg.Layer):
- """
- **Convlution2D transpose layer**
- The convolution2D transpose layer calculates the output based on the input,
- filter, and dilations, strides, paddings. Input(Input) and output(Output)
- are in NCHW format. Where N is batch size, C is the number of channels,
- H is the height of the feature, and W is the width of the feature.
- Parameters(dilations, strides, paddings) are two elements. These two elements
- represent height and width, respectively. The details of convolution transpose
- layer, please refer to the following explanation and references
- `therein `_.
- If bias attribution and activation type are provided, bias is added to
- the output of the convolution, and the corresponding activation function
- is applied to the final result.
- For each input :math:`X`, the equation is:
- .. math::
- Out = \sigma ((Vg) \\ast X + b)
- Where:
- * :math:`X`: Input value, a tensor with NCHW format.
- * :math:`V`: Filter value, a tensor with MCHW format.
- * :math:`g`: Filter value, a tensor with M format.
- * :math:`\\ast`: Convolution operation.
- * :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
- * :math:`\\sigma`: Activation function.
- * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
- Example:
- - Input:
- Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
- Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)`
- - Output:
- Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
- Where
- .. math::
- H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\
- W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\
- H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\
- W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] )
- Args:
- name_scope(str): The name of this class.
- num_filters(int): The number of the filter. It is as same as the output
- image channel.
- output_size(int|tuple|None): The output image size. If output size is a
- tuple, it must contain two integers, (image_H, image_W). None if use
- filter_size, padding, and stride to calculate output_size.
- if output_size and filter_size are specified at the same time, They
- should follow the formula above. Default: None.
- filter_size(int|tuple|None): The filter size. If filter_size is a tuple,
- it must contain two integers, (filter_size_H, filter_size_W).
- Otherwise, the filter will be a square. None if use output size to
- calculate filter_size. Default: None.
- padding(int|tuple): The padding size. If padding is a tuple, it must
- contain two integers, (padding_H, padding_W). Otherwise, the
- padding_H = padding_W = padding. Default: padding = 0.
- stride(int|tuple): The stride size. If stride is a tuple, it must
- contain two integers, (stride_H, stride_W). Otherwise, the
- stride_H = stride_W = stride. Default: stride = 1.
- dilation(int|tuple): The dilation size. If dilation is a tuple, it must
- contain two integers, (dilation_H, dilation_W). Otherwise, the
- dilation_H = dilation_W = dilation. Default: dilation = 1.
- groups(int): The groups number of the Conv2d transpose layer. Inspired by
- grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
- when group=2, the first half of the filters is only connected to the
- first half of the input channels, while the second half of the
- filters is only connected to the second half of the input channels.
- Default: groups = 1.
- param_attr (ParamAttr|None): The parameter attribute for learnable parameters/weights
- of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
- will create ParamAttr as param_attr. If the Initializer of the param_attr
- is not set, the parameter is initialized with Xavier. Default: None.
- bias_attr (ParamAttr|bool|None): The parameter attribute for the bias of conv2d_transpose.
- If it is set to False, no bias will be added to the output units.
- If it is set to None or one attribute of ParamAttr, conv2d_transpose
- will create ParamAttr as bias_attr. If the Initializer of the bias_attr
- is not set, the bias is initialized zero. Default: None.
- use_cudnn(bool): Use cudnn kernel or not, it is valid only when the cudnn
- library is installed. Default: True.
- act (str): Activation type, if it is set to None, activation is not appended.
- Default: None.
- Returns:
- Variable: The tensor variable storing the convolution transpose result.
- Raises:
- ValueError: If the shapes of input, filter_size, stride, padding and
- groups mismatch.
- Examples:
- .. code-block:: python
- import paddle.fluid as fluid
- import numpy
- with fluid.dygraph.guard():
- data = numpy.random.random((3, 32, 32)).astype('float32')
- conv2DTranspose = fluid.dygraph.nn.Conv2DTranspose(
- 'Conv2DTranspose', num_filters=2, filter_size=3)
- ret = conv2DTranspose(fluid.dygraph.base.to_variable(data))
- """
-
- def __init__(self,
- name_scope,
- num_filters,
- output_size=None,
- filter_size=None,
- padding=0,
- stride=1,
- dilation=1,
- groups=None,
- param_attr=None,
- bias_attr=None,
- use_cudnn=True,
- epsilon=1e-30,
- act=None,
- dtype="float32"):
- super(Conv2DTranspose, self).__init__(name_scope, dtype)
- assert (param_attr is not False
- ), "param_attr should not be False in conv2d_transpose."
- self._param_attr = param_attr
- self._bias_attr = bias_attr
- self._groups = groups
- self._num_filters = num_filters
- self._use_cudnn = use_cudnn
- self._padding = padding
- self._stride = stride
- self._dilation = dilation
- self._filter_size = filter_size
- self._output_size = output_size
- self._op_type = "conv2d_transpose"
- self._epsilon = epsilon
-
- def _build_once(self, input):
- input_channel = input.shape[1]
- if (input_channel == self._groups and
- self._num_filters == input_channel and not self._use_cudnn):
- self._op_type = "depthwise_conv2d_transpose"
-
- if not isinstance(input, Variable):
- raise TypeError("Input of conv2d_transpose must be Variable")
-
- self._padding = utils.convert_to_list(self._padding, 2, "padding")
- self._stride = utils.convert_to_list(self._stride, 2, "stride")
- self._dilation = utils.convert_to_list(self._dilation, 2, "dilation")
-
- if not isinstance(self._use_cudnn, bool):
- raise ValueError("use_cudnn should be True or False")
-
- if self._filter_size is None:
- if self._output_size is None:
- raise ValueError(
- "output_size must be set when filter_size is None")
- if isinstance(self._output_size, int):
- self._output_size = [self._output_size, self._output_size]
-
- h_in = input.shape[2]
- w_in = input.shape[3]
-
- filter_size_h = (self._output_size[0] -
- (h_in - 1) * self._stride[0] + 2 *
- self._padding[0] - 1) // self._dilation[0] + 1
- filter_size_w = (self._output_size[1] -
- (w_in - 1) * self._stride[1] + 2 *
- self._padding[1] - 1) // self._dilation[1] + 1
- self._filter_size = [filter_size_h, filter_size_w]
- else:
- self._filter_size = utils.convert_to_list(
- self._filter_size, 2, "conv2d_transpose.filter_size")
-
- if self._output_size is None:
- self._output_size = []
- elif isinstance(self._output_size, list) or isinstance(
- self._output_size, int):
- self._output_size = utils.convert_to_list(self._output_size, 2,
- "output_size")
- else:
- raise ValueError("output_size should be list or int")
- self._padding = utils.convert_to_list(self._padding, 2, "padding")
- self._groups = 1 if self._groups is None else self._groups
- filter_shape = [
- input_channel,
- self._num_filters // self._groups,
- ] + self._filter_size
-
- # img filter v (direction)
- self._img_filter_v = self.create_parameter(
- dtype=input.dtype, shape=filter_shape, attr=self._param_attr)
-
- # img filter g (magnitude)
- img_filter_magnitude = _norm(
- self._img_filter_v.numpy(), dim=0) # CAUTION: hard-code
- self._img_filter_g = self.create_parameter(
- dtype=input.dtype,
- shape=img_filter_magnitude.shape,
- attr=fluid.ParamAttr(
- initializer=NumpyArrayInitializer(img_filter_magnitude)))
-
- self._img_bias = self.create_parameter(
- attr=self._bias_attr,
- shape=[self._num_filters],
- dtype=self._dtype,
- is_bias=True)
-
- def forward(self, input):
- matrix = self._helper.create_variable_for_type_inference(self._dtype)
- tmp = self._helper.create_variable_for_type_inference(self._dtype)
- new_shape = [
- self._img_filter_v.shape[0],
- reduce(lambda x, y: x * y, self._img_filter_v.shape[1:], 1),
- ]
-
- self._helper.append_op(
- type="reshape2",
- inputs={"X": self._img_filter_v},
- attrs={"shape": new_shape},
- outputs={"Out": matrix,
- "XShape": tmp})
-
- m_norm = self._helper.create_variable_for_type_inference(self._dtype)
- m_normalized = self._helper.create_variable_for_type_inference(
- self._dtype)
- self._helper.append_op(
- type="norm",
- inputs={"X": matrix},
- outputs={"Out": m_normalized,
- "Norm": m_norm},
- attrs={"axis": 1,
- "epsilon": self._epsilon})
-
- v_normalized = self._helper.create_variable_for_type_inference(
- self._dtype)
- tmp2 = self._helper.create_variable_for_type_inference(self._dtype)
- self._helper.append_op(
- type="reshape2",
- inputs={"X": m_normalized},
- attrs={"shape": self._img_filter_v.shape},
- outputs={"Out": v_normalized,
- "XShape": tmp2})
-
- img_filter = self._helper.create_variable_for_type_inference(
- self._dtype)
- self._helper.append_op(
- type="elementwise_mul",
- inputs={"X": [v_normalized],
- "Y": [self._img_filter_g]},
- outputs={"Out": [img_filter]},
- attrs={"axis": 0}, # CAUTION: hard-code
- )
-
- pre_bias = self._helper.create_variable_for_type_inference(
- dtype=input.dtype)
- self._helper.append_op(
- type=self._op_type,
- inputs={"Input": [input],
- "Filter": [img_filter]},
- outputs={"Output": pre_bias},
- attrs={
- "output_size": self._output_size,
- "strides": self._stride,
- "paddings": self._padding,
- "dilations": self._dilation,
- "groups": self._groups,
- "use_cudnn": self._use_cudnn,
- })
-
- if self._img_bias is not None:
- pre_act = self._helper.create_variable_for_type_inference(
- dtype=self._dtype)
- self._helper.append_op(
- type="elementwise_add",
- inputs={"X": [pre_bias],
- "Y": [self._img_bias]},
- outputs={"Out": [pre_act]},
- attrs={"axis": 1})
- else:
- pre_act = pre_bias
-
- out = self._helper.append_activation(pre_act)
- return out
diff --git a/parakeet/modules/modules.py b/parakeet/modules/modules.py
index 4fb92ed..7aef463 100644
--- a/parakeet/modules/modules.py
+++ b/parakeet/modules/modules.py
@@ -26,6 +26,7 @@ def FC(name_scope,
in_features,
size,
num_flatten_dims=1,
+ relu=False,
dropout=0.0,
epsilon=1e-30,
act=None,
@@ -39,7 +40,11 @@ def FC(name_scope,
# stds
if isinstance(in_features, int):
in_features = [in_features]
+
stds = [np.sqrt((1 - dropout) / in_feature) for in_feature in in_features]
+ if relu:
+ stds = [std * np.sqrt(2.0) for std in stds]
+
weight_inits = [
fluid.initializer.NormalInitializer(scale=std) for std in stds
]
@@ -456,3 +461,152 @@ class PositionEmbedding(dg.Layer):
return out
else:
raise Exception("Then you can just use position rate at init")
+
+
+class Conv1D_GU(dg.Layer):
+ def __init__(self,
+ name_scope,
+ conditioner_dim,
+ in_channels,
+ num_filters,
+ filter_size,
+ dilation,
+ causal=False,
+ residual=True,
+ dtype="float32"):
+ super(Conv1D_GU, self).__init__(name_scope, dtype=dtype)
+
+ self.conditioner_dim = conditioner_dim
+ self.in_channels = in_channels
+ self.num_filters = num_filters
+ self.filter_size = filter_size
+ self.dilation = dilation
+ self.causal = causal
+ self.residual = residual
+
+ if residual:
+ assert (
+ in_channels == num_filters
+ ), "this block uses residual connection"\
+ "the input_channels should equals num_filters"
+
+ self.conv = Conv1D(
+ self.full_name(),
+ in_channels,
+ 2 * num_filters,
+ filter_size,
+ dilation,
+ causal=causal,
+ dtype=dtype)
+
+ self.fc = Conv1D(
+ self.full_name(),
+ conditioner_dim,
+ 2 * num_filters,
+ filter_size=1,
+ dilation=1,
+ causal=False,
+ dtype=dtype)
+
+ def forward(self, x, skip=None, conditioner=None):
+ """
+ Args:
+ x (Variable): Shape(B, C_in, 1, T), the input of Conv1D_GU
+ layer, where B means batch_size, C_in means the input channels
+ T means input time steps.
+ skip (Variable): Shape(B, C_in, 1, T), skip connection.
+ conditioner (Variable): Shape(B, C_con, 1, T), expanded mel
+ conditioner, where C_con is conditioner hidden dim which
+ equals the num of mel bands. Note that when using residual
+ connection, the Conv1D_GU does not change the number of
+ channels, so out channels equals input channels.
+ Returns:
+ x (Variable): Shape(B, C_out, 1, T), the output of Conv1D_GU, where
+ C_out means the output channels of Conv1D_GU.
+ skip (Variable): Shape(B, C_out, 1, T), skip connection.
+ """
+ residual = x
+ x = self.conv(x)
+
+ if conditioner is not None:
+ cond_bias = self.fc(conditioner)
+ x += cond_bias
+
+ content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
+
+ # Gated Unit.
+ x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate),
+ fluid.layers.tanh(content))
+
+ if skip is None:
+ skip = x
+ else:
+ skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
+
+ if self.residual:
+ x = fluid.layers.scale(residual + x, np.sqrt(0.5))
+
+ return x, skip
+
+ def add_input(self, x, skip=None, conditioner=None):
+ """
+ Inputs:
+ x: shape(B, num_filters, 1, time_steps)
+ skip: shape(B, num_filters, 1, time_steps), skip connection
+ conditioner: shape(B, conditioner_dim, 1, time_steps)
+ Outputs:
+ x: shape(B, num_filters, 1, time_steps), where time_steps = 1
+ skip: skip connection, same shape as x
+ """
+ residual = x
+
+ # add step input and produce step output
+ x = self.conv.add_input(x)
+
+ if conditioner is not None:
+ cond_bias = self.fc(conditioner)
+ x += cond_bias
+
+ content, gate = fluid.layers.split(x, num_or_sections=2, dim=1)
+
+ # Gated Unit.
+ x = fluid.layers.elementwise_mul(fluid.layers.sigmoid(gate),
+ fluid.layers.tanh(content))
+
+ if skip is None:
+ skip = x
+ else:
+ skip = fluid.layers.scale(skip + x, np.sqrt(0.5))
+
+ if self.residual:
+ x = fluid.layers.scale(residual + x, np.sqrt(0.5))
+
+ return x, skip
+
+
+def Conv2DTranspose(name_scope,
+ num_filters,
+ filter_size,
+ padding=0,
+ stride=1,
+ dilation=1,
+ use_cudnn=True,
+ act=None,
+ dtype="float32"):
+ val = 1.0 / (filter_size[0] * filter_size[1])
+ weight_init = fluid.initializer.ConstantInitializer(val)
+ weight_attr = fluid.ParamAttr(initializer=weight_init)
+
+ layer = weight_norm.Conv2DTranspose(
+ name_scope,
+ num_filters,
+ filter_size=filter_size,
+ padding=padding,
+ stride=stride,
+ dilation=dilation,
+ param_attr=weight_attr,
+ use_cudnn=use_cudnn,
+ act=act,
+ dtype=dtype)
+
+ return layer