Update synthesis script for waveflow
This commit is contained in:
parent
7635493a0a
commit
8b051486f1
|
@ -2,13 +2,13 @@ import os
|
||||||
import random
|
import random
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
import jsonargparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import paddle.fluid.dygraph as dg
|
import paddle.fluid.dygraph as dg
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
|
|
||||||
import utils
|
import utils
|
||||||
from waveflow import WaveFlow
|
from parakeet.models.waveflow import WaveFlow
|
||||||
|
|
||||||
|
|
||||||
def add_options_to_parser(parser):
|
def add_options_to_parser(parser):
|
||||||
|
@ -53,7 +53,7 @@ def add_options_to_parser(parser):
|
||||||
|
|
||||||
|
|
||||||
def synthesize(config):
|
def synthesize(config):
|
||||||
pprint(jsonargparse.namespace_to_dict(config))
|
pprint(vars(config))
|
||||||
|
|
||||||
# Get checkpoint directory path.
|
# Get checkpoint directory path.
|
||||||
run_dir = os.path.join("runs", config.model, config.name)
|
run_dir = os.path.join("runs", config.model, config.name)
|
||||||
|
@ -90,9 +90,8 @@ def synthesize(config):
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Create parser.
|
# Create parser.
|
||||||
parser = jsonargparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Synthesize audio using WaveNet model",
|
description="Synthesize audio using WaveNet model")
|
||||||
formatter_class='default_argparse')
|
|
||||||
add_options_to_parser(parser)
|
add_options_to_parser(parser)
|
||||||
utils.add_config_options_to_parser(parser)
|
utils.add_config_options_to_parser(parser)
|
||||||
|
|
||||||
|
@ -100,4 +99,5 @@ if __name__ == "__main__":
|
||||||
# For conflicting updates to the same field,
|
# For conflicting updates to the same field,
|
||||||
# the preceding update will be overwritten by the following one.
|
# the preceding update will be overwritten by the following one.
|
||||||
config = parser.parse_args()
|
config = parser.parse_args()
|
||||||
|
config = utils.add_yaml_config(config)
|
||||||
synthesize(config)
|
synthesize(config)
|
||||||
|
|
|
@ -84,7 +84,6 @@ def add_config_options_to_parser(parser):
|
||||||
|
|
||||||
|
|
||||||
def add_yaml_config(config):
|
def add_yaml_config(config):
|
||||||
print(config)
|
|
||||||
with open(config.config, 'rt') as f:
|
with open(config.config, 'rt') as f:
|
||||||
yaml_cfg = ruamel.yaml.safe_load(f)
|
yaml_cfg = ruamel.yaml.safe_load(f)
|
||||||
cfg_vars = vars(config)
|
cfg_vars = vars(config)
|
||||||
|
|
|
@ -152,7 +152,8 @@ class WaveFlow():
|
||||||
sample = config.sample
|
sample = config.sample
|
||||||
|
|
||||||
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
|
output = "{}/{}/iter-{}".format(config.output, config.name, iteration)
|
||||||
os.makedirs(output, exist_ok=True)
|
if not os.path.exists(output):
|
||||||
|
os.makedirs(output)
|
||||||
|
|
||||||
mels_list = [mels for _, mels in self.validloader()]
|
mels_list = [mels for _, mels in self.validloader()]
|
||||||
if sample is not None:
|
if sample is not None:
|
||||||
|
|
Loading…
Reference in New Issue