diff --git a/examples/parallelwave_gan/baker/synthesize.py b/examples/parallelwave_gan/baker/synthesize.py index b805d60..01cfbbf 100644 --- a/examples/parallelwave_gan/baker/synthesize.py +++ b/examples/parallelwave_gan/baker/synthesize.py @@ -38,6 +38,7 @@ parser.add_argument( parser.add_argument("--checkpoint", type=str, help="snapshot to load") parser.add_argument("--test-metadata", type=str, help="dev data") parser.add_argument("--output-dir", type=str, help="output dir") +parser.add_argument("--device", type=str, default="gpu", help="device to run") parser.add_argument("--verbose", type=int, default=1, help="verbose") args = parser.parse_args() @@ -45,11 +46,6 @@ config = get_cfg_default() if args.config: config.merge_from_file(args.config) -if not paddle.is_compiled_with_cuda: - paddle.set_device("cpu") -else: - paddle.set_device("gpu:0") - print("========Args========") print(yaml.safe_dump(vars(args))) print("========Config========") @@ -58,6 +54,7 @@ print( f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" ) +paddle.set_device(args.device) generator = PWGGenerator(**config["generator_params"]) state_dict = paddle.load(args.checkpoint) generator.set_state_dict(state_dict["generator_params"]) diff --git a/examples/parallelwave_gan/baker/train.py b/examples/parallelwave_gan/baker/train.py index af787bb..3699e6f 100644 --- a/examples/parallelwave_gan/baker/train.py +++ b/examples/parallelwave_gan/baker/train.py @@ -214,11 +214,15 @@ def main(): parser.add_argument("--train-metadata", type=str, help="training data") parser.add_argument("--dev-metadata", type=str, help="dev data") parser.add_argument("--output-dir", type=str, help="output dir") + parser.add_argument( + "--device", type=str, default="gpu", help="device type to use") parser.add_argument( "--nprocs", type=int, default=1, help="number of processes") parser.add_argument("--verbose", type=int, default=1, help="verbose") args = parser.parse_args() + if args.device == "cpu" and args.nprocs > 1: + raise RuntimeError("Multiprocess training on CPU is not supported.") config = get_cfg_default() if args.config: config.merge_from_file(args.config)