diff --git a/examples/waveflow/synthesis.py b/examples/waveflow/synthesis.py index 3ec0da7..1e3fb9e 100644 --- a/examples/waveflow/synthesis.py +++ b/examples/waveflow/synthesis.py @@ -2,13 +2,13 @@ import os import random from pprint import pprint -import jsonargparse +import argparse import numpy as np import paddle.fluid.dygraph as dg from paddle import fluid import utils -from waveflow import WaveFlow +from parakeet.models.waveflow import WaveFlow def add_options_to_parser(parser): @@ -53,7 +53,7 @@ def add_options_to_parser(parser): def synthesize(config): - pprint(jsonargparse.namespace_to_dict(config)) + pprint(vars(config)) # Get checkpoint directory path. run_dir = os.path.join("runs", config.model, config.name) @@ -90,9 +90,8 @@ def synthesize(config): if __name__ == "__main__": # Create parser. - parser = jsonargparse.ArgumentParser( - description="Synthesize audio using WaveNet model", - formatter_class='default_argparse') + parser = argparse.ArgumentParser( + description="Synthesize audio using WaveNet model") add_options_to_parser(parser) utils.add_config_options_to_parser(parser) @@ -100,4 +99,5 @@ if __name__ == "__main__": # For conflicting updates to the same field, # the preceding update will be overwritten by the following one. config = parser.parse_args() + config = utils.add_yaml_config(config) synthesize(config) diff --git a/examples/waveflow/utils.py b/examples/waveflow/utils.py index e8ed372..c088b1d 100644 --- a/examples/waveflow/utils.py +++ b/examples/waveflow/utils.py @@ -84,7 +84,6 @@ def add_config_options_to_parser(parser): def add_yaml_config(config): - print(config) with open(config.config, 'rt') as f: yaml_cfg = ruamel.yaml.safe_load(f) cfg_vars = vars(config) diff --git a/parakeet/models/waveflow/waveflow.py b/parakeet/models/waveflow/waveflow.py index 351b287..569086e 100644 --- a/parakeet/models/waveflow/waveflow.py +++ b/parakeet/models/waveflow/waveflow.py @@ -152,7 +152,8 @@ class WaveFlow(): sample = config.sample output = "{}/{}/iter-{}".format(config.output, config.name, iteration) - os.makedirs(output, exist_ok=True) + if not os.path.exists(output): + os.makedirs(output) mels_list = [mels for _, mels in self.validloader()] if sample is not None: