refined training module
This commit is contained in:
parent
29b8b8b0ea
commit
ef51e1ab13
|
@ -125,4 +125,5 @@ log_interval_steps: 100 # Interval steps to record the training
|
||||||
# OTHER SETTING #
|
# OTHER SETTING #
|
||||||
###########################################################
|
###########################################################
|
||||||
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
|
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
|
||||||
num_snapshots: 10
|
num_snapshots: 10
|
||||||
|
seed: 42
|
|
@ -74,18 +74,15 @@ class PWGUpdater(StandardUpdater):
|
||||||
# Generator
|
# Generator
|
||||||
noise = paddle.randn(wav.shape)
|
noise = paddle.randn(wav.shape)
|
||||||
|
|
||||||
synchronize()
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
wav_ = self.generator(noise, mel)
|
wav_ = self.generator(noise, mel)
|
||||||
synchronize()
|
|
||||||
logging.debug(f"Generator takes {t.elapse}s.")
|
logging.debug(f"Generator takes {t.elapse}s.")
|
||||||
|
|
||||||
## Multi-resolution stft loss
|
## Multi-resolution stft loss
|
||||||
synchronize()
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
sc_loss, mag_loss = self.criterion_stft(
|
sc_loss, mag_loss = self.criterion_stft(
|
||||||
wav_.squeeze(1), wav.squeeze(1))
|
wav_.squeeze(1), wav.squeeze(1))
|
||||||
synchronize()
|
|
||||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
||||||
|
|
||||||
report("train/spectral_convergence_loss", float(sc_loss))
|
report("train/spectral_convergence_loss", float(sc_loss))
|
||||||
|
@ -94,11 +91,9 @@ class PWGUpdater(StandardUpdater):
|
||||||
|
|
||||||
## Adversarial loss
|
## Adversarial loss
|
||||||
if self.state.iteration > self.discriminator_train_start_steps:
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
synchronize()
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
p_ = self.discriminator(wav_)
|
p_ = self.discriminator(wav_)
|
||||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
synchronize()
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
report("train/adversarial_loss", float(adv_loss))
|
report("train/adversarial_loss", float(adv_loss))
|
||||||
|
@ -106,18 +101,14 @@ class PWGUpdater(StandardUpdater):
|
||||||
|
|
||||||
report("train/generator_loss", float(gen_loss))
|
report("train/generator_loss", float(gen_loss))
|
||||||
|
|
||||||
synchronize()
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
self.optimizer_g.clear_grad()
|
self.optimizer_g.clear_grad()
|
||||||
gen_loss.backward()
|
gen_loss.backward()
|
||||||
synchronize()
|
|
||||||
logging.debug(f"Backward takes {t.elapse}s.")
|
logging.debug(f"Backward takes {t.elapse}s.")
|
||||||
|
|
||||||
synchronize()
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
self.optimizer_g.step()
|
self.optimizer_g.step()
|
||||||
self.scheduler_g.step()
|
self.scheduler_g.step()
|
||||||
synchronize()
|
|
||||||
logging.debug(f"Update takes {t.elapse}s.")
|
logging.debug(f"Update takes {t.elapse}s.")
|
||||||
|
|
||||||
# Disctiminator
|
# Disctiminator
|
||||||
|
@ -158,18 +149,15 @@ class PWGEvaluator(StandardEvaluator):
|
||||||
wav, mel = batch
|
wav, mel = batch
|
||||||
noise = paddle.randn(wav.shape)
|
noise = paddle.randn(wav.shape)
|
||||||
|
|
||||||
synchronize()
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
wav_ = self.generator(noise, mel)
|
wav_ = self.generator(noise, mel)
|
||||||
synchronize()
|
|
||||||
logging.debug(f"Generator takes {t.elapse}s")
|
logging.debug(f"Generator takes {t.elapse}s")
|
||||||
|
|
||||||
## Multi-resolution stft loss
|
## Multi-resolution stft loss
|
||||||
synchronize()
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
sc_loss, mag_loss = self.criterion_stft(
|
sc_loss, mag_loss = self.criterion_stft(
|
||||||
wav_.squeeze(1), wav.squeeze(1))
|
wav_.squeeze(1), wav.squeeze(1))
|
||||||
synchronize()
|
|
||||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||||
|
|
||||||
report("eval/spectral_convergence_loss", float(sc_loss))
|
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||||
|
@ -177,11 +165,9 @@ class PWGEvaluator(StandardEvaluator):
|
||||||
gen_loss = sc_loss + mag_loss
|
gen_loss = sc_loss + mag_loss
|
||||||
|
|
||||||
## Adversarial loss
|
## Adversarial loss
|
||||||
synchronize()
|
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
p_ = self.discriminator(wav_)
|
p_ = self.discriminator(wav_)
|
||||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
synchronize()
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
report("eval/adversarial_loss", float(adv_loss))
|
report("eval/adversarial_loss", float(adv_loss))
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from timer import timer
|
||||||
import logging
|
import logging
|
||||||
import argparse
|
import argparse
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -25,6 +26,8 @@ import numpy as np
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
from paddle import distributed as dist
|
from paddle import distributed as dist
|
||||||
|
|
||||||
|
paddle.set_device("cpu")
|
||||||
|
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
from parakeet.models.parallel_wavegan import PWGGenerator
|
from parakeet.models.parallel_wavegan import PWGGenerator
|
||||||
|
|
||||||
|
@ -71,11 +74,20 @@ test_dataset = DataTable(
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
N = 0
|
||||||
|
T = 0
|
||||||
for example in test_dataset:
|
for example in test_dataset:
|
||||||
utt_id = example['utt_id']
|
utt_id = example['utt_id']
|
||||||
mel = example['feats']
|
mel = example['feats']
|
||||||
mel = paddle.to_tensor(mel) # (T, C)
|
mel = paddle.to_tensor(mel) # (T, C)
|
||||||
wav = generator.inference(c=mel)
|
with timer() as t:
|
||||||
wav = wav.numpy()
|
wav = generator.inference(c=mel)
|
||||||
print(f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}")
|
wav = wav.numpy()
|
||||||
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=24000)
|
N += wav.size
|
||||||
|
T += t.elapse
|
||||||
|
speed = wav.size / t.elapse
|
||||||
|
print(
|
||||||
|
f"{utt_id}, mel: {mel.shape}, wave: {wav.shape}, time: {t.elapse}s, Hz: {speed}, RTF: {config.sr / speed}."
|
||||||
|
)
|
||||||
|
sf.write(output_dir / (utt_id + ".wav"), wav, samplerate=config.sr)
|
||||||
|
print(f"generation speed: {N / T}Hz, RTF: {config.sr / (N / T) }")
|
||||||
|
|
|
@ -60,7 +60,7 @@ def train_sp(args, config):
|
||||||
paddle.distributed.init_parallel_env()
|
paddle.distributed.init_parallel_env()
|
||||||
|
|
||||||
# set the random seed, it is a must for multiprocess training
|
# set the random seed, it is a must for multiprocess training
|
||||||
seed_everything(42)
|
seed_everything(config.seed)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||||
|
@ -149,6 +149,8 @@ def train_sp(args, config):
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
checkpoint_dir = output_dir / "checkpoints"
|
checkpoint_dir = output_dir / "checkpoints"
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
|
with open(output_dir / "config.yaml", 'wt') as f:
|
||||||
|
f.write(config.dump(default_flow_style=None))
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
|
@ -12,9 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
__version__ = "0.2.0-beta.0"
|
__version__ = "0.0.0"
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
|
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
|
||||||
|
|
||||||
logging.getLogger('parakeet').addHandler(logging.NullHandler())
|
|
||||||
|
|
|
@ -43,9 +43,10 @@ class SpectralConvergenceLoss(nn.Layer):
|
||||||
class LogSTFTMagnitudeLoss(nn.Layer):
|
class LogSTFTMagnitudeLoss(nn.Layer):
|
||||||
"""Log STFT magnitude loss module."""
|
"""Log STFT magnitude loss module."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, epsilon=1e-10):
|
||||||
"""Initilize los STFT magnitude loss module."""
|
"""Initilize los STFT magnitude loss module."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
def forward(self, x_mag, y_mag):
|
def forward(self, x_mag, y_mag):
|
||||||
"""Calculate forward propagation.
|
"""Calculate forward propagation.
|
||||||
|
@ -57,9 +58,9 @@ class LogSTFTMagnitudeLoss(nn.Layer):
|
||||||
"""
|
"""
|
||||||
return F.l1_loss(
|
return F.l1_loss(
|
||||||
paddle.log(paddle.clip(
|
paddle.log(paddle.clip(
|
||||||
y_mag, min=1e-10)),
|
y_mag, min=self.epsilon)),
|
||||||
paddle.log(paddle.clip(
|
paddle.log(paddle.clip(
|
||||||
x_mag, min=1e-10)))
|
x_mag, min=self.epsilon)))
|
||||||
|
|
||||||
|
|
||||||
class STFTLoss(nn.Layer):
|
class STFTLoss(nn.Layer):
|
||||||
|
|
|
@ -106,4 +106,5 @@ class Snapshot(extension.Extension):
|
||||||
record_path = self.checkpoint_dir / "records.jsonl"
|
record_path = self.checkpoint_dir / "records.jsonl"
|
||||||
with jsonlines.open(record_path, 'w') as writer:
|
with jsonlines.open(record_path, 'w') as writer:
|
||||||
for record in self.records:
|
for record in self.records:
|
||||||
writer.write(record)
|
# jsonlines.open may return a Writer or a Reader
|
||||||
|
writer.write(record) # pylint: disable=no-member
|
||||||
|
|
|
@ -21,7 +21,7 @@ from typing import Callable, Union, List
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from parakeet.training.trigger import get_trigger, IntervalTrigger
|
from parakeet.training.trigger import get_trigger, IntervalTrigger, LimitTrigger
|
||||||
from parakeet.training.updater import UpdaterBase
|
from parakeet.training.updater import UpdaterBase
|
||||||
from parakeet.training.reporter import scope
|
from parakeet.training.reporter import scope
|
||||||
from parakeet.training.extension import Extension, PRIORITY_READER
|
from parakeet.training.extension import Extension, PRIORITY_READER
|
||||||
|
@ -42,7 +42,7 @@ class Trainer(object):
|
||||||
extensions: List[Extension]=None):
|
extensions: List[Extension]=None):
|
||||||
self.updater = updater
|
self.updater = updater
|
||||||
self.extensions = OrderedDict()
|
self.extensions = OrderedDict()
|
||||||
self.stop_trigger = get_trigger(stop_trigger)
|
self.stop_trigger = LimitTrigger(*stop_trigger)
|
||||||
self.out = Path(out)
|
self.out = Path(out)
|
||||||
self.observation =...
|
self.observation =...
|
||||||
|
|
||||||
|
@ -125,16 +125,19 @@ class Trainer(object):
|
||||||
|
|
||||||
print(self.updater.state)
|
print(self.updater.state)
|
||||||
|
|
||||||
# TODO(chenfeiyu): display progress bar correctly
|
# display only one progress bar
|
||||||
# if the trainer is controlled by epoch: use 2 progressbars
|
max_iteration = None
|
||||||
# if the trainer is controlled by iteration: use 1 progressbar
|
if isinstance(stop_trigger, LimitTrigger):
|
||||||
if isinstance(stop_trigger, IntervalTrigger):
|
|
||||||
if stop_trigger.unit is 'epoch':
|
if stop_trigger.unit is 'epoch':
|
||||||
max_epoch = self.stop_trigger.period
|
max_epoch = self.stop_trigger.limit
|
||||||
|
updates_per_epoch = getattr(self.updater, "updates_per_epoch",
|
||||||
|
None)
|
||||||
|
max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None
|
||||||
else:
|
else:
|
||||||
max_iteration = self.stop_trigger.period
|
max_iteration = self.stop_trigger.limit
|
||||||
|
|
||||||
p = tqdm.tqdm(initial=self.updater.state.iteration)
|
p = tqdm.tqdm(
|
||||||
|
initial=self.updater.state.iteration, total=max_iteration)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not stop_trigger(self):
|
while not stop_trigger(self):
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from parakeet.training.triggers.interval_trigger import IntervalTrigger
|
from parakeet.training.triggers.interval_trigger import IntervalTrigger
|
||||||
|
from parakeet.training.triggers.limit_trigger import LimitTrigger
|
||||||
|
from parakeet.training.triggers.time_trigger import TimeTrigger
|
||||||
|
|
||||||
|
|
||||||
def never_file_trigger(trainer):
|
def never_file_trigger(trainer):
|
||||||
|
|
|
@ -19,17 +19,13 @@ class IntervalTrigger(object):
|
||||||
def __init__(self, period: int, unit: str):
|
def __init__(self, period: int, unit: str):
|
||||||
if unit not in ("iteration", "epoch"):
|
if unit not in ("iteration", "epoch"):
|
||||||
raise ValueError("unit should be 'iteration' or 'epoch'")
|
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||||
|
if period <= 0:
|
||||||
|
raise ValueError("period should be a positive integer.")
|
||||||
self.period = period
|
self.period = period
|
||||||
self.unit = unit
|
self.unit = unit
|
||||||
|
|
||||||
def __call__(self, trainer):
|
def __call__(self, trainer):
|
||||||
state = trainer.updater.state
|
state = trainer.updater.state
|
||||||
# we use a special scheme so we can use iteration % period == 0 as
|
index = getattr(state, self.unit)
|
||||||
# the predicate
|
fire = index % self.period == 0
|
||||||
# increase the iteration then update parameters
|
|
||||||
# instead of updating then increase iteration
|
|
||||||
if self.unit == "epoch":
|
|
||||||
fire = state.epoch % self.period == 0
|
|
||||||
else:
|
|
||||||
fire = state.iteration % self.period == 0
|
|
||||||
return fire
|
return fire
|
||||||
|
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
class LimitTrigger(object):
|
||||||
|
"""A Predicate to decide whether to stop."""
|
||||||
|
|
||||||
|
def __init__(self, limit: int, unit: str):
|
||||||
|
if unit not in ("iteration", "epoch"):
|
||||||
|
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||||
|
if limit <= 0:
|
||||||
|
raise ValueError("limit should be a positive integer.")
|
||||||
|
self.limit = limit
|
||||||
|
self.unit = unit
|
||||||
|
|
||||||
|
def __call__(self, trainer):
|
||||||
|
state = trainer.updater.state
|
||||||
|
index = getattr(state, self.unit)
|
||||||
|
fire = index >= self.limit
|
||||||
|
return fire
|
|
@ -62,7 +62,40 @@ class StandardUpdater(UpdaterBase):
|
||||||
self.train_iterator = iter(dataloader)
|
self.train_iterator = iter(dataloader)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.state.iteration += 1
|
# We increase the iteration index after updating and before extension.
|
||||||
|
# Here are the reasons.
|
||||||
|
|
||||||
|
# 0. Snapshotting(as well as other extensions, like visualizer) is
|
||||||
|
# executed after a step of updating;
|
||||||
|
# 1. We decide to increase the iteration index after updating and
|
||||||
|
# before any all extension is executed.
|
||||||
|
# 3. We do not increase the iteration after extension because we
|
||||||
|
# prefer a consistent resume behavior, when load from a
|
||||||
|
# `snapshot_iter_100.pdz` then the next step to train is `101`,
|
||||||
|
# naturally. But if iteration is increased increased after
|
||||||
|
# extension(including snapshot), then, a `snapshot_iter_99` is
|
||||||
|
# loaded. You would need a extra increasing of the iteration idex
|
||||||
|
# before training to avoid another iteration `99`, which has been
|
||||||
|
# done before snapshotting.
|
||||||
|
# 4. Thus iteration index represrnts "currently how mant epochs has
|
||||||
|
# been done."
|
||||||
|
# NOTE: use report to capture the correctly value. If you want to
|
||||||
|
# report the learning rate used for a step, you must report it before
|
||||||
|
# the learning rate scheduler's step() has been called. In paddle's
|
||||||
|
# convention, we do not use an extension to change the learning rate.
|
||||||
|
# so if you want to report it, do it in the updater.
|
||||||
|
|
||||||
|
# Then here comes the next question. When is the proper time to
|
||||||
|
# increase the epoch index? Since all extensions are executed after
|
||||||
|
# updating, it is the time that after updating is the proper time to
|
||||||
|
# increase epoch index.
|
||||||
|
# 1. If we increase the epoch index before updating, then an extension
|
||||||
|
# based ot epoch would miss the correct timing. It could only be
|
||||||
|
# triggerd after an extra updating.
|
||||||
|
# 2. Theoretically, when an epoch is done, the epoch index should be
|
||||||
|
# increased. So it would be increase after updating.
|
||||||
|
# 3. Thus, eppoch index represents "currently how many epochs has been
|
||||||
|
# done." So it starts from 0.
|
||||||
|
|
||||||
# switch to training mode
|
# switch to training mode
|
||||||
for layer in self.models.values():
|
for layer in self.models.values():
|
||||||
|
@ -72,6 +105,11 @@ class StandardUpdater(UpdaterBase):
|
||||||
batch = self.read_batch()
|
batch = self.read_batch()
|
||||||
self.update_core(batch)
|
self.update_core(batch)
|
||||||
|
|
||||||
|
self.state.iteration += 1
|
||||||
|
if self.updaters_per_epoch is not None:
|
||||||
|
if self.state.iteration % self.updaters_per_epoch == 0:
|
||||||
|
self.state.epoch += 1
|
||||||
|
|
||||||
def update_core(self, batch):
|
def update_core(self, batch):
|
||||||
"""A simple case for a training step. Basic assumptions are:
|
"""A simple case for a training step. Basic assumptions are:
|
||||||
Single model;
|
Single model;
|
||||||
|
@ -100,10 +138,20 @@ class StandardUpdater(UpdaterBase):
|
||||||
loss_dict["main"].backward()
|
loss_dict["main"].backward()
|
||||||
self.optimizer.update()
|
self.optimizer.update()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def updaters_per_epoch(self):
|
||||||
|
"""Number of updater per epoch, determined by the length of the
|
||||||
|
dataloader."""
|
||||||
|
length_of_dataloader = None
|
||||||
|
try:
|
||||||
|
length_of_dataloader = len(self.dataloader)
|
||||||
|
except TypeError:
|
||||||
|
logging.debug("This dataloader has no __len__.")
|
||||||
|
finally:
|
||||||
|
return length_of_dataloader
|
||||||
|
|
||||||
def new_epoch(self):
|
def new_epoch(self):
|
||||||
"""Start a new epoch."""
|
"""Start a new epoch."""
|
||||||
self.state.epoch += 1
|
|
||||||
|
|
||||||
# NOTE: all batch sampler for distributed training should
|
# NOTE: all batch sampler for distributed training should
|
||||||
# subclass DistributedBatchSampler and implement `set_epoch` method
|
# subclass DistributedBatchSampler and implement `set_epoch` method
|
||||||
batch_sampler = self.dataloader.batch_sampler
|
batch_sampler = self.dataloader.batch_sampler
|
||||||
|
@ -140,13 +188,3 @@ class StandardUpdater(UpdaterBase):
|
||||||
for name, optim in self.optimizers.items():
|
for name, optim in self.optimizers.items():
|
||||||
optim.set_state_dict(state_dict[f"{name}_optimizer"])
|
optim.set_state_dict(state_dict[f"{name}_optimizer"])
|
||||||
super().set_state_dict(state_dict)
|
super().set_state_dict(state_dict)
|
||||||
|
|
||||||
def save(self, path):
|
|
||||||
"""Save Updater state dict."""
|
|
||||||
archive = self.state_dict()
|
|
||||||
paddle.save(archive, path)
|
|
||||||
|
|
||||||
def load(self, path):
|
|
||||||
"""Load Updater state dict."""
|
|
||||||
archive = paddle.load(path)
|
|
||||||
self.set_state_dict(archive)
|
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.optimizer import Adam
|
||||||
|
from paddle.optimizer.lr import StepDecay
|
||||||
|
|
||||||
|
|
||||||
|
def test_optimizer():
|
||||||
|
model1 = nn.Linear(3, 4)
|
||||||
|
optim1 = Adam(
|
||||||
|
parameters=model1.parameters(), learning_rate=StepDecay(0.1, 100))
|
||||||
|
|
||||||
|
output_dir = Path("temp_test_optimizer")
|
||||||
|
shutil.rmtree(output_dir, ignore_errors=True)
|
||||||
|
output_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
# model1.set_state_dict(model1.state_dict())
|
||||||
|
optim1.set_state_dict(optim1.state_dict())
|
||||||
|
|
||||||
|
x = paddle.randn([6, 3])
|
||||||
|
y = model1(x).sum()
|
||||||
|
y.backward()
|
||||||
|
optim1.step()
|
Loading…
Reference in New Issue