From b93d6db94f035426648a978999e1e8934f854cb9 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 5 Mar 2020 02:51:16 +0000 Subject: [PATCH] add code in examples/clarinet --- examples/clarinet/synthesis.py | 151 ++++++++++++++++++++++ examples/clarinet/train.py | 220 +++++++++++++++++++++++++++++++++ examples/clarinet/utils.py | 96 ++++++++++++++ 3 files changed, 467 insertions(+) create mode 100644 examples/clarinet/synthesis.py create mode 100644 examples/clarinet/train.py create mode 100644 examples/clarinet/utils.py diff --git a/examples/clarinet/synthesis.py b/examples/clarinet/synthesis.py new file mode 100644 index 0000000..e227237 --- /dev/null +++ b/examples/clarinet/synthesis.py @@ -0,0 +1,151 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import ruamel.yaml +import random +from tqdm import tqdm +import pickle +import numpy as np +from tensorboardX import SummaryWriter + +import paddle.fluid.dygraph as dg +from paddle import fluid + +from parakeet.models.wavenet import WaveNet, UpsampleNet +from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet +from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo +from parakeet.utils.layer_tools import summary, freeze + +from utils import valid_model, eval_model, save_checkpoint, load_checkpoint, load_model +sys.path.append("../wavenet") +from data import LJSpeechMetaData, Transform, DataCollector + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="synthesize audio files from mel spectrogram in the validation set." + ) + parser.add_argument("--config", type=str, help="path of the config file.") + parser.add_argument( + "--device", type=int, default=-1, help="device to use.") + parser.add_argument("--data", type=str, help="path of LJspeech dataset.") + parser.add_argument( + "checkpoint", type=str, help="checkpoint to load from.") + parser.add_argument( + "output", type=str, default="experiment", help="path to save student.") + + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = ruamel.yaml.safe_load(f) + + ljspeech_meta = LJSpeechMetaData(args.data) + + data_config = config["data"] + sample_rate = data_config["sample_rate"] + n_fft = data_config["n_fft"] + win_length = data_config["win_length"] + hop_length = data_config["hop_length"] + n_mels = data_config["n_mels"] + train_clip_seconds = data_config["train_clip_seconds"] + transform = Transform(sample_rate, n_fft, win_length, hop_length, n_mels) + ljspeech = TransformDataset(ljspeech_meta, transform) + + valid_size = data_config["valid_size"] + ljspeech_valid = SliceDataset(ljspeech, 0, valid_size) + ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech)) + + teacher_config = config["teacher"] + n_loop = teacher_config["n_loop"] + n_layer = teacher_config["n_layer"] + filter_size = teacher_config["filter_size"] + context_size = 1 + n_layer * sum([filter_size**i for i in range(n_loop)]) + print("context size is {} samples".format(context_size)) + train_batch_fn = DataCollector(context_size, sample_rate, hop_length, + train_clip_seconds) + valid_batch_fn = DataCollector( + context_size, sample_rate, hop_length, train_clip_seconds, valid=True) + + batch_size = data_config["batch_size"] + train_cargo = DataCargo( + ljspeech_train, + train_batch_fn, + batch_size, + sampler=RandomSampler(ljspeech_train)) + + # only batch=1 for validation is enabled + valid_cargo = DataCargo( + ljspeech_valid, + valid_batch_fn, + batch_size=1, + sampler=SequentialSampler(ljspeech_valid)) + + if args.device == -1: + place = fluid.CPUPlace() + else: + place = fluid.CUDAPlace(args.device) + + with dg.guard(place): + # conditioner(upsampling net) + conditioner_config = config["conditioner"] + upsampling_factors = conditioner_config["upsampling_factors"] + upsample_net = UpsampleNet(upscale_factors=upsampling_factors) + freeze(upsample_net) + + residual_channels = teacher_config["residual_channels"] + loss_type = teacher_config["loss_type"] + output_dim = teacher_config["output_dim"] + log_scale_min = teacher_config["log_scale_min"] + assert loss_type == "mog" and output_dim == 3, \ + "the teacher wavenet should be a wavenet with single gaussian output" + + teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, + n_mels, filter_size, loss_type, log_scale_min) + # load & freeze upsample_net & teacher + freeze(teacher) + + student_config = config["student"] + n_loops = student_config["n_loops"] + n_layers = student_config["n_layers"] + student_residual_channels = student_config["residual_channels"] + student_filter_size = student_config["filter_size"] + student_log_scale_min = student_config["log_scale_min"] + student = ParallelWaveNet(n_loops, n_layers, student_residual_channels, + n_mels, student_filter_size) + + stft_config = config["stft"] + stft = STFT( + n_fft=stft_config["n_fft"], + hop_length=stft_config["hop_length"], + win_length=stft_config["win_length"]) + + lmd = config["loss"]["lmd"] + model = Clarinet(upsample_net, teacher, student, stft, + student_log_scale_min, lmd) + summary(model) + load_model(model, args.checkpoint) + + # loader + train_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + train_loader.set_batch_generator(train_cargo, place) + + valid_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + valid_loader.set_batch_generator(valid_cargo, place) + + if not os.path.exists(args.output): + os.makedirs(args.output) + eval_model(model, valid_loader, args.output, sample_rate) diff --git a/examples/clarinet/train.py b/examples/clarinet/train.py new file mode 100644 index 0000000..1ceb05c --- /dev/null +++ b/examples/clarinet/train.py @@ -0,0 +1,220 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import argparse +import ruamel.yaml +import random +from tqdm import tqdm +import pickle +import numpy as np +from tensorboardX import SummaryWriter + +import paddle.fluid.dygraph as dg +from paddle import fluid + +from parakeet.models.wavenet import WaveNet, UpsampleNet +from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet +from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo +from parakeet.utils.layer_tools import summary, freeze + +from utils import make_output_tree, valid_model, save_checkpoint, load_checkpoint, load_wavenet +sys.path.append("../wavenet") +from data import LJSpeechMetaData, Transform, DataCollector + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="train a clarinet model with LJspeech and a trained wavenet model." + ) + parser.add_argument("--config", type=str, help="path of the config file.") + parser.add_argument( + "--device", type=int, default=-1, help="device to use.") + parser.add_argument( + "--output", + type=str, + default="experiment", + help="path to save student.") + parser.add_argument("--data", type=str, help="path of LJspeech dataset.") + parser.add_argument("--resume", type=str, help="checkpoint to load from.") + parser.add_argument( + "--wavenet", type=str, help="wavenet checkpoint to use.") + args = parser.parse_args() + with open(args.config, 'rt') as f: + config = ruamel.yaml.safe_load(f) + + ljspeech_meta = LJSpeechMetaData(args.data) + + data_config = config["data"] + sample_rate = data_config["sample_rate"] + n_fft = data_config["n_fft"] + win_length = data_config["win_length"] + hop_length = data_config["hop_length"] + n_mels = data_config["n_mels"] + train_clip_seconds = data_config["train_clip_seconds"] + transform = Transform(sample_rate, n_fft, win_length, hop_length, n_mels) + ljspeech = TransformDataset(ljspeech_meta, transform) + + valid_size = data_config["valid_size"] + ljspeech_valid = SliceDataset(ljspeech, 0, valid_size) + ljspeech_train = SliceDataset(ljspeech, valid_size, len(ljspeech)) + + teacher_config = config["teacher"] + n_loop = teacher_config["n_loop"] + n_layer = teacher_config["n_layer"] + filter_size = teacher_config["filter_size"] + context_size = 1 + n_layer * sum([filter_size**i for i in range(n_loop)]) + print("context size is {} samples".format(context_size)) + train_batch_fn = DataCollector(context_size, sample_rate, hop_length, + train_clip_seconds) + valid_batch_fn = DataCollector( + context_size, sample_rate, hop_length, train_clip_seconds, valid=True) + + batch_size = data_config["batch_size"] + train_cargo = DataCargo( + ljspeech_train, + train_batch_fn, + batch_size, + sampler=RandomSampler(ljspeech_train)) + + # only batch=1 for validation is enabled + valid_cargo = DataCargo( + ljspeech_valid, + valid_batch_fn, + batch_size=1, + sampler=SequentialSampler(ljspeech_valid)) + + make_output_tree(args.output) + + if args.device == -1: + place = fluid.CPUPlace() + else: + place = fluid.CUDAPlace(args.device) + + with dg.guard(place): + # conditioner(upsampling net) + conditioner_config = config["conditioner"] + upsampling_factors = conditioner_config["upsampling_factors"] + upsample_net = UpsampleNet(upscale_factors=upsampling_factors) + freeze(upsample_net) + + residual_channels = teacher_config["residual_channels"] + loss_type = teacher_config["loss_type"] + output_dim = teacher_config["output_dim"] + log_scale_min = teacher_config["log_scale_min"] + assert loss_type == "mog" and output_dim == 3, \ + "the teacher wavenet should be a wavenet with single gaussian output" + + teacher = WaveNet(n_loop, n_layer, residual_channels, output_dim, + n_mels, filter_size, loss_type, log_scale_min) + freeze(teacher) + + student_config = config["student"] + n_loops = student_config["n_loops"] + n_layers = student_config["n_layers"] + student_residual_channels = student_config["residual_channels"] + student_filter_size = student_config["filter_size"] + student_log_scale_min = student_config["log_scale_min"] + student = ParallelWaveNet(n_loops, n_layers, student_residual_channels, + n_mels, student_filter_size) + + stft_config = config["stft"] + stft = STFT( + n_fft=stft_config["n_fft"], + hop_length=stft_config["hop_length"], + win_length=stft_config["win_length"]) + + lmd = config["loss"]["lmd"] + model = Clarinet(upsample_net, teacher, student, stft, + student_log_scale_min, lmd) + summary(model) + + # optim + train_config = config["train"] + learning_rate = train_config["learning_rate"] + anneal_rate = train_config["anneal_rate"] + anneal_interval = train_config["anneal_interval"] + lr_scheduler = dg.ExponentialDecay( + learning_rate, anneal_interval, anneal_rate, staircase=True) + optim = fluid.optimizer.Adam( + lr_scheduler, parameter_list=model.parameters()) + gradiant_max_norm = train_config["gradient_max_norm"] + clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( + gradiant_max_norm) + + assert args.wavenet or args.resume, "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended." + if args.wavenet: + load_wavenet(model, args.wavenet) + + if args.resume: + load_checkpoint(model, optim, args.resume) + + # loader + train_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + train_loader.set_batch_generator(train_cargo, place) + + valid_loader = fluid.io.DataLoader.from_generator( + capacity=10, return_list=True) + valid_loader.set_batch_generator(valid_cargo, place) + + # train + max_iterations = train_config["max_iterations"] + checkpoint_interval = train_config["checkpoint_interval"] + eval_interval = train_config["eval_interval"] + checkpoint_dir = os.path.join(args.output, "checkpoints") + state_dir = os.path.join(args.output, "states") + log_dir = os.path.join(args.output, "log") + writer = SummaryWriter(log_dir) + + # training loop + global_step = 1 + global_epoch = 1 + while global_step < max_iterations: + epoch_loss = 0. + for j, batch in tqdm(enumerate(train_loader), desc="[train]"): + audios, mels, audio_starts = batch + model.train() + loss_dict = model( + audios, mels, audio_starts, clip_kl=global_step > 500) + + writer.add_scalar("learning_rate", + optim._learning_rate.step().numpy()[0], + global_step) + for k, v in loss_dict.items(): + writer.add_scalar("loss/{}".format(k), + v.numpy()[0], global_step) + + l = loss_dict["loss"] + step_loss = l.numpy()[0] + print("[train] loss: {:<8.6f}".format(step_loss)) + epoch_loss += step_loss + + l.backward() + optim.minimize(l, grad_clip=clipper) + optim.clear_gradients() + + if global_step % eval_interval == 0: + # evaluate on valid dataset + valid_model(model, valid_loader, state_dir, global_step, + sample_rate) + if global_step % checkpoint_interval == 0: + save_checkpoint(model, optim, checkpoint_dir, global_step) + + global_step += 1 + + # epoch loss + average_loss = epoch_loss / j + writer.add_scalar("average_loss", average_loss, global_epoch) + global_epoch += 1 diff --git a/examples/clarinet/utils.py b/examples/clarinet/utils.py new file mode 100644 index 0000000..a0ec746 --- /dev/null +++ b/examples/clarinet/utils.py @@ -0,0 +1,96 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import soundfile as sf +from tensorboardX import SummaryWriter +from collections import OrderedDict + +from paddle import fluid +import paddle.fluid.dygraph as dg + + +def make_output_tree(output_dir): + checkpoint_dir = os.path.join(output_dir, "checkpoints") + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + state_dir = os.path.join(output_dir, "states") + if not os.path.exists(state_dir): + os.makedirs(state_dir) + + +def valid_model(model, valid_loader, output_dir, global_step, sample_rate): + model.eval() + for i, batch in enumerate(valid_loader): + # print("sentence {}".format(i)) + path = os.path.join(output_dir, + "step_{}_sentence_{}.wav".format(global_step, i)) + audio_clips, mel_specs, audio_starts = batch + wav_var = model.synthesis(mel_specs) + wav_np = wav_var.numpy()[0] + sf.write(path, wav_np, samplerate=sample_rate) + print("generated {}".format(path)) + + +def eval_model(model, valid_loader, output_dir, sample_rate): + model.eval() + for i, batch in enumerate(valid_loader): + # print("sentence {}".format(i)) + path = os.path.join(output_dir, "sentence_{}.wav".format(i)) + audio_clips, mel_specs, audio_starts = batch + wav_var = model.synthesis(mel_specs) + wav_np = wav_var.numpy()[0] + sf.write(path, wav_np, samplerate=sample_rate) + print("generated {}".format(path)) + + +def save_checkpoint(model, optim, checkpoint_dir, global_step): + path = os.path.join(checkpoint_dir, "step_{}".format(global_step)) + dg.save_dygraph(model.state_dict(), path) + print("saving model to {}".format(path + ".pdparams")) + if optim: + dg.save_dygraph(optim.state_dict(), path) + print("saving optimizer to {}".format(path + ".pdopt")) + + +def load_model(model, path): + model_dict, _ = dg.load_dygraph(path) + model.state_dict(model_dict) + print("loaded model from {}.pdparams".format(path)) + + +def load_checkpoint(model, optim, path): + model_dict, optim_dict = dg.load_dygraph(path) + model.state_dict(model_dict) + print("loaded model from {}.pdparams".format(path)) + if optim_dict: + optim.set_dict(optim_dict) + print("loaded optimizer from {}.pdparams".format(path)) + + +def load_wavenet(model, path): + wavenet_dict, _ = dg.load_dygraph(path) + encoder_dict = OrderedDict() + teacher_dict = OrderedDict() + for k, v in wavenet_dict.items(): + if k.startswith("encoder."): + encoder_dict[k.split('.', 1)[1]] = v + else: + # k starts with "decoder." + teacher_dict[k.split('.', 1)[1]] = v + + model.encoder.set_dict(encoder_dict) + model.teacher.set_dict(teacher_dict) + print("loaded the encoder part and teacher part from wavenet model.")