ParakeetRebeccaRosario/parakeet/models/deepvoice3/preprocess.py

90 lines
2.8 KiB
Python
Raw Normal View History

2019-11-13 22:22:46 +08:00
# Part of code was adpated from https://github.com/r9y9/deepvoice3_pytorch/tree/master/preprocess.py
# Copyright (c) 2017: Ryuichi Yamamoto.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import io
import six
import os
from multiprocessing import cpu_count
from tqdm import tqdm
import importlib
from hparams import hparams, hparams_debug_string
def build_parser():
parser = argparse.ArgumentParser(description="Data Preprocessing")
parser.add_argument("--num-workers", type=int, help="Num workers.")
parser.add_argument(
"--hparams",
type=str,
default="",
help="Hyper parameters to overwrite.")
parser.add_argument(
"--preset",
type=str,
required=True,
help="Path of preset parameters (json)")
parser.add_argument("name", type=str, help="Dataset name")
parser.add_argument("in_dir", type=str, help="Dataset path.")
parser.add_argument(
"out_dir", type=str, help="Path of preprocessed dataset.")
return parser
def preprocess(mod, in_dir, out_root, num_workers):
if not os.path.exists(out_dir):
os.makedirs(out_dir)
metadata = mod.build_from_path(in_dir, out_dir, num_workers, tqdm=tqdm)
write_metadata(metadata, out_dir)
def write_metadata(metadata, out_dir):
if six.PY3:
string_type = str
elif six.PY2:
string_type = unicode
else:
raise ValueError("Not running on Python2 or Python 3?")
with io.open(
os.path.join(out_dir, 'train.txt'), 'wt', encoding='utf-8') as f:
for m in metadata:
f.write(u'|'.join([string_type(x) for x in m]) + '\n')
frames = sum([m[2] for m in metadata])
frame_shift_ms = hparams.hop_size / hparams.sample_rate * 1000
hours = frames * frame_shift_ms / (3600 * 1000)
print('Wrote %d utterances, %d frames (%.2f hours)' %
(len(metadata), frames, hours))
print('Max input length: %d' % max(len(m[3]) for m in metadata))
print('Max output length: %d' % max(m[2] for m in metadata))
if __name__ == "__main__":
parser = build_parser()
args, _ = parser.parse_known_args()
name = args.name
in_dir = args.in_dir
out_dir = args.out_dir
num_workers = args.num_workers
if num_workers is None:
num_workers = cpu_count()
preset = args.preset
# Load preset if specified
if preset is not None:
with io.open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args.hparams)
assert hparams.name == "deepvoice3"
print(hparams_debug_string())
assert name in ["ljspeech"], "now we only supports ljspeech"
mod = importlib.import_module(name)
preprocess(mod, in_dir, out_dir, num_workers)