add profiling
This commit is contained in:
parent
bbbe5a8b50
commit
8dbcc9bccb
|
@ -16,12 +16,9 @@ from typing import List, Dict, Any
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from config import get_cfg_default
|
|
||||||
import argparse
|
import argparse
|
||||||
import yaml
|
import yaml
|
||||||
import json
|
import json
|
||||||
import dacite
|
|
||||||
import dataclasses
|
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -30,6 +27,8 @@ from operator import itemgetter
|
||||||
from praatio import tgio
|
from praatio import tgio
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from config import get_cfg_default
|
||||||
|
|
||||||
|
|
||||||
def logmelfilterbank(audio,
|
def logmelfilterbank(audio,
|
||||||
sr,
|
sr,
|
||||||
|
@ -229,7 +228,7 @@ def main():
|
||||||
|
|
||||||
if args.verbose > 1:
|
if args.verbose > 1:
|
||||||
print(vars(args))
|
print(vars(args))
|
||||||
print(yaml.dump(dataclasses.asdict(C)))
|
print(C)
|
||||||
|
|
||||||
root_dir = Path(args.rootdir)
|
root_dir = Path(args.rootdir)
|
||||||
dumpdir = Path(args.dumpdir)
|
dumpdir = Path(args.dumpdir)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
from paddle.fluid.core import _cuda_synchronize
|
||||||
from timer import timer
|
from timer import timer
|
||||||
|
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
|
@ -61,12 +62,15 @@ class PWGUpdater(UpdaterBase):
|
||||||
self.train_iterator = iter(self.train_dataloader)
|
self.train_iterator = iter(self.train_dataloader)
|
||||||
|
|
||||||
def update_core(self):
|
def update_core(self):
|
||||||
|
place = paddle.fluid.framework._current_expected_place()
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
|
_cuda_synchronize(place)
|
||||||
try:
|
try:
|
||||||
batch = next(self.train_iterator)
|
batch = next(self.train_iterator)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.train_iterator = iter(self.train_dataloader)
|
self.train_iterator = iter(self.train_dataloader)
|
||||||
batch = next(self.train_iterator)
|
batch = next(self.train_iterator)
|
||||||
|
_cuda_synchronize(place)
|
||||||
print(f"Loading a batch takes {t.elapse}s")
|
print(f"Loading a batch takes {t.elapse}s")
|
||||||
|
|
||||||
wav, mel = batch
|
wav, mel = batch
|
||||||
|
@ -75,13 +79,17 @@ class PWGUpdater(UpdaterBase):
|
||||||
noise = paddle.randn(wav.shape)
|
noise = paddle.randn(wav.shape)
|
||||||
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
|
_cuda_synchronize(place)
|
||||||
wav_ = self.generator(noise, mel)
|
wav_ = self.generator(noise, mel)
|
||||||
|
_cuda_synchronize(place)
|
||||||
print(f"Generator takes {t.elapse}s")
|
print(f"Generator takes {t.elapse}s")
|
||||||
|
|
||||||
## Multi-resolution stft loss
|
## Multi-resolution stft loss
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
|
_cuda_synchronize(place)
|
||||||
sc_loss, mag_loss = self.criterion_stft(
|
sc_loss, mag_loss = self.criterion_stft(
|
||||||
wav_.squeeze(1), wav.squeeze(1))
|
wav_.squeeze(1), wav.squeeze(1))
|
||||||
|
_cuda_synchronize(place)
|
||||||
print(f"Multi-resolution STFT loss takes {t.elapse}s")
|
print(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||||
|
|
||||||
report("train/spectral_convergence_loss", float(sc_loss))
|
report("train/spectral_convergence_loss", float(sc_loss))
|
||||||
|
@ -91,24 +99,30 @@ class PWGUpdater(UpdaterBase):
|
||||||
## Adversarial loss
|
## Adversarial loss
|
||||||
if self.state.iteration > self.discriminator_train_start_steps:
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
|
_cuda_synchronize(place)
|
||||||
p_ = self.discriminator(wav_)
|
p_ = self.discriminator(wav_)
|
||||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
|
_cuda_synchronize(place)
|
||||||
print(f"Discriminator and adversarial loss takes {t.elapse}s")
|
print(f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
report("train/adversarial_loss", float(adv_loss))
|
report("train/adversarial_loss", float(adv_loss))
|
||||||
gen_loss += self.lambda_adv * adv_loss
|
gen_loss += self.lambda_adv * adv_loss
|
||||||
|
|
||||||
report("train/generator_loss", float(gen_loss))
|
report("train/generator_loss", float(gen_loss))
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
|
_cuda_synchronize(place)
|
||||||
self.optimizer_g.clear_grad()
|
self.optimizer_g.clear_grad()
|
||||||
gen_loss.backward()
|
gen_loss.backward()
|
||||||
|
_cuda_synchronize(place)
|
||||||
print(f"Backward takes {t.elapse}s.")
|
print(f"Backward takes {t.elapse}s.")
|
||||||
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
|
_cuda_synchronize(place)
|
||||||
self.optimizer_g.step()
|
self.optimizer_g.step()
|
||||||
self.scheduler_g.step()
|
self.scheduler_g.step()
|
||||||
|
_cuda_synchronize(place)
|
||||||
print(f"Update takes {t.elapse}s.")
|
print(f"Update takes {t.elapse}s.")
|
||||||
|
|
||||||
# Disctiminator
|
# Disctiminator
|
||||||
if self.state.iteration > self.discriminator_train_start_steps:
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
wav_ = self.generator(noise, mel)
|
wav_ = self.generator(noise, mel)
|
||||||
|
|
|
@ -20,7 +20,6 @@ import dataclasses
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
import dacite
|
|
||||||
import json
|
import json
|
||||||
import paddle
|
import paddle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
|
@ -64,7 +64,7 @@ class STFTLoss(nn.Layer):
|
||||||
fft_size=1024,
|
fft_size=1024,
|
||||||
shift_size=120,
|
shift_size=120,
|
||||||
win_length=600,
|
win_length=600,
|
||||||
window="hann_window"):
|
window="hann"):
|
||||||
"""Initialize STFT loss module."""
|
"""Initialize STFT loss module."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fft_size = fft_size
|
self.fft_size = fft_size
|
||||||
|
|
Loading…
Reference in New Issue