1. add compute_statistics and normalize;
2. use jsonlines to read and write metadata by default; 3. use threadpool to replace processpool in preprocessing cause it is faster.
This commit is contained in:
parent
30045cf602
commit
fbc7e51fc9
|
@ -20,6 +20,7 @@ import os
|
|||
import numpy as np
|
||||
import yaml
|
||||
import json
|
||||
import jsonlines
|
||||
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from tqdm import tqdm
|
||||
|
@ -36,20 +37,13 @@ def main():
|
|||
parser = argparse.ArgumentParser(
|
||||
description="Compute mean and variance of dumped raw features.")
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
default=None,
|
||||
type=str,
|
||||
help="json file with id and file paths ")
|
||||
"--metadata", type=str, help="json file with id and file paths ")
|
||||
parser.add_argument(
|
||||
"--field-name",
|
||||
default=None,
|
||||
type=str,
|
||||
help="json file with id and file paths ")
|
||||
"--field-name", type=str, help="json file with id and file paths ")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
default=None,
|
||||
type=str,
|
||||
help="directory to save statistics. if not provided, "
|
||||
"stats will be saved in the above root directory. (default=None)")
|
||||
|
@ -89,8 +83,8 @@ def main():
|
|||
if not os.path.exists(args.dumpdir):
|
||||
os.makedirs(args.dumpdir)
|
||||
|
||||
with open(args.metadata, 'rt') as f:
|
||||
metadata = json.load(f)
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
dataset = DataTable(
|
||||
metadata,
|
||||
fields=[args.field_name],
|
||||
|
@ -101,7 +95,7 @@ def main():
|
|||
scaler = StandardScaler()
|
||||
for datum in tqdm(dataset):
|
||||
# StandardScalar supports (*, num_features) by default
|
||||
scaler.partial_fit(datum[args.field_name].T)
|
||||
scaler.partial_fit(datum[args.field_name])
|
||||
|
||||
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
|
||||
np.save(
|
||||
|
|
|
@ -11,3 +11,144 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Tomoki Hayashi
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
"""Normalize feature files and dump them."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from operator import itemgetter
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
import jsonlines
|
||||
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from tqdm import tqdm
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.utils.h5_utils import read_hdf5
|
||||
from parakeet.utils.h5_utils import write_hdf5
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metadata",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory including feature files to be normalized. "
|
||||
"you need to specify either *-scp or rootdir.")
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="directory to dump normalized feature files.")
|
||||
parser.add_argument(
|
||||
"--stats", type=str, required=True, help="statistics file.")
|
||||
parser.add_argument(
|
||||
"--skip-wav-copy",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="whether to skip the copy of wav files.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
default=1,
|
||||
help="logging level. higher is more logging. (default=1)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# set logger
|
||||
if args.verbose > 1:
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
elif args.verbose > 0:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
else:
|
||||
logging.basicConfig(
|
||||
level=logging.WARN,
|
||||
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||
)
|
||||
logging.warning('Skip DEBUG/INFO messages')
|
||||
|
||||
# load config
|
||||
config = get_cfg_default()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
# check directory existence
|
||||
dumpdir = Path(args.dumpdir).resolve()
|
||||
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# get dataset
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
dataset = DataTable(
|
||||
metadata,
|
||||
fields=["utt_id", "wave", "feats"],
|
||||
converters={
|
||||
'utt_id': None,
|
||||
'wave': None if args.skip_wav_copy else np.load,
|
||||
'feats': np.load,
|
||||
})
|
||||
logging.info(f"The number of files = {len(dataset)}.")
|
||||
|
||||
# restore scaler
|
||||
scaler = StandardScaler()
|
||||
scaler.mean_ = np.load(args.stats)[0]
|
||||
scaler.scale_ = np.load(args.stats)[1]
|
||||
|
||||
# from version 0.23.0, this information is needed
|
||||
scaler.n_features_in_ = scaler.mean_.shape[0]
|
||||
|
||||
# process each file
|
||||
output_metadata = []
|
||||
|
||||
for item in tqdm(dataset):
|
||||
utt_id = item['utt_id']
|
||||
wave = item['wave']
|
||||
mel = item['feats']
|
||||
# normalize
|
||||
mel = scaler.transform(mel)
|
||||
|
||||
# save
|
||||
mel_path = dumpdir / f"{utt_id}-feats.npy"
|
||||
np.save(mel_path, mel.astype(np.float32), allow_pickle=False)
|
||||
if not args.skip_wav_copy:
|
||||
wav_path = dumpdir / f"{utt_id}-wave.npy"
|
||||
np.save(wav_path, wave.astype(np.float32), allow_pickle=False)
|
||||
else:
|
||||
wav_path = wave
|
||||
output_metadata.append({
|
||||
'utt_id': utt_id,
|
||||
'wave': str(wav_path),
|
||||
'feats': str(mel_path),
|
||||
})
|
||||
output_metadata.sort(key=itemgetter('utt_id'))
|
||||
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
||||
with jsonlines.open(output_metadata_path, 'w') as writer:
|
||||
for item in output_metadata:
|
||||
writer.write(item)
|
||||
logging.info(f"metadata dumped into {output_metadata_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
|||
import argparse
|
||||
import yaml
|
||||
import json
|
||||
import jsonlines
|
||||
import concurrent.futures
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
@ -152,14 +153,14 @@ def process_sentence(config: Dict[str, Any],
|
|||
|
||||
mel_path = output_dir / (utt_id + "_feats.npy")
|
||||
wav_path = output_dir / (utt_id + "_wave.npy")
|
||||
np.save(wav_path, y)
|
||||
np.save(mel_path, logmel)
|
||||
np.save(wav_path, y) # (num_samples, )
|
||||
np.save(mel_path, logmel.T) # (num_frames, n_mels)
|
||||
record = {
|
||||
"utt_id": utt_id,
|
||||
"num_samples": num_sample,
|
||||
"num_frames": num_frames,
|
||||
"feats_path": str(mel_path.resolve()),
|
||||
"wave_path": str(wav_path.resolve()),
|
||||
"feats": str(mel_path.resolve()),
|
||||
"wave": str(wav_path.resolve()),
|
||||
}
|
||||
return record
|
||||
|
||||
|
@ -175,7 +176,7 @@ def process_sentences(config,
|
|||
results.append(
|
||||
process_sentence(config, fp, alignment_fp, output_dir))
|
||||
else:
|
||||
with ProcessPoolExecutor(nprocs) as pool:
|
||||
with ThreadPoolExecutor(nprocs) as pool:
|
||||
futures = []
|
||||
with tqdm.tqdm(total=len(fps)) as progress:
|
||||
for fp, alignment_fp in zip(fps, alignment_fps):
|
||||
|
@ -189,8 +190,9 @@ def process_sentences(config,
|
|||
results.append(ft.result())
|
||||
|
||||
results.sort(key=itemgetter("utt_id"))
|
||||
with open(output_dir / "metadata.json", 'wt') as f:
|
||||
json.dump(results, f)
|
||||
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
||||
for item in results:
|
||||
writer.write(item)
|
||||
print("Done")
|
||||
|
||||
|
||||
|
@ -247,11 +249,11 @@ def main():
|
|||
dev_alignment_files = alignment_files[9800:9900]
|
||||
test_alignment_files = alignment_files[9900:]
|
||||
|
||||
train_dump_dir = dumpdir / "train"
|
||||
train_dump_dir = dumpdir / "train" / "raw"
|
||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
dev_dump_dir = dumpdir / "dev"
|
||||
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
test_dump_dir = dumpdir / "test"
|
||||
test_dump_dir = dumpdir / "test" / "raw"
|
||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# process for the 3 sections
|
||||
|
|
Loading…
Reference in New Issue