add more cli options
This commit is contained in:
parent
51397f8500
commit
4ba8e7e342
|
@ -16,6 +16,7 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
@ -45,10 +46,10 @@ def main():
|
|||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
"--dumpdir",
|
||||
"--output",
|
||||
type=str,
|
||||
help="directory to save statistics. if not provided, "
|
||||
"stats will be saved in the above root directory. (default=None)")
|
||||
help="path to save statistics. if not provided, "
|
||||
"stats will be saved in the above root directory with name stats.npy")
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
type=int,
|
||||
|
@ -80,10 +81,11 @@ def main():
|
|||
config.merge_from_file(args.config)
|
||||
|
||||
# check directory existence
|
||||
if args.dumpdir is None:
|
||||
args.dumpdir = os.path.dirname(args.metadata)
|
||||
if not os.path.exists(args.dumpdir):
|
||||
os.makedirs(args.dumpdir)
|
||||
if args.output is None:
|
||||
args.output = Path(args.metadata).parent.with_name("stats.npy")
|
||||
else:
|
||||
args.output = Path(args.output)
|
||||
args.output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with jsonlines.open(args.metadata, 'r') as reader:
|
||||
metadata = list(reader)
|
||||
|
@ -100,10 +102,7 @@ def main():
|
|||
scaler.partial_fit(datum[args.field_name])
|
||||
|
||||
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
|
||||
np.save(
|
||||
os.path.join(args.dumpdir, "stats.npy"),
|
||||
stats.astype(np.float32),
|
||||
allow_pickle=False)
|
||||
np.save(str(args.output), stats.astype(np.float32), allow_pickle=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -50,6 +50,13 @@ def main():
|
|||
help="directory to dump normalized feature files.")
|
||||
parser.add_argument(
|
||||
"--stats", type=str, required=True, help="statistics file.")
|
||||
parser.add_argument(
|
||||
"--phones",
|
||||
type=str,
|
||||
default="phones.txt",
|
||||
help="phone vocabulary file.")
|
||||
parser.add_argument(
|
||||
"--tones", type=str, default="tones.txt", help="tone vocabulary file.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="yaml format configuration file.")
|
||||
parser.add_argument(
|
||||
|
@ -100,10 +107,10 @@ def main():
|
|||
# from version 0.23.0, this information is needed
|
||||
scaler.n_features_in_ = scaler.mean_.shape[0]
|
||||
|
||||
with open("phones.txt", 'rt') as f:
|
||||
with open(args.phones, 'rt') as f:
|
||||
phones = [line.strip() for line in f.readlines()]
|
||||
|
||||
with open("tones.txt", 'rt') as f:
|
||||
with open(args.tones, 'rt') as f:
|
||||
tones = [line.strip() for line in f.readlines()]
|
||||
voc_phones = Vocab(phones, start_symbol=None, end_symbol=None)
|
||||
voc_tones = Vocab(tones, start_symbol=None, end_symbol=None)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
python preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --num_cpu=20
|
||||
python compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="feats" --dumpdir=dump/train
|
||||
python compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="feats" --output=dump/train/stats.npy
|
||||
|
||||
python normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --stats=dump/train/stats.npy
|
||||
python normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --stats=dump/train/stats.npy
|
||||
|
|
|
@ -62,7 +62,6 @@ class SpeedySpeechUpdater(StandardUpdater):
|
|||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
report("train/loss", float(loss))
|
||||
report("train/l1_loss", float(l1_loss))
|
||||
report("train/duration_loss", float(duration_loss))
|
||||
|
|
Loading…
Reference in New Issue