177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
import itertools
|
|
import os
|
|
import time
|
|
|
|
import librosa
|
|
import numpy as np
|
|
import paddle.fluid.dygraph as dg
|
|
from paddle import fluid
|
|
|
|
import utils
|
|
from data import LJSpeech
|
|
from wavenet_modules import WaveNetModule
|
|
|
|
|
|
class WaveNet():
|
|
def __init__(self, config, checkpoint_dir, parallel=False, rank=0,
|
|
nranks=1, tb_logger=None):
|
|
# Process config to calculate the context size
|
|
dilations = list(
|
|
itertools.islice(
|
|
itertools.cycle(config.dilation_block), config.layers))
|
|
config.context_size = sum(dilations) + 1
|
|
self.config = config
|
|
self.checkpoint_dir = checkpoint_dir
|
|
self.parallel = parallel
|
|
self.rank = rank
|
|
self.nranks = nranks
|
|
self.tb_logger = tb_logger
|
|
|
|
def build(self, training=True):
|
|
config = self.config
|
|
dataset = LJSpeech(config, self.nranks, self.rank)
|
|
self.trainloader = dataset.trainloader
|
|
self.validloader = dataset.validloader
|
|
|
|
wavenet = WaveNetModule("wavenet", config, self.rank)
|
|
|
|
# Dry run once to create and initalize all necessary parameters.
|
|
audio = dg.to_variable(np.random.randn(1, 20000).astype(np.float32))
|
|
mel = dg.to_variable(
|
|
np.random.randn(1, 100, self.config.mel_bands).astype(np.float32))
|
|
audio_start = dg.to_variable(np.array([0], dtype=np.int32))
|
|
wavenet(audio, mel, audio_start)
|
|
|
|
if training:
|
|
# Create Learning rate scheduler.
|
|
lr_scheduler = dg.ExponentialDecay(
|
|
learning_rate = config.learning_rate,
|
|
decay_steps = config.anneal.every,
|
|
decay_rate = config.anneal.rate,
|
|
staircase=True)
|
|
|
|
optimizer = fluid.optimizer.AdamOptimizer(
|
|
learning_rate=lr_scheduler)
|
|
|
|
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
|
|
config.gradient_max_norm)
|
|
|
|
# Load parameters.
|
|
utils.load_parameters(self.checkpoint_dir, self.rank,
|
|
wavenet, optimizer,
|
|
iteration=config.iteration,
|
|
file_path=config.checkpoint)
|
|
print("Rank {}: checkpoint loaded.".format(self.rank))
|
|
|
|
# Data parallelism.
|
|
if self.parallel:
|
|
strategy = dg.parallel.prepare_context()
|
|
wavenet = dg.parallel.DataParallel(wavenet, strategy)
|
|
|
|
self.wavenet = wavenet
|
|
self.optimizer = optimizer
|
|
self.clipper = clipper
|
|
|
|
else:
|
|
# Load parameters.
|
|
utils.load_parameters(self.checkpoint_dir, self.rank, wavenet,
|
|
iteration=config.iteration,
|
|
file_path=config.checkpoint)
|
|
print("Rank {}: checkpoint loaded.".format(self.rank))
|
|
|
|
self.wavenet = wavenet
|
|
|
|
def train_step(self, iteration):
|
|
self.wavenet.train()
|
|
|
|
start_time = time.time()
|
|
audios, mels, audio_starts = next(self.trainloader)
|
|
load_time = time.time()
|
|
|
|
loss, _ = self.wavenet(audios, mels, audio_starts)
|
|
|
|
if self.parallel:
|
|
# loss = loss / num_trainers
|
|
loss = self.wavenet.scale_loss(loss)
|
|
loss.backward()
|
|
self.wavenet.apply_collective_grads()
|
|
else:
|
|
loss.backward()
|
|
|
|
if isinstance(self.optimizer._learning_rate,
|
|
fluid.optimizer.LearningRateDecay):
|
|
current_lr = self.optimizer._learning_rate.step().numpy()
|
|
else:
|
|
current_lr = self.optimizer._learning_rate
|
|
|
|
self.optimizer.minimize(loss, grad_clip=self.clipper,
|
|
parameter_list=self.wavenet.parameters())
|
|
self.wavenet.clear_gradients()
|
|
|
|
graph_time = time.time()
|
|
|
|
if self.rank == 0:
|
|
loss_val = float(loss.numpy()) * self.nranks
|
|
log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \
|
|
"Time: {:.3f}/{:.3f}".format(
|
|
self.rank, iteration, loss_val,
|
|
load_time - start_time, graph_time - load_time)
|
|
print(log)
|
|
|
|
tb = self.tb_logger
|
|
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
|
|
tb.add_scalar("Learning-Rate", current_lr, iteration)
|
|
|
|
@dg.no_grad
|
|
def valid_step(self, iteration):
|
|
self.wavenet.eval()
|
|
|
|
total_loss = []
|
|
sample_audios = []
|
|
start_time = time.time()
|
|
for audios, mels, audio_starts in self.validloader():
|
|
loss, sample_audio = self.wavenet(audios, mels, audio_starts, True)
|
|
total_loss.append(float(loss.numpy()))
|
|
sample_audios.append(sample_audio)
|
|
total_time = time.time() - start_time
|
|
|
|
if self.rank == 0:
|
|
loss_val = np.mean(total_loss)
|
|
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
|
|
self.rank, loss_val, total_time)
|
|
print(log)
|
|
|
|
tb = self.tb_logger
|
|
tb.add_scalar("Valid-Avg-Loss", loss_val, iteration)
|
|
tb.add_audio("Teacher-Forced-Audio-0", sample_audios[0].numpy(),
|
|
iteration, sample_rate=self.config.sample_rate)
|
|
tb.add_audio("Teacher-Forced-Audio-1", sample_audios[1].numpy(),
|
|
iteration, sample_rate=self.config.sample_rate)
|
|
|
|
@dg.no_grad
|
|
def infer(self, iteration):
|
|
self.wavenet.eval()
|
|
|
|
config = self.config
|
|
sample = config.sample
|
|
|
|
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
|
|
os.makedirs(output, exist_ok=True)
|
|
|
|
filename = "{}/valid_{}.wav".format(output, sample)
|
|
print("Synthesize sample {}, save as {}".format(sample, filename))
|
|
|
|
mels_list = [mels for _, mels, _ in self.validloader()]
|
|
start_time = time.time()
|
|
syn_audio = self.wavenet.synthesize(mels_list[sample])
|
|
syn_time = time.time() - start_time
|
|
print("audio shape {}, synthesis time {}".format(
|
|
syn_audio.shape, syn_time))
|
|
librosa.output.write_wav(filename, syn_audio,
|
|
sr=config.sample_rate)
|
|
|
|
def save(self, iteration):
|
|
utils.save_latest_parameters(self.checkpoint_dir, iteration,
|
|
self.wavenet, self.optimizer)
|
|
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
|