avoid duplicated computation in validation, compute adversarial before stft loss.

This commit is contained in:
chenfeiyu 2021-07-01 16:48:34 +08:00
parent 3ebed00c96
commit e41423caf0
1 changed files with 11 additions and 15 deletions

View File

@ -153,17 +153,6 @@ class PWGEvaluator(StandardEvaluator):
wav_ = self.generator(noise, mel)
logging.debug(f"Generator takes {t.elapse}s")
## Multi-resolution stft loss
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(
wav_.squeeze(1), wav.squeeze(1))
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
report("eval/spectral_convergence_loss", float(sc_loss))
report("eval/log_stft_magnitude_loss", float(mag_loss))
gen_loss = sc_loss + mag_loss
## Adversarial loss
with timer() as t:
p_ = self.discriminator(wav_)
@ -171,15 +160,22 @@ class PWGEvaluator(StandardEvaluator):
logging.debug(
f"Discriminator and adversarial loss takes {t.elapse}s")
report("eval/adversarial_loss", float(adv_loss))
gen_loss += self.lambda_adv * adv_loss
gen_loss = self.lambda_adv * adv_loss
# stft loss
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(
wav_.squeeze(1), wav.squeeze(1))
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
report("eval/spectral_convergence_loss", float(sc_loss))
report("eval/log_stft_magnitude_loss", float(mag_loss))
gen_loss += sc_loss + mag_loss
report("eval/generator_loss", float(gen_loss))
# Disctiminator
with paddle.no_grad():
wav_ = self.generator(noise, mel)
p = self.discriminator(wav)
p_ = self.discriminator(wav_.detach())
real_loss = self.criterion_mse(p, paddle.ones_like(p))
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
report("eval/real_loss", float(real_loss))