add profiling
This commit is contained in:
parent
95f64c4f02
commit
bbbe5a8b50
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
from timer import timer
|
||||||
|
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
from parakeet.training.updater import UpdaterBase, UpdaterState
|
from parakeet.training.updater import UpdaterBase, UpdaterState
|
||||||
|
@ -60,39 +61,54 @@ class PWGUpdater(UpdaterBase):
|
||||||
self.train_iterator = iter(self.train_dataloader)
|
self.train_iterator = iter(self.train_dataloader)
|
||||||
|
|
||||||
def update_core(self):
|
def update_core(self):
|
||||||
try:
|
with timer() as t:
|
||||||
batch = next(self.train_iterator)
|
try:
|
||||||
except StopIteration:
|
batch = next(self.train_iterator)
|
||||||
self.train_iterator = iter(self.train_dataloader)
|
except StopIteration:
|
||||||
batch = next(self.train_iterator)
|
self.train_iterator = iter(self.train_dataloader)
|
||||||
|
batch = next(self.train_iterator)
|
||||||
|
print(f"Loading a batch takes {t.elapse}s")
|
||||||
|
|
||||||
wav, mel = batch
|
wav, mel = batch
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
noise = paddle.randn(wav.shape)
|
noise = paddle.randn(wav.shape)
|
||||||
wav_ = self.generator(noise, mel)
|
|
||||||
|
with timer() as t:
|
||||||
|
wav_ = self.generator(noise, mel)
|
||||||
|
print(f"Generator takes {t.elapse}s")
|
||||||
|
|
||||||
## Multi-resolution stft loss
|
## Multi-resolution stft loss
|
||||||
sc_loss, mag_loss = self.criterion_stft(
|
with timer() as t:
|
||||||
wav_.squeeze(1), wav.squeeze(1))
|
sc_loss, mag_loss = self.criterion_stft(
|
||||||
|
wav_.squeeze(1), wav.squeeze(1))
|
||||||
|
print(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||||
|
|
||||||
report("train/spectral_convergence_loss", float(sc_loss))
|
report("train/spectral_convergence_loss", float(sc_loss))
|
||||||
report("train/log_stft_magnitude_loss", float(mag_loss))
|
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||||
gen_loss = sc_loss + mag_loss
|
gen_loss = sc_loss + mag_loss
|
||||||
|
|
||||||
## Adversarial loss
|
## Adversarial loss
|
||||||
if self.state.iteration > self.discriminator_train_start_steps:
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
p_ = self.discriminator(wav_)
|
with timer() as t:
|
||||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
p_ = self.discriminator(wav_)
|
||||||
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
|
print(f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
report("train/adversarial_loss", float(adv_loss))
|
report("train/adversarial_loss", float(adv_loss))
|
||||||
gen_loss += self.lambda_adv * adv_loss
|
gen_loss += self.lambda_adv * adv_loss
|
||||||
|
|
||||||
report("train/generator_loss", float(gen_loss))
|
report("train/generator_loss", float(gen_loss))
|
||||||
self.optimizer_g.clear_grad()
|
with timer() as t:
|
||||||
gen_loss.backward()
|
self.optimizer_g.clear_grad()
|
||||||
self.optimizer_g.step()
|
gen_loss.backward()
|
||||||
self.scheduler_g.step()
|
print(f"Backward takes {t.elapse}s.")
|
||||||
|
|
||||||
# Disctiminator
|
with timer() as t:
|
||||||
|
self.optimizer_g.step()
|
||||||
|
self.scheduler_g.step()
|
||||||
|
print(f"Update takes {t.elapse}s.")
|
||||||
|
|
||||||
|
# Disctiminator
|
||||||
if self.state.iteration > self.discriminator_train_start_steps:
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
wav_ = self.generator(noise, mel)
|
wav_ = self.generator(noise, mel)
|
||||||
|
|
|
@ -122,13 +122,18 @@ def train_sp(args, config):
|
||||||
print("criterions done!")
|
print("criterions done!")
|
||||||
|
|
||||||
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
|
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
|
||||||
|
gradient_clip_g = nn.ClipGradByGlobalNorm(config["generator_grad_norm"])
|
||||||
optimizer_g = Adam(
|
optimizer_g = Adam(
|
||||||
lr_schedule_g,
|
learning_rate=lr_schedule_g,
|
||||||
|
grad_clip=gradient_clip_g,
|
||||||
parameters=generator.parameters(),
|
parameters=generator.parameters(),
|
||||||
**config["generator_optimizer_params"])
|
**config["generator_optimizer_params"])
|
||||||
lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"])
|
lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"])
|
||||||
|
gradient_clip_d = nn.ClipGradByGlobalNorm(config[
|
||||||
|
"discriminator_grad_norm"])
|
||||||
optimizer_d = Adam(
|
optimizer_d = Adam(
|
||||||
lr_schedule_d,
|
learning_rate=lr_schedule_d,
|
||||||
|
grad_clip=gradient_clip_d,
|
||||||
parameters=discriminator.parameters(),
|
parameters=discriminator.parameters(),
|
||||||
**config["discriminator_optimizer_params"])
|
**config["discriminator_optimizer_params"])
|
||||||
print("optimizers done!")
|
print("optimizers done!")
|
||||||
|
@ -165,9 +170,11 @@ def train_sp(args, config):
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
updater,
|
updater,
|
||||||
stop_trigger=(config.train_max_steps, "iteration"),
|
stop_trigger=(10, "iteration"), # PROFILING
|
||||||
out=output_dir, )
|
out=output_dir, )
|
||||||
trainer.run()
|
with paddle.fluid.profiler.cuda_profiler(
|
||||||
|
str(output_dir / "profiler.log")) as prof:
|
||||||
|
trainer.run()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
Loading…
Reference in New Issue