2021-06-26 22:16:56 +08:00
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
import shutil
|
2021-08-17 15:29:30 +08:00
|
|
|
from itertools import count
|
2021-06-26 22:16:56 +08:00
|
|
|
|
|
|
|
from paddle import nn
|
|
|
|
from paddle.optimizer import Adam
|
|
|
|
|
|
|
|
from parakeet.training.extensions.snapshot import Snapshot
|
2021-08-17 15:29:30 +08:00
|
|
|
from parakeet.training.trainer import Trainer
|
|
|
|
from parakeet.training.updater import StandardUpdater
|
2021-06-26 22:16:56 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_snapshot():
|
|
|
|
model = nn.Linear(3, 4)
|
|
|
|
optimizer = Adam(parameters=model.parameters())
|
|
|
|
|
|
|
|
# use a simplest iterable object as dataloader
|
|
|
|
dataloader = count()
|
|
|
|
|
|
|
|
# hack the training proecss: training does nothing except increse iteration
|
|
|
|
updater = StandardUpdater(model, optimizer, dataloader=dataloader)
|
|
|
|
updater.update_core = lambda x: None
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
updater, stop_trigger=(1000, 'iteration'), out='temp_test_snapshot')
|
|
|
|
shutil.rmtree(trainer.out, ignore_errors=True)
|
|
|
|
|
|
|
|
snap = Snapshot(max_size=5)
|
|
|
|
trigger = (10, 'iteration')
|
|
|
|
trainer.extend(snap, name='snapshot', trigger=trigger, priority=0)
|
|
|
|
|
|
|
|
trainer.run()
|
|
|
|
|
|
|
|
checkpoint_dir = trainer.out / "checkpoints"
|
|
|
|
snapshots = sorted(list(checkpoint_dir.glob("snapshot_iter_*.pdz")))
|
|
|
|
for snap in snapshots:
|
|
|
|
print(snap)
|
|
|
|
assert len(snapshots) == 5
|
|
|
|
shutil.rmtree(trainer.out)
|