Merge pull request #94 from iclementine/develop
fix bugs with multiprocess training.
This commit is contained in:
commit
a3de28cbe0
|
@ -124,6 +124,7 @@ class ExperimentBase(object):
|
|||
"""
|
||||
dist.init_parallel_env()
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def save(self):
|
||||
"""Save checkpoint (model parameters and optimizer states).
|
||||
"""
|
||||
|
@ -195,17 +196,16 @@ class ExperimentBase(object):
|
|||
self.save()
|
||||
exit(-1)
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def setup_output_dir(self):
|
||||
"""Create a directory used for output.
|
||||
"""
|
||||
# output dir
|
||||
output_dir = Path(self.args.output).expanduser()
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.output_dir = output_dir
|
||||
|
||||
@mp_tools.rank_zero_only
|
||||
def setup_checkpointer(self):
|
||||
"""Create a directory used to save checkpoints into.
|
||||
|
||||
|
@ -213,7 +213,8 @@ class ExperimentBase(object):
|
|||
"""
|
||||
# checkpoint dir
|
||||
checkpoint_dir = self.output_dir / "checkpoints"
|
||||
checkpoint_dir.mkdir(exist_ok=True)
|
||||
if dist.get_rank() == 0:
|
||||
checkpoint_dir.mkdir(exist_ok=True)
|
||||
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
|
||||
|
|
|
@ -20,11 +20,10 @@ __all__ = ["rank_zero_only"]
|
|||
|
||||
|
||||
def rank_zero_only(func):
|
||||
local_rank = dist.get_rank()
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if local_rank != 0:
|
||||
if dist.get_rank() != 0:
|
||||
return
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
|
Loading…
Reference in New Issue