add docstring
This commit is contained in:
parent
8083da21ac
commit
4f7ded3c89
|
@ -109,6 +109,16 @@ def add_yaml_config(config):
|
|||
|
||||
|
||||
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
||||
"""Get the iteration number corresponding to the latest saved checkpoint
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
rank (int, optional): the rank of the process in multi-process setting.
|
||||
Defaults to 0.
|
||||
|
||||
Returns:
|
||||
int: the latest iteration number.
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Create checkpoint index file if not exist.
|
||||
if (not os.path.isfile(checkpoint_path)) and rank == 0:
|
||||
|
@ -129,6 +139,15 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
|
|||
|
||||
|
||||
def save_latest_checkpoint(checkpoint_dir, iteration):
|
||||
"""Save the iteration number of the latest model to be checkpointed.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
iteration (int): the latest iteration number.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||
# Update the latest checkpoint index.
|
||||
with open(checkpoint_path, "w") as handle:
|
||||
|
@ -142,6 +161,24 @@ def load_parameters(checkpoint_dir,
|
|||
iteration=None,
|
||||
file_path=None,
|
||||
dtype="float32"):
|
||||
"""Load a specific model checkpoint from disk.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
rank (int): the rank of the process in multi-process setting.
|
||||
model (obj): model to load parameters.
|
||||
optimizer (obj, optional): optimizer to load states if needed.
|
||||
Defaults to None.
|
||||
iteration (int, optional): if specified, load the specific checkpoint,
|
||||
if not specified, load the latest one. Defaults to None.
|
||||
file_path (str, optional): if specified, load the checkpoint
|
||||
stored in the file_path. Defaults to None.
|
||||
dtype (str, optional): precision of the model parameters.
|
||||
Defaults to float32.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if file_path is None:
|
||||
if iteration is None:
|
||||
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
||||
|
@ -165,6 +202,18 @@ def load_parameters(checkpoint_dir,
|
|||
|
||||
|
||||
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||
"""Checkpoint the latest trained model parameters.
|
||||
|
||||
Args:
|
||||
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||
iteration (int): the latest iteration number.
|
||||
model (obj): model to be checkpointed.
|
||||
optimizer (obj, optional): optimizer to be checkpointed.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
||||
model_dict = model.state_dict()
|
||||
dg.save_dygraph(model_dict, file_path)
|
||||
|
|
|
@ -80,6 +80,7 @@ class Subset(DatasetMixin):
|
|||
# whole audio for valid set
|
||||
pass
|
||||
else:
|
||||
# Randomly crop segment_length from audios in the training set.
|
||||
# audio shape: [len]
|
||||
if audio.shape[0] >= segment_length:
|
||||
max_audio_start = audio.shape[0] - segment_length
|
||||
|
|
|
@ -28,6 +28,25 @@ from .waveflow_modules import WaveFlowLoss, WaveFlowModule
|
|||
|
||||
|
||||
class WaveFlow():
|
||||
"""Wrapper class of WaveFlow model that supports multiple APIs.
|
||||
|
||||
This module provides APIs for model building, training, validation,
|
||||
inference, benchmarking, and saving.
|
||||
|
||||
Args:
|
||||
config (obj): config info.
|
||||
checkpoint_dir (str): path for checkpointing.
|
||||
parallel (bool, optional): whether use multiple GPUs for training.
|
||||
Defaults to False.
|
||||
rank (int, optional): the rank of the process in a multi-process
|
||||
scenario. Defaults to 0.
|
||||
nranks (int, optional): the total number of processes. Defaults to 1.
|
||||
tb_logger (obj, optional): logger to visualize metrics.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
WaveFlow
|
||||
"""
|
||||
def __init__(self,
|
||||
config,
|
||||
checkpoint_dir,
|
||||
|
@ -44,6 +63,15 @@ class WaveFlow():
|
|||
self.dtype = "float16" if config.use_fp16 else "float32"
|
||||
|
||||
def build(self, training=True):
|
||||
"""Initialize the model.
|
||||
|
||||
Args:
|
||||
training (bool, optional): Whether the model is built for training or inference.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
config = self.config
|
||||
dataset = LJSpeech(config, self.nranks, self.rank)
|
||||
self.trainloader = dataset.trainloader
|
||||
|
@ -99,6 +127,14 @@ class WaveFlow():
|
|||
self.waveflow = waveflow
|
||||
|
||||
def train_step(self, iteration):
|
||||
"""Train the model for one step.
|
||||
|
||||
Args:
|
||||
iteration (int): current iteration number.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.waveflow.train()
|
||||
|
||||
start_time = time.time()
|
||||
|
@ -135,6 +171,14 @@ class WaveFlow():
|
|||
|
||||
@dg.no_grad
|
||||
def valid_step(self, iteration):
|
||||
"""Run the model on the validation dataset.
|
||||
|
||||
Args:
|
||||
iteration (int): current iteration number.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.waveflow.eval()
|
||||
tb = self.tb_logger
|
||||
|
||||
|
@ -167,6 +211,14 @@ class WaveFlow():
|
|||
|
||||
@dg.no_grad
|
||||
def infer(self, iteration):
|
||||
"""Run the model to synthesize audios.
|
||||
|
||||
Args:
|
||||
iteration (int): iteration number of the loaded checkpoint.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.waveflow.eval()
|
||||
|
||||
config = self.config
|
||||
|
@ -203,6 +255,14 @@ class WaveFlow():
|
|||
|
||||
@dg.no_grad
|
||||
def benchmark(self):
|
||||
"""Run the model to benchmark synthesis speed.
|
||||
|
||||
Args:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.waveflow.eval()
|
||||
|
||||
mels_list = [mels for _, mels in self.validloader()]
|
||||
|
@ -223,6 +283,14 @@ class WaveFlow():
|
|||
print("{} X real-time".format(audio_time / syn_time))
|
||||
|
||||
def save(self, iteration):
|
||||
"""Save model checkpoint.
|
||||
|
||||
Args:
|
||||
iteration (int): iteration number of the model to be saved.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
utils.save_latest_parameters(self.checkpoint_dir, iteration,
|
||||
self.waveflow, self.optimizer)
|
||||
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
|
||||
|
|
|
@ -293,6 +293,14 @@ class Flow(dg.Layer):
|
|||
|
||||
|
||||
class WaveFlowModule(dg.Layer):
|
||||
"""WaveFlow model implementation.
|
||||
|
||||
Args:
|
||||
config (obj): model configuration parameters.
|
||||
|
||||
Returns:
|
||||
WaveFlowModule
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(WaveFlowModule, self).__init__()
|
||||
self.n_flows = config.n_flows
|
||||
|
@ -321,6 +329,22 @@ class WaveFlowModule(dg.Layer):
|
|||
self.perms.append(perm)
|
||||
|
||||
def forward(self, audio, mel):
|
||||
"""Training forward pass.
|
||||
|
||||
Use a conditioner to upsample mel spectrograms into hidden states.
|
||||
These hidden states along with the audio are passed to a stack of Flow
|
||||
modules to obtain the final latent variable z and a list of log scaling
|
||||
variables, which are then passed to the WaveFlowLoss module to calculate
|
||||
the negative log likelihood.
|
||||
|
||||
Args:
|
||||
audio (obj): audio samples.
|
||||
mel (obj): mel spectrograms.
|
||||
|
||||
Returns:
|
||||
z (obj): latent variable.
|
||||
log_s_list(list): list of log scaling variables.
|
||||
"""
|
||||
mel = self.conditioner(mel)
|
||||
assert mel.shape[2] >= audio.shape[1]
|
||||
# Prune out the tail of audio/mel so that time/n_group == 0.
|
||||
|
@ -361,6 +385,20 @@ class WaveFlowModule(dg.Layer):
|
|||
return z, log_s_list
|
||||
|
||||
def synthesize(self, mel, sigma=1.0):
|
||||
"""Use model to synthesize waveform.
|
||||
|
||||
Use a conditioner to upsample mel spectrograms into hidden states.
|
||||
These hidden states along with initial random gaussian latent variable
|
||||
are passed to a stack of Flow modules to obtain the audio output.
|
||||
|
||||
Args:
|
||||
mel (obj): mel spectrograms.
|
||||
sigma (float, optional): standard deviation of the guassian latent
|
||||
variable. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
audio (obj): synthesized audio.
|
||||
"""
|
||||
if self.dtype == "float16":
|
||||
mel = fluid.layers.cast(mel, self.dtype)
|
||||
mel = self.conditioner.infer(mel)
|
||||
|
|
Loading…
Reference in New Issue