From f423323bae209a33b7b9ee414317b5fbb43e8a5a Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 18 Feb 2021 19:09:54 +0800 Subject: [PATCH 1/2] fix bugs with multiprocess training. --- parakeet/training/experiment.py | 10 +++++----- parakeet/utils/mp_tools.py | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/parakeet/training/experiment.py b/parakeet/training/experiment.py index 16da93d..482e012 100644 --- a/parakeet/training/experiment.py +++ b/parakeet/training/experiment.py @@ -195,17 +195,16 @@ class ExperimentBase(object): self.save() exit(-1) - @mp_tools.rank_zero_only def setup_output_dir(self): """Create a directory used for output. """ # output dir output_dir = Path(self.args.output).expanduser() - output_dir.mkdir(parents=True, exist_ok=True) - + if dist.get_rank() == 0: + output_dir.mkdir(parents=True, exist_ok=True) + self.output_dir = output_dir - @mp_tools.rank_zero_only def setup_checkpointer(self): """Create a directory used to save checkpoints into. @@ -213,7 +212,8 @@ class ExperimentBase(object): """ # checkpoint dir checkpoint_dir = self.output_dir / "checkpoints" - checkpoint_dir.mkdir(exist_ok=True) + if dist.get_rank() == 0: + checkpoint_dir.mkdir(exist_ok=True) self.checkpoint_dir = checkpoint_dir diff --git a/parakeet/utils/mp_tools.py b/parakeet/utils/mp_tools.py index a4bc97a..0b0782c 100644 --- a/parakeet/utils/mp_tools.py +++ b/parakeet/utils/mp_tools.py @@ -20,11 +20,10 @@ __all__ = ["rank_zero_only"] def rank_zero_only(func): - local_rank = dist.get_rank() @wraps(func) def wrapper(*args, **kwargs): - if local_rank != 0: + if dist.get_rank() != 0: return result = func(*args, **kwargs) return result From 0af7402daa94660d3b8002f37be6a5fcf9cce9f4 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 18 Feb 2021 19:33:41 +0800 Subject: [PATCH 2/2] add rank_zero_only for ExperimentBase.save --- parakeet/training/experiment.py | 1 + 1 file changed, 1 insertion(+) diff --git a/parakeet/training/experiment.py b/parakeet/training/experiment.py index 482e012..94caa66 100644 --- a/parakeet/training/experiment.py +++ b/parakeet/training/experiment.py @@ -124,6 +124,7 @@ class ExperimentBase(object): """ dist.init_parallel_env() + @mp_tools.rank_zero_only def save(self): """Save checkpoint (model parameters and optimizer states). """