2019-12-13 09:58:10 +08:00
|
|
|
import os
|
|
|
|
import random
|
|
|
|
from pprint import pprint
|
|
|
|
|
2020-02-24 10:35:19 +08:00
|
|
|
import argparse
|
2019-12-13 09:58:10 +08:00
|
|
|
import numpy as np
|
|
|
|
import paddle.fluid.dygraph as dg
|
|
|
|
from paddle import fluid
|
|
|
|
|
|
|
|
import utils
|
2020-02-24 10:35:19 +08:00
|
|
|
from parakeet.models.waveflow import WaveFlow
|
2019-12-13 09:58:10 +08:00
|
|
|
|
|
|
|
|
|
|
|
def add_options_to_parser(parser):
|
2020-02-24 03:00:17 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--model',
|
|
|
|
type=str,
|
|
|
|
default='waveflow',
|
2019-12-13 09:58:10 +08:00
|
|
|
help="general name of the model")
|
2020-02-24 03:00:17 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--name', type=str, help="specific name of the training model")
|
|
|
|
parser.add_argument(
|
|
|
|
'--root', type=str, help="root path of the LJSpeech dataset")
|
2019-12-13 09:58:10 +08:00
|
|
|
|
2020-02-24 03:00:17 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--use_gpu',
|
2020-02-25 23:53:54 +08:00
|
|
|
type=utils.str2bool,
|
2020-02-24 03:00:17 +08:00
|
|
|
default=True,
|
2019-12-13 09:58:10 +08:00
|
|
|
help="option to use gpu training")
|
2020-02-25 23:53:54 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--use_fp16',
|
|
|
|
type=utils.str2bool,
|
|
|
|
default=True,
|
|
|
|
help="option to use fp16 for inference")
|
2019-12-13 09:58:10 +08:00
|
|
|
|
2020-02-24 03:00:17 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--iteration',
|
|
|
|
type=int,
|
|
|
|
default=None,
|
2019-12-13 09:58:10 +08:00
|
|
|
help=("which iteration of checkpoint to load, "
|
|
|
|
"default to load the latest checkpoint"))
|
2020-02-24 03:00:17 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--checkpoint',
|
|
|
|
type=str,
|
|
|
|
default=None,
|
2019-12-13 09:58:10 +08:00
|
|
|
help="path of the checkpoint to load")
|
|
|
|
|
2020-02-24 03:00:17 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--output',
|
|
|
|
type=str,
|
|
|
|
default="./syn_audios",
|
2019-12-13 09:58:10 +08:00
|
|
|
help="path to write synthesized audio files")
|
2020-02-24 03:00:17 +08:00
|
|
|
parser.add_argument(
|
|
|
|
'--sample',
|
|
|
|
type=int,
|
|
|
|
default=None,
|
2019-12-13 09:58:10 +08:00
|
|
|
help="which of the valid samples to synthesize audio")
|
|
|
|
|
|
|
|
|
|
|
|
def synthesize(config):
|
2020-02-24 10:35:19 +08:00
|
|
|
pprint(vars(config))
|
2019-12-13 09:58:10 +08:00
|
|
|
|
|
|
|
# Get checkpoint directory path.
|
|
|
|
run_dir = os.path.join("runs", config.model, config.name)
|
|
|
|
checkpoint_dir = os.path.join(run_dir, "checkpoint")
|
|
|
|
|
|
|
|
# Configurate device.
|
|
|
|
place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace()
|
|
|
|
|
|
|
|
with dg.guard(place):
|
|
|
|
# Fix random seed.
|
|
|
|
seed = config.seed
|
|
|
|
random.seed(seed)
|
|
|
|
np.random.seed(seed)
|
|
|
|
fluid.default_startup_program().random_seed = seed
|
|
|
|
fluid.default_main_program().random_seed = seed
|
|
|
|
print("Random Seed: ", seed)
|
2020-02-24 03:00:17 +08:00
|
|
|
|
2019-12-13 09:58:10 +08:00
|
|
|
# Build model.
|
2019-12-17 08:42:39 +08:00
|
|
|
model = WaveFlow(config, checkpoint_dir)
|
2019-12-13 09:58:10 +08:00
|
|
|
model.build(training=False)
|
|
|
|
# Obtain the current iteration.
|
|
|
|
if config.checkpoint is None:
|
|
|
|
if config.iteration is None:
|
|
|
|
iteration = utils.load_latest_checkpoint(checkpoint_dir)
|
|
|
|
else:
|
|
|
|
iteration = config.iteration
|
|
|
|
else:
|
|
|
|
iteration = int(config.checkpoint.split('/')[-1].split('-')[-1])
|
|
|
|
|
|
|
|
# Run model inference.
|
|
|
|
model.infer(iteration)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# Create parser.
|
2020-02-24 10:35:19 +08:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Synthesize audio using WaveNet model")
|
2019-12-13 09:58:10 +08:00
|
|
|
add_options_to_parser(parser)
|
|
|
|
utils.add_config_options_to_parser(parser)
|
|
|
|
|
|
|
|
# Parse argument from both command line and yaml config file.
|
|
|
|
# For conflicting updates to the same field,
|
|
|
|
# the preceding update will be overwritten by the following one.
|
|
|
|
config = parser.parse_args()
|
2020-02-24 10:35:19 +08:00
|
|
|
config = utils.add_yaml_config(config)
|
2019-12-13 09:58:10 +08:00
|
|
|
synthesize(config)
|