From f9105db72743e7b90f0b77055d3d5335da0014e8 Mon Sep 17 00:00:00 2001 From: iclementine Date: Sat, 26 Jun 2021 22:16:56 +0800 Subject: [PATCH] make a better snapshot extension --- parakeet/training/extensions/snapshot.py | 54 ++++++++++++++++++++--- parakeet/training/trainer.py | 8 +++- parakeet/training/updater.py | 2 +- tests/test_snapshot.py | 55 ++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 9 deletions(-) create mode 100644 tests/test_snapshot.py diff --git a/parakeet/training/extensions/snapshot.py b/parakeet/training/extensions/snapshot.py index e31403b..9cafef1 100644 --- a/parakeet/training/extensions/snapshot.py +++ b/parakeet/training/extensions/snapshot.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Union, List, Dict, Any from pathlib import Path +import jsonlines +import os +from datetime import datetime +import logging from parakeet.utils.mp_tools import rank_zero_only from parakeet.training.trainer import Trainer @@ -24,7 +28,7 @@ class Snapshot(object): the trainer. It is done by calling the updater's `save` method. An Updater save its state_dict by default, which contains the - updater state, (i.e. epoch and iteration) and all the model + updater state, (i.e. epoch and iteration) and all the model parameters and optimizer states. If the updater inside the trainer subclasses StandardUpdater, everything is good to go. @@ -34,11 +38,47 @@ class Snapshot(object): The directory to save checkpoints into. """ - def __init__(self, checkpoint_dir: Union[str, Path]): - self.checkpoint_dir = Path(checkpoint_dir) + def __init__(self, max_size: int=5): + self.records: List[Dict[str, Any]] = [] + self.max_size = max_size + self._save_all = (max_size == -1) + self.save_fn =... + self.del_fn =... + self.checkpoint_dir =... + + def initialize(self, trainer): + """setting up this extention.""" + self.save_fn = trainer.updater.save + self.del_fn = os.remove + self.checkpoint_dir = trainer.out / "checkpoints" + + def full(self): + return (not self._save_all) and len(self.records) >= self.max_size @rank_zero_only - def __call__(self, trainer: Trainer): + def save_checkpoint_and_update(self, trainer): iteration = trainer.updater.state.iteration - path = self.checkpoint_dir / f"step_{iteration}.pdz" - trainer.updater.save(str(path)) + path = self.checkpoint_dir / f"snapshot_iter_{iteration}.pdz" + + # remove the earist + if self.full(): + eariest_record = self.records[0] + self.del_fn(eariest_record["path"]) + self.records.pop(0) + + # add the new one + self.save_fn(path) + record = { + "time": str(datetime.now()), + 'path': str(path), + 'iteration': iteration + } + self.records.append(record) + + # update the record + with jsonlines.open(self.checkpoint_dir / "records.jsonl", 'w') as f: + for record in self.records: + f.write(record) + + def __call__(self, trainer): + self.save_checkpoint_and_update(trainer) diff --git a/parakeet/training/trainer.py b/parakeet/training/trainer.py index 99e5114..38ccefb 100644 --- a/parakeet/training/trainer.py +++ b/parakeet/training/trainer.py @@ -59,7 +59,13 @@ class Trainer(object): self.extensions.keys(), key=lambda name: self.extensions[name].priority, reverse=True) - extensions = [(name, self.extensions[name]) for name in extension_order] + extensions = [(name, self.extensions[name]) + for name in extension_order] + + print("initializing") + for name, entry in extensions: + if hasattr(entry.extension, "initialize"): + entry.extension.initialize(self) update = self.updater.update stop_trigger = self.stop_trigger diff --git a/parakeet/training/updater.py b/parakeet/training/updater.py index cb2213c..fb3bb41 100644 --- a/parakeet/training/updater.py +++ b/parakeet/training/updater.py @@ -198,7 +198,7 @@ class StandardUpdater(UpdaterBase): return state_dict def set_state_dict(self, state_dict): - """Set state dict for a Updater. Parameters of models, states for + """Set state dict for a Updater. Parameters of models, states for optimizers and UpdaterState are restored.""" for name, layer in self.models.items(): layer.set_state_dict(state_dict[f"{name}_params"]) diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py new file mode 100644 index 0000000..71e422c --- /dev/null +++ b/tests/test_snapshot.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import shutil + +import numpy as np +import paddle +from paddle import nn +from paddle.optimizer import Adam +from itertools import count + +from parakeet.training.updater import StandardUpdater +from parakeet.training.trainer import Trainer +from parakeet.training.extensions.snapshot import Snapshot + + +def test_snapshot(): + model = nn.Linear(3, 4) + optimizer = Adam(parameters=model.parameters()) + + # use a simplest iterable object as dataloader + dataloader = count() + + # hack the training proecss: training does nothing except increse iteration + updater = StandardUpdater(model, optimizer, dataloader=dataloader) + updater.update_core = lambda x: None + + trainer = Trainer( + updater, stop_trigger=(1000, 'iteration'), out='temp_test_snapshot') + shutil.rmtree(trainer.out, ignore_errors=True) + + snap = Snapshot(max_size=5) + trigger = (10, 'iteration') + trainer.extend(snap, name='snapshot', trigger=trigger, priority=0) + + trainer.run() + + checkpoint_dir = trainer.out / "checkpoints" + snapshots = sorted(list(checkpoint_dir.glob("snapshot_iter_*.pdz"))) + for snap in snapshots: + print(snap) + assert len(snapshots) == 5 + shutil.rmtree(trainer.out)