fix del fn
This commit is contained in:
parent
13323bdf6a
commit
0114a808a2
|
@ -88,7 +88,7 @@ class KBest(object):
|
||||||
worst_record_path = max(self.best_records,
|
worst_record_path = max(self.best_records,
|
||||||
key=self.best_records.get)
|
key=self.best_records.get)
|
||||||
self.best_records.pop(worst_record_path)
|
self.best_records.pop(worst_record_path)
|
||||||
self.del_fn(path)
|
self.del_fn(worst_record_path)
|
||||||
|
|
||||||
# add the new one
|
# add the new one
|
||||||
self.save_fn(path)
|
self.save_fn(path)
|
||||||
|
|
|
@ -23,7 +23,8 @@ def test_kbest():
|
||||||
with open(path, 'wt') as f:
|
with open(path, 'wt') as f:
|
||||||
f.write(f"My path is {str(path)}\n")
|
f.write(f"My path is {str(path)}\n")
|
||||||
|
|
||||||
kbest_manager = KBest(max_size=5, save_fn=save_fn)
|
K = 1
|
||||||
|
kbest_manager = KBest(max_size=K, save_fn=save_fn)
|
||||||
checkpoint_dir = Path("checkpoints")
|
checkpoint_dir = Path("checkpoints")
|
||||||
shutil.rmtree(checkpoint_dir)
|
shutil.rmtree(checkpoint_dir)
|
||||||
checkpoint_dir.mkdir(parents=True)
|
checkpoint_dir.mkdir(parents=True)
|
||||||
|
@ -31,4 +32,4 @@ def test_kbest():
|
||||||
for i, score in enumerate(a):
|
for i, score in enumerate(a):
|
||||||
path = checkpoint_dir / f"step_{i}"
|
path = checkpoint_dir / f"step_{i}"
|
||||||
kbest_manager.add_checkpoint(score, path)
|
kbest_manager.add_checkpoint(score, path)
|
||||||
assert len(list(checkpoint_dir.glob("step_*"))) == 5
|
assert len(list(checkpoint_dir.glob("step_*"))) == K
|
||||||
|
|
Loading…
Reference in New Issue