add profiling
This commit is contained in:
parent
95f64c4f02
commit
bbbe5a8b50
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from timer import timer
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.training.updater import UpdaterBase, UpdaterState
|
||||
|
@ -60,39 +61,54 @@ class PWGUpdater(UpdaterBase):
|
|||
self.train_iterator = iter(self.train_dataloader)
|
||||
|
||||
def update_core(self):
|
||||
try:
|
||||
batch = next(self.train_iterator)
|
||||
except StopIteration:
|
||||
self.train_iterator = iter(self.train_dataloader)
|
||||
batch = next(self.train_iterator)
|
||||
with timer() as t:
|
||||
try:
|
||||
batch = next(self.train_iterator)
|
||||
except StopIteration:
|
||||
self.train_iterator = iter(self.train_dataloader)
|
||||
batch = next(self.train_iterator)
|
||||
print(f"Loading a batch takes {t.elapse}s")
|
||||
|
||||
wav, mel = batch
|
||||
|
||||
# Generator
|
||||
noise = paddle.randn(wav.shape)
|
||||
wav_ = self.generator(noise, mel)
|
||||
|
||||
with timer() as t:
|
||||
wav_ = self.generator(noise, mel)
|
||||
print(f"Generator takes {t.elapse}s")
|
||||
|
||||
## Multi-resolution stft loss
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
print(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||
|
||||
report("train/spectral_convergence_loss", float(sc_loss))
|
||||
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||
gen_loss = sc_loss + mag_loss
|
||||
|
||||
## Adversarial loss
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||
with timer() as t:
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||
print(f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
report("train/adversarial_loss", float(adv_loss))
|
||||
gen_loss += self.lambda_adv * adv_loss
|
||||
|
||||
report("train/generator_loss", float(gen_loss))
|
||||
self.optimizer_g.clear_grad()
|
||||
gen_loss.backward()
|
||||
self.optimizer_g.step()
|
||||
self.scheduler_g.step()
|
||||
with timer() as t:
|
||||
self.optimizer_g.clear_grad()
|
||||
gen_loss.backward()
|
||||
print(f"Backward takes {t.elapse}s.")
|
||||
|
||||
# Disctiminator
|
||||
with timer() as t:
|
||||
self.optimizer_g.step()
|
||||
self.scheduler_g.step()
|
||||
print(f"Update takes {t.elapse}s.")
|
||||
|
||||
# Disctiminator
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
with paddle.no_grad():
|
||||
wav_ = self.generator(noise, mel)
|
||||
|
|
|
@ -122,13 +122,18 @@ def train_sp(args, config):
|
|||
print("criterions done!")
|
||||
|
||||
lr_schedule_g = StepDecay(**config["generator_scheduler_params"])
|
||||
gradient_clip_g = nn.ClipGradByGlobalNorm(config["generator_grad_norm"])
|
||||
optimizer_g = Adam(
|
||||
lr_schedule_g,
|
||||
learning_rate=lr_schedule_g,
|
||||
grad_clip=gradient_clip_g,
|
||||
parameters=generator.parameters(),
|
||||
**config["generator_optimizer_params"])
|
||||
lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"])
|
||||
gradient_clip_d = nn.ClipGradByGlobalNorm(config[
|
||||
"discriminator_grad_norm"])
|
||||
optimizer_d = Adam(
|
||||
lr_schedule_d,
|
||||
learning_rate=lr_schedule_d,
|
||||
grad_clip=gradient_clip_d,
|
||||
parameters=discriminator.parameters(),
|
||||
**config["discriminator_optimizer_params"])
|
||||
print("optimizers done!")
|
||||
|
@ -165,9 +170,11 @@ def train_sp(args, config):
|
|||
|
||||
trainer = Trainer(
|
||||
updater,
|
||||
stop_trigger=(config.train_max_steps, "iteration"),
|
||||
stop_trigger=(10, "iteration"), # PROFILING
|
||||
out=output_dir, )
|
||||
trainer.run()
|
||||
with paddle.fluid.profiler.cuda_profiler(
|
||||
str(output_dir / "profiler.log")) as prof:
|
||||
trainer.run()
|
||||
|
||||
|
||||
def main():
|
||||
|
|
Loading…
Reference in New Issue