2019-12-13 09:58:10 +08:00
|
|
|
import itertools
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import paddle.fluid.dygraph as dg
|
|
|
|
from paddle import fluid
|
2019-12-19 16:03:06 +08:00
|
|
|
from scipy.io.wavfile import write
|
2019-12-13 09:58:10 +08:00
|
|
|
|
|
|
|
import utils
|
|
|
|
from data import LJSpeech
|
|
|
|
from waveflow_modules import WaveFlowLoss, WaveFlowModule
|
|
|
|
|
|
|
|
|
|
|
|
class WaveFlow():
|
|
|
|
def __init__(self, config, checkpoint_dir, parallel=False, rank=0,
|
|
|
|
nranks=1, tb_logger=None):
|
|
|
|
self.config = config
|
|
|
|
self.checkpoint_dir = checkpoint_dir
|
|
|
|
self.parallel = parallel
|
|
|
|
self.rank = rank
|
|
|
|
self.nranks = nranks
|
|
|
|
self.tb_logger = tb_logger
|
|
|
|
|
|
|
|
def build(self, training=True):
|
|
|
|
config = self.config
|
|
|
|
dataset = LJSpeech(config, self.nranks, self.rank)
|
|
|
|
self.trainloader = dataset.trainloader
|
|
|
|
self.validloader = dataset.validloader
|
|
|
|
|
|
|
|
waveflow = WaveFlowModule("waveflow", config)
|
|
|
|
|
|
|
|
# Dry run once to create and initalize all necessary parameters.
|
|
|
|
audio = dg.to_variable(np.random.randn(1, 16000).astype(np.float32))
|
|
|
|
mel = dg.to_variable(
|
|
|
|
np.random.randn(1, config.mel_bands, 63).astype(np.float32))
|
|
|
|
waveflow(audio, mel)
|
|
|
|
|
|
|
|
if training:
|
|
|
|
optimizer = fluid.optimizer.AdamOptimizer(
|
|
|
|
learning_rate=config.learning_rate)
|
|
|
|
|
|
|
|
# Load parameters.
|
|
|
|
utils.load_parameters(self.checkpoint_dir, self.rank,
|
|
|
|
waveflow, optimizer,
|
|
|
|
iteration=config.iteration,
|
|
|
|
file_path=config.checkpoint)
|
|
|
|
print("Rank {}: checkpoint loaded.".format(self.rank))
|
|
|
|
|
|
|
|
# Data parallelism.
|
|
|
|
if self.parallel:
|
|
|
|
strategy = dg.parallel.prepare_context()
|
|
|
|
waveflow = dg.parallel.DataParallel(waveflow, strategy)
|
|
|
|
|
|
|
|
self.waveflow = waveflow
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.criterion = WaveFlowLoss(config.sigma)
|
|
|
|
|
|
|
|
else:
|
|
|
|
# Load parameters.
|
|
|
|
utils.load_parameters(self.checkpoint_dir, self.rank, waveflow,
|
|
|
|
iteration=config.iteration,
|
|
|
|
file_path=config.checkpoint)
|
|
|
|
print("Rank {}: checkpoint loaded.".format(self.rank))
|
|
|
|
|
|
|
|
self.waveflow = waveflow
|
|
|
|
|
|
|
|
def train_step(self, iteration):
|
|
|
|
self.waveflow.train()
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
audios, mels = next(self.trainloader)
|
|
|
|
load_time = time.time()
|
|
|
|
|
|
|
|
outputs = self.waveflow(audios, mels)
|
|
|
|
loss = self.criterion(outputs)
|
|
|
|
|
|
|
|
if self.parallel:
|
|
|
|
# loss = loss / num_trainers
|
|
|
|
loss = self.waveflow.scale_loss(loss)
|
|
|
|
loss.backward()
|
|
|
|
self.waveflow.apply_collective_grads()
|
|
|
|
else:
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
self.optimizer.minimize(loss, parameter_list=self.waveflow.parameters())
|
|
|
|
self.waveflow.clear_gradients()
|
|
|
|
|
|
|
|
graph_time = time.time()
|
|
|
|
|
|
|
|
if self.rank == 0:
|
|
|
|
loss_val = float(loss.numpy()) * self.nranks
|
|
|
|
log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \
|
|
|
|
"Time: {:.3f}/{:.3f}".format(
|
|
|
|
self.rank, iteration, loss_val,
|
|
|
|
load_time - start_time, graph_time - load_time)
|
|
|
|
print(log)
|
|
|
|
|
|
|
|
tb = self.tb_logger
|
|
|
|
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
|
|
|
|
|
|
|
|
@dg.no_grad
|
|
|
|
def valid_step(self, iteration):
|
|
|
|
self.waveflow.eval()
|
|
|
|
tb = self.tb_logger
|
|
|
|
|
|
|
|
total_loss = []
|
|
|
|
sample_audios = []
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
for i, batch in enumerate(self.validloader()):
|
|
|
|
audios, mels = batch
|
|
|
|
valid_outputs = self.waveflow(audios, mels)
|
|
|
|
valid_z, valid_log_s_list = valid_outputs
|
|
|
|
|
|
|
|
# Visualize latent z and scale log_s.
|
|
|
|
if self.rank == 0 and i == 0:
|
|
|
|
tb.add_histogram("Valid-Latent_z", valid_z.numpy(), iteration)
|
|
|
|
for j, valid_log_s in enumerate(valid_log_s_list):
|
|
|
|
hist_name = "Valid-{}th-Flow-Log_s".format(j)
|
|
|
|
tb.add_histogram(hist_name, valid_log_s.numpy(), iteration)
|
|
|
|
|
|
|
|
valid_loss = self.criterion(valid_outputs)
|
|
|
|
total_loss.append(float(valid_loss.numpy()))
|
|
|
|
|
|
|
|
total_time = time.time() - start_time
|
|
|
|
if self.rank == 0:
|
|
|
|
loss_val = np.mean(total_loss)
|
|
|
|
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
|
|
|
|
self.rank, loss_val, total_time)
|
|
|
|
print(log)
|
|
|
|
tb.add_scalar("Valid-Avg-Loss", loss_val, iteration)
|
|
|
|
|
|
|
|
@dg.no_grad
|
|
|
|
def infer(self, iteration):
|
|
|
|
self.waveflow.eval()
|
|
|
|
|
|
|
|
config = self.config
|
|
|
|
sample = config.sample
|
|
|
|
|
|
|
|
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
|
|
|
|
os.makedirs(output, exist_ok=True)
|
|
|
|
|
2019-12-17 08:42:39 +08:00
|
|
|
mels_list = [mels for _, mels in self.validloader()]
|
|
|
|
if sample is not None:
|
|
|
|
mels_list = [mels_list[sample]]
|
2019-12-13 09:58:10 +08:00
|
|
|
|
2019-12-17 08:42:39 +08:00
|
|
|
for sample, mel in enumerate(mels_list):
|
|
|
|
filename = "{}/valid_{}.wav".format(output, sample)
|
|
|
|
print("Synthesize sample {}, save as {}".format(sample, filename))
|
|
|
|
|
|
|
|
start_time = time.time()
|
2019-12-19 16:03:06 +08:00
|
|
|
audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
|
2019-12-17 08:42:39 +08:00
|
|
|
syn_time = time.time() - start_time
|
|
|
|
|
2019-12-19 16:03:06 +08:00
|
|
|
audio = audio[0]
|
|
|
|
audio_time = audio.shape[0] / self.config.sample_rate
|
|
|
|
print("audio time {:.4f}, synthesis time {:.4f}".format(
|
|
|
|
audio_time, syn_time))
|
2019-12-17 08:42:39 +08:00
|
|
|
|
2019-12-19 16:03:06 +08:00
|
|
|
# Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
|
2019-12-17 08:42:39 +08:00
|
|
|
audio = audio.numpy() * 32768.0
|
|
|
|
audio = audio.astype('int16')
|
|
|
|
write(filename, config.sample_rate, audio)
|
|
|
|
|
2019-12-19 16:03:06 +08:00
|
|
|
@dg.no_grad
|
|
|
|
def benchmark(self):
|
|
|
|
self.waveflow.eval()
|
|
|
|
|
|
|
|
mels_list = [mels for _, mels in self.validloader()]
|
|
|
|
mel = fluid.layers.concat(mels_list, axis=2)
|
|
|
|
mel = mel[:, :, :864]
|
|
|
|
batch_size = 8
|
|
|
|
mel = fluid.layers.expand(mel, [batch_size, 1, 1])
|
2019-12-17 08:42:39 +08:00
|
|
|
|
2019-12-19 16:03:06 +08:00
|
|
|
for i in range(10):
|
|
|
|
start_time = time.time()
|
|
|
|
audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
|
|
|
|
print("audio.shape = ", audio.shape)
|
|
|
|
syn_time = time.time() - start_time
|
2019-12-17 08:42:39 +08:00
|
|
|
|
2019-12-19 16:03:06 +08:00
|
|
|
audio_time = audio.shape[1] * batch_size / self.config.sample_rate
|
|
|
|
print("audio time {:.4f}, synthesis time {:.4f}".format(
|
|
|
|
audio_time, syn_time))
|
|
|
|
print("{} X real-time".format(audio_time / syn_time))
|
2019-12-13 09:58:10 +08:00
|
|
|
|
|
|
|
def save(self, iteration):
|
|
|
|
utils.save_latest_parameters(self.checkpoint_dir, iteration,
|
|
|
|
self.waveflow, self.optimizer)
|
|
|
|
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
|