From e53b9a0745794ea5a043c39f3778e015be86690f Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Mon, 11 Jan 2021 17:14:48 +0800 Subject: [PATCH] fix: the condition to init DataParallel --- examples/waveflow/train.py | 2 +- examples/wavenet/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/waveflow/train.py b/examples/waveflow/train.py index 443cc8b..c64ace6 100644 --- a/examples/waveflow/train.py +++ b/examples/waveflow/train.py @@ -46,7 +46,7 @@ class Experiment(ExperimentBase): n_mels=config.data.n_mels, kernel_size=config.model.kernel_size) - if self.parallel > 1: + if self.parallel: model = paddle.DataParallel(model) optimizer = paddle.optimizer.Adam( config.training.lr, parameters=model.parameters()) diff --git a/examples/wavenet/train.py b/examples/wavenet/train.py index 166e23d..51d000a 100644 --- a/examples/wavenet/train.py +++ b/examples/wavenet/train.py @@ -49,7 +49,7 @@ class Experiment(ExperimentBase): loss_type=config.model.loss_type, log_scale_min=config.model.log_scale_min) - if self.parallel > 1: + if self.parallel: model = paddle.DataParallel(model) lr_scheduler = paddle.optimizer.lr.StepDecay(