add profiling

This commit is contained in:
chenfeiyu 2021-06-16 14:47:30 +08:00
parent 95f64c4f02
commit bbbe5a8b50
2 changed files with 42 additions and 19 deletions

View File

@ -13,6 +13,7 @@
# limitations under the License.
import paddle
from timer import timer
from parakeet.datasets.data_table import DataTable
from parakeet.training.updater import UpdaterBase, UpdaterState
@ -60,39 +61,54 @@ class PWGUpdater(UpdaterBase):
self.train_iterator = iter(self.train_dataloader)
def update_core(self):
try:
batch = next(self.train_iterator)
except StopIteration:
self.train_iterator = iter(self.train_dataloader)
batch = next(self.train_iterator)
with timer() as t:
try:
batch = next(self.train_iterator)
except StopIteration:
self.train_iterator = iter(self.train_dataloader)
batch = next(self.train_iterator)
print(f"Loading a batch takes {t.elapse}s")
wav, mel = batch
# Generator
noise = paddle.randn(wav.shape)
wav_ = self.generator(noise, mel)
with timer() as t:
wav_ = self.generator(noise, mel)
print(f"Generator takes {t.elapse}s")
## Multi-resolution stft loss
sc_loss, mag_loss = self.criterion_stft(
wav_.squeeze(1), wav.squeeze(1))
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(
wav_.squeeze(1), wav.squeeze(1))
print(f"Multi-resolution STFT loss takes {t.elapse}s")
report("train/spectral_convergence_loss", float(sc_loss))
report("train/log_stft_magnitude_loss", float(mag_loss))
gen_loss = sc_loss + mag_loss
## Adversarial loss
if self.state.iteration > self.discriminator_train_start_steps:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
print(f"Discriminator and adversarial loss takes {t.elapse}s")
report("train/adversarial_loss", float(adv_loss))
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss))
self.optimizer_g.clear_grad()
gen_loss.backward()
self.optimizer_g.step()
self.scheduler_g.step()
with timer() as t:
self.optimizer_g.clear_grad()
gen_loss.backward()
print(f"Backward takes {t.elapse}s.")
# Disctiminator
with timer() as t:
self.optimizer_g.step()
self.scheduler_g.step()
print(f"Update takes {t.elapse}s.")
# Disctiminator
if self.state.iteration > self.discriminator_train_start_steps:
with paddle.no_grad():
wav_ = self.generator(noise, mel)

View File

@ -122,13 +122,18 @@ def train_sp(args, config):
print("criterions done!")
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
gradient_clip_g = nn.ClipGradByGlobalNorm(config["generator_grad_norm"])
optimizer_g = Adam(
lr_schedule_g,
learning_rate=lr_schedule_g,
grad_clip=gradient_clip_g,
parameters=generator.parameters(),
**config["generator_optimizer_params"])
lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"])
gradient_clip_d = nn.ClipGradByGlobalNorm(config[
"discriminator_grad_norm"])
optimizer_d = Adam(
lr_schedule_d,
learning_rate=lr_schedule_d,
grad_clip=gradient_clip_d,
parameters=discriminator.parameters(),
**config["discriminator_optimizer_params"])
print("optimizers done!")
@ -165,9 +170,11 @@ def train_sp(args, config):
trainer = Trainer(
updater,
stop_trigger=(config.train_max_steps, "iteration"),
stop_trigger=(10, "iteration"), # PROFILING
out=output_dir, )
trainer.run()
with paddle.fluid.profiler.cuda_profiler(
str(output_dir / "profiler.log")) as prof:
trainer.run()
def main():