fix del fn

This commit is contained in:
chenfeiyu 2021-06-10 00:40:54 +08:00
parent 13323bdf6a
commit 0114a808a2
2 changed files with 4 additions and 3 deletions

View File

@ -88,7 +88,7 @@ class KBest(object):
worst_record_path = max(self.best_records, worst_record_path = max(self.best_records,
key=self.best_records.get) key=self.best_records.get)
self.best_records.pop(worst_record_path) self.best_records.pop(worst_record_path)
self.del_fn(path) self.del_fn(worst_record_path)
# add the new one # add the new one
self.save_fn(path) self.save_fn(path)

View File

@ -23,7 +23,8 @@ def test_kbest():
with open(path, 'wt') as f: with open(path, 'wt') as f:
f.write(f"My path is {str(path)}\n") f.write(f"My path is {str(path)}\n")
kbest_manager = KBest(max_size=5, save_fn=save_fn) K = 1
kbest_manager = KBest(max_size=K, save_fn=save_fn)
checkpoint_dir = Path("checkpoints") checkpoint_dir = Path("checkpoints")
shutil.rmtree(checkpoint_dir) shutil.rmtree(checkpoint_dir)
checkpoint_dir.mkdir(parents=True) checkpoint_dir.mkdir(parents=True)
@ -31,4 +32,4 @@ def test_kbest():
for i, score in enumerate(a): for i, score in enumerate(a):
path = checkpoint_dir / f"step_{i}" path = checkpoint_dir / f"step_{i}"
kbest_manager.add_checkpoint(score, path) kbest_manager.add_checkpoint(score, path)
assert len(list(checkpoint_dir.glob("step_*"))) == 5 assert len(list(checkpoint_dir.glob("step_*"))) == K