add more cli options

This commit is contained in:
chenfeiyu 2021-07-19 13:03:23 +08:00
parent 51397f8500
commit 4ba8e7e342
4 changed files with 20 additions and 15 deletions

View File

@ -16,6 +16,7 @@
import argparse
import logging
import os
from pathlib import Path
import numpy as np
import yaml
@ -45,10 +46,10 @@ def main():
parser.add_argument(
"--config", type=str, help="yaml format configuration file.")
parser.add_argument(
"--dumpdir",
"--output",
type=str,
help="directory to save statistics. if not provided, "
"stats will be saved in the above root directory. (default=None)")
help="path to save statistics. if not provided, "
"stats will be saved in the above root directory with name stats.npy")
parser.add_argument(
"--verbose",
type=int,
@ -80,10 +81,11 @@ def main():
config.merge_from_file(args.config)
# check directory existence
if args.dumpdir is None:
args.dumpdir = os.path.dirname(args.metadata)
if not os.path.exists(args.dumpdir):
os.makedirs(args.dumpdir)
if args.output is None:
args.output = Path(args.metadata).parent.with_name("stats.npy")
else:
args.output = Path(args.output)
args.output.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
@ -100,10 +102,7 @@ def main():
scaler.partial_fit(datum[args.field_name])
stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
np.save(
os.path.join(args.dumpdir, "stats.npy"),
stats.astype(np.float32),
allow_pickle=False)
np.save(str(args.output), stats.astype(np.float32), allow_pickle=False)
if __name__ == "__main__":

View File

@ -50,6 +50,13 @@ def main():
help="directory to dump normalized feature files.")
parser.add_argument(
"--stats", type=str, required=True, help="statistics file.")
parser.add_argument(
"--phones",
type=str,
default="phones.txt",
help="phone vocabulary file.")
parser.add_argument(
"--tones", type=str, default="tones.txt", help="tone vocabulary file.")
parser.add_argument(
"--config", type=str, help="yaml format configuration file.")
parser.add_argument(
@ -100,10 +107,10 @@ def main():
# from version 0.23.0, this information is needed
scaler.n_features_in_ = scaler.mean_.shape[0]
with open("phones.txt", 'rt') as f:
with open(args.phones, 'rt') as f:
phones = [line.strip() for line in f.readlines()]
with open("tones.txt", 'rt') as f:
with open(args.tones, 'rt') as f:
tones = [line.strip() for line in f.readlines()]
voc_phones = Vocab(phones, start_symbol=None, end_symbol=None)
voc_tones = Vocab(tones, start_symbol=None, end_symbol=None)

View File

@ -1,5 +1,5 @@
python preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --num_cpu=20
python compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="feats" --dumpdir=dump/train
python compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="feats" --output=dump/train/stats.npy
python normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --stats=dump/train/stats.npy
python normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --stats=dump/train/stats.npy

View File

@ -62,7 +62,6 @@ class SpeedySpeechUpdater(StandardUpdater):
loss.backward()
optimizer.step()
# import pdb; pdb.set_trace()
report("train/loss", float(loss))
report("train/l1_loss", float(l1_loss))
report("train/duration_loss", float(duration_loss))