diff --git a/examples/waveflow/train.py b/examples/waveflow/train.py index 443cc8b..c64ace6 100644 --- a/examples/waveflow/train.py +++ b/examples/waveflow/train.py @@ -46,7 +46,7 @@ class Experiment(ExperimentBase): n_mels=config.data.n_mels, kernel_size=config.model.kernel_size) - if self.parallel > 1: + if self.parallel: model = paddle.DataParallel(model) optimizer = paddle.optimizer.Adam( config.training.lr, parameters=model.parameters()) diff --git a/examples/wavenet/train.py b/examples/wavenet/train.py index 166e23d..51d000a 100644 --- a/examples/wavenet/train.py +++ b/examples/wavenet/train.py @@ -49,7 +49,7 @@ class Experiment(ExperimentBase): loss_type=config.model.loss_type, log_scale_min=config.model.log_scale_min) - if self.parallel > 1: + if self.parallel: model = paddle.DataParallel(model) lr_scheduler = paddle.optimizer.lr.StepDecay(