add docstring
This commit is contained in:
parent
e82115bc8a
commit
484465ca1b
|
@ -109,6 +109,16 @@ def add_yaml_config(config):
|
||||||
|
|
||||||
|
|
||||||
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
def load_latest_checkpoint(checkpoint_dir, rank=0):
|
||||||
|
"""Get the iteration number corresponding to the latest saved checkpoint
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
|
rank (int, optional): the rank of the process in multi-process setting.
|
||||||
|
Defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: the latest iteration number.
|
||||||
|
"""
|
||||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
# Create checkpoint index file if not exist.
|
# Create checkpoint index file if not exist.
|
||||||
if (not os.path.isfile(checkpoint_path)) and rank == 0:
|
if (not os.path.isfile(checkpoint_path)) and rank == 0:
|
||||||
|
@ -129,6 +139,15 @@ def load_latest_checkpoint(checkpoint_dir, rank=0):
|
||||||
|
|
||||||
|
|
||||||
def save_latest_checkpoint(checkpoint_dir, iteration):
|
def save_latest_checkpoint(checkpoint_dir, iteration):
|
||||||
|
"""Save the iteration number of the latest model to be checkpointed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
|
iteration (int): the latest iteration number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint")
|
||||||
# Update the latest checkpoint index.
|
# Update the latest checkpoint index.
|
||||||
with open(checkpoint_path, "w") as handle:
|
with open(checkpoint_path, "w") as handle:
|
||||||
|
@ -142,6 +161,24 @@ def load_parameters(checkpoint_dir,
|
||||||
iteration=None,
|
iteration=None,
|
||||||
file_path=None,
|
file_path=None,
|
||||||
dtype="float32"):
|
dtype="float32"):
|
||||||
|
"""Load a specific model checkpoint from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
|
rank (int): the rank of the process in multi-process setting.
|
||||||
|
model (obj): model to load parameters.
|
||||||
|
optimizer (obj, optional): optimizer to load states if needed.
|
||||||
|
Defaults to None.
|
||||||
|
iteration (int, optional): if specified, load the specific checkpoint,
|
||||||
|
if not specified, load the latest one. Defaults to None.
|
||||||
|
file_path (str, optional): if specified, load the checkpoint
|
||||||
|
stored in the file_path. Defaults to None.
|
||||||
|
dtype (str, optional): precision of the model parameters.
|
||||||
|
Defaults to float32.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
if file_path is None:
|
if file_path is None:
|
||||||
if iteration is None:
|
if iteration is None:
|
||||||
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
iteration = load_latest_checkpoint(checkpoint_dir, rank)
|
||||||
|
@ -165,6 +202,18 @@ def load_parameters(checkpoint_dir,
|
||||||
|
|
||||||
|
|
||||||
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None):
|
||||||
|
"""Checkpoint the latest trained model parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint_dir (str): the directory where checkpoint is saved.
|
||||||
|
iteration (int): the latest iteration number.
|
||||||
|
model (obj): model to be checkpointed.
|
||||||
|
optimizer (obj, optional): optimizer to be checkpointed.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
file_path = "{}/step-{}".format(checkpoint_dir, iteration)
|
||||||
model_dict = model.state_dict()
|
model_dict = model.state_dict()
|
||||||
dg.save_dygraph(model_dict, file_path)
|
dg.save_dygraph(model_dict, file_path)
|
||||||
|
|
|
@ -80,6 +80,7 @@ class Subset(DatasetMixin):
|
||||||
# whole audio for valid set
|
# whole audio for valid set
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
|
# Randomly crop segment_length from audios in the training set.
|
||||||
# audio shape: [len]
|
# audio shape: [len]
|
||||||
if audio.shape[0] >= segment_length:
|
if audio.shape[0] >= segment_length:
|
||||||
max_audio_start = audio.shape[0] - segment_length
|
max_audio_start = audio.shape[0] - segment_length
|
||||||
|
|
|
@ -28,6 +28,25 @@ from .waveflow_modules import WaveFlowLoss, WaveFlowModule
|
||||||
|
|
||||||
|
|
||||||
class WaveFlow():
|
class WaveFlow():
|
||||||
|
"""Wrapper class of WaveFlow model that supports multiple APIs.
|
||||||
|
|
||||||
|
This module provides APIs for model building, training, validation,
|
||||||
|
inference, benchmarking, and saving.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (obj): config info.
|
||||||
|
checkpoint_dir (str): path for checkpointing.
|
||||||
|
parallel (bool, optional): whether use multiple GPUs for training.
|
||||||
|
Defaults to False.
|
||||||
|
rank (int, optional): the rank of the process in a multi-process
|
||||||
|
scenario. Defaults to 0.
|
||||||
|
nranks (int, optional): the total number of processes. Defaults to 1.
|
||||||
|
tb_logger (obj, optional): logger to visualize metrics.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WaveFlow
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config,
|
config,
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
|
@ -44,6 +63,15 @@ class WaveFlow():
|
||||||
self.dtype = "float16" if config.use_fp16 else "float32"
|
self.dtype = "float16" if config.use_fp16 else "float32"
|
||||||
|
|
||||||
def build(self, training=True):
|
def build(self, training=True):
|
||||||
|
"""Initialize the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
training (bool, optional): Whether the model is built for training or inference.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
config = self.config
|
config = self.config
|
||||||
dataset = LJSpeech(config, self.nranks, self.rank)
|
dataset = LJSpeech(config, self.nranks, self.rank)
|
||||||
self.trainloader = dataset.trainloader
|
self.trainloader = dataset.trainloader
|
||||||
|
@ -99,6 +127,14 @@ class WaveFlow():
|
||||||
self.waveflow = waveflow
|
self.waveflow = waveflow
|
||||||
|
|
||||||
def train_step(self, iteration):
|
def train_step(self, iteration):
|
||||||
|
"""Train the model for one step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iteration (int): current iteration number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
self.waveflow.train()
|
self.waveflow.train()
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -135,6 +171,14 @@ class WaveFlow():
|
||||||
|
|
||||||
@dg.no_grad
|
@dg.no_grad
|
||||||
def valid_step(self, iteration):
|
def valid_step(self, iteration):
|
||||||
|
"""Run the model on the validation dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iteration (int): current iteration number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
self.waveflow.eval()
|
self.waveflow.eval()
|
||||||
tb = self.tb_logger
|
tb = self.tb_logger
|
||||||
|
|
||||||
|
@ -167,6 +211,14 @@ class WaveFlow():
|
||||||
|
|
||||||
@dg.no_grad
|
@dg.no_grad
|
||||||
def infer(self, iteration):
|
def infer(self, iteration):
|
||||||
|
"""Run the model to synthesize audios.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iteration (int): iteration number of the loaded checkpoint.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
self.waveflow.eval()
|
self.waveflow.eval()
|
||||||
|
|
||||||
config = self.config
|
config = self.config
|
||||||
|
@ -203,6 +255,14 @@ class WaveFlow():
|
||||||
|
|
||||||
@dg.no_grad
|
@dg.no_grad
|
||||||
def benchmark(self):
|
def benchmark(self):
|
||||||
|
"""Run the model to benchmark synthesis speed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
self.waveflow.eval()
|
self.waveflow.eval()
|
||||||
|
|
||||||
mels_list = [mels for _, mels in self.validloader()]
|
mels_list = [mels for _, mels in self.validloader()]
|
||||||
|
@ -223,6 +283,14 @@ class WaveFlow():
|
||||||
print("{} X real-time".format(audio_time / syn_time))
|
print("{} X real-time".format(audio_time / syn_time))
|
||||||
|
|
||||||
def save(self, iteration):
|
def save(self, iteration):
|
||||||
|
"""Save model checkpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iteration (int): iteration number of the model to be saved.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
utils.save_latest_parameters(self.checkpoint_dir, iteration,
|
utils.save_latest_parameters(self.checkpoint_dir, iteration,
|
||||||
self.waveflow, self.optimizer)
|
self.waveflow, self.optimizer)
|
||||||
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
|
utils.save_latest_checkpoint(self.checkpoint_dir, iteration)
|
||||||
|
|
|
@ -293,6 +293,14 @@ class Flow(dg.Layer):
|
||||||
|
|
||||||
|
|
||||||
class WaveFlowModule(dg.Layer):
|
class WaveFlowModule(dg.Layer):
|
||||||
|
"""WaveFlow model implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (obj): model configuration parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WaveFlowModule
|
||||||
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(WaveFlowModule, self).__init__()
|
super(WaveFlowModule, self).__init__()
|
||||||
self.n_flows = config.n_flows
|
self.n_flows = config.n_flows
|
||||||
|
@ -321,6 +329,22 @@ class WaveFlowModule(dg.Layer):
|
||||||
self.perms.append(perm)
|
self.perms.append(perm)
|
||||||
|
|
||||||
def forward(self, audio, mel):
|
def forward(self, audio, mel):
|
||||||
|
"""Training forward pass.
|
||||||
|
|
||||||
|
Use a conditioner to upsample mel spectrograms into hidden states.
|
||||||
|
These hidden states along with the audio are passed to a stack of Flow
|
||||||
|
modules to obtain the final latent variable z and a list of log scaling
|
||||||
|
variables, which are then passed to the WaveFlowLoss module to calculate
|
||||||
|
the negative log likelihood.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
audio (obj): audio samples.
|
||||||
|
mel (obj): mel spectrograms.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
z (obj): latent variable.
|
||||||
|
log_s_list(list): list of log scaling variables.
|
||||||
|
"""
|
||||||
mel = self.conditioner(mel)
|
mel = self.conditioner(mel)
|
||||||
assert mel.shape[2] >= audio.shape[1]
|
assert mel.shape[2] >= audio.shape[1]
|
||||||
# Prune out the tail of audio/mel so that time/n_group == 0.
|
# Prune out the tail of audio/mel so that time/n_group == 0.
|
||||||
|
@ -361,6 +385,20 @@ class WaveFlowModule(dg.Layer):
|
||||||
return z, log_s_list
|
return z, log_s_list
|
||||||
|
|
||||||
def synthesize(self, mel, sigma=1.0):
|
def synthesize(self, mel, sigma=1.0):
|
||||||
|
"""Use model to synthesize waveform.
|
||||||
|
|
||||||
|
Use a conditioner to upsample mel spectrograms into hidden states.
|
||||||
|
These hidden states along with initial random gaussian latent variable
|
||||||
|
are passed to a stack of Flow modules to obtain the audio output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mel (obj): mel spectrograms.
|
||||||
|
sigma (float, optional): standard deviation of the guassian latent
|
||||||
|
variable. Defaults to 1.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
audio (obj): synthesized audio.
|
||||||
|
"""
|
||||||
if self.dtype == "float16":
|
if self.dtype == "float16":
|
||||||
mel = fluid.layers.cast(mel, self.dtype)
|
mel = fluid.layers.cast(mel, self.dtype)
|
||||||
mel = self.conditioner.infer(mel)
|
mel = self.conditioner.infer(mel)
|
||||||
|
|
Loading…
Reference in New Issue