checkpoint utility: add support for str in addition to Path object

This commit is contained in:
chenfeiyu 2021-06-10 20:34:59 +08:00
parent 60c16dcfb7
commit 3bf2e71734
2 changed files with 10 additions and 5 deletions

View File

@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Mapping, List
from typing import Callable, Mapping, List, Union
import os
from pathlib import Path
@ -58,8 +59,8 @@ class KBest(object):
def __init__(self,
max_size: int=5,
save_fn: Callable[[Path], None]=None,
del_fn: Callable[[Path], None]=lambda f: f.unlink()):
save_fn: Callable[[Union[Path, str]], None]=None,
del_fn: Callable[[Union[Path, str]], None]=os.remove):
self.best_records: Mapping[Path, float] = {}
self.save_fn = save_fn
self.del_fn = del_fn

View File

@ -27,13 +27,15 @@ def test_kbest():
K = 1
kbest_manager = KBest(max_size=K, save_fn=save_fn)
checkpoint_dir = Path("checkpoints")
shutil.rmtree(checkpoint_dir)
if checkpoint_dir.exists():
shutil.rmtree(checkpoint_dir)
checkpoint_dir.mkdir(parents=True)
a = np.random.rand(20)
for i, score in enumerate(a):
path = checkpoint_dir / f"step_{i}"
kbest_manager.add_checkpoint(score, path)
assert len(list(checkpoint_dir.glob("step_*"))) == K
shutil.rmtree(checkpoint_dir)
def test_klatest():
@ -44,9 +46,11 @@ def test_klatest():
K = 5
klatest_manager = KLatest(max_size=K, save_fn=save_fn)
checkpoint_dir = Path("checkpoints")
shutil.rmtree(checkpoint_dir)
if checkpoint_dir.exists():
shutil.rmtree(checkpoint_dir)
checkpoint_dir.mkdir(parents=True)
for i in range(20):
path = checkpoint_dir / f"step_{i}"
klatest_manager.add_checkpoint(path)
assert len(list(checkpoint_dir.glob("step_*"))) == K
shutil.rmtree(checkpoint_dir)