make a better snapshot extension
This commit is contained in:
parent
3e8a156348
commit
f9105db727
|
@ -12,8 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
from typing import Union, List, Dict, Any
|
||||
from pathlib import Path
|
||||
import jsonlines
|
||||
import os
|
||||
from datetime import datetime
|
||||
import logging
|
||||
|
||||
from parakeet.utils.mp_tools import rank_zero_only
|
||||
from parakeet.training.trainer import Trainer
|
||||
|
@ -24,7 +28,7 @@ class Snapshot(object):
|
|||
the trainer. It is done by calling the updater's `save` method.
|
||||
|
||||
An Updater save its state_dict by default, which contains the
|
||||
updater state, (i.e. epoch and iteration) and all the model
|
||||
updater state, (i.e. epoch and iteration) and all the model
|
||||
parameters and optimizer states. If the updater inside the trainer
|
||||
subclasses StandardUpdater, everything is good to go.
|
||||
|
||||
|
@ -34,11 +38,47 @@ class Snapshot(object):
|
|||
The directory to save checkpoints into.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_dir: Union[str, Path]):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
def __init__(self, max_size: int=5):
|
||||
self.records: List[Dict[str, Any]] = []
|
||||
self.max_size = max_size
|
||||
self._save_all = (max_size == -1)
|
||||
self.save_fn =...
|
||||
self.del_fn =...
|
||||
self.checkpoint_dir =...
|
||||
|
||||
def initialize(self, trainer):
|
||||
"""setting up this extention."""
|
||||
self.save_fn = trainer.updater.save
|
||||
self.del_fn = os.remove
|
||||
self.checkpoint_dir = trainer.out / "checkpoints"
|
||||
|
||||
def full(self):
|
||||
return (not self._save_all) and len(self.records) >= self.max_size
|
||||
|
||||
@rank_zero_only
|
||||
def __call__(self, trainer: Trainer):
|
||||
def save_checkpoint_and_update(self, trainer):
|
||||
iteration = trainer.updater.state.iteration
|
||||
path = self.checkpoint_dir / f"step_{iteration}.pdz"
|
||||
trainer.updater.save(str(path))
|
||||
path = self.checkpoint_dir / f"snapshot_iter_{iteration}.pdz"
|
||||
|
||||
# remove the earist
|
||||
if self.full():
|
||||
eariest_record = self.records[0]
|
||||
self.del_fn(eariest_record["path"])
|
||||
self.records.pop(0)
|
||||
|
||||
# add the new one
|
||||
self.save_fn(path)
|
||||
record = {
|
||||
"time": str(datetime.now()),
|
||||
'path': str(path),
|
||||
'iteration': iteration
|
||||
}
|
||||
self.records.append(record)
|
||||
|
||||
# update the record
|
||||
with jsonlines.open(self.checkpoint_dir / "records.jsonl", 'w') as f:
|
||||
for record in self.records:
|
||||
f.write(record)
|
||||
|
||||
def __call__(self, trainer):
|
||||
self.save_checkpoint_and_update(trainer)
|
||||
|
|
|
@ -59,7 +59,13 @@ class Trainer(object):
|
|||
self.extensions.keys(),
|
||||
key=lambda name: self.extensions[name].priority,
|
||||
reverse=True)
|
||||
extensions = [(name, self.extensions[name]) for name in extension_order]
|
||||
extensions = [(name, self.extensions[name])
|
||||
for name in extension_order]
|
||||
|
||||
print("initializing")
|
||||
for name, entry in extensions:
|
||||
if hasattr(entry.extension, "initialize"):
|
||||
entry.extension.initialize(self)
|
||||
|
||||
update = self.updater.update
|
||||
stop_trigger = self.stop_trigger
|
||||
|
|
|
@ -198,7 +198,7 @@ class StandardUpdater(UpdaterBase):
|
|||
return state_dict
|
||||
|
||||
def set_state_dict(self, state_dict):
|
||||
"""Set state dict for a Updater. Parameters of models, states for
|
||||
"""Set state dict for a Updater. Parameters of models, states for
|
||||
optimizers and UpdaterState are restored."""
|
||||
for name, layer in self.models.items():
|
||||
layer.set_state_dict(state_dict[f"{name}_params"])
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.optimizer import Adam
|
||||
from itertools import count
|
||||
|
||||
from parakeet.training.updater import StandardUpdater
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
|
||||
|
||||
def test_snapshot():
|
||||
model = nn.Linear(3, 4)
|
||||
optimizer = Adam(parameters=model.parameters())
|
||||
|
||||
# use a simplest iterable object as dataloader
|
||||
dataloader = count()
|
||||
|
||||
# hack the training proecss: training does nothing except increse iteration
|
||||
updater = StandardUpdater(model, optimizer, dataloader=dataloader)
|
||||
updater.update_core = lambda x: None
|
||||
|
||||
trainer = Trainer(
|
||||
updater, stop_trigger=(1000, 'iteration'), out='temp_test_snapshot')
|
||||
shutil.rmtree(trainer.out, ignore_errors=True)
|
||||
|
||||
snap = Snapshot(max_size=5)
|
||||
trigger = (10, 'iteration')
|
||||
trainer.extend(snap, name='snapshot', trigger=trigger, priority=0)
|
||||
|
||||
trainer.run()
|
||||
|
||||
checkpoint_dir = trainer.out / "checkpoints"
|
||||
snapshots = sorted(list(checkpoint_dir.glob("snapshot_iter_*.pdz")))
|
||||
for snap in snapshots:
|
||||
print(snap)
|
||||
assert len(snapshots) == 5
|
||||
shutil.rmtree(trainer.out)
|
Loading…
Reference in New Issue