Update benchmark script for waveflow

This commit is contained in:
liuyibing01 2020-02-24 02:51:06 +00:00
parent 8b051486f1
commit 6ad45772ab
1 changed files with 6 additions and 6 deletions

View File

@ -2,13 +2,13 @@ import os
import random import random
from pprint import pprint from pprint import pprint
import jsonargparse import argparse
import numpy as np import numpy as np
import paddle.fluid.dygraph as dg import paddle.fluid.dygraph as dg
from paddle import fluid from paddle import fluid
import utils import utils
from waveflow import WaveFlow from parakeet.models.waveflow import WaveFlow
def add_options_to_parser(parser): def add_options_to_parser(parser):
@ -42,7 +42,7 @@ def add_options_to_parser(parser):
def benchmark(config): def benchmark(config):
pprint(jsonargparse.namespace_to_dict(config)) pprint(vars(config))
# Get checkpoint directory path. # Get checkpoint directory path.
run_dir = os.path.join("runs", config.model, config.name) run_dir = os.path.join("runs", config.model, config.name)
@ -70,9 +70,8 @@ def benchmark(config):
if __name__ == "__main__": if __name__ == "__main__":
# Create parser. # Create parser.
parser = jsonargparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Synthesize audio using WaveNet model", description="Synthesize audio using WaveNet model")
formatter_class='default_argparse')
add_options_to_parser(parser) add_options_to_parser(parser)
utils.add_config_options_to_parser(parser) utils.add_config_options_to_parser(parser)
@ -80,4 +79,5 @@ if __name__ == "__main__":
# For conflicting updates to the same field, # For conflicting updates to the same field,
# the preceding update will be overwritten by the following one. # the preceding update will be overwritten by the following one.
config = parser.parse_args() config = parser.parse_args()
config = utils.add_yaml_config(config)
benchmark(config) benchmark(config)