ParakeetRebeccaRosario/examples/parallelwave_gan/baker/pwg_updater.py

232 lines
8.4 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.optimizer.lr import LRScheduler
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.reporter import report
from parakeet.training.updaters.standard_updater import StandardUpdater
from parakeet.training.updaters.standard_updater import UpdaterState
from timer import timer
logging.basicConfig(
format='%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s',
datefmt='[%Y-%m-%d %H:%M:%S]')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class PWGUpdater(StandardUpdater):
def __init__(self,
models: Dict[str, Layer],
optimizers: Dict[str, Optimizer],
criterions: Dict[str, Layer],
schedulers: Dict[str, LRScheduler],
dataloader: DataLoader,
discriminator_train_start_steps: int,
lambda_adv: float,
output_dir=None):
self.models = models
self.generator: Layer = models['generator']
self.discriminator: Layer = models['discriminator']
self.optimizers = optimizers
self.optimizer_g: Optimizer = optimizers['generator']
self.optimizer_d: Optimizer = optimizers['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_mse = criterions['mse']
self.schedulers = schedulers
self.scheduler_g = schedulers['generator']
self.scheduler_d = schedulers['discriminator']
self.dataloader = dataloader
self.discriminator_train_start_steps = discriminator_train_start_steps
self.lambda_adv = lambda_adv
self.state = UpdaterState(iteration=0, epoch=0)
self.train_iterator = iter(self.dataloader)
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def update_core(self, batch):
self.msg = "Rank: {}, ".format(dist.get_rank())
losses_dict = {}
# parse batch
wav, mel = batch
# Generator
noise = paddle.randn(wav.shape)
with timer() as t:
wav_ = self.generator(noise, mel)
# logging.debug(f"Generator takes {t.elapse}s.")
# initialize
gen_loss = 0.0
## Multi-resolution stft loss
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss += sc_loss + mag_loss
## Adversarial loss
if self.state.iteration > self.discriminator_train_start_steps:
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
# logging.debug(
# f"Discriminator and adversarial loss takes {t.elapse}s")
report("train/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
with timer() as t:
self.optimizer_g.clear_grad()
gen_loss.backward()
# logging.debug(f"Backward takes {t.elapse}s.")
with timer() as t:
self.optimizer_g.step()
self.scheduler_g.step()
# logging.debug(f"Update takes {t.elapse}s.")
# Disctiminator
if self.state.iteration > self.discriminator_train_start_steps:
with paddle.no_grad():
wav_ = self.generator(noise, mel)
p = self.discriminator(wav)
p_ = self.discriminator(wav_.detach())
real_loss = self.criterion_mse(p, paddle.ones_like(p))
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
dis_loss = real_loss + fake_loss
report("train/real_loss", float(real_loss))
report("train/fake_loss", float(fake_loss))
report("train/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.optimizer_d.clear_grad()
dis_loss.backward()
self.optimizer_d.step()
self.scheduler_d.step()
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
class PWGEvaluator(StandardEvaluator):
def __init__(self,
models,
criterions,
dataloader,
lambda_adv,
output_dir=None):
self.models = models
self.generator = models['generator']
self.discriminator = models['discriminator']
self.criterions = criterions
self.criterion_stft = criterions['stft']
self.criterion_mse = criterions['mse']
self.dataloader = dataloader
self.lambda_adv = lambda_adv
log_file = output_dir / 'worker_{}.log'.format(dist.get_rank())
self.filehandler = logging.FileHandler(str(log_file))
logger.addHandler(self.filehandler)
self.logger = logger
self.msg = ""
def evaluate_core(self, batch):
# logging.debug("Evaluate: ")
self.msg = "Evaluate: "
losses_dict = {}
wav, mel = batch
noise = paddle.randn(wav.shape)
with timer() as t:
wav_ = self.generator(noise, mel)
# logging.debug(f"Generator takes {t.elapse}s")
## Adversarial loss
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
# logging.debug(
# f"Discriminator and adversarial loss takes {t.elapse}s")
report("eval/adversarial_loss", float(adv_loss))
losses_dict["adversarial_loss"] = float(adv_loss)
gen_loss = self.lambda_adv * adv_loss
# stft loss
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(wav_, wav)
# logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
report("eval/spectral_convergence_loss", float(sc_loss))
report("eval/log_stft_magnitude_loss", float(mag_loss))
losses_dict["spectral_convergence_loss"] = float(sc_loss)
losses_dict["log_stft_magnitude_loss"] = float(mag_loss)
gen_loss += sc_loss + mag_loss
report("eval/generator_loss", float(gen_loss))
losses_dict["generator_loss"] = float(gen_loss)
# Disctiminator
p = self.discriminator(wav)
real_loss = self.criterion_mse(p, paddle.ones_like(p))
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
dis_loss = real_loss + fake_loss
report("eval/real_loss", float(real_loss))
report("eval/fake_loss", float(fake_loss))
report("eval/discriminator_loss", float(dis_loss))
losses_dict["real_loss"] = float(real_loss)
losses_dict["fake_loss"] = float(fake_loss)
losses_dict["discriminator_loss"] = float(dis_loss)
self.msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_dict.items())
self.logger.info(self.msg)