add better traininig utility code

This commit is contained in:
iclementine 2021-06-27 08:32:37 +08:00
parent f9105db727
commit 61c13dd69b
14 changed files with 452 additions and 408 deletions

View File

@ -14,4 +14,7 @@
__version__ = "0.2.0-beta.0"
import logging
from parakeet import audio, data, datasets, frontend, models, modules, training, utils
logging.getLogger('parakeet').addHandler(logging.NullHandler())

View File

@ -1,163 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Mapping, List, Union
import os
from pathlib import Path
class KBest(object):
"""
A utility class to help save the hard drive by only keeping K best
checkpoints.
To be as modularized as possible, this class does not assume anything like
a Trainer class or anything like a checkpoint directory, it does not know
about the model or the optimizer, etc.
It is basically a dynamically mantained K-bset Mapping. When a new item is
added to the map, save_fn is called. And when an item is removed from the
map, del_fn is called. `save_fn` and `del_fn` takes a Path object as input
and returns nothing.
Though it is designed to control checkpointing behaviors, it can be used
to do something else if you pass some save_fn and del_fn.
Example
--------
>>> from pathlib import Path
>>> import shutil
>>> import paddle
>>> from paddle import nn
>>> model = nn.Linear(2, 3)
>>> def save_model(path):
... paddle.save(model.state_dict(), path)
>>> kbest_manager = KBest(max_size=5, save_fn=save_model)
>>> checkpoint_dir = Path("checkpoints")
>>> shutil.rmtree(checkpoint_dir)
>>> checkpoint_dir.mkdir(parents=True)
>>> a = np.random.rand(20)
>>> for i, score in enumerate(a):
... path = checkpoint_dir / f"step_{i}"
... kbest_manager.add_checkpoint(score, path)
>>> assert len(list(checkpoint_dir.glob("step_*"))) == 5
"""
def __init__(self,
max_size: int=5,
save_fn: Callable[[Union[Path, str]], None]=None,
del_fn: Callable[[Union[Path, str]], None]=os.remove):
self.best_records: Mapping[Path, float] = {}
self.save_fn = save_fn
self.del_fn = del_fn
self.max_size = max_size
self._save_all = (max_size == -1)
def should_save(self, metric: float) -> bool:
if not self.full():
return True
# already full
worst_record_path = max(self.best_records, key=self.best_records.get)
worst_metric = self.best_records[worst_record_path]
return metric < worst_metric
def full(self):
return (not self._save_all) and len(self.best_records) == self.max_size
def add_checkpoint(self, metric, path):
if self.should_save(metric):
self.save_checkpoint_and_update(metric, path)
def save_checkpoint_and_update(self, metric, path):
# remove the worst
if self.full():
worst_record_path = max(self.best_records,
key=self.best_records.get)
self.best_records.pop(worst_record_path)
self.del_fn(worst_record_path)
# add the new one
self.save_fn(path)
self.best_records[path] = metric
class KLatest(object):
"""
A utility class to help save the hard drive by only keeping K latest
checkpoints.
To be as modularized as possible, this class does not assume anything like
a Trainer class or anything like a checkpoint directory, it does not know
about the model or the optimizer, etc.
It is basically a dynamically mantained Queue. When a new item is
added to the queue, save_fn is called. And when an item is removed from the
queue, del_fn is called. `save_fn` and `del_fn` takes a Path object as input
and returns nothing.
Though it is designed to control checkpointing behaviors, it can be used
to do something else if you pass some save_fn and del_fn.
Example
--------
>>> from pathlib import Path
>>> import shutil
>>> import paddle
>>> from paddle import nn
>>> model = nn.Linear(2, 3)
>>> def save_model(path):
... paddle.save(model.state_dict(), path)
>>> klatest_manager = KLatest(max_size=5, save_fn=save_model)
>>> checkpoint_dir = Path("checkpoints")
>>> shutil.rmtree(checkpoint_dir)
>>> checkpoint_dir.mkdir(parents=True)
>>> for i in range(20):
... path = checkpoint_dir / f"step_{i}"
... klatest_manager.add_checkpoint(path)
>>> assert len(list(checkpoint_dir.glob("step_*"))) == 5
"""
def __init__(self,
max_size: int=5,
save_fn: Callable[[Path], None]=None,
del_fn: Callable[[Path], None]=lambda f: f.unlink()):
self.latest_records: List[Path] = []
self.save_fn = save_fn
self.del_fn = del_fn
self.max_size = max_size
self._save_all = (max_size == -1)
def full(self):
return (
not self._save_all) and len(self.latest_records) == self.max_size
def add_checkpoint(self, path):
self.save_checkpoint_and_update(path)
def save_checkpoint_and_update(self, path):
# remove the earist
if self.full():
eariest_record_path = self.latest_records.pop(0)
self.del_fn(eariest_record_path)
# add the new one
self.save_fn(path)
self.latest_records.append(path)

View File

@ -0,0 +1,80 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable
PRIORITY_WRITER = 300
PRIORITY_EDITOR = 200
PRIORITY_READER = 100
class Extension(object):
"""Extension to customize the behavior of Trainer."""
trigger = (1, 'iteration')
priority = PRIORITY_READER
name = None
@property
def default_name(self):
"""Default name of the extension, class name by default."""
return type(self).__name__
def __call__(self, trainer):
"""Main action of the extention. After each update, it is executed
when the trigger fires."""
raise NotImplementedError(
'Extension implementation must override __call__.')
def initialize(self, trainer):
"""Action that is executed once to get the corect trainer state.
It is called before training normally, but if the trainer restores
states with an Snapshot extension, this method should also be called.g
"""
pass
def on_error(self, trainer, exc, tb):
"""Handles the error raised during training before finalization.
"""
pass
def finalize(self, trainer):
"""Action that is executed when training is done.
For example, visualizers would need to be closed.
"""
pass
def make_extension(trigger: Callable=None,
default_name: str=None,
priority: int=None,
finalizer: Callable=None,
initializer: Callable=None,
on_error: Callable=None):
"""Make an Extension-like object by injecting required attributes to it.
"""
if trigger is None:
trigger = Extension.trigger
if priority is None:
priority = Extension.priority
def decorator(ext):
ext.trigger = trigger
ext.default_name = default_name or ext.__name__
ext.priority = priority
ext.finalize = finalizer
ext.on_error = on_error
ext.initialize = initializer
return ext
return decorator

View File

@ -22,9 +22,17 @@ from paddle.nn import Layer
from paddle.io import DataLoader
from parakeet.training.reporter import scope, report, DictSummary
from parakeet.training import extension
class StandardEvaluator(object):
class StandardEvaluator(extension.Extension):
trigger = (1, 'epoch')
default_name = 'validation'
priority = extension.PRIORITY_WRITER
name = None
def __init__(self, model: Layer, dataloader: DataLoader):
# it is designed to hold multiple models
models = {"main": model}
@ -43,21 +51,23 @@ class StandardEvaluator(object):
for layer in self.models.values():
layer.eval()
# to average evaluation metrics
summary = DictSummary()
for batch in self.dataloader:
observation = {}
with scope(observation):
# main evaluation computation here.
with paddle.no_grad():
self.evaluate_core(
batch) # main evaluation computation here.
self.evaluate_core(batch)
summary.add(observation)
summary = summary.compute_mean()
return summary
def __call__(self, trainer=None):
self.observation = {}
with scope(self.observation):
summary = self.evaluate()
for k, v in summary.items():
report(k, v)
print(self.observation)
# evaluate and report the averaged metric to current observation
# if it is used to extend a trainer, the metrics is reported to
# to observation of the trainer
# or otherwise, you can use your own observation
summary = self.evaluate()
for k, v in summary.items():
report(k, v)

View File

@ -12,18 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, List, Dict, Any
from pathlib import Path
import jsonlines
import os
from pathlib import Path
from datetime import datetime
import logging
from typing import List, Dict, Any
import jsonlines
from parakeet.utils.mp_tools import rank_zero_only
from parakeet.training.trainer import Trainer
from parakeet.training import extension
class Snapshot(object):
def load_records(records_fp):
"""Load record files (json lines.)"""
with jsonlines.open(records_fp, 'r') as reader:
records = list(reader)
return records
class Snapshot(extension.Extension):
"""An extension to make snapshot of the updater object inside
the trainer. It is done by calling the updater's `save` method.
@ -38,34 +46,46 @@ class Snapshot(object):
The directory to save checkpoints into.
"""
def __init__(self, max_size: int=5):
trigger = (1, 'epoch')
priority = -100
def __init__(self, max_size: int=5, snapshot_on_error: bool=False):
self.records: List[Dict[str, Any]] = []
self.max_size = max_size
self._snapshot_on_error = snapshot_on_error
self._save_all = (max_size == -1)
self.save_fn =...
self.del_fn =...
self.del_fn = os.remove
self.checkpoint_dir =...
def initialize(self, trainer):
"""setting up this extention."""
def initialize(self, trainer: Trainer):
"""Setting up this extention."""
self.save_fn = trainer.updater.save
self.del_fn = os.remove
self.checkpoint_dir = trainer.out / "checkpoints"
# load existing records
record_path: Path = self.checkpoint_dir / "records.yaml"
if record_path.exists():
self.records = load_records(record_path)
def on_error(self, trainer, exc, tb):
if self._snapshot_on_error:
self.save_checkpoint_and_update(trainer)
def __call__(self, trainer: Trainer):
self.save_checkpoint_and_update(trainer)
def full(self):
return (not self._save_all) and len(self.records) >= self.max_size
"""Whether the number of snapshots it keeps track of is greater
than the max_size."""
return (not self._save_all) and len(self.records) > self.max_size
@rank_zero_only
def save_checkpoint_and_update(self, trainer):
def save_checkpoint_and_update(self, trainer: Trainer):
"""Saving new snapshot and remove the oldest snapshot if needed."""
iteration = trainer.updater.state.iteration
path = self.checkpoint_dir / f"snapshot_iter_{iteration}.pdz"
# remove the earist
if self.full():
eariest_record = self.records[0]
self.del_fn(eariest_record["path"])
self.records.pop(0)
# add the new one
self.save_fn(path)
record = {
@ -75,10 +95,14 @@ class Snapshot(object):
}
self.records.append(record)
# update the record
with jsonlines.open(self.checkpoint_dir / "records.jsonl", 'w') as f:
for record in self.records:
f.write(record)
# remove the earist
if self.full():
eariest_record = self.records[0]
self.del_fn(eariest_record["path"])
self.records.pop(0)
def __call__(self, trainer):
self.save_checkpoint_and_update(trainer)
# update the record file
record_path = self.checkpoint_dir / "records.jsonl"
with jsonlines.open(record_path, 'w') as writer:
for record in self.records:
writer.write(record)

View File

@ -13,22 +13,31 @@
# limitations under the License.
from visualdl import LogWriter
from parakeet.training.trainer import Trainer
from parakeet.utils.mp_tools import rank_zero_only
from parakeet.training import extension
class VisualDL(object):
class VisualDL(extension.Extension):
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
are all scalars which are recorded into the `.observation` dictionary of the
trainer object. The dictionary is created for each step, thus the visualdl log
are all scalars which are recorded into the `.observation` dictionary of the
trainer object. The dictionary is created for each step, thus the visualdl log
writer uses the iteration from the updater's `iteration` as the global step to
add records.
"""
trigger = (1, 'iteration')
default_name = 'visualdl'
priority = extension.PRIORITY_READER
def __init__(self, writer: LogWriter):
self.writer = writer
def __init__(self):
self.writer =...
def initialize(self, trainer):
self.writer = LogWriter(logdir=str(trainer.out))
@rank_zero_only
def __call__(self, trainer: Trainer):
for k, v in trainer.observation.items():
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
def finalize(self, trainer):
self.writer.close()

View File

@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import logging
import paddle
import random
import numpy as np
def seed_everything(seed: int):
"""Seed paddle, random and np.random to help reproductivity."""
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)

View File

@ -13,15 +13,18 @@
# limitations under the License.
from pathlib import Path
from collections import OrderedDict
from typing import Callable, Union, List
import tqdm
from dataclasses import dataclass
from parakeet.training.trigger import get_trigger, IntervalTrigger
from parakeet.training.updater import UpdaterBase
from parakeet.training.reporter import scope
from parakeet.training.extension import Extension, PRIORITY_READER
class ExtensionEntry(object):
class _ExtensionEntry(object):
def __init__(self, extension, trigger, priority):
self.extension = extension
self.trigger = trigger
@ -31,29 +34,76 @@ class ExtensionEntry(object):
class Trainer(object):
def __init__(self,
updater: UpdaterBase,
stop_trigger=None,
out='result',
extensions=None):
stop_trigger: Callable=None,
out: Union[str, Path]='result',
extensions: List[Extension]=None):
self.updater = updater
self.extensions = {}
self.extensions = OrderedDict()
self.stop_trigger = get_trigger(stop_trigger)
self.out = Path(out)
self.observation = {}
self.observation =...
self._done = False
if extensions:
for ext in extensions:
self.extend(ext)
@property
def is_before_training(self):
return self.updater.state.iteration == 0
def extend(self, extension, name=None, trigger=None, priority=None):
# get name for the extension
# argument \
# -> extention's name \
# -> default_name (class name, when it is an object) \
# -> function name when it is a function \
# -> error
if name is None:
name = getattr(extension, 'name', None)
if name is None:
name = getattr(extenion, 'default_name', None)
if name is None:
name = getattr(extension, '__name__', None)
if name is None:
raise ValueError(
"Name is not given for the extension.")
if name == 'training':
raise ValueError("training is a reserved name.")
if trigger is None:
trigger = getattr(extension, 'trigger', (1, 'iteration'))
trigger = get_trigger(trigger)
if priority is None:
priority = getattr(extension, 'priority', PRIORITY_READER)
# add suffix to avoid nameing conflict
ordinal = 0
modified_name = name
while modified_name in self.extensions:
print(self.extensions.keys())
ordinal += 1
modified_name = f"{name}_{ordinal}"
extension.name = modified_name
self.extensions[modified_name] = ExtensionEntry(extension, trigger,
priority)
self.extensions[modified_name] = _ExtensionEntry(extension, trigger,
priority)
def get_extension(self, name):
"""get extension by name."""
extensions = self.extensions
if name in extensions:
return extensions[name].extension
else:
raise ValueError(f'extension {name} not found')
def run(self):
if self._done:
raise RuntimeError("Training is already done!.")
self.out.mkdir(parents=True, exist_ok=True)
# sort extensions by priorities once
extension_order = sorted(
self.extensions.keys(),
@ -67,7 +117,7 @@ class Trainer(object):
if hasattr(entry.extension, "initialize"):
entry.extension.initialize(self)
update = self.updater.update
update = self.updater.update # training step
stop_trigger = self.stop_trigger
# TODO(chenfeiyu): display progress bar correctly

View File

@ -12,21 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
class IntervalTrigger(object):
def __init__(self, period: int, unit: str):
if unit not in ("iteration", "epoch"):
raise ValueError("unit should be 'iteration' or 'epoch'")
self.period = period
self.unit = unit
def __call__(self, trainer):
state = trainer.updater.state
if self.unit == "epoch":
fire = state.epoch % self.period == 0
else:
fire = state.iteration % self.period == 0
return fire
from parakeet.training.triggers.interval_trigger import IntervalTrigger
def never_file_trigger(trainer):

View File

@ -0,0 +1,35 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class IntervalTrigger(object):
"""A Predicate to do something every N cycle."""
def __init__(self, period: int, unit: str):
if unit not in ("iteration", "epoch"):
raise ValueError("unit should be 'iteration' or 'epoch'")
self.period = period
self.unit = unit
def __call__(self, trainer):
state = trainer.updater.state
# we use a special scheme so we can use iteration % period == 0 as
# the predicate
# increase the iteration then update parameters
# instead of updating then increase iteration
if self.unit == "epoch":
fire = state.epoch % self.period == 0
else:
fire = state.iteration % self.period == 0
return fire

View File

@ -0,0 +1,35 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TimeTrigger(object):
"""Trigger based on a fixed time interval.
This trigger accepts iterations with a given interval time.
Args:
period (float): Interval time. It is given in seconds.
"""
def __init__(self, period):
self._period = period
self._next_time = self._period
def __call__(self, trainer):
if self._next_time < trainer.elapsed_time:
self._next_time += self._period
return True
else:
return False

View File

@ -93,125 +93,3 @@ class UpdaterBase(object):
def load(self, path):
archive = paddle.load(path)
self.set_state_dict(archive)
class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need.
"""
def __init__(self,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state: Optional[UpdaterState]=None):
# it is designed to hold multiple models
models = {"main": model}
self.models: Dict[str, Layer] = models
self.model = model
# it is designed to hold multiple optimizers
optimizers = {"main": optimizer}
self.optimizer = optimizer
self.optimizers: Dict[str, Optimizer] = optimizers
# dataloaders
self.dataloader = dataloader
# init state
if init_state is None:
self.state = UpdaterState()
else:
self.state = init_state
self.train_iterator = iter(dataloader)
def update(self):
self.state.iteration += 1
# switch to training mode
for layer in self.models.values():
layer.train()
# training for a step is implemented here
batch = self.read_batch()
self.update_core(batch)
def update_core(self, batch):
"""A simple case for a training step. Basic assumptions are:
Single model;
Single optimizer;
A batch from the dataloader is just the input of the model;
The model return a single loss, or a dict containing serval losses.
Parameters updates at every batch, no gradient accumulation.
"""
loss = self.model(*batch)
if isinstance(loss, Tensor):
loss_dict = {"main": loss}
else:
# Dict[str, Tensor]
loss_dict = loss
if "main" not in loss_dict:
main_loss = 0
for loss_item in loss.values():
main_loss += loss_item
loss_dict["main"] = main_loss
for name, loss_item in loss_dict.items():
report(name, float(loss_item))
self.optimizer.clear_gradient()
loss_dict["main"].backward()
self.optimizer.update()
def new_epoch(self):
"""Start a new epoch."""
self.state.epoch += 1
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
batch_sampler = self.dataloader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.state.epoch)
self.train_iterator = iter(self.dataloader)
def read_batch(self):
"""Read a batch from the data loader, auto renew when data is exhausted."""
with timer() as t:
try:
batch = next(self.train_iterator)
except StopIteration:
self.new_epoch()
batch = next(self.train_iterator)
logging.debug(
f"Read a batch takes {t.elapse}s.") # replace it with logging
return batch
def state_dict(self):
"""State dict of a Updater, model, optimizer and updater state are included."""
state_dict = super().state_dict()
for name, layer in self.models.items():
state_dict[f"{name}_params"] = layer.state_dict()
for name, optim in self.optimizers.items():
state_dict[f"{name}_optimizer"] = optim.state_dict()
return state_dict
def set_state_dict(self, state_dict):
"""Set state dict for a Updater. Parameters of models, states for
optimizers and UpdaterState are restored."""
for name, layer in self.models.items():
layer.set_state_dict(state_dict[f"{name}_params"])
for name, optim in self.optimizers.items():
optim.set_state_dict(state_dict[f"{name}_optimizer"])
super().set_state_dict(state_dict)
def save(self, path):
"""Save Updater state dict."""
archive = self.state_dict()
paddle.save(archive, path)
def load(self, path):
"""Load Updater state dict."""
archive = paddle.load(path)
self.set_state_dict(archive)

View File

@ -0,0 +1,152 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass
from typing import Optional
from typing import Dict
from typing import Union
from timer import timer
import paddle
from paddle import Tensor
from paddle.nn import Layer
from paddle.optimizer import Optimizer
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler
from parakeet.training.reporter import report
from parakeet.training.updater import UpdaterBase, UpdaterState
class StandardUpdater(UpdaterBase):
"""An example of over-simplification. Things may not be that simple, but
you can subclass it to fit your need.
"""
def __init__(self,
model: Layer,
optimizer: Optimizer,
dataloader: DataLoader,
init_state: Optional[UpdaterState]=None):
# it is designed to hold multiple models
models = {"main": model}
self.models: Dict[str, Layer] = models
self.model = model
# it is designed to hold multiple optimizers
optimizers = {"main": optimizer}
self.optimizer = optimizer
self.optimizers: Dict[str, Optimizer] = optimizers
# dataloaders
self.dataloader = dataloader
# init state
if init_state is None:
self.state = UpdaterState()
else:
self.state = init_state
self.train_iterator = iter(dataloader)
def update(self):
self.state.iteration += 1
# switch to training mode
for layer in self.models.values():
layer.train()
# training for a step is implemented here
batch = self.read_batch()
self.update_core(batch)
def update_core(self, batch):
"""A simple case for a training step. Basic assumptions are:
Single model;
Single optimizer;
A batch from the dataloader is just the input of the model;
The model return a single loss, or a dict containing serval losses.
Parameters updates at every batch, no gradient accumulation.
"""
loss = self.model(*batch)
if isinstance(loss, Tensor):
loss_dict = {"main": loss}
else:
# Dict[str, Tensor]
loss_dict = loss
if "main" not in loss_dict:
main_loss = 0
for loss_item in loss.values():
main_loss += loss_item
loss_dict["main"] = main_loss
for name, loss_item in loss_dict.items():
report(name, float(loss_item))
self.optimizer.clear_gradient()
loss_dict["main"].backward()
self.optimizer.update()
def new_epoch(self):
"""Start a new epoch."""
self.state.epoch += 1
# NOTE: all batch sampler for distributed training should
# subclass DistributedBatchSampler and implement `set_epoch` method
batch_sampler = self.dataloader.batch_sampler
if isinstance(batch_sampler, DistributedBatchSampler):
batch_sampler.set_epoch(self.state.epoch)
self.train_iterator = iter(self.dataloader)
def read_batch(self):
"""Read a batch from the data loader, auto renew when data is exhausted."""
with timer() as t:
try:
batch = next(self.train_iterator)
except StopIteration:
self.new_epoch()
batch = next(self.train_iterator)
logging.debug(
f"Read a batch takes {t.elapse}s.") # replace it with logging
return batch
def state_dict(self):
"""State dict of a Updater, model, optimizer and updater state are included."""
state_dict = super().state_dict()
for name, layer in self.models.items():
state_dict[f"{name}_params"] = layer.state_dict()
for name, optim in self.optimizers.items():
state_dict[f"{name}_optimizer"] = optim.state_dict()
return state_dict
def set_state_dict(self, state_dict):
"""Set state dict for a Updater. Parameters of models, states for
optimizers and UpdaterState are restored."""
for name, layer in self.models.items():
layer.set_state_dict(state_dict[f"{name}_params"])
for name, optim in self.optimizers.items():
optim.set_state_dict(state_dict[f"{name}_optimizer"])
super().set_state_dict(state_dict)
def save(self, path):
"""Save Updater state dict."""
archive = self.state_dict()
paddle.save(archive, path)
def load(self, path):
"""Load Updater state dict."""
archive = paddle.load(path)
self.set_state_dict(archive)

View File

@ -1,56 +0,0 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import shutil
import numpy as np
from parakeet.training.checkpoint import KBest, KLatest
def test_kbest():
def save_fn(path):
with open(path, 'wt') as f:
f.write(f"My path is {str(path)}\n")
K = 1
kbest_manager = KBest(max_size=K, save_fn=save_fn)
checkpoint_dir = Path("checkpoints")
if checkpoint_dir.exists():
shutil.rmtree(checkpoint_dir)
checkpoint_dir.mkdir(parents=True)
a = np.random.rand(20)
for i, score in enumerate(a):
path = checkpoint_dir / f"step_{i}"
kbest_manager.add_checkpoint(score, path)
assert len(list(checkpoint_dir.glob("step_*"))) == K
shutil.rmtree(checkpoint_dir)
def test_klatest():
def save_fn(path):
with open(path, 'wt') as f:
f.write(f"My path is {str(path)}\n")
K = 5
klatest_manager = KLatest(max_size=K, save_fn=save_fn)
checkpoint_dir = Path("checkpoints")
if checkpoint_dir.exists():
shutil.rmtree(checkpoint_dir)
checkpoint_dir.mkdir(parents=True)
for i in range(20):
path = checkpoint_dir / f"step_{i}"
klatest_manager.add_checkpoint(path)
assert len(list(checkpoint_dir.glob("step_*"))) == K
shutil.rmtree(checkpoint_dir)