add docstring

This commit is contained in:
Kexin Zhao 2020-03-07 23:25:04 -08:00
parent 8083da21ac
commit 4f7ded3c89
4 changed files with 156 additions and 0 deletions

View File

@ -109,6 +109,16 @@ def add_yaml_config(config):
def load_latest_checkpoint(checkpoint_dir, rank=0):
"""Get the iteration number corresponding to the latest saved checkpoint
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int, optional): the rank of the process in multi-process setting.
Defaults to 0.
Returns:
int: the latest iteration number.
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Create checkpoint index file if not exist.
if (not os.path.isfile(checkpoint_path)) and rank == 0:
@ -129,6 +139,15 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
def save_latest_checkpoint(checkpoint_dir, iteration):
"""Save the iteration number of the latest model to be checkpointed.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
Returns:
None
"""
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
# Update the latest checkpoint index.
with open(checkpoint_path, "w") as handle:
@ -142,6 +161,24 @@ def load_parameters(checkpoint_dir,
iteration=None,
file_path=None,
dtype="float32"):
"""Load a specific model checkpoint from disk.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
rank (int): the rank of the process in multi-process setting.
model (obj): model to load parameters.
optimizer (obj, optional): optimizer to load states if needed.
Defaults to None.
iteration (int, optional): if specified, load the specific checkpoint,
if not specified, load the latest one. Defaults to None.
file_path (str, optional): if specified, load the checkpoint
stored in the file_path. Defaults to None.
dtype (str, optional): precision of the model parameters.
Defaults to float32.
Returns:
None
"""
if file_path is None:
if iteration is None:
iteration = load_latest_checkpoint(checkpoint_dir, rank)
@ -165,6 +202,18 @@ def load_parameters(checkpoint_dir,
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
"""Checkpoint the latest trained model parameters.
Args:
checkpoint_dir (str): the directory where checkpoint is saved.
iteration (int): the latest iteration number.
model (obj): model to be checkpointed.
optimizer (obj, optional): optimizer to be checkpointed.
Defaults to None.
Returns:
None
"""
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
model_dict = model.state_dict()
dg.save_dygraph(model_dict, file_path)

View File

@ -80,6 +80,7 @@ class Subset(DatasetMixin):
# whole audio for valid set
pass
else:
# Randomly crop segment_length from audios in the training set.
# audio shape: [len]
if audio.shape[0] >= segment_length:
max_audio_start = audio.shape[0] - segment_length

View File

@ -28,6 +28,25 @@ from .waveflow_modules import WaveFlowLoss, WaveFlowModule
class WaveFlow():
"""Wrapper class of WaveFlow model that supports multiple APIs.
This module provides APIs for model building, training, validation,
inference, benchmarking, and saving.
Args:
config (obj): config info.
checkpoint_dir (str): path for checkpointing.
parallel (bool, optional): whether use multiple GPUs for training.
Defaults to False.
rank (int, optional): the rank of the process in a multi-process
scenario. Defaults to 0.
nranks (int, optional): the total number of processes. Defaults to 1.
tb_logger (obj, optional): logger to visualize metrics.
Defaults to None.
Returns:
WaveFlow
"""
def __init__(self,
config,
checkpoint_dir,
@ -44,6 +63,15 @@ class WaveFlow():
self.dtype = "float16" if config.use_fp16 else "float32"
def build(self, training=True):
"""Initialize the model.
Args:
training (bool, optional): Whether the model is built for training or inference.
Defaults to True.
Returns:
None
"""
config = self.config
dataset = LJSpeech(config, self.nranks, self.rank)
self.trainloader = dataset.trainloader
@ -99,6 +127,14 @@ class WaveFlow():
self.waveflow = waveflow
def train_step(self, iteration):
"""Train the model for one step.
Args:
iteration (int): current iteration number.
Returns:
None
"""
self.waveflow.train()
start_time = time.time()
@ -135,6 +171,14 @@ class WaveFlow():
@dg.no_grad
def valid_step(self, iteration):
"""Run the model on the validation dataset.
Args:
iteration (int): current iteration number.
Returns:
None
"""
self.waveflow.eval()
tb = self.tb_logger
@ -167,6 +211,14 @@ class WaveFlow():
@dg.no_grad
def infer(self, iteration):
"""Run the model to synthesize audios.
Args:
iteration (int): iteration number of the loaded checkpoint.
Returns:
None
"""
self.waveflow.eval()
config = self.config
@ -203,6 +255,14 @@ class WaveFlow():
@dg.no_grad
def benchmark(self):
"""Run the model to benchmark synthesis speed.
Args:
None
Returns:
None
"""
self.waveflow.eval()
mels_list = [mels for _, mels in self.validloader()]
@ -223,6 +283,14 @@ class WaveFlow():
print("{} X real-time".format(audio_time / syn_time))
def save(self, iteration):
"""Save model checkpoint.
Args:
iteration (int): iteration number of the model to be saved.
Returns:
None
"""
utils.save_latest_parameters(self.checkpoint_dir, iteration,
self.waveflow, self.optimizer)
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)

View File

@ -293,6 +293,14 @@ class Flow(dg.Layer):
class WaveFlowModule(dg.Layer):
"""WaveFlow model implementation.
Args:
config (obj): model configuration parameters.
Returns:
WaveFlowModule
"""
def __init__(self, config):
super(WaveFlowModule, self).__init__()
self.n_flows = config.n_flows
@ -321,6 +329,22 @@ class WaveFlowModule(dg.Layer):
self.perms.append(perm)
def forward(self, audio, mel):
"""Training forward pass.
Use a conditioner to upsample mel spectrograms into hidden states.
These hidden states along with the audio are passed to a stack of Flow
modules to obtain the final latent variable z and a list of log scaling
variables, which are then passed to the WaveFlowLoss module to calculate
the negative log likelihood.
Args:
audio (obj): audio samples.
mel (obj): mel spectrograms.
Returns:
z (obj): latent variable.
log_s_list(list): list of log scaling variables.
"""
mel = self.conditioner(mel)
assert mel.shape[2] >= audio.shape[1]
# Prune out the tail of audio/mel so that time/n_group == 0.
@ -361,6 +385,20 @@ class WaveFlowModule(dg.Layer):
return z, log_s_list
def synthesize(self, mel, sigma=1.0):
"""Use model to synthesize waveform.
Use a conditioner to upsample mel spectrograms into hidden states.
These hidden states along with initial random gaussian latent variable
are passed to a stack of Flow modules to obtain the audio output.
Args:
mel (obj): mel spectrograms.
sigma (float, optional): standard deviation of the guassian latent
variable. Defaults to 1.0.
Returns:
audio (obj): synthesized audio.
"""
if self.dtype == "float16":
mel = fluid.layers.cast(mel, self.dtype)
mel = self.conditioner.infer(mel)