Merge branch 'waveflow' into 'master'

Adding WaveFlow model verified on LJSpeech dataset

See merge request !6
This commit is contained in:
liuyibing01 2019-12-22 12:16:04 +08:00
commit b866556cbd
9 changed files with 1191 additions and 0 deletions

View File

@ -0,0 +1,111 @@
# WaveFlow with Paddle Fluid
Paddle fluid implementation of [WaveFlow: A Compact Flow-based Model for Raw Audio](https://arxiv.org/abs/1912.01219).
## Project Structure
```text
├── configs # yaml configuration files of preset model hyperparameters
├── benchmark.py # benchmark code to test the speed of batched speech synthesis
├── data.py # dataset and dataloader settings for LJSpeech
├── synthesis.py # script for speech synthesis
├── train.py # script for model training
├── utils.py # helper functions for e.g., model checkpointing
├── waveflow.py # WaveFlow model high level APIs
└── waveflow_modules.py # WaveFlow model implementation
```
## Usage
There are many hyperparameters to be tuned depending on the specification of model and dataset you are working on.
We provide `wavenet_ljspeech.yaml` as a hyperparameter set that works well on the LJSpeech dataset.
Note that `train.py`, `synthesis.py`, and `benchmark.py` all accept a `--config` parameter. To ensure consistency, you should use the same config yaml file for both training, synthesizing and benchmarking. You can also overwrite these preset hyperparameters with command line by updating parameters after `--config`.
For example `--config=${yaml} --batch_size=8` can overwrite the corresponding hyperparameters in the `${yaml}` config file. For more details about these hyperparameters, check `utils.add_config_options_to_parser`.
Note that you also need to specify some additional parameters for `train.py`, `synthesis.py`, and `benchmark.py`, and the details can be found in `train.add_options_to_parser`, `synthesis.add_options_to_parser`, and `benchmark.add_options_to_parser`, respectively.
### Dataset
Download and unzip [LJSpeech](https://keithito.com/LJ-Speech-Dataset/).
```bash
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
tar xjvf LJSpeech-1.1.tar.bz2
```
In this example, assume that the path of unzipped LJSpeech dataset is `./data/LJSpeech-1.1`.
### Train on single GPU
```bash
export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.."
export CUDA_VISIBLE_DEVICES=0
python -u train.py \
--config=./configs/waveflow_ljspeech.yaml \
--root=./data/LJSpeech-1.1 \
--name=${ModelName} --batch_size=4 \
--parallel=false --use_gpu=true
```
#### Save and Load checkpoints
Our model will save model parameters as checkpoints in `./runs/waveflow/${ModelName}/checkpoint/` every 10000 iterations by default.
The saved checkpoint will have the format of `step-${iteration_number}.pdparams` for model parameters and `step-${iteration_number}.pdopt` for optimizer parameters.
There are three ways to load a checkpoint and resume training (take an example that you want to load a 500000-iteration checkpoint):
1. Use `--checkpoint=./runs/waveflow/${ModelName}/checkpoint/step-500000` to provide a specific path to load. Note that you only need to provide the base name of the parameter file, which is `step-500000`, no extension name `.pdparams` or `.pdopt` is needed.
2. Use `--iteration=500000`.
3. If you don't specify either `--checkpoint` or `--iteration`, the model will automatically load the latest checkpoint in `./runs/waveflow/${ModelName}/checkpoint`.
### Train on multiple GPUs
```bash
export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.."
export CUDA_VISIBLE_DEVICES=0,1,2,3
python -u -m paddle.distributed.launch train.py \
--config=./configs/waveflow_ljspeech.yaml \
--root=./data/LJSpeech-1.1 \
--name=${ModelName} --parallel=true --use_gpu=true
```
Use `export CUDA_VISIBLE_DEVICES=0,1,2,3` to set the GPUs that you want to use to be visible. Then the `paddle.distributed.launch` module will use these visible GPUs to do data parallel training in multiprocessing mode.
### Monitor with Tensorboard
By default, the logs are saved in `./runs/waveflow/${ModelName}/logs/`. You can monitor logs by tensorboard.
```bash
tensorboard --logdir=${log_dir} --port=8888
```
### Synthesize from a checkpoint
Check the [Save and load checkpoint](#save-and-load-checkpoints) section on how to load a specific checkpoint.
The following example will automatically load the latest checkpoint:
```bash
export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.."
export CUDA_VISIBLE_DEVICES=0
python -u synthesis.py \
--config=./configs/waveflow_ljspeech.yaml \
--root=./data/LJSpeech-1.1 \
--name=${ModelName} --use_gpu=true \
--output=./syn_audios \
--sample=${SAMPLE} \
--sigma=1.0
```
In this example, `--output` specifies where to save the synthesized audios and `--sample` specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset.
### Benchmarking
Use the following example to benchmark the speed of batched speech synthesis, which reports how many times faster than real-time:
```bash
export PYTHONPATH="${PYTHONPATH}:${PWD}/../../.."
export CUDA_VISIBLE_DEVICES=0
python -u benchmark.py \
--config=./configs/waveflow_ljspeech.yaml \
--root=./data/LJSpeech-1.1 \
--name=${ModelName} --use_gpu=true
```

View File

@ -0,0 +1,71 @@
import os
import random
from pprint import pprint
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
import utils
from waveflow import WaveFlow
def add_options_to_parser(parser):
parser.add_argument('--model', type=str, default='waveflow',
help="general name of the model")
parser.add_argument('--name', type=str,
help="specific name of the training model")
parser.add_argument('--root', type=str,
help="root path of the LJSpeech dataset")
parser.add_argument('--use_gpu', type=bool, default=True,
help="option to use gpu training")
parser.add_argument('--iteration', type=int, default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument('--checkpoint', type=str, default=None,
help="path of the checkpoint to load")
def benchmark(config):
pprint(jsonargparse.namespace_to_dict(config))
# Get checkpoint directory path.
run_dir = os.path.join("runs", config.model, config.name)
checkpoint_dir = os.path.join(run_dir, "checkpoint")
# Configurate device.
place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace()
with dg.guard(place):
# Fix random seed.
seed = config.seed
random.seed(seed)
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
print("Random Seed: ", seed)
# Build model.
model = WaveFlow(config, checkpoint_dir)
model.build(training=False)
# Run model inference.
model.benchmark()
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(
description="Synthesize audio using WaveNet model",
formatter_class='default_argparse')
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
# Parse argument from both command line and yaml config file.
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
benchmark(config)

View File

@ -0,0 +1,24 @@
valid_size: 16
segment_length: 16000
sample_rate: 22050
fft_window_shift: 256
fft_window_size: 1024
fft_size: 1024
mel_bands: 80
mel_fmin: 0.0
mel_fmax: 8000.0
seed: 1234
learning_rate: 0.0002
batch_size: 8
test_every: 2000
save_every: 10000
max_iterations: 3000000
sigma: 1.0
n_flows: 8
n_group: 16
n_layers: 8
n_channels: 64
kernel_h: 3
kernel_w: 3

View File

@ -0,0 +1,131 @@
import random
import librosa
import numpy as np
from paddle import fluid
from parakeet.datasets import ljspeech
from parakeet.data import dataset
from parakeet.data.batch import SpecBatcher, WavBatcher
from parakeet.data.datacargo import DataCargo
from parakeet.data.sampler import DistributedSampler, BatchSampler
from scipy.io.wavfile import read
class Dataset(ljspeech.LJSpeech):
def __init__(self, config):
super(Dataset, self).__init__(config.root)
self.config = config
def _get_example(self, metadatum):
fname, _, _ = metadatum
wav_path = self.root.joinpath("wavs", fname + ".wav")
loaded_sr, audio = read(wav_path)
assert loaded_sr == self.config.sample_rate
return audio
class Subset(dataset.Dataset):
def __init__(self, dataset, indices, valid):
self.dataset = dataset
self.indices = indices
self.valid = valid
self.config = dataset.config
def get_mel(self, audio):
spectrogram = librosa.core.stft(
audio, n_fft=self.config.fft_size,
hop_length=self.config.fft_window_shift,
win_length=self.config.fft_window_size)
spectrogram_magnitude = np.abs(spectrogram)
# mel_filter_bank shape: [n_mels, 1 + n_fft/2]
mel_filter_bank = librosa.filters.mel(
sr=self.config.sample_rate,
n_fft=self.config.fft_size,
n_mels=self.config.mel_bands,
fmin=self.config.mel_fmin,
fmax=self.config.mel_fmax)
# mel shape: [n_mels, num_frames]
mel = np.dot(mel_filter_bank, spectrogram_magnitude)
# Normalize mel.
clip_val = 1e-5
ref_constant = 1
mel = np.log(np.clip(mel, a_min=clip_val, a_max=None) * ref_constant)
return mel
def __getitem__(self, idx):
audio = self.dataset[self.indices[idx]]
segment_length = self.config.segment_length
if self.valid:
# whole audio for valid set
pass
else:
# audio shape: [len]
if audio.shape[0] >= segment_length:
max_audio_start = audio.shape[0] - segment_length
audio_start = random.randint(0, max_audio_start)
audio = audio[audio_start : (audio_start + segment_length)]
else:
audio = np.pad(audio, (0, segment_length - audio.shape[0]),
mode='constant', constant_values=0)
# Normalize audio to the [-1, 1] range.
audio = audio.astype(np.float32) / 32768.0
mel = self.get_mel(audio)
return audio, mel
def _batch_examples(self, batch):
audios = [sample[0] for sample in batch]
mels = [sample[1] for sample in batch]
audios = WavBatcher(pad_value=0.0)(audios)
mels = SpecBatcher(pad_value=0.0)(mels)
return audios, mels
def __len__(self):
return len(self.indices)
class LJSpeech:
def __init__(self, config, nranks, rank):
place = fluid.CUDAPlace(rank) if config.use_gpu else fluid.CPUPlace()
# Whole LJSpeech dataset.
ds = Dataset(config)
# Split into train and valid dataset.
indices = list(range(len(ds)))
train_indices = indices[config.valid_size:]
valid_indices = indices[:config.valid_size]
random.shuffle(train_indices)
# Train dataset.
trainset = Subset(ds, train_indices, valid=False)
sampler = DistributedSampler(len(trainset), nranks, rank)
total_bs = config.batch_size
assert total_bs % nranks == 0
train_sampler = BatchSampler(sampler, total_bs // nranks,
drop_last=True)
trainloader = DataCargo(trainset, batch_sampler=train_sampler)
trainreader = fluid.io.PyReader(capacity=50, return_list=True)
trainreader.decorate_batch_generator(trainloader, place)
self.trainloader = (data for _ in iter(int, 1)
for data in trainreader())
# Valid dataset.
validset = Subset(ds, valid_indices, valid=True)
# Currently only support batch_size = 1 for valid loader.
validloader = DataCargo(validset, batch_size=1, shuffle=False)
validreader = fluid.io.PyReader(capacity=20, return_list=True)
validreader.decorate_batch_generator(validloader, place)
self.validloader = validreader

View File

@ -0,0 +1,85 @@
import os
import random
from pprint import pprint
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
import utils
from waveflow import WaveFlow
def add_options_to_parser(parser):
parser.add_argument('--model', type=str, default='waveflow',
help="general name of the model")
parser.add_argument('--name', type=str,
help="specific name of the training model")
parser.add_argument('--root', type=str,
help="root path of the LJSpeech dataset")
parser.add_argument('--use_gpu', type=bool, default=True,
help="option to use gpu training")
parser.add_argument('--iteration', type=int, default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument('--checkpoint', type=str, default=None,
help="path of the checkpoint to load")
parser.add_argument('--output', type=str, default="./syn_audios",
help="path to write synthesized audio files")
parser.add_argument('--sample', type=int, default=None,
help="which of the valid samples to synthesize audio")
def synthesize(config):
pprint(jsonargparse.namespace_to_dict(config))
# Get checkpoint directory path.
run_dir = os.path.join("runs", config.model, config.name)
checkpoint_dir = os.path.join(run_dir, "checkpoint")
# Configurate device.
place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace()
with dg.guard(place):
# Fix random seed.
seed = config.seed
random.seed(seed)
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
print("Random Seed: ", seed)
# Build model.
model = WaveFlow(config, checkpoint_dir)
model.build(training=False)
# Obtain the current iteration.
if config.checkpoint is None:
if config.iteration is None:
iteration = utils.load_latest_checkpoint(checkpoint_dir)
else:
iteration = config.iteration
else:
iteration = int(config.checkpoint.split('/')[-1].split('-')[-1])
# Run model inference.
model.infer(iteration)
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(
description="Synthesize audio using WaveNet model",
formatter_class='default_argparse')
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
# Parse argument from both command line and yaml config file.
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
synthesize(config)

View File

@ -0,0 +1,114 @@
import os
import random
import subprocess
import time
from pprint import pprint
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from tensorboardX import SummaryWriter
import slurm
import utils
from waveflow import WaveFlow
def add_options_to_parser(parser):
parser.add_argument('--model', type=str, default='waveflow',
help="general name of the model")
parser.add_argument('--name', type=str,
help="specific name of the training model")
parser.add_argument('--root', type=str,
help="root path of the LJSpeech dataset")
parser.add_argument('--parallel', type=bool, default=True,
help="option to use data parallel training")
parser.add_argument('--use_gpu', type=bool, default=True,
help="option to use gpu training")
parser.add_argument('--iteration', type=int, default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument('--checkpoint', type=str, default=None,
help="path of the checkpoint to load")
def train(config):
use_gpu = config.use_gpu
parallel = config.parallel if use_gpu else False
# Get the rank of the current training process.
rank = dg.parallel.Env().local_rank if parallel else 0
nranks = dg.parallel.Env().nranks if parallel else 1
if rank == 0:
# Print the whole config setting.
pprint(jsonargparse.namespace_to_dict(config))
# Make checkpoint directory.
run_dir = os.path.join("runs", config.model, config.name)
checkpoint_dir = os.path.join(run_dir, "checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)
# Create tensorboard logger.
tb = SummaryWriter(os.path.join(run_dir, "logs")) \
if rank == 0 else None
# Configurate device
place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace()
with dg.guard(place):
# Fix random seed.
seed = config.seed
random.seed(seed)
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
print("Random Seed: ", seed)
# Build model.
model = WaveFlow(config, checkpoint_dir, parallel, rank, nranks, tb)
model.build()
# Obtain the current iteration.
if config.checkpoint is None:
if config.iteration is None:
iteration = utils.load_latest_checkpoint(checkpoint_dir, rank)
else:
iteration = config.iteration
else:
iteration = int(config.checkpoint.split('/')[-1].split('-')[-1])
while iteration < config.max_iterations:
# Run one single training step.
model.train_step(iteration)
iteration += 1
if iteration % config.test_every == 0:
# Run validation step.
model.valid_step(iteration)
if rank == 0 and iteration % config.save_every == 0:
# Save parameters.
model.save(iteration)
# Close TensorBoard.
if rank == 0:
tb.close()
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(description="Train WaveFlow model",
formatter_class='default_argparse')
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
# Parse argument from both command line and yaml config file.
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
train(config)

View File

@ -0,0 +1,114 @@
import itertools
import os
import time
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
def add_config_options_to_parser(parser):
parser.add_argument('--valid_size', type=int,
help="size of the valid dataset")
parser.add_argument('--segment_length', type=int,
help="the length of audio clip for training")
parser.add_argument('--sample_rate', type=int,
help="sampling rate of audio data file")
parser.add_argument('--fft_window_shift', type=int,
help="the shift of fft window for each frame")
parser.add_argument('--fft_window_size', type=int,
help="the size of fft window for each frame")
parser.add_argument('--fft_size', type=int,
help="the size of fft filter on each frame")
parser.add_argument('--mel_bands', type=int,
help="the number of mel bands when calculating mel spectrograms")
parser.add_argument('--mel_fmin', type=float,
help="lowest frequency in calculating mel spectrograms")
parser.add_argument('--mel_fmax', type=float,
help="highest frequency in calculating mel spectrograms")
parser.add_argument('--seed', type=int,
help="seed of random initialization for the model")
parser.add_argument('--learning_rate', type=float)
parser.add_argument('--batch_size', type=int,
help="batch size for training")
parser.add_argument('--test_every', type=int,
help="test interval during training")
parser.add_argument('--save_every', type=int,
help="checkpointing interval during training")
parser.add_argument('--max_iterations', type=int,
help="maximum training iterations")
parser.add_argument('--sigma', type=float,
help="standard deviation of the latent Gaussian variable")
parser.add_argument('--n_flows', type=int,
help="number of flows")
parser.add_argument('--n_group', type=int,
help="number of adjacent audio samples to squeeze into one column")
parser.add_argument('--n_layers', type=int,
help="number of conv2d layer in one wavenet-like flow architecture")
parser.add_argument('--n_channels', type=int,
help="number of residual channels in flow")
parser.add_argument('--kernel_h', type=int,
help="height of the kernel in the conv2d layer")
parser.add_argument('--kernel_w', type=int,
help="width of the kernel in the conv2d layer")
parser.add_argument('--config', action=jsonargparse.ActionConfigFile)
def load_latest_checkpoint(checkpoint_dir, rank=0):
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Create checkpoint index file if not exist.
if (not os.path.isfile(checkpoint_path)) and rank == 0:
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-0")
# Make sure that other process waits until checkpoint file is created
# by process 0.
while not os.path.isfile(checkpoint_path):
time.sleep(1)
# Fetch the latest checkpoint index.
with open(checkpoint_path, "r") as handle:
latest_checkpoint = handle.readline().split()[-1]
iteration = int(latest_checkpoint.split("-")[-1])
return iteration
def save_latest_checkpoint(checkpoint_dir, iteration):
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_path, "w") as handle:
handle.write("model_checkpoint_path: step-{}".format(iteration))
def load_parameters(checkpoint_dir, rank, model, optimizer=None,
iteration=None, file_path=None):
if file_path is None:
if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank)
if iteration == 0:
return
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict, optimizer_dict = dg.load_dygraph(file_path)
model.set_dict(model_dict)
print("[checkpoint] Rank {}: loaded model from {}".format(rank, file_path))
if optimizer and optimizer_dict:
optimizer.set_dict(optimizer_dict)
print("[checkpoint] Rank {}: loaded optimizer state from {}".format(
rank, file_path))
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict = model.state_dict()
dg.save_dygraph(model_dict, file_path)
print("[checkpoint] Saved model to {}".format(file_path))
if optimizer:
opt_dict = optimizer.state_dict()
dg.save_dygraph(opt_dict, file_path)
print("[checkpoint] Saved optimzier state to {}".format(file_path))

View File

@ -0,0 +1,190 @@
import itertools
import os
import time
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from scipy.io.wavfile import write
import utils
from data import LJSpeech
from waveflow_modules import WaveFlowLoss, WaveFlowModule
class WaveFlow():
def __init__(self, config, checkpoint_dir, parallel=False, rank=0,
nranks=1, tb_logger=None):
self.config = config
self.checkpoint_dir = checkpoint_dir
self.parallel = parallel
self.rank = rank
self.nranks = nranks
self.tb_logger = tb_logger
def build(self, training=True):
config = self.config
dataset = LJSpeech(config, self.nranks, self.rank)
self.trainloader = dataset.trainloader
self.validloader = dataset.validloader
waveflow = WaveFlowModule("waveflow", config)
# Dry run once to create and initalize all necessary parameters.
audio = dg.to_variable(np.random.randn(1, 16000).astype(np.float32))
mel = dg.to_variable(
np.random.randn(1, config.mel_bands, 63).astype(np.float32))
waveflow(audio, mel)
if training:
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=config.learning_rate)
# Load parameters.
utils.load_parameters(self.checkpoint_dir, self.rank,
waveflow, optimizer,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
# Data parallelism.
if self.parallel:
strategy = dg.parallel.prepare_context()
waveflow = dg.parallel.DataParallel(waveflow, strategy)
self.waveflow = waveflow
self.optimizer = optimizer
self.criterion = WaveFlowLoss(config.sigma)
else:
# Load parameters.
utils.load_parameters(self.checkpoint_dir, self.rank, waveflow,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
self.waveflow = waveflow
def train_step(self, iteration):
self.waveflow.train()
start_time = time.time()
audios, mels = next(self.trainloader)
load_time = time.time()
outputs = self.waveflow(audios, mels)
loss = self.criterion(outputs)
if self.parallel:
# loss = loss / num_trainers
loss = self.waveflow.scale_loss(loss)
loss.backward()
self.waveflow.apply_collective_grads()
else:
loss.backward()
self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters())
self.waveflow.clear_gradients()
graph_time = time.time()
if self.rank == 0:
loss_val = float(loss.numpy()) * self.nranks
log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \
"Time: {:.3f}/{:.3f}".format(
self.rank, iteration, loss_val,
load_time - start_time, graph_time - load_time)
print(log)
tb = self.tb_logger
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
@dg.no_grad
def valid_step(self, iteration):
self.waveflow.eval()
tb = self.tb_logger
total_loss = []
sample_audios = []
start_time = time.time()
for i, batch in enumerate(self.validloader()):
audios, mels = batch
valid_outputs = self.waveflow(audios, mels)
valid_z, valid_log_s_list = valid_outputs
# Visualize latent z and scale log_s.
if self.rank == 0 and i == 0:
tb.add_histogram("Valid-Latent_z", valid_z.numpy(), iteration)
for j, valid_log_s in enumerate(valid_log_s_list):
hist_name = "Valid-{}th-Flow-Log_s".format(j)
tb.add_histogram(hist_name, valid_log_s.numpy(), iteration)
valid_loss = self.criterion(valid_outputs)
total_loss.append(float(valid_loss.numpy()))
total_time = time.time() - start_time
if self.rank == 0:
loss_val = np.mean(total_loss)
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
self.rank, loss_val, total_time)
print(log)
tb.add_scalar("Valid-Avg-Loss", loss_val, iteration)
@dg.no_grad
def infer(self, iteration):
self.waveflow.eval()
config = self.config
sample = config.sample
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
os.makedirs(output, exist_ok=True)
mels_list = [mels for _, mels in self.validloader()]
if sample is not None:
mels_list = [mels_list[sample]]
for sample, mel in enumerate(mels_list):
filename = "{}/valid_{}.wav".format(output, sample)
print("Synthesize sample {}, save as {}".format(sample, filename))
start_time = time.time()
audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
syn_time = time.time() - start_time
audio = audio[0]
audio_time = audio.shape[0] / self.config.sample_rate
print("audio time {:.4f}, synthesis time {:.4f}".format(
audio_time, syn_time))
# Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
audio = audio.numpy() * 32768.0
audio = audio.astype('int16')
write(filename, config.sample_rate, audio)
@dg.no_grad
def benchmark(self):
self.waveflow.eval()
mels_list = [mels for _, mels in self.validloader()]
mel = fluid.layers.concat(mels_list, axis=2)
mel = mel[:, :, :864]
batch_size = 8
mel = fluid.layers.expand(mel, [batch_size, 1, 1])
for i in range(10):
start_time = time.time()
audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
print("audio.shape = ", audio.shape)
syn_time = time.time() - start_time
audio_time = audio.shape[1] * batch_size / self.config.sample_rate
print("audio time {:.4f}, synthesis time {:.4f}".format(
audio_time, syn_time))
print("{} X real-time".format(audio_time / syn_time))
def save(self, iteration):
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.waveflow, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)

View File

@ -0,0 +1,351 @@
import itertools
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from parakeet.modules import conv, modules, weight_norm
def set_param_attr(layer, c_in=1):
if isinstance(layer, (weight_norm.Conv2DTranspose, weight_norm.Conv2D)):
k = np.sqrt(1.0 / (c_in * np.prod(layer._filter_size)))
weight_init = fluid.initializer.UniformInitializer(low=-k, high=k)
bias_init = fluid.initializer.UniformInitializer(low=-k, high=k)
elif isinstance(layer, dg.Conv2D):
weight_init = fluid.initializer.ConstantInitializer(0.0)
bias_init = fluid.initializer.ConstantInitializer(0.0)
else:
raise TypeError("Unsupported layer type.")
layer._param_attr = fluid.ParamAttr(initializer=weight_init)
layer._bias_attr = fluid.ParamAttr(initializer=bias_init)
def unfold(x, n_group):
length = x.shape[-1]
new_shape = x.shape[:-1] + [length // n_group, n_group]
return fluid.layers.reshape(x, new_shape)
class WaveFlowLoss:
def __init__(self, sigma=1.0):
self.sigma = sigma
def __call__(self, model_output):
z, log_s_list = model_output
for i, log_s in enumerate(log_s_list):
if i == 0:
log_s_total = fluid.layers.reduce_sum(log_s)
else:
log_s_total = log_s_total + fluid.layers.reduce_sum(log_s)
loss = fluid.layers.reduce_sum(z * z) / (2 * self.sigma * self.sigma) \
- log_s_total
loss = loss / np.prod(z.shape)
const = 0.5 * np.log(2 * np.pi) + np.log(self.sigma)
return loss + const
class Conditioner(dg.Layer):
def __init__(self, name_scope):
super(Conditioner, self).__init__(name_scope)
upsample_factors = [16, 16]
self.upsample_conv2d = []
for s in upsample_factors:
in_channel = 1
conv_trans2d = modules.Conv2DTranspose(
self.full_name(),
num_filters=1,
filter_size=(3, 2 * s),
padding=(1, s // 2),
stride=(1, s))
set_param_attr(conv_trans2d, c_in=in_channel)
self.upsample_conv2d.append(conv_trans2d)
for i, layer in enumerate(self.upsample_conv2d):
self.add_sublayer("conv2d_transpose_{}".format(i), layer)
def forward(self, x):
x = fluid.layers.unsqueeze(x, 1)
for layer in self.upsample_conv2d:
x = fluid.layers.leaky_relu(layer(x), alpha=0.4)
return fluid.layers.squeeze(x, [1])
def infer(self, x):
x = fluid.layers.unsqueeze(x, 1)
for layer in self.upsample_conv2d:
x = layer(x)
# Trim conv artifacts.
time_cutoff = layer._filter_size[1] - layer._stride[1]
x = fluid.layers.leaky_relu(x[:, :, :, :-time_cutoff], alpha=0.4)
return fluid.layers.squeeze(x, [1])
class Flow(dg.Layer):
def __init__(self, name_scope, config):
super(Flow, self).__init__(name_scope)
self.n_layers = config.n_layers
self.n_channels = config.n_channels
self.kernel_h = config.kernel_h
self.kernel_w = config.kernel_w
# Transform audio: [batch, 1, n_group, time/n_group]
# => [batch, n_channels, n_group, time/n_group]
self.start = weight_norm.Conv2D(
self.full_name(),
num_filters=self.n_channels,
filter_size=(1, 1))
set_param_attr(self.start, c_in=1)
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. This helps with training stability
# output shape: [batch, 2, n_group, time/n_group]
self.end = dg.Conv2D(
self.full_name(),
num_filters=2,
filter_size=(1, 1))
set_param_attr(self.end)
# receiptive fileds: (kernel - 1) * sum(dilations) + 1 >= squeeze
dilation_dict = {8: [1, 1, 1, 1, 1, 1, 1, 1],
16: [1, 1, 1, 1, 1, 1, 1, 1],
32: [1, 2, 4, 1, 2, 4, 1, 2],
64: [1, 2, 4, 8, 16, 1, 2, 4],
128: [1, 2, 4, 8, 16, 32, 64, 1]}
self.dilation_h_list = dilation_dict[config.n_group]
self.in_layers = []
self.cond_layers = []
self.res_skip_layers = []
for i in range(self.n_layers):
dilation_h = self.dilation_h_list[i]
dilation_w = 2 ** i
in_layer = weight_norm.Conv2D(
self.full_name(),
num_filters=2 * self.n_channels,
filter_size=(self.kernel_h, self.kernel_w),
dilation=(dilation_h, dilation_w))
set_param_attr(in_layer, c_in=self.n_channels)
self.in_layers.append(in_layer)
cond_layer = weight_norm.Conv2D(
self.full_name(),
num_filters=2 * self.n_channels,
filter_size=(1, 1))
set_param_attr(cond_layer, c_in=config.mel_bands)
self.cond_layers.append(cond_layer)
if i < self.n_layers - 1:
res_skip_channels = 2 * self.n_channels
else:
res_skip_channels = self.n_channels
res_skip_layer = weight_norm.Conv2D(
self.full_name(),
num_filters=res_skip_channels,
filter_size=(1, 1))
set_param_attr(res_skip_layer, c_in=self.n_channels)
self.res_skip_layers.append(res_skip_layer)
self.add_sublayer("in_layer_{}".format(i), in_layer)
self.add_sublayer("cond_layer_{}".format(i), cond_layer)
self.add_sublayer("res_skip_layer_{}".format(i), res_skip_layer)
def forward(self, audio, mel):
# audio: [bs, 1, n_group, time/group]
# mel: [bs, mel_bands, n_group, time/n_group]
audio = self.start(audio)
for i in range(self.n_layers):
dilation_h = self.dilation_h_list[i]
dilation_w = 2 ** i
# Pad height dim (n_group): causal convolution
# Pad width dim (time): dialated non-causal convolution
pad_top, pad_bottom = (self.kernel_h - 1) * dilation_h, 0
pad_left = pad_right = int((self.kernel_w-1) * dilation_w / 2)
audio_pad = fluid.layers.pad2d(audio,
paddings=[pad_top, pad_bottom, pad_left, pad_right])
hidden = self.in_layers[i](audio_pad)
cond_hidden = self.cond_layers[i](mel)
in_acts = hidden + cond_hidden
out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \
fluid.layers.sigmoid(in_acts[:, self.n_channels:, :])
res_skip_acts = self.res_skip_layers[i](out_acts)
if i < self.n_layers - 1:
audio += res_skip_acts[:, :self.n_channels, :, :]
skip_acts = res_skip_acts[:, self.n_channels:, :, :]
else:
skip_acts = res_skip_acts
if i == 0:
output = skip_acts
else:
output += skip_acts
return self.end(output)
def infer(self, audio, mel, queues):
audio = self.start(audio)
for i in range(self.n_layers):
dilation_h = self.dilation_h_list[i]
dilation_w = 2 ** i
state_size = dilation_h * (self.kernel_h - 1)
queue = queues[i]
if len(queue) == 0:
for j in range(state_size):
queue.append(fluid.layers.zeros_like(audio))
state = queue[0:state_size]
state = fluid.layers.concat([*state, audio], axis=2)
queue.pop(0)
queue.append(audio)
# Pad height dim (n_group): causal convolution
# Pad width dim (time): dialated non-causal convolution
pad_top, pad_bottom = 0, 0
pad_left = int((self.kernel_w-1) * dilation_w / 2)
pad_right = int((self.kernel_w-1) * dilation_w / 2)
state = fluid.layers.pad2d(state,
paddings=[pad_top, pad_bottom, pad_left, pad_right])
hidden = self.in_layers[i](state)
cond_hidden = self.cond_layers[i](mel)
in_acts = hidden + cond_hidden
out_acts = fluid.layers.tanh(in_acts[:, :self.n_channels, :]) * \
fluid.layers.sigmoid(in_acts[:, self.n_channels:, :])
res_skip_acts = self.res_skip_layers[i](out_acts)
if i < self.n_layers - 1:
audio += res_skip_acts[:, :self.n_channels, :, :]
skip_acts = res_skip_acts[:, self.n_channels:, :, :]
else:
skip_acts = res_skip_acts
if i == 0:
output = skip_acts
else:
output += skip_acts
return self.end(output)
class WaveFlowModule(dg.Layer):
def __init__(self, name_scope, config):
super(WaveFlowModule, self).__init__(name_scope)
self.n_flows = config.n_flows
self.n_group = config.n_group
self.n_layers = config.n_layers
assert self.n_group % 2 == 0
assert self.n_flows % 2 == 0
self.conditioner = Conditioner(self.full_name())
self.flows = []
for i in range(self.n_flows):
flow = Flow(self.full_name(), config)
self.flows.append(flow)
self.add_sublayer("flow_{}".format(i), flow)
self.perms = []
half = self.n_group // 2
for i in range(self.n_flows):
perm = list(range(self.n_group))
if i < self.n_flows // 2:
perm = perm[::-1]
else:
perm[:half] = reversed(perm[:half])
perm[half:] = reversed(perm[half:])
self.perms.append(perm)
def forward(self, audio, mel):
mel = self.conditioner(mel)
assert mel.shape[2] >= audio.shape[1]
# Prune out the tail of audio/mel so that time/n_group == 0.
pruned_len = audio.shape[1] // self.n_group * self.n_group
if audio.shape[1] > pruned_len:
audio = audio[:, :pruned_len]
if mel.shape[2] > pruned_len:
mel = mel[:, :, :pruned_len]
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
# From [bs, time] to [bs, n_group, time/n_group]
audio = fluid.layers.transpose(unfold(audio, self.n_group), [0, 2, 1])
# [bs, 1, n_group, time/n_group]
audio = fluid.layers.unsqueeze(audio, 1)
log_s_list = []
for i in range(self.n_flows):
inputs = audio[:, :, :-1, :]
conds = mel[:, :, 1:, :]
outputs = self.flows[i](inputs, conds)
log_s = outputs[:, :1, :, :]
b = outputs[:, 1:, :, :]
log_s_list.append(log_s)
audio_0 = audio[:, :, :1, :]
audio_out = audio[:, :, 1:, :] * fluid.layers.exp(log_s) + b
audio = fluid.layers.concat([audio_0, audio_out], axis=2)
# Permute over the height dim.
audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
audio = fluid.layers.stack(audio_slices, axis=2)
mel_slices = [mel[:, :, j, :] for j in self.perms[i]]
mel = fluid.layers.stack(mel_slices, axis=2)
z = fluid.layers.squeeze(audio, [1])
return z, log_s_list
def synthesize(self, mel, sigma=1.0):
mel = self.conditioner.infer(mel)
# From [bs, mel_bands, time] to [bs, mel_bands, n_group, time/n_group]
mel = fluid.layers.transpose(unfold(mel, self.n_group), [0, 1, 3, 2])
audio = fluid.layers.gaussian_random(
shape=[mel.shape[0], 1, mel.shape[2], mel.shape[3]], std=sigma)
for i in reversed(range(self.n_flows)):
# Permute over the height dimension.
audio_slices = [audio[:, :, j, :] for j in self.perms[i]]
audio = fluid.layers.stack(audio_slices, axis=2)
mel_slices = [mel[:, :, j, :] for j in self.perms[i]]
mel = fluid.layers.stack(mel_slices, axis=2)
audio_list = []
audio_0 = audio[:, :, 0:1, :]
audio_list.append(audio_0)
audio_h = audio_0
queues = [[] for _ in range(self.n_layers)]
for h in range(1, self.n_group):
inputs = audio_h
conds = mel[:, :, h:(h+1), :]
outputs = self.flows[i].infer(inputs, conds, queues)
log_s = outputs[:, 0:1, :, :]
b = outputs[:, 1:, :, :]
audio_h = (audio[:, :, h:(h+1), :] - b) / \
fluid.layers.exp(log_s)
audio_list.append(audio_h)
audio = fluid.layers.concat(audio_list, axis=2)
# audio: [bs, n_group, time/n_group]
audio = fluid.layers.squeeze(audio, [1])
# audio: [bs, time]
audio = fluid.layers.reshape(
fluid.layers.transpose(audio, [0, 2, 1]), [audio.shape[0], -1])
return audio