ParakeetEricRoss/parakeet/models/wavenet/wavenet.py

177 lines
6.3 KiB
Python

import itertools
import os
import time
import librosa
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
import utils
from data import LJSpeech
from wavenet_modules import WaveNetModule
class WaveNet():
def __init__(self, config, checkpoint_dir, parallel=False, rank=0,
nranks=1, tb_logger=None):
# Process config to calculate the context size
dilations = list(
itertools.islice(
itertools.cycle(config.dilation_block), config.layers))
config.context_size = sum(dilations) + 1
self.config = config
self.checkpoint_dir = checkpoint_dir
self.parallel = parallel
self.rank = rank
self.nranks = nranks
self.tb_logger = tb_logger
def build(self, training=True):
config = self.config
dataset = LJSpeech(config, self.nranks, self.rank)
self.trainloader = dataset.trainloader
self.validloader = dataset.validloader
wavenet = WaveNetModule("wavenet", config, self.rank)
# Dry run once to create and initalize all necessary parameters.
audio = dg.to_variable(np.random.randn(1, 20000).astype(np.float32))
mel = dg.to_variable(
np.random.randn(1, 100, self.config.mel_bands).astype(np.float32))
audio_start = dg.to_variable(np.array([0], dtype=np.int32))
wavenet(audio, mel, audio_start)
if training:
# Create Learning rate scheduler.
lr_scheduler = dg.ExponentialDecay(
learning_rate = config.learning_rate,
decay_steps = config.anneal.every,
decay_rate = config.anneal.rate,
staircase=True)
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=lr_scheduler)
clipper = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
config.gradient_max_norm)
# Load parameters.
utils.load_parameters(self.checkpoint_dir, self.rank,
wavenet, optimizer,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
# Data parallelism.
if self.parallel:
strategy = dg.parallel.prepare_context()
wavenet = dg.parallel.DataParallel(wavenet, strategy)
self.wavenet = wavenet
self.optimizer = optimizer
self.clipper = clipper
else:
# Load parameters.
utils.load_parameters(self.checkpoint_dir, self.rank, wavenet,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
self.wavenet = wavenet
def train_step(self, iteration):
self.wavenet.train()
start_time = time.time()
audios, mels, audio_starts = next(self.trainloader)
load_time = time.time()
loss, _ = self.wavenet(audios, mels, audio_starts)
if self.parallel:
# loss = loss / num_trainers
loss = self.wavenet.scale_loss(loss)
loss.backward()
self.wavenet.apply_collective_grads()
else:
loss.backward()
if isinstance(self.optimizer._learning_rate,
fluid.optimizer.LearningRateDecay):
current_lr = self.optimizer._learning_rate.step().numpy()
else:
current_lr = self.optimizer._learning_rate
self.optimizer.minimize(loss, grad_clip=self.clipper,
parameter_list=self.wavenet.parameters())
self.wavenet.clear_gradients()
graph_time = time.time()
if self.rank == 0:
loss_val = float(loss.numpy()) * self.nranks
log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \
"Time: {:.3f}/{:.3f}".format(
self.rank, iteration, loss_val,
load_time - start_time, graph_time - load_time)
print(log)
tb = self.tb_logger
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
tb.add_scalar("Learning-Rate", current_lr, iteration)
@dg.no_grad
def valid_step(self, iteration):
self.wavenet.eval()
total_loss = []
sample_audios = []
start_time = time.time()
for audios, mels, audio_starts in self.validloader():
loss, sample_audio = self.wavenet(audios, mels, audio_starts, True)
total_loss.append(float(loss.numpy()))
sample_audios.append(sample_audio)
total_time = time.time() - start_time
if self.rank == 0:
loss_val = np.mean(total_loss)
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
self.rank, loss_val, total_time)
print(log)
tb = self.tb_logger
tb.add_scalar("Valid-Avg-Loss", loss_val, iteration)
tb.add_audio("Teacher-Forced-Audio-0", sample_audios[0].numpy(),
iteration, sample_rate=self.config.sample_rate)
tb.add_audio("Teacher-Forced-Audio-1", sample_audios[1].numpy(),
iteration, sample_rate=self.config.sample_rate)
@dg.no_grad
def infer(self, iteration):
self.wavenet.eval()
config = self.config
sample = config.sample
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
os.makedirs(output, exist_ok=True)
filename = "{}/valid_{}.wav".format(output, sample)
print("Synthesize sample {}, save as {}".format(sample, filename))
mels_list = [mels for _, mels, _ in self.validloader()]
start_time = time.time()
syn_audio = self.wavenet.synthesize(mels_list[sample])
syn_time = time.time() - start_time
print("audio shape {}, synthesis time {}".format(
syn_audio.shape, syn_time))
librosa.output.write_wav(filename, syn_audio,
sr=config.sample_rate)
def save(self, iteration):
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.wavenet, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)