Update synthesis script for waveflow

This commit is contained in:
liuyibing01 2020-02-24 02:35:19 +00:00
parent 7635493a0a
commit 8b051486f1
3 changed files with 8 additions and 8 deletions

View File

@ -2,13 +2,13 @@ import os
import random
from pprint import pprint
import jsonargparse
import argparse
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
import utils
from waveflow import WaveFlow
from parakeet.models.waveflow import WaveFlow
def add_options_to_parser(parser):
@ -53,7 +53,7 @@ def add_options_to_parser(parser):
def synthesize(config):
pprint(jsonargparse.namespace_to_dict(config))
pprint(vars(config))
# Get checkpoint directory path.
run_dir = os.path.join("runs", config.model, config.name)
@ -90,9 +90,8 @@ def synthesize(config):
if __name__ == "__main__":
# Create parser.
parser = jsonargparse.ArgumentParser(
description="Synthesize audio using WaveNet model",
formatter_class='default_argparse')
parser = argparse.ArgumentParser(
description="Synthesize audio using WaveNet model")
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
@ -100,4 +99,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
config = utils.add_yaml_config(config)
synthesize(config)

View File

@ -84,7 +84,6 @@ def add_config_options_to_parser(parser):
def add_yaml_config(config):
print(config)
with open(config.config, 'rt') as f:
yaml_cfg = ruamel.yaml.safe_load(f)
cfg_vars = vars(config)

View File

@ -152,7 +152,8 @@ class WaveFlow():
sample = config.sample
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
os.makedirs(output, exist_ok=True)
if not os.path.exists(output):
os.makedirs(output)
mels_list = [mels for _, mels in self.validloader()]
if sample is not None: