Parakeet/examples/waveflow/benchmark.py

104 lines
3.0 KiB
Python
Raw Normal View History

2020-02-26 21:03:51 +08:00
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2019-12-19 16:03:06 +08:00
import os
import random
from pprint import pprint
2020-02-24 10:51:06 +08:00
import argparse
2019-12-19 16:03:06 +08:00
import numpy as np
import paddle.fluid.dygraph as dg
from paddle import fluid
import utils
2020-03-22 16:05:05 +08:00
from parakeet.utils import io
2020-03-26 17:27:22 +08:00
from waveflow import WaveFlow
2019-12-19 16:03:06 +08:00
def add_options_to_parser(parser):
parser.add_argument(
'--model',
type=str,
default='waveflow',
2019-12-19 16:03:06 +08:00
help="general name of the model")
parser.add_argument(
'--name', type=str, help="specific name of the training model")
parser.add_argument(
'--root', type=str, help="root path of the LJSpeech dataset")
2019-12-19 16:03:06 +08:00
parser.add_argument(
'--use_gpu',
2020-02-25 23:53:54 +08:00
type=utils.str2bool,
default=True,
2019-12-19 16:03:06 +08:00
help="option to use gpu training")
2020-02-25 23:53:54 +08:00
parser.add_argument(
'--use_fp16',
type=utils.str2bool,
default=True,
help="option to use fp16 for inference")
2019-12-19 16:03:06 +08:00
parser.add_argument(
'--iteration',
type=int,
default=None,
2019-12-19 16:03:06 +08:00
help=("which iteration of checkpoint to load, "
"default to load the latest checkpoint"))
parser.add_argument(
'--checkpoint',
type=str,
default=None,
2019-12-19 16:03:06 +08:00
help="path of the checkpoint to load")
def benchmark(config):
2020-02-24 10:51:06 +08:00
pprint(vars(config))
2019-12-19 16:03:06 +08:00
# Get checkpoint directory path.
run_dir = os.path.join("runs", config.model, config.name)
checkpoint_dir = os.path.join(run_dir, "checkpoint")
# Configurate device.
place = fluid.CUDAPlace(0) if config.use_gpu else fluid.CPUPlace()
with dg.guard(place):
# Fix random seed.
seed = config.seed
random.seed(seed)
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
print("Random Seed: ", seed)
2019-12-19 16:03:06 +08:00
# Build model.
model = WaveFlow(config, checkpoint_dir)
model.build(training=False)
# Run model inference.
model.benchmark()
if __name__ == "__main__":
# Create parser.
2020-02-24 10:51:06 +08:00
parser = argparse.ArgumentParser(
description="Synthesize audio using WaveNet model")
2019-12-19 16:03:06 +08:00
add_options_to_parser(parser)
utils.add_config_options_to_parser(parser)
# Parse argument from both command line and yaml config file.
# For conflicting updates to the same field,
# the preceding update will be overwritten by the following one.
config = parser.parse_args()
2020-03-22 16:05:05 +08:00
config = io.add_yaml_config_to_args(config)
2019-12-19 16:03:06 +08:00
benchmark(config)