ParakeetRebeccaRosario/examples/deepvoice3/train.py

314 lines
15 KiB
Python
Raw Normal View History

2020-02-13 10:24:34 +08:00
import os
import argparse
import ruamel.yaml
2020-02-13 10:24:34 +08:00
import numpy as np
from matplotlib import cm
import matplotlib.pyplot as plt
import tqdm
import librosa
from librosa import display
import soundfile as sf
from tensorboardX import SummaryWriter
from paddle import fluid
import paddle.fluid.layers as F
import paddle.fluid.dygraph as dg
from parakeet.g2p import en
from parakeet.data import FilterDataset, TransformDataset, FilterDataset
from parakeet.data import DataCargo, PartialyRandomizedSimilarTimeLengthSampler, SequentialSampler
from parakeet.models.deepvoice3 import Encoder, Decoder, Converter, DeepVoice3, ConvSpec
2020-02-13 10:24:34 +08:00
from parakeet.models.deepvoice3.loss import TTSLoss
from parakeet.utils.layer_tools import summary
from data import LJSpeechMetaData, DataCollector, Transform
from utils import make_model, eval_model, save_state, make_output_tree, plot_alignment
2020-02-13 10:24:34 +08:00
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a deepvoice 3 model with LJSpeech dataset.")
parser.add_argument("-c", "--config", type=str, help="experimrnt config")
parser.add_argument("-s",
"--data",
type=str,
default="/workspace/datasets/LJSpeech-1.1/",
help="The path of the LJSpeech dataset.")
parser.add_argument("-r", "--resume", type=str, help="checkpoint to load")
parser.add_argument("-o",
"--output",
type=str,
default="result",
help="The directory to save result.")
parser.add_argument("-g",
"--device",
type=int,
default=-1,
help="device to use")
args, _ = parser.parse_known_args()
with open(args.config, 'rt') as f:
config = ruamel.yaml.safe_load(f)
# =========================dataset=========================
# construct meta data
data_root = args.data
meta = LJSpeechMetaData(data_root)
# filter it!
min_text_length = config["meta_data"]["min_text_length"]
meta = FilterDataset(meta, lambda x: len(x[2]) >= min_text_length)
# transform meta data into meta data
transform_config = config["transform"]
replace_pronounciation_prob = transform_config[
"replace_pronunciation_prob"]
sample_rate = transform_config["sample_rate"]
preemphasis = transform_config["preemphasis"]
n_fft = transform_config["n_fft"]
win_length = transform_config["win_length"]
hop_length = transform_config["hop_length"]
fmin = transform_config["fmin"]
fmax = transform_config["fmax"]
n_mels = transform_config["n_mels"]
min_level_db = transform_config["min_level_db"]
ref_level_db = transform_config["ref_level_db"]
max_norm = transform_config["max_norm"]
clip_norm = transform_config["clip_norm"]
transform = Transform(replace_pronounciation_prob, sample_rate,
preemphasis, n_fft, win_length, hop_length, fmin,
fmax, n_mels, min_level_db, ref_level_db, max_norm,
clip_norm)
ljspeech = TransformDataset(meta, transform)
# =========================dataiterator=========================
# use meta data's text length as a sort key for the sampler
train_config = config["train"]
batch_size = train_config["batch_size"]
text_lengths = [len(example[2]) for example in meta]
sampler = PartialyRandomizedSimilarTimeLengthSampler(
text_lengths, batch_size)
# some hyperparameters affect how we process data, so create a data collector!
model_config = config["model"]
downsample_factor = model_config["downsample_factor"]
r = model_config["outputs_per_step"]
collector = DataCollector(downsample_factor=downsample_factor, r=r)
ljspeech_loader = DataCargo(ljspeech,
batch_fn=collector,
batch_size=batch_size,
sampler=sampler)
# =========================model=========================
if args.device == -1:
place = fluid.CPUPlace()
else:
place = fluid.CUDAPlace(args.device)
with dg.guard(place):
# =========================model=========================
n_speakers = model_config["n_speakers"]
speaker_dim = model_config["speaker_embed_dim"]
speaker_embed_std = model_config["speaker_embedding_weight_std"]
n_vocab = en.n_vocab
embed_dim = model_config["text_embed_dim"]
linear_dim = 1 + n_fft // 2
use_decoder_states = model_config[
"use_decoder_state_for_postnet_input"]
filter_size = model_config["kernel_size"]
encoder_channels = model_config["encoder_channels"]
decoder_channels = model_config["decoder_channels"]
converter_channels = model_config["converter_channels"]
dropout = model_config["dropout"]
padding_idx = model_config["padding_idx"]
embedding_std = model_config["embedding_weight_std"]
max_positions = model_config["max_positions"]
freeze_embedding = model_config["freeze_embedding"]
trainable_positional_encodings = model_config[
"trainable_positional_encodings"]
use_memory_mask = model_config["use_memory_mask"]
query_position_rate = model_config["query_position_rate"]
key_position_rate = model_config["key_position_rate"]
window_backward = model_config["window_backward"]
2020-02-13 10:24:34 +08:00
window_ahead = model_config["window_ahead"]
key_projection = model_config["key_projection"]
value_projection = model_config["value_projection"]
dv3 = make_model(n_speakers, speaker_dim, speaker_embed_std, embed_dim,
padding_idx, embedding_std, max_positions, n_vocab,
freeze_embedding, filter_size, encoder_channels,
n_mels, decoder_channels, r,
trainable_positional_encodings, use_memory_mask,
query_position_rate, key_position_rate,
window_backward, window_ahead, key_projection,
value_projection, downsample_factor, linear_dim,
use_decoder_states, converter_channels, dropout)
2020-02-13 10:24:34 +08:00
# =========================loss=========================
loss_config = config["loss"]
masked_weight = loss_config["masked_loss_weight"]
priority_freq = loss_config["priority_freq"] # Hz
priority_bin = int(priority_freq / (0.5 * sample_rate) * linear_dim)
priority_freq_weight = loss_config["priority_freq_weight"]
binary_divergence_weight = loss_config["binary_divergence_weight"]
guided_attention_sigma = loss_config["guided_attention_sigma"]
criterion = TTSLoss(masked_weight=masked_weight,
priority_bin=priority_bin,
priority_weight=priority_freq_weight,
binary_divergence_weight=binary_divergence_weight,
guided_attention_sigma=guided_attention_sigma,
downsample_factor=downsample_factor,
r=r)
# =========================lr_scheduler=========================
lr_config = config["lr_scheduler"]
warmup_steps = lr_config["warmup_steps"]
peak_learning_rate = lr_config["peak_learning_rate"]
lr_scheduler = dg.NoamDecay(
1 / (warmup_steps * (peak_learning_rate)**2), warmup_steps)
# =========================optimizer=========================
optim_config = config["optimizer"]
beta1 = optim_config["beta1"]
beta2 = optim_config["beta2"]
epsilon = optim_config["epsilon"]
optim = fluid.optimizer.Adam(lr_scheduler,
beta1,
beta2,
epsilon=epsilon,
parameter_list=dv3.parameters())
gradient_clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.1)
# generation
synthesis_config = config["synthesis"]
power = synthesis_config["power"]
n_iter = synthesis_config["n_iter"]
2020-02-13 10:24:34 +08:00
# =========================link(dataloader, paddle)=========================
# CAUTION: it does not return a DataLoader
loader = fluid.io.DataLoader.from_generator(capacity=10,
return_list=True)
loader.set_batch_generator(ljspeech_loader, places=place)
# tensorboard & checkpoint preparation
output_dir = args.output
ckpt_dir = os.path.join(output_dir, "checkpoints")
log_dir = os.path.join(output_dir, "log")
state_dir = os.path.join(output_dir, "states")
make_output_tree(output_dir)
writer = SummaryWriter(logdir=log_dir)
# load model parameters
resume_path = args.resume
if resume_path is not None:
state, _ = dg.load_dygraph(args.resume)
dv3.set_dict(state)
# =========================train=========================
epoch = train_config["epochs"]
snap_interval = train_config["snap_interval"]
save_interval = train_config["save_interval"]
eval_interval = train_config["eval_interval"]
global_step = 1
for j in range(1, 1 + epoch):
epoch_loss = 0.
2020-02-13 10:24:34 +08:00
for i, batch in tqdm.tqdm(enumerate(loader, 1)):
dv3.train() # CAUTION: don't forget to switch to train
(text_sequences, text_lengths, text_positions, mel_specs,
lin_specs, frames, decoder_positions, done_flags) = batch
downsampled_mel_specs = F.strided_slice(
mel_specs,
axes=[1],
starts=[0],
ends=[mel_specs.shape[1]],
strides=[downsample_factor])
mel_outputs, linear_outputs, alignments, done = dv3(
text_sequences, text_positions, text_lengths, None,
downsampled_mel_specs, decoder_positions)
losses = criterion(mel_outputs, linear_outputs, done,
alignments, downsampled_mel_specs,
lin_specs, done_flags, text_lengths, frames)
l = losses["loss"]
2020-02-13 10:24:34 +08:00
l.backward()
# record learning rate before updating
writer.add_scalar("learning_rate",
optim._learning_rate.step().numpy(),
global_step)
2020-02-13 10:24:34 +08:00
optim.minimize(l, grad_clip=gradient_clipper)
optim.clear_gradients()
2020-02-13 10:24:34 +08:00
# ==================all kinds of tedious things=================
# record step loss into tensorboard
epoch_loss += l.numpy()[0]
2020-02-13 10:24:34 +08:00
step_loss = {k: v.numpy()[0] for k, v in losses.items()}
for k, v in step_loss.items():
writer.add_scalar(k, v, global_step)
# TODO: clean code
# train state saving, the first sentence in the batch
if global_step % snap_interval == 0:
save_state(state_dir,
writer,
2020-02-13 10:24:34 +08:00
global_step,
mel_input=downsampled_mel_specs,
mel_output=mel_outputs,
lin_input=lin_specs,
lin_output=linear_outputs,
alignments=alignments,
win_length=win_length,
hop_length=hop_length,
min_level_db=min_level_db,
ref_level_db=ref_level_db,
power=power,
n_iter=n_iter,
preemphasis=preemphasis,
sample_rate=sample_rate)
2020-02-13 10:24:34 +08:00
# evaluation
if global_step % eval_interval == 0:
sentences = [
"Scientists at the CERN laboratory say they have discovered a new particle.",
"There's a way to measure the acute emotional intelligence that has never gone out of style.",
"President Trump met with other leaders at the Group of 20 conference.",
"Generative adversarial network or variational auto-encoder.",
"Please call Stella.",
"Some have accepted this as a miracle without any physical explanation.",
]
for idx, sent in enumerate(sentences):
2020-02-13 10:24:34 +08:00
wav, attn = eval_model(dv3, sent,
replace_pronounciation_prob,
min_level_db, ref_level_db,
power, n_iter, win_length,
hop_length, preemphasis)
wav_path = os.path.join(
state_dir, "waveform",
"eval_sample_{:09d}.wav".format(global_step))
sf.write(wav_path, wav, sample_rate)
writer.add_audio("eval_sample_{}".format(idx),
wav,
global_step,
sample_rate=sample_rate)
2020-02-13 10:24:34 +08:00
attn_path = os.path.join(
state_dir, "alignments",
"eval_sample_attn_{:09d}.png".format(global_step))
plot_alignment(attn, attn_path)
writer.add_image("eval_sample_attn{}".format(idx),
cm.viridis(attn),
global_step,
dataformats="HWC")
2020-02-13 10:24:34 +08:00
# save checkpoint
if global_step % save_interval == 0:
dg.save_dygraph(
dv3.state_dict(),
os.path.join(ckpt_dir,
"model_step_{}".format(global_step)))
dg.save_dygraph(
optim.state_dict(),
os.path.join(ckpt_dir,
"model_step_{}".format(global_step)))
2020-02-13 10:24:34 +08:00
global_step += 1
# epoch report
writer.add_scalar("epoch_average_loss", epoch_loss / i, j)
epoch_loss = 0.