diff --git a/README.md b/README.md index 9eddef4..aef1963 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Parakeet aims to provide a flexible, efficient and state-of-the-art text-to-spee In particular, it features the latest [WaveFlow] (https://arxiv.org/abs/1912.01219) model proposed by Baidu Research. - WaveFlow can synthesize 22.05 kHz high-fidelity speech around 40x faster than real-time on a Nvidia V100 GPU without engineered inference kernels, which is faster than [WaveGlow] (https://github.com/NVIDIA/waveglow) and serveral orders of magnitude faster than WaveNet. - WaveFlow is a small-footprint flow-based model for raw audio. It has only 5.9M parameters, which is 15x smalller than WaveGlow (87.9M) and comparable to WaveNet (4.6M). -- WaveFlow is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in Parallel WaveNet and ClariNet, which simplifies the training pipeline and reduces the cost of development. +- WaveFlow is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in Parallel WaveNet and ClariNet, which simplifies the training pipeline and reduces the cost of development. ### Setup @@ -45,8 +45,10 @@ nltk.download("cmudict") - [Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning](https://arxiv.org/abs/1710.07654) - [Neural Speech Synthesis with Transformer Network](https://arxiv.org/abs/1809.08895) -- [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263). +- [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263) - [WaveFlow: A Compact Flow-based Model for Raw Audio](https://arxiv.org/abs/1912.01219) +- [WaveNet: A Generative Model for Raw Audio](https://arxiv.org/abs/1609.03499) +- [ClariNet: Parallel Wave Generation in End-to-End Text-to-Speech](https://arxiv.org/abs/1807.07281) ## Examples @@ -54,6 +56,8 @@ nltk.download("cmudict") - [Train a TransformerTTS model with ljspeech dataset](./examples/transformer_tts) - [Train a FastSpeech model with ljspeech dataset](./examples/fastspeech) - [Train a WaveFlow model with ljspeech dataset](./examples/waveflow) +- [Train a WaveNet model with ljspeech dataset](./examples/wavenet) +- [Train a Clarinet model with ljspeech dataset](./examples/clarinet) ## Copyright and License diff --git a/examples/clarinet/configs/clarinet_ljspeech.yaml b/examples/clarinet/configs/clarinet_ljspeech.yaml index 7ceedcc..2e571e5 100644 --- a/examples/clarinet/configs/clarinet_ljspeech.yaml +++ b/examples/clarinet/configs/clarinet_ljspeech.yaml @@ -1,5 +1,5 @@ data: - batch_size: 4 + batch_size: 8 train_clip_seconds: 0.5 sample_rate: 22050 hop_length: 256 diff --git a/examples/clarinet/synthesis.py b/examples/clarinet/synthesis.py new file mode 100644 index 0000000..e227237 --- /dev/null +++ b/examples/clarinet/synthesis.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import ruamel.yaml +import random +from tqdm import tqdm +import pickle +import numpy as np +from tensorboardX import SummaryWriter + +import paddle.fluid.dygraph as dg +from paddle import fluid + +from parakeet.models.wavenet import WaveNet, UpsampleNet +from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet +from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo +from parakeet.utils.layer_tools import summary, freeze + +from utils import valid_model, eval_model, save_checkpoint, load_checkpoint, load_model +sys.path.append("../wavenet") +from data import LJSpeechMetaData, Transform, DataCollector + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="synthesize audio files from mel spectrogram in the validation set." + ) + parser.add_argument("--config", type=str, help="path of the config file.") + parser.add_argument( + "--device", type=int, default=-1, help="device to use.") + parser.add_argument("--data", type=str, help="path of LJspeech dataset.") + parser.add_argument( + "checkpoint", type=str, help="checkpoint to load from.") + parser.add_argument( + "output", type=str, default="experiment", help="path to save student.") + + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = ruamel.yaml.safe_load(f) + + ljspeech_meta = LJSpeechMetaData(args.data) + + data_config = config["data"] + sample_rate = data_config["sample_rate"] + n_fft = data_config["n_fft"] + win_length = data_config["win_length"] + hop_length = data_config["hop_length"] + n_mels = data_config["n_mels"] + train_clip_seconds = data_config["train_clip_seconds"] + transform = Transform(sample_rate, n_fft, win_length, hop_length, n_mels) + ljspeech = TransformDataset(ljspeech_meta, transform) + + valid_size = data_config["valid_size"] + ljspeech_valid = SliceDataset(ljspeech, 0, valid_size) + ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech)) + + teacher_config = config["teacher"] + n_loop = teacher_config["n_loop"] + n_layer = teacher_config["n_layer"] + filter_size = teacher_config["filter_size"] + context_size = 1 + n_layer * sum([filter_size**i for i in range(n_loop)]) + print("context size is {} samples".format(context_size)) + train_batch_fn = DataCollector(context_size, sample_rate, hop_length, + train_clip_seconds) + valid_batch_fn = DataCollector( + context_size, sample_rate, hop_length, train_clip_seconds, valid=True) + + batch_size = data_config["batch_size"] + train_cargo = DataCargo( + ljspeech_train, + train_batch_fn, + batch_size, + sampler=RandomSampler(ljspeech_train)) + + # only batch=1 for validation is enabled + valid_cargo = DataCargo( + ljspeech_valid, + valid_batch_fn, + batch_size=1, + sampler=SequentialSampler(ljspeech_valid)) + + if args.device == -1: + place = fluid.CPUPlace() + else: + place = fluid.CUDAPlace(args.device) + + with dg.guard(place): + # conditioner(upsampling net) + conditioner_config = config["conditioner"] + upsampling_factors = conditioner_config["upsampling_factors"] + upsample_net = UpsampleNet(upscale_factors=upsampling_factors) + freeze(upsample_net) + + residual_channels = teacher_config["residual_channels"] + loss_type = teacher_config["loss_type"] + output_dim = teacher_config["output_dim"] + log_scale_min = teacher_config["log_scale_min"] + assert loss_type == "mog" and output_dim == 3, \ + "the teacher wavenet should be a wavenet with single gaussian output" + + teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, + n_mels, filter_size, loss_type, log_scale_min) + # load & freeze upsample_net & teacher + freeze(teacher) + + student_config = config["student"] + n_loops = student_config["n_loops"] + n_layers = student_config["n_layers"] + student_residual_channels = student_config["residual_channels"] + student_filter_size = student_config["filter_size"] + student_log_scale_min = student_config["log_scale_min"] + student = ParallelWaveNet(n_loops, n_layers, student_residual_channels, + n_mels, student_filter_size) + + stft_config = config["stft"] + stft = STFT( + n_fft=stft_config["n_fft"], + hop_length=stft_config["hop_length"], + win_length=stft_config["win_length"]) + + lmd = config["loss"]["lmd"] + model = Clarinet(upsample_net, teacher, student, stft, + student_log_scale_min, lmd) + summary(model) + load_model(model, args.checkpoint) + + # loader + train_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + train_loader.set_batch_generator(train_cargo, place) + + valid_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + valid_loader.set_batch_generator(valid_cargo, place) + + if not os.path.exists(args.output): + os.makedirs(args.output) + eval_model(model, valid_loader, args.output, sample_rate) diff --git a/examples/clarinet/train.py b/examples/clarinet/train.py new file mode 100644 index 0000000..1ceb05c --- /dev/null +++ b/examples/clarinet/train.py @@ -0,0 +1,220 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import ruamel.yaml +import random +from tqdm import tqdm +import pickle +import numpy as np +from tensorboardX import SummaryWriter + +import paddle.fluid.dygraph as dg +from paddle import fluid + +from parakeet.models.wavenet import WaveNet, UpsampleNet +from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet +from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo +from parakeet.utils.layer_tools import summary, freeze + +from utils import make_output_tree, valid_model, save_checkpoint, load_checkpoint, load_wavenet +sys.path.append("../wavenet") +from data import LJSpeechMetaData, Transform, DataCollector + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="train a clarinet model with LJspeech and a trained wavenet model." + ) + parser.add_argument("--config", type=str, help="path of the config file.") + parser.add_argument( + "--device", type=int, default=-1, help="device to use.") + parser.add_argument( + "--output", + type=str, + default="experiment", + help="path to save student.") + parser.add_argument("--data", type=str, help="path of LJspeech dataset.") + parser.add_argument("--resume", type=str, help="checkpoint to load from.") + parser.add_argument( + "--wavenet", type=str, help="wavenet checkpoint to use.") + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = ruamel.yaml.safe_load(f) + + ljspeech_meta = LJSpeechMetaData(args.data) + + data_config = config["data"] + sample_rate = data_config["sample_rate"] + n_fft = data_config["n_fft"] + win_length = data_config["win_length"] + hop_length = data_config["hop_length"] + n_mels = data_config["n_mels"] + train_clip_seconds = data_config["train_clip_seconds"] + transform = Transform(sample_rate, n_fft, win_length, hop_length, n_mels) + ljspeech = TransformDataset(ljspeech_meta, transform) + + valid_size = data_config["valid_size"] + ljspeech_valid = SliceDataset(ljspeech, 0, valid_size) + ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech)) + + teacher_config = config["teacher"] + n_loop = teacher_config["n_loop"] + n_layer = teacher_config["n_layer"] + filter_size = teacher_config["filter_size"] + context_size = 1 + n_layer * sum([filter_size**i for i in range(n_loop)]) + print("context size is {} samples".format(context_size)) + train_batch_fn = DataCollector(context_size, sample_rate, hop_length, + train_clip_seconds) + valid_batch_fn = DataCollector( + context_size, sample_rate, hop_length, train_clip_seconds, valid=True) + + batch_size = data_config["batch_size"] + train_cargo = DataCargo( + ljspeech_train, + train_batch_fn, + batch_size, + sampler=RandomSampler(ljspeech_train)) + + # only batch=1 for validation is enabled + valid_cargo = DataCargo( + ljspeech_valid, + valid_batch_fn, + batch_size=1, + sampler=SequentialSampler(ljspeech_valid)) + + make_output_tree(args.output) + + if args.device == -1: + place = fluid.CPUPlace() + else: + place = fluid.CUDAPlace(args.device) + + with dg.guard(place): + # conditioner(upsampling net) + conditioner_config = config["conditioner"] + upsampling_factors = conditioner_config["upsampling_factors"] + upsample_net = UpsampleNet(upscale_factors=upsampling_factors) + freeze(upsample_net) + + residual_channels = teacher_config["residual_channels"] + loss_type = teacher_config["loss_type"] + output_dim = teacher_config["output_dim"] + log_scale_min = teacher_config["log_scale_min"] + assert loss_type == "mog" and output_dim == 3, \ + "the teacher wavenet should be a wavenet with single gaussian output" + + teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, + n_mels, filter_size, loss_type, log_scale_min) + freeze(teacher) + + student_config = config["student"] + n_loops = student_config["n_loops"] + n_layers = student_config["n_layers"] + student_residual_channels = student_config["residual_channels"] + student_filter_size = student_config["filter_size"] + student_log_scale_min = student_config["log_scale_min"] + student = ParallelWaveNet(n_loops, n_layers, student_residual_channels, + n_mels, student_filter_size) + + stft_config = config["stft"] + stft = STFT( + n_fft=stft_config["n_fft"], + hop_length=stft_config["hop_length"], + win_length=stft_config["win_length"]) + + lmd = config["loss"]["lmd"] + model = Clarinet(upsample_net, teacher, student, stft, + student_log_scale_min, lmd) + summary(model) + + # optim + train_config = config["train"] + learning_rate = train_config["learning_rate"] + anneal_rate = train_config["anneal_rate"] + anneal_interval = train_config["anneal_interval"] + lr_scheduler = dg.ExponentialDecay( + learning_rate, anneal_interval, anneal_rate, staircase=True) + optim = fluid.optimizer.Adam( + lr_scheduler, parameter_list=model.parameters()) + gradiant_max_norm = train_config["gradient_max_norm"] + clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( + gradiant_max_norm) + + assert args.wavenet or args.resume, "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended." + if args.wavenet: + load_wavenet(model, args.wavenet) + + if args.resume: + load_checkpoint(model, optim, args.resume) + + # loader + train_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + train_loader.set_batch_generator(train_cargo, place) + + valid_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + valid_loader.set_batch_generator(valid_cargo, place) + + # train + max_iterations = train_config["max_iterations"] + checkpoint_interval = train_config["checkpoint_interval"] + eval_interval = train_config["eval_interval"] + checkpoint_dir = os.path.join(args.output, "checkpoints") + state_dir = os.path.join(args.output, "states") + log_dir = os.path.join(args.output, "log") + writer = SummaryWriter(log_dir) + + # training loop + global_step = 1 + global_epoch = 1 + while global_step < max_iterations: + epoch_loss = 0. + for j, batch in tqdm(enumerate(train_loader), desc="[train]"): + audios, mels, audio_starts = batch + model.train() + loss_dict = model( + audios, mels, audio_starts, clip_kl=global_step > 500) + + writer.add_scalar("learning_rate", + optim._learning_rate.step().numpy()[0], + global_step) + for k, v in loss_dict.items(): + writer.add_scalar("loss/{}".format(k), + v.numpy()[0], global_step) + + l = loss_dict["loss"] + step_loss = l.numpy()[0] + print("[train] loss: {:<8.6f}".format(step_loss)) + epoch_loss += step_loss + + l.backward() + optim.minimize(l, grad_clip=clipper) + optim.clear_gradients() + + if global_step % eval_interval == 0: + # evaluate on valid dataset + valid_model(model, valid_loader, state_dir, global_step, + sample_rate) + if global_step % checkpoint_interval == 0: + save_checkpoint(model, optim, checkpoint_dir, global_step) + + global_step += 1 + + # epoch loss + average_loss = epoch_loss / j + writer.add_scalar("average_loss", average_loss, global_epoch) + global_epoch += 1 diff --git a/examples/clarinet/utils.py b/examples/clarinet/utils.py new file mode 100644 index 0000000..a0ec746 --- /dev/null +++ b/examples/clarinet/utils.py @@ -0,0 +1,96 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import soundfile as sf +from tensorboardX import SummaryWriter +from collections import OrderedDict + +from paddle import fluid +import paddle.fluid.dygraph as dg + + +def make_output_tree(output_dir): + checkpoint_dir = os.path.join(output_dir, "checkpoints") + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + state_dir = os.path.join(output_dir, "states") + if not os.path.exists(state_dir): + os.makedirs(state_dir) + + +def valid_model(model, valid_loader, output_dir, global_step, sample_rate): + model.eval() + for i, batch in enumerate(valid_loader): + # print("sentence {}".format(i)) + path = os.path.join(output_dir, + "step_{}_sentence_{}.wav".format(global_step, i)) + audio_clips, mel_specs, audio_starts = batch + wav_var = model.synthesis(mel_specs) + wav_np = wav_var.numpy()[0] + sf.write(path, wav_np, samplerate=sample_rate) + print("generated {}".format(path)) + + +def eval_model(model, valid_loader, output_dir, sample_rate): + model.eval() + for i, batch in enumerate(valid_loader): + # print("sentence {}".format(i)) + path = os.path.join(output_dir, "sentence_{}.wav".format(i)) + audio_clips, mel_specs, audio_starts = batch + wav_var = model.synthesis(mel_specs) + wav_np = wav_var.numpy()[0] + sf.write(path, wav_np, samplerate=sample_rate) + print("generated {}".format(path)) + + +def save_checkpoint(model, optim, checkpoint_dir, global_step): + path = os.path.join(checkpoint_dir, "step_{}".format(global_step)) + dg.save_dygraph(model.state_dict(), path) + print("saving model to {}".format(path + ".pdparams")) + if optim: + dg.save_dygraph(optim.state_dict(), path) + print("saving optimizer to {}".format(path + ".pdopt")) + + +def load_model(model, path): + model_dict, _ = dg.load_dygraph(path) + model.state_dict(model_dict) + print("loaded model from {}.pdparams".format(path)) + + +def load_checkpoint(model, optim, path): + model_dict, optim_dict = dg.load_dygraph(path) + model.state_dict(model_dict) + print("loaded model from {}.pdparams".format(path)) + if optim_dict: + optim.set_dict(optim_dict) + print("loaded optimizer from {}.pdparams".format(path)) + + +def load_wavenet(model, path): + wavenet_dict, _ = dg.load_dygraph(path) + encoder_dict = OrderedDict() + teacher_dict = OrderedDict() + for k, v in wavenet_dict.items(): + if k.startswith("encoder."): + encoder_dict[k.split('.', 1)[1]] = v + else: + # k starts with "decoder." + teacher_dict[k.split('.', 1)[1]] = v + + model.encoder.set_dict(encoder_dict) + model.teacher.set_dict(teacher_dict) + print("loaded the encoder part and teacher part from wavenet model.") diff --git a/examples/deepvoice3/README.md b/examples/deepvoice3/README.md index 80434ce..fa7a5e4 100644 --- a/examples/deepvoice3/README.md +++ b/examples/deepvoice3/README.md @@ -23,7 +23,7 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed ```text ├── data.py data_processing -├── ljspeech.yaml (example) configuration file +├── configs/ (example) configuration files ├── sentences.txt sample sentences ├── synthesis.py script to synthesize waveform from text ├── train.py script to train a model @@ -72,7 +72,7 @@ optional arguments: Example script: ```bash -python train.py --config=./ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 +python train.py --config=configs/ljspeech.yaml --data=./LJSpeech-1.1/ --output=experiment --device=0 ``` You can monitor training log via tensorboard, using the script below. @@ -110,5 +110,5 @@ optional arguments: Example script: ```bash -python synthesis.py --config=./ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated +python synthesis.py --config=configs/ljspeech.yaml --device=0 experiment/checkpoints/model_step_005000000 sentences.txt generated ``` diff --git a/examples/deepvoice3/ljspeech.yaml b/examples/deepvoice3/configs/ljspeech.yaml similarity index 100% rename from examples/deepvoice3/ljspeech.yaml rename to examples/deepvoice3/configs/ljspeech.yaml diff --git a/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml b/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml index a848a52..68936ee 100644 --- a/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml +++ b/examples/wavenet/configs/wavenet_mixture_of_gaussians.yaml @@ -1,5 +1,5 @@ data: - batch_size: 4 + batch_size: 16 train_clip_seconds: 0.5 sample_rate: 22050 hop_length: 256 @@ -30,7 +30,7 @@ train: snap_interval: 10000 eval_interval: 10000 - max_iterations: 200000 + max_iterations: 2000000 diff --git a/examples/wavenet/configs/wavenet_single_gaussian.yaml b/examples/wavenet/configs/wavenet_single_gaussian.yaml index 8e33349..484db0b 100644 --- a/examples/wavenet/configs/wavenet_single_gaussian.yaml +++ b/examples/wavenet/configs/wavenet_single_gaussian.yaml @@ -1,5 +1,5 @@ data: - batch_size: 4 + batch_size: 16 train_clip_seconds: 0.5 sample_rate: 22050 hop_length: 256 @@ -30,7 +30,7 @@ train: snap_interval: 10000 eval_interval: 10000 - max_iterations: 200000 + max_iterations: 2000000 diff --git a/examples/wavenet/configs/wavenet_softmax.yaml b/examples/wavenet/configs/wavenet_softmax.yaml index 98018ee..7e9d756 100644 --- a/examples/wavenet/configs/wavenet_softmax.yaml +++ b/examples/wavenet/configs/wavenet_softmax.yaml @@ -1,5 +1,5 @@ data: - batch_size: 4 + batch_size: 16 train_clip_seconds: 0.5 sample_rate: 22050 hop_length: 256 @@ -30,7 +30,7 @@ train: snap_interval: 10000 eval_interval: 10000 - max_iterations: 200000 + max_iterations: 2000000