add profiling

This commit is contained in:
chenfeiyu 2021-06-16 09:40:47 +00:00
parent bbbe5a8b50
commit 8dbcc9bccb
5 changed files with 19 additions and 8 deletions

View File

@ -16,12 +16,9 @@ from typing import List, Dict, Any
import soundfile as sf
import librosa
import numpy as np
from config import get_cfg_default
import argparse
import yaml
import json
import dacite
import dataclasses
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from pathlib import Path
@ -30,6 +27,8 @@ from operator import itemgetter
from praatio import tgio
import logging
from config import get_cfg_default
def logmelfilterbank(audio,
sr,
@ -229,7 +228,7 @@ def main():
if args.verbose > 1:
print(vars(args))
print(yaml.dump(dataclasses.asdict(C)))
print(C)
root_dir = Path(args.rootdir)
dumpdir = Path(args.dumpdir)

View File

@ -13,6 +13,7 @@
# limitations under the License.
import paddle
from paddle.fluid.core import _cuda_synchronize
from timer import timer
from parakeet.datasets.data_table import DataTable
@ -61,12 +62,15 @@ class PWGUpdater(UpdaterBase):
self.train_iterator = iter(self.train_dataloader)
def update_core(self):
place = paddle.fluid.framework._current_expected_place()
with timer() as t:
_cuda_synchronize(place)
try:
batch = next(self.train_iterator)
except StopIteration:
self.train_iterator = iter(self.train_dataloader)
batch = next(self.train_iterator)
_cuda_synchronize(place)
print(f"Loading a batch takes {t.elapse}s")
wav, mel = batch
@ -75,13 +79,17 @@ class PWGUpdater(UpdaterBase):
noise = paddle.randn(wav.shape)
with timer() as t:
_cuda_synchronize(place)
wav_ = self.generator(noise, mel)
_cuda_synchronize(place)
print(f"Generator takes {t.elapse}s")
## Multi-resolution stft loss
with timer() as t:
_cuda_synchronize(place)
sc_loss, mag_loss = self.criterion_stft(
wav_.squeeze(1), wav.squeeze(1))
_cuda_synchronize(place)
print(f"Multi-resolution STFT loss takes {t.elapse}s")
report("train/spectral_convergence_loss", float(sc_loss))
@ -91,24 +99,30 @@ class PWGUpdater(UpdaterBase):
## Adversarial loss
if self.state.iteration > self.discriminator_train_start_steps:
with timer() as t:
_cuda_synchronize(place)
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
_cuda_synchronize(place)
print(f"Discriminator and adversarial loss takes {t.elapse}s")
report("train/adversarial_loss", float(adv_loss))
gen_loss += self.lambda_adv * adv_loss
report("train/generator_loss", float(gen_loss))
with timer() as t:
_cuda_synchronize(place)
self.optimizer_g.clear_grad()
gen_loss.backward()
_cuda_synchronize(place)
print(f"Backward takes {t.elapse}s.")
with timer() as t:
_cuda_synchronize(place)
self.optimizer_g.step()
self.scheduler_g.step()
_cuda_synchronize(place)
print(f"Update takes {t.elapse}s.")
# Disctiminator
# Disctiminator
if self.state.iteration > self.discriminator_train_start_steps:
with paddle.no_grad():
wav_ = self.generator(noise, mel)

View File

@ -20,7 +20,6 @@ import dataclasses
from pathlib import Path
import yaml
import dacite
import json
import paddle
import numpy as np

View File

@ -64,7 +64,7 @@ class STFTLoss(nn.Layer):
fft_size=1024,
shift_size=120,
win_length=600,
window="hann_window"):
window="hann"):
"""Initialize STFT loss module."""
super().__init__()
self.fft_size = fft_size

View File

@ -64,7 +64,6 @@ setup_info = dict(
'scipy',
'pandas',
'sox',
# 'opencc',
'soundfile',
'g2p_en',
'yacs',