From ef51e1ab13e2bf0c146fab59200bd1a83070ed1b Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Wed, 30 Jun 2021 12:30:14 +0800 Subject: [PATCH] refined training module --- .../parallelwave_gan/baker/conf/default.yaml | 3 +- .../parallelwave_gan/baker/pwg_updater.py | 18 +----- examples/parallelwave_gan/baker/synthesize.py | 20 ++++-- examples/parallelwave_gan/baker/train.py | 4 +- parakeet/__init__.py | 4 +- parakeet/modules/stft_loss.py | 7 +- parakeet/training/extensions/snapshot.py | 3 +- parakeet/training/trainer.py | 21 +++--- parakeet/training/trigger.py | 2 + .../training/triggers/interval_trigger.py | 12 ++-- parakeet/training/triggers/limit_trigger.py | 31 +++++++++ .../training/updaters/standard_updater.py | 64 +++++++++++++++---- tests/test_optimizer.py | 39 +++++++++++ 13 files changed, 169 insertions(+), 59 deletions(-) create mode 100644 parakeet/training/triggers/limit_trigger.py create mode 100644 tests/test_optimizer.py diff --git a/examples/parallelwave_gan/baker/conf/default.yaml b/examples/parallelwave_gan/baker/conf/default.yaml index ce5b064..877be2c 100644 --- a/examples/parallelwave_gan/baker/conf/default.yaml +++ b/examples/parallelwave_gan/baker/conf/default.yaml @@ -125,4 +125,5 @@ log_interval_steps: 100 # Interval steps to record the training # OTHER SETTING # ########################################################### num_save_intermediate_results: 4 # Number of results to be saved as intermediate results. -num_snapshots: 10 \ No newline at end of file +num_snapshots: 10 +seed: 42 \ No newline at end of file diff --git a/examples/parallelwave_gan/baker/pwg_updater.py b/examples/parallelwave_gan/baker/pwg_updater.py index f7ff916..838416a 100644 --- a/examples/parallelwave_gan/baker/pwg_updater.py +++ b/examples/parallelwave_gan/baker/pwg_updater.py @@ -74,18 +74,15 @@ class PWGUpdater(StandardUpdater): # Generator noise = paddle.randn(wav.shape) - synchronize() with timer() as t: wav_ = self.generator(noise, mel) - synchronize() logging.debug(f"Generator takes {t.elapse}s.") ## Multi-resolution stft loss - synchronize() + with timer() as t: sc_loss, mag_loss = self.criterion_stft( wav_.squeeze(1), wav.squeeze(1)) - synchronize() logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.") report("train/spectral_convergence_loss", float(sc_loss)) @@ -94,11 +91,9 @@ class PWGUpdater(StandardUpdater): ## Adversarial loss if self.state.iteration > self.discriminator_train_start_steps: - synchronize() with timer() as t: p_ = self.discriminator(wav_) adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) - synchronize() logging.debug( f"Discriminator and adversarial loss takes {t.elapse}s") report("train/adversarial_loss", float(adv_loss)) @@ -106,18 +101,14 @@ class PWGUpdater(StandardUpdater): report("train/generator_loss", float(gen_loss)) - synchronize() with timer() as t: self.optimizer_g.clear_grad() gen_loss.backward() - synchronize() logging.debug(f"Backward takes {t.elapse}s.") - synchronize() with timer() as t: self.optimizer_g.step() self.scheduler_g.step() - synchronize() logging.debug(f"Update takes {t.elapse}s.") # Disctiminator @@ -158,18 +149,15 @@ class PWGEvaluator(StandardEvaluator): wav, mel = batch noise = paddle.randn(wav.shape) - synchronize() with timer() as t: wav_ = self.generator(noise, mel) - synchronize() logging.debug(f"Generator takes {t.elapse}s") ## Multi-resolution stft loss - synchronize() + with timer() as t: sc_loss, mag_loss = self.criterion_stft( wav_.squeeze(1), wav.squeeze(1)) - synchronize() logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s") report("eval/spectral_convergence_loss", float(sc_loss)) @@ -177,11 +165,9 @@ class PWGEvaluator(StandardEvaluator): gen_loss = sc_loss + mag_loss ## Adversarial loss - synchronize() with timer() as t: p_ = self.discriminator(wav_) adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) - synchronize() logging.debug( f"Discriminator and adversarial loss takes {t.elapse}s") report("eval/adversarial_loss", float(adv_loss)) diff --git a/examples/parallelwave_gan/baker/synthesize.py b/examples/parallelwave_gan/baker/synthesize.py index 4c4e754..fd84e85 100644 --- a/examples/parallelwave_gan/baker/synthesize.py +++ b/examples/parallelwave_gan/baker/synthesize.py @@ -14,6 +14,7 @@ import os import sys +from timer import timer import logging import argparse from pathlib import Path @@ -25,6 +26,8 @@ import numpy as np import soundfile as sf from paddle import distributed as dist +paddle.set_device("cpu") + from parakeet.datasets.data_table import DataTable from parakeet.models.parallel_wavegan import PWGGenerator @@ -71,11 +74,20 @@ test_dataset = DataTable( output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) +N = 0 +T = 0 for example in test_dataset: utt_id = example['utt_id'] mel = example['feats'] mel = paddle.to_tensor(mel) # (T, C) - wav = generator.inference(c=mel) - wav = wav.numpy() - print(f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}") - sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000) + with timer() as t: + wav = generator.inference(c=mel) + wav = wav.numpy() + N += wav.size + T += t.elapse + speed = wav.size / t.elapse + print( + f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.sr / speed}." + ) + sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=config.sr) +print(f"generation speed: {N / T}Hz, RTF: {config.sr / (N / T) }") diff --git a/examples/parallelwave_gan/baker/train.py b/examples/parallelwave_gan/baker/train.py index bf8767a..ee012e2 100644 --- a/examples/parallelwave_gan/baker/train.py +++ b/examples/parallelwave_gan/baker/train.py @@ -60,7 +60,7 @@ def train_sp(args, config): paddle.distributed.init_parallel_env() # set the random seed, it is a must for multiprocess training - seed_everything(42) + seed_everything(config.seed) print( f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", @@ -149,6 +149,8 @@ def train_sp(args, config): output_dir = Path(args.output_dir) checkpoint_dir = output_dir / "checkpoints" if dist.get_rank() == 0: + with open(output_dir / "config.yaml", 'wt') as f: + f.write(config.dump(default_flow_style=None)) output_dir.mkdir(parents=True, exist_ok=True) checkpoint_dir.mkdir(parents=True, exist_ok=True) diff --git a/parakeet/__init__.py b/parakeet/__init__.py index cce4fff..f08f907 100644 --- a/parakeet/__init__.py +++ b/parakeet/__init__.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.2.0-beta.0" +__version__ = "0.0.0" import logging from parakeet import audio, data, datasets, frontend, models, modules, training, utils - -logging.getLogger('parakeet').addHandler(logging.NullHandler()) diff --git a/parakeet/modules/stft_loss.py b/parakeet/modules/stft_loss.py index f98a3dd..cdc066f 100644 --- a/parakeet/modules/stft_loss.py +++ b/parakeet/modules/stft_loss.py @@ -43,9 +43,10 @@ class SpectralConvergenceLoss(nn.Layer): class LogSTFTMagnitudeLoss(nn.Layer): """Log STFT magnitude loss module.""" - def __init__(self): + def __init__(self, epsilon=1e-10): """Initilize los STFT magnitude loss module.""" super().__init__() + self.epsilon = epsilon def forward(self, x_mag, y_mag): """Calculate forward propagation. @@ -57,9 +58,9 @@ class LogSTFTMagnitudeLoss(nn.Layer): """ return F.l1_loss( paddle.log(paddle.clip( - y_mag, min=1e-10)), + y_mag, min=self.epsilon)), paddle.log(paddle.clip( - x_mag, min=1e-10))) + x_mag, min=self.epsilon))) class STFTLoss(nn.Layer): diff --git a/parakeet/training/extensions/snapshot.py b/parakeet/training/extensions/snapshot.py index a209524..92d74ef 100644 --- a/parakeet/training/extensions/snapshot.py +++ b/parakeet/training/extensions/snapshot.py @@ -106,4 +106,5 @@ class Snapshot(extension.Extension): record_path = self.checkpoint_dir / "records.jsonl" with jsonlines.open(record_path, 'w') as writer: for record in self.records: - writer.write(record) + # jsonlines.open may return a Writer or a Reader + writer.write(record) # pylint: disable=no-member diff --git a/parakeet/training/trainer.py b/parakeet/training/trainer.py index 9e90d9a..484845f 100644 --- a/parakeet/training/trainer.py +++ b/parakeet/training/trainer.py @@ -21,7 +21,7 @@ from typing import Callable, Union, List import tqdm -from parakeet.training.trigger import get_trigger, IntervalTrigger +from parakeet.training.trigger import get_trigger, IntervalTrigger, LimitTrigger from parakeet.training.updater import UpdaterBase from parakeet.training.reporter import scope from parakeet.training.extension import Extension, PRIORITY_READER @@ -42,7 +42,7 @@ class Trainer(object): extensions: List[Extension]=None): self.updater = updater self.extensions = OrderedDict() - self.stop_trigger = get_trigger(stop_trigger) + self.stop_trigger = LimitTrigger(*stop_trigger) self.out = Path(out) self.observation =... @@ -125,16 +125,19 @@ class Trainer(object): print(self.updater.state) - # TODO(chenfeiyu): display progress bar correctly - # if the trainer is controlled by epoch: use 2 progressbars - # if the trainer is controlled by iteration: use 1 progressbar - if isinstance(stop_trigger, IntervalTrigger): + # display only one progress bar + max_iteration = None + if isinstance(stop_trigger, LimitTrigger): if stop_trigger.unit is 'epoch': - max_epoch = self.stop_trigger.period + max_epoch = self.stop_trigger.limit + updates_per_epoch = getattr(self.updater, "updates_per_epoch", + None) + max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None else: - max_iteration = self.stop_trigger.period + max_iteration = self.stop_trigger.limit - p = tqdm.tqdm(initial=self.updater.state.iteration) + p = tqdm.tqdm( + initial=self.updater.state.iteration, total=max_iteration) try: while not stop_trigger(self): diff --git a/parakeet/training/trigger.py b/parakeet/training/trigger.py index f5834bc..b588512 100644 --- a/parakeet/training/trigger.py +++ b/parakeet/training/trigger.py @@ -13,6 +13,8 @@ # limitations under the License. from parakeet.training.triggers.interval_trigger import IntervalTrigger +from parakeet.training.triggers.limit_trigger import LimitTrigger +from parakeet.training.triggers.time_trigger import TimeTrigger def never_file_trigger(trainer): diff --git a/parakeet/training/triggers/interval_trigger.py b/parakeet/training/triggers/interval_trigger.py index 82f441e..b88816c 100644 --- a/parakeet/training/triggers/interval_trigger.py +++ b/parakeet/training/triggers/interval_trigger.py @@ -19,17 +19,13 @@ class IntervalTrigger(object): def __init__(self, period: int, unit: str): if unit not in ("iteration", "epoch"): raise ValueError("unit should be 'iteration' or 'epoch'") + if period <= 0: + raise ValueError("period should be a positive integer.") self.period = period self.unit = unit def __call__(self, trainer): state = trainer.updater.state - # we use a special scheme so we can use iteration % period == 0 as - # the predicate - # increase the iteration then update parameters - # instead of updating then increase iteration - if self.unit == "epoch": - fire = state.epoch % self.period == 0 - else: - fire = state.iteration % self.period == 0 + index = getattr(state, self.unit) + fire = index % self.period == 0 return fire diff --git a/parakeet/training/triggers/limit_trigger.py b/parakeet/training/triggers/limit_trigger.py new file mode 100644 index 0000000..dd7a135 --- /dev/null +++ b/parakeet/training/triggers/limit_trigger.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class LimitTrigger(object): + """A Predicate to decide whether to stop.""" + + def __init__(self, limit: int, unit: str): + if unit not in ("iteration", "epoch"): + raise ValueError("unit should be 'iteration' or 'epoch'") + if limit <= 0: + raise ValueError("limit should be a positive integer.") + self.limit = limit + self.unit = unit + + def __call__(self, trainer): + state = trainer.updater.state + index = getattr(state, self.unit) + fire = index >= self.limit + return fire diff --git a/parakeet/training/updaters/standard_updater.py b/parakeet/training/updaters/standard_updater.py index 5cc0252..e39b758 100644 --- a/parakeet/training/updaters/standard_updater.py +++ b/parakeet/training/updaters/standard_updater.py @@ -62,7 +62,40 @@ class StandardUpdater(UpdaterBase): self.train_iterator = iter(dataloader) def update(self): - self.state.iteration += 1 + # We increase the iteration index after updating and before extension. + # Here are the reasons. + + # 0. Snapshotting(as well as other extensions, like visualizer) is + # executed after a step of updating; + # 1. We decide to increase the iteration index after updating and + # before any all extension is executed. + # 3. We do not increase the iteration after extension because we + # prefer a consistent resume behavior, when load from a + # `snapshot_iter_100.pdz` then the next step to train is `101`, + # naturally. But if iteration is increased increased after + # extension(including snapshot), then, a `snapshot_iter_99` is + # loaded. You would need a extra increasing of the iteration idex + # before training to avoid another iteration `99`, which has been + # done before snapshotting. + # 4. Thus iteration index represrnts "currently how mant epochs has + # been done." + # NOTE: use report to capture the correctly value. If you want to + # report the learning rate used for a step, you must report it before + # the learning rate scheduler's step() has been called. In paddle's + # convention, we do not use an extension to change the learning rate. + # so if you want to report it, do it in the updater. + + # Then here comes the next question. When is the proper time to + # increase the epoch index? Since all extensions are executed after + # updating, it is the time that after updating is the proper time to + # increase epoch index. + # 1. If we increase the epoch index before updating, then an extension + # based ot epoch would miss the correct timing. It could only be + # triggerd after an extra updating. + # 2. Theoretically, when an epoch is done, the epoch index should be + # increased. So it would be increase after updating. + # 3. Thus, eppoch index represents "currently how many epochs has been + # done." So it starts from 0. # switch to training mode for layer in self.models.values(): @@ -72,6 +105,11 @@ class StandardUpdater(UpdaterBase): batch = self.read_batch() self.update_core(batch) + self.state.iteration += 1 + if self.updaters_per_epoch is not None: + if self.state.iteration % self.updaters_per_epoch == 0: + self.state.epoch += 1 + def update_core(self, batch): """A simple case for a training step. Basic assumptions are: Single model; @@ -100,10 +138,20 @@ class StandardUpdater(UpdaterBase): loss_dict["main"].backward() self.optimizer.update() + @property + def updaters_per_epoch(self): + """Number of updater per epoch, determined by the length of the + dataloader.""" + length_of_dataloader = None + try: + length_of_dataloader = len(self.dataloader) + except TypeError: + logging.debug("This dataloader has no __len__.") + finally: + return length_of_dataloader + def new_epoch(self): """Start a new epoch.""" - self.state.epoch += 1 - # NOTE: all batch sampler for distributed training should # subclass DistributedBatchSampler and implement `set_epoch` method batch_sampler = self.dataloader.batch_sampler @@ -140,13 +188,3 @@ class StandardUpdater(UpdaterBase): for name, optim in self.optimizers.items(): optim.set_state_dict(state_dict[f"{name}_optimizer"]) super().set_state_dict(state_dict) - - def save(self, path): - """Save Updater state dict.""" - archive = self.state_dict() - paddle.save(archive, path) - - def load(self, path): - """Load Updater state dict.""" - archive = paddle.load(path) - self.set_state_dict(archive) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000..bdb3d96 --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,39 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +from pathlib import Path + +import paddle +from paddle import nn +from paddle.optimizer import Adam +from paddle.optimizer.lr import StepDecay + + +def test_optimizer(): + model1 = nn.Linear(3, 4) + optim1 = Adam( + parameters=model1.parameters(), learning_rate=StepDecay(0.1, 100)) + + output_dir = Path("temp_test_optimizer") + shutil.rmtree(output_dir, ignore_errors=True) + output_dir.mkdir(exist_ok=True, parents=True) + + # model1.set_state_dict(model1.state_dict()) + optim1.set_state_dict(optim1.state_dict()) + + x = paddle.randn([6, 3]) + y = model1(x).sum() + y.backward() + optim1.step()