Parakeet/examples/fastspeech/train.py

195 lines
7.8 KiB
Python
Raw Normal View History

2020-02-26 21:03:51 +08:00
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2020-01-03 16:25:17 +08:00
import numpy as np
import argparse
import os
import time
import math
from pathlib import Path
2020-01-22 15:46:35 +08:00
from parse import add_config_options_to_parser
from pprint import pprint
2020-02-13 14:48:21 +08:00
from ruamel import yaml
2020-01-03 16:25:17 +08:00
from tqdm import tqdm
from matplotlib import cm
2020-01-22 15:46:35 +08:00
from collections import OrderedDict
2020-01-03 16:25:17 +08:00
from tensorboardX import SummaryWriter
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
import paddle.fluid as fluid
2020-02-13 20:46:21 +08:00
from parakeet.models.transformer_tts.transformer_tts import TransformerTTS
2020-02-10 15:38:29 +08:00
from parakeet.models.fastspeech.fastspeech import FastSpeech
from parakeet.models.fastspeech.utils import get_alignment
2020-02-13 20:46:21 +08:00
import sys
sys.path.append("../transformer_tts")
from data import LJSpeechLoader
2020-01-03 16:25:17 +08:00
2020-02-26 21:03:51 +08:00
2020-01-22 15:46:35 +08:00
def load_checkpoint(step, model_path):
2020-02-26 21:03:51 +08:00
model_dict, opti_dict = fluid.dygraph.load_dygraph(
os.path.join(model_path, step))
2020-01-22 15:46:35 +08:00
new_state_dict = OrderedDict()
for param in model_dict:
if param.startswith('_layers.'):
new_state_dict[param[8:]] = model_dict[param]
2020-01-03 16:25:17 +08:00
else:
2020-01-22 15:46:35 +08:00
new_state_dict[param] = model_dict[param]
return new_state_dict, opti_dict
2020-01-03 16:25:17 +08:00
2020-02-26 21:03:51 +08:00
2020-02-13 14:48:21 +08:00
def main(args):
local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
2020-01-03 16:25:17 +08:00
2020-02-13 14:48:21 +08:00
with open(args.config_path) as f:
cfg = yaml.load(f, Loader=yaml.Loader)
2020-01-03 16:25:17 +08:00
global_step = 0
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
2020-02-13 14:48:21 +08:00
if args.use_data_parallel else fluid.CUDAPlace(0)
if args.use_gpu else fluid.CPUPlace())
2020-01-03 16:25:17 +08:00
2020-02-13 14:48:21 +08:00
if not os.path.exists(args.log_dir):
2020-02-26 21:03:51 +08:00
os.mkdir(args.log_dir)
path = os.path.join(args.log_dir, 'fastspeech')
2020-01-03 16:25:17 +08:00
writer = SummaryWriter(path) if local_rank == 0 else None
with dg.guard(place):
2020-01-15 14:10:27 +08:00
with fluid.unique_name.guard():
2020-03-06 10:47:16 +08:00
transformer_tts = TransformerTTS(cfg)
2020-02-26 21:03:51 +08:00
model_dict, _ = load_checkpoint(
str(args.transformer_step),
os.path.join(args.transtts_path, "transformer"))
2020-03-06 10:47:16 +08:00
transformer_tts.set_dict(model_dict)
transformer_tts.eval()
2020-01-03 16:25:17 +08:00
model = FastSpeech(cfg)
model.train()
2020-02-26 21:03:51 +08:00
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=dg.NoamDecay(1 / (
cfg['warm_up_step'] * (args.lr**2)), cfg['warm_up_step']),
parameter_list=model.parameters())
reader = LJSpeechLoader(
cfg, args, nranks, local_rank, shuffle=True).reader()
2020-02-13 14:48:21 +08:00
if args.checkpoint_path is not None:
2020-02-26 21:03:51 +08:00
model_dict, opti_dict = load_checkpoint(
str(args.fastspeech_step),
os.path.join(args.checkpoint_path, "fastspeech"))
2020-01-03 16:25:17 +08:00
model.set_dict(model_dict)
optimizer.set_dict(opti_dict)
2020-02-13 14:48:21 +08:00
global_step = args.fastspeech_step
2020-01-03 16:25:17 +08:00
print("load checkpoint!!!")
2020-02-13 14:48:21 +08:00
if args.use_data_parallel:
2020-01-03 16:25:17 +08:00
strategy = dg.parallel.prepare_context()
2020-01-22 15:46:35 +08:00
model = fluid.dygraph.parallel.DataParallel(model, strategy)
2020-01-03 16:25:17 +08:00
2020-02-13 14:48:21 +08:00
for epoch in range(args.epochs):
2020-01-03 16:25:17 +08:00
pbar = tqdm(reader)
for i, data in enumerate(pbar):
2020-02-26 21:03:51 +08:00
pbar.set_description('Processing at epoch %d' % epoch)
(character, mel, mel_input, pos_text, pos_mel, text_length,
mel_lens, enc_slf_mask, enc_query_mask, dec_slf_mask,
enc_dec_mask, dec_query_slf_mask, dec_query_mask) = data
2020-01-03 16:25:17 +08:00
2020-03-06 10:47:16 +08:00
_, _, attn_probs, _, _, _ = transformer_tts(
character,
mel_input,
pos_text,
pos_mel,
dec_slf_mask=dec_slf_mask,
enc_slf_mask=enc_slf_mask,
enc_query_mask=enc_query_mask,
enc_dec_mask=enc_dec_mask,
dec_query_slf_mask=dec_query_slf_mask,
dec_query_mask=dec_query_mask)
alignment, max_attn = get_alignment(attn_probs, mel_lens,
cfg['transformer_head'])
alignment = dg.to_variable(alignment).astype(np.float32)
if local_rank == 0 and global_step % 5 == 1:
x = np.uint8(
cm.viridis(max_attn[8, :mel_lens.numpy()[8]]) * 255)
writer.add_image(
'Attention_%d_0' % global_step,
x,
0,
dataformats="HWC")
2020-01-22 15:46:35 +08:00
2020-01-03 16:25:17 +08:00
global_step += 1
2020-02-26 21:03:51 +08:00
2020-01-03 16:25:17 +08:00
#Forward
2020-02-26 21:03:51 +08:00
result = model(
character,
pos_text,
mel_pos=pos_mel,
length_target=alignment,
enc_non_pad_mask=enc_query_mask,
enc_slf_attn_mask=enc_slf_mask,
dec_non_pad_mask=dec_query_slf_mask,
dec_slf_attn_mask=dec_slf_mask)
2020-01-03 16:25:17 +08:00
mel_output, mel_output_postnet, duration_predictor_output, _, _ = result
mel_loss = layers.mse_loss(mel_output, mel)
mel_postnet_loss = layers.mse_loss(mel_output_postnet, mel)
2020-02-26 21:03:51 +08:00
duration_loss = layers.mean(
layers.abs(
layers.elementwise_sub(duration_predictor_output,
alignment)))
2020-01-03 16:25:17 +08:00
total_loss = mel_loss + mel_postnet_loss + duration_loss
2020-02-26 21:03:51 +08:00
if local_rank == 0:
writer.add_scalar('mel_loss',
mel_loss.numpy(), global_step)
writer.add_scalar('post_mel_loss',
mel_postnet_loss.numpy(), global_step)
writer.add_scalar('duration_loss',
duration_loss.numpy(), global_step)
writer.add_scalar('learning_rate',
optimizer._learning_rate.step().numpy(),
global_step)
2020-01-03 16:25:17 +08:00
2020-02-13 14:48:21 +08:00
if args.use_data_parallel:
2020-01-03 16:25:17 +08:00
total_loss = model.scale_loss(total_loss)
total_loss.backward()
model.apply_collective_grads()
else:
total_loss.backward()
2020-02-26 21:03:51 +08:00
optimizer.minimize(
total_loss,
grad_clip=fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg[
'grad_clip_thresh']))
2020-01-03 16:25:17 +08:00
model.clear_gradients()
2020-02-26 21:03:51 +08:00
# save checkpoint
if local_rank == 0 and global_step % args.save_step == 0:
2020-02-13 14:48:21 +08:00
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
2020-02-26 21:03:51 +08:00
save_path = os.path.join(args.save_path,
'fastspeech/%d' % global_step)
2020-01-03 16:25:17 +08:00
dg.save_dygraph(model.state_dict(), save_path)
dg.save_dygraph(optimizer.state_dict(), save_path)
2020-02-26 21:03:51 +08:00
if local_rank == 0:
2020-01-03 16:25:17 +08:00
writer.close()
2020-02-26 21:03:51 +08:00
if __name__ == '__main__':
2020-02-13 14:48:21 +08:00
parser = argparse.ArgumentParser(description="Train Fastspeech model")
2020-01-03 16:25:17 +08:00
add_config_options_to_parser(parser)
2020-02-13 14:48:21 +08:00
args = parser.parse_args()
# Print the whole config setting.
pprint(args)
main(args)