fix del fn
This commit is contained in:
parent
13323bdf6a
commit
0114a808a2
|
@ -88,7 +88,7 @@ class KBest(object):
|
|||
worst_record_path = max(self.best_records,
|
||||
key=self.best_records.get)
|
||||
self.best_records.pop(worst_record_path)
|
||||
self.del_fn(path)
|
||||
self.del_fn(worst_record_path)
|
||||
|
||||
# add the new one
|
||||
self.save_fn(path)
|
||||
|
|
|
@ -23,7 +23,8 @@ def test_kbest():
|
|||
with open(path, 'wt') as f:
|
||||
f.write(f"My path is {str(path)}\n")
|
||||
|
||||
kbest_manager = KBest(max_size=5, save_fn=save_fn)
|
||||
K = 1
|
||||
kbest_manager = KBest(max_size=K, save_fn=save_fn)
|
||||
checkpoint_dir = Path("checkpoints")
|
||||
shutil.rmtree(checkpoint_dir)
|
||||
checkpoint_dir.mkdir(parents=True)
|
||||
|
@ -31,4 +32,4 @@ def test_kbest():
|
|||
for i, score in enumerate(a):
|
||||
path = checkpoint_dir / f"step_{i}"
|
||||
kbest_manager.add_checkpoint(score, path)
|
||||
assert len(list(checkpoint_dir.glob("step_*"))) == 5
|
||||
assert len(list(checkpoint_dir.glob("step_*"))) == K
|
||||
|
|
Loading…
Reference in New Issue