diff --git a/examples/speedyspeech/baker/compute_statistics.py b/examples/speedyspeech/baker/compute_statistics.py index 06b9b65..e145974 100644 --- a/examples/speedyspeech/baker/compute_statistics.py +++ b/examples/speedyspeech/baker/compute_statistics.py @@ -16,6 +16,7 @@ import argparse import logging import os +from pathlib import Path import numpy as np import yaml @@ -45,10 +46,10 @@ def main(): parser.add_argument( "--config", type=str, help="yaml format configuration file.") parser.add_argument( - "--dumpdir", + "--output", type=str, - help="directory to save statistics. if not provided, " - "stats will be saved in the above root directory. (default=None)") + help="path to save statistics. if not provided, " + "stats will be saved in the above root directory with name stats.npy") parser.add_argument( "--verbose", type=int, @@ -80,10 +81,11 @@ def main(): config.merge_from_file(args.config) # check directory existence - if args.dumpdir is None: - args.dumpdir = os.path.dirname(args.metadata) - if not os.path.exists(args.dumpdir): - os.makedirs(args.dumpdir) + if args.output is None: + args.output = Path(args.metadata).parent.with_name("stats.npy") + else: + args.output = Path(args.output) + args.output.parent.mkdir(parents=True, exist_ok=True) with jsonlines.open(args.metadata, 'r') as reader: metadata = list(reader) @@ -100,10 +102,7 @@ def main(): scaler.partial_fit(datum[args.field_name]) stats = np.stack([scaler.mean_, scaler.scale_], axis=0) - np.save( - os.path.join(args.dumpdir, "stats.npy"), - stats.astype(np.float32), - allow_pickle=False) + np.save(str(args.output), stats.astype(np.float32), allow_pickle=False) if __name__ == "__main__": diff --git a/examples/speedyspeech/baker/normalize.py b/examples/speedyspeech/baker/normalize.py index 74661f8..daa0a91 100644 --- a/examples/speedyspeech/baker/normalize.py +++ b/examples/speedyspeech/baker/normalize.py @@ -50,6 +50,13 @@ def main(): help="directory to dump normalized feature files.") parser.add_argument( "--stats", type=str, required=True, help="statistics file.") + parser.add_argument( + "--phones", + type=str, + default="phones.txt", + help="phone vocabulary file.") + parser.add_argument( + "--tones", type=str, default="tones.txt", help="tone vocabulary file.") parser.add_argument( "--config", type=str, help="yaml format configuration file.") parser.add_argument( @@ -100,10 +107,10 @@ def main(): # from version 0.23.0, this information is needed scaler.n_features_in_ = scaler.mean_.shape[0] - with open("phones.txt", 'rt') as f: + with open(args.phones, 'rt') as f: phones = [line.strip() for line in f.readlines()] - with open("tones.txt", 'rt') as f: + with open(args.tones, 'rt') as f: tones = [line.strip() for line in f.readlines()] voc_phones = Vocab(phones, start_symbol=None, end_symbol=None) voc_tones = Vocab(tones, start_symbol=None, end_symbol=None) diff --git a/examples/speedyspeech/baker/preprocess.sh b/examples/speedyspeech/baker/preprocess.sh index eab4146..d1f212b 100644 --- a/examples/speedyspeech/baker/preprocess.sh +++ b/examples/speedyspeech/baker/preprocess.sh @@ -1,5 +1,5 @@ python preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --num_cpu=20 -python compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="feats" --dumpdir=dump/train +python compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="feats" --output=dump/train/stats.npy python normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --stats=dump/train/stats.npy python normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --stats=dump/train/stats.npy diff --git a/examples/speedyspeech/baker/speedyspeech_updater.py b/examples/speedyspeech/baker/speedyspeech_updater.py index 46a4777..bc4d0f9 100644 --- a/examples/speedyspeech/baker/speedyspeech_updater.py +++ b/examples/speedyspeech/baker/speedyspeech_updater.py @@ -62,7 +62,6 @@ class SpeedySpeechUpdater(StandardUpdater): loss.backward() optimizer.step() - # import pdb; pdb.set_trace() report("train/loss", float(loss)) report("train/l1_loss", float(l1_loss)) report("train/duration_loss", float(duration_loss))