refactor wavenet

This commit is contained in:
chenfeiyu 2020-02-26 15:06:48 +00:00
parent faa725bad9
commit a012825423
21 changed files with 1316 additions and 1514 deletions

View File

@ -0,0 +1,97 @@
# Wavenet
Paddle implementation of wavenet in dynamic graph, a convolutional network based vocoder. Wavenet is proposed in [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499), but in thie experiment, the implementation follows the teacher model in [ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech](arxiv.org/abs/1807.07281).
## Dataset
We experiment with the LJSpeech dataset. Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
## Project Structure
```text
├── data.py data_processing
├── configs/ (example) configuration file
├── synthesis.py script to synthesize waveform from mel_spectrogram
├── train.py script to train a model
└── utils.py utility functions
```
## Train
Train the model using train.py, follow the usage displayed by `python train.py --help`.
```text
usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT]
[--device DEVICE] [--resume RESUME]
Train a wavenet model with LJSpeech.
optional arguments:
-h, --help show this help message and exit
--data DATA path of the LJspeech dataset.
--config CONFIG path of the config file.
--output OUTPUT path to save results.
--device DEVICE device to use.
--resume RESUME checkpoint to resume from.
```
1. `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config.
2. `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt).
3. `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig.
4. `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below.
```text
├── checkpoints # checkpoint
└── log # tensorboard log
```
5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
example script:
```bash
python train.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0
```
You can monitor training log via tensorboard, using the script below.
```bash
cd experiment/log
tensorboard --logdir=.
```
## Synthesis
```text
usage: synthesis.py [-h] [--data DATA] [--config CONFIG] [--device DEVICE]
checkpoint output
Synthesize valid data from LJspeech with a wavenet model.
positional arguments:
checkpoint checkpoint to load.
output path to save results.
optional arguments:
-h, --help show this help message and exit
--data DATA path of the LJspeech dataset.
--config CONFIG path of the config file.
--device DEVICE device to use.
```
1. `--config` is the configuration file to use. You should use the same configuration with which you train you model.
2. `--data` is the path of the LJspeech dataset. A dataset is not needed for synthesis, but since the input is mel spectrogram, we need to get mel spectrogram from audio files.
3. `checkpoint` is the checkpoint to load.
4. `output_path` is the directory to save results. The output path contains the generated audio files (`*.wav`).
5. `--device` is the device (gpu id) to use for training. `-1` means CPU.
example script:
```bash
python synthesis.py --config=./configs/wavenet_single_gaussian.yaml --data=./LJSpeech-1.1/ --device=0 experiment/checkpoints/step_500000 generated
```

View File

@ -0,0 +1,37 @@
data:
root: "/workspace/datasets/LJSpeech-1.1/"
batch_size: 4
train_clip_seconds: 0.5
sample_rate: 22050
hop_length: 256
win_length: 1024
n_fft: 2048
n_mels: 80
valid_size: 16
model:
upsampling_factors: [16, 16]
n_loop: 10
n_layer: 3
filter_size: 2
residual_channels: 128
loss_type: "mog"
output_dim: 30
log_scale_min: -9
train:
learning_rate: 0.001
anneal_rate: 0.5
anneal_interval: 200000
gradient_max_norm: 100.0
checkpoint_interval: 10000
snap_interval: 10000
eval_interval: 10000
max_iterations: 200000

View File

@ -0,0 +1,37 @@
data:
root: "/workspace/datasets/LJSpeech-1.1/"
batch_size: 4
train_clip_seconds: 0.5
sample_rate: 22050
hop_length: 256
win_length: 1024
n_fft: 2048
n_mels: 80
valid_size: 16
model:
upsampling_factors: [16, 16]
n_loop: 10
n_layer: 3
filter_size: 2
residual_channels: 128
loss_type: "mog"
output_dim: 3
log_scale_min: -9
train:
learning_rate: 0.001
anneal_rate: 0.5
anneal_interval: 200000
gradient_max_norm: 100.0
checkpoint_interval: 10000
snap_interval: 10000
eval_interval: 10000
max_iterations: 200000

View File

@ -0,0 +1,37 @@
data:
root: "/workspace/datasets/LJSpeech-1.1/"
batch_size: 4
train_clip_seconds: 0.5
sample_rate: 22050
hop_length: 256
win_length: 1024
n_fft: 2048
n_mels: 80
valid_size: 16
model:
upsampling_factors: [16, 16]
n_loop: 10
n_layer: 3
filter_size: 2
residual_channels: 128
loss_type: "softmax"
output_dim: 2048
log_scale_min: -9
train:
learning_rate: 0.001
anneal_rate: 0.5
anneal_interval: 200000
gradient_max_norm: 100.0
checkpoint_interval: 10000
snap_interval: 10000
eval_interval: 10000
max_iterations: 200000

163
examples/wavenet/data.py Normal file
View File

@ -0,0 +1,163 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import numpy as np
import librosa
from pathlib import Path
import pandas as pd
from parakeet.data import batch_spec, batch_wav
from parakeet.data import DatasetMixin
class LJSpeechMetaData(DatasetMixin):
def __init__(self, root):
self.root = Path(root)
self._wav_dir = self.root.joinpath("wavs")
csv_path = self.root.joinpath("metadata.csv")
self._table = pd.read_csv(
csv_path,
sep="|",
header=None,
quoting=csv.QUOTE_NONE,
names=["fname", "raw_text", "normalized_text"])
def get_example(self, i):
fname, raw_text, normalized_text = self._table.iloc[i]
fname = str(self._wav_dir.joinpath(fname + ".wav"))
return fname, raw_text, normalized_text
def __len__(self):
return len(self._table)
class Transform(object):
def __init__(self, sample_rate, n_fft, win_length, hop_length, n_mels):
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length
self.hop_length = hop_length
self.n_mels = n_mels
def __call__(self, example):
wav_path, _, _ = example
sr = self.sample_rate
n_fft = self.n_fft
win_length = self.win_length
hop_length = self.hop_length
n_mels = self.n_mels
wav, loaded_sr = librosa.load(wav_path, sr=None)
assert loaded_sr == sr, "sample rate does not match, resampling applied"
# Pad audio to the right size.
frames = int(np.ceil(float(wav.size) / hop_length))
fft_padding = (n_fft - hop_length) // 2 # sound
desired_length = frames * hop_length + fft_padding * 2
pad_amount = (desired_length - wav.size) // 2
if wav.size % 2 == 0:
wav = np.pad(wav, (pad_amount, pad_amount), mode='reflect')
else:
wav = np.pad(wav, (pad_amount, pad_amount + 1), mode='reflect')
# Normalize audio.
wav = wav / np.abs(wav).max() * 0.999
# Compute mel-spectrogram.
# Turn center to False to prevent internal padding.
spectrogram = librosa.core.stft(
wav,
hop_length=hop_length,
win_length=win_length,
n_fft=n_fft,
center=False)
spectrogram_magnitude = np.abs(spectrogram)
# Compute mel-spectrograms.
mel_filter_bank = librosa.filters.mel(sr=sr,
n_fft=n_fft,
n_mels=n_mels)
mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude)
mel_spectrogram = mel_spectrogram
# Rescale mel_spectrogram.
min_level, ref_level = 1e-5, 20 # hard code it
mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram))
mel_spectrogram = mel_spectrogram - ref_level
mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1)
# Extract the center of audio that corresponds to mel spectrograms.
audio = wav[fft_padding:-fft_padding]
assert mel_spectrogram.shape[1] * hop_length == audio.size
# there is no clipping here
return audio, mel_spectrogram
class DataCollector(object):
def __init__(self,
context_size,
sample_rate,
hop_length,
train_clip_seconds,
valid=False):
frames_per_second = sample_rate // hop_length
train_clip_frames = int(
np.ceil(train_clip_seconds * frames_per_second))
context_frames = context_size // hop_length
self.num_frames = train_clip_frames + context_frames
self.sample_rate = sample_rate
self.hop_length = hop_length
self.valid = valid
def random_crop(self, sample):
audio, mel_spectrogram = sample
audio_frames = int(audio.size) // self.hop_length
max_start_frame = audio_frames - self.num_frames
assert max_start_frame >= 0, "audio is too short to be cropped"
frame_start = np.random.randint(0, max_start_frame)
# frame_start = 0 # norandom
frame_end = frame_start + self.num_frames
audio_start = frame_start * self.hop_length
audio_end = frame_end * self.hop_length
audio = audio[audio_start:audio_end]
return audio, mel_spectrogram, audio_start
def __call__(self, samples):
# transform them first
if self.valid:
samples = [(audio, mel_spectrogram, 0)
for audio, mel_spectrogram in samples]
else:
samples = [self.random_crop(sample) for sample in samples]
# batch them
audios = [sample[0] for sample in samples]
audio_starts = [sample[2] for sample in samples]
mels = [sample[1] for sample in samples]
mels = batch_spec(mels)
if self.valid:
audios = batch_wav(audios, dtype=np.float32)
else:
audios = np.array(audios, dtype=np.float32)
audio_starts = np.array(audio_starts, dtype=np.int64)
return audios, mels, audio_starts

View File

@ -0,0 +1,124 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import ruamel.yaml
import argparse
from tqdm import tqdm
from tensorboardX import SummaryWriter
from paddle import fluid
import paddle.fluid.dygraph as dg
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, eval_model, save_checkpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Synthesize valid data from LJspeech with a wavenet model.")
parser.add_argument(
"--data", type=str, help="path of the LJspeech dataset.")
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
parser.add_argument("checkpoint", type=str, help="checkpoint to load.")
parser.add_argument(
"output", type=str, default="experiment", help="path to save results.")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
ljspeech_meta = LJSpeechMetaData(args.data)
data_config = config["data"]
sample_rate = data_config["sample_rate"]
n_fft = data_config["n_fft"]
win_length = data_config["win_length"]
hop_length = data_config["hop_length"]
n_mels = data_config["n_mels"]
train_clip_seconds = data_config["train_clip_seconds"]
transform = Transform(sample_rate, n_fft, win_length, hop_length, n_mels)
ljspeech = TransformDataset(ljspeech_meta, transform)
valid_size = data_config["valid_size"]
ljspeech_valid = SliceDataset(ljspeech, 0, valid_size)
ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech))
model_config = config["model"]
n_loop = model_config["n_loop"]
n_layer = model_config["n_layer"]
filter_size = model_config["filter_size"]
context_size = 1 + n_layer * sum([filter_size**i for i in range(n_loop)])
print("context size is {} samples".format(context_size))
train_batch_fn = DataCollector(context_size, sample_rate, hop_length,
train_clip_seconds)
valid_batch_fn = DataCollector(
context_size, sample_rate, hop_length, train_clip_seconds, valid=True)
batch_size = data_config["batch_size"]
train_cargo = DataCargo(
ljspeech_train,
train_batch_fn,
batch_size,
sampler=RandomSampler(ljspeech_train))
# only batch=1 for validation is enabled
valid_cargo = DataCargo(
ljspeech_valid,
valid_batch_fn,
batch_size=1,
sampler=SequentialSampler(ljspeech_valid))
make_output_tree(args.output)
if args.device == -1:
place = fluid.CPUPlace()
else:
place = fluid.CUDAPlace(args.device)
with dg.guard(place):
model_config = config["model"]
upsampling_factors = model_config["upsampling_factors"]
encoder = UpsampleNet(upsampling_factors)
n_loop = model_config["n_loop"]
n_layer = model_config["n_layer"]
residual_channels = model_config["residual_channels"]
output_dim = model_config["output_dim"]
loss_type = model_config["loss_type"]
log_scale_min = model_config["log_scale_min"]
decoder = WaveNet(n_loop, n_layer, residual_channels, output_dim,
n_mels, filter_size, loss_type, log_scale_min)
model = ConditionalWavenet(encoder, decoder)
summary(model)
model_dict, _ = dg.load_dygraph(args.checkpoint)
print("Loading from {}.pdparams".format(args.checkpoint))
model.set_dict(model_dict)
train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place)
valid_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
valid_loader.set_batch_generator(valid_cargo, place)
eval_model(model, valid_loader, args.output, sample_rate)

181
examples/wavenet/train.py Normal file
View File

@ -0,0 +1,181 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import ruamel.yaml
import argparse
from tqdm import tqdm
from tensorboardX import SummaryWriter
from paddle import fluid
import paddle.fluid.dygraph as dg
from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler
from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet
from parakeet.utils.layer_tools import summary
from data import LJSpeechMetaData, Transform, DataCollector
from utils import make_output_tree, valid_model, save_checkpoint
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a wavenet model with LJSpeech.")
parser.add_argument(
"--data", type=str, help="path of the LJspeech dataset.")
parser.add_argument("--config", type=str, help="path of the config file.")
parser.add_argument(
"--output",
type=str,
default="experiment",
help="path to save results.")
parser.add_argument(
"--device", type=int, default=-1, help="device to use.")
parser.add_argument(
"--resume", type=str, help="checkpoint to resume from.")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
ljspeech_meta = LJSpeechMetaData(args.data)
data_config = config["data"]
sample_rate = data_config["sample_rate"]
n_fft = data_config["n_fft"]
win_length = data_config["win_length"]
hop_length = data_config["hop_length"]
n_mels = data_config["n_mels"]
train_clip_seconds = data_config["train_clip_seconds"]
transform = Transform(sample_rate, n_fft, win_length, hop_length, n_mels)
ljspeech = TransformDataset(ljspeech_meta, transform)
valid_size = data_config["valid_size"]
ljspeech_valid = SliceDataset(ljspeech, 0, valid_size)
ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech))
model_config = config["model"]
n_loop = model_config["n_loop"]
n_layer = model_config["n_layer"]
filter_size = model_config["filter_size"]
context_size = 1 + n_layer * sum([filter_size**i for i in range(n_loop)])
print("context size is {} samples".format(context_size))
train_batch_fn = DataCollector(context_size, sample_rate, hop_length,
train_clip_seconds)
valid_batch_fn = DataCollector(
context_size, sample_rate, hop_length, train_clip_seconds, valid=True)
batch_size = data_config["batch_size"]
train_cargo = DataCargo(
ljspeech_train,
train_batch_fn,
batch_size,
sampler=RandomSampler(ljspeech_train))
# only batch=1 for validation is enabled
valid_cargo = DataCargo(
ljspeech_valid,
valid_batch_fn,
batch_size=1,
sampler=SequentialSampler(ljspeech_valid))
make_output_tree(args.output)
if args.device == -1:
place = fluid.CPUPlace()
else:
place = fluid.CUDAPlace(args.device)
with dg.guard(place):
model_config = config["model"]
upsampling_factors = model_config["upsampling_factors"]
encoder = UpsampleNet(upsampling_factors)
n_loop = model_config["n_loop"]
n_layer = model_config["n_layer"]
residual_channels = model_config["residual_channels"]
output_dim = model_config["output_dim"]
loss_type = model_config["loss_type"]
log_scale_min = model_config["log_scale_min"]
decoder = WaveNet(n_loop, n_layer, residual_channels, output_dim,
n_mels, filter_size, loss_type, log_scale_min)
model = ConditionalWavenet(encoder, decoder)
summary(model)
train_config = config["train"]
learning_rate = train_config["learning_rate"]
anneal_rate = train_config["anneal_rate"]
anneal_interval = train_config["anneal_interval"]
lr_scheduler = dg.ExponentialDecay(
learning_rate, anneal_interval, anneal_rate, staircase=True)
optim = fluid.optimizer.Adam(
lr_scheduler, parameter_list=model.parameters())
gradiant_max_norm = train_config["gradient_max_norm"]
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
gradiant_max_norm)
if args.resume:
model_dict, optim_dict = dg.load_dygraph(args.resume)
print("Loading from {}.pdparams".format(args.resume))
model.set_dict(model_dict)
if optim_dict:
optim.set_dict(optim_dict)
print("Loading from {}.pdopt".format(args.resume))
train_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
train_loader.set_batch_generator(train_cargo, place)
valid_loader = fluid.io.DataLoader.from_generator(
capacity=10, return_list=True)
valid_loader.set_batch_generator(valid_cargo, place)
max_iterations = train_config["max_iterations"]
checkpoint_interval = train_config["checkpoint_interval"]
snap_interval = train_config["snap_interval"]
eval_interval = train_config["eval_interval"]
checkpoint_dir = os.path.join(args.output, "checkpoints")
log_dir = os.path.join(args.output, "log")
writer = SummaryWriter(log_dir)
global_step = 1
while global_step <= max_iterations:
epoch_loss = 0.
for i, batch in tqdm(enumerate(train_loader)):
audio_clips, mel_specs, audio_starts = batch
model.train()
y_var = model(audio_clips, mel_specs, audio_starts)
loss_var = model.loss(y_var, audio_clips)
loss_var.backward()
loss_np = loss_var.numpy()
epoch_loss += loss_np[0]
writer.add_scalar("loss", loss_np[0], global_step)
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy()[0],
global_step)
optim.minimize(loss_var, grad_clip=clipper)
optim.clear_gradients()
print("loss: {:<8.6f}".format(loss_np[0]))
if global_step % snap_interval == 0:
valid_model(model, valid_loader, writer, global_step,
sample_rate)
if global_step % checkpoint_interval == 0:
save_checkpoint(model, optim, checkpoint_dir, global_step)
global_step += 1

67
examples/wavenet/utils.py Normal file
View File

@ -0,0 +1,67 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import soundfile as sf
import paddle.fluid.dygraph as dg
def make_output_tree(output_dir):
checkpoint_dir = os.path.join(output_dir, "checkpoints")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
state_dir = os.path.join(output_dir, "states")
if not os.path.exists(state_dir):
os.makedirs(state_dir)
def valid_model(model, valid_loader, writer, global_step, sample_rate):
loss = []
wavs = []
model.eval()
for i, batch in enumerate(valid_loader):
# print("sentence {}".format(i))
audio_clips, mel_specs, audio_starts = batch
y_var = model(audio_clips, mel_specs, audio_starts)
wav_var = model.sample(y_var)
loss_var = model.loss(y_var, audio_clips)
loss.append(loss_var.numpy()[0])
wavs.append(wav_var.numpy()[0])
average_loss = np.mean(loss)
writer.add_scalar("valid_loss", average_loss, global_step)
for i, wav in enumerate(wavs):
writer.add_audio("valid/sample_{}".format(i), wav, global_step,
sample_rate)
def eval_model(model, valid_loader, output_dir, sample_rate):
model.eval()
for i, batch in enumerate(valid_loader):
# print("sentence {}".format(i))
path = os.path.join(output_dir, "sentence_{}.wav".format(i))
audio_clips, mel_specs, audio_starts = batch
wav_var = model.synthesis(mel_specs)
wav_np = wav_var.numpy()[0]
sf.write(wav_np, path, samplerate=sample_rate)
print("generated {}".format(path))
def save_checkpoint(model, optim, checkpoint_dir, global_step):
checkpoint_path = os.path.join(checkpoint_dir,
"step_{:09d}".format(global_step))
dg.save_dygraph(model.state_dict(), checkpoint_path)
dg.save_dygraph(optim.state_dict(), checkpoint_path)

View File

@ -1,97 +0,0 @@
# WaveNet with Paddle Fluid
Paddle fluid implementation of WaveNet, a deep generative model of raw audio waveforms.
WaveNet model is originally proposed in [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499).
Our implementation is based on the WaveNet architecture described in [ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech](https://arxiv.org/abs/1807.07281) and can provide various output distributions, including single Gaussian, mixture of Gaussian, and softmax with linearly quantized channels.
We implement WaveNet model in paddle fluid with dynamic graph, which is convenient for flexible network architectures.
## Project Structure
```text
├── configs # yaml configuration files of preset model hyperparameters
├── data.py # dataset and dataloader settings for LJSpeech
├── slurm.py # optional slurm helper functions if you use slurm to train model
├── synthesis.py # script for speech synthesis
├── train.py # script for model training
├── utils.py # helper functions for e.g., model checkpointing
├── wavenet.py # WaveNet model high level APIs
└── wavenet_modules.py # WaveNet model implementation
```
## Usage
There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on. Hyperparameters that are known to work good for the LJSpeech dataset are provided as yaml files in `./configs/` folder. Specifically, we provide `wavenet_ljspeech_single_gaussian.yaml`, `wavenet_ljspeech_mix_gaussian.yaml`, and `wavenet_ljspeech_softmax.yaml` config files for WaveNet with single Gaussian, 10-component mixture of Gaussians, and softmax (with 2048 linearly quantized channels) output distributions, respectively.
Note that `train.py` and `synthesis.py` all accept a `--config` parameter. To ensure consistency, you should use the same config yaml file for both training and synthesizing. You can also overwrite these preset hyperparameters with command line by updating parameters after `--config`. For example `--config=${yaml} --batch_size=8 --layers=20` can overwrite the corresponding hyperparameters in the `${yaml}` config file. For more details about these hyperparameters, check `utils.add_config_options_to_parser`.
Note that you also need to specify some additional parameters for `train.py` and `synthesis.py`, and the details can be found in `train.add_options_to_parser` and `synthesis.add_options_to_parser`, respectively.
### Dataset
Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
In this example, assume that the path of unzipped LJSpeech dataset is `./data/LJSpeech-1.1`.
### Train on single GPU
```bash
export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.."
export CUDA_VISIBLE_DEVICES=0
python -u train.py --config=${yaml} \
--root=./data/LJSpeech-1.1 \
--name=${ModelName} --batch_size=4 \
--parallel=false --use_gpu=true
```
#### Save and Load checkpoints
Our model will save model parameters as checkpoints in `./runs/wavenet/${ModelName}/checkpoint/` every 10000 iterations by default.
The saved checkpoint will have the format of `step-${iteration_number}.pdparams` for model parameters and `step-${iteration_number}.pdopt` for optimizer parameters.
There are three ways to load a checkpoint and resume training (take an example that you want to load a 500000-iteration checkpoint):
1. Use `--checkpoint=./runs/wavenet/${ModelName}/checkpoint/step-500000` to provide a specific path to load. Note that you only need to provide the base name of the parameter file, which is `step-500000`, no extension name `.pdparams` or `.pdopt` is needed.
2. Use `--iteration=500000`.
3. If you don't specify either `--checkpoint` or `--iteration`, the model will automatically load the latest checkpoint in `./runs/wavenet/${ModelName}/checkpoint`.
### Train on multiple GPUs
```bash
export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.."
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -u -m paddle.distributed.launch train.py \
--config=${yaml} \
--root=./data/LJSpeech-1.1 \
--name=${ModelName} --parallel=true --use_gpu=true
```
Use `export CUDA_VISIBLE_DEVICES=0,1,2,3` to set the GPUs that you want to use to be visible. Then the `paddle.distributed.launch` module will use these visible GPUs to do data parallel training in multiprocessing mode.
### Monitor with Tensorboard
By default, the logs are saved in `./runs/wavenet/${ModelName}/logs/`. You can monitor logs by tensorboard.
```bash
tensorboard --logdir=${log_dir} --port=8888
```
### Synthesize from a checkpoint
Check the [Save and load checkpoint](#save-and-load-checkpoints) section on how to load a specific checkpoint.
The following example will automatically load the latest checkpoint:
```bash
export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.."
export CUDA_VISIBLE_DEVICES=0
python -u synthesis.py --config=${yaml} \
--root=./data/LJSpeech-1.1 \
--name=${ModelName} --use_gpu=true \
--output=./syn_audios \
--sample=${SAMPLE}
```
In this example, `--output` specifies where to save the synthesized audios and `--sample` specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset.

View File

@ -0,0 +1,16 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .net import *
from .wavenet import *

View File

@ -1,32 +0,0 @@
valid_size: 16
train_clip_second: 0.5
sample_rate: 22050
fft_window_shift: 256
fft_window_size: 1024
fft_size: 2048
mel_bands: 80
seed: 1
batch_size: 8
test_every: 2000
save_every: 10000
max_iterations: 2000000
layers: 30
kernel_width: 2
dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
residual_channels: 128
skip_channels: 128
loss_type: mix-gaussian-pdf
num_mixtures: 10
log_scale_min: -9.0
conditioner:
filter_sizes: [[32, 3], [32, 3]]
upsample_factors: [16, 16]
learning_rate: 0.001
gradient_max_norm: 100.0
anneal:
every: 200000
rate: 0.5

View File

@ -1,32 +0,0 @@
valid_size: 16
train_clip_second: 0.5
sample_rate: 22050
fft_window_shift: 256
fft_window_size: 1024
fft_size: 2048
mel_bands: 80
seed: 1
batch_size: 8
test_every: 2000
save_every: 10000
max_iterations: 2000000
layers: 30
kernel_width: 2
dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
residual_channels: 128
skip_channels: 128
loss_type: mix-gaussian-pdf
num_mixtures: 1
log_scale_min: -9.0
conditioner:
filter_sizes: [[32, 3], [32, 3]]
upsample_factors: [16, 16]
learning_rate: 0.001
gradient_max_norm: 100.0
anneal:
every: 200000
rate: 0.5

View File

@ -1,31 +0,0 @@
valid_size: 16
train_clip_second: 0.5
sample_rate: 22050
fft_window_shift: 256
fft_window_size: 1024
fft_size: 2048
mel_bands: 80
seed: 1
batch_size: 8
test_every: 2000
save_every: 10000
max_iterations: 2000000
layers: 30
kernel_width: 2
dilation_block: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
residual_channels: 128
skip_channels: 128
loss_type: softmax
num_channels: 2048
conditioner:
filter_sizes: [[32, 3], [32, 3]]
upsample_factors: [16, 16]
learning_rate: 0.001
gradient_max_norm: 100.0
anneal:
every: 200000
rate: 0.5

View File

@ -1,178 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import librosa
import numpy as np
from paddle import fluid
import utils
from parakeet.datasets import ljspeech
from parakeet.data import dataset
from parakeet.data.sampler import DistributedSampler, BatchSampler
from parakeet.data.datacargo import DataCargo
class Dataset(ljspeech.LJSpeech):
def __init__(self, config):
super(Dataset, self).__init__(config.root)
self.config = config
self.fft_window_shift = config.fft_window_shift
# Calculate context frames.
frames_per_second = config.sample_rate // self.fft_window_shift
train_clip_frames = int(
np.ceil(config.train_clip_second * frames_per_second))
context_frames = config.context_size // self.fft_window_shift
self.num_frames = train_clip_frames + context_frames
def _get_example(self, metadatum):
fname, _, _ = metadatum
wav_path = self.root.joinpath("wavs", fname + ".wav")
config = self.config
sr = config.sample_rate
fft_window_shift = config.fft_window_shift
fft_window_size = config.fft_window_size
fft_size = config.fft_size
audio, loaded_sr = librosa.load(wav_path, sr=None)
assert loaded_sr == sr
# Pad audio to the right size.
frames = int(np.ceil(float(audio.size) / fft_window_shift))
fft_padding = (fft_size - fft_window_shift) // 2
desired_length = frames * fft_window_shift + fft_padding * 2
pad_amount = (desired_length - audio.size) // 2
if audio.size % 2 == 0:
audio = np.pad(audio, (pad_amount, pad_amount), mode='reflect')
else:
audio = np.pad(audio, (pad_amount, pad_amount + 1), mode='reflect')
# Normalize audio.
audio = audio / np.abs(audio).max() * 0.999
# Compute mel-spectrogram.
# Turn center to False to prevent internal padding.
spectrogram = librosa.core.stft(
audio,
hop_length=fft_window_shift,
win_length=fft_window_size,
n_fft=fft_size,
center=False)
spectrogram_magnitude = np.abs(spectrogram)
# Compute mel-spectrograms.
mel_filter_bank = librosa.filters.mel(sr=sr,
n_fft=fft_size,
n_mels=config.mel_bands)
mel_spectrogram = np.dot(mel_filter_bank, spectrogram_magnitude)
mel_spectrogram = mel_spectrogram.T
# Rescale mel_spectrogram.
min_level, ref_level = 1e-5, 20
mel_spectrogram = 20 * np.log10(np.maximum(min_level, mel_spectrogram))
mel_spectrogram = mel_spectrogram - ref_level
mel_spectrogram = np.clip((mel_spectrogram + 100) / 100, 0, 1)
# Extract the center of audio that corresponds to mel spectrograms.
audio = audio[fft_padding:-fft_padding]
assert mel_spectrogram.shape[0] * fft_window_shift == audio.size
return audio, mel_spectrogram
class Subset(dataset.Dataset):
def __init__(self, dataset, indices, valid):
self.dataset = dataset
self.indices = indices
self.valid = valid
def __getitem__(self, idx):
fft_window_shift = self.dataset.fft_window_shift
num_frames = self.dataset.num_frames
audio, mel = self.dataset[self.indices[idx]]
if self.valid:
audio_start = 0
else:
# Randomly crop context + train_clip_second of audio.
audio_frames = int(audio.size) // fft_window_shift
max_start_frame = audio_frames - num_frames
assert max_start_frame >= 0, "audio {} is too short".format(idx)
frame_start = random.randint(0, max_start_frame)
frame_end = frame_start + num_frames
audio_start = frame_start * fft_window_shift
audio_end = frame_end * fft_window_shift
audio = audio[audio_start:audio_end]
return audio, mel, audio_start
def _batch_examples(self, batch):
audios = [sample[0] for sample in batch]
audio_starts = [sample[2] for sample in batch]
# mels shape [num_frames, mel_bands]
max_frames = max(sample[1].shape[0] for sample in batch)
mels = [utils.pad_to_size(sample[1], max_frames) for sample in batch]
audios = np.array(audios, dtype=np.float32)
mels = np.array(mels, dtype=np.float32)
audio_starts = np.array(audio_starts, dtype=np.int32)
return audios, mels, audio_starts
def __len__(self):
return len(self.indices)
class LJSpeech:
def __init__(self, config, nranks, rank):
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
# Whole LJSpeech dataset.
ds = Dataset(config)
# Split into train and valid dataset.
indices = list(range(len(ds)))
train_indices = indices[config.valid_size:]
valid_indices = indices[:config.valid_size]
random.shuffle(train_indices)
# Train dataset.
trainset = Subset(ds, train_indices, valid=False)
sampler = DistributedSampler(len(trainset), nranks, rank)
total_bs = config.batch_size
assert total_bs % nranks == 0
train_sampler = BatchSampler(
sampler, total_bs // nranks, drop_last=True)
trainloader = DataCargo(trainset, batch_sampler=train_sampler)
trainreader = fluid.io.PyReader(capacity=50, return_list=True)
trainreader.decorate_batch_generator(trainloader, place)
self.trainloader = (data for _ in iter(int, 1)
for data in trainreader())
# Valid dataset.
validset = Subset(ds, valid_indices, valid=True)
# Currently only support batch_size = 1 for valid loader.
validloader = DataCargo(validset, batch_size=1, shuffle=False)
validreader = fluid.io.PyReader(capacity=20, return_list=True)
validreader.decorate_batch_generator(validloader, place)
self.validloader = validreader

View File

@ -0,0 +1,174 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import numpy as np
from scipy import signal
from tqdm import trange
import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg
import paddle.fluid.initializer as I
import paddle.fluid.layers.distributions as D
from parakeet.modules.weight_norm import Conv2DTranspose
from parakeet.models.wavenet.wavenet import WaveNet
def crop(x, audio_start, audio_length):
"""Crop mel spectrogram.
Args:
x (Variable): shape(batch_size, channels, time_steps), the condition, upsampled mel spectrogram.
audio_start (int): starting point.
audio_length (int): length.
Returns:
out: cropped condition.
"""
# crop audio
slices = [] # for each example
starts = audio_start.numpy()
for i in range(x.shape[0]):
start = starts[i]
end = start + audio_length
slice = F.slice(x[i], axes=[1], starts=[start], ends=[end])
slices.append(slice)
out = F.stack(slices)
return out
class UpsampleNet(dg.Layer):
"""A upsampling net (bridge net) in clarinet to upsample spectrograms from frame level to sample level.
It consists of several(2) layers of transposed_conv2d. in time and frequency.
The time dim is dilated hop_length times. The frequency bands retains.
"""
def __init__(self, upscale_factors=[16, 16]):
super().__init__()
self.upscale_factors = list(upscale_factors)
self.upsample_convs = dg.LayerList()
for i, factor in enumerate(upscale_factors):
self.upsample_convs.append(
Conv2DTranspose(
1,
1,
filter_size=(3, 2 * factor),
stride=(1, factor),
padding=(1, factor // 2)))
@property
def upscale_factor(self):
return np.prod(self.upscale_factors)
def forward(self, x):
"""upsample local condition to match time steps of input signals. i.e. upsample mel spectrogram to match time steps for waveform, for each layer of a wavenet.
Arguments:
x {Variable} -- shape(batch_size, frequency, time_steps), local condition
Returns:
Variable -- shape(batch_size, frequency, time_steps * np.prod(upscale_factors)), upsampled condition for each layer.
"""
x = F.unsqueeze(x, axes=[1])
for sublayer in self.upsample_convs:
x = F.leaky_relu(sublayer(x), alpha=.4)
x = F.squeeze(x, [1])
return x
# AutoRegressive Model
class ConditionalWavenet(dg.Layer):
def __init__(self, encoder: UpsampleNet, decoder: WaveNet):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, audio, mel, audio_start):
"""forward
Arguments:
audio {Variable} -- shape(batch_size, time_steps), waveform of 0.5 seconds
mel {Variable} -- shape(batch_size, frequency_bands, frames), mel spectrogram of the whole sentence
audio_start {Variable} -- shape(batch_size, ), audio start positions
Returns:
Variable -- shape(batch_size, time_steps - 1, output_dim), output distribution parameters
"""
audio_length = audio.shape[1] # audio clip's length
condition = self.encoder(mel)
condition_slice = crop(condition, audio_start,
audio_length) # crop audio
# shifting 1 step
audio = audio[:, :-1]
condition_slice = condition_slice[:, :, 1:]
y = self.decoder(audio, condition_slice)
return y
def loss(self, y, t):
"""compute loss
Arguments:
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution parameters
t {Variable} -- shape(batch_size, time_steps), target waveform
Returns:
Variable -- shape(1, ), reduced loss
"""
t = t[:, 1:]
loss = self.decoder.loss(y, t)
return loss
def sample(self, y):
"""sample from output distribution
Arguments:
y {Variable} -- shape(batch_size, time_steps, output_dim), output distribution parameters
Returns:
Variable -- shape(batch_size, time_steps) samples
"""
samples = self.decoder.sample(y)
return samples
@dg.no_grad
def synthesis(self, mel):
"""synthesize waveform from mel spectrogram
Arguments:
mel {Variable} -- shape(batch_size, frequency_bands, frames), mel-spectrogram
Returns:
Variable -- shape(batch_size, time_steps), synthesized waveform.
"""
condition = self.encoder(mel)
batch_size, _, time_steps = condition.shape
samples = []
self.decoder.start_sequence()
x_t = F.zeros((batch_size, 1), dtype="float32")
for i in trange(time_steps):
c_t = condition[:, :, i:i + 1]
y_t = self.decoder.add_input(x_t, c_t)
x_t = self.sample(y_t)
samples.append(x_t)
samples = F.concat(samples, axis=-1)
return samples

View File

@ -1,128 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility module for restarting training when using SLURM.
"""
import subprocess
import os
import sys
import shlex
import re
import time
def job_info():
"""Get information about the current job using `scontrol show job`.
Returns a dict mapping parameter names (e.g. "UserId", "RunTime", etc) to
their values, both as strings.
"""
job_id = int(os.environ["SLURM_JOB_ID"])
command = ["scontrol", "show", "job", str(job_id)]
output = subprocess.check_output(command).decode("utf-8")
# Use a regex to extract the parameter names and values
pattern = "([A-Za-z/]*)=([^ \t\n]*)"
return dict(re.findall(pattern, output))
def parse_hours(text):
"""Parse a time format HH or DD-HH into a number of hours."""
hour_chunks = text.split("-")
if len(hour_chunks) == 1:
return int(hour_chunks[0])
elif len(hour_chunks) == 2:
return 24 * int(hour_chunks[0]) + int(hour_chunks[1])
else:
raise ValueError("Unexpected hour format (expected HH or "
"DD-HH, but got {}).".format(text))
def parse_time(text):
"""Convert slurm time to an integer.
Expects time to be of the form:
"hours:minutes:seconds" or "day-hours:minutes:seconds".
"""
hours, minutes, seconds = text.split(":")
try:
return parse_hours(hours) * 3600 + int(minutes) * 60 + int(seconds)
except ValueError as e:
raise ValueError("Error parsing time {}. Got error {}.".format(text,
str(e)))
def restart_command():
"""Using the environment and SLURM command, create a command that, when,
run, will enqueue a repeat of the current job using `sbatch`.
Return the command as a list of strings, suitable for passing to
`subprocess.check_call` or similar functions.
Returns:
resume_command: list<str>, command to run to restart job.
end_time: int or None; the time the job will end or None
if the job has unlimited runtime.
"""
# Make sure `RunTime` could be parsed correctly.
while job_info()["RunTime"] == "INVALID":
time.sleep(1)
# Get all the necessary information by querying SLURM with this job id
info = job_info()
try:
num_cpus = int(info["CPUs/Task"])
except KeyError:
num_cpus = int(os.environ["SLURM_CPUS_PER_TASK"])
num_tasks = int(os.environ["SLURM_NTASKS"])
nodes = info["NumNodes"]
gres, partition = info.get("Gres"), info.get("Partition")
stderr, stdout = info.get("StdErr"), info.get("StdOut")
job_name = info.get("JobName")
command = [
"sbatch", "--job-name={}".format(job_name),
"--ntasks={}".format(num_tasks)
]
if partition:
command.extend(["--partition", partition])
if gres and gres != "(null)":
command.extend(["--gres", gres])
num_gpu = int(gres.split(':')[-1])
print("number of gpu assigned by slurm is {}".format(num_gpu))
if stderr:
command.extend(["--error", stderr])
if stdout:
command.extend(["--output", stdout])
python = subprocess.check_output(
["/usr/bin/which", "python3"]).decode("utf-8").strip()
dist_setting = ['-m', 'paddle.distributed.launch']
wrap_cmd = ["srun", python, '-u'] + dist_setting + sys.argv
command.append("--wrap={}".format(" ".join(
shlex.quote(arg) for arg in wrap_cmd)))
time_limit_string = info["TimeLimit"]
if time_limit_string.lower() == "unlimited":
print(
"UNLIMITED detected: restart OFF, infinite learning ON.",
flush=True)
return command, None
time_limit = parse_time(time_limit_string)
runtime = parse_time(info["RunTime"])
end_time = time.time() + time_limit - runtime
return command, end_time

View File

@ -1,116 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
from pprint import pprint
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
import utils
from wavenet import WaveNet
def add_options_to_parser(parser):
parser.add_argument(
'--model',
type=str,
default='wavenet',
help="general name of the model")
parser.add_argument(
'--name', type=str, help="specific name of the training model")
parser.add_argument(
'--root', type=str, help="root path of the LJSpeech dataset")
parser.add_argument(
'--use_gpu',
type=bool,
default=True,
help="option to use gpu training")
parser.add_argument(
'--iteration',
type=int,
default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument(
'--checkpoint',
type=str,
default=None,
help="path of the checkpoint to load")
parser.add_argument(
'--output',
type=str,
default="./syn_audios",
help="path to write synthesized audio files")
parser.add_argument(
'--sample',
type=int,
help="which of the valid samples to synthesize audio")
def synthesize(config):
pprint(jsonargparse.namespace_to_dict(config))
# Get checkpoint directory path.
run_dir = os.path.join("runs", config.model, config.name)
checkpoint_dir = os.path.join(run_dir, "checkpoint")
# Configurate device.
place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace()
with dg.guard(place):
# Fix random seed.
seed = config.seed
random.seed(seed)
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
print("Random Seed: ", seed)
# Build model.
model = WaveNet(config, checkpoint_dir)
model.build(training=False)
# Obtain the current iteration.
if config.checkpoint is None:
if config.iteration is None:
iteration = utils.load_latest_checkpoint(checkpoint_dir)
else:
iteration = config.iteration
else:
iteration = int(config.checkpoint.split('/')[-1].split('-')[-1])
# Run model inference.
model.infer(iteration)
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(
description="Synthesize audio using WaveNet model",
formatter_class='default_argparse')
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
# Parse argument from both command line and yaml config file.
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
synthesize(config)

View File

@ -1,171 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import subprocess
import time
from pprint import pprint
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from tensorboardX import SummaryWriter
import slurm
import utils
from wavenet import WaveNet
MAXIMUM_SAVE_TIME = 10 * 60
def add_options_to_parser(parser):
parser.add_argument(
'--model',
type=str,
default='wavenet',
help="general name of the model")
parser.add_argument(
'--name', type=str, help="specific name of the training model")
parser.add_argument(
'--root', type=str, help="root path of the LJSpeech dataset")
parser.add_argument(
'--parallel',
type=bool,
default=True,
help="option to use data parallel training")
parser.add_argument(
'--use_gpu',
type=bool,
default=True,
help="option to use gpu training")
parser.add_argument(
'--iteration',
type=int,
default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument(
'--checkpoint',
type=str,
default=None,
help="path of the checkpoint to load")
parser.add_argument(
'--slurm',
type=bool,
default=False,
help="whether you are using slurm to submit training jobs")
def train(config):
use_gpu = config.use_gpu
parallel = config.parallel if use_gpu else False
# Get the rank of the current training process.
rank = dg.parallel.Env().local_rank if parallel else 0
nranks = dg.parallel.Env().nranks if parallel else 1
if rank == 0:
# Print the whole config setting.
pprint(jsonargparse.namespace_to_dict(config))
# Make checkpoint directory.
run_dir = os.path.join("runs", config.model, config.name)
checkpoint_dir = os.path.join(run_dir, "checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)
# Create tensorboard logger.
tb = SummaryWriter(os.path.join(run_dir, "logs")) \
if rank == 0 else None
# Configurate device
place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace()
with dg.guard(place):
# Fix random seed.
seed = config.seed
random.seed(seed)
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
print("Random Seed: ", seed)
# Build model.
model = WaveNet(config, checkpoint_dir, parallel, rank, nranks, tb)
model.build()
# Obtain the current iteration.
if config.checkpoint is None:
if config.iteration is None:
iteration = utils.load_latest_checkpoint(checkpoint_dir, rank)
else:
iteration = config.iteration
else:
iteration = int(config.checkpoint.split('/')[-1].split('-')[-1])
# Get restart command if using slurm.
if config.slurm:
resume_command, death_time = slurm.restart_command()
if rank == 0:
print("Restart command:", " ".join(resume_command))
done = False
while iteration < config.max_iterations:
# Run one single training step.
model.train_step(iteration)
iteration += 1
if iteration % config.test_every == 0:
# Run validation step.
model.valid_step(iteration)
# Check whether reaching the time limit.
if config.slurm:
done = (death_time is not None and
death_time - time.time() < MAXIMUM_SAVE_TIME)
if rank == 0 and done:
print("Saving progress before exiting.")
model.save(iteration)
print("Running restart command:", " ".join(resume_command))
# Submit restart command.
subprocess.check_call(resume_command)
break
if rank == 0 and iteration % config.save_every == 0:
# Save parameters.
model.save(iteration)
# Close TensorBoard.
if rank == 0:
tb.close()
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(
description="Train WaveNet model", formatter_class='default_argparse')
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
# Parse argument from both command line and yaml config file.
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
train(config)

View File

@ -1,186 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
import time
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
def add_config_options_to_parser(parser):
parser.add_argument(
'--valid_size', type=int, help="size of the valid dataset")
parser.add_argument(
'--train_clip_second',
type=float,
help="the length of audio clip for training")
parser.add_argument(
'--sample_rate', type=int, help="sampling rate of audio data file")
parser.add_argument(
'--fft_window_shift',
type=int,
help="the shift of fft window for each frame")
parser.add_argument(
'--fft_window_size',
type=int,
help="the size of fft window for each frame")
parser.add_argument(
'--fft_size', type=int, help="the size of fft filter on each frame")
parser.add_argument(
'--mel_bands',
type=int,
help="the number of mel bands when calculating mel spectrograms")
parser.add_argument(
'--seed', type=int, help="seed of random initialization for the model")
parser.add_argument(
'--batch_size', type=int, help="batch size for training")
parser.add_argument(
'--test_every', type=int, help="test interval during training")
parser.add_argument(
'--save_every',
type=int,
help="checkpointing interval during training")
parser.add_argument(
'--max_iterations', type=int, help="maximum training iterations")
parser.add_argument(
'--layers', type=int, help="number of dilated convolution layers")
parser.add_argument(
'--kernel_width', type=int, help="dilated convolution kernel width")
parser.add_argument(
'--dilation_block', type=list, help="dilated convolution kernel width")
parser.add_argument('--residual_channels', type=int)
parser.add_argument('--skip_channels', type=int)
parser.add_argument(
'--loss_type', type=str, help="mix-gaussian-pdf or softmax")
parser.add_argument(
'--num_channels',
type=int,
default=None,
help="number of channels for softmax output")
parser.add_argument(
'--num_mixtures',
type=int,
default=None,
help="number of gaussian mixtures for gaussian output")
parser.add_argument(
'--log_scale_min',
type=float,
default=None,
help="minimum clip value of log variance of gaussian output")
parser.add_argument(
'--conditioner.filter_sizes',
type=list,
help="conv2d tranpose op filter sizes for building conditioner")
parser.add_argument(
'--conditioner.upsample_factors',
type=list,
help="list of upsample factors for building conditioner")
parser.add_argument('--learning_rate', type=float)
parser.add_argument('--gradient_max_norm', type=float)
parser.add_argument(
'--anneal.every',
type=int,
help="step interval for annealing learning rate")
parser.add_argument('--anneal.rate', type=float)
parser.add_argument('--config', action=jsonargparse.ActionConfigFile)
def pad_to_size(array, length, pad_with=0.0):
"""
Pad an array on the first (length) axis to a given length.
"""
padding = length - array.shape[0]
assert padding >= 0, "Padding required was less than zero"
paddings = [(0, 0)] * len(array.shape)
paddings[0] = (0, padding)
return np.pad(array, paddings, mode='constant', constant_values=pad_with)
def calculate_context_size(config):
dilations = list(
itertools.islice(
itertools.cycle(config.dilation_block), config.layers))
config.context_size = sum(dilations) + 1
print("Context size is", config.context_size)
def load_latest_checkpoint(checkpoint_dir, rank=0):
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Create checkpoint index file if not exist.
if (not os.path.isfile(checkpoint_path)) and rank == 0:
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-0")
# Make sure that other process waits until checkpoint file is created
# by process 0.
while not os.path.isfile(checkpoint_path):
time.sleep(1)
# Fetch the latest checkpoint index.
with open(checkpoint_path, "r") as handle:
latest_checkpoint = handle.readline().split()[-1]
iteration = int(latest_checkpoint.split("-")[-1])
return iteration
def save_latest_checkpoint(checkpoint_dir, iteration):
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-{}".format(iteration))
def load_parameters(checkpoint_dir,
rank,
model,
optimizer=None,
iteration=None,
file_path=None):
if file_path is None:
if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank)
if iteration == 0:
return
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict, optimizer_dict = dg.load_dygraph(file_path)
model.set_dict(model_dict)
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
if optimizer and optimizer_dict:
optimizer.set_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
rank, file_path))
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict = model.state_dict()
dg.save_dygraph(model_dict, file_path)
print("[checkpoint] Saved model to {}".format(file_path))
if optimizer:
opt_dict = optimizer.state_dict()
dg.save_dygraph(opt_dict, file_path)
print("[checkpoint] Saved optimzier state to {}".format(file_path))

View File

@ -12,197 +12,425 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
import math
import time
import librosa
import itertools
import numpy as np
import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg
from paddle import fluid
import paddle.fluid.initializer as I
import paddle.fluid.layers.distributions as D
import utils
from data import LJSpeech
from wavenet_modules import WaveNetModule
from parakeet.modules.weight_norm import Linear, Conv1D, Conv1DCell, Conv2DTranspose
class WaveNet():
def __init__(self,
config,
checkpoint_dir,
parallel=False,
rank=0,
nranks=1,
tb_logger=None):
# Process config to calculate the context size
dilations = list(
itertools.islice(
itertools.cycle(config.dilation_block), config.layers))
config.context_size = sum(dilations) + 1
self.config = config
self.checkpoint_dir = checkpoint_dir
self.parallel = parallel
self.rank = rank
self.nranks = nranks
self.tb_logger = tb_logger
# for wavenet with softmax loss
def quantize(values, n_bands):
quantized = F.cast((values + 1.0) / 2.0 * n_bands, "int64")
return quantized
def build(self, training=True):
config = self.config
dataset = LJSpeech(config, self.nranks, self.rank)
self.trainloader = dataset.trainloader
self.validloader = dataset.validloader
wavenet = WaveNetModule("wavenet", config, self.rank)
def dequantize(quantized, n_bands):
value = (F.cast(quantized, "float32") + 0.5) * (2.0 / n_bands) - 1.0
return value
# Dry run once to create and initalize all necessary parameters.
audio = dg.to_variable(np.random.randn(1, 20000).astype(np.float32))
mel = dg.to_variable(
np.random.randn(1, 100, self.config.mel_bands).astype(np.float32))
audio_start = dg.to_variable(np.array([0], dtype=np.int32))
wavenet(audio, mel, audio_start)
if training:
# Create Learning rate scheduler.
lr_scheduler = dg.ExponentialDecay(
learning_rate=config.learning_rate,
decay_steps=config.anneal.every,
decay_rate=config.anneal.rate,
staircase=True)
class ResidualBlock(dg.Layer):
def __init__(self, residual_channels, condition_dim, filter_size,
dilation):
super().__init__()
dilated_channels = 2 * residual_channels
# following clarinet's implementation, we do not have parametric residual
# & skip connection.
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=lr_scheduler)
std = np.sqrt(1 / (filter_size * residual_channels))
self.conv = Conv1DCell(
residual_channels,
dilated_channels,
filter_size,
dilation=dilation,
causal=True,
param_attr=I.Normal(scale=std))
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
config.gradient_max_norm)
std = np.sqrt(1 / condition_dim)
self.condition_proj = Conv1D(
condition_dim, dilated_channels, 1, param_attr=I.Normal(scale=std))
# Load parameters.
utils.load_parameters(
self.checkpoint_dir,
self.rank,
wavenet,
optimizer,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
self.filter_size = filter_size
self.dilation = dilation
self.dilated_channels = dilated_channels
self.residual_channels = residual_channels
self.condition_dim = condition_dim
# Data parallelism.
if self.parallel:
strategy = dg.parallel.prepare_context()
wavenet = dg.parallel.DataParallel(wavenet, strategy)
def forward(self, x, condition=None):
"""Conv1D gated tanh Block
self.wavenet = wavenet
self.optimizer = optimizer
self.clipper = clipper
Arguments:
x {Variable} -- shape(batch_size, residual_channels, time_steps), the input.
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim, time_steps), upsampled local condition, it has the shape time steps as the input x. (default: {None})
Returns:
Variable -- shape(batch_size, residual_channels, time_steps), the output which is used as the input of the next layer.
Variable -- shape(batch_size, residual_channels, time_steps), the output which is stacked alongside with other layers' as the output of wavenet.
"""
time_steps = x.shape[-1]
h = x
# dilated conv
h = self.conv(h)
if h.shape[-1] != time_steps:
h = h[:, :, :time_steps]
# condition
if condition:
h += self.condition_proj(condition)
# gated tanh
content, gate = F.split(h, 2, dim=1)
z = F.sigmoid(gate) * F.tanh(content)
# projection
residual = F.scale(z + x, math.sqrt(.5))
skip_connection = z
return residual, skip_connection
def start_sequence(self):
self.conv.start_sequence()
def add_input(self, x, condition=None):
"""add a step input.
Arguments:
x {Variable} -- shape(batch_size, in_channels, time_steps=1), step input
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim, time_steps=1) (default: {None})
Returns:
Variable -- shape(batch_size, in_channels, time_steps=1), residual connection, which is the input for the next layer
Variable -- shape(batch_size, in_channels, time_steps=1), skip connection
"""
h = x
# dilated conv
h = self.conv.add_input(h)
# condition
if condition is not None:
h += self.condition_proj(condition)
# gated tanh
content, gate = F.split(h, 2, dim=1)
z = F.sigmoid(gate) * F.tanh(content)
# projection
residual = F.scale(z + x, np.sqrt(0.5))
skip_connection = z
return residual, skip_connection
class ResidualNet(dg.Layer):
def __init__(self, n_loop, n_layer, residual_channels, condition_dim,
filter_size):
super().__init__()
# double the dilation at each layer in a loop(n_loop layers)
dilations = [2**i for i in range(n_loop)] * n_layer
self.context_size = 1 + sum(dilations)
self.residual_blocks = dg.LayerList([
ResidualBlock(residual_channels, condition_dim, filter_size,
dilation) for dilation in dilations
])
def forward(self, x, condition=None):
"""n_layer layers of n_loop Residual Blocks.
Arguments:
x {Variable} -- shape(batch_size, residual_channels, time_steps), input of the residual net.
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim, time_steps), upsampled conditions, which has the same time steps as the input. (default: {None})
Returns:
Variable -- shape(batch_size, skip_channels, time_steps), output of the residual net.
"""
#before_resnet = time.time()
for i, func in enumerate(self.residual_blocks):
x, skip = func(x, condition)
if i == 0:
skip_connections = skip
else:
# Load parameters.
utils.load_parameters(
self.checkpoint_dir,
self.rank,
wavenet,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
skip_connections = F.scale(skip_connections + skip,
np.sqrt(0.5))
#print("resnet: ", time.time() - before_resnet)
return skip_connections
self.wavenet = wavenet
def start_sequence(self):
for block in self.residual_blocks:
block.start_sequence()
def train_step(self, iteration):
self.wavenet.train()
def add_input(self, x, condition=None):
"""add step input and return step output.
start_time = time.time()
audios, mels, audio_starts = next(self.trainloader)
load_time = time.time()
Arguments:
x {Variable} -- shape(batch_size, residual_channels, time_steps=1), step input.
loss, _ = self.wavenet(audios, mels, audio_starts)
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim, time_steps=1), step condition (default: {None})
if self.parallel:
# loss = loss / num_trainers
loss = self.wavenet.scale_loss(loss)
loss.backward()
self.wavenet.apply_collective_grads()
Returns:
Variable -- shape(batch_size, skip_channels, time_steps=1), step output, parameters of the output distribution.
"""
for i, func in enumerate(self.residual_blocks):
x, skip = func.add_input(x, condition)
if i == 0:
skip_connections = skip
else:
loss.backward()
skip_connections = F.scale(skip_connections + skip,
np.sqrt(0.5))
return skip_connections
if isinstance(self.optimizer._learning_rate,
fluid.optimizer.LearningRateDecay):
current_lr = self.optimizer._learning_rate.step().numpy()
class WaveNet(dg.Layer):
def __init__(self, n_loop, n_layer, residual_channels, output_dim,
condition_dim, filter_size, loss_type, log_scale_min):
super().__init__()
if loss_type not in ["softmax", "mog"]:
raise ValueError("loss_type {} is not supported".format(loss_type))
if loss_type == "softmax":
self.embed = dg.Embedding((output_dim, residual_channels))
else:
current_lr = self.optimizer._learning_rate
assert output_dim % 3 == 0, "with MoG output, the output dim must be divided by 3"
self.embed = Linear(1, residual_channels)
self.optimizer.minimize(
loss,
grad_clip=self.clipper,
parameter_list=self.wavenet.parameters())
self.wavenet.clear_gradients()
self.resnet = ResidualNet(n_loop, n_layer, residual_channels,
condition_dim, filter_size)
self.context_size = self.resnet.context_size
graph_time = time.time()
skip_channels = residual_channels # assume the same channel
self.proj1 = Linear(skip_channels, skip_channels)
self.proj2 = Linear(skip_channels, skip_channels)
# if loss_type is softmax, output_dim is n_vocab of waveform magnitude.
# if loss_type is mog, output_dim is 3 * gaussian, (weight, mean and stddev)
self.proj3 = Linear(skip_channels, output_dim)
if self.rank == 0:
loss_val = float(loss.numpy()) * self.nranks
log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \
"Time: {:.3f}/{:.3f}".format(
self.rank, iteration, loss_val,
load_time - start_time, graph_time - load_time)
print(log)
self.loss_type = loss_type
self.output_dim = output_dim
self.input_dim = 1
self.skip_channels = skip_channels
self.log_scale_min = log_scale_min
tb = self.tb_logger
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
tb.add_scalar("Learning-Rate", current_lr, iteration)
def forward(self, x, condition=None):
"""(Possibly) Conditonal Wavenet.
@dg.no_grad
def valid_step(self, iteration):
self.wavenet.eval()
Arguments:
x {Variable} -- shape(batch_size, time_steps), the input signal of wavenet. The waveform in 0.5 seconds.
total_loss = []
sample_audios = []
start_time = time.time()
for audios, mels, audio_starts in self.validloader():
loss, sample_audio = self.wavenet(audios, mels, audio_starts, True)
total_loss.append(float(loss.numpy()))
sample_audios.append(sample_audio)
total_time = time.time() - start_time
Keyword Arguments:
conditions {Variable} -- shape(batch_size, condition_dim, 1, time_steps), the upsampled local condition. (default: {None})
if self.rank == 0:
loss_val = np.mean(total_loss)
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
self.rank, loss_val, total_time)
print(log)
Returns:
Variable -- shape(batch_size, time_steps, output_dim), output distributions at each time_steps.
"""
tb = self.tb_logger
tb.add_scalar("Valid-Avg-Loss", loss_val, iteration)
tb.add_audio(
"Teacher-Forced-Audio-0",
sample_audios[0].numpy(),
iteration,
sample_rate=self.config.sample_rate)
tb.add_audio(
"Teacher-Forced-Audio-1",
sample_audios[1].numpy(),
iteration,
sample_rate=self.config.sample_rate)
# CAUTION: rank-4 condition here
# Causal Conv
if self.loss_type == "softmax":
x = F.clip(x, min=-1., max=0.99999)
x = quantize(x, self.output_dim)
x = self.embed(x) # (B, T, C)
else:
x = F.unsqueeze(x, axes=[-1]) # (B, T, 1)
x = self.embed(x) # (B, T, C)
x = F.transpose(x, perm=[0, 2, 1]) # (B, C, T)
@dg.no_grad
def infer(self, iteration):
self.wavenet.eval()
# Residual & Skip-conenection & linears
z = self.resnet(x, condition)
config = self.config
sample = config.sample
z = F.transpose(z, [0, 2, 1])
z = F.relu(self.proj2(F.relu(self.proj1(z))))
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
os.makedirs(output, exist_ok=True)
y = self.proj3(z)
return y
filename = "{}/valid_{}.wav".format(output, sample)
print("Synthesize sample {}, save as {}".format(sample, filename))
def start_sequence(self):
self.resnet.start_sequence()
mels_list = [mels for _, mels, _ in self.validloader()]
start_time = time.time()
syn_audio = self.wavenet.synthesize(mels_list[sample])
syn_time = time.time() - start_time
print("audio shape {}, synthesis time {}".format(syn_audio.shape,
syn_time))
librosa.output.write_wav(filename, syn_audio, sr=config.sample_rate)
def add_input(self, x, condition=None):
"""add step input
def save(self, iteration):
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.wavenet, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
Arguments:
x {Variable} -- shape(batch_size, time_steps=1), step input.
Keyword Arguments:
condition {Variable} -- shape(batch_size, condition_dim , 1, time_steps=1) (default: {None})
Returns:
Variable -- ouput parameter for the distribution.
"""
# Causal Conv
if self.loss_type == "softmax":
x = quantize(x, self.output_dim)
x = self.embed(x) # (B, T, C), T=1
else:
x = F.unsqueeze(x, axes=[-1]) # (B, T, 1), T=1
x = self.embed(x) # (B, T, C)
x = F.transpose(x, perm=[0, 2, 1])
# Residual & Skip-conenection & linears
z = self.resnet.add_input(x, condition)
z = F.transpose(z, [0, 2, 1])
z = F.relu(self.proj2(F.relu(self.proj1(z)))) # (B, T, C)
# Output
y = self.proj3(z)
return y
def compute_softmax_loss(self, y, t):
"""compute loss, it is basically a language_model-like loss.
Arguments:
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution of multinomial distribution.
t {Variable} -- shape(batch_size, time_steps - 1), target waveform.
Returns:
Variable -- shape(1,), loss
"""
# context size is not taken into account
y = y[:, self.context_size:, :]
t = t[:, self.context_size:]
t = F.clip(t, min=-1.0, max=0.99999)
quantized = quantize(t, n_bands=self.output_dim)
label = F.unsqueeze(quantized, axes=[-1])
loss = F.softmax_with_cross_entropy(y, label)
reduced_loss = F.reduce_mean(loss)
return reduced_loss
def sample_from_softmax(self, y):
"""sample from output distribution.
Arguments:
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution.
Returns:
Variable -- shape(batch_size, time_steps - 1), samples.
"""
# dequantize
batch_size, time_steps, output_dim, = y.shape
y = F.reshape(y, (batch_size * time_steps, output_dim))
prob = F.softmax(y)
quantized = F.sampling_id(prob)
samples = dequantize(quantized, n_bands=self.output_dim)
samples = F.reshape(samples, (batch_size, -1))
return samples
def compute_mog_loss(self, y, t):
"""compute the loss with an mog output distribution.
WARNING: this is not a legal probability, but a density. so it might be greater than 1.
Arguments:
y {Variable} -- shape(batch_size, time_steps, output_dim), output distribution's parameter. To represent a mixture of Gaussians. The output for each example at each time_step consists of 3 parts. The mean, the stddev, and a weight for that gaussian.
t {Variable} -- shape(batch_size, time_steps), target waveform.
Returns:
Variable -- loss, note that it is computed with the pdf of the MoG distribution.
"""
n_mixture = self.output_dim // 3
# context size is not taken in to account
y = y[:, self.context_size:, :]
t = t[:, self.context_size:]
w, mu, log_std = F.split(y, 3, dim=2)
# 100.0 is just a large float
log_std = F.clip(log_std, min=self.log_scale_min, max=100.)
inv_std = F.exp(-log_std)
p_mixture = F.softmax(w, axis=-1)
t = F.unsqueeze(t, axes=[-1])
if n_mixture > 1:
# t = F.expand_as(t, log_std)
t = F.expand(t, [1, 1, n_mixture])
x_std = inv_std * (t - mu)
exponent = F.exp(-0.5 * x_std * x_std)
pdf_x = 1.0 / np.sqrt(2.0 * np.pi) * inv_std * exponent
pdf_x = p_mixture * pdf_x
# pdf_x: [bs, len]
pdf_x = F.reduce_sum(pdf_x, dim=-1)
per_sample_loss = -F.log(pdf_x + 1e-9)
loss = F.reduce_mean(per_sample_loss)
return loss
def sample_from_mog(self, y):
"""sample from output distribution.
Arguments:
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution.
Returns:
Variable -- shape(batch_size, time_steps - 1), samples.
"""
batch_size, time_steps, output_dim = y.shape
n_mixture = output_dim // 3
w, mu, log_std = F.split(y, 3, dim=-1)
reshaped_w = F.reshape(w, (batch_size * time_steps, n_mixture))
prob_ids = F.sampling_id(F.softmax(reshaped_w))
prob_ids = F.reshape(prob_ids, (batch_size, time_steps))
prob_ids = prob_ids.numpy()
index = np.array([[[b, t, prob_ids[b, t]] for t in range(time_steps)]
for b in range(batch_size)]).astype("int32")
index_var = dg.to_variable(index)
mu_ = F.gather_nd(mu, index_var)
log_std_ = F.gather_nd(log_std, index_var)
dist = D.Normal(mu_, F.exp(log_std_))
samples = dist.sample(shape=[])
samples = F.clip(samples, min=-1., max=1.)
return samples
def sample(self, y):
"""sample from output distribution.
Arguments:
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution.
Returns:
Variable -- shape(batch_size, time_steps - 1), samples.
"""
if self.loss_type == "softmax":
return self.sample_from_softmax(y)
else:
return self.sample_from_mog(y)
def loss(self, y, t):
"""compute loss.
Arguments:
y {Variable} -- shape(batch_size, time_steps - 1, output_dim), output distribution of multinomial distribution.
t {Variable} -- shape(batch_size, time_steps - 1), target waveform.
Returns:
Variable -- shape(1,), loss
"""
if self.loss_type == "softmax":
return self.compute_softmax_loss(y, t)
else:
return self.compute_mog_loss(y, t)

View File

@ -1,388 +0,0 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from parakeet.modules import conv, modules
def get_padding(filter_size, stride, padding_type='same'):
if padding_type == 'same':
padding = [(x - y) // 2 for x, y in zip(filter_size, stride)]
else:
raise ValueError("Only support same padding")
return padding
def extract_slices(x, audio_starts, audio_length, rank):
slices = []
for i in range(x.shape[0]):
start = audio_starts.numpy()[i]
end = start + audio_length
slice = fluid.layers.slice(
x, axes=[0, 1], starts=[i, start], ends=[i + 1, end])
slices.append(fluid.layers.squeeze(slice, [0]))
x = fluid.layers.stack(slices, axis=0)
return x
class Conditioner(dg.Layer):
def __init__(self, name_scope, config):
super(Conditioner, self).__init__(name_scope)
upsample_factors = config.conditioner.upsample_factors
filter_sizes = config.conditioner.filter_sizes
assert np.prod(upsample_factors) == config.fft_window_shift
self.deconvs = []
for i, up_scale in enumerate(upsample_factors):
stride = (up_scale, 1)
padding = get_padding(filter_sizes[i], stride)
self.deconvs.append(
modules.Conv2DTranspose(
self.full_name(),
num_filters=1,
filter_size=filter_sizes[i],
padding=padding,
stride=stride))
# Register python list as parameters.
for i, layer in enumerate(self.deconvs):
self.add_sublayer("conv_transpose_{}".format(i), layer)
def forward(self, x):
x = fluid.layers.unsqueeze(x, 1)
for layer in self.deconvs:
x = fluid.layers.leaky_relu(layer(x), alpha=0.4)
return fluid.layers.squeeze(x, [1])
class WaveNetModule(dg.Layer):
def __init__(self, name_scope, config, rank):
super(WaveNetModule, self).__init__(name_scope)
self.rank = rank
self.conditioner = Conditioner(self.full_name(), config)
self.dilations = list(
itertools.islice(
itertools.cycle(config.dilation_block), config.layers))
self.context_size = sum(self.dilations) + 1
self.log_scale_min = config.log_scale_min
self.config = config
print("dilations", self.dilations)
print("context_size", self.context_size)
if config.loss_type == "softmax":
self.embedding_fc = modules.Embedding(
self.full_name(),
num_embeddings=config.num_channels,
embed_dim=config.residual_channels,
std=0.1)
elif config.loss_type == "mix-gaussian-pdf":
self.embedding_fc = modules.FC(self.full_name(),
in_features=1,
size=config.residual_channels,
num_flatten_dims=2,
relu=False)
else:
raise ValueError("loss_type {} is unsupported!".format(loss_type))
self.dilated_causal_convs = []
for dilation in self.dilations:
self.dilated_causal_convs.append(
modules.Conv1D_GU(
self.full_name(),
conditioner_dim=config.mel_bands,
in_channels=config.residual_channels,
num_filters=config.residual_channels,
filter_size=config.kernel_width,
dilation=dilation,
causal=True))
for i, layer in enumerate(self.dilated_causal_convs):
self.add_sublayer("dilated_causal_conv_{}".format(i), layer)
self.fc1 = modules.FC(self.full_name(),
in_features=config.residual_channels,
size=config.skip_channels,
num_flatten_dims=2,
relu=True,
act="relu")
self.fc2 = modules.FC(self.full_name(),
in_features=config.skip_channels,
size=config.skip_channels,
num_flatten_dims=2,
relu=True,
act="relu")
if config.loss_type == "softmax":
self.fc3 = modules.FC(self.full_name(),
in_features=config.skip_channels,
size=config.num_channels,
num_flatten_dims=2,
relu=False)
elif config.loss_type == "mix-gaussian-pdf":
self.fc3 = modules.FC(self.full_name(),
in_features=config.skip_channels,
size=3 * config.num_mixtures,
num_flatten_dims=2,
relu=False)
else:
raise ValueError("loss_type {} is unsupported!".format(loss_type))
def sample_softmax(self, mix_parameters):
batch, length, hidden = mix_parameters.shape
mix_param_2d = fluid.layers.reshape(mix_parameters,
[batch * length, hidden])
mix_param_2d = fluid.layers.softmax(mix_param_2d, axis=-1)
# quantized: [batch * length]
quantized = fluid.layers.cast(
fluid.layers.sampling_id(mix_param_2d), dtype="float32")
samples = (quantized + 0.5) * (2.0 / self.config.num_channels) - 1.0
# samples: [batch * length]
return samples
def sample_mix_gaussian(self, mix_parameters):
# mix_parameters reshape from [bs, len, 3 * num_mixtures]
# to [bs * len, 3 * num_mixtures].
batch, length, hidden = mix_parameters.shape
mix_param_2d = fluid.layers.reshape(mix_parameters,
[batch * length, hidden])
K = hidden // 3
# Unpack the parameters of the mixture of gaussian.
logits_pi = mix_param_2d[:, 0:K]
mu = mix_param_2d[:, K:2 * K]
log_s = mix_param_2d[:, 2 * K:3 * K]
s = fluid.layers.exp(log_s)
pi = fluid.layers.softmax(logits_pi, axis=-1)
comp_samples = fluid.layers.sampling_id(pi)
row_idx = dg.to_variable(np.arange(batch * length))
comp_samples = fluid.layers.stack([row_idx, comp_samples], axis=-1)
mu_comp = fluid.layers.gather_nd(mu, comp_samples)
s_comp = fluid.layers.gather_nd(s, comp_samples)
# N(0, 1) normal sample.
u = fluid.layers.gaussian_random(shape=[batch * length])
samples = mu_comp + u * s_comp
samples = fluid.layers.clip(samples, min=-1.0, max=1.0)
return samples
def softmax_loss(self, targets, mix_parameters):
targets = targets[:, self.context_size:]
mix_parameters = mix_parameters[:, self.context_size:, :]
# Quantized audios to integral values with range [0, num_channels)
num_channels = self.config.num_channels
targets = fluid.layers.clip(targets, min=-1.0, max=0.99999)
quantized = fluid.layers.cast(
(targets + 1.0) / 2.0 * num_channels, dtype="int64")
# per_sample_loss shape: [bs, len, 1]
per_sample_loss = fluid.layers.softmax_with_cross_entropy(
logits=mix_parameters, label=fluid.layers.unsqueeze(quantized, 2))
loss = fluid.layers.reduce_mean(per_sample_loss)
return loss
def mixture_density_loss(self, targets, mix_parameters, log_scale_min):
# targets: [bs, len]
# mix_params: [bs, len, 3 * num_mixture]
targets = targets[:, self.context_size:]
mix_parameters = mix_parameters[:, self.context_size:, :]
# log_s: [bs, len, num_mixture]
logits_pi, mu, log_s = fluid.layers.split(
mix_parameters, num_or_sections=3, dim=-1)
pi = fluid.layers.softmax(logits_pi, axis=-1)
log_s = fluid.layers.clip(log_s, min=log_scale_min, max=100.0)
inv_s = fluid.layers.exp(0.0 - log_s)
# Calculate gaussian loss.
targets = fluid.layers.unsqueeze(targets, -1)
targets = fluid.layers.expand(targets,
[1, 1, self.config.num_mixtures])
x_std = inv_s * (targets - mu)
exponent = fluid.layers.exp(-0.5 * x_std * x_std)
pdf_x = 1.0 / np.sqrt(2.0 * np.pi) * inv_s * exponent
pdf_x = pi * pdf_x
# pdf_x: [bs, len]
pdf_x = fluid.layers.reduce_sum(pdf_x, dim=-1)
per_sample_loss = 0.0 - fluid.layers.log(pdf_x + 1e-9)
loss = fluid.layers.reduce_mean(per_sample_loss)
return loss
def forward(self, audios, mels, audio_starts, sample=False):
# Build conditioner based on mels.
full_conditioner = self.conditioner(mels)
# Slice conditioners.
audio_length = audios.shape[1]
conditioner = extract_slices(full_conditioner, audio_starts,
audio_length, self.rank)
# input_audio, target_audio: [bs, len]
input_audios = audios[:, :-1]
target_audios = audios[:, 1:]
# conditioner: [bs, len, mel_bands]
conditioner = conditioner[:, 1:, :]
loss_type = self.config.loss_type
if loss_type == "softmax":
input_audios = fluid.layers.clip(
input_audios, min=-1.0, max=0.99999)
# quantized have values in [0, num_channels)
quantized = fluid.layers.cast(
(input_audios + 1.0) / 2.0 * self.config.num_channels,
dtype="int64")
layer_input = self.embedding_fc(
fluid.layers.unsqueeze(quantized, 2))
elif loss_type == "mix-gaussian-pdf":
layer_input = self.embedding_fc(
fluid.layers.unsqueeze(input_audios, 2))
else:
raise ValueError("loss_type {} is unsupported!".format(loss_type))
# layer_input: [bs, res_channel, 1, len]
layer_input = fluid.layers.unsqueeze(
fluid.layers.transpose(
layer_input, perm=[0, 2, 1]), 2)
# conditioner: [bs, mel_bands, 1, len]
conditioner = fluid.layers.unsqueeze(
fluid.layers.transpose(
conditioner, perm=[0, 2, 1]), 2)
skip = None
for i, layer in enumerate(self.dilated_causal_convs):
# layer_input: [bs, res_channel, 1, len]
# skip: [bs, res_channel, 1, len]
layer_input, skip = layer(layer_input, skip, conditioner)
# Reshape skip to [bs, len, res_channel]
skip = fluid.layers.transpose(
fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
mix_parameters = self.fc3(self.fc2(self.fc1(skip)))
# Sample teacher-forced audio.
sample_audios = None
if sample:
if loss_type == "softmax":
sample_audios = self.sample_softmax(mix_parameters)
elif loss_type == "mix-gaussian-pdf":
sample_audios = self.sample_mix_gaussian(mix_parameters)
else:
raise ValueError("loss_type {} is unsupported!".format(
loss_type))
if loss_type == "softmax":
loss = self.softmax_loss(target_audios, mix_parameters)
elif loss_type == "mix-gaussian-pdf":
loss = self.mixture_density_loss(target_audios, mix_parameters,
self.log_scale_min)
else:
raise ValueError("loss_type {} is unsupported!".format(loss_type))
return loss, sample_audios
def synthesize(self, mels):
self.start_new_sequence()
bs, n_frames, mel_bands = mels.shape
conditioner = self.conditioner(mels)
time_steps = conditioner.shape[1]
print("input mels shape", mels.shape)
print("Total synthesis steps", time_steps)
loss_type = self.config.loss_type
audio_samples = []
current_sample = fluid.layers.zeros(shape=[bs, 1, 1], dtype="float32")
for i in range(time_steps):
if i % 100 == 0:
print("Step", i)
# Convert from real value sample to audio embedding.
# audio_input: [bs, 1, channel]
if loss_type == "softmax":
current_sample = fluid.layers.clip(
current_sample, min=-1.0, max=0.99999)
# quantized have values in [0, num_channels)
quantized = fluid.layers.cast(
(current_sample + 1.0) / 2.0 * self.config.num_channels,
dtype="int64")
audio_input = self.embedding_fc(quantized)
elif loss_type == "mix-gaussian-pdf":
audio_input = self.embedding_fc(current_sample)
else:
raise ValueError("loss_type {} is unsupported!".format(
loss_type))
# [bs, channel, 1, 1]
audio_input = fluid.layers.unsqueeze(
fluid.layers.transpose(
audio_input, perm=[0, 2, 1]), 2)
# [bs, mel_bands]
cond_input = conditioner[:, i, :]
# [bs, mel_bands, 1, 1]
cond_input = fluid.layers.reshape(cond_input,
cond_input.shape + [1, 1])
skip = None
for layer in self.dilated_causal_convs:
audio_input, skip = layer.add_input(audio_input, skip,
cond_input)
# [bs, 1, channel]
skip = fluid.layers.transpose(
fluid.layers.squeeze(skip, [2]), perm=[0, 2, 1])
mix_parameters = self.fc3(self.fc2(self.fc1(skip)))
if loss_type == "softmax":
sample = self.sample_softmax(mix_parameters)
elif loss_type == "mix-gaussian-pdf":
sample = self.sample_mix_gaussian(mix_parameters)
else:
raise ValueError("loss_type {} is unsupported!".format(
loss_type))
audio_samples.append(sample)
# [bs]
current_sample = audio_samples[-1]
# [bs, 1, 1]
current_sample = fluid.layers.reshape(
current_sample, current_sample.shape + [1, 1])
# syn_audio: [num_samples]
syn_audio = fluid.layers.concat(audio_samples, axis=0).numpy()
return syn_audio
def start_new_sequence(self):
for layer in self.sublayers():
if isinstance(layer, conv.Conv1D):
layer.start_new_sequence()