1. add compute_statistics and normalize;

2. use jsonlines to read and write metadata by default;
3. use threadpool to replace processpool in preprocessing cause it is faster.
This commit is contained in:
chenfeiyu 2021-06-18 02:49:49 +00:00
parent 30045cf602
commit fbc7e51fc9
4 changed files with 160 additions and 22 deletions

View File

@ -20,6 +20,7 @@ import os
import numpy as np import numpy as np
import yaml import yaml
import json import json
import jsonlines
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from tqdm import tqdm from tqdm import tqdm
@ -36,20 +37,13 @@ def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Compute mean and variance of dumped raw features.") description="Compute mean and variance of dumped raw features.")
parser.add_argument( parser.add_argument(
"--metadata", "--metadata", type=str, help="json file with id and file paths ")
default=None,
type=str,
help="json file with id and file paths ")
parser.add_argument( parser.add_argument(
"--field-name", "--field-name", type=str, help="json file with id and file paths ")
default=None,
type=str,
help="json file with id and file paths ")
parser.add_argument( parser.add_argument(
"--config", type=str, help="yaml format configuration file.") "--config", type=str, help="yaml format configuration file.")
parser.add_argument( parser.add_argument(
"--dumpdir", "--dumpdir",
default=None,
type=str, type=str,
help="directory to save statistics. if not provided, " help="directory to save statistics. if not provided, "
"stats will be saved in the above root directory. (default=None)") "stats will be saved in the above root directory. (default=None)")
@ -89,8 +83,8 @@ def main():
if not os.path.exists(args.dumpdir): if not os.path.exists(args.dumpdir):
os.makedirs(args.dumpdir) os.makedirs(args.dumpdir)
with open(args.metadata, 'rt') as f: with jsonlines.open(args.metadata, 'r') as reader:
metadata = json.load(f) metadata = list(reader)
dataset = DataTable( dataset = DataTable(
metadata, metadata,
fields=[args.field_name], fields=[args.field_name],
@ -101,7 +95,7 @@ def main():
scaler = StandardScaler() scaler = StandardScaler()
for datum in tqdm(dataset): for datum in tqdm(dataset):
# StandardScalar supports (*, num_features) by default # StandardScalar supports (*, num_features) by default
scaler.partial_fit(datum[args.field_name].T) scaler.partial_fit(datum[args.field_name])
stats = np.stack([scaler.mean_, scaler.scale_], axis=0) stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
np.save( np.save(

View File

@ -11,3 +11,144 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Normalize feature files and dump them."""
import argparse
import logging
import os
from operator import itemgetter
from pathlib import Path
import numpy as np
import yaml
import jsonlines
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
from parakeet.datasets.data_table import DataTable
from parakeet.utils.h5_utils import read_hdf5
from parakeet.utils.h5_utils import write_hdf5
from config import get_cfg_default
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
)
parser.add_argument(
"--metadata",
type=str,
required=True,
help="directory including feature files to be normalized. "
"you need to specify either *-scp or rootdir.")
parser.add_argument(
"--dumpdir",
type=str,
required=True,
help="directory to dump normalized feature files.")
parser.add_argument(
"--stats", type=str, required=True, help="statistics file.")
parser.add_argument(
"--skip-wav-copy",
default=False,
action="store_true",
help="whether to skip the copy of wav files.")
parser.add_argument(
"--config", type=str, help="yaml format configuration file.")
parser.add_argument(
"--verbose",
type=int,
default=1,
help="logging level. higher is more logging. (default=1)")
args = parser.parse_args()
# set logger
if args.verbose > 1:
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
elif args.verbose > 0:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
else:
logging.basicConfig(
level=logging.WARN,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
logging.warning('Skip DEBUG/INFO messages')
# load config
config = get_cfg_default()
if args.config:
config.merge_from_file(args.config)
# check directory existence
dumpdir = Path(args.dumpdir).resolve()
dumpdir.mkdir(parents=True, exist_ok=True)
# get dataset
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
dataset = DataTable(
metadata,
fields=["utt_id", "wave", "feats"],
converters={
'utt_id': None,
'wave': None if args.skip_wav_copy else np.load,
'feats': np.load,
})
logging.info(f"The number of files = {len(dataset)}.")
# restore scaler
scaler = StandardScaler()
scaler.mean_ = np.load(args.stats)[0]
scaler.scale_ = np.load(args.stats)[1]
# from version 0.23.0, this information is needed
scaler.n_features_in_ = scaler.mean_.shape[0]
# process each file
output_metadata = []
for item in tqdm(dataset):
utt_id = item['utt_id']
wave = item['wave']
mel = item['feats']
# normalize
mel = scaler.transform(mel)
# save
mel_path = dumpdir / f"{utt_id}-feats.npy"
np.save(mel_path, mel.astype(np.float32), allow_pickle=False)
if not args.skip_wav_copy:
wav_path = dumpdir / f"{utt_id}-wave.npy"
np.save(wav_path, wave.astype(np.float32), allow_pickle=False)
else:
wav_path = wave
output_metadata.append({
'utt_id': utt_id,
'wave': str(wav_path),
'feats': str(mel_path),
})
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
with jsonlines.open(output_metadata_path, 'w') as writer:
for item in output_metadata:
writer.write(item)
logging.info(f"metadata dumped into {output_metadata_path}")
if __name__ == "__main__":
main()

View File

@ -19,6 +19,7 @@ import numpy as np
import argparse import argparse
import yaml import yaml
import json import json
import jsonlines
import concurrent.futures import concurrent.futures
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from pathlib import Path from pathlib import Path
@ -152,14 +153,14 @@ def process_sentence(config: Dict[str, Any],
mel_path = output_dir / (utt_id + "_feats.npy") mel_path = output_dir / (utt_id + "_feats.npy")
wav_path = output_dir / (utt_id + "_wave.npy") wav_path = output_dir / (utt_id + "_wave.npy")
np.save(wav_path, y) np.save(wav_path, y) # (num_samples, )
np.save(mel_path, logmel) np.save(mel_path, logmel.T) # (num_frames, n_mels)
record = { record = {
"utt_id": utt_id, "utt_id": utt_id,
"num_samples": num_sample, "num_samples": num_sample,
"num_frames": num_frames, "num_frames": num_frames,
"feats_path": str(mel_path.resolve()), "feats": str(mel_path.resolve()),
"wave_path": str(wav_path.resolve()), "wave": str(wav_path.resolve()),
} }
return record return record
@ -175,7 +176,7 @@ def process_sentences(config,
results.append( results.append(
process_sentence(config, fp, alignment_fp, output_dir)) process_sentence(config, fp, alignment_fp, output_dir))
else: else:
with ProcessPoolExecutor(nprocs) as pool: with ThreadPoolExecutor(nprocs) as pool:
futures = [] futures = []
with tqdm.tqdm(total=len(fps)) as progress: with tqdm.tqdm(total=len(fps)) as progress:
for fp, alignment_fp in zip(fps, alignment_fps): for fp, alignment_fp in zip(fps, alignment_fps):
@ -189,8 +190,9 @@ def process_sentences(config,
results.append(ft.result()) results.append(ft.result())
results.sort(key=itemgetter("utt_id")) results.sort(key=itemgetter("utt_id"))
with open(output_dir / "metadata.json", 'wt') as f: with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
json.dump(results, f) for item in results:
writer.write(item)
print("Done") print("Done")
@ -247,11 +249,11 @@ def main():
dev_alignment_files = alignment_files[9800:9900] dev_alignment_files = alignment_files[9800:9900]
test_alignment_files = alignment_files[9900:] test_alignment_files = alignment_files[9900:]
train_dump_dir = dumpdir / "train" train_dump_dir = dumpdir / "train" / "raw"
train_dump_dir.mkdir(parents=True, exist_ok=True) train_dump_dir.mkdir(parents=True, exist_ok=True)
dev_dump_dir = dumpdir / "dev" dev_dump_dir = dumpdir / "dev" / "raw"
dev_dump_dir.mkdir(parents=True, exist_ok=True) dev_dump_dir.mkdir(parents=True, exist_ok=True)
test_dump_dir = dumpdir / "test" test_dump_dir = dumpdir / "test" / "raw"
test_dump_dir.mkdir(parents=True, exist_ok=True) test_dump_dir.mkdir(parents=True, exist_ok=True)
# process for the 3 sections # process for the 3 sections

View File

@ -74,6 +74,7 @@ setup_info = dict(
'praatio', 'praatio',
"h5py", "h5py",
"timer", "timer",
'jsonlines',
], ],
extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], }, extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },