checkpoint utility: add support for str in addition to Path object
This commit is contained in:
parent
60c16dcfb7
commit
3bf2e71734
|
@ -12,7 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Callable, Mapping, List
|
from typing import Callable, Mapping, List, Union
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,8 +59,8 @@ class KBest(object):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
max_size: int=5,
|
max_size: int=5,
|
||||||
save_fn: Callable[[Path], None]=None,
|
save_fn: Callable[[Union[Path, str]], None]=None,
|
||||||
del_fn: Callable[[Path], None]=lambda f: f.unlink()):
|
del_fn: Callable[[Union[Path, str]], None]=os.remove):
|
||||||
self.best_records: Mapping[Path, float] = {}
|
self.best_records: Mapping[Path, float] = {}
|
||||||
self.save_fn = save_fn
|
self.save_fn = save_fn
|
||||||
self.del_fn = del_fn
|
self.del_fn = del_fn
|
||||||
|
|
|
@ -27,6 +27,7 @@ def test_kbest():
|
||||||
K = 1
|
K = 1
|
||||||
kbest_manager = KBest(max_size=K, save_fn=save_fn)
|
kbest_manager = KBest(max_size=K, save_fn=save_fn)
|
||||||
checkpoint_dir = Path("checkpoints")
|
checkpoint_dir = Path("checkpoints")
|
||||||
|
if checkpoint_dir.exists():
|
||||||
shutil.rmtree(checkpoint_dir)
|
shutil.rmtree(checkpoint_dir)
|
||||||
checkpoint_dir.mkdir(parents=True)
|
checkpoint_dir.mkdir(parents=True)
|
||||||
a = np.random.rand(20)
|
a = np.random.rand(20)
|
||||||
|
@ -34,6 +35,7 @@ def test_kbest():
|
||||||
path = checkpoint_dir / f"step_{i}"
|
path = checkpoint_dir / f"step_{i}"
|
||||||
kbest_manager.add_checkpoint(score, path)
|
kbest_manager.add_checkpoint(score, path)
|
||||||
assert len(list(checkpoint_dir.glob("step_*"))) == K
|
assert len(list(checkpoint_dir.glob("step_*"))) == K
|
||||||
|
shutil.rmtree(checkpoint_dir)
|
||||||
|
|
||||||
|
|
||||||
def test_klatest():
|
def test_klatest():
|
||||||
|
@ -44,9 +46,11 @@ def test_klatest():
|
||||||
K = 5
|
K = 5
|
||||||
klatest_manager = KLatest(max_size=K, save_fn=save_fn)
|
klatest_manager = KLatest(max_size=K, save_fn=save_fn)
|
||||||
checkpoint_dir = Path("checkpoints")
|
checkpoint_dir = Path("checkpoints")
|
||||||
|
if checkpoint_dir.exists():
|
||||||
shutil.rmtree(checkpoint_dir)
|
shutil.rmtree(checkpoint_dir)
|
||||||
checkpoint_dir.mkdir(parents=True)
|
checkpoint_dir.mkdir(parents=True)
|
||||||
for i in range(20):
|
for i in range(20):
|
||||||
path = checkpoint_dir / f"step_{i}"
|
path = checkpoint_dir / f"step_{i}"
|
||||||
klatest_manager.add_checkpoint(path)
|
klatest_manager.add_checkpoint(path)
|
||||||
assert len(list(checkpoint_dir.glob("step_*"))) == K
|
assert len(list(checkpoint_dir.glob("step_*"))) == K
|
||||||
|
shutil.rmtree(checkpoint_dir)
|
||||||
|
|
Loading…
Reference in New Issue