avoid duplicated computation in validation, compute adversarial before stft loss.
This commit is contained in:
parent
3ebed00c96
commit
e41423caf0
|
@ -153,17 +153,6 @@ class PWGEvaluator(StandardEvaluator):
|
||||||
wav_ = self.generator(noise, mel)
|
wav_ = self.generator(noise, mel)
|
||||||
logging.debug(f"Generator takes {t.elapse}s")
|
logging.debug(f"Generator takes {t.elapse}s")
|
||||||
|
|
||||||
## Multi-resolution stft loss
|
|
||||||
|
|
||||||
with timer() as t:
|
|
||||||
sc_loss, mag_loss = self.criterion_stft(
|
|
||||||
wav_.squeeze(1), wav.squeeze(1))
|
|
||||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
|
||||||
|
|
||||||
report("eval/spectral_convergence_loss", float(sc_loss))
|
|
||||||
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
|
||||||
gen_loss = sc_loss + mag_loss
|
|
||||||
|
|
||||||
## Adversarial loss
|
## Adversarial loss
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
p_ = self.discriminator(wav_)
|
p_ = self.discriminator(wav_)
|
||||||
|
@ -171,15 +160,22 @@ class PWGEvaluator(StandardEvaluator):
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
report("eval/adversarial_loss", float(adv_loss))
|
report("eval/adversarial_loss", float(adv_loss))
|
||||||
gen_loss += self.lambda_adv * adv_loss
|
gen_loss = self.lambda_adv * adv_loss
|
||||||
|
|
||||||
|
# stft loss
|
||||||
|
with timer() as t:
|
||||||
|
sc_loss, mag_loss = self.criterion_stft(
|
||||||
|
wav_.squeeze(1), wav.squeeze(1))
|
||||||
|
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||||
|
|
||||||
|
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||||
|
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
||||||
|
gen_loss += sc_loss + mag_loss
|
||||||
|
|
||||||
report("eval/generator_loss", float(gen_loss))
|
report("eval/generator_loss", float(gen_loss))
|
||||||
|
|
||||||
# Disctiminator
|
# Disctiminator
|
||||||
with paddle.no_grad():
|
|
||||||
wav_ = self.generator(noise, mel)
|
|
||||||
p = self.discriminator(wav)
|
p = self.discriminator(wav)
|
||||||
p_ = self.discriminator(wav_.detach())
|
|
||||||
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
||||||
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
||||||
report("eval/real_loss", float(real_loss))
|
report("eval/real_loss", float(real_loss))
|
||||||
|
|
Loading…
Reference in New Issue