diff --git a/.gitignore b/.gitignore index 7906666..af9563d 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,5 @@ dmypy.json *.swp runs syn_audios +exp/ +dump/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9d6da44..6f222bb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,11 @@ +repos: - repo: https://github.com/PaddlePaddle/mirrors-yapf.git - sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 + rev: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 hooks: - id: yapf files: \.py$ - repo: https://github.com/pre-commit/pre-commit-hooks - sha: a11d9314b22d8f8c7556443875b731ef05965464 + rev: a11d9314b22d8f8c7556443875b731ef05965464 hooks: - id: check-merge-conflict - id: check-symlinks @@ -15,7 +16,7 @@ - id: trailing-whitespace files: \.md$ - repo: https://github.com/Lucas-C/pre-commit-hooks - sha: v1.0.1 + rev: v1.0.1 hooks: - id: forbid-crlf files: \.md$ diff --git a/examples/parallelwave_gan/baker/batch_fn.py b/examples/parallelwave_gan/baker/batch_fn.py new file mode 100644 index 0000000..22af5af --- /dev/null +++ b/examples/parallelwave_gan/baker/batch_fn.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle + + +class Clip(object): + """Collate functor for training vocoders. + """ + + def __init__( + self, + batch_max_steps=20480, + hop_size=256, + aux_context_window=0, ): + """Initialize customized collater for DataLoader. + + Args: + batch_max_steps (int): The maximum length of input signal in batch. + hop_size (int): Hop size of auxiliary features. + aux_context_window (int): Context window size for auxiliary feature conv. + + """ + if batch_max_steps % hop_size != 0: + batch_max_steps += -(batch_max_steps % hop_size) + assert batch_max_steps % hop_size == 0 + self.batch_max_steps = batch_max_steps + self.batch_max_frames = batch_max_steps // hop_size + self.hop_size = hop_size + self.aux_context_window = aux_context_window + + # set useful values in random cutting + self.start_offset = aux_context_window + self.end_offset = -(self.batch_max_frames + aux_context_window) + self.mel_threshold = self.batch_max_frames + 2 * aux_context_window + + def __call__(self, examples): + """Convert into batch tensors. + + Args: + batch (list): list of tuple of the pair of audio and features. Audio shape + (T, ), features shape(T', C). + + Returns: + Tensor: Auxiliary feature batch (B, C, T'), where + T = (T' - 2 * aux_context_window) * hop_size. + Tensor: Target signal batch (B, 1, T). + + """ + # check length + examples = [ + self._adjust_length(b['wave'], b['feats']) for b in examples + if b['feats'].shape[0] > self.mel_threshold + ] + xs, cs = [b[0] for b in examples], [b[1] for b in examples] + + # make batch with random cut + c_lengths = [c.shape[0] for c in cs] + start_frames = np.array([ + np.random.randint(self.start_offset, cl + self.end_offset) + for cl in c_lengths + ]) + x_starts = start_frames * self.hop_size + x_ends = x_starts + self.batch_max_steps + + c_starts = start_frames - self.aux_context_window + c_ends = start_frames + self.batch_max_frames + self.aux_context_window + y_batch = np.stack( + [x[start:end] for x, start, end in zip(xs, x_starts, x_ends)]) + c_batch = np.stack( + [c[start:end] for c, start, end in zip(cs, c_starts, c_ends)]) + + # convert each batch to tensor, asuume that each item in batch has the same length + y_batch = paddle.to_tensor( + y_batch, dtype=paddle.float32).unsqueeze(1) # (B, 1, T) + c_batch = paddle.to_tensor( + c_batch, dtype=paddle.float32).transpose([0, 2, 1]) # (B, C, T') + + return y_batch, c_batch + + def _adjust_length(self, x, c): + """Adjust the audio and feature lengths. + + Note: + Basically we assume that the length of x and c are adjusted + through preprocessing stage, but if we use other library processed + features, this process will be needed. + + """ + if len(x) < c.shape[1] * self.hop_size: + x = np.pad(x, (0, c.shape[1] * self.hop_size - len(x)), + mode="edge") + + # check the legnth is valid + assert len(x) == c.shape[ + 0] * self.hop_size, f"wave length: ({len(x)}), mel length: ({c.shape[0]})" + + return x, c diff --git a/examples/parallelwave_gan/baker/compute_statistics.py b/examples/parallelwave_gan/baker/compute_statistics.py new file mode 100644 index 0000000..06b9b65 --- /dev/null +++ b/examples/parallelwave_gan/baker/compute_statistics.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Calculate statistics of feature files.""" + +import argparse +import logging +import os + +import numpy as np +import yaml +import json +import jsonlines + +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from parakeet.datasets.data_table import DataTable +from parakeet.utils.h5_utils import read_hdf5 +from parakeet.utils.h5_utils import write_hdf5 + +from config import get_cfg_default + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Compute mean and variance of dumped raw features.") + parser.add_argument( + "--metadata", type=str, help="json file with id and file paths ") + parser.add_argument( + "--field-name", + type=str, + help="name of the field to compute statistics for.") + parser.add_argument( + "--config", type=str, help="yaml format configuration file.") + parser.add_argument( + "--dumpdir", + type=str, + help="directory to save statistics. if not provided, " + "stats will be saved in the above root directory. (default=None)") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + logging.warning('Skip DEBUG/INFO messages') + + config = get_cfg_default() + # load config + if args.config: + config.merge_from_file(args.config) + + # check directory existence + if args.dumpdir is None: + args.dumpdir = os.path.dirname(args.metadata) + if not os.path.exists(args.dumpdir): + os.makedirs(args.dumpdir) + + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, + fields=[args.field_name], + converters={args.field_name: np.load}, ) + logging.info(f"The number of files = {len(dataset)}.") + + # calculate statistics + scaler = StandardScaler() + for datum in tqdm(dataset): + # StandardScalar supports (*, num_features) by default + scaler.partial_fit(datum[args.field_name]) + + stats = np.stack([scaler.mean_, scaler.scale_], axis=0) + np.save( + os.path.join(args.dumpdir, "stats.npy"), + stats.astype(np.float32), + allow_pickle=False) + + +if __name__ == "__main__": + main() diff --git a/examples/parallelwave_gan/baker/conf/default.yaml b/examples/parallelwave_gan/baker/conf/default.yaml new file mode 100644 index 0000000..777b2b0 --- /dev/null +++ b/examples/parallelwave_gan/baker/conf/default.yaml @@ -0,0 +1,128 @@ +# This is the hyperparameter configuration file for Parallel WaveGAN. +# Please make sure this is adjusted for the CSMSC dataset. If you want to +# apply to the other dataset, you might need to carefully change some parameters. +# This configuration requires 12 GB GPU memory and takes ~3 days on RTX TITAN. + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +sr: 24000 # Sampling rate. +n_fft: 2048 # FFT size. +hop_length: 300 # Hop size. +win_length: 1200 # Window length. + # If set to null, it will be the same as fft_size. +window: "hann" # Window function. +n_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. +fmax: 7600 # Maximum frequency in mel basis calculation. +# global_gain_scale: 1.0 # Will be multiplied to all of waveform. +trim_silence: false # Whether to trim the start and end of silence. +top_db: 60 # Need to tune carefully if the recording is not good. +trim_frame_length: 2048 # Frame size in trimming.(in samples) +trim_hop_length: 512 # Hop size in trimming.(in samples) + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_size: 3 # Kernel size of dilated convolution. + layers: 30 # Number of residual block layers. + stacks: 3 # Number of stacks i.e., dilation cycles. + residual_channels: 64 # Number of channels in residual conv. + gate_channels: 128 # Number of channels in gated conv. + skip_channels: 64 # Number of channels in skip conv. + aux_channels: 80 # Number of channels for auxiliary feature conv. + # Must be the same as num_mels. + aux_context_window: 2 # Context window size for auxiliary feature. + # If set to 2, previous 2 and future 2 frames will be considered. + dropout: 0.0 # Dropout rate. 0.0 means no dropout applied. + bias: true # use bias in residual blocks + use_weight_norm: true # Whether to use weight norm. + # If set to true, it will be applied to all of the conv layers. + use_causal_conv: false # use causal conv in residual blocks and upsample layers + # upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture. + upsample_scales: [4, 5, 3, 5] # Upsampling scales. Prodcut of these must be the same as hop size. + interpolate_mode: "nearest" # upsample net interpolate mode + freq_axis_kernel_size: 1 # upsamling net: convolution kernel size in frequencey axis + nonlinear_activation: null + nonlinear_activation_params: {} + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_size: 3 # Number of output channels. + layers: 10 # Number of conv layers. + conv_channels: 64 # Number of chnn layers. + bias: true # Whether to use bias parameter in conv. + use_weight_norm: true # Whether to use weight norm. + # If set to true, it will be applied to all of the conv layers. + nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv. + nonlinear_activation_params: # Nonlinear function parameters + negative_slope: 0.2 # Alpha in LeakyReLU. + +########################################################### +# STFT LOSS SETTING # +########################################################### +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann" # Window function for STFT-based loss + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_adv: 4.0 # Loss balancing coefficient. + +########################################################### +# DATA LOADER SETTING # +########################################################### +batch_size: 6 # Batch size. +batch_max_steps: 25500 # Length of each audio in batch. Make sure dividable by hop_size. +pin_memory: true # Whether to pin memory in Pytorch DataLoader. +num_workers: 4 # Number of workers in Pytorch DataLoader. +remove_short_samples: true # Whether to remove samples the length of which are less than batch_max_steps. +allow_cache: true # Whether to allow cache in dataset. If true, it requires cpu memory. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + epsilon: 1.0e-6 # Generator's epsilon. + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + learning_rate: 0.0001 # Generator's learning rate. + step_size: 200000 # Generator's scheduler step size. + gamma: 0.5 # Generator's scheduler gamma. + # At each step size, lr will be multiplied by this parameter. +generator_grad_norm: 10 # Generator's gradient norm. +discriminator_optimizer_params: + epsilon: 1.0e-6 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + learning_rate: 0.00005 # Discriminator's learning rate. + step_size: 200000 # Discriminator's scheduler step size. + gamma: 0.5 # Discriminator's scheduler gamma. + # At each step size, lr will be multiplied by this parameter. +discriminator_grad_norm: 1 # Discriminator's gradient norm. + +########################################################### +# INTERVAL SETTING # +########################################################### +discriminator_train_start_steps: 100000 # Number of steps to start to train discriminator. +train_max_steps: 400000 # Number of training steps. +save_interval_steps: 5000 # Interval steps to save checkpoint. +eval_interval_steps: 1000 # Interval steps to evaluate the network. + + +########################################################### +# OTHER SETTING # +########################################################### +num_save_intermediate_results: 4 # Number of results to be saved as intermediate results. +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random \ No newline at end of file diff --git a/examples/parallelwave_gan/baker/config.py b/examples/parallelwave_gan/baker/config.py new file mode 100644 index 0000000..f555791 --- /dev/null +++ b/examples/parallelwave_gan/baker/config.py @@ -0,0 +1,25 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import yaml +from yacs.config import CfgNode as Configuration + +with open("conf/default.yaml", 'rt') as f: + _C = yaml.safe_load(f) + _C = Configuration(_C) + + +def get_cfg_default(): + config = _C.clone() + return config diff --git a/examples/parallelwave_gan/baker/normalize.py b/examples/parallelwave_gan/baker/normalize.py new file mode 100644 index 0000000..0cf2841 --- /dev/null +++ b/examples/parallelwave_gan/baker/normalize.py @@ -0,0 +1,145 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Normalize feature files and dump them.""" + +import argparse +import logging +import os +from operator import itemgetter +from pathlib import Path + +import numpy as np +import yaml +import jsonlines +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from parakeet.datasets.data_table import DataTable + +from config import get_cfg_default + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." + ) + parser.add_argument( + "--metadata", + type=str, + required=True, + help="directory including feature files to be normalized. " + "you need to specify either *-scp or rootdir.") + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump normalized feature files.") + parser.add_argument( + "--stats", type=str, required=True, help="statistics file.") + parser.add_argument( + "--skip-wav-copy", + default=False, + action="store_true", + help="whether to skip the copy of wav files.") + parser.add_argument( + "--config", type=str, help="yaml format configuration file.") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + logging.warning('Skip DEBUG/INFO messages') + + # load config + config = get_cfg_default() + if args.config: + config.merge_from_file(args.config) + + # check directory existence + dumpdir = Path(args.dumpdir).resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, + fields=["utt_id", "wave", "feats"], + converters={ + 'utt_id': None, + 'wave': None if args.skip_wav_copy else np.load, + 'feats': np.load, + }) + logging.info(f"The number of files = {len(dataset)}.") + + # restore scaler + scaler = StandardScaler() + scaler.mean_ = np.load(args.stats)[0] + scaler.scale_ = np.load(args.stats)[1] + + # from version 0.23.0, this information is needed + scaler.n_features_in_ = scaler.mean_.shape[0] + + # process each file + output_metadata = [] + + for item in tqdm(dataset): + utt_id = item['utt_id'] + wave = item['wave'] + mel = item['feats'] + # normalize + mel = scaler.transform(mel) + + # save + mel_path = dumpdir / f"{utt_id}-feats.npy" + np.save(mel_path, mel.astype(np.float32), allow_pickle=False) + if not args.skip_wav_copy: + wav_path = dumpdir / f"{utt_id}-wave.npy" + np.save(wav_path, wave.astype(np.float32), allow_pickle=False) + else: + wav_path = wave + output_metadata.append({ + 'utt_id': utt_id, + 'wave': str(wav_path), + 'feats': str(mel_path), + }) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/parallelwave_gan/baker/preprocess.py b/examples/parallelwave_gan/baker/preprocess.py new file mode 100644 index 0000000..6144c34 --- /dev/null +++ b/examples/parallelwave_gan/baker/preprocess.py @@ -0,0 +1,287 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Dict, Any +import soundfile as sf +import librosa +import numpy as np +import argparse +import yaml +import json +import jsonlines +import concurrent.futures +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from pathlib import Path +import tqdm +from operator import itemgetter +from praatio import tgio +import logging + +from config import get_cfg_default + + +def logmelfilterbank(audio, + sr, + n_fft=1024, + hop_length=256, + win_length=None, + window="hann", + n_mels=80, + fmin=None, + fmax=None, + eps=1e-10): + """Compute log-Mel filterbank feature. + + Parameters + ---------- + audio : ndarray + Audio signal (T,). + sr : int + Sampling rate. + n_fft : int + FFT size. (Default value = 1024) + hop_length : int + Hop size. (Default value = 256) + win_length : int + Window length. If set to None, it will be the same as fft_size. (Default value = None) + window : str + Window function type. (Default value = "hann") + n_mels : int + Number of mel basis. (Default value = 80) + fmin : int + Minimum frequency in mel basis calculation. (Default value = None) + fmax : int + Maximum frequency in mel basis calculation. (Default value = None) + eps : float + Epsilon value to avoid inf in log calculation. (Default value = 1e-10) + + Returns + ------- + np.ndarray + Log Mel filterbank feature (#frames, num_mels). + + """ + # get amplitude spectrogram + x_stft = librosa.stft( + audio, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + pad_mode="reflect") + spc = np.abs(x_stft) # (#bins, #frames,) + + # get mel basis + fmin = 0 if fmin is None else fmin + fmax = sr / 2 if fmax is None else fmax + mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax) + + return np.log10(np.maximum(eps, np.dot(mel_basis, spc))) + + +def process_sentence(config: Dict[str, Any], + fp: Path, + alignment_fp: Path, + output_dir: Path): + utt_id = fp.stem + + # reading + y, sr = librosa.load(fp, sr=config.sr) # resampling may occur + assert len(y.shape) == 1, f"{utt_id} is not a mono-channel audio." + assert np.abs(y).max( + ) <= 1.0, f"{utt_id} is seems to be different that 16 bit PCM." + duration = librosa.get_duration(y, sr=sr) + + # trim according to the alignment file + alignment = tgio.openTextgrid(alignment_fp) + intervals = alignment.tierDict[alignment.tierNameList[0]].entryList + first, last = intervals[0], intervals[-1] + start = 0 + end = last.end + if first.label == "sil" and first.end < duration: + start = first.end + else: + logging.warning( + f" There is something wrong with the fisrt interval {first} in utterance: {utt_id}" + ) + if last.label == "sil" and last.start < duration: + end = last.start + else: + end = duration + logging.warning( + f" There is something wrong with the last interval {last} in utterance: {utt_id}" + ) + # silence trimmed + start, end = librosa.time_to_samples([first.end, last.start], sr=sr) + y = y[start:end] + + # energy based silence trimming + if config.trim_silence: + y, _ = librosa.effects.trim( + y, + top_db=config.top_db, + frame_length=config.trim_frame_length, + hop_length=config.trim_hop_length) + + logmel = logmelfilterbank( + y, + sr=sr, + n_fft=config.n_fft, + window=config.window, + win_length=config.win_length, + hop_length=config.hop_length, + n_mels=config.n_mels, + fmin=config.fmin, + fmax=config.fmax) + + # adjust time to make num_samples == num_frames * hop_length + num_frames = logmel.shape[1] + if y.size < num_frames * config.hop_length: + y = np.pad(y, (0, num_frames * config.hop_length - y.size), + mode="reflect") + else: + y = y[:num_frames * config.hop_length] + num_sample = y.shape[0] + + mel_path = output_dir / (utt_id + "_feats.npy") + wav_path = output_dir / (utt_id + "_wave.npy") + np.save(wav_path, y) # (num_samples, ) + np.save(mel_path, logmel.T) # (num_frames, n_mels) + record = { + "utt_id": utt_id, + "num_samples": num_sample, + "num_frames": num_frames, + "feats": str(mel_path.resolve()), + "wave": str(wav_path.resolve()), + } + return record + + +def process_sentences(config, + fps: List[Path], + alignment_fps: List[Path], + output_dir: Path, + nprocs: int=1): + if nprocs == 1: + results = [] + for fp, alignment_fp in tqdm.tqdm(zip(fps, alignment_fps)): + results.append( + process_sentence(config, fp, alignment_fp, output_dir)) + else: + with ThreadPoolExecutor(nprocs) as pool: + futures = [] + with tqdm.tqdm(total=len(fps)) as progress: + for fp, alignment_fp in zip(fps, alignment_fps): + future = pool.submit(process_sentence, config, fp, + alignment_fp, output_dir) + future.add_done_callback(lambda p: progress.update()) + futures.append(future) + + results = [] + for ft in futures: + results.append(ft.result()) + + results.sort(key=itemgetter("utt_id")) + with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: + for item in results: + writer.write(item) + print("Done") + + +def main(): + # parse config and args + parser = argparse.ArgumentParser( + description="Preprocess audio and then extract features (See detail in parallel_wavegan/bin/preprocess.py)." + ) + parser.add_argument( + "--rootdir", + default=None, + type=str, + help="directory including wav files. you need to specify either scp or rootdir." + ) + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump feature files.") + parser.add_argument( + "--config", type=str, help="yaml format configuration file.") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + parser.add_argument( + "--num_cpu", type=int, default=1, help="number of process.") + args = parser.parse_args() + + C = get_cfg_default() + if args.config: + C.merge_from_file(args.config) + C.freeze() + + if args.verbose > 1: + print(vars(args)) + print(C) + + root_dir = Path(args.rootdir).expanduser() + dumpdir = Path(args.dumpdir).expanduser() + dumpdir.mkdir(parents=True, exist_ok=True) + + wav_files = sorted(list((root_dir / "Wave").rglob("*.wav"))) + alignment_files = sorted( + list((root_dir / "PhoneLabeling").rglob("*.interval"))) + + # split data into 3 sections + num_train = 9800 + num_dev = 100 + + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + + train_alignment_files = alignment_files[:num_train] + dev_alignment_files = alignment_files[num_train:num_train + num_dev] + test_alignment_files = alignment_files[num_train + num_dev:] + + train_dump_dir = dumpdir / "train" / "raw" + train_dump_dir.mkdir(parents=True, exist_ok=True) + dev_dump_dir = dumpdir / "dev" / "raw" + dev_dump_dir.mkdir(parents=True, exist_ok=True) + test_dump_dir = dumpdir / "test" / "raw" + test_dump_dir.mkdir(parents=True, exist_ok=True) + + # process for the 3 sections + process_sentences( + C, + train_wav_files, + train_alignment_files, + train_dump_dir, + nprocs=args.num_cpu) + process_sentences( + C, + dev_wav_files, + dev_alignment_files, + dev_dump_dir, + nprocs=args.num_cpu) + process_sentences( + C, + test_wav_files, + test_alignment_files, + test_dump_dir, + nprocs=args.num_cpu) + + +if __name__ == "__main__": + main() diff --git a/examples/parallelwave_gan/baker/pwg_updater.py b/examples/parallelwave_gan/baker/pwg_updater.py new file mode 100644 index 0000000..dde7773 --- /dev/null +++ b/examples/parallelwave_gan/baker/pwg_updater.py @@ -0,0 +1,184 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Dict + +import paddle +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler +from timer import timer + +from parakeet.datasets.data_table import DataTable +from parakeet.training.updaters.standard_updater import StandardUpdater, UpdaterState +from parakeet.training.extensions.evaluator import StandardEvaluator +from parakeet.training.trainer import Trainer +from parakeet.training.reporter import report +from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator +from parakeet.modules.stft_loss import MultiResolutionSTFTLoss +from parakeet.utils.profile import synchronize + + +class PWGUpdater(StandardUpdater): + def __init__( + self, + models: Dict[str, Layer], + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, + discriminator_train_start_steps: int, + lambda_adv: float, ): + self.models = models + self.generator: Layer = models['generator'] + self.discriminator: Layer = models['discriminator'] + + self.optimizers = optimizers + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_mse = criterions['mse'] + + self.schedulers = schedulers + self.scheduler_g = schedulers['generator'] + self.scheduler_d = schedulers['discriminator'] + + self.dataloader = dataloader + + self.discriminator_train_start_steps = discriminator_train_start_steps + self.lambda_adv = lambda_adv + self.state = UpdaterState(iteration=0, epoch=0) + + self.train_iterator = iter(self.dataloader) + + def update_core(self, batch): + # parse batch + wav, mel = batch + + # Generator + noise = paddle.randn(wav.shape) + + with timer() as t: + wav_ = self.generator(noise, mel) + logging.debug(f"Generator takes {t.elapse}s.") + + ## Multi-resolution stft loss + + with timer() as t: + sc_loss, mag_loss = self.criterion_stft( + wav_.squeeze(1), wav.squeeze(1)) + logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.") + + report("train/spectral_convergence_loss", float(sc_loss)) + report("train/log_stft_magnitude_loss", float(mag_loss)) + gen_loss = sc_loss + mag_loss + + ## Adversarial loss + if self.state.iteration > self.discriminator_train_start_steps: + with timer() as t: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) + logging.debug( + f"Discriminator and adversarial loss takes {t.elapse}s") + report("train/adversarial_loss", float(adv_loss)) + gen_loss += self.lambda_adv * adv_loss + + report("train/generator_loss", float(gen_loss)) + + with timer() as t: + self.optimizer_g.clear_grad() + gen_loss.backward() + logging.debug(f"Backward takes {t.elapse}s.") + + with timer() as t: + self.optimizer_g.step() + self.scheduler_g.step() + logging.debug(f"Update takes {t.elapse}s.") + + # Disctiminator + if self.state.iteration > self.discriminator_train_start_steps: + with paddle.no_grad(): + wav_ = self.generator(noise, mel) + p = self.discriminator(wav) + p_ = self.discriminator(wav_.detach()) + real_loss = self.criterion_mse(p, paddle.ones_like(p)) + fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_)) + report("train/real_loss", float(real_loss)) + report("train/fake_loss", float(fake_loss)) + dis_loss = real_loss + fake_loss + report("train/discriminator_loss", float(dis_loss)) + + self.optimizer_d.clear_grad() + dis_loss.backward() + + self.optimizer_d.step() + self.scheduler_d.step() + + +class PWGEvaluator(StandardEvaluator): + def __init__(self, models, criterions, dataloader, lambda_adv): + self.models = models + self.generator = models['generator'] + self.discriminator = models['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_mse = criterions['mse'] + + self.dataloader = dataloader + self.lambda_adv = lambda_adv + + def evaluate_core(self, batch): + logging.debug("Evaluate: ") + wav, mel = batch + noise = paddle.randn(wav.shape) + + with timer() as t: + wav_ = self.generator(noise, mel) + logging.debug(f"Generator takes {t.elapse}s") + + ## Adversarial loss + with timer() as t: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) + logging.debug( + f"Discriminator and adversarial loss takes {t.elapse}s") + report("eval/adversarial_loss", float(adv_loss)) + gen_loss = self.lambda_adv * adv_loss + + # stft loss + with timer() as t: + sc_loss, mag_loss = self.criterion_stft( + wav_.squeeze(1), wav.squeeze(1)) + logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s") + + report("eval/spectral_convergence_loss", float(sc_loss)) + report("eval/log_stft_magnitude_loss", float(mag_loss)) + gen_loss += sc_loss + mag_loss + + report("eval/generator_loss", float(gen_loss)) + + # Disctiminator + p = self.discriminator(wav) + real_loss = self.criterion_mse(p, paddle.ones_like(p)) + fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_)) + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + dis_loss = real_loss + fake_loss + report("eval/discriminator_loss", float(dis_loss)) diff --git a/examples/parallelwave_gan/baker/synthesize.py b/examples/parallelwave_gan/baker/synthesize.py new file mode 100644 index 0000000..01cfbbf --- /dev/null +++ b/examples/parallelwave_gan/baker/synthesize.py @@ -0,0 +1,93 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +from timer import timer +import logging +import argparse +from pathlib import Path + +import yaml +import jsonlines +import paddle +import numpy as np +import soundfile as sf +from paddle import distributed as dist + +from parakeet.datasets.data_table import DataTable +from parakeet.models.parallel_wavegan import PWGGenerator + +from config import get_cfg_default + +parser = argparse.ArgumentParser( + description="synthesize with parallel wavegan.") +parser.add_argument( + "--config", type=str, help="config file to overwrite default config") +parser.add_argument("--checkpoint", type=str, help="snapshot to load") +parser.add_argument("--test-metadata", type=str, help="dev data") +parser.add_argument("--output-dir", type=str, help="output dir") +parser.add_argument("--device", type=str, default="gpu", help="device to run") +parser.add_argument("--verbose", type=int, default=1, help="verbose") + +args = parser.parse_args() +config = get_cfg_default() +if args.config: + config.merge_from_file(args.config) + +print("========Args========") +print(yaml.safe_dump(vars(args))) +print("========Config========") +print(config) +print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" +) + +paddle.set_device(args.device) +generator = PWGGenerator(**config["generator_params"]) +state_dict = paddle.load(args.checkpoint) +generator.set_state_dict(state_dict["generator_params"]) + +generator.remove_weight_norm() +generator.eval() +with jsonlines.open(args.test_metadata, 'r') as reader: + metadata = list(reader) + +test_dataset = DataTable( + metadata, + fields=['utt_id', 'feats'], + converters={ + 'utt_id': None, + 'feats': np.load, + }) +output_dir = Path(args.output_dir) +output_dir.mkdir(parents=True, exist_ok=True) + +N = 0 +T = 0 +for example in test_dataset: + utt_id = example['utt_id'] + mel = example['feats'] + mel = paddle.to_tensor(mel) # (T, C) + with timer() as t: + wav = generator.inference(c=mel) + wav = wav.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.sr / speed}." + ) + sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=config.sr) +print(f"generation speed: {N / T}Hz, RTF: {config.sr / (N / T) }") diff --git a/examples/parallelwave_gan/baker/train.py b/examples/parallelwave_gan/baker/train.py new file mode 100644 index 0000000..3699e6f --- /dev/null +++ b/examples/parallelwave_gan/baker/train.py @@ -0,0 +1,246 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import logging +import argparse +import dataclasses +from pathlib import Path + +import yaml +import jsonlines +import paddle +import numpy as np +from paddle import nn +from paddle.nn import functional as F +from paddle import distributed as dist +from paddle.io import DataLoader, DistributedBatchSampler +from paddle.optimizer import Adam # No RAdaom +from paddle.optimizer.lr import StepDecay +from paddle import DataParallel +from visualdl import LogWriter + +from parakeet.datasets.data_table import DataTable +from parakeet.training.updater import UpdaterBase +from parakeet.training.trainer import Trainer +from parakeet.training.reporter import report +from parakeet.training import extension +from parakeet.training.extensions.snapshot import Snapshot +from parakeet.training.extensions.visualizer import VisualDL +from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator +from parakeet.modules.stft_loss import MultiResolutionSTFTLoss +from parakeet.training.seeding import seed_everything + +from batch_fn import Clip +from config import get_cfg_default +from pwg_updater import PWGUpdater, PWGEvaluator + + +def train_sp(args, config): + # decides device type and whether to run in parallel + # setup running environment correctly + if not paddle.is_compiled_with_cuda: + paddle.set_device("cpu") + else: + paddle.set_device("gpu") + world_size = paddle.distributed.get_world_size() + if world_size > 1: + paddle.distributed.init_parallel_env() + + # set the random seed, it is a must for multiprocess training + seed_everything(config.seed) + + print( + f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", + ) + + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + + # construct dataset for training and validation + with jsonlines.open(args.train_metadata, 'r') as reader: + train_metadata = list(reader) + train_dataset = DataTable( + data=train_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + with jsonlines.open(args.dev_metadata, 'r') as reader: + dev_metadata = list(reader) + dev_dataset = DataTable( + data=dev_metadata, + fields=["wave", "feats"], + converters={ + "wave": np.load, + "feats": np.load, + }, ) + + # collate function and dataloader + train_sampler = DistributedBatchSampler( + train_dataset, + batch_size=config.batch_size, + shuffle=True, + drop_last=True) + dev_sampler = DistributedBatchSampler( + dev_dataset, + batch_size=config.batch_size, + shuffle=False, + drop_last=False) + print("samplers done!") + + train_batch_fn = Clip( + batch_max_steps=config.batch_max_steps, + hop_size=config.hop_length, + aux_context_window=config.generator_params.aux_context_window) + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + dev_dataloader = DataLoader( + dev_dataset, + batch_sampler=dev_sampler, + collate_fn=train_batch_fn, + num_workers=config.num_workers) + print("dataloaders done!") + + generator = PWGGenerator(**config["generator_params"]) + discriminator = PWGDiscriminator(**config["discriminator_params"]) + if world_size > 1: + generator = DataParallel(generator) + discriminator = DataParallel(discriminator) + print("models done!") + + criterion_stft = MultiResolutionSTFTLoss(**config["stft_loss_params"]) + criterion_mse = nn.MSELoss() + print("criterions done!") + + lr_schedule_g = StepDecay(**config["generator_scheduler_params"]) + gradient_clip_g = nn.ClipGradByGlobalNorm(config["generator_grad_norm"]) + optimizer_g = Adam( + learning_rate=lr_schedule_g, + grad_clip=gradient_clip_g, + parameters=generator.parameters(), + **config["generator_optimizer_params"]) + lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"]) + gradient_clip_d = nn.ClipGradByGlobalNorm(config[ + "discriminator_grad_norm"]) + optimizer_d = Adam( + learning_rate=lr_schedule_d, + grad_clip=gradient_clip_d, + parameters=discriminator.parameters(), + **config["discriminator_optimizer_params"]) + print("optimizers done!") + + output_dir = Path(args.output_dir) + checkpoint_dir = output_dir / "checkpoints" + if dist.get_rank() == 0: + output_dir.mkdir(parents=True, exist_ok=True) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + with open(output_dir / "config.yaml", 'wt') as f: + f.write(config.dump(default_flow_style=None)) + + updater = PWGUpdater( + models={ + "generator": generator, + "discriminator": discriminator, + }, + optimizers={ + "generator": optimizer_g, + "discriminator": optimizer_d, + }, + criterions={ + "stft": criterion_stft, + "mse": criterion_mse, + }, + schedulers={ + "generator": lr_schedule_g, + "discriminator": lr_schedule_d, + }, + dataloader=train_dataloader, + discriminator_train_start_steps=config.discriminator_train_start_steps, + lambda_adv=config.lambda_adv, ) + + evaluator = PWGEvaluator( + models={ + "generator": generator, + "discriminator": discriminator, + }, + criterions={ + "stft": criterion_stft, + "mse": criterion_mse, + }, + dataloader=dev_dataloader, + lambda_adv=config.lambda_adv, ) + trainer = Trainer( + updater, + stop_trigger=(config.train_max_steps, "iteration"), + out=output_dir, ) + + trainer.extend( + evaluator, trigger=(config.eval_interval_steps, 'iteration')) + if dist.get_rank() == 0: + writer = LogWriter(str(trainer.out)) + trainer.extend(VisualDL(writer), trigger=(1, 'iteration')) + trainer.extend( + Snapshot(max_size=config.num_snapshots), + trigger=(config.save_interval_steps, 'iteration')) + + print(trainer.extensions.keys()) + print("Trainer Done!") + trainer.run() + + +def main(): + # parse args and config and redirect to train_sp + parser = argparse.ArgumentParser(description="Train a ParallelWaveGAN " + "model with Baker Mandrin TTS dataset.") + parser.add_argument( + "--config", type=str, help="config file to overwrite default config") + parser.add_argument("--train-metadata", type=str, help="training data") + parser.add_argument("--dev-metadata", type=str, help="dev data") + parser.add_argument("--output-dir", type=str, help="output dir") + parser.add_argument( + "--device", type=str, default="gpu", help="device type to use") + parser.add_argument( + "--nprocs", type=int, default=1, help="number of processes") + parser.add_argument("--verbose", type=int, default=1, help="verbose") + + args = parser.parse_args() + if args.device == "cpu" and args.nprocs > 1: + raise RuntimeError("Multiprocess training on CPU is not supported.") + config = get_cfg_default() + if args.config: + config.merge_from_file(args.config) + + print("========Args========") + print(yaml.safe_dump(vars(args))) + print("========Config========") + print(config) + print( + f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" + ) + + # dispatch + if args.nprocs > 1: + dist.spawn(train_sp, (args, config), nprocs=args.nprocs) + else: + train_sp(args, config) + + +if __name__ == "__main__": + main() diff --git a/parakeet/__init__.py b/parakeet/__init__.py index d4940a3..f08f907 100644 --- a/parakeet/__init__.py +++ b/parakeet/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.0-beta.0" +__version__ = "0.0.0" +import logging from parakeet import audio, data, datasets, frontend, models, modules, training, utils diff --git a/parakeet/datasets/data_table.py b/parakeet/datasets/data_table.py new file mode 100644 index 0000000..78a3608 --- /dev/null +++ b/parakeet/datasets/data_table.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union, Optional, Callable, Tuple, List, Dict, Any +from pathlib import Path +from multiprocessing import Manager + +import numpy as np +from paddle.io import Dataset + + +class DataTable(Dataset): + """Dataset to load and convert data for general purpose. + + Parameters + ---------- + data : List[Dict[str, Any]] + Metadata, a list of meta datum, each of which is composed of + several fields + fields : List[str], optional + Fields to use, if not specified, all the fields in the data are + used, by default None + converters : Dict[str, Callable], optional + Converters used to process each field, by default None + use_cache : bool, optional + Whether to use cache, by default False + + Raises + ------ + ValueError + If there is some field that does not exist in data. + ValueError + If there is some field in converters that does not exist in fields. + """ + + def __init__(self, + data: List[Dict[str, Any]], + fields: List[str]=None, + converters: Dict[str, Callable]=None, + use_cache: bool=False): + # metadata + self.data = data + assert len(data) > 0, "This dataset has no examples" + + # peak an example to get existing fields. + first_example = self.data[0] + fields_in_data = first_example.keys() + + # check all the requested fields exist + if fields is None: + self.fields = fields_in_data + else: + for field in fields: + if field not in fields_in_data: + raise ValueError( + f"The requested field ({field}) is not found" + f"in the data. Fields in the data is {fields_in_data}") + self.fields = fields + + # check converters + if converters is None: + self.converters = {} + else: + for field in converters.keys(): + if field not in self.fields: + raise ValueError( + f"The converter has a non existing field ({field})") + self.converters = converters + + self.use_cache = use_cache + if use_cache: + self._initialize_cache() + + def _initialize_cache(self): + self.manager = Manager() + self.caches = self.manager.list() + self.caches += [None for _ in range(len(self))] + + def _get_metadata(self, idx: int) -> Dict[str, Any]: + """Return a meta-datum given an index.""" + return self.data[idx] + + def _convert(self, meta_datum: Dict[str, Any]) -> Dict[str, Any]: + """Convert a meta datum to an example by applying the corresponding + converters to each fields requested. + + Parameters + ---------- + meta_datum : Dict[str, Any] + Meta datum + + Returns + ------- + Dict[str, Any] + Converted example + """ + example = {} + for field in self.fields: + converter = self.converters.get(field, None) + meta_datum_field = meta_datum[field] + if converter is not None: + converted_field = converter(meta_datum_field) + else: + converted_field = meta_datum_field + example[field] = converted_field + return example + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """Get an example given an index. + + Parameters + ---------- + idx : int + Index of the example to get + + Returns + ------- + Dict[str, Any] + A converted example + """ + if self.use_cache and self.caches[idx] is not None: + return self.caches[idx] + + meta_datum = self._get_metadata(idx) + example = self._convert(meta_datum) + + if self.use_cache: + self.caches[idx] = example + + return example + + def __len__(self) -> int: + """Returns the size of the dataset. + + Returns + ------- + int + The length of the dataset + """ + return len(self.data) diff --git a/parakeet/models/parallel_wavegan.py b/parakeet/models/parallel_wavegan.py new file mode 100644 index 0000000..ea183ef --- /dev/null +++ b/parakeet/models/parallel_wavegan.py @@ -0,0 +1,770 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Dict, Any, Union, Optional, Tuple + +import numpy as np +import paddle +from paddle import Tensor +from paddle import nn +from paddle.nn import functional as F + + +class Stretch2D(nn.Layer): + def __init__(self, w_scale: int, h_scale: int, mode: str="nearest"): + """Strech an image (or image-like object) with some interpolation. + + Parameters + ---------- + w_scale : int + Scalar of width. + h_scale : int + Scalar of the height. + mode : str, optional + Interpolation mode, modes suppored are "nearest", "bilinear", + "trilinear", "bicubic", "linear" and "area",by default "nearest" + + For more details about interpolation, see + `paddle.nn.functional.interpolate `_. + """ + super().__init__() + self.w_scale = w_scale + self.h_scale = h_scale + self.mode = mode + + def forward(self, x: Tensor) -> Tensor: + """ + Parameters + ---------- + x : Tensor + Shape (N, C, H, W) + + Returns + ------- + Tensor + Shape (N, C, H', W'), where ``H'=h_scale * H``, ``W'=w_scale * W``. + The stretched image. + """ + out = F.interpolate( + x, scale_factor=(self.h_scale, self.w_scale), mode=self.mode) + return out + + +class UpsampleNet(nn.Layer): + """A Layer to upsample spectrogram by applying consecutive stretch and + convolutions. + + Parameters + ---------- + upsample_scales : List[int] + Upsampling factors for each strech. + nonlinear_activation : Optional[str], optional + Activation after each convolution, by default None + nonlinear_activation_params : Dict[str, Any], optional + Parameters passed to construct the activation, by default {} + interpolate_mode : str, optional + Interpolation mode of the strech, by default "nearest" + freq_axis_kernel_size : int, optional + Convolution kernel size along the frequency axis, by default 1 + use_causal_conv : bool, optional + Whether to use causal padding before convolution, by default False + + If True, Causal padding is used along the time axis, i.e. padding + amount is ``receptive field - 1`` and 0 for before and after, + respectively. + + If False, "same" padding is used along the time axis. + """ + + def __init__(self, + upsample_scales: List[int], + nonlinear_activation: Optional[str]=None, + nonlinear_activation_params: Dict[str, Any]={}, + interpolate_mode: str="nearest", + freq_axis_kernel_size: int=1, + use_causal_conv: bool=False): + super().__init__() + self.use_causal_conv = use_causal_conv + self.up_layers = nn.LayerList() + for scale in upsample_scales: + stretch = Stretch2D(scale, 1, interpolate_mode) + assert freq_axis_kernel_size % 2 == 1 + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + kernel_size = (freq_axis_kernel_size, scale * 2 + 1) + if use_causal_conv: + padding = (freq_axis_padding, scale * 2) + else: + padding = (freq_axis_padding, scale) + conv = nn.Conv2D( + 1, 1, kernel_size, padding=padding, bias_attr=False) + self.up_layers.extend([stretch, conv]) + if nonlinear_activation is not None: + nonlinear = getattr( + nn, nonlinear_activation)(**nonlinear_activation_params) + self.up_layers.append(nonlinear) + + def forward(self, c: Tensor) -> Tensor: + """ + Parameters + ---------- + c : Tensor + Shape (N, F, T), spectrogram + + Returns + ------- + Tensor + Shape (N, F, T'), where ``T' = upsample_factor * T``, upsampled + spectrogram + """ + c = c.unsqueeze(1) + for f in self.up_layers: + if self.use_causal_conv and isinstance(f, nn.Conv2D): + c = f(c)[:, :, :, c.shape[-1]] + else: + c = f(c) + return c.squeeze(1) + + +class ConvInUpsampleNet(nn.Layer): + """A Layer to upsample spectrogram composed of a convolution and an + UpsampleNet. + + Parameters + ---------- + upsample_scales : List[int] + Upsampling factors for each strech. + nonlinear_activation : Optional[str], optional + Activation after each convolution, by default None + nonlinear_activation_params : Dict[str, Any], optional + Parameters passed to construct the activation, by default {} + interpolate_mode : str, optional + Interpolation mode of the strech, by default "nearest" + freq_axis_kernel_size : int, optional + Convolution kernel size along the frequency axis, by default 1 + aux_channels : int, optional + Feature size of the input, by default 80 + aux_context_window : int, optional + Context window of the first 1D convolution applied to the input. It + related to the kernel size of the convolution, by default 0 + + If use causal convolution, the kernel size is ``window + 1``, else + the kernel size is ``2 * window + 1``. + use_causal_conv : bool, optional + Whether to use causal padding before convolution, by default False + + If True, Causal padding is used along the time axis, i.e. padding + amount is ``receptive field - 1`` and 0 for before and after, + respectively. + + If False, "same" padding is used along the time axis. + """ + + def __init__(self, + upsample_scales: List[int], + nonlinear_activation: Optional[str]=None, + nonlinear_activation_params: Dict[str, Any]={}, + interpolate_mode: str="nearest", + freq_axis_kernel_size: int=1, + aux_channels: int=80, + aux_context_window: int=0, + use_causal_conv: bool=False): + super().__init__() + self.aux_context_window = aux_context_window + self.use_causal_conv = use_causal_conv and aux_context_window > 0 + kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 + self.conv_in = nn.Conv1D( + aux_channels, + aux_channels, + kernel_size=kernel_size, + bias_attr=False) + self.upsample = UpsampleNet( + upsample_scales=upsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + interpolate_mode=interpolate_mode, + freq_axis_kernel_size=freq_axis_kernel_size, + use_causal_conv=use_causal_conv) + + def forward(self, c: Tensor) -> Tensor: + """ + Parameters + ---------- + c : Tensor + Shape (N, F, T), spectrogram + + Returns + ------- + Tensors + Shape (N, F, T'), where ``T' = upsample_factor * T``, upsampled + spectrogram + """ + c_ = self.conv_in(c) + c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ + return self.upsample(c) + + +class ResidualBlock(nn.Layer): + """A gated activation unit composed of an 1D convolution, a gated tanh + unit and parametric redidual and skip connections. For more details, + refer to `WaveNet: A Generative Model for Raw Audio `_. + + Parameters + ---------- + kernel_size : int, optional + Kernel size of the 1D convolution, by default 3 + residual_channels : int, optional + Feature size of the resiaudl output(and also the input), by default 64 + gate_channels : int, optional + Output feature size of the 1D convolution, by default 128 + skip_channels : int, optional + Feature size of the skip output, by default 64 + aux_channels : int, optional + Feature size of the auxiliary input (e.g. spectrogram), by default 80 + dropout : float, optional + Probability of the dropout before the 1D convolution, by default 0. + dilation : int, optional + Dilation of the 1D convolution, by default 1 + bias : bool, optional + Whether to use bias in the 1D convolution, by default True + use_causal_conv : bool, optional + Whether to use causal padding for the 1D convolution, by default False + """ + + def __init__(self, + kernel_size: int=3, + residual_channels: int=64, + gate_channels: int=128, + skip_channels: int=64, + aux_channels: int=80, + dropout: float=0., + dilation: int=1, + bias: bool=True, + use_causal_conv: bool=False): + super().__init__() + self.dropout = dropout + if use_causal_conv: + padding = (kernel_size - 1) * dilation + else: + assert kernel_size % 2 == 1 + padding = (kernel_size - 1) // 2 * dilation + self.use_causal_conv = use_causal_conv + + self.conv = nn.Conv1D( + residual_channels, + gate_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias_attr=bias) + if aux_channels is not None: + self.conv1x1_aux = nn.Conv1D( + aux_channels, gate_channels, kernel_size=1, bias_attr=False) + else: + self.conv1x1_aux = None + + gate_out_channels = gate_channels // 2 + self.conv1x1_out = nn.Conv1D( + gate_out_channels, + residual_channels, + kernel_size=1, + bias_attr=bias) + self.conv1x1_skip = nn.Conv1D( + gate_out_channels, skip_channels, kernel_size=1, bias_attr=bias) + + def forward(self, x: Tensor, c: Tensor) -> Tuple[Tensor, Tensor]: + """ + Parameters + ---------- + x : Tensor + Shape (N, C_res, T), the input features. + c : Tensor + Shape (N, C_aux, T), the auxiliary input. + + Returns + ------- + res : Tensor + Shape (N, C_res, T), the residual output, which is used as the + input of the next ResidualBlock in a stack of ResidualBlocks. + skip : Tensor + Shape (N, C_skip, T), the skip output, which is collected among + each layer in a stack of ResidualBlocks. + """ + x_input = x + x = F.dropout(x, self.dropout, training=self.training) + x = self.conv(x) + x = x[:, :, x_input.shape[-1]] if self.use_causal_conv else x + if c is not None: + c = self.conv1x1_aux(c) + x += c + + a, b = paddle.chunk(x, 2, axis=1) + x = paddle.tanh(a) * F.sigmoid(b) + + skip = self.conv1x1_skip(x) + res = (self.conv1x1_out(x) + x_input) * math.sqrt(0.5) + return res, skip + + +class PWGGenerator(nn.Layer): + """Wave Generator for Parallel WaveGAN + + Parameters + ---------- + in_channels : int, optional + Number of channels of the input waveform, by default 1 + out_channels : int, optional + Number of channels of the output waveform, by default 1 + kernel_size : int, optional + Kernel size of the residual blocks inside, by default 3 + layers : int, optional + Number of residual blocks inside, by default 30 + stacks : int, optional + The number of groups to split the residual blocks into, by default 3 + + Within each group, the dilation of the residual block grows + exponentially. + residual_channels : int, optional + Residual channel of the residual blocks, by default 64 + gate_channels : int, optional + Gate channel of the residual blocks, by default 128 + skip_channels : int, optional + Skip channel of the residual blocks, by default 64 + aux_channels : int, optional + Auxiliary channel of the residual blocks, by default 80 + aux_context_window : int, optional + The context window size of the first convolution applied to the + auxiliary input, by default 2 + dropout : float, optional + Dropout of the residual blocks, by default 0. + bias : bool, optional + Whether to use bias in residual blocks, by default True + use_weight_norm : bool, optional + Whether to use weight norm in all convolutions, by default True + use_causal_conv : bool, optional + Whether to use causal padding in the upsample network and residual + blocks, by default False + upsample_scales : List[int], optional + Upsample scales of the upsample network, by default [4, 4, 4, 4] + nonlinear_activation : Optional[str], optional + Non linear activation in upsample network, by default None + nonlinear_activation_params : Dict[str, Any], optional + Parameters passed to the linear activation in the upsample network, + by default {} + interpolate_mode : str, optional + Interpolation mode of the upsample network, by default "nearest" + freq_axis_kernel_size : int, optional + Kernel size along the frequency axis of the upsample network, by default 1 + """ + + def __init__(self, + in_channels: int=1, + out_channels: int=1, + kernel_size: int=3, + layers: int=30, + stacks: int=3, + residual_channels: int=64, + gate_channels: int=128, + skip_channels: int=64, + aux_channels: int=80, + aux_context_window: int=2, + dropout: float=0., + bias: bool=True, + use_weight_norm: bool=True, + use_causal_conv: bool=False, + upsample_scales: List[int]=[4, 4, 4, 4], + nonlinear_activation: Optional[str]=None, + nonlinear_activation_params: Dict[str, Any]={}, + interpolate_mode: str="nearest", + freq_axis_kernel_size: int=1): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.aux_channels = aux_channels + self.aux_context_window = aux_context_window + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + self.first_conv = nn.Conv1D( + in_channels, residual_channels, 1, bias_attr=True) + self.upsample_net = ConvInUpsampleNet( + upsample_scales=upsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + interpolate_mode=interpolate_mode, + freq_axis_kernel_size=freq_axis_kernel_size, + aux_channels=aux_channels, + aux_context_window=aux_context_window, + use_causal_conv=use_causal_conv) + self.upsample_factor = np.prod(upsample_scales) + + self.conv_layers = nn.LayerList() + for layer in range(layers): + dilation = 2**(layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + dilation=dilation, + dropout=dropout, + bias=bias, + use_causal_conv=use_causal_conv) + self.conv_layers.append(conv) + + self.last_conv_layers = nn.Sequential( + nn.ReLU(), + nn.Conv1D( + skip_channels, skip_channels, 1, bias_attr=True), + nn.ReLU(), + nn.Conv1D( + skip_channels, out_channels, 1, bias_attr=True)) + + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, x: Tensor, c: Tensor) -> Tensor: + """Generate waveform. + + Parameters + ---------- + x : Tensor + Shape (N, C_in, T), The input waveform. + c : Tensor + Shape (N, C_aux, T'). The auxiliary input (e.g. spectrogram). It + is upsampled to match the time resolution of the input. + + Returns + ------- + Tensor + Shape (N, C_out, T), the generated waveform. + """ + c = self.upsample_net(c) + assert c.shape[-1] == x.shape[-1] + + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, s = f(x, c) + skips += s + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + x = self.last_conv_layers(skips) + return x + + def apply_weight_norm(self): + """Recursively apply weight normalization to all the Convolution layers + in the sublayers. + """ + + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Recursively remove weight normalization from all the Convolution + layers in the sublayers. + """ + + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) + + def inference(self, c: Optional[Tensor]=None, + x: Optional[Tensor]=None) -> Tensor: + """Waveform generation. This function is used for single instance + inference. + + Parameters + ---------- + c : Tensor, optional + Shape (T', C_aux), the auxiliary input, by default None + x : Tensor, optional + Shape (T, C_in), the noise waveform, by default None + If not provided, a sample is drawn from a gaussian distribution. + + Returns + ------- + Tensor + Shape (T, C_out), the generated waveform + """ + if x is not None: + x = paddle.transpose(x, [1, 0]).unsqueeze(0) # pseudo batch + else: + assert c is not None + x = paddle.randn( + [1, self.in_channels, c.shape[0] * self.upsample_factor]) + + if c is not None: + c = paddle.transpose(c, [1, 0]).unsqueeze(0) # pseudo batch + c = nn.Pad1D(self.aux_context_window, mode='replicate')(c) + out = self.forward(x, c).squeeze(0).transpose([1, 0]) + return out + + +class PWGDiscriminator(nn.Layer): + """A convolutional discriminator for audio. + + Parameters + ---------- + in_channels : int, optional + Number of channels of the input audio, by default 1 + out_channels : int, optional + Output feature size, by default 1 + kernel_size : int, optional + Kernel size of convolutional sublayers, by default 3 + layers : int, optional + Number of layers, by default 10 + conv_channels : int, optional + Feature size of the convolutional sublayers, by default 64 + dilation_factor : int, optional + The factor with which dilation of each convolutional sublayers grows + exponentially if it is greater than 1, else the dilation of each + convolutional sublayers grows linearly, by default 1 + nonlinear_activation : str, optional + The activation after each convolutional sublayer, by default "LeakyReLU" + nonlinear_activation_params : Dict[str, Any], optional + The parameters passed to the activation's initializer, by default + {"negative_slope": 0.2} + bias : bool, optional + Whether to use bias in convolutional sublayers, by default True + use_weight_norm : bool, optional + Whether to use weight normalization at all convolutional sublayers, + by default True + """ + + def __init__(self, + in_channels: int=1, + out_channels: int=1, + kernel_size: int=3, + layers: int=10, + conv_channels: int=64, + dilation_factor: int=1, + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[ + str, Any]={"negative_slope": 0.2}, + bias: bool=True, + use_weight_norm: bool=True): + super().__init__() + assert kernel_size % 2 == 1 + assert dilation_factor > 0 + conv_layers = [] + conv_in_channels = in_channels + for i in range(layers - 1): + if i == 0: + dilation = 1 + else: + dilation = i if dilation_factor == 1 else dilation_factor**i + conv_in_channels = conv_channels + padding = (kernel_size - 1) // 2 * dilation + conv_layer = nn.Conv1D( + conv_in_channels, + conv_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias_attr=bias) + nonlinear = getattr( + nn, nonlinear_activation)(**nonlinear_activation_params) + conv_layers.append(conv_layer) + conv_layers.append(nonlinear) + padding = (kernel_size - 1) // 2 + last_conv = nn.Conv1D( + conv_in_channels, + out_channels, + kernel_size, + padding=padding, + bias_attr=bias) + conv_layers.append(last_conv) + self.conv_layers = nn.Sequential(*conv_layers) + + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, x: Tensor) -> Tensor: + """ + Parameters + ---------- + x : Tensor + Shape (N, in_channels, num_samples), the input audio. + + Returns + ------- + Tensor + Shape (N, out_channels, num_samples), the predicted logits. + """ + return self.conv_layers(x) + + def apply_weight_norm(self): + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) + + +class ResidualPWGDiscriminator(nn.Layer): + """A wavenet-style discriminator for audio. + + Parameters + ---------- + in_channels : int, optional + Number of channels of the input audio, by default 1 + out_channels : int, optional + Output feature size, by default 1 + kernel_size : int, optional + Kernel size of residual blocks, by default 3 + layers : int, optional + Number of residual blocks, by default 30 + stacks : int, optional + Number of groups of residual blocks, within which the dilation + of each residual blocks grows exponentially, by default 3 + residual_channels : int, optional + Residual channels of residual blocks, by default 64 + gate_channels : int, optional + Gate channels of residual blocks, by default 128 + skip_channels : int, optional + Skip channels of residual blocks, by default 64 + dropout : float, optional + Dropout probability of residual blocks, by default 0. + bias : bool, optional + Whether to use bias in residual blocks, by default True + use_weight_norm : bool, optional + Whether to use weight normalization in all convolutional layers, + by default True + use_causal_conv : bool, optional + Whether to use causal convolution in residual blocks, by default False + nonlinear_activation : str, optional + Activation after convolutions other than those in residual blocks, + by default "LeakyReLU" + nonlinear_activation_params : Dict[str, Any], optional + Parameters to pass to the activation, by default {"negative_slope": 0.2} + """ + + def __init__(self, + in_channels: int=1, + out_channels: int=1, + kernel_size: int=3, + layers: int=30, + stacks: int=3, + residual_channels: int=64, + gate_channels: int=128, + skip_channels: int=64, + dropout: float=0., + bias: bool=True, + use_weight_norm: bool=True, + use_causal_conv: bool=False, + nonlinear_activation: str="LeakyReLU", + nonlinear_activation_params: Dict[ + str, Any]={"negative_slope": 0.2}): + super().__init__() + assert kernel_size % 2 == 1 + self.in_channels = in_channels + self.out_channels = out_channels + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + self.first_conv = nn.Sequential( + nn.Conv1D( + in_channels, residual_channels, 1, bias_attr=True), + getattr(nn, nonlinear_activation)(**nonlinear_activation_params)) + + self.conv_layers = nn.LayerList() + for layer in range(layers): + dilation = 2**(layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=None, # no auxiliary input + dropout=dropout, + dilation=dilation, + bias=bias, + use_causal_conv=use_causal_conv) + self.conv_layers.append(conv) + + self.last_conv_layers = nn.Sequential( + getattr(nn, nonlinear_activation)(**nonlinear_activation_params), + nn.Conv1D( + skip_channels, skip_channels, 1, bias_attr=True), + getattr(nn, nonlinear_activation)(**nonlinear_activation_params), + nn.Conv1D( + skip_channels, out_channels, 1, bias_attr=True)) + + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, x: Tensor) -> Tensor: + """ + Parameters + ---------- + x : Tensor + Shape (N, in_channels, num_samples), the input audio. + + Returns + ------- + Tensor + Shape (N, out_channels, num_samples), the predicted logits. + """ + x = self.first_conv(x) + skip = 0 + for f in self.conv_layers: + x, h = f(x, None) + skip += h + skip *= math.sqrt(1 / len(self.conv_layers)) + + x = skip + x = self.last_conv_layers(x) + return x + + def apply_weight_norm(self): + def _apply_weight_norm(layer): + if isinstance(layer, (nn.Conv1D, nn.Conv2D)): + nn.utils.weight_norm(layer) + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + def _remove_weight_norm(layer): + try: + nn.utils.remove_weight_norm(layer) + except ValueError: + pass + + self.apply(_remove_weight_norm) diff --git a/parakeet/modules/audio.py b/parakeet/modules/audio.py index 16c64a4..c44aa66 100644 --- a/parakeet/modules/audio.py +++ b/parakeet/modules/audio.py @@ -20,7 +20,7 @@ import librosa from librosa.util import pad_center import numpy as np -__all__ = ["quantize", "dequantize", "STFT"] +__all__ = ["quantize", "dequantize", "STFT", "MelScale"] def quantize(values, n_bands): @@ -96,10 +96,10 @@ class STFT(nn.Layer): Defaults to True. pad_mode : string or function - If center=True, this argument is passed to np.pad for padding the edges - of the signal y. By default (pad_mode="reflect"), y is padded on both - sides with its own reflection, mirrored around its first and last - sample respectively. If center=False, this argument is ignored. + If center=True, this argument is passed to np.pad for padding the edges + of the signal y. By default (pad_mode="reflect"), y is padded on both + sides with its own reflection, mirrored around its first and last + sample respectively. If center=False, this argument is ignored. @@ -163,17 +163,15 @@ class STFT(nn.Layer): w = np.concatenate([w_real, w_imag], axis=0) w = w * window w = np.expand_dims(w, 1) - self.weight = paddle.cast( - paddle.to_tensor(w), paddle.get_default_dtype()) + weight = paddle.cast(paddle.to_tensor(w), paddle.get_default_dtype()) + self.register_buffer("weight", weight) def forward(self, x): """Compute the stft transform. - Parameters ------------ x : Tensor [shape=(B, T)] The input waveform. - Returns ------------ real : Tensor [shape=(B, C, frames)] @@ -195,36 +193,32 @@ class STFT(nn.Layer): def power(self, x): """Compute the power spectrum. - Parameters ------------ x : Tensor [shape=(B, T)] The input waveform. - Returns ------------ Tensor [shape=(B, C, T)] The power spectrum. """ - real, imag = self(x) + real, imag = self.forward(x) power = real**2 + imag**2 return power def magnitude(self, x): """Compute the magnitude of the spectrum. - Parameters ------------ x : Tensor [shape=(B, T)] The input waveform. - Returns ------------ Tensor [shape=(B, C, T)] The magnitude of the spectrum. """ power = self.power(x) - magnitude = paddle.sqrt(power) + magnitude = paddle.sqrt(power) # TODO(chenfeiyu): maybe clipping return magnitude @@ -232,7 +226,9 @@ class MelScale(nn.Layer): def __init__(self, sr, n_fft, n_mels, fmin, fmax): super().__init__() mel_basis = librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax) - self.weight = paddle.to_tensor(mel_basis) + # self.weight = paddle.to_tensor(mel_basis) + weight = paddle.to_tensor(mel_basis, dtype=paddle.get_default_dtype()) + self.register_buffer("weight", weight) def forward(self, spec): # (n_mels, n_freq) * (batch_size, n_freq, n_frames) diff --git a/parakeet/modules/stft_loss.py b/parakeet/modules/stft_loss.py new file mode 100644 index 0000000..cdc066f --- /dev/null +++ b/parakeet/modules/stft_loss.py @@ -0,0 +1,144 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import nn +from paddle.nn import functional as F + +from parakeet.modules.audio import STFT + + +class SpectralConvergenceLoss(nn.Layer): + """Spectral convergence loss module.""" + + def __init__(self): + """Initilize spectral convergence loss module.""" + super().__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, C, T). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, C, T). + Returns: + Tensor: Spectral convergence loss value. + """ + return paddle.norm( + y_mag - x_mag, p="fro") / paddle.clip( + paddle.norm( + y_mag, p="fro"), min=1e-10) + + +class LogSTFTMagnitudeLoss(nn.Layer): + """Log STFT magnitude loss module.""" + + def __init__(self, epsilon=1e-10): + """Initilize los STFT magnitude loss module.""" + super().__init__() + self.epsilon = epsilon + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + Returns: + Tensor: Log STFT magnitude loss value. + """ + return F.l1_loss( + paddle.log(paddle.clip( + y_mag, min=self.epsilon)), + paddle.log(paddle.clip( + x_mag, min=self.epsilon))) + + +class STFTLoss(nn.Layer): + """STFT loss module.""" + + def __init__(self, + fft_size=1024, + shift_size=120, + win_length=600, + window="hann"): + """Initialize STFT loss module.""" + super().__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.stft = STFT( + n_fft=fft_size, + hop_length=shift_size, + win_length=win_length, + window=window) + self.spectral_convergence_loss = SpectralConvergenceLoss() + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + Returns: + Tensor: Spectral convergence loss value. + Tensor: Log STFT magnitude loss value. + """ + x_mag = self.stft.magnitude(x) + y_mag = self.stft.magnitude(y) + sc_loss = self.spectral_convergence_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + + return sc_loss, mag_loss + + +class MultiResolutionSTFTLoss(nn.Layer): + """Multi resolution STFT loss module.""" + + def __init__( + self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window="hann", ): + """Initialize Multi resolution STFT loss module. + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str): Window function type. + """ + super().__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + self.stft_losses = nn.LayerList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses.append(STFTLoss(fs, ss, wl, window)) + + def forward(self, x, y): + """Calculate forward propagation. + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + Returns: + Tensor: Multi resolution spectral convergence loss value. + Tensor: Multi resolution log STFT magnitude loss value. + """ + sc_loss = 0.0 + mag_loss = 0.0 + for f in self.stft_losses: + sc_l, mag_l = f(x, y) + sc_loss += sc_l + mag_loss += mag_l + sc_loss /= len(self.stft_losses) + mag_loss /= len(self.stft_losses) + + return sc_loss, mag_loss diff --git a/parakeet/training/checkpoint.py b/parakeet/training/checkpoint.py deleted file mode 100644 index 4bb12e2..0000000 --- a/parakeet/training/checkpoint.py +++ /dev/null @@ -1,162 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Mapping, List -from pathlib import Path - - -class KBest(object): - """ - A utility class to help save the hard drive by only keeping K best - checkpoints. - - To be as modularized as possible, this class does not assume anything like - a Trainer class or anything like a checkpoint directory, it does not know - about the model or the optimizer, etc. - - It is basically a dynamically mantained K-bset Mapping. When a new item is - added to the map, save_fn is called. And when an item is removed from the - map, del_fn is called. `save_fn` and `del_fn` takes a Path object as input - and returns nothing. - - Though it is designed to control checkpointing behaviors, it can be used - to do something else if you pass some save_fn and del_fn. - - Example - -------- - - >>> from pathlib import Path - >>> import shutil - >>> import paddle - >>> from paddle import nn - - >>> model = nn.Linear(2, 3) - >>> def save_model(path): - ... paddle.save(model.state_dict(), path) - - >>> kbest_manager = KBest(max_size=5, save_fn=save_model) - >>> checkpoint_dir = Path("checkpoints") - >>> shutil.rmtree(checkpoint_dir) - >>> checkpoint_dir.mkdir(parents=True) - >>> a = np.random.rand(20) - >>> for i, score in enumerate(a): - ... path = checkpoint_dir / f"step_{i}" - ... kbest_manager.add_checkpoint(score, path) - >>> assert len(list(checkpoint_dir.glob("step_*"))) == 5 - """ - - def __init__(self, - max_size: int=5, - save_fn: Callable[[Path], None]=None, - del_fn: Callable[[Path], None]=lambda f: f.unlink()): - self.best_records: Mapping[Path, float] = {} - self.save_fn = save_fn - self.del_fn = del_fn - self.max_size = max_size - self._save_all = (max_size == -1) - - def should_save(self, metric: float) -> bool: - if not self.full(): - return True - - # already full - worst_record_path = max(self.best_records, key=self.best_records.get) - worst_metric = self.best_records[worst_record_path] - return metric < worst_metric - - def full(self): - return (not self._save_all) and len(self.best_records) == self.max_size - - def add_checkpoint(self, metric, path): - if self.should_save(metric): - self.save_checkpoint_and_update(metric, path) - - def save_checkpoint_and_update(self, metric, path): - # remove the worst - if self.full(): - worst_record_path = max(self.best_records, - key=self.best_records.get) - self.best_records.pop(worst_record_path) - self.del_fn(worst_record_path) - - # add the new one - self.save_fn(path) - self.best_records[path] = metric - - -class KLatest(object): - """ - A utility class to help save the hard drive by only keeping K latest - checkpoints. - - To be as modularized as possible, this class does not assume anything like - a Trainer class or anything like a checkpoint directory, it does not know - about the model or the optimizer, etc. - - It is basically a dynamically mantained Queue. When a new item is - added to the queue, save_fn is called. And when an item is removed from the - queue, del_fn is called. `save_fn` and `del_fn` takes a Path object as input - and returns nothing. - - Though it is designed to control checkpointing behaviors, it can be used - to do something else if you pass some save_fn and del_fn. - - Example - -------- - - >>> from pathlib import Path - >>> import shutil - >>> import paddle - >>> from paddle import nn - - >>> model = nn.Linear(2, 3) - >>> def save_model(path): - ... paddle.save(model.state_dict(), path) - - >>> klatest_manager = KLatest(max_size=5, save_fn=save_model) - >>> checkpoint_dir = Path("checkpoints") - >>> shutil.rmtree(checkpoint_dir) - >>> checkpoint_dir.mkdir(parents=True) - >>> for i in range(20): - ... path = checkpoint_dir / f"step_{i}" - ... klatest_manager.add_checkpoint(path) - >>> assert len(list(checkpoint_dir.glob("step_*"))) == 5 - """ - - def __init__(self, - max_size: int=5, - save_fn: Callable[[Path], None]=None, - del_fn: Callable[[Path], None]=lambda f: f.unlink()): - self.latest_records: List[Path] = [] - self.save_fn = save_fn - self.del_fn = del_fn - self.max_size = max_size - self._save_all = (max_size == -1) - - def full(self): - return ( - not self._save_all) and len(self.latest_records) == self.max_size - - def add_checkpoint(self, path): - self.save_checkpoint_and_update(path) - - def save_checkpoint_and_update(self, path): - # remove the earist - if self.full(): - eariest_record_path = self.latest_records.pop(0) - self.del_fn(eariest_record_path) - - # add the new one - self.save_fn(path) - self.latest_records.append(path) diff --git a/parakeet/training/extension.py b/parakeet/training/extension.py new file mode 100644 index 0000000..57c4f29 --- /dev/null +++ b/parakeet/training/extension.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +PRIORITY_WRITER = 300 +PRIORITY_EDITOR = 200 +PRIORITY_READER = 100 + + +class Extension(object): + """Extension to customize the behavior of Trainer.""" + trigger = (1, 'iteration') + priority = PRIORITY_READER + name = None + + @property + def default_name(self): + """Default name of the extension, class name by default.""" + return type(self).__name__ + + def __call__(self, trainer): + """Main action of the extention. After each update, it is executed + when the trigger fires.""" + raise NotImplementedError( + 'Extension implementation must override __call__.') + + def initialize(self, trainer): + """Action that is executed once to get the corect trainer state. + It is called before training normally, but if the trainer restores + states with an Snapshot extension, this method should also be called.g + """ + pass + + def on_error(self, trainer, exc, tb): + """Handles the error raised during training before finalization. + """ + pass + + def finalize(self, trainer): + """Action that is executed when training is done. + For example, visualizers would need to be closed. + """ + pass + + +def make_extension(trigger: Callable=None, + default_name: str=None, + priority: int=None, + finalizer: Callable=None, + initializer: Callable=None, + on_error: Callable=None): + """Make an Extension-like object by injecting required attributes to it. + """ + if trigger is None: + trigger = Extension.trigger + if priority is None: + priority = Extension.priority + + def decorator(ext): + ext.trigger = trigger + ext.default_name = default_name or ext.__name__ + ext.priority = priority + ext.finalize = finalizer + ext.on_error = on_error + ext.initialize = initializer + return ext + + return decorator diff --git a/parakeet/training/extensions/evaluator.py b/parakeet/training/extensions/evaluator.py new file mode 100644 index 0000000..6ebaae6 --- /dev/null +++ b/parakeet/training/extensions/evaluator.py @@ -0,0 +1,73 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import Optional, Callable, Dict + +from tqdm import tqdm +import paddle +from paddle import Tensor +from paddle.nn import Layer +from paddle.io import DataLoader + +from parakeet.training.reporter import scope, report, DictSummary +from parakeet.training import extension + + +class StandardEvaluator(extension.Extension): + + trigger = (1, 'epoch') + default_name = 'validation' + priority = extension.PRIORITY_WRITER + + name = None + + def __init__(self, model: Layer, dataloader: DataLoader): + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models + self.model = model + + # dataloaders + self.dataloader = dataloader + + def evaluate_core(self, batch): + # compute + self.model(batch) # you may report here + + def evaluate(self): + # switch to eval mode + for layer in self.models.values(): + layer.eval() + + # to average evaluation metrics + summary = DictSummary() + for batch in self.dataloader: + observation = {} + with scope(observation): + # main evaluation computation here. + with paddle.no_grad(): + self.evaluate_core(batch) + summary.add(observation) + summary = summary.compute_mean() + return summary + + def __call__(self, trainer=None): + # evaluate and report the averaged metric to current observation + # if it is used to extend a trainer, the metrics is reported to + # to observation of the trainer + # or otherwise, you can use your own observation + summary = self.evaluate() + for k, v in summary.items(): + report(k, v) diff --git a/parakeet/training/extensions/snapshot.py b/parakeet/training/extensions/snapshot.py new file mode 100644 index 0000000..92d74ef --- /dev/null +++ b/parakeet/training/extensions/snapshot.py @@ -0,0 +1,110 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +from pathlib import Path +from datetime import datetime +from typing import List, Dict, Any + +import jsonlines + +from parakeet.utils.mp_tools import rank_zero_only +from parakeet.training.trainer import Trainer +from parakeet.training import extension + + +def load_records(records_fp): + """Load record files (json lines.)""" + with jsonlines.open(records_fp, 'r') as reader: + records = list(reader) + return records + + +class Snapshot(extension.Extension): + """An extension to make snapshot of the updater object inside + the trainer. It is done by calling the updater's `save` method. + + An Updater save its state_dict by default, which contains the + updater state, (i.e. epoch and iteration) and all the model + parameters and optimizer states. If the updater inside the trainer + subclasses StandardUpdater, everything is good to go. + + Parameters + ---------- + checkpoint_dir : Union[str, Path] + The directory to save checkpoints into. + """ + + trigger = (1, 'epoch') + priority = -100 + default_name = "snapshot" + + def __init__(self, max_size: int=5, snapshot_on_error: bool=False): + self.records: List[Dict[str, Any]] = [] + self.max_size = max_size + self._snapshot_on_error = snapshot_on_error + self._save_all = (max_size == -1) + self.checkpoint_dir =... + + def initialize(self, trainer: Trainer): + """Setting up this extention.""" + self.checkpoint_dir = trainer.out / "checkpoints" + + # load existing records + record_path: Path = self.checkpoint_dir / "records.jsonl" + if record_path.exists(): + logging.debug("Loading from an existing checkpoint dir") + self.records = load_records(record_path) + trainer.updater.load(self.records[-1]['path']) + + def on_error(self, trainer, exc, tb): + if self._snapshot_on_error: + self.save_checkpoint_and_update(trainer) + + def __call__(self, trainer: Trainer): + self.save_checkpoint_and_update(trainer) + + def full(self): + """Whether the number of snapshots it keeps track of is greater + than the max_size.""" + return (not self._save_all) and len(self.records) > self.max_size + + @rank_zero_only + def save_checkpoint_and_update(self, trainer: Trainer): + """Saving new snapshot and remove the oldest snapshot if needed.""" + iteration = trainer.updater.state.iteration + path = self.checkpoint_dir / f"snapshot_iter_{iteration}.pdz" + + # add the new one + trainer.updater.save(path) + record = { + "time": str(datetime.now()), + 'path': str(path.resolve()), # use absolute path + 'iteration': iteration + } + self.records.append(record) + + # remove the earist + if self.full(): + eariest_record = self.records[0] + os.remove(eariest_record["path"]) + self.records.pop(0) + + # update the record file + record_path = self.checkpoint_dir / "records.jsonl" + with jsonlines.open(record_path, 'w') as writer: + for record in self.records: + # jsonlines.open may return a Writer or a Reader + writer.write(record) # pylint: disable=no-member diff --git a/parakeet/training/extensions/visualizer.py b/parakeet/training/extensions/visualizer.py new file mode 100644 index 0000000..138bf1e --- /dev/null +++ b/parakeet/training/extensions/visualizer.py @@ -0,0 +1,40 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from visualdl import LogWriter + +from parakeet.training.trainer import Trainer +from parakeet.training import extension + + +class VisualDL(extension.Extension): + """A wrapper of visualdl log writer. It assumes that the metrics to be visualized + are all scalars which are recorded into the `.observation` dictionary of the + trainer object. The dictionary is created for each step, thus the visualdl log + writer uses the iteration from the updater's `iteration` as the global step to + add records. + """ + trigger = (1, 'iteration') + default_name = 'visualdl' + priority = extension.PRIORITY_READER + + def __init__(self, writer): + self.writer = writer + + def __call__(self, trainer: Trainer): + for k, v in trainer.observation.items(): + self.writer.add_scalar(k, v, step=trainer.updater.state.iteration) + + def finalize(self, trainer): + self.writer.close() diff --git a/parakeet/training/reporter.py b/parakeet/training/reporter.py index 3f4d77f..c2f171c 100644 --- a/parakeet/training/reporter.py +++ b/parakeet/training/reporter.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import contextlib +from collections import defaultdict OBSERVATIONS = None @@ -45,3 +47,113 @@ def report(name, value): return else: observations[name] = value + + +class Summary(object): + """Online summarization of a sequence of scalars. + Summary computes the statistics of given scalars online. + """ + + def __init__(self): + self._x = 0.0 + self._x2 = 0.0 + self._n = 0 + + def add(self, value, weight=1): + """Adds a scalar value. + + Args: + value: Scalar value to accumulate. It is either a NumPy scalar or + a zero-dimensional array (on CPU or GPU). + weight: An optional weight for the value. It is a NumPy scalar or + a zero-dimensional array (on CPU or GPU). + Default is 1 (integer). + + """ + self._x += weight * value + self._x2 += weight * value * value + self._n += weight + + def compute_mean(self): + """Computes the mean.""" + x, n = self._x, self._n + return x / n + + def make_statistics(self): + """Computes and returns the mean and standard deviation values. + + Returns: + tuple: Mean and standard deviation values. + + """ + x, n = self._x, self._n + mean = x / n + var = self._x2 / n - mean * mean + std = math.sqrt(var) + return mean, std + + +class DictSummary(object): + """Online summarization of a sequence of dictionaries. + + ``DictSummary`` computes the statistics of a given set of scalars online. + It only computes the statistics for scalar values and variables of scalar + values in the dictionaries. + + """ + + def __init__(self): + self._summaries = defaultdict(Summary) + + def add(self, d): + """Adds a dictionary of scalars. + + Args: + d (dict): Dictionary of scalars to accumulate. Only elements of + scalars, zero-dimensional arrays, and variables of + zero-dimensional arrays are accumulated. When the value + is a tuple, the second element is interpreted as a weight. + + """ + summaries = self._summaries + for k, v in d.items(): + w = 1 + if isinstance(v, tuple): + w = v[1] + v = v[0] + summaries[k].add(v, weight=w) + + def compute_mean(self): + """Creates a dictionary of mean values. + + It returns a single dictionary that holds a mean value for each entry + added to the summary. + + Returns: + dict: Dictionary of mean values. + + """ + return { + name: summary.compute_mean() + for name, summary in self._summaries.items() + } + + def make_statistics(self): + """Creates a dictionary of statistics. + + It returns a single dictionary that holds mean and standard deviation + values for every entry added to the summary. For an entry of name + ``'key'``, these values are added to the dictionary by names ``'key'`` + and ``'key.std'``, respectively. + + Returns: + dict: Dictionary of statistics of all entries. + + """ + stats = {} + for name, summary in self._summaries.items(): + mean, std = summary.make_statistics() + stats[name] = mean + stats[name + '.std'] = std + + return stats diff --git a/parakeet/training/seeding.py b/parakeet/training/seeding.py new file mode 100644 index 0000000..1663d2d --- /dev/null +++ b/parakeet/training/seeding.py @@ -0,0 +1,27 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import logging + +import paddle +import numpy as np + + +def seed_everything(seed: int): + """Seed paddle, random and np.random to help reproductivity.""" + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + logging.debug(f"Set the seed of paddle, random, np.random to {seed}.") diff --git a/parakeet/training/trainer.py b/parakeet/training/trainer.py index 544ea8e..484845f 100644 --- a/parakeet/training/trainer.py +++ b/parakeet/training/trainer.py @@ -12,16 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys +import six +import traceback from pathlib import Path -import tqdm -from dataclasses import dataclass +from collections import OrderedDict +from typing import Callable, Union, List -from parakeet.training.trigger import get_trigger, IntervalTrigger +import tqdm + +from parakeet.training.trigger import get_trigger, IntervalTrigger, LimitTrigger from parakeet.training.updater import UpdaterBase from parakeet.training.reporter import scope +from parakeet.training.extension import Extension, PRIORITY_READER -class ExtensionEntry(object): +class _ExtensionEntry(object): def __init__(self, extension, trigger, priority): self.extension = extension self.trigger = trigger @@ -31,31 +37,76 @@ class ExtensionEntry(object): class Trainer(object): def __init__(self, updater: UpdaterBase, - stop_trigger=None, - out='result', - extensions=None): + stop_trigger: Callable=None, + out: Union[str, Path]='result', + extensions: List[Extension]=None): self.updater = updater - self.extensions = {} - self.stop_trigger = get_trigger(stop_trigger) + self.extensions = OrderedDict() + self.stop_trigger = LimitTrigger(*stop_trigger) self.out = Path(out) - self.observation = {} + self.observation =... - def setup(self): - pass + self._done = False + if extensions: + for ext in extensions: + self.extend(ext) + + @property + def is_before_training(self): + return self.updater.state.iteration == 0 def extend(self, extension, name=None, trigger=None, priority=None): + # get name for the extension + # argument \ + # -> extention's name \ + # -> default_name (class name, when it is an object) \ + # -> function name when it is a function \ + # -> error + + if name is None: + name = getattr(extension, 'name', None) + if name is None: + name = getattr(extension, 'default_name', None) + if name is None: + name = getattr(extension, '__name__', None) + if name is None: + raise ValueError( + "Name is not given for the extension.") + if name == 'training': + raise ValueError("training is a reserved name.") + + if trigger is None: + trigger = getattr(extension, 'trigger', (1, 'iteration')) trigger = get_trigger(trigger) + if priority is None: + priority = getattr(extension, 'priority', PRIORITY_READER) + + # add suffix to avoid nameing conflict ordinal = 0 modified_name = name - while name in self.extensions: + while modified_name in self.extensions: ordinal += 1 modified_name = f"{name}_{ordinal}" + extension.name = modified_name - self.extensions[modified_name] = ExtensionEntry(extension, trigger, - priority) + self.extensions[modified_name] = _ExtensionEntry(extension, trigger, + priority) + + def get_extension(self, name): + """get extension by name.""" + extensions = self.extensions + if name in extensions: + return extensions[name].extension + else: + raise ValueError(f'extension {name} not found') def run(self): + if self._done: + raise RuntimeError("Training is already done!.") + + self.out.mkdir(parents=True, exist_ok=True) + # sort extensions by priorities once extension_order = sorted( self.extensions.keys(), @@ -64,28 +115,72 @@ class Trainer(object): extensions = [(name, self.extensions[name]) for name in extension_order] - update = self.updater.update + # initializing all extensions + for name, entry in extensions: + if hasattr(entry.extension, "initialize"): + entry.extension.initialize(self) + + update = self.updater.update # training step stop_trigger = self.stop_trigger - # TODO(chenfeiyu): display progress bar correctly - # if the trainer is controlled by epoch: use 2 progressbars - # if the trainer is controlled by iteration: use 1 progressbar - if isinstance(stop_trigger, IntervalTrigger): + print(self.updater.state) + + # display only one progress bar + max_iteration = None + if isinstance(stop_trigger, LimitTrigger): if stop_trigger.unit is 'epoch': - max_epoch = self.stop_trigger.period + max_epoch = self.stop_trigger.limit + updates_per_epoch = getattr(self.updater, "updates_per_epoch", + None) + max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None else: - max_iteration = self.stop_trigger.period + max_iteration = self.stop_trigger.limit - while not stop_trigger(self): - self.observation = {} - # set observation as the report target - # you can use report freely in Updater.update() + p = tqdm.tqdm( + initial=self.updater.state.iteration, total=max_iteration) - # updating parameters and state - with scope(self.observation): - update() + try: + while not stop_trigger(self): + self.observation = {} + # set observation as the report target + # you can use report freely in Updater.update() - # execute extension when necessary + # updating parameters and state + with scope(self.observation): + update() + p.update() + + # execute extension when necessary + for name, entry in extensions: + if entry.trigger(self): + entry.extension(self) + + # print("###", self.observation) + except Exception as e: + f = sys.stderr + f.write(f"Exception in main training loop: {e}\n") + f.write("Traceback (most recent call last):\n") + traceback.print_tb(sys.exc_info()[2]) + f.write( + "Trainer extensions will try to handle the extension. Then all extensions will finalize." + ) + + # capture the exception in the mian training loop + exc_info = sys.exc_info() + + # try to handle it for name, entry in extensions: - if entry.trigger(self): - entry.extension(self) + if hasattr(entry.extension, "on_error"): + try: + entry.extension.on_error(self, e, sys.exc_info()[2]) + except Exception as ee: + f.write(f"Exception in error handler: {ee}\n") + f.write('Traceback (most recent call last):\n') + traceback.print_tb(sys.exc_info()[2]) + + # raise exception in main training loop + six.reraise(*exc_info) + finally: + for name, entry in extensions: + if hasattr(entry.extension, "finalize"): + entry.extension.finalize(self) diff --git a/parakeet/training/trigger.py b/parakeet/training/trigger.py index a7d4ef9..b588512 100644 --- a/parakeet/training/trigger.py +++ b/parakeet/training/trigger.py @@ -12,21 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - -class IntervalTrigger(object): - def __init__(self, period: int, unit: str): - if unit not in ("iteration", "epoch"): - raise ValueError("unit should be 'iteration' or 'epoch'") - self.period = period - self.unit = unit - - def __call__(self, trainer): - state = trainer.updater.state - if self.unit == "epoch": - fire = not (state.epoch % self.period) - else: - fire = not (state.iteration % self.iteration) - return fire +from parakeet.training.triggers.interval_trigger import IntervalTrigger +from parakeet.training.triggers.limit_trigger import LimitTrigger +from parakeet.training.triggers.time_trigger import TimeTrigger def never_file_trigger(trainer): diff --git a/parakeet/training/triggers/interval_trigger.py b/parakeet/training/triggers/interval_trigger.py new file mode 100644 index 0000000..b88816c --- /dev/null +++ b/parakeet/training/triggers/interval_trigger.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class IntervalTrigger(object): + """A Predicate to do something every N cycle.""" + + def __init__(self, period: int, unit: str): + if unit not in ("iteration", "epoch"): + raise ValueError("unit should be 'iteration' or 'epoch'") + if period <= 0: + raise ValueError("period should be a positive integer.") + self.period = period + self.unit = unit + + def __call__(self, trainer): + state = trainer.updater.state + index = getattr(state, self.unit) + fire = index % self.period == 0 + return fire diff --git a/parakeet/training/triggers/limit_trigger.py b/parakeet/training/triggers/limit_trigger.py new file mode 100644 index 0000000..dd7a135 --- /dev/null +++ b/parakeet/training/triggers/limit_trigger.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class LimitTrigger(object): + """A Predicate to decide whether to stop.""" + + def __init__(self, limit: int, unit: str): + if unit not in ("iteration", "epoch"): + raise ValueError("unit should be 'iteration' or 'epoch'") + if limit <= 0: + raise ValueError("limit should be a positive integer.") + self.limit = limit + self.unit = unit + + def __call__(self, trainer): + state = trainer.updater.state + index = getattr(state, self.unit) + fire = index >= self.limit + return fire diff --git a/parakeet/training/triggers/time_trigger.py b/parakeet/training/triggers/time_trigger.py new file mode 100644 index 0000000..aff9382 --- /dev/null +++ b/parakeet/training/triggers/time_trigger.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TimeTrigger(object): + """Trigger based on a fixed time interval. + + This trigger accepts iterations with a given interval time. + + Args: + period (float): Interval time. It is given in seconds. + + """ + + def __init__(self, period): + self._period = period + self._next_time = self._period + + def __call__(self, trainer): + if self._next_time < trainer.elapsed_time: + self._next_time += self._period + return True + else: + return False diff --git a/parakeet/training/updater.py b/parakeet/training/updater.py index 8359eef..5ec5eec 100644 --- a/parakeet/training/updater.py +++ b/parakeet/training/updater.py @@ -12,12 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from dataclasses import dataclass from typing import Optional +from typing import Dict +from typing import Union +from timer import timer +import paddle +from paddle import Tensor from paddle.nn import Layer from paddle.optimizer import Optimizer from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler + +from parakeet.training.reporter import report @dataclass @@ -56,68 +65,33 @@ class UpdaterBase(object): So the best practice is to define a model and define a updater for it. """ - def update(self): - pass - - def update_core(self): - pass - - -class StandardUpdater(UpdaterBase): - """An example of over-simplification. Things may not be that simple, but - you can subclass it to fit your need. - """ - - def __init__(self, - model: Layer, - dataloader: DataLoader, - optimizer: Optimizer, - loss_func=None, - auto_new_epoch: bool=True, - init_state: Optional[UpdaterState]=None): - self.model = model - self.dataloader = dataloader - self.optimizer = optimizer - self.loss_func = loss_func - self.auto_new_epoch = auto_new_epoch - self.iterator = iter(dataloader) - + def __init__(self, init_state=None): if init_state is None: self.state = UpdaterState() else: self.state = init_state - def update(self): - self.update_core() - self.state.iteration += 1 + def update(self, batch): + raise NotImplementedError( + "Implement your own `update` method for training a step.") - def new_epoch(self): - self.iterator = iter(self.dataloader) - self.state.epoch += 1 + def state_dict(self): + state_dict = { + "epoch": self.state.epoch, + "iteration": self.state.iteration, + } + return state_dict - def update_core(self): - model = self.model - optimizer = self.optimizer - loss_func = self.loss_func + def set_state_dict(self, state_dict): + self.state.epoch = state_dict["epoch"] + self.state.iteration = state_dict["iteration"] - model.train() - optimizer.clear_grad() + def save(self, path): + logging.debug(f"Saving to {path}.") + archive = self.state_dict() + paddle.save(archive, str(path)) - # fetch a batch - try: - batch = next(self.iterator) - except StopIteration as e: - if self.auto_new_epoch: - self.new_epoch() - - # forward - if self.loss_func is not None: - loss = loss_func(batch) - else: - loss = model(batch) - - # backward - loss.backward() - - # update parameters - optimizer.step() + def load(self, path): + logging.debug(f"Loading from {path}.") + archive = paddle.load(str(path)) + self.set_state_dict(archive) diff --git a/parakeet/training/updaters/standard_updater.py b/parakeet/training/updaters/standard_updater.py new file mode 100644 index 0000000..e39b758 --- /dev/null +++ b/parakeet/training/updaters/standard_updater.py @@ -0,0 +1,190 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass +from typing import Optional +from typing import Dict +from typing import Union + +from timer import timer +import paddle +from paddle import Tensor +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler + +from parakeet.training.reporter import report +from parakeet.training.updater import UpdaterBase, UpdaterState + + +class StandardUpdater(UpdaterBase): + """An example of over-simplification. Things may not be that simple, but + you can subclass it to fit your need. + """ + + def __init__(self, + model: Layer, + optimizer: Optimizer, + dataloader: DataLoader, + init_state: Optional[UpdaterState]=None): + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models + self.model = model + + # it is designed to hold multiple optimizers + optimizers = {"main": optimizer} + self.optimizer = optimizer + self.optimizers: Dict[str, Optimizer] = optimizers + + # dataloaders + self.dataloader = dataloader + + # init state + if init_state is None: + self.state = UpdaterState() + else: + self.state = init_state + + self.train_iterator = iter(dataloader) + + def update(self): + # We increase the iteration index after updating and before extension. + # Here are the reasons. + + # 0. Snapshotting(as well as other extensions, like visualizer) is + # executed after a step of updating; + # 1. We decide to increase the iteration index after updating and + # before any all extension is executed. + # 3. We do not increase the iteration after extension because we + # prefer a consistent resume behavior, when load from a + # `snapshot_iter_100.pdz` then the next step to train is `101`, + # naturally. But if iteration is increased increased after + # extension(including snapshot), then, a `snapshot_iter_99` is + # loaded. You would need a extra increasing of the iteration idex + # before training to avoid another iteration `99`, which has been + # done before snapshotting. + # 4. Thus iteration index represrnts "currently how mant epochs has + # been done." + # NOTE: use report to capture the correctly value. If you want to + # report the learning rate used for a step, you must report it before + # the learning rate scheduler's step() has been called. In paddle's + # convention, we do not use an extension to change the learning rate. + # so if you want to report it, do it in the updater. + + # Then here comes the next question. When is the proper time to + # increase the epoch index? Since all extensions are executed after + # updating, it is the time that after updating is the proper time to + # increase epoch index. + # 1. If we increase the epoch index before updating, then an extension + # based ot epoch would miss the correct timing. It could only be + # triggerd after an extra updating. + # 2. Theoretically, when an epoch is done, the epoch index should be + # increased. So it would be increase after updating. + # 3. Thus, eppoch index represents "currently how many epochs has been + # done." So it starts from 0. + + # switch to training mode + for layer in self.models.values(): + layer.train() + + # training for a step is implemented here + batch = self.read_batch() + self.update_core(batch) + + self.state.iteration += 1 + if self.updaters_per_epoch is not None: + if self.state.iteration % self.updaters_per_epoch == 0: + self.state.epoch += 1 + + def update_core(self, batch): + """A simple case for a training step. Basic assumptions are: + Single model; + Single optimizer; + A batch from the dataloader is just the input of the model; + The model return a single loss, or a dict containing serval losses. + Parameters updates at every batch, no gradient accumulation. + """ + loss = self.model(*batch) + + if isinstance(loss, Tensor): + loss_dict = {"main": loss} + else: + # Dict[str, Tensor] + loss_dict = loss + if "main" not in loss_dict: + main_loss = 0 + for loss_item in loss.values(): + main_loss += loss_item + loss_dict["main"] = main_loss + + for name, loss_item in loss_dict.items(): + report(name, float(loss_item)) + + self.optimizer.clear_gradient() + loss_dict["main"].backward() + self.optimizer.update() + + @property + def updaters_per_epoch(self): + """Number of updater per epoch, determined by the length of the + dataloader.""" + length_of_dataloader = None + try: + length_of_dataloader = len(self.dataloader) + except TypeError: + logging.debug("This dataloader has no __len__.") + finally: + return length_of_dataloader + + def new_epoch(self): + """Start a new epoch.""" + # NOTE: all batch sampler for distributed training should + # subclass DistributedBatchSampler and implement `set_epoch` method + batch_sampler = self.dataloader.batch_sampler + if isinstance(batch_sampler, DistributedBatchSampler): + batch_sampler.set_epoch(self.state.epoch) + self.train_iterator = iter(self.dataloader) + + def read_batch(self): + """Read a batch from the data loader, auto renew when data is exhausted.""" + with timer() as t: + try: + batch = next(self.train_iterator) + except StopIteration: + self.new_epoch() + batch = next(self.train_iterator) + logging.debug( + f"Read a batch takes {t.elapse}s.") # replace it with logging + return batch + + def state_dict(self): + """State dict of a Updater, model, optimizer and updater state are included.""" + state_dict = super().state_dict() + for name, layer in self.models.items(): + state_dict[f"{name}_params"] = layer.state_dict() + for name, optim in self.optimizers.items(): + state_dict[f"{name}_optimizer"] = optim.state_dict() + return state_dict + + def set_state_dict(self, state_dict): + """Set state dict for a Updater. Parameters of models, states for + optimizers and UpdaterState are restored.""" + for name, layer in self.models.items(): + layer.set_state_dict(state_dict[f"{name}_params"]) + for name, optim in self.optimizers.items(): + optim.set_state_dict(state_dict[f"{name}_optimizer"]) + super().set_state_dict(state_dict) diff --git a/parakeet/utils/h5_utils.py b/parakeet/utils/h5_utils.py new file mode 100644 index 0000000..cd0c670 --- /dev/null +++ b/parakeet/utils/h5_utils.py @@ -0,0 +1,105 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Union, Any +import sys +import logging +import h5py +import numpy as np + + +def read_hdf5(filename: Union[Path, str], dataset_name: str) -> Any: + """Read a dataset from a HDF5 file. + + Parameters + ---------- + filename : Union[Path, str] + Path of the HDF5 file. + dataset_name : str + Name of the dataset to read. + + Returns + ------- + Any + The retrieved dataset. + """ + filename = Path(filename) + + if not filename.exists(): + logging.error(f"There is no such a hdf5 file ({filename}).") + sys.exit(1) + + hdf5_file = h5py.File(filename, "r") + + if dataset_name not in hdf5_file: + logging.error( + f"There is no such a data in hdf5 file. ({dataset_name})") + sys.exit(1) + + # [()]: a special syntax of h5py to get the dataset as-is + hdf5_data = hdf5_file[dataset_name][()] + hdf5_file.close() + + return hdf5_data + + +def write_hdf5(filename: Union[Path, str], + dataset_name: str, + write_data: np.ndarray, + is_overwrite: bool=True) -> None: + """Write dataset to HDF5 file. + + Parameters + ---------- + filename : Union[Path, str] + Path of the HDF5 file. + dataset_name : str + Name of the dataset to write to. + write_data : np.ndarrays + The data to write. + is_overwrite : bool, optional + Whether to overwrite, by default True + """ + # convert to numpy array + filename = Path(filename) + write_data = np.array(write_data) + + # check folder existence + filename.parent.mkdir(parents=True, exist_ok=True) + + # check hdf5 existence + if filename.exists(): + # if already exists, open with r+ mode + hdf5_file = h5py.File(filename, "r+") + # check dataset existence + if dataset_name in hdf5_file: + if is_overwrite: + logging.warning("Dataset in hdf5 file already exists. " + "recreate dataset in hdf5.") + hdf5_file.__delitem__(dataset_name) + else: + logging.error( + "Dataset in hdf5 file already exists. " + "if you want to overwrite, please set is_overwrite = True.") + hdf5_file.close() + sys.exit(1) + else: + # if not exists, open with w mode + hdf5_file = h5py.File(filename, "w") + + # write data to hdf5 + hdf5_file.create_dataset(dataset_name, data=write_data) + hdf5_file.flush() + hdf5_file.close() diff --git a/parakeet/utils/profile.py b/parakeet/utils/profile.py new file mode 100644 index 0000000..cfffb4b --- /dev/null +++ b/parakeet/utils/profile.py @@ -0,0 +1,34 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.framework import core +from paddle.framework import CUDAPlace +from contextlib import contextmanager + + +def synchronize(): + """Trigger cuda synchronization for better timing.""" + place = paddle.fluid.framework._current_expected_place() + if isinstance(place, CUDAPlace): + paddle.fluid.core._cuda_synchronize(place) + + +@contextmanager +def nvtx_span(name): + try: + core.nvprof_nvtx_push(name) + yield + finally: + core.nvprof_nvtx_pop() diff --git a/parakeet/utils/timeline.py b/parakeet/utils/timeline.py new file mode 100644 index 0000000..2a399b7 --- /dev/null +++ b/parakeet/utils/timeline.py @@ -0,0 +1,319 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import six +import sys +import unittest + +import google.protobuf.text_format as text_format +import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2 + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + '--profile_path', + type=str, + default='', + help='Input profile file name. If there are multiple file, the format ' + 'should be trainer1=file1,trainer2=file2,ps=file3') +parser.add_argument( + '--timeline_path', type=str, default='', help='Output timeline file name.') +args = parser.parse_args() + + +class _ChromeTraceFormatter(object): + def __init__(self): + self._events = [] + self._metadata = [] + + def _create_event(self, ph, category, name, pid, tid, timestamp): + """Creates a new Chrome Trace event. + + For details of the file format, see: + https://github.com/catapult-project/catapult/blob/master/tracing/README.md + + Args: + ph: The type of event - usually a single character. + category: The event category as a string. + name: The event name as a string. + pid: Identifier of the process generating this event as an integer. + tid: Identifier of the thread generating this event as an integer. + timestamp: The timestamp of this event as a long integer. + + Returns: + A JSON compatible event object. + """ + event = {} + event['ph'] = ph + event['cat'] = category + event['name'] = name.replace("ParallelExecutor::Run/", "") + event['pid'] = pid + event['tid'] = tid + event['ts'] = timestamp + return event + + def emit_pid(self, name, pid): + """Adds a process metadata event to the trace. + + Args: + name: The process name as a string. + pid: Identifier of the process as an integer. + """ + event = {} + event['name'] = 'process_name' + event['ph'] = 'M' + event['pid'] = pid + event['args'] = {'name': name} + self._metadata.append(event) + + def emit_region(self, timestamp, duration, pid, tid, category, name, args): + """Adds a region event to the trace. + + Args: + timestamp: The start timestamp of this region as a long integer. + duration: The duration of this region as a long integer. + pid: Identifier of the process generating this event as an integer. + tid: Identifier of the thread generating this event as an integer. + category: The event category as a string. + name: The event name as a string. + args: A JSON-compatible dictionary of event arguments. + """ + event = self._create_event('X', category, name, pid, tid, timestamp) + event['dur'] = duration + event['args'] = args + self._events.append(event) + + def emit_counter(self, category, name, pid, timestamp, counter, value): + """Emits a record for a single counter. + + Args: + category: The event category as string + name: The event name as string + pid: Identifier of the process generating this event as integer + timestamp: The timestamps of this event as long integer + counter: Name of the counter as string + value: Value of the counter as integer + tid: Thread id of the allocation as integer + """ + event = self._create_event('C', category, name, pid, 0, timestamp) + event['args'] = {counter: value} + self._events.append(event) + + def format_to_string(self, pretty=False): + """Formats the chrome trace to a string. + + Args: + pretty: (Optional.) If True, produce human-readable JSON output. + + Returns: + A JSON-formatted string in Chrome Trace format. + """ + trace = {} + trace['traceEvents'] = self._metadata + self._events + if pretty: + return json.dumps(trace, indent=4, separators=(',', ': ')) + else: + return json.dumps(trace, separators=(',', ':')) + + +class Timeline(object): + def __init__(self, profile_dict): + self._profile_dict = profile_dict + self._pid = 0 + self._devices = dict() + self._mem_devices = dict() + self._chrome_trace = _ChromeTraceFormatter() + + def _allocate_pid(self): + cur_pid = self._pid + self._pid += 1 + return cur_pid + + def _allocate_pids(self): + for k, profile_pb in six.iteritems(self._profile_dict): + for event in profile_pb.events: + if event.type == profiler_pb2.Event.CPU: + if (k, event.device_id, "CPU") not in self._devices: + pid = self._allocate_pid() + self._devices[(k, event.device_id, "CPU")] = pid + # -1 device id represents CUDA API(RunTime) call.(e.g. cudaLaunch, cudaMemcpy) + if event.device_id == -1: + self._chrome_trace.emit_pid("%s:cuda_api" % k, pid) + else: + self._chrome_trace.emit_pid( + "%s:cpu:block:%d" % (k, event.device_id), pid) + elif event.type == profiler_pb2.Event.GPUKernel: + if (k, event.device_id, "GPUKernel") not in self._devices: + pid = self._allocate_pid() + self._devices[(k, event.device_id, "GPUKernel")] = pid + self._chrome_trace.emit_pid("%s:gpu:%d" % + (k, event.device_id), pid) + if not hasattr(profile_pb, "mem_events"): + continue + for mevent in profile_pb.mem_events: + if mevent.place == profiler_pb2.MemEvent.CUDAPlace: + if (k, mevent.device_id, "GPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, "GPU")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:gpu:%d" % (k, mevent.device_id), + pid) + elif mevent.place == profiler_pb2.MemEvent.CPUPlace: + if (k, mevent.device_id, "CPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, "CPU")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:cpu:%d" % (k, mevent.device_id), + pid) + elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace: + if (k, mevent.device_id, "CUDAPinnedPlace" + ) not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, + "CUDAPinnedPlace")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:cudapinnedplace:%d" % + (k, mevent.device_id), pid) + elif mevent.place == profiler_pb2.MemEvent.NPUPlace: + if (k, mevent.device_id, "NPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, "NPU")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:npu:%d" % (k, mevent.device_id), + pid) + if (k, 0, "CPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "CPU")] = pid + self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" % + (k, 0), pid) + if (k, 0, "GPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "GPU")] = pid + self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" % + (k, 0), pid) + if (k, 0, "CUDAPinnedPlace") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "CUDAPinnedPlace")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:cudapinnedplace:%d" % (k, 0), pid) + if (k, 0, "NPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "NPU")] = pid + self._chrome_trace.emit_pid("memory usage on %s:npu:%d" % + (k, 0), pid) + + def _allocate_events(self): + for k, profile_pb in six.iteritems(self._profile_dict): + for event in profile_pb.events: + if event.type == profiler_pb2.Event.CPU: + type = "CPU" + elif event.type == profiler_pb2.Event.GPUKernel: + type = "GPUKernel" + pid = self._devices[(k, event.device_id, type)] + args = {'name': event.name} + if event.memcopy.bytes > 0: + args['mem_bytes'] = event.memcopy.bytes + if hasattr(event, "detail_info") and event.detail_info: + args['detail_info'] = event.detail_info + # TODO(panyx0718): Chrome tracing only handles ms. However, some + # ops takes micro-seconds. Hence, we keep the ns here. + self._chrome_trace.emit_region( + event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid, + event.sub_device_id, 'Op', event.name, args) + + def _allocate_memory_event(self): + if not hasattr(profiler_pb2, "MemEvent"): + return + place_to_str = { + profiler_pb2.MemEvent.CPUPlace: "CPU", + profiler_pb2.MemEvent.CUDAPlace: "GPU", + profiler_pb2.MemEvent.CUDAPinnedPlace: "CUDAPinnedPlace", + profiler_pb2.MemEvent.NPUPlace: "NPU" + } + for k, profile_pb in six.iteritems(self._profile_dict): + mem_list = [] + end_profiler = 0 + for mevent in profile_pb.mem_events: + crt_info = dict() + crt_info['time'] = mevent.start_ns + crt_info['size'] = mevent.bytes + if mevent.place in place_to_str: + place = place_to_str[mevent.place] + else: + place = "UnDefine" + crt_info['place'] = place + pid = self._mem_devices[(k, mevent.device_id, place)] + crt_info['pid'] = pid + crt_info['thread_id'] = mevent.thread_id + crt_info['device_id'] = mevent.device_id + mem_list.append(crt_info) + crt_info = dict() + crt_info['place'] = place + crt_info['pid'] = pid + crt_info['thread_id'] = mevent.thread_id + crt_info['device_id'] = mevent.device_id + crt_info['time'] = mevent.end_ns + crt_info['size'] = -mevent.bytes + mem_list.append(crt_info) + end_profiler = max(end_profiler, crt_info['time']) + mem_list.sort(key=lambda tmp: (tmp.get('time', 0))) + i = 0 + total_size = 0 + while i < len(mem_list): + total_size += mem_list[i]['size'] + while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[ + i + 1]['time']: + total_size += mem_list[i + 1]['size'] + i += 1 + + self._chrome_trace.emit_counter( + "Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'], + 0, total_size) + i += 1 + + def generate_chrome_trace(self): + self._allocate_pids() + self._allocate_events() + self._allocate_memory_event() + return self._chrome_trace.format_to_string() + + +profile_path = '/tmp/profile' +if args.profile_path: + profile_path = args.profile_path +timeline_path = '/tmp/timeline' +if args.timeline_path: + timeline_path = args.timeline_path + +profile_paths = profile_path.split(',') +profile_dict = dict() +if len(profile_paths) == 1: + with open(profile_path, 'rb') as f: + profile_s = f.read() + profile_pb = profiler_pb2.Profile() + profile_pb.ParseFromString(profile_s) + profile_dict['trainer'] = profile_pb +else: + for profile_path in profile_paths: + k, v = profile_path.split('=') + with open(v, 'rb') as f: + profile_s = f.read() + profile_pb = profiler_pb2.Profile() + profile_pb.ParseFromString(profile_s) + profile_dict[k] = profile_pb + +tl = Timeline(profile_dict) +with open(timeline_path, 'w') as f: + f.write(tl.generate_chrome_trace()) diff --git a/setup.py b/setup.py index eefc922..b7cb4da 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,6 @@ setup_info = dict( 'scipy', 'pandas', 'sox', - # 'opencc', 'soundfile', 'g2p_en', 'yacs', @@ -73,6 +72,9 @@ setup_info = dict( 'webrtcvad', 'g2pM', 'praatio', + "h5py", + "timer", + 'jsonlines', ], extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], }, diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py deleted file mode 100644 index 9173033..0000000 --- a/tests/test_checkpoint.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path -import shutil - -import numpy as np -from parakeet.training.checkpoint import KBest, KLatest - - -def test_kbest(): - def save_fn(path): - with open(path, 'wt') as f: - f.write(f"My path is {str(path)}\n") - - K = 1 - kbest_manager = KBest(max_size=K, save_fn=save_fn) - checkpoint_dir = Path("checkpoints") - shutil.rmtree(checkpoint_dir) - checkpoint_dir.mkdir(parents=True) - a = np.random.rand(20) - for i, score in enumerate(a): - path = checkpoint_dir / f"step_{i}" - kbest_manager.add_checkpoint(score, path) - assert len(list(checkpoint_dir.glob("step_*"))) == K - - -def test_klatest(): - def save_fn(path): - with open(path, 'wt') as f: - f.write(f"My path is {str(path)}\n") - - K = 5 - klatest_manager = KLatest(max_size=K, save_fn=save_fn) - checkpoint_dir = Path("checkpoints") - shutil.rmtree(checkpoint_dir) - checkpoint_dir.mkdir(parents=True) - for i in range(20): - path = checkpoint_dir / f"step_{i}" - klatest_manager.add_checkpoint(path) - assert len(list(checkpoint_dir.glob("step_*"))) == K diff --git a/tests/test_data_table.py b/tests/test_data_table.py new file mode 100644 index 0000000..aca0605 --- /dev/null +++ b/tests/test_data_table.py @@ -0,0 +1,22 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from parakeet.datasets.data_tabel import DataTable + + +def test_audio_dataset(): + metadata = [{'name': 'Sonic', 'v': 1000}, {'name': 'Prestol', 'v': 2000}] + converters = {'v': lambda x: x / 1000} + dataset = DataTable(metadata, fields=['v'], converters=converters) + assert dataset[0] == {'v': 1.0} diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000..bdb3d96 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +from pathlib import Path + +import paddle +from paddle import nn +from paddle.optimizer import Adam +from paddle.optimizer.lr import StepDecay + + +def test_optimizer(): + model1 = nn.Linear(3, 4) + optim1 = Adam( + parameters=model1.parameters(), learning_rate=StepDecay(0.1, 100)) + + output_dir = Path("temp_test_optimizer") + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.mkdir(exist_ok=True, parents=True) + + # model1.set_state_dict(model1.state_dict()) + optim1.set_state_dict(optim1.state_dict()) + + x = paddle.randn([6, 3]) + y = model1(x).sum() + y.backward() + optim1.step() diff --git a/tests/test_pwg.py b/tests/test_pwg.py new file mode 100644 index 0000000..0978714 --- /dev/null +++ b/tests/test_pwg.py @@ -0,0 +1,240 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import torch +from timer import timer +from parallel_wavegan.layers import upsample, residual_block +from parallel_wavegan.models import parallel_wavegan as pwgan +from parakeet.utils.layer_tools import summary +from parakeet.utils.profile import synchronize + +from parakeet.models.parallel_wavegan import ConvInUpsampleNet, ResidualBlock +from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator, ResidualPWGDiscriminator + +paddle.set_device("gpu:0") +device = torch.device("cuda:0") + + +def test_convin_upsample_net(): + net = ConvInUpsampleNet( + [4, 4, 4, 4], + "LeakyReLU", {"negative_slope": 0.2}, + freq_axis_kernel_size=3, + aux_context_window=0) + net2 = upsample.ConvInUpsampleNetwork( + [4, 4, 4, 4], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + freq_axis_kernel_size=3, + aux_context_window=0).to(device) + summary(net) + for k, v in net2.named_parameters(): + print(k, v.shape) + net.state_dict()[k].set_value(v.data.cpu().numpy()) + + c = paddle.randn([4, 80, 180]) + synchronize() + with timer(unit='s') as t: + out = net(c) + synchronize() + print(f"paddle conv_in_upsample_net forward takes {t.elapse}s.") + + with timer(unit='s') as t: + out.sum().backward() + synchronize() + print(f"paddle conv_in_upsample_net backward takes {t.elapse}s.") + + c_torch = torch.as_tensor(c.numpy()).to(device) + torch.cuda.synchronize() + with timer(unit='s') as t: + out2 = net2(c_torch) + print(f"torch conv_in_upsample_net forward takes {t.elapse}s.") + + with timer(unit='s') as t: + out2.sum().backward() + print(f"torch conv_in_upsample_net backward takes {t.elapse}s.") + + print("forward check") + print(out.numpy()[0]) + print(out2.data.cpu().numpy()[0]) + + print("backward check") + print(net.conv_in.weight.grad.numpy()[0]) + print(net2.conv_in.weight.grad.data.cpu().numpy()[0]) + + +def test_residual_block(): + net = ResidualBlock(dilation=9) + net2 = residual_block.ResidualBlock(dilation=9) + summary(net) + summary(net2) + for k, v in net2.named_parameters(): + net.state_dict()[k].set_value(v.data.cpu().numpy()) + + x = paddle.randn([4, 64, 180]) + c = paddle.randn([4, 80, 180]) + res, skip = net(x, c) + res2, skip2 = net2(torch.as_tensor(x.numpy()), torch.as_tensor(c.numpy())) + + print("forward:") + print(res.numpy()[0]) + print(res2.data.cpu().numpy()[0]) + print(skip.numpy()[0]) + print(skip2.data.cpu().numpy()[0]) + + (res.sum() + skip.sum()).backward() + (res2.sum() + skip2.sum()).backward() + + print("backward:") + print(net.conv.weight.grad.numpy().squeeze()[0]) + print(net2.conv.weight.grad.data.cpu().numpy().squeeze()[0]) + + +def test_pwg_generator(): + net = PWGGenerator( + layers=9, + stacks=3, + upsample_scales=[4, 4, 4, 4], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.5}, + use_weight_norm=True) + net2 = pwgan.ParallelWaveGANGenerator( + layers=9, + stacks=3, + upsample_params={ + "upsample_scales": [4, 4, 4, 4], + "nonlinear_activation": "LeakyReLU", + "nonlinear_activation_params": { + "negative_slope": 0.5 + } + }, + use_weight_norm=True).to(device) + summary(net) + summary(net2) + for k, v in net2.named_parameters(): + p = net.state_dict()[k] + if k.endswith("_g"): + p.set_value(v.data.cpu().numpy().reshape([-1])) + else: + p.set_value(v.data.cpu().numpy()) + x = paddle.randn([4, 1, 80 * 256]) + c = paddle.randn([4, 80, 80 + 4]) + + synchronize() + with timer(unit='s') as t: + out = net(x, c) + synchronize() + print(f"paddle generator forward takes {t.elapse}s.") + + synchronize() + with timer(unit='s') as t: + out.sum().backward() + synchronize() + print(f"paddle generator backward takes {t.elapse}s.") + + x_torch = torch.as_tensor(x.numpy()).to(device) + c_torch = torch.as_tensor(c.numpy()).to(device) + + torch.cuda.synchronize() + with timer(unit='s') as t: + out2 = net2(x_torch, c_torch) + torch.cuda.synchronize() + print(f"torch generator forward takes {t.elapse}s.") + + torch.cuda.synchronize() + with timer(unit='s') as t: + out2.sum().backward() + torch.cuda.synchronize() + print(f"torch generator backward takes {t.elapse}s.") + + print("test forward:") + print(out.numpy()[0]) + print(out2.data.cpu().numpy()[0]) + + print("test backward:") + print("wv") + print(net.first_conv.weight_v.grad.numpy().squeeze()) + print(net2.first_conv.weight_v.grad.data.cpu().numpy().squeeze()) + + print("wg") + print(net.first_conv.weight_g.grad.numpy().squeeze()) + print(net2.first_conv.weight_g.grad.data.cpu().numpy().squeeze()) + # print(out.shape) + + +def test_pwg_discriminator(): + net = PWGDiscriminator() + net2 = pwgan.ParallelWaveGANDiscriminator().to(device) + summary(net) + summary(net2) + for k, v in net2.named_parameters(): + p = net.state_dict()[k] + if k.endswith("_g"): + p.set_value(v.data.cpu().numpy().reshape([-1])) + else: + p.set_value(v.data.cpu().numpy()) + x = paddle.randn([4, 1, 180 * 256]) + + synchronize() + with timer() as t: + y = net(x) + synchronize() + print(f"forward takes {t.elapse}s.") + + synchronize() + with timer() as t: + y.sum().backward() + synchronize() + print(f"backward takes {t.elapse}s.") + + x_torch = torch.as_tensor(x.numpy()).to(device) + torch.cuda.synchronize() + with timer() as t: + y2 = net2(x_torch) + torch.cuda.synchronize() + print(f"forward takes {t.elapse}s.") + + torch.cuda.synchronize() + with timer() as t: + y2.sum().backward() + torch.cuda.synchronize() + print(f"backward takes {t.elapse}s.") + + print("test forward:") + print(y.numpy()[0]) + print(y2.data.cpu().numpy()[0]) + + print("test backward:") + print(net.conv_layers[0].weight_v.grad.numpy().squeeze()) + print(net2.conv_layers[0].weight_v.grad.data.cpu().numpy().squeeze()) + + +def test_residual_pwg_discriminator(): + net = ResidualPWGDiscriminator() + net2 = pwgan.ResidualParallelWaveGANDiscriminator() + summary(net) + summary(net2) + for k, v in net2.named_parameters(): + p = net.state_dict()[k] + if k.endswith("_g"): + p.set_value(v.data.cpu().numpy().reshape([-1])) + else: + p.set_value(v.data.cpu().numpy()) + x = paddle.randn([4, 1, 180 * 256]) + y = net(x) + y2 = net2(torch.as_tensor(x.numpy())) + print(y.numpy()[0]) + print(y2.data.cpu().numpy()[0]) + print(y.shape) diff --git a/tests/test_reporter.py b/tests/test_reporter.py new file mode 100644 index 0000000..cd40364 --- /dev/null +++ b/tests/test_reporter.py @@ -0,0 +1,51 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from parakeet.training.reporter import report, scope +from parakeet.training.reporter import Summary, DictSummary + + +def test_reporter_scope(): + first = {} + second = {} + third = {} + + with scope(first): + report("first_begin", 1) + with scope(second): + report("second_begin", 2) + with scope(third): + report("third_begin", 3) + report("third_end", 4) + report("seconf_end", 5) + report("first_end", 6) + + assert first == {'first_begin': 1, 'first_end': 6} + assert second == {'second_begin': 2, 'seconf_end': 5} + assert third == {'third_begin': 3, 'third_end': 4} + print(first) + print(second) + print(third) + + +def test_summary(): + summary = Summary() + summary.add(1) + summary.add(2) + summary.add(3) + state = summary.make_statistics() + print(state) + np.testing.assert_allclose( + np.array(list(state)), np.array([2.0, np.std([1, 2, 3])])) diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py new file mode 100644 index 0000000..71e422c --- /dev/null +++ b/tests/test_snapshot.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import shutil + +import numpy as np +import paddle +from paddle import nn +from paddle.optimizer import Adam +from itertools import count + +from parakeet.training.updater import StandardUpdater +from parakeet.training.trainer import Trainer +from parakeet.training.extensions.snapshot import Snapshot + + +def test_snapshot(): + model = nn.Linear(3, 4) + optimizer = Adam(parameters=model.parameters()) + + # use a simplest iterable object as dataloader + dataloader = count() + + # hack the training proecss: training does nothing except increse iteration + updater = StandardUpdater(model, optimizer, dataloader=dataloader) + updater.update_core = lambda x: None + + trainer = Trainer( + updater, stop_trigger=(1000, 'iteration'), out='temp_test_snapshot') + shutil.rmtree(trainer.out, ignore_errors=True) + + snap = Snapshot(max_size=5) + trigger = (10, 'iteration') + trainer.extend(snap, name='snapshot', trigger=trigger, priority=0) + + trainer.run() + + checkpoint_dir = trainer.out / "checkpoints" + snapshots = sorted(list(checkpoint_dir.glob("snapshot_iter_*.pdz"))) + for snap in snapshots: + print(snap) + assert len(snapshots) == 5 + shutil.rmtree(trainer.out) diff --git a/tests/test_stft.py b/tests/test_stft.py new file mode 100644 index 0000000..c985235 --- /dev/null +++ b/tests/test_stft.py @@ -0,0 +1,73 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import torch +import librosa +import numpy as np +from parakeet.modules.stft_loss import STFT, MultiResolutionSTFTLoss +from parallel_wavegan.losses import stft_loss as sl +from scipy import signal + + +def test_stft(): + stft = STFT(n_fft=1024, hop_length=256, win_length=1024) + x = paddle.uniform([4, 46080]) + S = stft.magnitude(x) + window = signal.get_window('hann', 1024, fftbins=True) + D2 = torch.stft( + torch.as_tensor(x.numpy()), + n_fft=1024, + hop_length=256, + win_length=1024, + window=torch.as_tensor(window)) + S2 = (D2**2).sum(-1).sqrt() + S3 = np.abs( + librosa.stft( + x.numpy()[0], n_fft=1024, hop_length=256, win_length=1024)) + print(S2.shape) + print(S.numpy()[0]) + print(S2.data.cpu().numpy()[0]) + print(S3) + + +def test_torch_stft(): + # NOTE: torch.stft use no window by default + x = np.random.uniform(-1.0, 1.0, size=(46080, )) + window = signal.get_window('hann', 1024, fftbins=True) + D2 = torch.stft( + torch.as_tensor(x), + n_fft=1024, + hop_length=256, + win_length=1024, + window=torch.as_tensor(window)) + D3 = librosa.stft( + x, n_fft=1024, hop_length=256, win_length=1024, window='hann') + print(D2[:, :, 0].data.cpu().numpy()[:, 30:60]) + print(D3.real[:, 30:60]) + # print(D3.imag[:, 30:60]) + + +def test_multi_resolution_stft_loss(): + net = MultiResolutionSTFTLoss() + net2 = sl.MultiResolutionSTFTLoss() + + x = paddle.uniform([4, 46080]) + y = paddle.uniform([4, 46080]) + sc, m = net(x, y) + sc2, m2 = net2(torch.as_tensor(x.numpy()), torch.as_tensor(y.numpy())) + print(sc.numpy()) + print(sc2.data.cpu().numpy()) + print(m.numpy()) + print(m2.data.cpu().numpy())