auto choose device for inference

This commit is contained in:
chenfeiyu 2021-06-30 13:50:41 +08:00
parent af26c1e389
commit a93fad051c
1 changed files with 8 additions and 5 deletions

View File

@ -26,8 +26,6 @@ import numpy as np
import soundfile as sf
from paddle import distributed as dist
paddle.set_device("cpu")
from parakeet.datasets.data_table import DataTable
from parakeet.models.parallel_wavegan import PWGGenerator
@ -37,7 +35,7 @@ parser = argparse.ArgumentParser(
description="synthesize with parallel wavegan.")
parser.add_argument(
"--config", type=str, help="config file to overwrite default config")
parser.add_argument("--params", type=str, help="generator parameter file")
parser.add_argument("--checkpoint", type=str, help="snapshot to load")
parser.add_argument("--test-metadata", type=str, help="dev data")
parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument("--verbose", type=int, default=1, help="verbose")
@ -47,6 +45,11 @@ config = get_cfg_default()
if args.config:
config.merge_from_file(args.config)
if not paddle.is_compiled_with_cuda:
paddle.set_device("cpu")
else:
paddle.set_device("gpu:0")
print("========Args========")
print(yaml.safe_dump(vars(args)))
print("========Config========")
@ -56,8 +59,8 @@ print(
)
generator = PWGGenerator(**config["generator_params"])
state_dict = paddle.load(args.params)
generator.set_state_dict(state_dict)
state_dict = paddle.load(args.checkpoint)
generator.set_state_dict(state_dict["generator_params"])
generator.remove_weight_norm()
generator.eval()