From a0763986e224e565ff997c168741b6557a5ee125 Mon Sep 17 00:00:00 2001 From: Feiyu Chan Date: Thu, 10 Jun 2021 10:36:14 +0800 Subject: [PATCH] add kbest (#114) * add kbest * fix typos * remove unnecessary imports * fix del fn * fix typos * add k-latest --- parakeet/training/checkpoint.py | 162 ++++++++++++++++++++++++++++++++ tests/test_checkpoint.py | 52 ++++++++++ 2 files changed, 214 insertions(+) create mode 100644 parakeet/training/checkpoint.py create mode 100644 tests/test_checkpoint.py diff --git a/parakeet/training/checkpoint.py b/parakeet/training/checkpoint.py new file mode 100644 index 0000000..4bb12e2 --- /dev/null +++ b/parakeet/training/checkpoint.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Mapping, List +from pathlib import Path + + +class KBest(object): + """ + A utility class to help save the hard drive by only keeping K best + checkpoints. + + To be as modularized as possible, this class does not assume anything like + a Trainer class or anything like a checkpoint directory, it does not know + about the model or the optimizer, etc. + + It is basically a dynamically mantained K-bset Mapping. When a new item is + added to the map, save_fn is called. And when an item is removed from the + map, del_fn is called. `save_fn` and `del_fn` takes a Path object as input + and returns nothing. + + Though it is designed to control checkpointing behaviors, it can be used + to do something else if you pass some save_fn and del_fn. + + Example + -------- + + >>> from pathlib import Path + >>> import shutil + >>> import paddle + >>> from paddle import nn + + >>> model = nn.Linear(2, 3) + >>> def save_model(path): + ... paddle.save(model.state_dict(), path) + + >>> kbest_manager = KBest(max_size=5, save_fn=save_model) + >>> checkpoint_dir = Path("checkpoints") + >>> shutil.rmtree(checkpoint_dir) + >>> checkpoint_dir.mkdir(parents=True) + >>> a = np.random.rand(20) + >>> for i, score in enumerate(a): + ... path = checkpoint_dir / f"step_{i}" + ... kbest_manager.add_checkpoint(score, path) + >>> assert len(list(checkpoint_dir.glob("step_*"))) == 5 + """ + + def __init__(self, + max_size: int=5, + save_fn: Callable[[Path], None]=None, + del_fn: Callable[[Path], None]=lambda f: f.unlink()): + self.best_records: Mapping[Path, float] = {} + self.save_fn = save_fn + self.del_fn = del_fn + self.max_size = max_size + self._save_all = (max_size == -1) + + def should_save(self, metric: float) -> bool: + if not self.full(): + return True + + # already full + worst_record_path = max(self.best_records, key=self.best_records.get) + worst_metric = self.best_records[worst_record_path] + return metric < worst_metric + + def full(self): + return (not self._save_all) and len(self.best_records) == self.max_size + + def add_checkpoint(self, metric, path): + if self.should_save(metric): + self.save_checkpoint_and_update(metric, path) + + def save_checkpoint_and_update(self, metric, path): + # remove the worst + if self.full(): + worst_record_path = max(self.best_records, + key=self.best_records.get) + self.best_records.pop(worst_record_path) + self.del_fn(worst_record_path) + + # add the new one + self.save_fn(path) + self.best_records[path] = metric + + +class KLatest(object): + """ + A utility class to help save the hard drive by only keeping K latest + checkpoints. + + To be as modularized as possible, this class does not assume anything like + a Trainer class or anything like a checkpoint directory, it does not know + about the model or the optimizer, etc. + + It is basically a dynamically mantained Queue. When a new item is + added to the queue, save_fn is called. And when an item is removed from the + queue, del_fn is called. `save_fn` and `del_fn` takes a Path object as input + and returns nothing. + + Though it is designed to control checkpointing behaviors, it can be used + to do something else if you pass some save_fn and del_fn. + + Example + -------- + + >>> from pathlib import Path + >>> import shutil + >>> import paddle + >>> from paddle import nn + + >>> model = nn.Linear(2, 3) + >>> def save_model(path): + ... paddle.save(model.state_dict(), path) + + >>> klatest_manager = KLatest(max_size=5, save_fn=save_model) + >>> checkpoint_dir = Path("checkpoints") + >>> shutil.rmtree(checkpoint_dir) + >>> checkpoint_dir.mkdir(parents=True) + >>> for i in range(20): + ... path = checkpoint_dir / f"step_{i}" + ... klatest_manager.add_checkpoint(path) + >>> assert len(list(checkpoint_dir.glob("step_*"))) == 5 + """ + + def __init__(self, + max_size: int=5, + save_fn: Callable[[Path], None]=None, + del_fn: Callable[[Path], None]=lambda f: f.unlink()): + self.latest_records: List[Path] = [] + self.save_fn = save_fn + self.del_fn = del_fn + self.max_size = max_size + self._save_all = (max_size == -1) + + def full(self): + return ( + not self._save_all) and len(self.latest_records) == self.max_size + + def add_checkpoint(self, path): + self.save_checkpoint_and_update(path) + + def save_checkpoint_and_update(self, path): + # remove the earist + if self.full(): + eariest_record_path = self.latest_records.pop(0) + self.del_fn(eariest_record_path) + + # add the new one + self.save_fn(path) + self.latest_records.append(path) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py new file mode 100644 index 0000000..9173033 --- /dev/null +++ b/tests/test_checkpoint.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +import shutil + +import numpy as np +from parakeet.training.checkpoint import KBest, KLatest + + +def test_kbest(): + def save_fn(path): + with open(path, 'wt') as f: + f.write(f"My path is {str(path)}\n") + + K = 1 + kbest_manager = KBest(max_size=K, save_fn=save_fn) + checkpoint_dir = Path("checkpoints") + shutil.rmtree(checkpoint_dir) + checkpoint_dir.mkdir(parents=True) + a = np.random.rand(20) + for i, score in enumerate(a): + path = checkpoint_dir / f"step_{i}" + kbest_manager.add_checkpoint(score, path) + assert len(list(checkpoint_dir.glob("step_*"))) == K + + +def test_klatest(): + def save_fn(path): + with open(path, 'wt') as f: + f.write(f"My path is {str(path)}\n") + + K = 5 + klatest_manager = KLatest(max_size=K, save_fn=save_fn) + checkpoint_dir = Path("checkpoints") + shutil.rmtree(checkpoint_dir) + checkpoint_dir.mkdir(parents=True) + for i in range(20): + path = checkpoint_dir / f"step_{i}" + klatest_manager.add_checkpoint(path) + assert len(list(checkpoint_dir.glob("step_*"))) == K