Update synthesis script for waveflow
This commit is contained in:
parent
7635493a0a
commit
8b051486f1
|
@ -2,13 +2,13 @@ import os
|
|||
import random
|
||||
from pprint import pprint
|
||||
|
||||
import jsonargparse
|
||||
import argparse
|
||||
import numpy as np
|
||||
import paddle.fluid.dygraph as dg
|
||||
from paddle import fluid
|
||||
|
||||
import utils
|
||||
from waveflow import WaveFlow
|
||||
from parakeet.models.waveflow import WaveFlow
|
||||
|
||||
|
||||
def add_options_to_parser(parser):
|
||||
|
@ -53,7 +53,7 @@ def add_options_to_parser(parser):
|
|||
|
||||
|
||||
def synthesize(config):
|
||||
pprint(jsonargparse.namespace_to_dict(config))
|
||||
pprint(vars(config))
|
||||
|
||||
# Get checkpoint directory path.
|
||||
run_dir = os.path.join("runs", config.model, config.name)
|
||||
|
@ -90,9 +90,8 @@ def synthesize(config):
|
|||
|
||||
if __name__ == "__main__":
|
||||
# Create parser.
|
||||
parser = jsonargparse.ArgumentParser(
|
||||
description="Synthesize audio using WaveNet model",
|
||||
formatter_class='default_argparse')
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Synthesize audio using WaveNet model")
|
||||
add_options_to_parser(parser)
|
||||
utils.add_config_options_to_parser(parser)
|
||||
|
||||
|
@ -100,4 +99,5 @@ if __name__ == "__main__":
|
|||
# For conflicting updates to the same field,
|
||||
# the preceding update will be overwritten by the following one.
|
||||
config = parser.parse_args()
|
||||
config = utils.add_yaml_config(config)
|
||||
synthesize(config)
|
||||
|
|
|
@ -84,7 +84,6 @@ def add_config_options_to_parser(parser):
|
|||
|
||||
|
||||
def add_yaml_config(config):
|
||||
print(config)
|
||||
with open(config.config, 'rt') as f:
|
||||
yaml_cfg = ruamel.yaml.safe_load(f)
|
||||
cfg_vars = vars(config)
|
||||
|
|
|
@ -152,7 +152,8 @@ class WaveFlow():
|
|||
sample = config.sample
|
||||
|
||||
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
|
||||
os.makedirs(output, exist_ok=True)
|
||||
if not os.path.exists(output):
|
||||
os.makedirs(output)
|
||||
|
||||
mels_list = [mels for _, mels in self.validloader()]
|
||||
if sample is not None:
|
||||
|
|
Loading…
Reference in New Issue