refined training module
This commit is contained in:
parent
29b8b8b0ea
commit
ef51e1ab13
|
@ -125,4 +125,5 @@ log_interval_steps: 100 # Interval steps to record the training
|
|||
# OTHER SETTING #
|
||||
###########################################################
|
||||
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
|
||||
num_snapshots: 10
|
||||
num_snapshots: 10
|
||||
seed: 42
|
|
@ -74,18 +74,15 @@ class PWGUpdater(StandardUpdater):
|
|||
# Generator
|
||||
noise = paddle.randn(wav.shape)
|
||||
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
wav_ = self.generator(noise, mel)
|
||||
synchronize()
|
||||
logging.debug(f"Generator takes {t.elapse}s.")
|
||||
|
||||
## Multi-resolution stft loss
|
||||
synchronize()
|
||||
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
synchronize()
|
||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
||||
|
||||
report("train/spectral_convergence_loss", float(sc_loss))
|
||||
|
@ -94,11 +91,9 @@ class PWGUpdater(StandardUpdater):
|
|||
|
||||
## Adversarial loss
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||
synchronize()
|
||||
logging.debug(
|
||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
report("train/adversarial_loss", float(adv_loss))
|
||||
|
@ -106,18 +101,14 @@ class PWGUpdater(StandardUpdater):
|
|||
|
||||
report("train/generator_loss", float(gen_loss))
|
||||
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
self.optimizer_g.clear_grad()
|
||||
gen_loss.backward()
|
||||
synchronize()
|
||||
logging.debug(f"Backward takes {t.elapse}s.")
|
||||
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
self.optimizer_g.step()
|
||||
self.scheduler_g.step()
|
||||
synchronize()
|
||||
logging.debug(f"Update takes {t.elapse}s.")
|
||||
|
||||
# Disctiminator
|
||||
|
@ -158,18 +149,15 @@ class PWGEvaluator(StandardEvaluator):
|
|||
wav, mel = batch
|
||||
noise = paddle.randn(wav.shape)
|
||||
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
wav_ = self.generator(noise, mel)
|
||||
synchronize()
|
||||
logging.debug(f"Generator takes {t.elapse}s")
|
||||
|
||||
## Multi-resolution stft loss
|
||||
synchronize()
|
||||
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
synchronize()
|
||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||
|
||||
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||
|
@ -177,11 +165,9 @@ class PWGEvaluator(StandardEvaluator):
|
|||
gen_loss = sc_loss + mag_loss
|
||||
|
||||
## Adversarial loss
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||
synchronize()
|
||||
logging.debug(
|
||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
report("eval/adversarial_loss", float(adv_loss))
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import os
|
||||
import sys
|
||||
from timer import timer
|
||||
import logging
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
@ -25,6 +26,8 @@ import numpy as np
|
|||
import soundfile as sf
|
||||
from paddle import distributed as dist
|
||||
|
||||
paddle.set_device("cpu")
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||
|
||||
|
@ -71,11 +74,20 @@ test_dataset = DataTable(
|
|||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
N = 0
|
||||
T = 0
|
||||
for example in test_dataset:
|
||||
utt_id = example['utt_id']
|
||||
mel = example['feats']
|
||||
mel = paddle.to_tensor(mel) # (T, C)
|
||||
wav = generator.inference(c=mel)
|
||||
wav = wav.numpy()
|
||||
print(f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}")
|
||||
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000)
|
||||
with timer() as t:
|
||||
wav = generator.inference(c=mel)
|
||||
wav = wav.numpy()
|
||||
N += wav.size
|
||||
T += t.elapse
|
||||
speed = wav.size / t.elapse
|
||||
print(
|
||||
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.sr / speed}."
|
||||
)
|
||||
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=config.sr)
|
||||
print(f"generation speed: {N / T}Hz, RTF: {config.sr / (N / T) }")
|
||||
|
|
|
@ -60,7 +60,7 @@ def train_sp(args, config):
|
|||
paddle.distributed.init_parallel_env()
|
||||
|
||||
# set the random seed, it is a must for multiprocess training
|
||||
seed_everything(42)
|
||||
seed_everything(config.seed)
|
||||
|
||||
print(
|
||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
|
@ -149,6 +149,8 @@ def train_sp(args, config):
|
|||
output_dir = Path(args.output_dir)
|
||||
checkpoint_dir = output_dir / "checkpoints"
|
||||
if dist.get_rank() == 0:
|
||||
with open(output_dir / "config.yaml", 'wt') as f:
|
||||
f.write(config.dump(default_flow_style=None))
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
|
|
@ -12,9 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.2.0-beta.0"
|
||||
__version__ = "0.0.0"
|
||||
|
||||
import logging
|
||||
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
|
||||
|
||||
logging.getLogger('parakeet').addHandler(logging.NullHandler())
|
||||
|
|
|
@ -43,9 +43,10 @@ class SpectralConvergenceLoss(nn.Layer):
|
|||
class LogSTFTMagnitudeLoss(nn.Layer):
|
||||
"""Log STFT magnitude loss module."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, epsilon=1e-10):
|
||||
"""Initilize los STFT magnitude loss module."""
|
||||
super().__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
|
@ -57,9 +58,9 @@ class LogSTFTMagnitudeLoss(nn.Layer):
|
|||
"""
|
||||
return F.l1_loss(
|
||||
paddle.log(paddle.clip(
|
||||
y_mag, min=1e-10)),
|
||||
y_mag, min=self.epsilon)),
|
||||
paddle.log(paddle.clip(
|
||||
x_mag, min=1e-10)))
|
||||
x_mag, min=self.epsilon)))
|
||||
|
||||
|
||||
class STFTLoss(nn.Layer):
|
||||
|
|
|
@ -106,4 +106,5 @@ class Snapshot(extension.Extension):
|
|||
record_path = self.checkpoint_dir / "records.jsonl"
|
||||
with jsonlines.open(record_path, 'w') as writer:
|
||||
for record in self.records:
|
||||
writer.write(record)
|
||||
# jsonlines.open may return a Writer or a Reader
|
||||
writer.write(record) # pylint: disable=no-member
|
||||
|
|
|
@ -21,7 +21,7 @@ from typing import Callable, Union, List
|
|||
|
||||
import tqdm
|
||||
|
||||
from parakeet.training.trigger import get_trigger, IntervalTrigger
|
||||
from parakeet.training.trigger import get_trigger, IntervalTrigger, LimitTrigger
|
||||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.reporter import scope
|
||||
from parakeet.training.extension import Extension, PRIORITY_READER
|
||||
|
@ -42,7 +42,7 @@ class Trainer(object):
|
|||
extensions: List[Extension]=None):
|
||||
self.updater = updater
|
||||
self.extensions = OrderedDict()
|
||||
self.stop_trigger = get_trigger(stop_trigger)
|
||||
self.stop_trigger = LimitTrigger(*stop_trigger)
|
||||
self.out = Path(out)
|
||||
self.observation =...
|
||||
|
||||
|
@ -125,16 +125,19 @@ class Trainer(object):
|
|||
|
||||
print(self.updater.state)
|
||||
|
||||
# TODO(chenfeiyu): display progress bar correctly
|
||||
# if the trainer is controlled by epoch: use 2 progressbars
|
||||
# if the trainer is controlled by iteration: use 1 progressbar
|
||||
if isinstance(stop_trigger, IntervalTrigger):
|
||||
# display only one progress bar
|
||||
max_iteration = None
|
||||
if isinstance(stop_trigger, LimitTrigger):
|
||||
if stop_trigger.unit is 'epoch':
|
||||
max_epoch = self.stop_trigger.period
|
||||
max_epoch = self.stop_trigger.limit
|
||||
updates_per_epoch = getattr(self.updater, "updates_per_epoch",
|
||||
None)
|
||||
max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None
|
||||
else:
|
||||
max_iteration = self.stop_trigger.period
|
||||
max_iteration = self.stop_trigger.limit
|
||||
|
||||
p = tqdm.tqdm(initial=self.updater.state.iteration)
|
||||
p = tqdm.tqdm(
|
||||
initial=self.updater.state.iteration, total=max_iteration)
|
||||
|
||||
try:
|
||||
while not stop_trigger(self):
|
||||
|
|
|
@ -13,6 +13,8 @@
|
|||
# limitations under the License.
|
||||
|
||||
from parakeet.training.triggers.interval_trigger import IntervalTrigger
|
||||
from parakeet.training.triggers.limit_trigger import LimitTrigger
|
||||
from parakeet.training.triggers.time_trigger import TimeTrigger
|
||||
|
||||
|
||||
def never_file_trigger(trainer):
|
||||
|
|
|
@ -19,17 +19,13 @@ class IntervalTrigger(object):
|
|||
def __init__(self, period: int, unit: str):
|
||||
if unit not in ("iteration", "epoch"):
|
||||
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||
if period <= 0:
|
||||
raise ValueError("period should be a positive integer.")
|
||||
self.period = period
|
||||
self.unit = unit
|
||||
|
||||
def __call__(self, trainer):
|
||||
state = trainer.updater.state
|
||||
# we use a special scheme so we can use iteration % period == 0 as
|
||||
# the predicate
|
||||
# increase the iteration then update parameters
|
||||
# instead of updating then increase iteration
|
||||
if self.unit == "epoch":
|
||||
fire = state.epoch % self.period == 0
|
||||
else:
|
||||
fire = state.iteration % self.period == 0
|
||||
index = getattr(state, self.unit)
|
||||
fire = index % self.period == 0
|
||||
return fire
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class LimitTrigger(object):
|
||||
"""A Predicate to decide whether to stop."""
|
||||
|
||||
def __init__(self, limit: int, unit: str):
|
||||
if unit not in ("iteration", "epoch"):
|
||||
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||
if limit <= 0:
|
||||
raise ValueError("limit should be a positive integer.")
|
||||
self.limit = limit
|
||||
self.unit = unit
|
||||
|
||||
def __call__(self, trainer):
|
||||
state = trainer.updater.state
|
||||
index = getattr(state, self.unit)
|
||||
fire = index >= self.limit
|
||||
return fire
|
|
@ -62,7 +62,40 @@ class StandardUpdater(UpdaterBase):
|
|||
self.train_iterator = iter(dataloader)
|
||||
|
||||
def update(self):
|
||||
self.state.iteration += 1
|
||||
# We increase the iteration index after updating and before extension.
|
||||
# Here are the reasons.
|
||||
|
||||
# 0. Snapshotting(as well as other extensions, like visualizer) is
|
||||
# executed after a step of updating;
|
||||
# 1. We decide to increase the iteration index after updating and
|
||||
# before any all extension is executed.
|
||||
# 3. We do not increase the iteration after extension because we
|
||||
# prefer a consistent resume behavior, when load from a
|
||||
# `snapshot_iter_100.pdz` then the next step to train is `101`,
|
||||
# naturally. But if iteration is increased increased after
|
||||
# extension(including snapshot), then, a `snapshot_iter_99` is
|
||||
# loaded. You would need a extra increasing of the iteration idex
|
||||
# before training to avoid another iteration `99`, which has been
|
||||
# done before snapshotting.
|
||||
# 4. Thus iteration index represrnts "currently how mant epochs has
|
||||
# been done."
|
||||
# NOTE: use report to capture the correctly value. If you want to
|
||||
# report the learning rate used for a step, you must report it before
|
||||
# the learning rate scheduler's step() has been called. In paddle's
|
||||
# convention, we do not use an extension to change the learning rate.
|
||||
# so if you want to report it, do it in the updater.
|
||||
|
||||
# Then here comes the next question. When is the proper time to
|
||||
# increase the epoch index? Since all extensions are executed after
|
||||
# updating, it is the time that after updating is the proper time to
|
||||
# increase epoch index.
|
||||
# 1. If we increase the epoch index before updating, then an extension
|
||||
# based ot epoch would miss the correct timing. It could only be
|
||||
# triggerd after an extra updating.
|
||||
# 2. Theoretically, when an epoch is done, the epoch index should be
|
||||
# increased. So it would be increase after updating.
|
||||
# 3. Thus, eppoch index represents "currently how many epochs has been
|
||||
# done." So it starts from 0.
|
||||
|
||||
# switch to training mode
|
||||
for layer in self.models.values():
|
||||
|
@ -72,6 +105,11 @@ class StandardUpdater(UpdaterBase):
|
|||
batch = self.read_batch()
|
||||
self.update_core(batch)
|
||||
|
||||
self.state.iteration += 1
|
||||
if self.updaters_per_epoch is not None:
|
||||
if self.state.iteration % self.updaters_per_epoch == 0:
|
||||
self.state.epoch += 1
|
||||
|
||||
def update_core(self, batch):
|
||||
"""A simple case for a training step. Basic assumptions are:
|
||||
Single model;
|
||||
|
@ -100,10 +138,20 @@ class StandardUpdater(UpdaterBase):
|
|||
loss_dict["main"].backward()
|
||||
self.optimizer.update()
|
||||
|
||||
@property
|
||||
def updaters_per_epoch(self):
|
||||
"""Number of updater per epoch, determined by the length of the
|
||||
dataloader."""
|
||||
length_of_dataloader = None
|
||||
try:
|
||||
length_of_dataloader = len(self.dataloader)
|
||||
except TypeError:
|
||||
logging.debug("This dataloader has no __len__.")
|
||||
finally:
|
||||
return length_of_dataloader
|
||||
|
||||
def new_epoch(self):
|
||||
"""Start a new epoch."""
|
||||
self.state.epoch += 1
|
||||
|
||||
# NOTE: all batch sampler for distributed training should
|
||||
# subclass DistributedBatchSampler and implement `set_epoch` method
|
||||
batch_sampler = self.dataloader.batch_sampler
|
||||
|
@ -140,13 +188,3 @@ class StandardUpdater(UpdaterBase):
|
|||
for name, optim in self.optimizers.items():
|
||||
optim.set_state_dict(state_dict[f"{name}_optimizer"])
|
||||
super().set_state_dict(state_dict)
|
||||
|
||||
def save(self, path):
|
||||
"""Save Updater state dict."""
|
||||
archive = self.state_dict()
|
||||
paddle.save(archive, path)
|
||||
|
||||
def load(self, path):
|
||||
"""Load Updater state dict."""
|
||||
archive = paddle.load(path)
|
||||
self.set_state_dict(archive)
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.optimizer import Adam
|
||||
from paddle.optimizer.lr import StepDecay
|
||||
|
||||
|
||||
def test_optimizer():
|
||||
model1 = nn.Linear(3, 4)
|
||||
optim1 = Adam(
|
||||
parameters=model1.parameters(), learning_rate=StepDecay(0.1, 100))
|
||||
|
||||
output_dir = Path("temp_test_optimizer")
|
||||
shutil.rmtree(output_dir, ignore_errors=True)
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# model1.set_state_dict(model1.state_dict())
|
||||
optim1.set_state_dict(optim1.state_dict())
|
||||
|
||||
x = paddle.randn([6, 3])
|
||||
y = model1(x).sum()
|
||||
y.backward()
|
||||
optim1.step()
|
Loading…
Reference in New Issue