add --device to cli argument, use gpu by default

This commit is contained in:
chenfeiyu 2021-06-30 13:57:05 +08:00
parent a93fad051c
commit dd6772bc3e
2 changed files with 6 additions and 5 deletions

View File

@ -38,6 +38,7 @@ parser.add_argument(
parser.add_argument("--checkpoint", type=str, help="snapshot to load") parser.add_argument("--checkpoint", type=str, help="snapshot to load")
parser.add_argument("--test-metadata", type=str, help="dev data") parser.add_argument("--test-metadata", type=str, help="dev data")
parser.add_argument("--output-dir", type=str, help="output dir") parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument("--device", type=str, default="gpu", help="device to run")
parser.add_argument("--verbose", type=int, default=1, help="verbose") parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args() args = parser.parse_args()
@ -45,11 +46,6 @@ config = get_cfg_default()
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)
if not paddle.is_compiled_with_cuda:
paddle.set_device("cpu")
else:
paddle.set_device("gpu:0")
print("========Args========") print("========Args========")
print(yaml.safe_dump(vars(args))) print(yaml.safe_dump(vars(args)))
print("========Config========") print("========Config========")
@ -58,6 +54,7 @@ print(
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}" f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
) )
paddle.set_device(args.device)
generator = PWGGenerator(**config["generator_params"]) generator = PWGGenerator(**config["generator_params"])
state_dict = paddle.load(args.checkpoint) state_dict = paddle.load(args.checkpoint)
generator.set_state_dict(state_dict["generator_params"]) generator.set_state_dict(state_dict["generator_params"])

View File

@ -214,11 +214,15 @@ def main():
parser.add_argument("--train-metadata", type=str, help="training data") parser.add_argument("--train-metadata", type=str, help="training data")
parser.add_argument("--dev-metadata", type=str, help="dev data") parser.add_argument("--dev-metadata", type=str, help="dev data")
parser.add_argument("--output-dir", type=str, help="output dir") parser.add_argument("--output-dir", type=str, help="output dir")
parser.add_argument(
"--device", type=str, default="gpu", help="device type to use")
parser.add_argument( parser.add_argument(
"--nprocs", type=int, default=1, help="number of processes") "--nprocs", type=int, default=1, help="number of processes")
parser.add_argument("--verbose", type=int, default=1, help="verbose") parser.add_argument("--verbose", type=int, default=1, help="verbose")
args = parser.parse_args() args = parser.parse_args()
if args.device == "cpu" and args.nprocs > 1:
raise RuntimeError("Multiprocess training on CPU is not supported.")
config = get_cfg_default() config = get_cfg_default()
if args.config: if args.config:
config.merge_from_file(args.config) config.merge_from_file(args.config)