From 58a988c7895941b79efc34a7e74200426c26d328 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Wed, 16 Jun 2021 16:18:13 +0000 Subject: [PATCH] add code to compute statistics --- .../baker/compute_statistics.py | 114 ++++++++++++++++++ examples/parallelwave_gan/baker/normalize.py | 13 ++ parakeet/utils/h5_utils.py | 2 +- setup.py | 2 + 4 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 examples/parallelwave_gan/baker/compute_statistics.py create mode 100644 examples/parallelwave_gan/baker/normalize.py diff --git a/examples/parallelwave_gan/baker/compute_statistics.py b/examples/parallelwave_gan/baker/compute_statistics.py new file mode 100644 index 0000000..ea696f8 --- /dev/null +++ b/examples/parallelwave_gan/baker/compute_statistics.py @@ -0,0 +1,114 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Calculate statistics of feature files.""" + +import argparse +import logging +import os + +import numpy as np +import yaml +import json + +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm + +from parakeet.datasets.data_table import DataTable +from parakeet.utils.h5_utils import read_hdf5 +from parakeet.utils.h5_utils import write_hdf5 + +from config import get_cfg_default + + +def main(): + """Run preprocessing process.""" + parser = argparse.ArgumentParser( + description="Compute mean and variance of dumped raw features.") + parser.add_argument( + "--metadata", + default=None, + type=str, + help="json file with id and file paths ") + parser.add_argument( + "--field-name", + default=None, + type=str, + help="json file with id and file paths ") + parser.add_argument( + "--config", type=str, help="yaml format configuration file.") + parser.add_argument( + "--dumpdir", + default=None, + type=str, + help="directory to save statistics. if not provided, " + "stats will be saved in the above root directory. (default=None)") + parser.add_argument( + "--verbose", + type=int, + default=1, + help="logging level. higher is more logging. (default=1)") + args = parser.parse_args() + + # set logger + if args.verbose > 1: + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + elif args.verbose > 0: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + else: + logging.basicConfig( + level=logging.WARN, + format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" + ) + logging.warning('Skip DEBUG/INFO messages') + + config = get_cfg_default() + # load config + if args.config: + config.merge_from_file(args.config) + + # check directory existence + if args.dumpdir is None: + args.dumpdir = os.path.dirname(args.metadata) + if not os.path.exists(args.dumpdir): + os.makedirs(args.dumpdir) + + with open(args.metadata, 'rt') as f: + metadata = json.load(f) + dataset = DataTable( + metadata, + fields=[args.field_name], + converters={args.field_name: np.load}, ) + logging.info(f"The number of files = {len(dataset)}.") + + # calculate statistics + scaler = StandardScaler() + for datum in tqdm(dataset): + # StandardScalar supports (*, num_features) by default + scaler.partial_fit(datum[args.field_name].T) + + stats = np.stack([scaler.mean_, scaler.scale_], axis=0) + np.save( + os.path.join(args.dumpdir, "stats.npy"), + stats.astype(np.float32), + allow_pickle=False) + + +if __name__ == "__main__": + main() diff --git a/examples/parallelwave_gan/baker/normalize.py b/examples/parallelwave_gan/baker/normalize.py new file mode 100644 index 0000000..185a92b --- /dev/null +++ b/examples/parallelwave_gan/baker/normalize.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/parakeet/utils/h5_utils.py b/parakeet/utils/h5_utils.py index 7cdbfc6..cd0c670 100644 --- a/parakeet/utils/h5_utils.py +++ b/parakeet/utils/h5_utils.py @@ -57,7 +57,7 @@ def read_hdf5(filename: Union[Path, str], dataset_name: str) -> Any: def write_hdf5(filename: Union[Path, str], dataset_name: str, - write_data: np.ndarrays, + write_data: np.ndarray, is_overwrite: bool=True) -> None: """Write dataset to HDF5 file. diff --git a/setup.py b/setup.py index 6f112cc..3a6c87e 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,8 @@ setup_info = dict( 'webrtcvad', 'g2pM', 'praatio', + "h5py", + "timer", ], extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },