Update benchmark script for waveflow
This commit is contained in:
parent
8b051486f1
commit
6ad45772ab
|
@ -2,13 +2,13 @@ import os
|
|||
import random
|
||||
from pprint import pprint
|
||||
|
||||
import jsonargparse
|
||||
import argparse
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
|
||||
import utils
|
||||
from waveflow import WaveFlow
|
||||
from parakeet.models.waveflow import WaveFlow
|
||||
|
||||
|
||||
def add_options_to_parser(parser):
|
||||
|
@ -42,7 +42,7 @@ def add_options_to_parser(parser):
|
|||
|
||||
|
||||
def benchmark(config):
|
||||
pprint(jsonargparse.namespace_to_dict(config))
|
||||
pprint(vars(config))
|
||||
|
||||
# Get checkpoint directory path.
|
||||
run_dir = os.path.join("runs", config.model, config.name)
|
||||
|
@ -70,9 +70,8 @@ def benchmark(config):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Create parser.
|
||||
parser = jsonargparse.ArgumentParser(
|
||||
description="Synthesize audio using WaveNet model",
|
||||
formatter_class='default_argparse')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize audio using WaveNet model")
|
||||
add_options_to_parser(parser)
|
||||
utils.add_config_options_to_parser(parser)
|
||||
|
||||
|
@ -80,4 +79,5 @@ if __name__ == "__main__":
|
|||
# For conflicting updates to the same field,
|
||||
# the preceding update will be overwritten by the following one.
|
||||
config = parser.parse_args()
|
||||
config = utils.add_yaml_config(config)
|
||||
benchmark(config)
|
||||
|
|
Loading…
Reference in New Issue