Parakeet/examples/fastspeech/train.py

167 lines
5.7 KiB
Python
Raw Normal View History

2020-02-26 21:03:51 +08:00
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2020-01-03 16:25:17 +08:00
import numpy as np
import argparse
import os
import time
import math
from pathlib import Path
2020-01-22 15:46:35 +08:00
from pprint import pprint
2020-02-13 14:48:21 +08:00
from ruamel import yaml
2020-01-03 16:25:17 +08:00
from tqdm import tqdm
from matplotlib import cm
2020-01-22 15:46:35 +08:00
from collections import OrderedDict
2020-08-07 16:28:21 +08:00
from visualdl import LogWriter
2020-01-03 16:25:17 +08:00
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
import paddle.fluid as fluid
2020-02-10 15:38:29 +08:00
from parakeet.models.fastspeech.fastspeech import FastSpeech
from parakeet.models.fastspeech.utils import get_alignment
2020-02-13 20:46:21 +08:00
from data import LJSpeechLoader
from parakeet.utils import io
2020-01-03 16:25:17 +08:00
2020-02-26 21:03:51 +08:00
def add_config_options_to_parser(parser):
parser.add_argument("--config", type=str, help="path of the config file")
parser.add_argument("--use_gpu", type=int, default=0, help="device to use")
parser.add_argument("--data", type=str, help="path of LJspeech dataset")
parser.add_argument(
"--alignments_path", type=str, help="path of alignments")
g = parser.add_mutually_exclusive_group()
g.add_argument("--checkpoint", type=str, help="checkpoint to resume from")
g.add_argument(
"--iteration",
type=int,
help="the iteration of the checkpoint to load from output directory")
parser.add_argument(
"--output",
type=str,
default="experiment",
help="path to save experiment results")
2020-01-03 16:25:17 +08:00
2020-02-26 21:03:51 +08:00
2020-02-13 14:48:21 +08:00
def main(args):
local_rank = dg.parallel.Env().local_rank
nranks = dg.parallel.Env().nranks
parallel = nranks > 1
2020-01-03 16:25:17 +08:00
with open(args.config) as f:
2020-02-13 14:48:21 +08:00
cfg = yaml.load(f, Loader=yaml.Loader)
2020-01-03 16:25:17 +08:00
global_step = 0
place = fluid.CUDAPlace(dg.parallel.Env()
.dev_id) if args.use_gpu else fluid.CPUPlace()
fluid.enable_dygraph(place)
2020-01-03 16:25:17 +08:00
if not os.path.exists(args.output):
os.mkdir(args.output)
2020-01-03 16:25:17 +08:00
2020-08-07 16:28:21 +08:00
writer = LogWriter(os.path.join(args.output,
'log')) if local_rank == 0 else None
2020-01-03 16:25:17 +08:00
model = FastSpeech(cfg['network'], num_mels=cfg['audio']['num_mels'])
model.train()
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=dg.NoamDecay(1 / (cfg['train']['warm_up_step'] *
(cfg['train']['learning_rate']**2)),
cfg['train']['warm_up_step']),
parameter_list=model.parameters(),
grad_clip=fluid.clip.GradientClipByGlobalNorm(cfg['train'][
'grad_clip_thresh']))
reader = LJSpeechLoader(
cfg['audio'],
place,
args.data,
args.alignments_path,
cfg['train']['batch_size'],
nranks,
local_rank,
shuffle=True).reader
iterator = iter(tqdm(reader))
# Load parameters.
global_step = io.load_parameters(
model=model,
optimizer=optimizer,
checkpoint_dir=os.path.join(args.output, 'checkpoints'),
iteration=args.iteration,
checkpoint_path=args.checkpoint)
print("Rank {}: checkpoint loaded.".format(local_rank))
if parallel:
strategy = dg.parallel.prepare_context()
model = fluid.dygraph.parallel.DataParallel(model, strategy)
while global_step <= cfg['train']['max_iteration']:
try:
batch = next(iterator)
except StopIteration as e:
iterator = iter(tqdm(reader))
batch = next(iterator)
(character, mel, pos_text, pos_mel, alignment) = batch
global_step += 1
#Forward
result = model(
character, pos_text, mel_pos=pos_mel, length_target=alignment)
mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
mel_loss = layers.mse_loss(mel_output, mel)
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
duration_loss = layers.mean(
layers.abs(
layers.elementwise_sub(duration_predictor_output, alignment)))
total_loss = mel_loss + mel_postnet_loss + duration_loss
if local_rank == 0:
writer.add_scalar('mel_loss', mel_loss.numpy(), global_step)
writer.add_scalar('post_mel_loss',
mel_postnet_loss.numpy(), global_step)
writer.add_scalar('duration_loss',
duration_loss.numpy(), global_step)
writer.add_scalar('learning_rate',
optimizer._learning_rate.step().numpy(),
global_step)
if parallel:
total_loss = model.scale_loss(total_loss)
total_loss.backward()
model.apply_collective_grads()
else:
total_loss.backward()
optimizer.minimize(total_loss)
model.clear_gradients()
# save checkpoint
if local_rank == 0 and global_step % cfg['train'][
'checkpoint_interval'] == 0:
io.save_parameters(
os.path.join(args.output, 'checkpoints'), global_step, model,
optimizer)
if local_rank == 0:
writer.close()
2020-01-03 16:25:17 +08:00
2020-02-26 21:03:51 +08:00
if __name__ == '__main__':
2020-02-13 14:48:21 +08:00
parser = argparse.ArgumentParser(description="Train Fastspeech model")
2020-01-03 16:25:17 +08:00
add_config_options_to_parser(parser)
2020-02-13 14:48:21 +08:00
args = parser.parse_args()
# Print the whole config setting.
pprint(vars(args))
2020-02-13 14:48:21 +08:00
main(args)