Parakeet/parakeet/models/waveflow/waveflow.py

297 lines
9.6 KiB
Python

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import os
import time
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
from scipy.io.wavfile import write
import utils
from parakeet.modules import weight_norm
from .data import LJSpeech
from .waveflow_modules import WaveFlowLoss, WaveFlowModule
class WaveFlow():
"""Wrapper class of WaveFlow model that supports multiple APIs.
This module provides APIs for model building, training, validation,
inference, benchmarking, and saving.
Args:
config (obj): config info.
checkpoint_dir (str): path for checkpointing.
parallel (bool, optional): whether use multiple GPUs for training.
Defaults to False.
rank (int, optional): the rank of the process in a multi-process
scenario. Defaults to 0.
nranks (int, optional): the total number of processes. Defaults to 1.
tb_logger (obj, optional): logger to visualize metrics.
Defaults to None.
Returns:
WaveFlow
"""
def __init__(self,
config,
checkpoint_dir,
parallel=False,
rank=0,
nranks=1,
tb_logger=None):
self.config = config
self.checkpoint_dir = checkpoint_dir
self.parallel = parallel
self.rank = rank
self.nranks = nranks
self.tb_logger = tb_logger
self.dtype = "float16" if config.use_fp16 else "float32"
def build(self, training=True):
"""Initialize the model.
Args:
training (bool, optional): Whether the model is built for training or inference.
Defaults to True.
Returns:
None
"""
config = self.config
dataset = LJSpeech(config, self.nranks, self.rank)
self.trainloader = dataset.trainloader
self.validloader = dataset.validloader
waveflow = WaveFlowModule(config)
# Dry run once to create and initalize all necessary parameters.
audio = dg.to_variable(np.random.randn(1, 16000).astype(self.dtype))
mel = dg.to_variable(
np.random.randn(1, config.mel_bands, 63).astype(self.dtype))
waveflow(audio, mel)
if training:
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate=config.learning_rate,
parameter_list=waveflow.parameters())
# Load parameters.
utils.load_parameters(
self.checkpoint_dir,
self.rank,
waveflow,
optimizer,
iteration=config.iteration,
file_path=config.checkpoint)
print("Rank {}: checkpoint loaded.".format(self.rank))
# Data parallelism.
if self.parallel:
strategy = dg.parallel.prepare_context()
waveflow = dg.parallel.DataParallel(waveflow, strategy)
self.waveflow = waveflow
self.optimizer = optimizer
self.criterion = WaveFlowLoss(config.sigma)
else:
# Load parameters.
utils.load_parameters(
self.checkpoint_dir,
self.rank,
waveflow,
iteration=config.iteration,
file_path=config.checkpoint,
dtype=self.dtype)
print("Rank {}: checkpoint loaded.".format(self.rank))
for layer in waveflow.sublayers():
if isinstance(layer, weight_norm.WeightNormWrapper):
layer.remove_weight_norm()
self.waveflow = waveflow
def train_step(self, iteration):
"""Train the model for one step.
Args:
iteration (int): current iteration number.
Returns:
None
"""
self.waveflow.train()
start_time = time.time()
audios, mels = next(self.trainloader)
load_time = time.time()
outputs = self.waveflow(audios, mels)
loss = self.criterion(outputs)
if self.parallel:
# loss = loss / num_trainers
loss = self.waveflow.scale_loss(loss)
loss.backward()
self.waveflow.apply_collective_grads()
else:
loss.backward()
self.optimizer.minimize(
loss, parameter_list=self.waveflow.parameters())
self.waveflow.clear_gradients()
graph_time = time.time()
if self.rank == 0:
loss_val = float(loss.numpy()) * self.nranks
log = "Rank: {} Step: {:^8d} Loss: {:<8.3f} " \
"Time: {:.3f}/{:.3f}".format(
self.rank, iteration, loss_val,
load_time - start_time, graph_time - load_time)
print(log)
tb = self.tb_logger
tb.add_scalar("Train-Loss-Rank-0", loss_val, iteration)
@dg.no_grad
def valid_step(self, iteration):
"""Run the model on the validation dataset.
Args:
iteration (int): current iteration number.
Returns:
None
"""
self.waveflow.eval()
tb = self.tb_logger
total_loss = []
sample_audios = []
start_time = time.time()
for i, batch in enumerate(self.validloader()):
audios, mels = batch
valid_outputs = self.waveflow(audios, mels)
valid_z, valid_log_s_list = valid_outputs
# Visualize latent z and scale log_s.
if self.rank == 0 and i == 0:
tb.add_histogram("Valid-Latent_z", valid_z.numpy(), iteration)
for j, valid_log_s in enumerate(valid_log_s_list):
hist_name = "Valid-{}th-Flow-Log_s".format(j)
tb.add_histogram(hist_name, valid_log_s.numpy(), iteration)
valid_loss = self.criterion(valid_outputs)
total_loss.append(float(valid_loss.numpy()))
total_time = time.time() - start_time
if self.rank == 0:
loss_val = np.mean(total_loss)
log = "Test | Rank: {} AvgLoss: {:<8.3f} Time {:<8.3f}".format(
self.rank, loss_val, total_time)
print(log)
tb.add_scalar("Valid-Avg-Loss", loss_val, iteration)
@dg.no_grad
def infer(self, iteration):
"""Run the model to synthesize audios.
Args:
iteration (int): iteration number of the loaded checkpoint.
Returns:
None
"""
self.waveflow.eval()
config = self.config
sample = config.sample
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
if not os.path.exists(output):
os.makedirs(output)
mels_list = [mels for _, mels in self.validloader()]
if sample is not None:
mels_list = [mels_list[sample]]
else:
sample = 0
for idx, mel in enumerate(mels_list):
abs_idx = sample + idx
filename = "{}/valid_{}.wav".format(output, abs_idx)
print("Synthesize sample {}, save as {}".format(abs_idx, filename))
start_time = time.time()
audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
syn_time = time.time() - start_time
audio = audio[0]
audio_time = audio.shape[0] / self.config.sample_rate
print("audio time {:.4f}, synthesis time {:.4f}".format(audio_time,
syn_time))
# Denormalize audio from [-1, 1] to [-32768, 32768] int16 range.
audio = audio.numpy().astype("float32") * 32768.0
audio = audio.astype('int16')
write(filename, config.sample_rate, audio)
@dg.no_grad
def benchmark(self):
"""Run the model to benchmark synthesis speed.
Args:
None
Returns:
None
"""
self.waveflow.eval()
mels_list = [mels for _, mels in self.validloader()]
mel = fluid.layers.concat(mels_list, axis=2)
mel = mel[:, :, :864]
batch_size = 8
mel = fluid.layers.expand(mel, [batch_size, 1, 1])
for i in range(10):
start_time = time.time()
audio = self.waveflow.synthesize(mel, sigma=self.config.sigma)
print("audio.shape = ", audio.shape)
syn_time = time.time() - start_time
audio_time = audio.shape[1] * batch_size / self.config.sample_rate
print("audio time {:.4f}, synthesis time {:.4f}".format(audio_time,
syn_time))
print("{} X real-time".format(audio_time / syn_time))
def save(self, iteration):
"""Save model checkpoint.
Args:
iteration (int): iteration number of the model to be saved.
Returns:
None
"""
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.waveflow, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)