From fbc7e51fc9699d9a022a33270932eb0a460acc7a Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Fri, 18 Jun 2021 02:49:49 +0000 Subject: [PATCH] 1. add compute_statistics and normalize; 2. use jsonlines to read and write metadata by default; 3. use threadpool to replace processpool in preprocessing cause it is faster. --- .../baker/compute_statistics.py | 18 +-- examples/parallelwave_gan/baker/normalize.py | 141 ++++++++++++++++++ examples/parallelwave_gan/baker/preprocess.py | 22 +-- setup.py | 1 + 4 files changed, 160 insertions(+), 22 deletions(-) diff --git a/examples/parallelwave_gan/baker/compute_statistics.py b/examples/parallelwave_gan/baker/compute_statistics.py index ea696f8..4db003f 100644 --- a/examples/parallelwave_gan/baker/compute_statistics.py +++ b/examples/parallelwave_gan/baker/compute_statistics.py @@ -20,6 +20,7 @@ import os import numpy as np import yaml import json +import jsonlines from sklearn.preprocessing import StandardScaler from tqdm import tqdm @@ -36,20 +37,13 @@ def main(): parser = argparse.ArgumentParser( description="Compute mean and variance of dumped raw features.") parser.add_argument( - "--metadata", - default=None, - type=str, - help="json file with id and file paths ") + "--metadata", type=str, help="json file with id and file paths ") parser.add_argument( - "--field-name", - default=None, - type=str, - help="json file with id and file paths ") + "--field-name", type=str, help="json file with id and file paths ") parser.add_argument( "--config", type=str, help="yaml format configuration file.") parser.add_argument( "--dumpdir", - default=None, type=str, help="directory to save statistics. if not provided, " "stats will be saved in the above root directory. (default=None)") @@ -89,8 +83,8 @@ def main(): if not os.path.exists(args.dumpdir): os.makedirs(args.dumpdir) - with open(args.metadata, 'rt') as f: - metadata = json.load(f) + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) dataset = DataTable( metadata, fields=[args.field_name], @@ -101,7 +95,7 @@ def main(): scaler = StandardScaler() for datum in tqdm(dataset): # StandardScalar supports (*, num_features) by default - scaler.partial_fit(datum[args.field_name].T) + scaler.partial_fit(datum[args.field_name]) stats = np.stack([scaler.mean_, scaler.scale_], axis=0) np.save( diff --git a/examples/parallelwave_gan/baker/normalize.py b/examples/parallelwave_gan/baker/normalize.py index 185a92b..6134917 100644 --- a/examples/parallelwave_gan/baker/normalize.py +++ b/examples/parallelwave_gan/baker/normalize.py @@ -11,3 +11,144 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) +"""Normalize feature files and dump them.""" + +import argparse +import logging +import os +from operator import itemgetter +from pathlib import Path + +import numpy as np +import yaml +import jsonlines + +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from parakeet.datasets.data_table import DataTable +from parakeet.utils.h5_utils import read_hdf5 +from parakeet.utils.h5_utils import write_hdf5 + +from config import get_cfg_default + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)." + ) + parser.add_argument( + "--metadata", + type=str, + required=True, + help="directory including feature files to be normalized. " + "you need to specify either *-scp or rootdir.") + parser.add_argument( + "--dumpdir", + type=str, + required=True, + help="directory to dump normalized feature files.") + parser.add_argument( + "--stats", type=str, required=True, help="statistics file.") + parser.add_argument( + "--skip-wav-copy", + default=False, + action="store_true", + help="whether to skip the copy of wav files.") + parser.add_argument( + "--config", type=str, help="yaml format configuration file.") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + logging.warning('Skip DEBUG/INFO messages') + + # load config + config = get_cfg_default() + if args.config: + config.merge_from_file(args.config) + + # check directory existence + dumpdir = Path(args.dumpdir).resolve() + dumpdir.mkdir(parents=True, exist_ok=True) + + # get dataset + with jsonlines.open(args.metadata, 'r') as reader: + metadata = list(reader) + dataset = DataTable( + metadata, + fields=["utt_id", "wave", "feats"], + converters={ + 'utt_id': None, + 'wave': None if args.skip_wav_copy else np.load, + 'feats': np.load, + }) + logging.info(f"The number of files = {len(dataset)}.") + + # restore scaler + scaler = StandardScaler() + scaler.mean_ = np.load(args.stats)[0] + scaler.scale_ = np.load(args.stats)[1] + + # from version 0.23.0, this information is needed + scaler.n_features_in_ = scaler.mean_.shape[0] + + # process each file + output_metadata = [] + + for item in tqdm(dataset): + utt_id = item['utt_id'] + wave = item['wave'] + mel = item['feats'] + # normalize + mel = scaler.transform(mel) + + # save + mel_path = dumpdir / f"{utt_id}-feats.npy" + np.save(mel_path, mel.astype(np.float32), allow_pickle=False) + if not args.skip_wav_copy: + wav_path = dumpdir / f"{utt_id}-wave.npy" + np.save(wav_path, wave.astype(np.float32), allow_pickle=False) + else: + wav_path = wave + output_metadata.append({ + 'utt_id': utt_id, + 'wave': str(wav_path), + 'feats': str(mel_path), + }) + output_metadata.sort(key=itemgetter('utt_id')) + output_metadata_path = Path(args.dumpdir) / "metadata.jsonl" + with jsonlines.open(output_metadata_path, 'w') as writer: + for item in output_metadata: + writer.write(item) + logging.info(f"metadata dumped into {output_metadata_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/parallelwave_gan/baker/preprocess.py b/examples/parallelwave_gan/baker/preprocess.py index 23b5f05..09f2004 100644 --- a/examples/parallelwave_gan/baker/preprocess.py +++ b/examples/parallelwave_gan/baker/preprocess.py @@ -19,6 +19,7 @@ import numpy as np import argparse import yaml import json +import jsonlines import concurrent.futures from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from pathlib import Path @@ -152,14 +153,14 @@ def process_sentence(config: Dict[str, Any], mel_path = output_dir / (utt_id + "_feats.npy") wav_path = output_dir / (utt_id + "_wave.npy") - np.save(wav_path, y) - np.save(mel_path, logmel) + np.save(wav_path, y) # (num_samples, ) + np.save(mel_path, logmel.T) # (num_frames, n_mels) record = { "utt_id": utt_id, "num_samples": num_sample, "num_frames": num_frames, - "feats_path": str(mel_path.resolve()), - "wave_path": str(wav_path.resolve()), + "feats": str(mel_path.resolve()), + "wave": str(wav_path.resolve()), } return record @@ -175,7 +176,7 @@ def process_sentences(config, results.append( process_sentence(config, fp, alignment_fp, output_dir)) else: - with ProcessPoolExecutor(nprocs) as pool: + with ThreadPoolExecutor(nprocs) as pool: futures = [] with tqdm.tqdm(total=len(fps)) as progress: for fp, alignment_fp in zip(fps, alignment_fps): @@ -189,8 +190,9 @@ def process_sentences(config, results.append(ft.result()) results.sort(key=itemgetter("utt_id")) - with open(output_dir / "metadata.json", 'wt') as f: - json.dump(results, f) + with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer: + for item in results: + writer.write(item) print("Done") @@ -247,11 +249,11 @@ def main(): dev_alignment_files = alignment_files[9800:9900] test_alignment_files = alignment_files[9900:] - train_dump_dir = dumpdir / "train" + train_dump_dir = dumpdir / "train" / "raw" train_dump_dir.mkdir(parents=True, exist_ok=True) - dev_dump_dir = dumpdir / "dev" + dev_dump_dir = dumpdir / "dev" / "raw" dev_dump_dir.mkdir(parents=True, exist_ok=True) - test_dump_dir = dumpdir / "test" + test_dump_dir = dumpdir / "test" / "raw" test_dump_dir.mkdir(parents=True, exist_ok=True) # process for the 3 sections diff --git a/setup.py b/setup.py index 3a6c87e..b7cb4da 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,7 @@ setup_info = dict( 'praatio', "h5py", "timer", + 'jsonlines', ], extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },