diff --git a/examples/waveflow/train.py b/examples/waveflow/train.py index c64ace6..ba973a0 100644 --- a/examples/waveflow/train.py +++ b/examples/waveflow/train.py @@ -118,8 +118,9 @@ class Experiment(ExperimentBase): iteration_time) msg += "loss: {:>.6f}".format(loss_value) self.logger.info(msg) - self.visualizer.add_scalar( - "train/loss", loss_value, global_step=self.iteration) + if dist.get_rank() == 0: + self.visualizer.add_scalar( + "train/loss", loss_value, global_step=self.iteration) @mp_tools.rank_zero_only @paddle.no_grad() diff --git a/examples/wavenet/train.py b/examples/wavenet/train.py index b62e4a3..8a42e6f 100644 --- a/examples/wavenet/train.py +++ b/examples/wavenet/train.py @@ -131,8 +131,9 @@ class Experiment(ExperimentBase): iteration_time) msg += "loss: {:>.6f}".format(loss_value) self.logger.info(msg) - self.visualizer.add_scalar( - "train/loss", loss_value, global_step=self.iteration) + if dist.get_rank() == 0: + self.visualizer.add_scalar( + "train/loss", loss_value, global_step=self.iteration) @mp_tools.rank_zero_only @paddle.no_grad()