Merge pull request #94 from iclementine/develop
fix bugs with multiprocess training.
This commit is contained in:
commit
a3de28cbe0
|
@ -124,6 +124,7 @@ class ExperimentBase(object):
|
||||||
"""
|
"""
|
||||||
dist.init_parallel_env()
|
dist.init_parallel_env()
|
||||||
|
|
||||||
|
@mp_tools.rank_zero_only
|
||||||
def save(self):
|
def save(self):
|
||||||
"""Save checkpoint (model parameters and optimizer states).
|
"""Save checkpoint (model parameters and optimizer states).
|
||||||
"""
|
"""
|
||||||
|
@ -195,17 +196,16 @@ class ExperimentBase(object):
|
||||||
self.save()
|
self.save()
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
def setup_output_dir(self):
|
def setup_output_dir(self):
|
||||||
"""Create a directory used for output.
|
"""Create a directory used for output.
|
||||||
"""
|
"""
|
||||||
# output dir
|
# output dir
|
||||||
output_dir = Path(self.args.output).expanduser()
|
output_dir = Path(self.args.output).expanduser()
|
||||||
|
if dist.get_rank() == 0:
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
|
|
||||||
@mp_tools.rank_zero_only
|
|
||||||
def setup_checkpointer(self):
|
def setup_checkpointer(self):
|
||||||
"""Create a directory used to save checkpoints into.
|
"""Create a directory used to save checkpoints into.
|
||||||
|
|
||||||
|
@ -213,6 +213,7 @@ class ExperimentBase(object):
|
||||||
"""
|
"""
|
||||||
# checkpoint dir
|
# checkpoint dir
|
||||||
checkpoint_dir = self.output_dir / "checkpoints"
|
checkpoint_dir = self.output_dir / "checkpoints"
|
||||||
|
if dist.get_rank() == 0:
|
||||||
checkpoint_dir.mkdir(exist_ok=True)
|
checkpoint_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
self.checkpoint_dir = checkpoint_dir
|
self.checkpoint_dir = checkpoint_dir
|
||||||
|
|
|
@ -20,11 +20,10 @@ __all__ = ["rank_zero_only"]
|
||||||
|
|
||||||
|
|
||||||
def rank_zero_only(func):
|
def rank_zero_only(func):
|
||||||
local_rank = dist.get_rank()
|
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
if local_rank != 0:
|
if dist.get_rank() != 0:
|
||||||
return
|
return
|
||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
Loading…
Reference in New Issue