auto choose device for inference
This commit is contained in:
parent
af26c1e389
commit
a93fad051c
|
@ -26,8 +26,6 @@ import numpy as np
|
|||
import soundfile as sf
|
||||
from paddle import distributed as dist
|
||||
|
||||
paddle.set_device("cpu")
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||
|
||||
|
@ -37,7 +35,7 @@ parser = argparse.ArgumentParser(
|
|||
description="synthesize with parallel wavegan.")
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="config file to overwrite default config")
|
||||
parser.add_argument("--params", type=str, help="generator parameter file")
|
||||
parser.add_argument("--checkpoint", type=str, help="snapshot to load")
|
||||
parser.add_argument("--test-metadata", type=str, help="dev data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
|
@ -47,6 +45,11 @@ config = get_cfg_default()
|
|||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
if not paddle.is_compiled_with_cuda:
|
||||
paddle.set_device("cpu")
|
||||
else:
|
||||
paddle.set_device("gpu:0")
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
|
@ -56,8 +59,8 @@ print(
|
|||
)
|
||||
|
||||
generator = PWGGenerator(**config["generator_params"])
|
||||
state_dict = paddle.load(args.params)
|
||||
generator.set_state_dict(state_dict)
|
||||
state_dict = paddle.load(args.checkpoint)
|
||||
generator.set_state_dict(state_dict["generator_params"])
|
||||
|
||||
generator.remove_weight_norm()
|
||||
generator.eval()
|
||||
|
|
Loading…
Reference in New Issue