diff --git a/examples/parallelwave_gan/baker/pwg_updater.py b/examples/parallelwave_gan/baker/pwg_updater.py index 838416a..dde7773 100644 --- a/examples/parallelwave_gan/baker/pwg_updater.py +++ b/examples/parallelwave_gan/baker/pwg_updater.py @@ -153,17 +153,6 @@ class PWGEvaluator(StandardEvaluator): wav_ = self.generator(noise, mel) logging.debug(f"Generator takes {t.elapse}s") - ## Multi-resolution stft loss - - with timer() as t: - sc_loss, mag_loss = self.criterion_stft( - wav_.squeeze(1), wav.squeeze(1)) - logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s") - - report("eval/spectral_convergence_loss", float(sc_loss)) - report("eval/log_stft_magnitude_loss", float(mag_loss)) - gen_loss = sc_loss + mag_loss - ## Adversarial loss with timer() as t: p_ = self.discriminator(wav_) @@ -171,15 +160,22 @@ class PWGEvaluator(StandardEvaluator): logging.debug( f"Discriminator and adversarial loss takes {t.elapse}s") report("eval/adversarial_loss", float(adv_loss)) - gen_loss += self.lambda_adv * adv_loss + gen_loss = self.lambda_adv * adv_loss + + # stft loss + with timer() as t: + sc_loss, mag_loss = self.criterion_stft( + wav_.squeeze(1), wav.squeeze(1)) + logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s") + + report("eval/spectral_convergence_loss", float(sc_loss)) + report("eval/log_stft_magnitude_loss", float(mag_loss)) + gen_loss += sc_loss + mag_loss report("eval/generator_loss", float(gen_loss)) # Disctiminator - with paddle.no_grad(): - wav_ = self.generator(noise, mel) p = self.discriminator(wav) - p_ = self.discriminator(wav_.detach()) real_loss = self.criterion_mse(p, paddle.ones_like(p)) fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_)) report("eval/real_loss", float(real_loss))