2020-07-10 20:22:43 +08:00
|
|
|
import numpy as np
|
|
|
|
from matplotlib import cm
|
|
|
|
import librosa
|
2020-02-13 10:24:34 +08:00
|
|
|
import os
|
2020-07-10 20:22:43 +08:00
|
|
|
import time
|
2020-02-13 10:24:34 +08:00
|
|
|
import tqdm
|
2020-07-10 20:22:43 +08:00
|
|
|
import paddle
|
2020-02-13 10:24:34 +08:00
|
|
|
from paddle import fluid
|
2020-07-10 20:22:43 +08:00
|
|
|
from paddle.fluid import layers as F
|
2020-07-29 11:54:47 +08:00
|
|
|
from paddle.fluid import initializer as I
|
2020-07-10 20:22:43 +08:00
|
|
|
from paddle.fluid import dygraph as dg
|
|
|
|
from paddle.fluid.io import DataLoader
|
2020-08-07 16:28:21 +08:00
|
|
|
from visualdl import LogWriter
|
2020-03-26 10:58:16 +08:00
|
|
|
|
2020-07-10 20:22:43 +08:00
|
|
|
from parakeet.models.deepvoice3 import Encoder, Decoder, PostNet, SpectraNet
|
2020-07-29 11:54:47 +08:00
|
|
|
from parakeet.data import SliceDataset, DataCargo, SequentialSampler, RandomSampler
|
2020-07-10 20:22:43 +08:00
|
|
|
from parakeet.utils.io import save_parameters, load_parameters
|
|
|
|
from parakeet.g2p import en
|
|
|
|
|
|
|
|
from data import LJSpeech, DataCollector
|
|
|
|
from vocoder import WaveflowVocoder, GriffinLimVocoder
|
|
|
|
from clip import DoubleClip
|
|
|
|
|
|
|
|
|
|
|
|
def create_model(config):
|
2020-07-29 11:54:47 +08:00
|
|
|
char_embedding = dg.Embedding((en.n_vocab, config["char_dim"]), param_attr=I.Normal(scale=0.1))
|
2020-07-10 20:22:43 +08:00
|
|
|
multi_speaker = config["n_speakers"] > 1
|
2020-07-29 11:54:47 +08:00
|
|
|
speaker_embedding = dg.Embedding((config["n_speakers"], config["speaker_dim"]), param_attr=I.Normal(scale=0.1)) \
|
2020-07-10 20:22:43 +08:00
|
|
|
if multi_speaker else None
|
|
|
|
encoder = Encoder(config["encoder_layers"], config["char_dim"],
|
|
|
|
config["encoder_dim"], config["kernel_size"],
|
|
|
|
has_bias=multi_speaker, bias_dim=config["speaker_dim"],
|
|
|
|
keep_prob=1.0 - config["dropout"])
|
|
|
|
decoder = Decoder(config["n_mels"], config["reduction_factor"],
|
|
|
|
list(config["prenet_sizes"]) + [config["char_dim"]],
|
|
|
|
config["decoder_layers"], config["kernel_size"],
|
|
|
|
config["attention_dim"],
|
|
|
|
position_encoding_weight=config["position_weight"],
|
|
|
|
omega=config["position_rate"],
|
|
|
|
has_bias=multi_speaker, bias_dim=config["speaker_dim"],
|
|
|
|
keep_prob=1.0 - config["dropout"])
|
|
|
|
postnet = PostNet(config["postnet_layers"], config["char_dim"],
|
|
|
|
config["postnet_dim"], config["kernel_size"],
|
|
|
|
config["n_mels"], config["reduction_factor"],
|
|
|
|
has_bias=multi_speaker, bias_dim=config["speaker_dim"],
|
|
|
|
keep_prob=1.0 - config["dropout"])
|
|
|
|
spectranet = SpectraNet(char_embedding, speaker_embedding, encoder, decoder, postnet)
|
|
|
|
return spectranet
|
|
|
|
|
|
|
|
def create_data(config, data_path):
|
|
|
|
dataset = LJSpeech(data_path)
|
|
|
|
|
|
|
|
train_dataset = SliceDataset(dataset, config["valid_size"], len(dataset))
|
|
|
|
train_collator = DataCollector(config["p_pronunciation"])
|
2020-07-29 11:54:47 +08:00
|
|
|
train_sampler = RandomSampler(train_dataset)
|
2020-07-10 20:22:43 +08:00
|
|
|
train_cargo = DataCargo(train_dataset, train_collator,
|
|
|
|
batch_size=config["batch_size"], sampler=train_sampler)
|
|
|
|
train_loader = DataLoader\
|
|
|
|
.from_generator(capacity=10, return_list=True)\
|
|
|
|
.set_batch_generator(train_cargo)
|
|
|
|
|
|
|
|
valid_dataset = SliceDataset(dataset, 0, config["valid_size"])
|
|
|
|
valid_collector = DataCollector(1.)
|
|
|
|
valid_sampler = SequentialSampler(valid_dataset)
|
|
|
|
valid_cargo = DataCargo(valid_dataset, valid_collector,
|
|
|
|
batch_size=1, sampler=valid_sampler)
|
|
|
|
valid_loader = DataLoader\
|
|
|
|
.from_generator(capacity=2, return_list=True)\
|
|
|
|
.set_batch_generator(valid_cargo)
|
|
|
|
return train_loader, valid_loader
|
|
|
|
|
|
|
|
def create_optimizer(model, config):
|
|
|
|
optim = fluid.optimizer.Adam(config["learning_rate"],
|
|
|
|
parameter_list=model.parameters(),
|
|
|
|
grad_clip=DoubleClip(config["clip_value"], config["clip_norm"]))
|
|
|
|
return optim
|
|
|
|
|
|
|
|
def train(args, config):
|
|
|
|
model = create_model(config)
|
|
|
|
train_loader, valid_loader = create_data(config, args.input)
|
|
|
|
optim = create_optimizer(model, config)
|
|
|
|
|
|
|
|
global global_step
|
2020-08-11 17:12:50 +08:00
|
|
|
max_iteration = config["max_iteration"]
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
iterator = iter(tqdm.tqdm(train_loader))
|
|
|
|
while global_step <= max_iteration:
|
|
|
|
# get inputs
|
2020-05-15 19:02:31 +08:00
|
|
|
try:
|
|
|
|
batch = next(iterator)
|
2020-07-10 20:22:43 +08:00
|
|
|
except StopIteration:
|
|
|
|
iterator = iter(tqdm.tqdm(train_loader))
|
2020-05-15 19:02:31 +08:00
|
|
|
batch = next(iterator)
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
# unzip it
|
|
|
|
text_seqs, text_lengths, specs, mels, num_frames = batch
|
2020-05-15 19:02:31 +08:00
|
|
|
|
2020-07-10 20:22:43 +08:00
|
|
|
# forward & backward
|
2020-05-15 19:02:31 +08:00
|
|
|
model.train()
|
2020-07-10 20:22:43 +08:00
|
|
|
outputs = model(text_seqs, text_lengths, speakers=None, mel=mels)
|
|
|
|
decoded, refined, attentions, final_state = outputs
|
|
|
|
|
|
|
|
causal_mel_loss = model.spec_loss(decoded, mels, num_frames)
|
|
|
|
non_causal_mel_loss = model.spec_loss(refined, mels, num_frames)
|
|
|
|
loss = causal_mel_loss + non_causal_mel_loss
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
# update
|
|
|
|
optim.minimize(loss)
|
|
|
|
|
|
|
|
# logging
|
|
|
|
tqdm.tqdm.write("[train] step: {}\tloss: {:.6f}\tcausal:{:.6f}\tnon_causal:{:.6f}".format(
|
|
|
|
global_step,
|
|
|
|
loss.numpy()[0],
|
|
|
|
causal_mel_loss.numpy()[0],
|
|
|
|
non_causal_mel_loss.numpy()[0]))
|
2020-08-20 13:26:15 +08:00
|
|
|
writer.add_scalar("loss/causal_mel_loss", causal_mel_loss.numpy()[0], step=global_step)
|
|
|
|
writer.add_scalar("loss/non_causal_mel_loss", non_causal_mel_loss.numpy()[0], step=global_step)
|
|
|
|
writer.add_scalar("loss/loss", loss.numpy()[0], step=global_step)
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
if global_step % config["report_interval"] == 0:
|
|
|
|
text_length = int(text_lengths.numpy()[0])
|
|
|
|
num_frame = int(num_frames.numpy()[0])
|
|
|
|
|
|
|
|
tag = "train_mel/ground-truth"
|
|
|
|
img = cm.viridis(normalize(mels.numpy()[0, :num_frame].T))
|
2020-08-20 13:26:15 +08:00
|
|
|
writer.add_image(tag, img, step=global_step)
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
tag = "train_mel/decoded"
|
|
|
|
img = cm.viridis(normalize(decoded.numpy()[0, :num_frame].T))
|
2020-08-20 13:26:15 +08:00
|
|
|
writer.add_image(tag, img, step=global_step)
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
tag = "train_mel/refined"
|
|
|
|
img = cm.viridis(normalize(refined.numpy()[0, :num_frame].T))
|
2020-08-20 13:26:15 +08:00
|
|
|
writer.add_image(tag, img, step=global_step)
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
vocoder = WaveflowVocoder()
|
|
|
|
vocoder.model.eval()
|
|
|
|
|
|
|
|
tag = "train_audio/ground-truth-waveflow"
|
|
|
|
wav = vocoder(F.transpose(mels[0:1, :num_frame, :], (0, 2, 1)))
|
2020-08-20 13:26:15 +08:00
|
|
|
writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
tag = "train_audio/decoded-waveflow"
|
|
|
|
wav = vocoder(F.transpose(decoded[0:1, :num_frame, :], (0, 2, 1)))
|
2020-08-20 13:26:15 +08:00
|
|
|
writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
tag = "train_audio/refined-waveflow"
|
|
|
|
wav = vocoder(F.transpose(refined[0:1, :num_frame, :], (0, 2, 1)))
|
2020-08-20 13:26:15 +08:00
|
|
|
writer.add_audio(tag, wav.numpy()[0], step=global_step, sample_rate=22050)
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
attentions_np = attentions.numpy()
|
|
|
|
attentions_np = attentions_np[:, 0, :num_frame // 4 , :text_length]
|
|
|
|
for i, attention_layer in enumerate(np.rot90(attentions_np, axes=(1,2))):
|
|
|
|
tag = "train_attention/layer_{}".format(i)
|
|
|
|
img = cm.viridis(normalize(attention_layer))
|
2020-08-20 13:26:15 +08:00
|
|
|
writer.add_image(tag, img, step=global_step, dataformats="HWC")
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
if global_step % config["save_interval"] == 0:
|
|
|
|
save_parameters(writer.logdir, global_step, model, optim)
|
|
|
|
|
|
|
|
# global step +1
|
2020-05-15 19:02:31 +08:00
|
|
|
global_step += 1
|
2020-07-10 20:22:43 +08:00
|
|
|
|
|
|
|
def normalize(arr):
|
|
|
|
return (arr - arr.min()) / (arr.max() - arr.min())
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import argparse
|
|
|
|
from ruamel import yaml
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="train a Deep Voice 3 model with LJSpeech")
|
|
|
|
parser.add_argument("--config", type=str, required=True, help="config file")
|
|
|
|
parser.add_argument("--input", type=str, required=True, help="data path of the original data")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
with open(args.config, 'rt') as f:
|
|
|
|
config = yaml.safe_load(f)
|
|
|
|
|
|
|
|
dg.enable_dygraph(fluid.CUDAPlace(0))
|
|
|
|
global global_step
|
|
|
|
global_step = 1
|
|
|
|
global writer
|
2020-08-07 16:28:21 +08:00
|
|
|
writer = LogWriter()
|
2020-07-10 20:22:43 +08:00
|
|
|
print("[Training] tensorboard log and checkpoints are save in {}".format(
|
|
|
|
writer.logdir))
|
|
|
|
train(args, config)
|