diff --git a/examples/parallelwave_gan/baker/pwg_updater.py b/examples/parallelwave_gan/baker/pwg_updater.py index 3ab79b8..29f313b 100644 --- a/examples/parallelwave_gan/baker/pwg_updater.py +++ b/examples/parallelwave_gan/baker/pwg_updater.py @@ -78,15 +78,15 @@ class PWGUpdater(UpdaterBase): # Generator noise = paddle.randn(wav.shape) + _cuda_synchronize(place) with timer() as t: - _cuda_synchronize(place) wav_ = self.generator(noise, mel) _cuda_synchronize(place) print(f"Generator takes {t.elapse}s") ## Multi-resolution stft loss + _cuda_synchronize(place) with timer() as t: - _cuda_synchronize(place) sc_loss, mag_loss = self.criterion_stft( wav_.squeeze(1), wav.squeeze(1)) _cuda_synchronize(place) @@ -108,15 +108,16 @@ class PWGUpdater(UpdaterBase): gen_loss += self.lambda_adv * adv_loss report("train/generator_loss", float(gen_loss)) + + _cuda_synchronize(place) with timer() as t: - _cuda_synchronize(place) self.optimizer_g.clear_grad() gen_loss.backward() _cuda_synchronize(place) print(f"Backward takes {t.elapse}s.") + _cuda_synchronize(place) with timer() as t: - _cuda_synchronize(place) self.optimizer_g.step() self.scheduler_g.step() _cuda_synchronize(place)