avoid duplicated computation in validation, compute adversarial before stft loss.
This commit is contained in:
parent
3ebed00c96
commit
e41423caf0
|
@ -153,17 +153,6 @@ class PWGEvaluator(StandardEvaluator):
|
|||
wav_ = self.generator(noise, mel)
|
||||
logging.debug(f"Generator takes {t.elapse}s")
|
||||
|
||||
## Multi-resolution stft loss
|
||||
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||
|
||||
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
||||
gen_loss = sc_loss + mag_loss
|
||||
|
||||
## Adversarial loss
|
||||
with timer() as t:
|
||||
p_ = self.discriminator(wav_)
|
||||
|
@ -171,15 +160,22 @@ class PWGEvaluator(StandardEvaluator):
|
|||
logging.debug(
|
||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
report("eval/adversarial_loss", float(adv_loss))
|
||||
gen_loss += self.lambda_adv * adv_loss
|
||||
gen_loss = self.lambda_adv * adv_loss
|
||||
|
||||
# stft loss
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||
|
||||
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
||||
gen_loss += sc_loss + mag_loss
|
||||
|
||||
report("eval/generator_loss", float(gen_loss))
|
||||
|
||||
# Disctiminator
|
||||
with paddle.no_grad():
|
||||
wav_ = self.generator(noise, mel)
|
||||
p = self.discriminator(wav)
|
||||
p_ = self.discriminator(wav_.detach())
|
||||
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
||||
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
||||
report("eval/real_loss", float(real_loss))
|
||||
|
|
Loading…
Reference in New Issue