1. use relative path in metadata.jsonl;

2. support key=value format to pass extra command line arguments to modify config values;
3. use path relative to the config.py to locate the default config.
This commit is contained in:
chenfeiyu 2021-08-16 10:01:51 +08:00
parent b452586fcf
commit 9425c779a0
11 changed files with 45 additions and 16 deletions

View File

@ -89,6 +89,11 @@ def main():
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
metadata_dir = Path(args.metadata).parent
for item in metadata:
item["feats"] = str(metadata_dir / item["feats"])
dataset = DataTable(
metadata,
fields=[args.field_name],

View File

@ -14,6 +14,9 @@
import yaml
from yacs.config import CfgNode as Configuration
from pathlib import Path
config_path = (Path(__file__).parent / "conf" / "default.yaml").resolve()
with open("conf/default.yaml", 'rt') as f:
_C = yaml.safe_load(f)

View File

@ -13,6 +13,8 @@
# limitations under the License.
import re
from pathlib import Path
import numpy as np
import paddle
import pypinyin
@ -22,10 +24,11 @@ import phkit
phkit.initialize()
from parakeet.frontend.vocab import Vocab
with open("phones.txt", 'rt') as f:
file_dir = Path(__file__).parent.resolve()
with open(file_dir / "phones.txt", 'rt') as f:
phones = [line.strip() for line in f.readlines()]
with open("tones.txt", 'rt') as f:
with open(file_dir / "tones.txt", 'rt') as f:
tones = [line.strip() for line in f.readlines()]
voc_phones = Vocab(phones, start_symbol=None, end_symbol=None)
voc_tones = Vocab(tones, start_symbol=None, end_symbol=None)

View File

@ -33,7 +33,7 @@ def main():
help="text to synthesize, a 'utt_id sentence' pair per line")
parser.add_argument("--output-dir", type=str, help="output dir")
args = parser.parse_args()
args, _ = parser.parse_known_args()
speedyspeech_config = inference.Config(
str(Path(args.inference_dir) / "speedyspeech.pdmodel"),

View File

@ -96,6 +96,10 @@ def main():
# get dataset
with jsonlines.open(args.metadata, 'r') as reader:
metadata = list(reader)
metadata_dir = Path(args.metadata).parent
for item in metadata:
item["feats"] = str(metadata_dir / item["feats"])
dataset = DataTable(metadata, converters={'feats': np.load, })
logging.info(f"The number of files = {len(dataset)}.")
@ -136,7 +140,7 @@ def main():
'num_phones': item['num_phones'],
'num_frames': item['num_frames'],
'durations': item['durations'],
'feats': str(mel_path),
'feats': str(mel_path.relative_to(dumpdir)),
})
output_metadata.sort(key=itemgetter('utt_id'))
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"

View File

@ -181,7 +181,7 @@ def process_sentence(config: Dict[str, Any],
"num_phones": len(phones),
"num_frames": num_frames,
"durations": durations_frame,
"feats": str(mel_path.resolve()), # use absolute path
"feats": mel_path, # Path object
}
return record
@ -212,8 +212,12 @@ def process_sentences(config,
results.append(ft.result())
results.sort(key=itemgetter("utt_id"))
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
output_dir = Path(output_dir)
metadata_path = output_dir / "metadata.jsonl"
# NOTE: use relative path to the meta jsonlines file
with jsonlines.open(metadata_path, 'w') as writer:
for item in results:
item["feats"] = str(item["feats"].relative_to(output_dir))
writer.write(item)
print("Done")

View File

@ -70,7 +70,6 @@ class SpeedySpeechUpdater(StandardUpdater):
class SpeedySpeechEvaluator(StandardEvaluator):
def evaluate_core(self, batch):
print("fire")
decoded, predicted_durations = self.model(
text=batch["phones"],
tones=batch["tones"],

View File

@ -150,7 +150,7 @@ def main():
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args()
args, _ = parser.parse_known_args()
with open(args.speedyspeech_config) as f:
speedyspeech_config = CfgNode(yaml.safe_load(f))
with open(args.pwg_config) as f:

View File

@ -152,7 +152,7 @@ def main():
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args()
args, _ = parser.parse_known_args()
with open(args.speedyspeech_config) as f:
speedyspeech_config = CfgNode(yaml.safe_load(f))
with open(args.pwg_config) as f:

View File

@ -72,6 +72,10 @@ def train_sp(args, config):
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
metadata_dir = Path(args.train_metadata).parent
for item in train_metadata:
item["feats"] = str(metadata_dir / item["feats"])
train_dataset = DataTable(
data=train_metadata,
fields=[
@ -80,6 +84,9 @@ def train_sp(args, config):
converters={"feats": np.load, }, )
with jsonlines.open(args.dev_metadata, 'r') as reader:
dev_metadata = list(reader)
metadata_dir = Path(args.dev_metadata).parent
for item in dev_metadata:
item["feats"] = str(metadata_dir / item["feats"])
dev_dataset = DataTable(
data=dev_metadata,
fields=[
@ -113,9 +120,6 @@ def train_sp(args, config):
num_workers=config.num_workers)
print("dataloaders done!")
# batch = collate_baker_examples([train_dataset[i] for i in range(10)])
# # batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
# import pdb; pdb.set_trace()
model = SpeedySpeech(**config["model"])
if world_size > 1:
model = DataParallel(model) # TODO, do not use vocab size from config
@ -141,7 +145,7 @@ def train_sp(args, config):
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
trainer.extend(
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
print(trainer.extensions)
# print(trainer.extensions)
trainer.run()
@ -160,12 +164,18 @@ def main():
"--nprocs", type=int, default=1, help="number of processes")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args()
args, rest = parser.parse_known_args()
if args.device == "cpu" and args.nprocs > 1:
raise RuntimeError("Multiprocess training on CPU is not supported.")
config = get_cfg_default()
if args.config:
config.merge_from_file(args.config)
if rest:
extra = []
# to support key=value format
for item in rest:
extra.extend(item.split("=", maxsplit=1))
config.merge_from_list(extra)
print("========Args========")
print(yaml.safe_dump(vars(args)))

View File

@ -64,17 +64,18 @@ setup_info = dict(
'scipy',
'pandas',
'sox',
'soundfile',
'soundfile~=0.10',
'g2p_en',
'yacs',
'visualdl',
'pypinyin',
'webrtcvad',
'g2pM',
'praatio',
'praatio~=4.1',
"h5py",
"timer",
'jsonlines',
"phkit",
],
extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },