From 8083da21acbf877f2800ff51781fb4ee5eee75d4 Mon Sep 17 00:00:00 2001 From: liuyibing01 Date: Sat, 7 Mar 2020 14:21:35 +0000 Subject: [PATCH 1/2] Fix sample file name --- examples/waveflow/README.md | 4 ++-- parakeet/models/waveflow/waveflow.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/waveflow/README.md b/examples/waveflow/README.md index e21039a..d36f0f3 100644 --- a/examples/waveflow/README.md +++ b/examples/waveflow/README.md @@ -4,7 +4,7 @@ PaddlePaddle dynamic graph implementation of [WaveFlow: A Compact Flow-based Mod - WaveFlow can synthesize 22.05 kHz high-fidelity speech around 40x faster than real-time on a Nvidia V100 GPU without engineered inference kernels, which is faster than [WaveGlow] (https://github.com/NVIDIA/waveglow) and serveral orders of magnitude faster than WaveNet. - WaveFlow is a small-footprint flow-based model for raw audio. It has only 5.9M parameters, which is 15x smalller than WaveGlow (87.9M) and comparable to WaveNet (4.6M). -- WaveFlow is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in Parallel WaveNet and ClariNet, which simplifies the training pipeline and reduces the cost of development. +- WaveFlow is directly trained with maximum likelihood without probability density distillation and auxiliary losses as used in Parallel WaveNet and ClariNet, which simplifies the training pipeline and reduces the cost of development. ## Project Structure ```text @@ -99,7 +99,7 @@ python -u synthesis.py \ --sigma=1.0 ``` -In this example, `--output` specifies where to save the synthesized audios and `--sample` specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset. +In this example, `--output` specifies where to save the synthesized audios and `--sample` (<16) specifies which sample in the valid dataset (a split from the whole LJSpeech dataset, by default contains the first 16 audio samples) to synthesize based on the mel-spectrograms computed from the ground truth sample audio, e.g., `--sample=0` means to synthesize the first audio in the valid dataset. ### Benchmarking diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py index a8bd8af..4ef1411 100644 --- a/parakeet/models/waveflow/waveflow.py +++ b/parakeet/models/waveflow/waveflow.py @@ -179,10 +179,13 @@ class WaveFlow(): mels_list = [mels for _, mels in self.validloader()] if sample is not None: mels_list = [mels_list[sample]] + else: + sample = 0 - for sample, mel in enumerate(mels_list): - filename = "{}/valid_{}.wav".format(output, sample) - print("Synthesize sample {}, save as {}".format(sample, filename)) + for idx, mel in enumerate(mels_list): + abs_idx = sample + idx + filename = "{}/valid_{}.wav".format(output, abs_idx) + print("Synthesize sample {}, save as {}".format(abs_idx, filename)) start_time = time.time() audio = self.waveflow.synthesize(mel, sigma=self.config.sigma) From 4f7ded3c89081e93f129323e59e99e5f0946f2cf Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sat, 7 Mar 2020 23:25:04 -0800 Subject: [PATCH 2/2] add docstring --- examples/waveflow/utils.py | 49 ++++++++++++++ parakeet/models/waveflow/data.py | 1 + parakeet/models/waveflow/waveflow.py | 68 ++++++++++++++++++++ parakeet/models/waveflow/waveflow_modules.py | 38 +++++++++++ 4 files changed, 156 insertions(+) diff --git a/examples/waveflow/utils.py b/examples/waveflow/utils.py index da9b4ba..b899073 100644 --- a/examples/waveflow/utils.py +++ b/examples/waveflow/utils.py @@ -109,6 +109,16 @@ def add_yaml_config(config): def load_latest_checkpoint(checkpoint_dir, rank=0): + """Get the iteration number corresponding to the latest saved checkpoint + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + rank (int, optional): the rank of the process in multi-process setting. + Defaults to 0. + + Returns: + int: the latest iteration number. + """ checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") # Create checkpoint index file if not exist. if (not os.path.isfile(checkpoint_path)) and rank == 0: @@ -129,6 +139,15 @@ def load_latest_checkpoint(checkpoint_dir, rank=0): def save_latest_checkpoint(checkpoint_dir, iteration): + """Save the iteration number of the latest model to be checkpointed. + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + + Returns: + None + """ checkpoint_path = os.path.join(checkpoint_dir, "checkpoint") # Update the latest checkpoint index. with open(checkpoint_path, "w") as handle: @@ -142,6 +161,24 @@ def load_parameters(checkpoint_dir, iteration=None, file_path=None, dtype="float32"): + """Load a specific model checkpoint from disk. + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + rank (int): the rank of the process in multi-process setting. + model (obj): model to load parameters. + optimizer (obj, optional): optimizer to load states if needed. + Defaults to None. + iteration (int, optional): if specified, load the specific checkpoint, + if not specified, load the latest one. Defaults to None. + file_path (str, optional): if specified, load the checkpoint + stored in the file_path. Defaults to None. + dtype (str, optional): precision of the model parameters. + Defaults to float32. + + Returns: + None + """ if file_path is None: if iteration is None: iteration = load_latest_checkpoint(checkpoint_dir, rank) @@ -165,6 +202,18 @@ def load_parameters(checkpoint_dir, def save_latest_parameters(checkpoint_dir, iteration, model, optimizer=None): + """Checkpoint the latest trained model parameters. + + Args: + checkpoint_dir (str): the directory where checkpoint is saved. + iteration (int): the latest iteration number. + model (obj): model to be checkpointed. + optimizer (obj, optional): optimizer to be checkpointed. + Defaults to None. + + Returns: + None + """ file_path = "{}/step-{}".format(checkpoint_dir, iteration) model_dict = model.state_dict() dg.save_dygraph(model_dict, file_path) diff --git a/parakeet/models/waveflow/data.py b/parakeet/models/waveflow/data.py index 83438f7..33e2ee5 100644 --- a/parakeet/models/waveflow/data.py +++ b/parakeet/models/waveflow/data.py @@ -80,6 +80,7 @@ class Subset(DatasetMixin): # whole audio for valid set pass else: + # Randomly crop segment_length from audios in the training set. # audio shape: [len] if audio.shape[0] >= segment_length: max_audio_start = audio.shape[0] - segment_length diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py index 4ef1411..101bb66 100644 --- a/parakeet/models/waveflow/waveflow.py +++ b/parakeet/models/waveflow/waveflow.py @@ -28,6 +28,25 @@ from .waveflow_modules import WaveFlowLoss, WaveFlowModule class WaveFlow(): + """Wrapper class of WaveFlow model that supports multiple APIs. + + This module provides APIs for model building, training, validation, + inference, benchmarking, and saving. + + Args: + config (obj): config info. + checkpoint_dir (str): path for checkpointing. + parallel (bool, optional): whether use multiple GPUs for training. + Defaults to False. + rank (int, optional): the rank of the process in a multi-process + scenario. Defaults to 0. + nranks (int, optional): the total number of processes. Defaults to 1. + tb_logger (obj, optional): logger to visualize metrics. + Defaults to None. + + Returns: + WaveFlow + """ def __init__(self, config, checkpoint_dir, @@ -44,6 +63,15 @@ class WaveFlow(): self.dtype = "float16" if config.use_fp16 else "float32" def build(self, training=True): + """Initialize the model. + + Args: + training (bool, optional): Whether the model is built for training or inference. + Defaults to True. + + Returns: + None + """ config = self.config dataset = LJSpeech(config, self.nranks, self.rank) self.trainloader = dataset.trainloader @@ -99,6 +127,14 @@ class WaveFlow(): self.waveflow = waveflow def train_step(self, iteration): + """Train the model for one step. + + Args: + iteration (int): current iteration number. + + Returns: + None + """ self.waveflow.train() start_time = time.time() @@ -135,6 +171,14 @@ class WaveFlow(): @dg.no_grad def valid_step(self, iteration): + """Run the model on the validation dataset. + + Args: + iteration (int): current iteration number. + + Returns: + None + """ self.waveflow.eval() tb = self.tb_logger @@ -167,6 +211,14 @@ class WaveFlow(): @dg.no_grad def infer(self, iteration): + """Run the model to synthesize audios. + + Args: + iteration (int): iteration number of the loaded checkpoint. + + Returns: + None + """ self.waveflow.eval() config = self.config @@ -203,6 +255,14 @@ class WaveFlow(): @dg.no_grad def benchmark(self): + """Run the model to benchmark synthesis speed. + + Args: + None + + Returns: + None + """ self.waveflow.eval() mels_list = [mels for _, mels in self.validloader()] @@ -223,6 +283,14 @@ class WaveFlow(): print("{} X real-time".format(audio_time / syn_time)) def save(self, iteration): + """Save model checkpoint. + + Args: + iteration (int): iteration number of the model to be saved. + + Returns: + None + """ utils.save_latest_parameters(self.checkpoint_dir, iteration, self.waveflow, self.optimizer) utils.save_latest_checkpoint(self.checkpoint_dir, iteration) diff --git a/parakeet/models/waveflow/waveflow_modules.py b/parakeet/models/waveflow/waveflow_modules.py index 46dfba7..f480cd9 100644 --- a/parakeet/models/waveflow/waveflow_modules.py +++ b/parakeet/models/waveflow/waveflow_modules.py @@ -293,6 +293,14 @@ class Flow(dg.Layer): class WaveFlowModule(dg.Layer): + """WaveFlow model implementation. + + Args: + config (obj): model configuration parameters. + + Returns: + WaveFlowModule + """ def __init__(self, config): super(WaveFlowModule, self).__init__() self.n_flows = config.n_flows @@ -321,6 +329,22 @@ class WaveFlowModule(dg.Layer): self.perms.append(perm) def forward(self, audio, mel): + """Training forward pass. + + Use a conditioner to upsample mel spectrograms into hidden states. + These hidden states along with the audio are passed to a stack of Flow + modules to obtain the final latent variable z and a list of log scaling + variables, which are then passed to the WaveFlowLoss module to calculate + the negative log likelihood. + + Args: + audio (obj): audio samples. + mel (obj): mel spectrograms. + + Returns: + z (obj): latent variable. + log_s_list(list): list of log scaling variables. + """ mel = self.conditioner(mel) assert mel.shape[2] >= audio.shape[1] # Prune out the tail of audio/mel so that time/n_group == 0. @@ -361,6 +385,20 @@ class WaveFlowModule(dg.Layer): return z, log_s_list def synthesize(self, mel, sigma=1.0): + """Use model to synthesize waveform. + + Use a conditioner to upsample mel spectrograms into hidden states. + These hidden states along with initial random gaussian latent variable + are passed to a stack of Flow modules to obtain the audio output. + + Args: + mel (obj): mel spectrograms. + sigma (float, optional): standard deviation of the guassian latent + variable. Defaults to 1.0. + + Returns: + audio (obj): synthesized audio. + """ if self.dtype == "float16": mel = fluid.layers.cast(mel, self.dtype) mel = self.conditioner.infer(mel)