From a93fad051c37df8548011d5153ba9bbf854d4c2a Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Wed, 30 Jun 2021 13:50:41 +0800 Subject: [PATCH] auto choose device for inference --- examples/parallelwave_gan/baker/synthesize.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/examples/parallelwave_gan/baker/synthesize.py b/examples/parallelwave_gan/baker/synthesize.py index fd84e85..b805d60 100644 --- a/examples/parallelwave_gan/baker/synthesize.py +++ b/examples/parallelwave_gan/baker/synthesize.py @@ -26,8 +26,6 @@ import numpy as np import soundfile as sf from paddle import distributed as dist -paddle.set_device("cpu") - from parakeet.datasets.data_table import DataTable from parakeet.models.parallel_wavegan import PWGGenerator @@ -37,7 +35,7 @@ parser = argparse.ArgumentParser( description="synthesize with parallel wavegan.") parser.add_argument( "--config", type=str, help="config file to overwrite default config") -parser.add_argument("--params", type=str, help="generator parameter file") +parser.add_argument("--checkpoint", type=str, help="snapshot to load") parser.add_argument("--test-metadata", type=str, help="dev data") parser.add_argument("--output-dir", type=str, help="output dir") parser.add_argument("--verbose", type=int, default=1, help="verbose") @@ -47,6 +45,11 @@ config = get_cfg_default() if args.config: config.merge_from_file(args.config) +if not paddle.is_compiled_with_cuda: + paddle.set_device("cpu") +else: + paddle.set_device("gpu:0") + print("========Args========") print(yaml.safe_dump(vars(args))) print("========Config========") @@ -56,8 +59,8 @@ print( ) generator = PWGGenerator(**config["generator_params"]) -state_dict = paddle.load(args.params) -generator.set_state_dict(state_dict) +state_dict = paddle.load(args.checkpoint) +generator.set_state_dict(state_dict["generator_params"]) generator.remove_weight_norm() generator.eval()