From 3ebed00c964185de3dee32206071a56bdef0fb29 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 1 Jul 2021 16:14:55 +0800 Subject: [PATCH] minor fixes to refine code. --- examples/parallelwave_gan/baker/batch_fn.py | 2 +- .../baker/compute_statistics.py | 4 +++- .../parallelwave_gan/baker/conf/default.yaml | 11 +++++----- examples/parallelwave_gan/baker/preprocess.py | 22 ++++++++++++------- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/examples/parallelwave_gan/baker/batch_fn.py b/examples/parallelwave_gan/baker/batch_fn.py index aff647c..22af5af 100644 --- a/examples/parallelwave_gan/baker/batch_fn.py +++ b/examples/parallelwave_gan/baker/batch_fn.py @@ -25,7 +25,7 @@ class Clip(object): batch_max_steps=20480, hop_size=256, aux_context_window=0, ): - """Initialize customized collater for PyTorch DataLoader. + """Initialize customized collater for DataLoader. Args: batch_max_steps (int): The maximum length of input signal in batch. diff --git a/examples/parallelwave_gan/baker/compute_statistics.py b/examples/parallelwave_gan/baker/compute_statistics.py index 4db003f..06b9b65 100644 --- a/examples/parallelwave_gan/baker/compute_statistics.py +++ b/examples/parallelwave_gan/baker/compute_statistics.py @@ -39,7 +39,9 @@ def main(): parser.add_argument( "--metadata", type=str, help="json file with id and file paths ") parser.add_argument( - "--field-name", type=str, help="json file with id and file paths ") + "--field-name", + type=str, + help="name of the field to compute statistics for.") parser.add_argument( "--config", type=str, help="yaml format configuration file.") parser.add_argument( diff --git a/examples/parallelwave_gan/baker/conf/default.yaml b/examples/parallelwave_gan/baker/conf/default.yaml index 877be2c..777b2b0 100644 --- a/examples/parallelwave_gan/baker/conf/default.yaml +++ b/examples/parallelwave_gan/baker/conf/default.yaml @@ -18,9 +18,8 @@ fmax: 7600 # Maximum frequency in mel basis calculation. # global_gain_scale: 1.0 # Will be multiplied to all of waveform. trim_silence: false # Whether to trim the start and end of silence. top_db: 60 # Need to tune carefully if the recording is not good. -trim_frame_length: 2048 # Frame size in trimming. -trim_hop_length: 512 # Hop size in trimming. -# format: "npy" # Feature file format. "npy" or "hdf5" is supported. +trim_frame_length: 2048 # Frame size in trimming.(in samples) +trim_hop_length: 512 # Hop size in trimming.(in samples) ########################################################### # GENERATOR NETWORK ARCHITECTURE SETTING # @@ -119,11 +118,11 @@ discriminator_train_start_steps: 100000 # Number of steps to start to train disc train_max_steps: 400000 # Number of training steps. save_interval_steps: 5000 # Interval steps to save checkpoint. eval_interval_steps: 1000 # Interval steps to evaluate the network. -log_interval_steps: 100 # Interval steps to record the training log. + ########################################################### # OTHER SETTING # ########################################################### num_save_intermediate_results: 4 # Number of results to be saved as intermediate results. -num_snapshots: 10 -seed: 42 \ No newline at end of file +num_snapshots: 10 # max number of snapshots to keep while training +seed: 42 # random seed for paddle, random, and np.random \ No newline at end of file diff --git a/examples/parallelwave_gan/baker/preprocess.py b/examples/parallelwave_gan/baker/preprocess.py index 3ae461b..6144c34 100644 --- a/examples/parallelwave_gan/baker/preprocess.py +++ b/examples/parallelwave_gan/baker/preprocess.py @@ -147,8 +147,11 @@ def process_sentence(config: Dict[str, Any], # adjust time to make num_samples == num_frames * hop_length num_frames = logmel.shape[1] - y = np.pad(y, (0, config.n_fft), mode="reflect") - y = y[:num_frames * config.hop_length] + if y.size < num_frames * config.hop_length: + y = np.pad(y, (0, num_frames * config.hop_length - y.size), + mode="reflect") + else: + y = y[:num_frames * config.hop_length] num_sample = y.shape[0] mel_path = output_dir / (utt_id + "_feats.npy") @@ -241,13 +244,16 @@ def main(): list((root_dir / "PhoneLabeling").rglob("*.interval"))) # split data into 3 sections - train_wav_files = wav_files[:9800] - dev_wav_files = wav_files[9800:9900] - test_wav_files = wav_files[9900:] + num_train = 9800 + num_dev = 100 - train_alignment_files = alignment_files[:9800] - dev_alignment_files = alignment_files[9800:9900] - test_alignment_files = alignment_files[9900:] + train_wav_files = wav_files[:num_train] + dev_wav_files = wav_files[num_train:num_train + num_dev] + test_wav_files = wav_files[num_train + num_dev:] + + train_alignment_files = alignment_files[:num_train] + dev_alignment_files = alignment_files[num_train:num_train + num_dev] + test_alignment_files = alignment_files[num_train + num_dev:] train_dump_dir = dumpdir / "train" / "raw" train_dump_dir.mkdir(parents=True, exist_ok=True)