diff --git a/examples/clarinet/README.md b/examples/clarinet/README.md index 459e2f5..9b79897 100644 --- a/examples/clarinet/README.md +++ b/examples/clarinet/README.md @@ -28,24 +28,24 @@ Train the model using train.py, follow the usage displayed by `python train.py - ```text usage: train.py [-h] [--config CONFIG] [--device DEVICE] [--output OUTPUT] - [--data DATA] [--resume RESUME] [--wavenet WAVENET] + [--data DATA] [--checkpoint CHECKPOINT] [--wavenet WAVENET] train a ClariNet model with LJspeech and a trained WaveNet model. optional arguments: - -h, --help show this help message and exit - --config CONFIG path of the config file. - --device DEVICE device to use. - --output OUTPUT path to save student. - --data DATA path of LJspeech dataset. - --resume RESUME checkpoint to load from. - --wavenet WAVENET wavenet checkpoint to use. + -h, --help show this help message and exit + --config CONFIG path of the config file. + --device DEVICE device to use. + --output OUTPUT path to save student. + --data DATA path of LJspeech dataset. + --checkpoint CHECKPOINT checkpoint to load from. + --wavenet WAVENET wavenet checkpoint to use. ``` - `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config. - `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt). -- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig. -- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below. +- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig. +- `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below. ```text ├── checkpoints # checkpoint @@ -53,6 +53,8 @@ optional arguments: └── log # tensorboard log ``` +If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training. + - `--device` is the device (gpu id) to use for training. `-1` means CPU. - `--wavenet` is the path of the wavenet checkpoint to load. If you do not specify `--resume`, then this must be provided. diff --git a/examples/clarinet/synthesis.py b/examples/clarinet/synthesis.py index db12035..ce16fc1 100644 --- a/examples/clarinet/synthesis.py +++ b/examples/clarinet/synthesis.py @@ -31,7 +31,7 @@ from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo from parakeet.utils.layer_tools import summary, freeze -from utils import valid_model, eval_model, save_checkpoint, load_checkpoint, load_model +from utils import valid_model, eval_model, load_model sys.path.append("../wavenet") from data import LJSpeechMetaData, Transform, DataCollector diff --git a/examples/clarinet/train.py b/examples/clarinet/train.py index c6039b3..dcfff9b 100644 --- a/examples/clarinet/train.py +++ b/examples/clarinet/train.py @@ -30,14 +30,15 @@ from parakeet.models.wavenet import WaveNet, UpsampleNet from parakeet.models.clarinet import STFT, Clarinet, ParallelWaveNet from parakeet.data import TransformDataset, SliceDataset, RandomSampler, SequentialSampler, DataCargo from parakeet.utils.layer_tools import summary, freeze +from parakeet.utils import io -from utils import make_output_tree, valid_model, save_checkpoint, load_checkpoint, load_wavenet +from utils import make_output_tree, valid_model, load_wavenet sys.path.append("../wavenet") from data import LJSpeechMetaData, Transform, DataCollector if __name__ == "__main__": parser = argparse.ArgumentParser( - description="train a clarinet model with LJspeech and a trained wavenet model." + description="train a ClariNet model with LJspeech and a trained WaveNet model." ) parser.add_argument("--config", type=str, help="path of the config file.") parser.add_argument( @@ -48,13 +49,18 @@ if __name__ == "__main__": default="experiment", help="path to save student.") parser.add_argument("--data", type=str, help="path of LJspeech dataset.") - parser.add_argument("--resume", type=str, help="checkpoint to load from.") + parser.add_argument( + "--checkpoint", type=str, help="checkpoint to load from.") parser.add_argument( "--wavenet", type=str, help="wavenet checkpoint to use.") args = parser.parse_args() with open(args.config, 'rt') as f: config = ruamel.yaml.safe_load(f) + print("Command Line args: ") + for k, v in vars(args).items(): + print("{}: {}".format(k, v)) + ljspeech_meta = LJSpeechMetaData(args.data) data_config = config["data"] @@ -154,12 +160,38 @@ if __name__ == "__main__": clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( gradiant_max_norm) - assert args.wavenet or args.resume, "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended." - if args.wavenet: + # train + max_iterations = train_config["max_iterations"] + checkpoint_interval = train_config["checkpoint_interval"] + eval_interval = train_config["eval_interval"] + checkpoint_dir = os.path.join(args.output, "checkpoints") + state_dir = os.path.join(args.output, "states") + log_dir = os.path.join(args.output, "log") + writer = SummaryWriter(log_dir) + + # load wavenet/checkpoint, determine iterations done + if args.checkpoint is not None: + iteration = int(os.path.basename(args.checkpoint).split('-')[-1]) + else: + iteration = io.load_latest_checkpoint(checkpoint_dir) + + if iteration == 0 and args.wavenet is None: + raise Exception( + "you should load from a trained wavenet or resume training; training without a trained wavenet is not recommended." + ) + + if args.wavenet is not None and iteration > 0: + if args.checkpoint is None: + print("Resume training, --wavenet ignored") + else: + print("--checkpoint provided, --wavenet ignored") + + if args.wavenet is not None and iteration == 0: load_wavenet(model, args.wavenet) - if args.resume: - load_checkpoint(model, optim, args.resume) + # it may overwrite the wavenet loaded + io.load_parameters( + checkpoint_dir, 0, model, optim, file_path=args.checkpoint) # loader train_loader = fluid.io.DataLoader.from_generator( @@ -170,52 +202,43 @@ if __name__ == "__main__": capacity=10, return_list=True) valid_loader.set_batch_generator(valid_cargo, place) - # train - max_iterations = train_config["max_iterations"] - checkpoint_interval = train_config["checkpoint_interval"] - eval_interval = train_config["eval_interval"] - checkpoint_dir = os.path.join(args.output, "checkpoints") - state_dir = os.path.join(args.output, "states") - log_dir = os.path.join(args.output, "log") - writer = SummaryWriter(log_dir) - # training loop - global_step = 1 - global_epoch = 1 + global_step = iteration + 1 + iterator = iter(tqdm(train_loader)) while global_step < max_iterations: - epoch_loss = 0. - for j, batch in tqdm(enumerate(train_loader), desc="[train]"): - audios, mels, audio_starts = batch - model.train() - loss_dict = model( - audios, mels, audio_starts, clip_kl=global_step > 500) + try: + batch = next(iterator) + except StopIteration as e: + iterator = iter(tqdm(train_loader)) + batch = next(iterator) - writer.add_scalar("learning_rate", - optim._learning_rate.step().numpy()[0], - global_step) - for k, v in loss_dict.items(): - writer.add_scalar("loss/{}".format(k), - v.numpy()[0], global_step) + audios, mels, audio_starts = batch + model.train() + loss_dict = model( + audios, mels, audio_starts, clip_kl=global_step > 500) - l = loss_dict["loss"] - step_loss = l.numpy()[0] - print("[train] loss: {:<8.6f}".format(step_loss)) - epoch_loss += step_loss + writer.add_scalar("learning_rate", + optim._learning_rate.step().numpy()[0], + global_step) + for k, v in loss_dict.items(): + writer.add_scalar("loss/{}".format(k), + v.numpy()[0], global_step) - l.backward() - optim.minimize(l, grad_clip=clipper) - optim.clear_gradients() + l = loss_dict["loss"] + step_loss = l.numpy()[0] + print("[train] loss: {:<8.6f}".format(step_loss)) - if global_step % eval_interval == 0: - # evaluate on valid dataset - valid_model(model, valid_loader, state_dir, global_step, - sample_rate) - if global_step % checkpoint_interval == 0: - save_checkpoint(model, optim, checkpoint_dir, global_step) + l.backward() + optim.minimize(l, grad_clip=clipper) + optim.clear_gradients() - global_step += 1 + if global_step % eval_interval == 0: + # evaluate on valid dataset + valid_model(model, valid_loader, state_dir, global_step, + sample_rate) + if global_step % checkpoint_interval == 0: + io.save_latest_parameters(checkpoint_dir, global_step, model, + optim) + io.save_latest_checkpoint(checkpoint_dir, global_step) - # epoch loss - average_loss = epoch_loss / j - writer.add_scalar("average_loss", average_loss, global_epoch) - global_epoch += 1 + global_step += 1 diff --git a/examples/deepvoice3/README.md b/examples/deepvoice3/README.md index fa7a5e4..7c2ad77 100644 --- a/examples/deepvoice3/README.md +++ b/examples/deepvoice3/README.md @@ -35,26 +35,23 @@ The model consists of an encoder, a decoder and a converter (and a speaker embed Train the model using train.py, follow the usage displayed by `python train.py --help`. ```text -usage: train.py [-h] [-c CONFIG] [-s DATA] [-r RESUME] [-o OUTPUT] [-g DEVICE] +usage: train.py [-h] [-c CONFIG] [-s DATA] [--checkpoint CHECKPOINT] + [-o OUTPUT] [-g DEVICE] Train a Deep Voice 3 model with LJSpeech dataset. optional arguments: - -h, --help show this help message and exit - -c CONFIG, --config CONFIG - experimrnt config - -s DATA, --data DATA The path of the LJSpeech dataset. - -r RESUME, --resume RESUME - checkpoint to load - -o OUTPUT, --output OUTPUT - The directory to save result. - -g DEVICE, --device DEVICE - device to use + -h, --help show this help message and exit + -c CONFIG, --config CONFIG experimrnt config + -s DATA, --data DATA The path of the LJSpeech dataset. + --checkpoint CHECKPOINT checkpoint to load + -o OUTPUT, --output OUTPUT The directory to save result. + -g DEVICE, --device DEVICE device to use ``` - `--config` is the configuration file to use. The provided `ljspeech.yaml` can be used directly. And you can change some values in the configuration file and train the model with a different config. - `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt). -- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig. +- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before trainig. - `--output` is the directory to save results, all results are saved in this directory. The structure of the output directory is shown below. ```text @@ -67,6 +64,8 @@ optional arguments: └── waveform # waveform (.wav files) ``` +If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training. + - `--device` is the device (gpu id) to use for training. `-1` means CPU. Example script: diff --git a/examples/deepvoice3/configs/ljspeech.yaml b/examples/deepvoice3/configs/ljspeech.yaml index 8aa6b5a..b270719 100644 --- a/examples/deepvoice3/configs/ljspeech.yaml +++ b/examples/deepvoice3/configs/ljspeech.yaml @@ -83,7 +83,7 @@ lr_scheduler: train: batch_size: 16 - epochs: 2000 + max_iteration: 2000000 snap_interval: 1000 eval_interval: 10000 diff --git a/examples/deepvoice3/synthesis.py b/examples/deepvoice3/synthesis.py index 0631dae..d3cd9f0 100644 --- a/examples/deepvoice3/synthesis.py +++ b/examples/deepvoice3/synthesis.py @@ -25,8 +25,9 @@ import paddle.fluid.dygraph as dg from tensorboardX import SummaryWriter from parakeet.g2p import en -from parakeet.utils.layer_tools import summary from parakeet.modules.weight_norm import WeightNormWrapper +from parakeet.utils.layer_tools import summary +from parakeet.utils.io import load_parameters from utils import make_model, eval_model, plot_alignment @@ -44,6 +45,10 @@ if __name__ == "__main__": with open(args.config, 'rt') as f: config = ruamel.yaml.safe_load(f) + print("Command Line Args: ") + for k, v in vars(args).items(): + print("{}: {}".format(k, v)) + if args.device == -1: place = fluid.CPUPlace() else: diff --git a/examples/deepvoice3/train.py b/examples/deepvoice3/train.py index 11f8407..6e0a9ba 100644 --- a/examples/deepvoice3/train.py +++ b/examples/deepvoice3/train.py @@ -17,6 +17,8 @@ import os import argparse import ruamel.yaml import numpy as np +import matplotlib +matplotlib.use("agg") from matplotlib import cm import matplotlib.pyplot as plt import tqdm @@ -35,13 +37,14 @@ from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec from parakeet.models.deepvoice3.loss import TTSLoss from parakeet.utils.layer_tools import summary +from parakeet.utils import io from data import LJSpeechMetaData, DataCollector, Transform from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Train a deepvoice 3 model with LJSpeech dataset.") + description="Train a Deep Voice 3 model with LJSpeech dataset.") parser.add_argument("-c", "--config", type=str, help="experimrnt config") parser.add_argument( "-s", @@ -49,7 +52,7 @@ if __name__ == "__main__": type=str, default="/workspace/datasets/LJSpeech-1.1/", help="The path of the LJSpeech dataset.") - parser.add_argument("-r", "--resume", type=str, help="checkpoint to load") + parser.add_argument("--checkpoint", type=str, help="checkpoint to load") parser.add_argument( "-o", "--output", @@ -62,6 +65,10 @@ if __name__ == "__main__": with open(args.config, 'rt') as f: config = ruamel.yaml.safe_load(f) + print("Command Line Args: ") + for k, v in vars(args).items(): + print("{}: {}".format(k, v)) + # =========================dataset========================= # construct meta data data_root = args.data @@ -151,6 +158,7 @@ if __name__ == "__main__": query_position_rate, key_position_rate, window_backward, window_ahead, key_projection, value_projection, downsample_factor, linear_dim, use_decoder_states, converter_channels, dropout) + summary(dv3) # =========================loss========================= loss_config = config["loss"] @@ -195,7 +203,6 @@ if __name__ == "__main__": n_iter = synthesis_config["n_iter"] # =========================link(dataloader, paddle)========================= - # CAUTION: it does not return a DataLoader loader = fluid.io.DataLoader.from_generator( capacity=10, return_list=True) loader.set_batch_generator(ljspeech_loader, places=place) @@ -208,122 +215,117 @@ if __name__ == "__main__": make_output_tree(output_dir) writer = SummaryWriter(logdir=log_dir) - # load model parameters - resume_path = args.resume - if resume_path is not None: - state, _ = dg.load_dygraph(args.resume) - dv3.set_dict(state) + # load parameters and optimizer, and opdate iterations done sofar + io.load_parameters(ckpt_dir, 0, dv3, optim, file_path=args.checkpoint) + if args.checkpoint is not None: + iteration = int(os.path.basename(args.checkpoint).split("-")[-1]) + else: + iteration = io.load_latest_checkpoint(ckpt_dir) # =========================train========================= - epoch = train_config["epochs"] + max_iter = train_config["max_iteration"] snap_interval = train_config["snap_interval"] save_interval = train_config["save_interval"] eval_interval = train_config["eval_interval"] - global_step = 1 + global_step = iteration + 1 + iterator = iter(tqdm.tqdm(loader)) + while global_step <= max_iter: + try: + batch = next(iterator) + except StopIteration as e: + iterator = iter(tqdm.tqdm(loader)) + batch = next(iterator) - for j in range(1, 1 + epoch): - epoch_loss = 0. - for i, batch in tqdm.tqdm(enumerate(loader, 1)): - dv3.train() # CAUTION: don't forget to switch to train - (text_sequences, text_lengths, text_positions, mel_specs, - lin_specs, frames, decoder_positions, done_flags) = batch - downsampled_mel_specs = F.strided_slice( - mel_specs, - axes=[1], - starts=[0], - ends=[mel_specs.shape[1]], - strides=[downsample_factor]) - mel_outputs, linear_outputs, alignments, done = dv3( - text_sequences, text_positions, text_lengths, None, - downsampled_mel_specs, decoder_positions) + dv3.train() + (text_sequences, text_lengths, text_positions, mel_specs, + lin_specs, frames, decoder_positions, done_flags) = batch + downsampled_mel_specs = F.strided_slice( + mel_specs, + axes=[1], + starts=[0], + ends=[mel_specs.shape[1]], + strides=[downsample_factor]) + mel_outputs, linear_outputs, alignments, done = dv3( + text_sequences, text_positions, text_lengths, None, + downsampled_mel_specs, decoder_positions) - losses = criterion(mel_outputs, linear_outputs, done, - alignments, downsampled_mel_specs, - lin_specs, done_flags, text_lengths, frames) - l = losses["loss"] - l.backward() - # record learning rate before updating - writer.add_scalar("learning_rate", - optim._learning_rate.step().numpy(), - global_step) - optim.minimize(l, grad_clip=gradient_clipper) - optim.clear_gradients() + losses = criterion(mel_outputs, linear_outputs, done, alignments, + downsampled_mel_specs, lin_specs, done_flags, + text_lengths, frames) + l = losses["loss"] + l.backward() + # record learning rate before updating + writer.add_scalar("learning_rate", + optim._learning_rate.step().numpy(), global_step) + optim.minimize(l, grad_clip=gradient_clipper) + optim.clear_gradients() - # ==================all kinds of tedious things================= - # record step loss into tensorboard - epoch_loss += l.numpy()[0] - step_loss = {k: v.numpy()[0] for k, v in losses.items()} - for k, v in step_loss.items(): - writer.add_scalar(k, v, global_step) + # ==================all kinds of tedious things================= + # record step loss into tensorboard + step_loss = {k: v.numpy()[0] for k, v in losses.items()} + tqdm.tqdm.write("global_step: {}\tloss: {}".format( + global_step, step_loss["loss"])) + for k, v in step_loss.items(): + writer.add_scalar(k, v, global_step) - # TODO: clean code - # train state saving, the first sentence in the batch - if global_step % snap_interval == 0: - save_state( - state_dir, - writer, + # train state saving, the first sentence in the batch + if global_step % snap_interval == 0: + save_state( + state_dir, + writer, + global_step, + mel_input=downsampled_mel_specs, + mel_output=mel_outputs, + lin_input=lin_specs, + lin_output=linear_outputs, + alignments=alignments, + win_length=win_length, + hop_length=hop_length, + min_level_db=min_level_db, + ref_level_db=ref_level_db, + power=power, + n_iter=n_iter, + preemphasis=preemphasis, + sample_rate=sample_rate) + + # evaluation + if global_step % eval_interval == 0: + sentences = [ + "Scientists at the CERN laboratory say they have discovered a new particle.", + "There's a way to measure the acute emotional intelligence that has never gone out of style.", + "President Trump met with other leaders at the Group of 20 conference.", + "Generative adversarial network or variational auto-encoder.", + "Please call Stella.", + "Some have accepted this as a miracle without any physical explanation.", + ] + for idx, sent in enumerate(sentences): + wav, attn = eval_model( + dv3, sent, replace_pronounciation_prob, min_level_db, + ref_level_db, power, n_iter, win_length, hop_length, + preemphasis) + wav_path = os.path.join( + state_dir, "waveform", + "eval_sample_{:09d}.wav".format(global_step)) + sf.write(wav_path, wav, sample_rate) + writer.add_audio( + "eval_sample_{}".format(idx), + wav, global_step, - mel_input=downsampled_mel_specs, - mel_output=mel_outputs, - lin_input=lin_specs, - lin_output=linear_outputs, - alignments=alignments, - win_length=win_length, - hop_length=hop_length, - min_level_db=min_level_db, - ref_level_db=ref_level_db, - power=power, - n_iter=n_iter, - preemphasis=preemphasis, sample_rate=sample_rate) + attn_path = os.path.join( + state_dir, "alignments", + "eval_sample_attn_{:09d}.png".format(global_step)) + plot_alignment(attn, attn_path) + writer.add_image( + "eval_sample_attn{}".format(idx), + cm.viridis(attn), + global_step, + dataformats="HWC") - # evaluation - if global_step % eval_interval == 0: - sentences = [ - "Scientists at the CERN laboratory say they have discovered a new particle.", - "There's a way to measure the acute emotional intelligence that has never gone out of style.", - "President Trump met with other leaders at the Group of 20 conference.", - "Generative adversarial network or variational auto-encoder.", - "Please call Stella.", - "Some have accepted this as a miracle without any physical explanation.", - ] - for idx, sent in enumerate(sentences): - wav, attn = eval_model( - dv3, sent, replace_pronounciation_prob, - min_level_db, ref_level_db, power, n_iter, - win_length, hop_length, preemphasis) - wav_path = os.path.join( - state_dir, "waveform", - "eval_sample_{:09d}.wav".format(global_step)) - sf.write(wav_path, wav, sample_rate) - writer.add_audio( - "eval_sample_{}".format(idx), - wav, - global_step, - sample_rate=sample_rate) - attn_path = os.path.join( - state_dir, "alignments", - "eval_sample_attn_{:09d}.png".format(global_step)) - plot_alignment(attn, attn_path) - writer.add_image( - "eval_sample_attn{}".format(idx), - cm.viridis(attn), - global_step, - dataformats="HWC") + # save checkpoint + if global_step % save_interval == 0: + io.save_latest_parameters(ckpt_dir, global_step, dv3, optim) + io.save_latest_checkpoint(ckpt_dir, global_step) - # save checkpoint - if global_step % save_interval == 0: - dg.save_dygraph( - dv3.state_dict(), - os.path.join(ckpt_dir, - "model_step_{}".format(global_step))) - dg.save_dygraph( - optim.state_dict(), - os.path.join(ckpt_dir, - "model_step_{}".format(global_step))) - - global_step += 1 - # epoch report - writer.add_scalar("epoch_average_loss", epoch_loss / i, j) - epoch_loss = 0. + global_step += 1 diff --git a/examples/wavenet/README.md b/examples/wavenet/README.md index 5114182..af34457 100644 --- a/examples/wavenet/README.md +++ b/examples/wavenet/README.md @@ -28,22 +28,22 @@ Train the model using train.py. For help on usage, try `python train.py --help`. ```text usage: train.py [-h] [--data DATA] [--config CONFIG] [--output OUTPUT] - [--device DEVICE] [--resume RESUME] + [--device DEVICE] [--checkpoint CHECKPOINT] Train a WaveNet model with LJSpeech. optional arguments: - -h, --help show this help message and exit - --data DATA path of the LJspeech dataset. - --config CONFIG path of the config file. - --output OUTPUT path to save results. - --device DEVICE device to use. - --resume RESUME checkpoint to resume from. + -h, --help show this help message and exit + --data DATA path of the LJspeech dataset. + --config CONFIG path of the config file. + --output OUTPUT path to save results. + --device DEVICE device to use. + --checkpoint CHECKPOINT checkpoint to resume from. ``` - `--config` is the configuration file to use. The provided configurations can be used directly. And you can change some values in the configuration file and train the model with a different config. - `--data` is the path of the LJSpeech dataset, the extracted folder from the downloaded archive (the folder which contains metadata.txt). -- `--resume` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training. +- `--checkpoint` is the path of the checkpoint. If it is provided, the model would load the checkpoint before training. - `--output` is the directory to save results, all result are saved in this directory. The structure of the output directory is shown below. ```text @@ -51,6 +51,8 @@ optional arguments: └── log # tensorboard log ``` +If `checkpoints` is not empty and argument `--checkpoint` is not specified, the model will be resumed from the latest checkpoint at the beginning of training. + - `--device` is the device (gpu id) to use for training. `-1` means CPU. Example script: diff --git a/examples/wavenet/synthesis.py b/examples/wavenet/synthesis.py index f3d4c93..5edb1ed 100644 --- a/examples/wavenet/synthesis.py +++ b/examples/wavenet/synthesis.py @@ -27,7 +27,7 @@ from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet from parakeet.utils.layer_tools import summary from data import LJSpeechMetaData, Transform, DataCollector -from utils import make_output_tree, valid_model, eval_model, save_checkpoint +from utils import make_output_tree, valid_model, eval_model if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -87,7 +87,8 @@ if __name__ == "__main__": batch_size=1, sampler=SequentialSampler(ljspeech_valid)) - make_output_tree(args.output) + if not os.path.exists(args.output): + os.makedirs(args.output) if args.device == -1: place = fluid.CPUPlace() diff --git a/examples/wavenet/train.py b/examples/wavenet/train.py index df24b10..3fdfaeb 100644 --- a/examples/wavenet/train.py +++ b/examples/wavenet/train.py @@ -16,7 +16,7 @@ from __future__ import division import os import ruamel.yaml import argparse -from tqdm import tqdm +import tqdm from tensorboardX import SummaryWriter from paddle import fluid import paddle.fluid.dygraph as dg @@ -24,13 +24,14 @@ import paddle.fluid.dygraph as dg from parakeet.data import SliceDataset, TransformDataset, DataCargo, SequentialSampler, RandomSampler from parakeet.models.wavenet import UpsampleNet, WaveNet, ConditionalWavenet from parakeet.utils.layer_tools import summary +from parakeet.utils import io from data import LJSpeechMetaData, Transform, DataCollector -from utils import make_output_tree, valid_model, save_checkpoint +from utils import make_output_tree, valid_model if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Train a wavenet model with LJSpeech.") + description="Train a WaveNet model with LJSpeech.") parser.add_argument( "--data", type=str, help="path of the LJspeech dataset.") parser.add_argument("--config", type=str, help="path of the config file.") @@ -42,12 +43,16 @@ if __name__ == "__main__": parser.add_argument( "--device", type=int, default=-1, help="device to use.") parser.add_argument( - "--resume", type=str, help="checkpoint to resume from.") + "--checkpoint", type=str, help="checkpoint to resume from.") args = parser.parse_args() with open(args.config, 'rt') as f: config = ruamel.yaml.safe_load(f) + print("Command Line Args: ") + for k, v in vars(args).items(): + print("{}: {}".format(k, v)) + ljspeech_meta = LJSpeechMetaData(args.data) data_config = config["data"] @@ -126,14 +131,6 @@ if __name__ == "__main__": clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm( gradiant_max_norm) - if args.resume: - model_dict, optim_dict = dg.load_dygraph(args.resume) - print("Loading from {}.pdparams".format(args.resume)) - model.set_dict(model_dict) - if optim_dict: - optim.set_dict(optim_dict) - print("Loading from {}.pdopt".format(args.resume)) - train_loader = fluid.io.DataLoader.from_generator( capacity=10, return_list=True) train_loader.set_batch_generator(train_cargo, place) @@ -150,33 +147,48 @@ if __name__ == "__main__": log_dir = os.path.join(args.output, "log") writer = SummaryWriter(log_dir) - global_step = 1 + # load parameters and optimizer, and opdate iterations done sofar + io.load_parameters( + checkpoint_dir, 0, model, optim, file_path=args.checkpoint) + if args.checkpoint is not None: + iteration = int(os.path.basename(args.checkpoint).split("-")[-1]) + else: + iteration = io.load_latest_checkpoint(checkpoint_dir) + + global_step = iteration + 1 + iterator = iter(tqdm.tqdm(train_loader)) while global_step <= max_iterations: - epoch_loss = 0. - for i, batch in tqdm(enumerate(train_loader)): - audio_clips, mel_specs, audio_starts = batch + print(global_step) + try: + batch = next(iterator) + except StopIteration as e: + iterator = iter(tqdm.tqdm(train_loader)) + batch = next(iterator) - model.train() - y_var = model(audio_clips, mel_specs, audio_starts) - loss_var = model.loss(y_var, audio_clips) - loss_var.backward() - loss_np = loss_var.numpy() + audio_clips, mel_specs, audio_starts = batch - epoch_loss += loss_np[0] + model.train() + y_var = model(audio_clips, mel_specs, audio_starts) + loss_var = model.loss(y_var, audio_clips) + loss_var.backward() + loss_np = loss_var.numpy() - writer.add_scalar("loss", loss_np[0], global_step) - writer.add_scalar("learning_rate", - optim._learning_rate.step().numpy()[0], - global_step) - optim.minimize(loss_var, grad_clip=clipper) - optim.clear_gradients() - print("loss: {:<8.6f}".format(loss_np[0])) + writer.add_scalar("loss", loss_np[0], global_step) + writer.add_scalar("learning_rate", + optim._learning_rate.step().numpy()[0], + global_step) + optim.minimize(loss_var, grad_clip=clipper) + optim.clear_gradients() + print("global_step: {}\tloss: {:<8.6f}".format(global_step, + loss_np[0])) - if global_step % snap_interval == 0: - valid_model(model, valid_loader, writer, global_step, - sample_rate) + if global_step % snap_interval == 0: + valid_model(model, valid_loader, writer, global_step, + sample_rate) - if global_step % checkpoint_interval == 0: - save_checkpoint(model, optim, checkpoint_dir, global_step) + if global_step % checkpoint_interval == 0: + io.save_latest_parameters(checkpoint_dir, global_step, model, + optim) + io.save_latest_checkpoint(checkpoint_dir, global_step) - global_step += 1 + global_step += 1 diff --git a/examples/wavenet/utils.py b/examples/wavenet/utils.py index bae186f..cb71acd 100644 --- a/examples/wavenet/utils.py +++ b/examples/wavenet/utils.py @@ -59,10 +59,3 @@ def eval_model(model, valid_loader, output_dir, sample_rate): wav_np = wav_var.numpy()[0] sf.write(path, wav_np, samplerate=sample_rate) print("generated {}".format(path)) - - -def save_checkpoint(model, optim, checkpoint_dir, global_step): - checkpoint_path = os.path.join(checkpoint_dir, - "step_{:09d}".format(global_step)) - dg.save_dygraph(model.state_dict(), checkpoint_path) - dg.save_dygraph(optim.state_dict(), checkpoint_path)