ParakeetRebeccaRosario/examples/parallelwave_gan/baker/pwg_updater.py

143 lines
5.2 KiB
Python
Raw Normal View History

2021-06-14 17:05:37 +08:00
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
2021-06-16 17:40:47 +08:00
from paddle.fluid.core import _cuda_synchronize
2021-06-16 14:47:30 +08:00
from timer import timer
2021-06-14 17:05:37 +08:00
from parakeet.datasets.data_table import DataTable
from parakeet.training.updater import UpdaterBase, UpdaterState
from parakeet.training.trainer import Trainer
from parakeet.training.reporter import report
from parakeet.training.checkpoint import KBest, KLatest
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
class PWGUpdater(UpdaterBase):
def __init__(
self,
models,
optimizers,
criterions,
schedulers,
dataloaders,
discriminator_train_start_steps: int,
lambda_adv: float, ):
self.models = models
self.generator = models['generator']
self.discriminator = models['discriminator']
self.optimizers = optimizers
self.optimizer_g = optimizers['generator']
self.optimizer_d = optimizers['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_mse = criterions['mse']
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_d = schedulers['discriminator']
self.dataloaders = dataloaders
self.train_dataloader = dataloaders['train']
self.dev_dataloader = dataloaders['dev']
self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.train_dataloader)
def update_core(self):
2021-06-16 17:40:47 +08:00
place = paddle.fluid.framework._current_expected_place()
2021-06-16 14:47:30 +08:00
with timer() as t:
2021-06-16 17:40:47 +08:00
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
try:
batch = next(self.train_iterator)
except StopIteration:
self.train_iterator = iter(self.train_dataloader)
batch = next(self.train_iterator)
2021-06-16 17:40:47 +08:00
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
print(f"Loading a batch takes {t.elapse}s")
2021-06-14 17:05:37 +08:00
wav, mel = batch
# Generator
noise = paddle.randn(wav.shape)
2021-06-16 14:47:30 +08:00
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
with timer() as t:
wav_ = self.generator(noise, mel)
2021-06-16 17:40:47 +08:00
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
print(f"Generator takes {t.elapse}s")
2021-06-14 17:05:37 +08:00
## Multi-resolution stft loss
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(
wav_.squeeze(1), wav.squeeze(1))
2021-06-16 17:40:47 +08:00
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
print(f"Multi-resolution STFT loss takes {t.elapse}s")
2021-06-14 17:05:37 +08:00
report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss))
gen_loss = sc_loss + mag_loss
## Adversarial loss
if self.state.iteration > self.discriminator_train_start_steps:
2021-06-16 14:47:30 +08:00
with timer() as t:
2021-06-16 17:40:47 +08:00
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
2021-06-16 17:40:47 +08:00
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
print(f"Discriminator and adversarial loss takes {t.elapse}s")
2021-06-14 17:05:37 +08:00
report("train/adversarial_loss", float(adv_loss))
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss))
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
with timer() as t:
self.optimizer_g.clear_grad()
gen_loss.backward()
2021-06-16 17:40:47 +08:00
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
print(f"Backward takes {t.elapse}s.")
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
with timer() as t:
self.optimizer_g.step()
self.scheduler_g.step()
2021-06-16 17:40:47 +08:00
_cuda_synchronize(place)
2021-06-16 14:47:30 +08:00
print(f"Update takes {t.elapse}s.")
2021-06-14 17:05:37 +08:00
2021-06-16 17:40:47 +08:00
# Disctiminator
2021-06-14 17:05:37 +08:00
if self.state.iteration > self.discriminator_train_start_steps:
with paddle.no_grad():
wav_ = self.generator(noise, mel)
p = self.discriminator(wav)
p_ = self.discriminator(wav_.detach())
real_loss = self.criterion_mse(p, paddle.ones_like(p))
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss))
dis_loss = real_loss + fake_loss
report("train/discriminator_loss", float(dis_loss))
self.optimizer_d.clear_grad()
dis_loss.backward()
self.optimizer_d.step()
self.scheduler_d.step()