add --device to cli argument, use gpu by default
This commit is contained in:
parent
a93fad051c
commit
dd6772bc3e
|
@ -38,6 +38,7 @@ parser.add_argument(
|
|||
parser.add_argument("--checkpoint", type=str, help="snapshot to load")
|
||||
parser.add_argument("--test-metadata", type=str, help="dev data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument("--device", type=str, default="gpu", help="device to run")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
@ -45,11 +46,6 @@ config = get_cfg_default()
|
|||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
||||
if not paddle.is_compiled_with_cuda:
|
||||
paddle.set_device("cpu")
|
||||
else:
|
||||
paddle.set_device("gpu:0")
|
||||
|
||||
print("========Args========")
|
||||
print(yaml.safe_dump(vars(args)))
|
||||
print("========Config========")
|
||||
|
@ -58,6 +54,7 @@ print(
|
|||
f"master see the word size: {dist.get_world_size()}, from pid: {os.getpid()}"
|
||||
)
|
||||
|
||||
paddle.set_device(args.device)
|
||||
generator = PWGGenerator(**config["generator_params"])
|
||||
state_dict = paddle.load(args.checkpoint)
|
||||
generator.set_state_dict(state_dict["generator_params"])
|
||||
|
|
|
@ -214,11 +214,15 @@ def main():
|
|||
parser.add_argument("--train-metadata", type=str, help="training data")
|
||||
parser.add_argument("--dev-metadata", type=str, help="dev data")
|
||||
parser.add_argument("--output-dir", type=str, help="output dir")
|
||||
parser.add_argument(
|
||||
"--device", type=str, default="gpu", help="device type to use")
|
||||
parser.add_argument(
|
||||
"--nprocs", type=int, default=1, help="number of processes")
|
||||
parser.add_argument("--verbose", type=int, default=1, help="verbose")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.device == "cpu" and args.nprocs > 1:
|
||||
raise RuntimeError("Multiprocess training on CPU is not supported.")
|
||||
config = get_cfg_default()
|
||||
if args.config:
|
||||
config.merge_from_file(args.config)
|
||||
|
|
Loading…
Reference in New Issue