From c306f5c2b339f5b7fd9874e1b3428c45fe10f9e8 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Thu, 10 Jun 2021 03:39:54 +0800 Subject: [PATCH] add k-latest --- parakeet/training/checkpoint.py | 71 ++++++++++++++++++++++++++++++++- tests/test_checkpoint.py | 21 +++++++++- 2 files changed, 88 insertions(+), 4 deletions(-) diff --git a/parakeet/training/checkpoint.py b/parakeet/training/checkpoint.py index 155261d..4bb12e2 100644 --- a/parakeet/training/checkpoint.py +++ b/parakeet/training/checkpoint.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Mapping +from typing import Callable, Mapping, List from pathlib import Path @@ -43,7 +43,7 @@ class KBest(object): >>> model = nn.Linear(2, 3) >>> def save_model(path): - ... paddle.save(model.state_dict()) + ... paddle.save(model.state_dict(), path) >>> kbest_manager = KBest(max_size=5, save_fn=save_model) >>> checkpoint_dir = Path("checkpoints") @@ -93,3 +93,70 @@ class KBest(object): # add the new one self.save_fn(path) self.best_records[path] = metric + + +class KLatest(object): + """ + A utility class to help save the hard drive by only keeping K latest + checkpoints. + + To be as modularized as possible, this class does not assume anything like + a Trainer class or anything like a checkpoint directory, it does not know + about the model or the optimizer, etc. + + It is basically a dynamically mantained Queue. When a new item is + added to the queue, save_fn is called. And when an item is removed from the + queue, del_fn is called. `save_fn` and `del_fn` takes a Path object as input + and returns nothing. + + Though it is designed to control checkpointing behaviors, it can be used + to do something else if you pass some save_fn and del_fn. + + Example + -------- + + >>> from pathlib import Path + >>> import shutil + >>> import paddle + >>> from paddle import nn + + >>> model = nn.Linear(2, 3) + >>> def save_model(path): + ... paddle.save(model.state_dict(), path) + + >>> klatest_manager = KLatest(max_size=5, save_fn=save_model) + >>> checkpoint_dir = Path("checkpoints") + >>> shutil.rmtree(checkpoint_dir) + >>> checkpoint_dir.mkdir(parents=True) + >>> for i in range(20): + ... path = checkpoint_dir / f"step_{i}" + ... klatest_manager.add_checkpoint(path) + >>> assert len(list(checkpoint_dir.glob("step_*"))) == 5 + """ + + def __init__(self, + max_size: int=5, + save_fn: Callable[[Path], None]=None, + del_fn: Callable[[Path], None]=lambda f: f.unlink()): + self.latest_records: List[Path] = [] + self.save_fn = save_fn + self.del_fn = del_fn + self.max_size = max_size + self._save_all = (max_size == -1) + + def full(self): + return ( + not self._save_all) and len(self.latest_records) == self.max_size + + def add_checkpoint(self, path): + self.save_checkpoint_and_update(path) + + def save_checkpoint_and_update(self, path): + # remove the earist + if self.full(): + eariest_record_path = self.latest_records.pop(0) + self.del_fn(eariest_record_path) + + # add the new one + self.save_fn(path) + self.latest_records.append(path) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py index af8df02..9173033 100644 --- a/tests/test_checkpoint.py +++ b/tests/test_checkpoint.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from parakeet.training.checkpoint import KBest -import numpy as np from pathlib import Path import shutil +import numpy as np +from parakeet.training.checkpoint import KBest, KLatest + def test_kbest(): def save_fn(path): @@ -33,3 +34,19 @@ def test_kbest(): path = checkpoint_dir / f"step_{i}" kbest_manager.add_checkpoint(score, path) assert len(list(checkpoint_dir.glob("step_*"))) == K + + +def test_klatest(): + def save_fn(path): + with open(path, 'wt') as f: + f.write(f"My path is {str(path)}\n") + + K = 5 + klatest_manager = KLatest(max_size=K, save_fn=save_fn) + checkpoint_dir = Path("checkpoints") + shutil.rmtree(checkpoint_dir) + checkpoint_dir.mkdir(parents=True) + for i in range(20): + path = checkpoint_dir / f"step_{i}" + klatest_manager.add_checkpoint(path) + assert len(list(checkpoint_dir.glob("step_*"))) == K