Parakeet/examples/transformer_tts/train_transformer.py

147 lines
6.3 KiB
Python
Raw Normal View History

2019-12-16 17:04:22 +08:00
import os
from tqdm import tqdm
from tensorboardX import SummaryWriter
from pathlib import Path
2020-01-22 15:46:35 +08:00
from collections import OrderedDict
2020-02-13 14:48:21 +08:00
import argparse
2019-12-16 17:04:22 +08:00
from parse import add_config_options_to_parser
from pprint import pprint
2020-02-13 14:48:21 +08:00
from ruamel import yaml
2019-12-16 17:04:22 +08:00
from matplotlib import cm
2020-02-10 15:38:29 +08:00
import numpy as np
import paddle.fluid as fluid
2020-01-22 15:46:35 +08:00
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
2020-01-03 16:25:17 +08:00
from parakeet.modules.utils import cross_entropy
2020-01-22 15:46:35 +08:00
from parakeet.models.dataloader.ljspeech import LJSpeechLoader
2020-02-13 14:48:21 +08:00
from parakeet.models.transformer_tts.transformerTTS import TransformerTTS
2019-12-16 17:04:22 +08:00
2020-01-08 14:08:09 +08:00
def load_checkpoint(step, model_path):
model_dict, opti_dict = fluid.dygraph.load_dygraph(os.path.join(model_path, step))
2020-01-22 15:46:35 +08:00
new_state_dict = OrderedDict()
for param in model_dict:
if param.startswith('_layers.'):
new_state_dict[param[8:]] = model_dict[param]
else:
new_state_dict[param] = model_dict[param]
return new_state_dict, opti_dict
2020-01-08 14:08:09 +08:00
2019-12-16 17:04:22 +08:00
2020-02-13 14:48:21 +08:00
def main(args):
local_rank = dg.parallel.Env().local_rank if args.use_data_parallel else 0
nranks = dg.parallel.Env().nranks if args.use_data_parallel else 1
2019-12-16 17:04:22 +08:00
2020-02-13 14:48:21 +08:00
with open(args.config_path) as f:
cfg = yaml.load(f, Loader=yaml.Loader)
2019-12-16 17:04:22 +08:00
global_step = 0
place = (fluid.CUDAPlace(dg.parallel.Env().dev_id)
2020-02-13 14:48:21 +08:00
if args.use_data_parallel else fluid.CUDAPlace(0)
if args.use_gpu else fluid.CPUPlace())
2019-12-16 17:04:22 +08:00
2020-02-13 14:48:21 +08:00
if not os.path.exists(args.log_dir):
os.mkdir(args.log_dir)
path = os.path.join(args.log_dir,'transformer')
2019-12-16 17:04:22 +08:00
writer = SummaryWriter(path) if local_rank == 0 else None
with dg.guard(place):
2020-01-03 16:25:17 +08:00
model = TransformerTTS(cfg)
2019-12-16 17:04:22 +08:00
model.train()
2020-02-13 14:48:21 +08:00
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=dg.NoamDecay(1/(cfg['warm_up_step'] *( args.lr ** 2)), cfg['warm_up_step']),
parameter_list=model.parameters())
2019-12-17 14:23:34 +08:00
2020-02-13 14:48:21 +08:00
reader = LJSpeechLoader(cfg, args, nranks, local_rank, shuffle=True).reader()
if args.checkpoint_path is not None:
model_dict, opti_dict = load_checkpoint(str(args.transformer_step), os.path.join(args.checkpoint_path, "transformer"))
2019-12-16 17:04:22 +08:00
model.set_dict(model_dict)
optimizer.set_dict(opti_dict)
2020-02-13 14:48:21 +08:00
global_step = args.transformer_step
2019-12-16 17:04:22 +08:00
print("load checkpoint!!!")
2020-02-13 14:48:21 +08:00
if args.use_data_parallel:
2019-12-17 14:23:34 +08:00
strategy = dg.parallel.prepare_context()
2020-01-22 15:46:35 +08:00
model = fluid.dygraph.parallel.DataParallel(model, strategy)
2019-12-17 14:23:34 +08:00
2020-02-13 14:48:21 +08:00
for epoch in range(args.epochs):
2019-12-17 14:23:34 +08:00
pbar = tqdm(reader)
2019-12-16 17:04:22 +08:00
for i, data in enumerate(pbar):
pbar.set_description('Processing at epoch %d'%epoch)
2020-01-22 15:46:35 +08:00
character, mel, mel_input, pos_text, pos_mel, text_length, _ = data
2019-12-16 17:04:22 +08:00
global_step += 1
mel_pred, postnet_pred, attn_probs, stop_preds, attn_enc, attn_dec = model(character, mel_input, pos_text, pos_mel)
2020-01-13 20:37:49 +08:00
2020-01-03 16:25:17 +08:00
2020-01-15 14:10:27 +08:00
label = (pos_mel == 0).astype(np.float32)
2020-01-13 20:37:49 +08:00
2019-12-16 17:04:22 +08:00
mel_loss = layers.mean(layers.abs(layers.elementwise_sub(mel_pred, mel)))
post_mel_loss = layers.mean(layers.abs(layers.elementwise_sub(postnet_pred, mel)))
2020-01-22 15:46:35 +08:00
loss = mel_loss + post_mel_loss
# Note: When used stop token loss the learning did not work.
2020-02-13 14:48:21 +08:00
if args.stop_token:
2020-01-22 15:46:35 +08:00
stop_loss = cross_entropy(stop_preds, label)
loss = loss + stop_loss
2020-01-13 20:37:49 +08:00
2019-12-17 14:23:34 +08:00
if local_rank==0:
writer.add_scalars('training_loss', {
'mel_loss':mel_loss.numpy(),
2020-02-10 15:38:29 +08:00
'post_mel_loss':post_mel_loss.numpy()
2019-12-17 14:23:34 +08:00
}, global_step)
2019-12-16 17:04:22 +08:00
2020-02-13 14:48:21 +08:00
if args.stop_token:
2020-02-10 15:38:29 +08:00
writer.add_scalar('stop_loss', stop_loss.numpy(), global_step)
2019-12-17 14:23:34 +08:00
writer.add_scalars('alphas', {
'encoder_alpha':model.encoder.alpha.numpy(),
'decoder_alpha':model.decoder.alpha.numpy(),
}, global_step)
2019-12-16 17:04:22 +08:00
2019-12-17 14:23:34 +08:00
writer.add_scalar('learning_rate', optimizer._learning_rate.step().numpy(), global_step)
2019-12-16 17:04:22 +08:00
2020-02-13 14:48:21 +08:00
if global_step % args.image_step == 1:
2019-12-17 14:23:34 +08:00
for i, prob in enumerate(attn_probs):
for j in range(4):
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
2020-01-03 16:25:17 +08:00
writer.add_image('Attention_%d_0'%global_step, x, i*4+j, dataformats="HWC")
2019-12-16 17:04:22 +08:00
2019-12-17 14:23:34 +08:00
for i, prob in enumerate(attn_enc):
for j in range(4):
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
writer.add_image('Attention_enc_%d_0'%global_step, x, i*4+j, dataformats="HWC")
2019-12-16 17:04:22 +08:00
2019-12-17 14:23:34 +08:00
for i, prob in enumerate(attn_dec):
for j in range(4):
x = np.uint8(cm.viridis(prob.numpy()[j*16]) * 255)
writer.add_image('Attention_dec_%d_0'%global_step, x, i*4+j, dataformats="HWC")
2020-02-13 14:48:21 +08:00
if args.use_data_parallel:
2019-12-17 14:23:34 +08:00
loss = model.scale_loss(loss)
loss.backward()
2019-12-16 17:04:22 +08:00
model.apply_collective_grads()
2019-12-17 14:23:34 +08:00
else:
loss.backward()
2020-02-13 14:48:21 +08:00
optimizer.minimize(loss, grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(cfg['grad_clip_thresh']))
2019-12-16 17:04:22 +08:00
model.clear_gradients()
2019-12-16 17:04:22 +08:00
# save checkpoint
2020-02-13 14:48:21 +08:00
if local_rank==0 and global_step % args.save_step == 0:
if not os.path.exists(args.save_path):
os.mkdir(args.save_path)
save_path = os.path.join(args.save_path,'transformer/%d' % global_step)
2019-12-16 17:04:22 +08:00
dg.save_dygraph(model.state_dict(), save_path)
dg.save_dygraph(optimizer.state_dict(), save_path)
if local_rank==0:
writer.close()
if __name__ =='__main__':
2020-02-13 14:48:21 +08:00
parser = argparse.ArgumentParser(description="Train TransformerTTS model")
2019-12-17 14:23:34 +08:00
add_config_options_to_parser(parser)
2020-02-13 14:48:21 +08:00
args = parser.parse_args()
# Print the whole config setting.
pprint(args)
main(args)