refined training module

This commit is contained in:
chenfeiyu 2021-06-30 12:30:14 +08:00
parent 29b8b8b0ea
commit ef51e1ab13
13 changed files with 169 additions and 59 deletions

View File

@ -125,4 +125,5 @@ log_interval_steps: 100 # Interval steps to record the training
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
num_snapshots: 10
num_snapshots: 10
seed: 42

View File

@ -74,18 +74,15 @@ class PWGUpdater(StandardUpdater):
# Generator
noise = paddle.randn(wav.shape)
synchronize()
with timer() as t:
wav_ = self.generator(noise, mel)
synchronize()
logging.debug(f"Generator takes {t.elapse}s.")
## Multi-resolution stft loss
synchronize()
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(
wav_.squeeze(1), wav.squeeze(1))
synchronize()
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
report("train/spectral_convergence_loss", float(sc_loss))
@ -94,11 +91,9 @@ class PWGUpdater(StandardUpdater):
## Adversarial loss
if self.state.iteration > self.discriminator_train_start_steps:
synchronize()
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
synchronize()
logging.debug(
f"Discriminator and adversarial loss takes {t.elapse}s")
report("train/adversarial_loss", float(adv_loss))
@ -106,18 +101,14 @@ class PWGUpdater(StandardUpdater):
report("train/generator_loss", float(gen_loss))
synchronize()
with timer() as t:
self.optimizer_g.clear_grad()
gen_loss.backward()
synchronize()
logging.debug(f"Backward takes {t.elapse}s.")
synchronize()
with timer() as t:
self.optimizer_g.step()
self.scheduler_g.step()
synchronize()
logging.debug(f"Update takes {t.elapse}s.")
# Disctiminator
@ -158,18 +149,15 @@ class PWGEvaluator(StandardEvaluator):
wav, mel = batch
noise = paddle.randn(wav.shape)
synchronize()
with timer() as t:
wav_ = self.generator(noise, mel)
synchronize()
logging.debug(f"Generator takes {t.elapse}s")
## Multi-resolution stft loss
synchronize()
with timer() as t:
sc_loss, mag_loss = self.criterion_stft(
wav_.squeeze(1), wav.squeeze(1))
synchronize()
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
report("eval/spectral_convergence_loss", float(sc_loss))
@ -177,11 +165,9 @@ class PWGEvaluator(StandardEvaluator):
gen_loss = sc_loss + mag_loss
## Adversarial loss
synchronize()
with timer() as t:
p_ = self.discriminator(wav_)
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
synchronize()
logging.debug(
f"Discriminator and adversarial loss takes {t.elapse}s")
report("eval/adversarial_loss", float(adv_loss))

View File

@ -14,6 +14,7 @@
import os
import sys
from timer import timer
import logging
import argparse
from pathlib import Path
@ -25,6 +26,8 @@ import numpy as np
import soundfile as sf
from paddle import distributed as dist
paddle.set_device("cpu")
from parakeet.datasets.data_table import DataTable
from parakeet.models.parallel_wavegan import PWGGenerator
@ -71,11 +74,20 @@ test_dataset = DataTable(
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
N = 0
T = 0
for example in test_dataset:
utt_id = example['utt_id']
mel = example['feats']
mel = paddle.to_tensor(mel) # (T, C)
wav = generator.inference(c=mel)
wav = wav.numpy()
print(f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}")
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000)
with timer() as t:
wav = generator.inference(c=mel)
wav = wav.numpy()
N += wav.size
T += t.elapse
speed = wav.size / t.elapse
print(
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.sr / speed}."
)
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=config.sr)
print(f"generation speed: {N / T}Hz, RTF: {config.sr / (N / T) }")

View File

@ -60,7 +60,7 @@ def train_sp(args, config):
paddle.distributed.init_parallel_env()
# set the random seed, it is a must for multiprocess training
seed_everything(42)
seed_everything(config.seed)
print(
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
@ -149,6 +149,8 @@ def train_sp(args, config):
output_dir = Path(args.output_dir)
checkpoint_dir = output_dir / "checkpoints"
if dist.get_rank() == 0:
with open(output_dir / "config.yaml", 'wt') as f:
f.write(config.dump(default_flow_style=None))
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir.mkdir(parents=True, exist_ok=True)

View File

@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.2.0-beta.0"
__version__ = "0.0.0"
import logging
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
logging.getLogger('parakeet').addHandler(logging.NullHandler())

View File

@ -43,9 +43,10 @@ class SpectralConvergenceLoss(nn.Layer):
class LogSTFTMagnitudeLoss(nn.Layer):
"""Log STFT magnitude loss module."""
def __init__(self):
def __init__(self, epsilon=1e-10):
"""Initilize los STFT magnitude loss module."""
super().__init__()
self.epsilon = epsilon
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
@ -57,9 +58,9 @@ class LogSTFTMagnitudeLoss(nn.Layer):
"""
return F.l1_loss(
paddle.log(paddle.clip(
y_mag, min=1e-10)),
y_mag, min=self.epsilon)),
paddle.log(paddle.clip(
x_mag, min=1e-10)))
x_mag, min=self.epsilon)))
class STFTLoss(nn.Layer):

View File

@ -106,4 +106,5 @@ class Snapshot(extension.Extension):
record_path = self.checkpoint_dir / "records.jsonl"
with jsonlines.open(record_path, 'w') as writer:
for record in self.records:
writer.write(record)
# jsonlines.open may return a Writer or a Reader
writer.write(record) # pylint: disable=no-member

View File

@ -21,7 +21,7 @@ from typing import Callable, Union, List
import tqdm
from parakeet.training.trigger import get_trigger, IntervalTrigger
from parakeet.training.trigger import get_trigger, IntervalTrigger, LimitTrigger
from parakeet.training.updater import UpdaterBase
from parakeet.training.reporter import scope
from parakeet.training.extension import Extension, PRIORITY_READER
@ -42,7 +42,7 @@ class Trainer(object):
extensions: List[Extension]=None):
self.updater = updater
self.extensions = OrderedDict()
self.stop_trigger = get_trigger(stop_trigger)
self.stop_trigger = LimitTrigger(*stop_trigger)
self.out = Path(out)
self.observation =...
@ -125,16 +125,19 @@ class Trainer(object):
print(self.updater.state)
# TODO(chenfeiyu): display progress bar correctly
# if the trainer is controlled by epoch: use 2 progressbars
# if the trainer is controlled by iteration: use 1 progressbar
if isinstance(stop_trigger, IntervalTrigger):
# display only one progress bar
max_iteration = None
if isinstance(stop_trigger, LimitTrigger):
if stop_trigger.unit is 'epoch':
max_epoch = self.stop_trigger.period
max_epoch = self.stop_trigger.limit
updates_per_epoch = getattr(self.updater, "updates_per_epoch",
None)
max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None
else:
max_iteration = self.stop_trigger.period
max_iteration = self.stop_trigger.limit
p = tqdm.tqdm(initial=self.updater.state.iteration)
p = tqdm.tqdm(
initial=self.updater.state.iteration, total=max_iteration)
try:
while not stop_trigger(self):

View File

@ -13,6 +13,8 @@
# limitations under the License.
from parakeet.training.triggers.interval_trigger import IntervalTrigger
from parakeet.training.triggers.limit_trigger import LimitTrigger
from parakeet.training.triggers.time_trigger import TimeTrigger
def never_file_trigger(trainer):

View File

@ -19,17 +19,13 @@ class IntervalTrigger(object):
def __init__(self, period: int, unit: str):
if unit not in ("iteration", "epoch"):
raise ValueError("unit should be 'iteration' or 'epoch'")
if period <= 0:
raise ValueError("period should be a positive integer.")
self.period = period
self.unit = unit
def __call__(self, trainer):
state = trainer.updater.state
# we use a special scheme so we can use iteration % period == 0 as
# the predicate
# increase the iteration then update parameters
# instead of updating then increase iteration
if self.unit == "epoch":
fire = state.epoch % self.period == 0
else:
fire = state.iteration % self.period == 0
index = getattr(state, self.unit)
fire = index % self.period == 0
return fire

View File

@ -0,0 +1,31 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class LimitTrigger(object):
"""A Predicate to decide whether to stop."""
def __init__(self, limit: int, unit: str):
if unit not in ("iteration", "epoch"):
raise ValueError("unit should be 'iteration' or 'epoch'")
if limit <= 0:
raise ValueError("limit should be a positive integer.")
self.limit = limit
self.unit = unit
def __call__(self, trainer):
state = trainer.updater.state
index = getattr(state, self.unit)
fire = index >= self.limit
return fire

View File

@ -62,7 +62,40 @@ class StandardUpdater(UpdaterBase):
self.train_iterator = iter(dataloader)
def update(self):
self.state.iteration += 1
# We increase the iteration index after updating and before extension.
# Here are the reasons.
# 0. Snapshotting(as well as other extensions, like visualizer) is
# executed after a step of updating;
# 1. We decide to increase the iteration index after updating and
# before any all extension is executed.
# 3. We do not increase the iteration after extension because we
# prefer a consistent resume behavior, when load from a
# `snapshot_iter_100.pdz` then the next step to train is `101`,
# naturally. But if iteration is increased increased after
# extension(including snapshot), then, a `snapshot_iter_99` is
# loaded. You would need a extra increasing of the iteration idex
# before training to avoid another iteration `99`, which has been
# done before snapshotting.
# 4. Thus iteration index represrnts "currently how mant epochs has
# been done."
# NOTE: use report to capture the correctly value. If you want to
# report the learning rate used for a step, you must report it before
# the learning rate scheduler's step() has been called. In paddle's
# convention, we do not use an extension to change the learning rate.
# so if you want to report it, do it in the updater.
# Then here comes the next question. When is the proper time to
# increase the epoch index? Since all extensions are executed after
# updating, it is the time that after updating is the proper time to
# increase epoch index.
# 1. If we increase the epoch index before updating, then an extension
# based ot epoch would miss the correct timing. It could only be
# triggerd after an extra updating.
# 2. Theoretically, when an epoch is done, the epoch index should be
# increased. So it would be increase after updating.
# 3. Thus, eppoch index represents "currently how many epochs has been
# done." So it starts from 0.
# switch to training mode
for layer in self.models.values():
@ -72,6 +105,11 @@ class StandardUpdater(UpdaterBase):
batch = self.read_batch()
self.update_core(batch)
self.state.iteration += 1
if self.updaters_per_epoch is not None:
if self.state.iteration % self.updaters_per_epoch == 0:
self.state.epoch += 1
def update_core(self, batch):
"""A simple case for a training step. Basic assumptions are:
Single model;
@ -100,10 +138,20 @@ class StandardUpdater(UpdaterBase):
loss_dict["main"].backward()
self.optimizer.update()
@property
def updaters_per_epoch(self):
"""Number of updater per epoch, determined by the length of the
dataloader."""
length_of_dataloader = None
try:
length_of_dataloader = len(self.dataloader)
except TypeError:
logging.debug("This dataloader has no __len__.")
finally:
return length_of_dataloader
def new_epoch(self):
"""Start a new epoch."""
self.state.epoch += 1
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
batch_sampler = self.dataloader.batch_sampler
@ -140,13 +188,3 @@ class StandardUpdater(UpdaterBase):
for name, optim in self.optimizers.items():
optim.set_state_dict(state_dict[f"{name}_optimizer"])
super().set_state_dict(state_dict)
def save(self, path):
"""Save Updater state dict."""
archive = self.state_dict()
paddle.save(archive, path)
def load(self, path):
"""Load Updater state dict."""
archive = paddle.load(path)
self.set_state_dict(archive)

39
tests/test_optimizer.py Normal file
View File

@ -0,0 +1,39 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
from pathlib import Path
import paddle
from paddle import nn
from paddle.optimizer import Adam
from paddle.optimizer.lr import StepDecay
def test_optimizer():
model1 = nn.Linear(3, 4)
optim1 = Adam(
parameters=model1.parameters(), learning_rate=StepDecay(0.1, 100))
output_dir = Path("temp_test_optimizer")
shutil.rmtree(output_dir, ignore_errors=True)
output_dir.mkdir(exist_ok=True, parents=True)
# model1.set_state_dict(model1.state_dict())
optim1.set_state_dict(optim1.state_dict())
x = paddle.randn([6, 3])
y = model1(x).sum()
y.backward()
optim1.step()