1. add compute_statistics and normalize;
2. use jsonlines to read and write metadata by default; 3. use threadpool to replace processpool in preprocessing cause it is faster.
This commit is contained in:
parent
30045cf602
commit
fbc7e51fc9
|
@ -20,6 +20,7 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
import json
|
import json
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
@ -36,20 +37,13 @@ def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Compute mean and variance of dumped raw features.")
|
description="Compute mean and variance of dumped raw features.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--metadata",
|
"--metadata", type=str, help="json file with id and file paths ")
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="json file with id and file paths ")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--field-name",
|
"--field-name", type=str, help="json file with id and file paths ")
|
||||||
default=None,
|
|
||||||
type=str,
|
|
||||||
help="json file with id and file paths ")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config", type=str, help="yaml format configuration file.")
|
"--config", type=str, help="yaml format configuration file.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dumpdir",
|
"--dumpdir",
|
||||||
default=None,
|
|
||||||
type=str,
|
type=str,
|
||||||
help="directory to save statistics. if not provided, "
|
help="directory to save statistics. if not provided, "
|
||||||
"stats will be saved in the above root directory. (default=None)")
|
"stats will be saved in the above root directory. (default=None)")
|
||||||
|
@ -89,8 +83,8 @@ def main():
|
||||||
if not os.path.exists(args.dumpdir):
|
if not os.path.exists(args.dumpdir):
|
||||||
os.makedirs(args.dumpdir)
|
os.makedirs(args.dumpdir)
|
||||||
|
|
||||||
with open(args.metadata, 'rt') as f:
|
with jsonlines.open(args.metadata, 'r') as reader:
|
||||||
metadata = json.load(f)
|
metadata = list(reader)
|
||||||
dataset = DataTable(
|
dataset = DataTable(
|
||||||
metadata,
|
metadata,
|
||||||
fields=[args.field_name],
|
fields=[args.field_name],
|
||||||
|
@ -101,7 +95,7 @@ def main():
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
for datum in tqdm(dataset):
|
for datum in tqdm(dataset):
|
||||||
# StandardScalar supports (*, num_features) by default
|
# StandardScalar supports (*, num_features) by default
|
||||||
scaler.partial_fit(datum[args.field_name].T)
|
scaler.partial_fit(datum[args.field_name])
|
||||||
|
|
||||||
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
|
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
|
||||||
np.save(
|
np.save(
|
||||||
|
|
|
@ -11,3 +11,144 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
# Copyright 2019 Tomoki Hayashi
|
||||||
|
# MIT License (https://opensource.org/licenses/MIT)
|
||||||
|
"""Normalize feature files and dump them."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from operator import itemgetter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import yaml
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
from parakeet.utils.h5_utils import read_hdf5
|
||||||
|
from parakeet.utils.h5_utils import write_hdf5
|
||||||
|
|
||||||
|
from config import get_cfg_default
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Run preprocessing process."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Normalize dumped raw features (See detail in parallel_wavegan/bin/normalize.py)."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--metadata",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="directory including feature files to be normalized. "
|
||||||
|
"you need to specify either *-scp or rootdir.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dumpdir",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="directory to dump normalized feature files.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--stats", type=str, required=True, help="statistics file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-wav-copy",
|
||||||
|
default=False,
|
||||||
|
action="store_true",
|
||||||
|
help="whether to skip the copy of wav files.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--config", type=str, help="yaml format configuration file.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--verbose",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="logging level. higher is more logging. (default=1)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# set logger
|
||||||
|
if args.verbose > 1:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
elif args.verbose > 0:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.WARN,
|
||||||
|
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
|
||||||
|
)
|
||||||
|
logging.warning('Skip DEBUG/INFO messages')
|
||||||
|
|
||||||
|
# load config
|
||||||
|
config = get_cfg_default()
|
||||||
|
if args.config:
|
||||||
|
config.merge_from_file(args.config)
|
||||||
|
|
||||||
|
# check directory existence
|
||||||
|
dumpdir = Path(args.dumpdir).resolve()
|
||||||
|
dumpdir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# get dataset
|
||||||
|
with jsonlines.open(args.metadata, 'r') as reader:
|
||||||
|
metadata = list(reader)
|
||||||
|
dataset = DataTable(
|
||||||
|
metadata,
|
||||||
|
fields=["utt_id", "wave", "feats"],
|
||||||
|
converters={
|
||||||
|
'utt_id': None,
|
||||||
|
'wave': None if args.skip_wav_copy else np.load,
|
||||||
|
'feats': np.load,
|
||||||
|
})
|
||||||
|
logging.info(f"The number of files = {len(dataset)}.")
|
||||||
|
|
||||||
|
# restore scaler
|
||||||
|
scaler = StandardScaler()
|
||||||
|
scaler.mean_ = np.load(args.stats)[0]
|
||||||
|
scaler.scale_ = np.load(args.stats)[1]
|
||||||
|
|
||||||
|
# from version 0.23.0, this information is needed
|
||||||
|
scaler.n_features_in_ = scaler.mean_.shape[0]
|
||||||
|
|
||||||
|
# process each file
|
||||||
|
output_metadata = []
|
||||||
|
|
||||||
|
for item in tqdm(dataset):
|
||||||
|
utt_id = item['utt_id']
|
||||||
|
wave = item['wave']
|
||||||
|
mel = item['feats']
|
||||||
|
# normalize
|
||||||
|
mel = scaler.transform(mel)
|
||||||
|
|
||||||
|
# save
|
||||||
|
mel_path = dumpdir / f"{utt_id}-feats.npy"
|
||||||
|
np.save(mel_path, mel.astype(np.float32), allow_pickle=False)
|
||||||
|
if not args.skip_wav_copy:
|
||||||
|
wav_path = dumpdir / f"{utt_id}-wave.npy"
|
||||||
|
np.save(wav_path, wave.astype(np.float32), allow_pickle=False)
|
||||||
|
else:
|
||||||
|
wav_path = wave
|
||||||
|
output_metadata.append({
|
||||||
|
'utt_id': utt_id,
|
||||||
|
'wave': str(wav_path),
|
||||||
|
'feats': str(mel_path),
|
||||||
|
})
|
||||||
|
output_metadata.sort(key=itemgetter('utt_id'))
|
||||||
|
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
||||||
|
with jsonlines.open(output_metadata_path, 'w') as writer:
|
||||||
|
for item in output_metadata:
|
||||||
|
writer.write(item)
|
||||||
|
logging.info(f"metadata dumped into {output_metadata_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
|
@ -19,6 +19,7 @@ import numpy as np
|
||||||
import argparse
|
import argparse
|
||||||
import yaml
|
import yaml
|
||||||
import json
|
import json
|
||||||
|
import jsonlines
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -152,14 +153,14 @@ def process_sentence(config: Dict[str, Any],
|
||||||
|
|
||||||
mel_path = output_dir / (utt_id + "_feats.npy")
|
mel_path = output_dir / (utt_id + "_feats.npy")
|
||||||
wav_path = output_dir / (utt_id + "_wave.npy")
|
wav_path = output_dir / (utt_id + "_wave.npy")
|
||||||
np.save(wav_path, y)
|
np.save(wav_path, y) # (num_samples, )
|
||||||
np.save(mel_path, logmel)
|
np.save(mel_path, logmel.T) # (num_frames, n_mels)
|
||||||
record = {
|
record = {
|
||||||
"utt_id": utt_id,
|
"utt_id": utt_id,
|
||||||
"num_samples": num_sample,
|
"num_samples": num_sample,
|
||||||
"num_frames": num_frames,
|
"num_frames": num_frames,
|
||||||
"feats_path": str(mel_path.resolve()),
|
"feats": str(mel_path.resolve()),
|
||||||
"wave_path": str(wav_path.resolve()),
|
"wave": str(wav_path.resolve()),
|
||||||
}
|
}
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
@ -175,7 +176,7 @@ def process_sentences(config,
|
||||||
results.append(
|
results.append(
|
||||||
process_sentence(config, fp, alignment_fp, output_dir))
|
process_sentence(config, fp, alignment_fp, output_dir))
|
||||||
else:
|
else:
|
||||||
with ProcessPoolExecutor(nprocs) as pool:
|
with ThreadPoolExecutor(nprocs) as pool:
|
||||||
futures = []
|
futures = []
|
||||||
with tqdm.tqdm(total=len(fps)) as progress:
|
with tqdm.tqdm(total=len(fps)) as progress:
|
||||||
for fp, alignment_fp in zip(fps, alignment_fps):
|
for fp, alignment_fp in zip(fps, alignment_fps):
|
||||||
|
@ -189,8 +190,9 @@ def process_sentences(config,
|
||||||
results.append(ft.result())
|
results.append(ft.result())
|
||||||
|
|
||||||
results.sort(key=itemgetter("utt_id"))
|
results.sort(key=itemgetter("utt_id"))
|
||||||
with open(output_dir / "metadata.json", 'wt') as f:
|
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
||||||
json.dump(results, f)
|
for item in results:
|
||||||
|
writer.write(item)
|
||||||
print("Done")
|
print("Done")
|
||||||
|
|
||||||
|
|
||||||
|
@ -247,11 +249,11 @@ def main():
|
||||||
dev_alignment_files = alignment_files[9800:9900]
|
dev_alignment_files = alignment_files[9800:9900]
|
||||||
test_alignment_files = alignment_files[9900:]
|
test_alignment_files = alignment_files[9900:]
|
||||||
|
|
||||||
train_dump_dir = dumpdir / "train"
|
train_dump_dir = dumpdir / "train" / "raw"
|
||||||
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
train_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||||
dev_dump_dir = dumpdir / "dev"
|
dev_dump_dir = dumpdir / "dev" / "raw"
|
||||||
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
dev_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||||
test_dump_dir = dumpdir / "test"
|
test_dump_dir = dumpdir / "test" / "raw"
|
||||||
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
test_dump_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# process for the 3 sections
|
# process for the 3 sections
|
||||||
|
|
Loading…
Reference in New Issue