diff --git a/examples/parallelwave_gan/baker/preprocess.py b/examples/parallelwave_gan/baker/preprocess.py index 9bf9623..23b5f05 100644 --- a/examples/parallelwave_gan/baker/preprocess.py +++ b/examples/parallelwave_gan/baker/preprocess.py @@ -16,12 +16,9 @@ from typing import List, Dict, Any import soundfile as sf import librosa import numpy as np -from config import get_cfg_default import argparse import yaml import json -import dacite -import dataclasses import concurrent.futures from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from pathlib import Path @@ -30,6 +27,8 @@ from operator import itemgetter from praatio import tgio import logging +from config import get_cfg_default + def logmelfilterbank(audio, sr, @@ -229,7 +228,7 @@ def main(): if args.verbose > 1: print(vars(args)) - print(yaml.dump(dataclasses.asdict(C))) + print(C) root_dir = Path(args.rootdir) dumpdir = Path(args.dumpdir) diff --git a/examples/parallelwave_gan/baker/pwg_updater.py b/examples/parallelwave_gan/baker/pwg_updater.py index 70507f2..3ab79b8 100644 --- a/examples/parallelwave_gan/baker/pwg_updater.py +++ b/examples/parallelwave_gan/baker/pwg_updater.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle +from paddle.fluid.core import _cuda_synchronize from timer import timer from parakeet.datasets.data_table import DataTable @@ -61,12 +62,15 @@ class PWGUpdater(UpdaterBase): self.train_iterator = iter(self.train_dataloader) def update_core(self): + place = paddle.fluid.framework._current_expected_place() with timer() as t: + _cuda_synchronize(place) try: batch = next(self.train_iterator) except StopIteration: self.train_iterator = iter(self.train_dataloader) batch = next(self.train_iterator) + _cuda_synchronize(place) print(f"Loading a batch takes {t.elapse}s") wav, mel = batch @@ -75,13 +79,17 @@ class PWGUpdater(UpdaterBase): noise = paddle.randn(wav.shape) with timer() as t: + _cuda_synchronize(place) wav_ = self.generator(noise, mel) + _cuda_synchronize(place) print(f"Generator takes {t.elapse}s") ## Multi-resolution stft loss with timer() as t: + _cuda_synchronize(place) sc_loss, mag_loss = self.criterion_stft( wav_.squeeze(1), wav.squeeze(1)) + _cuda_synchronize(place) print(f"Multi-resolution STFT loss takes {t.elapse}s") report("train/spectral_convergence_loss", float(sc_loss)) @@ -91,24 +99,30 @@ class PWGUpdater(UpdaterBase): ## Adversarial loss if self.state.iteration > self.discriminator_train_start_steps: with timer() as t: + _cuda_synchronize(place) p_ = self.discriminator(wav_) adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) + _cuda_synchronize(place) print(f"Discriminator and adversarial loss takes {t.elapse}s") report("train/adversarial_loss", float(adv_loss)) gen_loss += self.lambda_adv * adv_loss report("train/generator_loss", float(gen_loss)) with timer() as t: + _cuda_synchronize(place) self.optimizer_g.clear_grad() gen_loss.backward() + _cuda_synchronize(place) print(f"Backward takes {t.elapse}s.") with timer() as t: + _cuda_synchronize(place) self.optimizer_g.step() self.scheduler_g.step() + _cuda_synchronize(place) print(f"Update takes {t.elapse}s.") -# Disctiminator + # Disctiminator if self.state.iteration > self.discriminator_train_start_steps: with paddle.no_grad(): wav_ = self.generator(noise, mel) diff --git a/examples/parallelwave_gan/baker/train.py b/examples/parallelwave_gan/baker/train.py index 2e1af14..830bc83 100644 --- a/examples/parallelwave_gan/baker/train.py +++ b/examples/parallelwave_gan/baker/train.py @@ -20,7 +20,6 @@ import dataclasses from pathlib import Path import yaml -import dacite import json import paddle import numpy as np diff --git a/parakeet/modules/stft_loss.py b/parakeet/modules/stft_loss.py index 6c5bc9f..6531010 100644 --- a/parakeet/modules/stft_loss.py +++ b/parakeet/modules/stft_loss.py @@ -64,7 +64,7 @@ class STFTLoss(nn.Layer): fft_size=1024, shift_size=120, win_length=600, - window="hann_window"): + window="hann"): """Initialize STFT loss module.""" super().__init__() self.fft_size = fft_size diff --git a/setup.py b/setup.py index eefc922..6f112cc 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,6 @@ setup_info = dict( 'scipy', 'pandas', 'sox', - # 'opencc', 'soundfile', 'g2p_en', 'yacs',