Update benchmark script for waveflow
This commit is contained in:
parent
8b051486f1
commit
6ad45772ab
|
@ -2,13 +2,13 @@ import os
|
||||||
import random
|
import random
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
import jsonargparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from waveflow import WaveFlow
|
from parakeet.models.waveflow import WaveFlow
|
||||||
|
|
||||||
|
|
||||||
def add_options_to_parser(parser):
|
def add_options_to_parser(parser):
|
||||||
|
@ -42,7 +42,7 @@ def add_options_to_parser(parser):
|
||||||
|
|
||||||
|
|
||||||
def benchmark(config):
|
def benchmark(config):
|
||||||
pprint(jsonargparse.namespace_to_dict(config))
|
pprint(vars(config))
|
||||||
|
|
||||||
# Get checkpoint directory path.
|
# Get checkpoint directory path.
|
||||||
run_dir = os.path.join("runs", config.model, config.name)
|
run_dir = os.path.join("runs", config.model, config.name)
|
||||||
|
@ -70,9 +70,8 @@ def benchmark(config):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Create parser.
|
# Create parser.
|
||||||
parser = jsonargparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Synthesize audio using WaveNet model",
|
description="Synthesize audio using WaveNet model")
|
||||||
formatter_class='default_argparse')
|
|
||||||
add_options_to_parser(parser)
|
add_options_to_parser(parser)
|
||||||
utils.add_config_options_to_parser(parser)
|
utils.add_config_options_to_parser(parser)
|
||||||
|
|
||||||
|
@ -80,4 +79,5 @@ if __name__ == "__main__":
|
||||||
# For conflicting updates to the same field,
|
# For conflicting updates to the same field,
|
||||||
# the preceding update will be overwritten by the following one.
|
# the preceding update will be overwritten by the following one.
|
||||||
config = parser.parse_args()
|
config = parser.parse_args()
|
||||||
|
config = utils.add_yaml_config(config)
|
||||||
benchmark(config)
|
benchmark(config)
|
||||||
|
|
Loading…
Reference in New Issue