ParakeetEricRoss/parakeet/models/wavenet/train.py

140 lines
4.5 KiB
Python

import os
import random
import subprocess
import time
from pprint import pprint
import jsonargparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from tensorboardX import SummaryWriter
import slurm
import utils
from wavenet import WaveNet
MAXIMUM_SAVE_TIME = 10 * 60
def add_options_to_parser(parser):
parser.add_argument('--model', type=str, default='wavenet',
help="general name of the model")
parser.add_argument('--name', type=str,
help="specific name of the training model")
parser.add_argument('--root', type=str,
help="root path of the LJSpeech dataset")
parser.add_argument('--parallel', type=bool, default=True,
help="option to use data parallel training")
parser.add_argument('--use_gpu', type=bool, default=True,
help="option to use gpu training")
parser.add_argument('--iteration', type=int, default=None,
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument('--checkpoint', type=str, default=None,
help="path of the checkpoint to load")
parser.add_argument('--slurm', type=bool, default=False,
help="whether you are using slurm to submit training jobs")
def train(config):
use_gpu = config.use_gpu
parallel = config.parallel if use_gpu else False
# Get the rank of the current training process.
rank = dg.parallel.Env().local_rank if parallel else 0
nranks = dg.parallel.Env().nranks if parallel else 1
if rank == 0:
# Print the whole config setting.
pprint(jsonargparse.namespace_to_dict(config))
# Make checkpoint directory.
run_dir = os.path.join("runs", config.model, config.name)
checkpoint_dir = os.path.join(run_dir, "checkpoint")
os.makedirs(checkpoint_dir, exist_ok=True)
# Create tensorboard logger.
tb = SummaryWriter(os.path.join(run_dir, "logs")) \
if rank == 0 else None
# Configurate device
place = fluid.CUDAPlace(rank) if use_gpu else fluid.CPUPlace()
with dg.guard(place):
# Fix random seed.
seed = config.seed
random.seed(seed)
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
print("Random Seed: ", seed)
# Build model.
model = WaveNet(config, checkpoint_dir, parallel, rank, nranks, tb)
model.build()
# Obtain the current iteration.
if config.checkpoint is None:
if config.iteration is None:
iteration = utils.load_latest_checkpoint(checkpoint_dir, rank)
else:
iteration = config.iteration
else:
iteration = int(config.checkpoint.split('/')[-1].split('-')[-1])
# Get restart command if using slurm.
if config.slurm:
resume_command, death_time = slurm.restart_command()
if rank == 0:
print("Restart command:", " ".join(resume_command))
done = False
while iteration < config.max_iterations:
# Run one single training step.
model.train_step(iteration)
iteration += 1
if iteration % config.test_every == 0:
# Run validation step.
model.valid_step(iteration)
# Check whether reaching the time limit.
if config.slurm:
done = (death_time is not None and death_time - time.time() <
MAXIMUM_SAVE_TIME)
if rank == 0 and done:
print("Saving progress before exiting.")
model.save(iteration)
print("Running restart command:", " ".join(resume_command))
# Submit restart command.
subprocess.check_call(resume_command)
break
if rank == 0 and iteration % config.save_every == 0:
# Save parameters.
model.save(iteration)
# Close TensorBoard.
if rank == 0:
tb.close()
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(description="Train WaveNet model",
formatter_class='default_argparse')
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
# Parse argument from both command line and yaml config file.
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
train(config)