Merge pull request #94 from iclementine/develop

fix bugs with multiprocess training.
This commit is contained in:
Feiyu Chan 2021-02-18 19:48:56 +08:00 committed by GitHub
commit a3de28cbe0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 7 deletions

View File

@ -124,6 +124,7 @@ class ExperimentBase(object):
""" """
dist.init_parallel_env() dist.init_parallel_env()
@mp_tools.rank_zero_only
def save(self): def save(self):
"""Save checkpoint (model parameters and optimizer states). """Save checkpoint (model parameters and optimizer states).
""" """
@ -195,17 +196,16 @@ class ExperimentBase(object):
self.save() self.save()
exit(-1) exit(-1)
@mp_tools.rank_zero_only
def setup_output_dir(self): def setup_output_dir(self):
"""Create a directory used for output. """Create a directory used for output.
""" """
# output dir # output dir
output_dir = Path(self.args.output).expanduser() output_dir = Path(self.args.output).expanduser()
output_dir.mkdir(parents=True, exist_ok=True) if dist.get_rank() == 0:
output_dir.mkdir(parents=True, exist_ok=True)
self.output_dir = output_dir self.output_dir = output_dir
@mp_tools.rank_zero_only
def setup_checkpointer(self): def setup_checkpointer(self):
"""Create a directory used to save checkpoints into. """Create a directory used to save checkpoints into.
@ -213,7 +213,8 @@ class ExperimentBase(object):
""" """
# checkpoint dir # checkpoint dir
checkpoint_dir = self.output_dir / "checkpoints" checkpoint_dir = self.output_dir / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True) if dist.get_rank() == 0:
checkpoint_dir.mkdir(exist_ok=True)
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir

View File

@ -20,11 +20,10 @@ __all__ = ["rank_zero_only"]
def rank_zero_only(func): def rank_zero_only(func):
local_rank = dist.get_rank()
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if local_rank != 0: if dist.get_rank() != 0:
return return
result = func(*args, **kwargs) result = func(*args, **kwargs)
return result return result