add more cli options

This commit is contained in:
chenfeiyu 2021-07-19 13:03:23 +08:00
parent 51397f8500
commit 4ba8e7e342
4 changed files with 20 additions and 15 deletions

View File

@ -16,6 +16,7 @@
import argparse import argparse
import logging import logging
import os import os
from pathlib import Path
import numpy as np import numpy as np
import yaml import yaml
@ -45,10 +46,10 @@ def main():
parser.add_argument( parser.add_argument(
"--config", type=str, help="yaml format configuration file.") "--config", type=str, help="yaml format configuration file.")
parser.add_argument( parser.add_argument(
"--dumpdir", "--output",
type=str, type=str,
help="directory to save statistics. if not provided, " help="path to save statistics. if not provided, "
"stats will be saved in the above root directory. (default=None)") "stats will be saved in the above root directory with name stats.npy")
parser.add_argument( parser.add_argument(
"--verbose", "--verbose",
type=int, type=int,
@ -80,10 +81,11 @@ def main():
config.merge_from_file(args.config) config.merge_from_file(args.config)
# check directory existence # check directory existence
if args.dumpdir is None: if args.output is None:
args.dumpdir = os.path.dirname(args.metadata) args.output = Path(args.metadata).parent.with_name("stats.npy")
if not os.path.exists(args.dumpdir): else:
os.makedirs(args.dumpdir) args.output = Path(args.output)
args.output.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(args.metadata, 'r') as reader: with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader) metadata = list(reader)
@ -100,10 +102,7 @@ def main():
scaler.partial_fit(datum[args.field_name]) scaler.partial_fit(datum[args.field_name])
stats = np.stack([scaler.mean_, scaler.scale_], axis=0) stats = np.stack([scaler.mean_, scaler.scale_], axis=0)
np.save( np.save(str(args.output), stats.astype(np.float32), allow_pickle=False)
os.path.join(args.dumpdir, "stats.npy"),
stats.astype(np.float32),
allow_pickle=False)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -50,6 +50,13 @@ def main():
help="directory to dump normalized feature files.") help="directory to dump normalized feature files.")
parser.add_argument( parser.add_argument(
"--stats", type=str, required=True, help="statistics file.") "--stats", type=str, required=True, help="statistics file.")
parser.add_argument(
"--phones",
type=str,
default="phones.txt",
help="phone vocabulary file.")
parser.add_argument(
"--tones", type=str, default="tones.txt", help="tone vocabulary file.")
parser.add_argument( parser.add_argument(
"--config", type=str, help="yaml format configuration file.") "--config", type=str, help="yaml format configuration file.")
parser.add_argument( parser.add_argument(
@ -100,10 +107,10 @@ def main():
# from version 0.23.0, this information is needed # from version 0.23.0, this information is needed
scaler.n_features_in_ = scaler.mean_.shape[0] scaler.n_features_in_ = scaler.mean_.shape[0]
with open("phones.txt", 'rt') as f: with open(args.phones, 'rt') as f:
phones = [line.strip() for line in f.readlines()] phones = [line.strip() for line in f.readlines()]
with open("tones.txt", 'rt') as f: with open(args.tones, 'rt') as f:
tones = [line.strip() for line in f.readlines()] tones = [line.strip() for line in f.readlines()]
voc_phones = Vocab(phones, start_symbol=None, end_symbol=None) voc_phones = Vocab(phones, start_symbol=None, end_symbol=None)
voc_tones = Vocab(tones, start_symbol=None, end_symbol=None) voc_tones = Vocab(tones, start_symbol=None, end_symbol=None)

View File

@ -1,5 +1,5 @@
python preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --num_cpu=20 python preprocess.py --rootdir=~/datasets/BZNSYP/ --dumpdir=dump --num_cpu=20
python compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="feats" --dumpdir=dump/train python compute_statistics.py --metadata=dump/train/raw/metadata.jsonl --field-name="feats" --output=dump/train/stats.npy
python normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --stats=dump/train/stats.npy python normalize.py --metadata=dump/train/raw/metadata.jsonl --dumpdir=dump/train/norm --stats=dump/train/stats.npy
python normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --stats=dump/train/stats.npy python normalize.py --metadata=dump/dev/raw/metadata.jsonl --dumpdir=dump/dev/norm --stats=dump/train/stats.npy

View File

@ -62,7 +62,6 @@ class SpeedySpeechUpdater(StandardUpdater):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# import pdb; pdb.set_trace()
report("train/loss", float(loss)) report("train/loss", float(loss))
report("train/l1_loss", float(l1_loss)) report("train/l1_loss", float(l1_loss))
report("train/duration_loss", float(duration_loss)) report("train/duration_loss", float(duration_loss))