diff --git a/examples/clarinet/README.md b/examples/clarinet/README.md new file mode 100644 index 0000000..6046327 --- /dev/null +++ b/examples/clarinet/README.md @@ -0,0 +1,103 @@ +# Clarinet + +Paddle implementation of clarinet in dynamic graph, a convolutional network based vocoder. The implementation is based on the paper [ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech](arxiv.org/abs/1807.07281). + + +## Dataset + +We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). + +```bash +wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 +tar xjvf LJSpeech-1.1.tar.bz2 +``` + +## Project Structure + +```text +├── data.py data_processing +├── configs/ (example) configuration file +├── synthesis.py script to synthesize waveform from mel_spectrogram +├── train.py script to train a model +└── utils.py utility functions +``` + +## Train + +Train the model using train.py, follow the usage displayed by `python train.py --help`. + +```text +usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT] + [--data DATA] [--resume RESUME] [--wavenet WAVENET] + +train a clarinet model with LJspeech and a trained wavenet model. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG path of the config file. + --device DEVICE device to use. + --output OUTPUT path to save student. + --data DATA path of LJspeech dataset. + --resume RESUME checkpoint to load from. + --wavenet WAVENET wavenet checkpoint to use. +``` + +1. `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config. +2. `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt). +3. `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig. +4. `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below. + +```text +├── checkpoints # checkpoint +├── states # audio files generated at validation +└── log # tensorboard log +``` + +5. `--device` is the device (gpu id) to use for training. `-1` means CPU. +6. `--wavenet` is the path of the wavenet checkpoint to load. if you do not specify `--resume`, then this must be provided. + + +Before you start training a clarinet model, you should have trained a wavenet model with single gaussian as output distribution. Make sure the config for teacher matches that for the trained model. + +example script: + +```bash +python train.py --config=./configs/clarinet_ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 --conditioner=wavenet_checkpoint/conditioner --conditioner=wavenet_checkpoint/teacher +``` + +You can monitor training log via tensorboard, using the script below. + +```bash +cd experiment/log +tensorboard --logdir=. +``` + +## Synthesis +```text +usage: synthesis.py [-h] [--config CONFIG] [--device DEVICE] [--data DATA] + checkpoint output + +train a clarinet model with LJspeech and a trained wavenet model. + +positional arguments: + checkpoint checkpoint to load from. + output path to save student. + +optional arguments: + -h, --help show this help message and exit + --config CONFIG path of the config file. + --device DEVICE device to use. + --data DATA path of LJspeech dataset. +``` + +1. `--config` is the configuration file to use. You should use the same configuration with which you train you model. +2. `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files. +3. `checkpoint` is the checkpoint to load. +4. `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`). +5. `--device` is the device (gpu id) to use for training. `-1` means CPU. + +example script: + +```bash +python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated +``` diff --git a/examples/clarinet/configs/clarinet_ljspeech.yaml b/examples/clarinet/configs/clarinet_ljspeech.yaml new file mode 100644 index 0000000..7ceedcc --- /dev/null +++ b/examples/clarinet/configs/clarinet_ljspeech.yaml @@ -0,0 +1,52 @@ +data: + batch_size: 4 + train_clip_seconds: 0.5 + sample_rate: 22050 + hop_length: 256 + win_length: 1024 + n_fft: 2048 + + n_mels: 80 + valid_size: 16 + + +conditioner: + upsampling_factors: [16, 16] + +teacher: + n_loop: 10 + n_layer: 3 + filter_size: 2 + residual_channels: 128 + loss_type: "mog" + output_dim: 3 + log_scale_min: -9 + +student: + n_loops: [10, 10, 10, 10, 10, 10] + n_layers: [1, 1, 1, 1, 1, 1] + filter_size: 3 + residual_channels: 64 + log_scale_min: -7 + +stft: + n_fft: 2048 + win_length: 1024 + hop_length: 256 + +loss: + lmd: 4 + +train: + learning_rate: 0.0005 + anneal_rate: 0.5 + anneal_interval: 200000 + gradient_max_norm: 100.0 + + checkpoint_interval: 1000 + eval_interval: 1000 + + max_iterations: 2000000 + + + diff --git a/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml b/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml index 427c975..a848a52 100644 --- a/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml +++ b/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml @@ -1,5 +1,4 @@ data: - root: "/workspace/datasets/LJSpeech-1.1/" batch_size: 4 train_clip_seconds: 0.5 sample_rate: 22050 diff --git a/examples/wavenet/configs/wavenet_single_gaussian.yaml b/examples/wavenet/configs/wavenet_single_gaussian.yaml index 8dd8d46..8e33349 100644 --- a/examples/wavenet/configs/wavenet_single_gaussian.yaml +++ b/examples/wavenet/configs/wavenet_single_gaussian.yaml @@ -1,5 +1,4 @@ data: - root: "/workspace/datasets/LJSpeech-1.1/" batch_size: 4 train_clip_seconds: 0.5 sample_rate: 22050 diff --git a/examples/wavenet/configs/wavenet_softmax.yaml b/examples/wavenet/configs/wavenet_softmax.yaml index 57c36cc..98018ee 100644 --- a/examples/wavenet/configs/wavenet_softmax.yaml +++ b/examples/wavenet/configs/wavenet_softmax.yaml @@ -1,5 +1,4 @@ data: - root: "/workspace/datasets/LJSpeech-1.1/" batch_size: 4 train_clip_seconds: 0.5 sample_rate: 22050 diff --git a/examples/wavenet/utils.py b/examples/wavenet/utils.py index 82ab553..86c8ebf 100644 --- a/examples/wavenet/utils.py +++ b/examples/wavenet/utils.py @@ -56,7 +56,7 @@ def eval_model(model, valid_loader, output_dir, sample_rate): audio_clips, mel_specs, audio_starts = batch wav_var = model.synthesis(mel_specs) wav_np = wav_var.numpy()[0] - sf.write(wav_np, path, samplerate=sample_rate) + sf.write(path, wav_np, samplerate=sample_rate) print("generated {}".format(path)) diff --git a/parakeet/data/dataset.py b/parakeet/data/dataset.py index d577f9e..16a58bf 100644 --- a/parakeet/data/dataset.py +++ b/parakeet/data/dataset.py @@ -134,7 +134,7 @@ class SliceDataset(DatasetMixin): format(len(order), len(dataset))) self._order = order - def len(self): + def __len__(self): return self._size def get_example(self, i): diff --git a/parakeet/models/clarinet/__init__.py b/parakeet/models/clarinet/__init__.py new file mode 100644 index 0000000..f3148be --- /dev/null +++ b/parakeet/models/clarinet/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .net import * +from .parallel_wavenet import * \ No newline at end of file diff --git a/parakeet/models/clarinet/net.py b/parakeet/models/clarinet/net.py new file mode 100644 index 0000000..35f0f03 --- /dev/null +++ b/parakeet/models/clarinet/net.py @@ -0,0 +1,169 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import numpy as np +from scipy import signal +from tqdm import trange + +import paddle.fluid.layers as F +import paddle.fluid.dygraph as dg +import paddle.fluid.initializer as I +import paddle.fluid.layers.distributions as D + +from parakeet.modules.weight_norm import Conv2DTranspose +from parakeet.models.wavenet import crop, WaveNet, UpsampleNet +from parakeet.models.clarinet.parallel_wavenet import ParallelWaveNet +from parakeet.models.clarinet.utils import conv2d + + +# Gaussian IAF model +class Clarinet(dg.Layer): + def __init__(self, + encoder, + teacher, + student, + stft, + min_log_scale=-6.0, + lmd=4.0): + super(Clarinet, self).__init__() + self.lmd = lmd + self.encoder = encoder + self.teacher = teacher + self.student = student + + self.min_log_scale = min_log_scale + self.stft = stft + + def forward(self, audio, mel, audio_start, clip_kl=True): + """Compute loss for a distill model + + Arguments: + audio {Variable} -- shape(batch_size, time_steps), target waveform. + mel {Variable} -- shape(batch_size, condition_dim, time_steps // hop_length), original mel spectrogram, not upsampled yet. + audio_starts {Variable} -- shape(batch_size, ), the index of the start sample. + clip_kl (bool) -- whether to clip kl divergence if it is greater than 10.0. + + Returns: + Variable -- shape(1,), loss + """ + + batch_size, audio_length = audio.shape # audio clip's length + + z = F.gaussian_random(audio.shape) + condition = self.encoder(mel) # (B, C, T) + condition_slice = crop(condition, audio_start, audio_length) + + x, s_means, s_scales = self.student(z, condition_slice) # all [0: T] + s_means = s_means[:, 1:] # (B, T-1), time steps [1: T] + s_scales = s_scales[:, 1:] # (B, T-1), time steps [1: T] + s_clipped_scales = F.clip(s_scales, self.min_log_scale, 100.) + + # teacher outputs single gaussian + y = self.teacher(x[:, :-1], condition_slice[:, :, 1:]) + _, t_means, t_scales = F.split(y, 3, -1) # time steps [1: T] + t_means = F.squeeze(t_means, [-1]) # (B, T-1), time steps [1: T] + t_scales = F.squeeze(t_scales, [-1]) # (B, T-1), time steps [1: T] + t_clipped_scales = F.clip(t_scales, self.min_log_scale, 100.) + + s_distribution = D.Normal(s_means, F.exp(s_clipped_scales)) + t_distribution = D.Normal(t_means, F.exp(t_clipped_scales)) + + # kl divergence loss, so we only need to sample once? no MC + kl = s_distribution.kl_divergence(t_distribution) + if clip_kl: + kl = F.clip(kl, -100., 10.) + # context size dropped + kl = F.reduce_mean(kl[:, self.teacher.context_size:]) + # major diff here + regularization = F.mse_loss(t_scales[:, self.teacher.context_size:], + s_scales[:, self.teacher.context_size:]) + + # introduce information from real target + spectrogram_frame_loss = F.mse_loss( + self.stft.magnitude(audio), self.stft.magnitude(x)) + loss = kl + self.lmd * regularization + spectrogram_frame_loss + loss_dict = { + "loss": loss, + "kl_divergence": kl, + "regularization": regularization, + "stft_loss": spectrogram_frame_loss + } + return loss_dict + + @dg.no_grad + def synthesis(self, mel): + """Synthesize waveform conditioned on the mel spectrogram. + + Arguments: + mel {Variable} -- shape(batch_size, frequqncy_bands, frames) + + Returns: + Variable -- shape(batch_size, frames * upsample_factor) + """ + condition = self.encoder(mel) + samples_shape = (condition.shape[0], condition.shape[-1]) + z = F.gaussian_random(samples_shape) + x, s_means, s_scales = self.student(z, condition) + return x + + +class STFT(dg.Layer): + def __init__(self, n_fft, hop_length, win_length, window="hanning"): + super(STFT, self).__init__() + self.hop_length = hop_length + self.n_bin = 1 + n_fft // 2 + self.n_fft = n_fft + + # calculate window + window = signal.get_window(window, win_length) + if n_fft != win_length: + pad = (n_fft - win_length) // 2 + window = np.pad(window, ((pad, pad), ), 'constant') + + # calculate weights + r = np.arange(0, n_fft) + M = np.expand_dims(r, -1) * np.expand_dims(r, 0) + w_real = np.reshape(window * + np.cos(2 * np.pi * M / n_fft)[:self.n_bin], + (self.n_bin, 1, 1, self.n_fft)).astype("float32") + w_imag = np.reshape(window * + np.sin(-2 * np.pi * M / n_fft)[:self.n_bin], + (self.n_bin, 1, 1, self.n_fft)).astype("float32") + + w = np.concatenate([w_real, w_imag], axis=0) + self.weight = dg.to_variable(w) + + def forward(self, x): + # x(batch_size, time_steps) + # pad it first with reflect mode + pad_start = F.reverse(x[:, 1:1 + self.n_fft // 2], axis=1) + pad_stop = F.reverse(x[:, -(1 + self.n_fft // 2):-1], axis=1) + x = F.concat([pad_start, x, pad_stop], axis=-1) + + # to BC1T, C=1 + x = F.unsqueeze(x, axes=[1, 2]) + out = conv2d(x, self.weight, stride=(1, self.hop_length)) + real, imag = F.split(out, 2, dim=1) # BC1T + return real, imag + + def power(self, x): + real, imag = self(x) + power = real**2 + imag**2 + return power + + def magnitude(self, x): + power = self.power(x) + magnitude = F.sqrt(power) + return magnitude diff --git a/parakeet/models/clarinet/parallel_wavenet.py b/parakeet/models/clarinet/parallel_wavenet.py new file mode 100644 index 0000000..be30b7b --- /dev/null +++ b/parakeet/models/clarinet/parallel_wavenet.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import time +import itertools +import numpy as np + +import paddle.fluid.layers as F +import paddle.fluid.dygraph as dg +import paddle.fluid.initializer as I +import paddle.fluid.layers.distributions as D + +from parakeet.modules.weight_norm import Linear, Conv1D, Conv1DCell, Conv2DTranspose +from parakeet.models.wavenet import WaveNet + + +class ParallelWaveNet(dg.Layer): + def __init__(self, n_loops, n_layers, residual_channels, condition_dim, + filter_size): + super(ParallelWaveNet, self).__init__() + self.flows = dg.LayerList() + for n_loop, n_layer in zip(n_loops, n_layers): + # teacher's log_scale_min does not matter herem, -100 is a dummy value + self.flows.append( + WaveNet(n_loop, n_layer, residual_channels, 3, condition_dim, + filter_size, "mog", -100.0)) + + def forward(self, z, condition=None): + """Inverse Autoregressive Flow. Several wavenets. + + Arguments: + z {Variable} -- shape(batch_size, time_steps), hidden variable, sampled from a standard normal distribution. + + Keyword Arguments: + condition {Variable} -- shape(batch_size, condition_dim, time_steps), condition, basically upsampled mel spectrogram. (default: {None}) + + Returns: + Variable -- shape(batch_size, time_steps), transformed z. + Variable -- shape(batch_size, time_steps), output distribution's mu. + Variable -- shape(batch_size, time_steps), output distribution's log_std. + """ + + for i, flow in enumerate(self.flows): + theta = flow(z, condition) # w, mu, log_std [0: T] + w, mu, log_std = F.split(theta, 3, dim=-1) # (B, T, 1) for each + mu = F.squeeze(mu, [-1]) #[0: T] + log_std = F.squeeze(log_std, [-1]) #[0: T] + z = z * F.exp(log_std) + mu #[0: T] + + if i == 0: + out_mu = mu + out_log_std = log_std + else: + out_mu = out_mu * F.exp(log_std) + mu + out_log_std += log_std + + return z, out_mu, out_log_std diff --git a/parakeet/models/clarinet/utils.py b/parakeet/models/clarinet/utils.py new file mode 100644 index 0000000..c2d3252 --- /dev/null +++ b/parakeet/models/clarinet/utils.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle import fluid +from paddle.fluid.core import ops + + +@fluid.framework.dygraph_only +def conv2d(input, + weight, + stride=(1, 1), + padding=((0, 0), (0, 0)), + dilation=(1, 1), + groups=1, + use_cudnn=True, + data_format="NCHW"): + padding = tuple(pad for pad_dim in padding for pad in pad_dim) + + inputs = { + 'Input': [input], + 'Filter': [weight], + } + attrs = { + 'strides': stride, + 'paddings': padding, + 'dilations': dilation, + 'groups': groups, + 'use_cudnn': use_cudnn, + 'use_mkldnn': False, + 'fuse_relu_before_depthwise_conv': False, + "padding_algorithm": "EXPLICIT", + "data_format": data_format, + } + + outputs = ops.conv2d(inputs, attrs) + out = outputs["Output"][0] + return out \ No newline at end of file diff --git a/parakeet/models/wavenet/net.py b/parakeet/models/wavenet/net.py index 7bbc67a..72b9ad5 100644 --- a/parakeet/models/wavenet/net.py +++ b/parakeet/models/wavenet/net.py @@ -57,7 +57,7 @@ class UpsampleNet(dg.Layer): """ def __init__(self, upscale_factors=[16, 16]): - super().__init__() + super(UpsampleNet, self).__init__() self.upscale_factors = list(upscale_factors) self.upsample_convs = dg.LayerList() for i, factor in enumerate(upscale_factors): @@ -92,7 +92,7 @@ class UpsampleNet(dg.Layer): # AutoRegressive Model class ConditionalWavenet(dg.Layer): def __init__(self, encoder: UpsampleNet, decoder: WaveNet): - super().__init__() + super(ConditionalWavenet, self).__init__() self.encoder = encoder self.decoder = decoder diff --git a/parakeet/models/wavenet/wavenet.py b/parakeet/models/wavenet/wavenet.py index 289efe7..4c355f4 100644 --- a/parakeet/models/wavenet/wavenet.py +++ b/parakeet/models/wavenet/wavenet.py @@ -39,7 +39,7 @@ def dequantize(quantized, n_bands): class ResidualBlock(dg.Layer): def __init__(self, residual_channels, condition_dim, filter_size, dilation): - super().__init__() + super(ResidualBlock, self).__init__() dilated_channels = 2 * residual_channels # following clarinet's implementation, we do not have parametric residual # & skip connection. @@ -135,7 +135,7 @@ class ResidualBlock(dg.Layer): class ResidualNet(dg.Layer): def __init__(self, n_loop, n_layer, residual_channels, condition_dim, filter_size): - super().__init__() + super(ResidualNet, self).__init__() # double the dilation at each layer in a loop(n_loop layers) dilations = [2**i for i in range(n_loop)] * n_layer self.context_size = 1 + sum(dilations) @@ -198,7 +198,7 @@ class ResidualNet(dg.Layer): class WaveNet(dg.Layer): def __init__(self, n_loop, n_layer, residual_channels, output_dim, condition_dim, filter_size, loss_type, log_scale_min): - super().__init__() + super(WaveNet, self).__init__() if loss_type not in ["softmax", "mog"]: raise ValueError("loss_type {} is not supported".format(loss_type)) if loss_type == "softmax":