Parakeet/examples/deepvoice3/train.py

187 lines
7.7 KiB
Python
Raw Normal View History

2020-07-10 20:22:43 +08:00
import numpy as np
from matplotlib import cm
import librosa
2020-02-13 10:24:34 +08:00
import os
2020-07-10 20:22:43 +08:00
import time
2020-02-13 10:24:34 +08:00
import tqdm
2020-07-10 20:22:43 +08:00
import paddle
2020-02-13 10:24:34 +08:00
from paddle import fluid
2020-07-10 20:22:43 +08:00
from paddle.fluid import layers as F
from paddle.fluid import initializer as I
2020-07-10 20:22:43 +08:00
from paddle.fluid import dygraph as dg
from paddle.fluid.io import DataLoader
2020-08-07 16:28:21 +08:00
from visualdl import LogWriter
2020-03-26 10:58:16 +08:00
2020-07-10 20:22:43 +08:00
from parakeet.models.deepvoice3 import Encoder, Decoder, PostNet, SpectraNet
from parakeet.data import SliceDataset, DataCargo, SequentialSampler, RandomSampler
2020-07-10 20:22:43 +08:00
from parakeet.utils.io import save_parameters, load_parameters
from parakeet.g2p import en
from data import LJSpeech, DataCollector
from vocoder import WaveflowVocoder, GriffinLimVocoder
from clip import DoubleClip
def create_model(config):
char_embedding = dg.Embedding((en.n_vocab, config["char_dim"]), param_attr=I.Normal(scale=0.1))
2020-07-10 20:22:43 +08:00
multi_speaker = config["n_speakers"] > 1
speaker_embedding = dg.Embedding((config["n_speakers"], config["speaker_dim"]), param_attr=I.Normal(scale=0.1)) \
2020-07-10 20:22:43 +08:00
if multi_speaker else None
encoder = Encoder(config["encoder_layers"], config["char_dim"],
config["encoder_dim"], config["kernel_size"],
has_bias=multi_speaker, bias_dim=config["speaker_dim"],
keep_prob=1.0 - config["dropout"])
decoder = Decoder(config["n_mels"], config["reduction_factor"],
list(config["prenet_sizes"]) + [config["char_dim"]],
config["decoder_layers"], config["kernel_size"],
config["attention_dim"],
position_encoding_weight=config["position_weight"],
omega=config["position_rate"],
has_bias=multi_speaker, bias_dim=config["speaker_dim"],
keep_prob=1.0 - config["dropout"])
postnet = PostNet(config["postnet_layers"], config["char_dim"],
config["postnet_dim"], config["kernel_size"],
config["n_mels"], config["reduction_factor"],
has_bias=multi_speaker, bias_dim=config["speaker_dim"],
keep_prob=1.0 - config["dropout"])
spectranet = SpectraNet(char_embedding, speaker_embedding, encoder, decoder, postnet)
return spectranet
def create_data(config, data_path):
dataset = LJSpeech(data_path)
train_dataset = SliceDataset(dataset, config["valid_size"], len(dataset))
train_collator = DataCollector(config["p_pronunciation"])
train_sampler = RandomSampler(train_dataset)
2020-07-10 20:22:43 +08:00
train_cargo = DataCargo(train_dataset, train_collator,
batch_size=config["batch_size"], sampler=train_sampler)
train_loader = DataLoader\
.from_generator(capacity=10, return_list=True)\
.set_batch_generator(train_cargo)
valid_dataset = SliceDataset(dataset, 0, config["valid_size"])
valid_collector = DataCollector(1.)
valid_sampler = SequentialSampler(valid_dataset)
valid_cargo = DataCargo(valid_dataset, valid_collector,
batch_size=1, sampler=valid_sampler)
valid_loader = DataLoader\
.from_generator(capacity=2, return_list=True)\
.set_batch_generator(valid_cargo)
return train_loader, valid_loader
def create_optimizer(model, config):
optim = fluid.optimizer.Adam(config["learning_rate"],
parameter_list=model.parameters(),
grad_clip=DoubleClip(config["clip_value"], config["clip_norm"]))
return optim
def train(args, config):
model = create_model(config)
train_loader, valid_loader = create_data(config, args.input)
optim = create_optimizer(model, config)
global global_step
max_iteration = config["max_iteration"]
2020-07-10 20:22:43 +08:00
iterator = iter(tqdm.tqdm(train_loader))
while global_step <= max_iteration:
# get inputs
try:
batch = next(iterator)
2020-07-10 20:22:43 +08:00
except StopIteration:
iterator = iter(tqdm.tqdm(train_loader))
batch = next(iterator)
2020-07-10 20:22:43 +08:00
# unzip it
text_seqs, text_lengths, specs, mels, num_frames = batch
2020-07-10 20:22:43 +08:00
# forward & backward
model.train()
2020-07-10 20:22:43 +08:00
outputs = model(text_seqs, text_lengths, speakers=None, mel=mels)
decoded, refined, attentions, final_state = outputs
causal_mel_loss = model.spec_loss(decoded, mels, num_frames)
non_causal_mel_loss = model.spec_loss(refined, mels, num_frames)
loss = causal_mel_loss + non_causal_mel_loss
loss.backward()
# update
optim.minimize(loss)
# logging
tqdm.tqdm.write("[train] step: {}\tloss: {:.6f}\tcausal:{:.6f}\tnon_causal:{:.6f}".format(
global_step,
loss.numpy()[0],
causal_mel_loss.numpy()[0],
non_causal_mel_loss.numpy()[0]))
2020-08-20 13:26:15 +08:00
writer.add_scalar("loss/causal_mel_loss", causal_mel_loss.numpy()[0], step=global_step)
writer.add_scalar("loss/non_causal_mel_loss", non_causal_mel_loss.numpy()[0], step=global_step)
writer.add_scalar("loss/loss", loss.numpy()[0], step=global_step)
2020-07-10 20:22:43 +08:00
if global_step % config["report_interval"] == 0:
text_length = int(text_lengths.numpy()[0])
num_frame = int(num_frames.numpy()[0])
tag = "train_mel/ground-truth"
img = cm.viridis(normalize(mels.numpy()[0, :num_frame].T))
2020-08-20 13:26:15 +08:00
writer.add_image(tag, img, step=global_step)
2020-07-10 20:22:43 +08:00
tag = "train_mel/decoded"
img = cm.viridis(normalize(decoded.numpy()[0, :num_frame].T))
2020-08-20 13:26:15 +08:00
writer.add_image(tag, img, step=global_step)
2020-07-10 20:22:43 +08:00
tag = "train_mel/refined"
img = cm.viridis(normalize(refined.numpy()[0, :num_frame].T))
2020-08-20 13:26:15 +08:00
writer.add_image(tag, img, step=global_step)
2020-07-10 20:22:43 +08:00
vocoder = WaveflowVocoder()
vocoder.model.eval()
tag = "train_audio/ground-truth-waveflow"
wav = vocoder(F.transpose(mels[0:1, :num_frame, :], (0, 2, 1)))
2020-08-20 13:26:15 +08:00
writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
2020-07-10 20:22:43 +08:00
tag = "train_audio/decoded-waveflow"
wav = vocoder(F.transpose(decoded[0:1, :num_frame, :], (0, 2, 1)))
2020-08-20 13:26:15 +08:00
writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
2020-07-10 20:22:43 +08:00
tag = "train_audio/refined-waveflow"
wav = vocoder(F.transpose(refined[0:1, :num_frame, :], (0, 2, 1)))
2020-08-20 13:26:15 +08:00
writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
2020-07-10 20:22:43 +08:00
attentions_np = attentions.numpy()
attentions_np = attentions_np[:, 0, :num_frame // 4 , :text_length]
for i, attention_layer in enumerate(np.rot90(attentions_np, axes=(1,2))):
tag = "train_attention/layer_{}".format(i)
img = cm.viridis(normalize(attention_layer))
2020-08-20 13:26:15 +08:00
writer.add_image(tag, img, step=global_step, dataformats="HWC")
2020-07-10 20:22:43 +08:00
if global_step % config["save_interval"] == 0:
save_parameters(writer.logdir, global_step, model, optim)
# global step +1
global_step += 1
2020-07-10 20:22:43 +08:00
def normalize(arr):
return (arr - arr.min()) / (arr.max() - arr.min())
if __name__ == "__main__":
import argparse
from ruamel import yaml
parser = argparse.ArgumentParser(description="train a Deep Voice 3 model with LJSpeech")
parser.add_argument("--config", type=str, required=True, help="config file")
parser.add_argument("--input", type=str, required=True, help="data path of the original data")
args = parser.parse_args()
with open(args.config, 'rt') as f:
config = yaml.safe_load(f)
dg.enable_dygraph(fluid.CUDAPlace(0))
global global_step
global_step = 1
global writer
2020-08-07 16:28:21 +08:00
writer = LogWriter()
2020-07-10 20:22:43 +08:00
print("[Training] tensorboard log and checkpoints are save in {}".format(
writer.logdir))
train(args, config)