1. use relative path in metadata.jsonl;
2. support key=value format to pass extra command line arguments to modify config values; 3. use path relative to the config.py to locate the default config.
This commit is contained in:
parent
b452586fcf
commit
9425c779a0
|
@ -89,6 +89,11 @@ def main():
|
||||||
|
|
||||||
with jsonlines.open(args.metadata, 'r') as reader:
|
with jsonlines.open(args.metadata, 'r') as reader:
|
||||||
metadata = list(reader)
|
metadata = list(reader)
|
||||||
|
|
||||||
|
metadata_dir = Path(args.metadata).parent
|
||||||
|
for item in metadata:
|
||||||
|
item["feats"] = str(metadata_dir / item["feats"])
|
||||||
|
|
||||||
dataset = DataTable(
|
dataset = DataTable(
|
||||||
metadata,
|
metadata,
|
||||||
fields=[args.field_name],
|
fields=[args.field_name],
|
||||||
|
|
|
@ -14,6 +14,9 @@
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from yacs.config import CfgNode as Configuration
|
from yacs.config import CfgNode as Configuration
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
config_path = (Path(__file__).parent / "conf" / "default.yaml").resolve()
|
||||||
|
|
||||||
with open("conf/default.yaml", 'rt') as f:
|
with open("conf/default.yaml", 'rt') as f:
|
||||||
_C = yaml.safe_load(f)
|
_C = yaml.safe_load(f)
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle
|
import paddle
|
||||||
import pypinyin
|
import pypinyin
|
||||||
|
@ -22,10 +24,11 @@ import phkit
|
||||||
phkit.initialize()
|
phkit.initialize()
|
||||||
from parakeet.frontend.vocab import Vocab
|
from parakeet.frontend.vocab import Vocab
|
||||||
|
|
||||||
with open("phones.txt", 'rt') as f:
|
file_dir = Path(__file__).parent.resolve()
|
||||||
|
with open(file_dir / "phones.txt", 'rt') as f:
|
||||||
phones = [line.strip() for line in f.readlines()]
|
phones = [line.strip() for line in f.readlines()]
|
||||||
|
|
||||||
with open("tones.txt", 'rt') as f:
|
with open(file_dir / "tones.txt", 'rt') as f:
|
||||||
tones = [line.strip() for line in f.readlines()]
|
tones = [line.strip() for line in f.readlines()]
|
||||||
voc_phones = Vocab(phones, start_symbol=None, end_symbol=None)
|
voc_phones = Vocab(phones, start_symbol=None, end_symbol=None)
|
||||||
voc_tones = Vocab(tones, start_symbol=None, end_symbol=None)
|
voc_tones = Vocab(tones, start_symbol=None, end_symbol=None)
|
||||||
|
|
|
@ -33,7 +33,7 @@ def main():
|
||||||
help="text to synthesize, a 'utt_id sentence' pair per line")
|
help="text to synthesize, a 'utt_id sentence' pair per line")
|
||||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
speedyspeech_config = inference.Config(
|
speedyspeech_config = inference.Config(
|
||||||
str(Path(args.inference_dir) / "speedyspeech.pdmodel"),
|
str(Path(args.inference_dir) / "speedyspeech.pdmodel"),
|
||||||
|
|
|
@ -96,6 +96,10 @@ def main():
|
||||||
# get dataset
|
# get dataset
|
||||||
with jsonlines.open(args.metadata, 'r') as reader:
|
with jsonlines.open(args.metadata, 'r') as reader:
|
||||||
metadata = list(reader)
|
metadata = list(reader)
|
||||||
|
metadata_dir = Path(args.metadata).parent
|
||||||
|
for item in metadata:
|
||||||
|
item["feats"] = str(metadata_dir / item["feats"])
|
||||||
|
|
||||||
dataset = DataTable(metadata, converters={'feats': np.load, })
|
dataset = DataTable(metadata, converters={'feats': np.load, })
|
||||||
logging.info(f"The number of files = {len(dataset)}.")
|
logging.info(f"The number of files = {len(dataset)}.")
|
||||||
|
|
||||||
|
@ -136,7 +140,7 @@ def main():
|
||||||
'num_phones': item['num_phones'],
|
'num_phones': item['num_phones'],
|
||||||
'num_frames': item['num_frames'],
|
'num_frames': item['num_frames'],
|
||||||
'durations': item['durations'],
|
'durations': item['durations'],
|
||||||
'feats': str(mel_path),
|
'feats': str(mel_path.relative_to(dumpdir)),
|
||||||
})
|
})
|
||||||
output_metadata.sort(key=itemgetter('utt_id'))
|
output_metadata.sort(key=itemgetter('utt_id'))
|
||||||
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
output_metadata_path = Path(args.dumpdir) / "metadata.jsonl"
|
||||||
|
|
|
@ -181,7 +181,7 @@ def process_sentence(config: Dict[str, Any],
|
||||||
"num_phones": len(phones),
|
"num_phones": len(phones),
|
||||||
"num_frames": num_frames,
|
"num_frames": num_frames,
|
||||||
"durations": durations_frame,
|
"durations": durations_frame,
|
||||||
"feats": str(mel_path.resolve()), # use absolute path
|
"feats": mel_path, # Path object
|
||||||
}
|
}
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
@ -212,8 +212,12 @@ def process_sentences(config,
|
||||||
results.append(ft.result())
|
results.append(ft.result())
|
||||||
|
|
||||||
results.sort(key=itemgetter("utt_id"))
|
results.sort(key=itemgetter("utt_id"))
|
||||||
with jsonlines.open(output_dir / "metadata.jsonl", 'w') as writer:
|
output_dir = Path(output_dir)
|
||||||
|
metadata_path = output_dir / "metadata.jsonl"
|
||||||
|
# NOTE: use relative path to the meta jsonlines file
|
||||||
|
with jsonlines.open(metadata_path, 'w') as writer:
|
||||||
for item in results:
|
for item in results:
|
||||||
|
item["feats"] = str(item["feats"].relative_to(output_dir))
|
||||||
writer.write(item)
|
writer.write(item)
|
||||||
print("Done")
|
print("Done")
|
||||||
|
|
||||||
|
|
|
@ -70,7 +70,6 @@ class SpeedySpeechUpdater(StandardUpdater):
|
||||||
|
|
||||||
class SpeedySpeechEvaluator(StandardEvaluator):
|
class SpeedySpeechEvaluator(StandardEvaluator):
|
||||||
def evaluate_core(self, batch):
|
def evaluate_core(self, batch):
|
||||||
print("fire")
|
|
||||||
decoded, predicted_durations = self.model(
|
decoded, predicted_durations = self.model(
|
||||||
text=batch["phones"],
|
text=batch["phones"],
|
||||||
tones=batch["tones"],
|
tones=batch["tones"],
|
||||||
|
|
|
@ -150,7 +150,7 @@ def main():
|
||||||
"--device", type=str, default="gpu", help="device type to use")
|
"--device", type=str, default="gpu", help="device type to use")
|
||||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args, _ = parser.parse_known_args()
|
||||||
with open(args.speedyspeech_config) as f:
|
with open(args.speedyspeech_config) as f:
|
||||||
speedyspeech_config = CfgNode(yaml.safe_load(f))
|
speedyspeech_config = CfgNode(yaml.safe_load(f))
|
||||||
with open(args.pwg_config) as f:
|
with open(args.pwg_config) as f:
|
||||||
|
|
|
@ -152,7 +152,7 @@ def main():
|
||||||
"--device", type=str, default="gpu", help="device type to use")
|
"--device", type=str, default="gpu", help="device type to use")
|
||||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args, _ = parser.parse_known_args()
|
||||||
with open(args.speedyspeech_config) as f:
|
with open(args.speedyspeech_config) as f:
|
||||||
speedyspeech_config = CfgNode(yaml.safe_load(f))
|
speedyspeech_config = CfgNode(yaml.safe_load(f))
|
||||||
with open(args.pwg_config) as f:
|
with open(args.pwg_config) as f:
|
||||||
|
|
|
@ -72,6 +72,10 @@ def train_sp(args, config):
|
||||||
# construct dataset for training and validation
|
# construct dataset for training and validation
|
||||||
with jsonlines.open(args.train_metadata, 'r') as reader:
|
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||||
train_metadata = list(reader)
|
train_metadata = list(reader)
|
||||||
|
metadata_dir = Path(args.train_metadata).parent
|
||||||
|
for item in train_metadata:
|
||||||
|
item["feats"] = str(metadata_dir / item["feats"])
|
||||||
|
|
||||||
train_dataset = DataTable(
|
train_dataset = DataTable(
|
||||||
data=train_metadata,
|
data=train_metadata,
|
||||||
fields=[
|
fields=[
|
||||||
|
@ -80,6 +84,9 @@ def train_sp(args, config):
|
||||||
converters={"feats": np.load, }, )
|
converters={"feats": np.load, }, )
|
||||||
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
with jsonlines.open(args.dev_metadata, 'r') as reader:
|
||||||
dev_metadata = list(reader)
|
dev_metadata = list(reader)
|
||||||
|
metadata_dir = Path(args.dev_metadata).parent
|
||||||
|
for item in dev_metadata:
|
||||||
|
item["feats"] = str(metadata_dir / item["feats"])
|
||||||
dev_dataset = DataTable(
|
dev_dataset = DataTable(
|
||||||
data=dev_metadata,
|
data=dev_metadata,
|
||||||
fields=[
|
fields=[
|
||||||
|
@ -113,9 +120,6 @@ def train_sp(args, config):
|
||||||
num_workers=config.num_workers)
|
num_workers=config.num_workers)
|
||||||
print("dataloaders done!")
|
print("dataloaders done!")
|
||||||
|
|
||||||
# batch = collate_baker_examples([train_dataset[i] for i in range(10)])
|
|
||||||
# # batch = collate_baker_examples([dev_dataset[i] for i in range(10)])
|
|
||||||
# import pdb; pdb.set_trace()
|
|
||||||
model = SpeedySpeech(**config["model"])
|
model = SpeedySpeech(**config["model"])
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
model = DataParallel(model) # TODO, do not use vocab size from config
|
model = DataParallel(model) # TODO, do not use vocab size from config
|
||||||
|
@ -141,7 +145,7 @@ def train_sp(args, config):
|
||||||
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
trainer.extend(VisualDL(writer), trigger=(1, "iteration"))
|
||||||
trainer.extend(
|
trainer.extend(
|
||||||
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
Snapshot(max_size=config.num_snapshots), trigger=(1, 'epoch'))
|
||||||
print(trainer.extensions)
|
# print(trainer.extensions)
|
||||||
trainer.run()
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,12 +164,18 @@ def main():
|
||||||
"--nprocs", type=int, default=1, help="number of processes")
|
"--nprocs", type=int, default=1, help="number of processes")
|
||||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args, rest = parser.parse_known_args()
|
||||||
if args.device == "cpu" and args.nprocs > 1:
|
if args.device == "cpu" and args.nprocs > 1:
|
||||||
raise RuntimeError("Multiprocess training on CPU is not supported.")
|
raise RuntimeError("Multiprocess training on CPU is not supported.")
|
||||||
config = get_cfg_default()
|
config = get_cfg_default()
|
||||||
if args.config:
|
if args.config:
|
||||||
config.merge_from_file(args.config)
|
config.merge_from_file(args.config)
|
||||||
|
if rest:
|
||||||
|
extra = []
|
||||||
|
# to support key=value format
|
||||||
|
for item in rest:
|
||||||
|
extra.extend(item.split("=", maxsplit=1))
|
||||||
|
config.merge_from_list(extra)
|
||||||
|
|
||||||
print("========Args========")
|
print("========Args========")
|
||||||
print(yaml.safe_dump(vars(args)))
|
print(yaml.safe_dump(vars(args)))
|
||||||
|
|
5
setup.py
5
setup.py
|
@ -64,17 +64,18 @@ setup_info = dict(
|
||||||
'scipy',
|
'scipy',
|
||||||
'pandas',
|
'pandas',
|
||||||
'sox',
|
'sox',
|
||||||
'soundfile',
|
'soundfile~=0.10',
|
||||||
'g2p_en',
|
'g2p_en',
|
||||||
'yacs',
|
'yacs',
|
||||||
'visualdl',
|
'visualdl',
|
||||||
'pypinyin',
|
'pypinyin',
|
||||||
'webrtcvad',
|
'webrtcvad',
|
||||||
'g2pM',
|
'g2pM',
|
||||||
'praatio',
|
'praatio~=4.1',
|
||||||
"h5py",
|
"h5py",
|
||||||
"timer",
|
"timer",
|
||||||
'jsonlines',
|
'jsonlines',
|
||||||
|
"phkit",
|
||||||
],
|
],
|
||||||
extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },
|
extras_require={'doc': ["sphinx", "sphinx-rtd-theme", "numpydoc"], },
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue