diff --git a/parakeet/training/extensions/snapshot.py b/parakeet/training/extensions/snapshot.py
index e31403b..9cafef1 100644
--- a/parakeet/training/extensions/snapshot.py
+++ b/parakeet/training/extensions/snapshot.py
@@ -12,8 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Union
+from typing import Union, List, Dict, Any
 from pathlib import Path
+import jsonlines
+import os
+from datetime import datetime
+import logging
 
 from parakeet.utils.mp_tools import rank_zero_only
 from parakeet.training.trainer import Trainer
@@ -24,7 +28,7 @@ class Snapshot(object):
     the trainer. It is done by calling the updater's `save` method.
 
     An Updater save its state_dict by default, which contains the
-    updater state, (i.e. epoch and iteration) and all the model 
+    updater state, (i.e. epoch and iteration) and all the model
     parameters and optimizer states. If the updater inside the trainer
     subclasses StandardUpdater, everything is good to go.
 
@@ -34,11 +38,47 @@ class Snapshot(object):
         The directory to save checkpoints into.
     """
 
-    def __init__(self, checkpoint_dir: Union[str, Path]):
-        self.checkpoint_dir = Path(checkpoint_dir)
+    def __init__(self, max_size: int=5):
+        self.records: List[Dict[str, Any]] = []
+        self.max_size = max_size
+        self._save_all = (max_size == -1)
+        self.save_fn =...
+        self.del_fn =...
+        self.checkpoint_dir =...
+
+    def initialize(self, trainer):
+        """setting up this extention."""
+        self.save_fn = trainer.updater.save
+        self.del_fn = os.remove
+        self.checkpoint_dir = trainer.out / "checkpoints"
+
+    def full(self):
+        return (not self._save_all) and len(self.records) >= self.max_size
 
     @rank_zero_only
-    def __call__(self, trainer: Trainer):
+    def save_checkpoint_and_update(self, trainer):
         iteration = trainer.updater.state.iteration
-        path = self.checkpoint_dir / f"step_{iteration}.pdz"
-        trainer.updater.save(str(path))
+        path = self.checkpoint_dir / f"snapshot_iter_{iteration}.pdz"
+
+        # remove the earist
+        if self.full():
+            eariest_record = self.records[0]
+            self.del_fn(eariest_record["path"])
+            self.records.pop(0)
+
+        # add the new one
+        self.save_fn(path)
+        record = {
+            "time": str(datetime.now()),
+            'path': str(path),
+            'iteration': iteration
+        }
+        self.records.append(record)
+
+        # update the record
+        with jsonlines.open(self.checkpoint_dir / "records.jsonl", 'w') as f:
+            for record in self.records:
+                f.write(record)
+
+    def __call__(self, trainer):
+        self.save_checkpoint_and_update(trainer)
diff --git a/parakeet/training/trainer.py b/parakeet/training/trainer.py
index 99e5114..38ccefb 100644
--- a/parakeet/training/trainer.py
+++ b/parakeet/training/trainer.py
@@ -59,7 +59,13 @@ class Trainer(object):
             self.extensions.keys(),
             key=lambda name: self.extensions[name].priority,
             reverse=True)
-        extensions = [(name, self.extensions[name]) for name in extension_order]
+        extensions = [(name, self.extensions[name])
+                      for name in extension_order]
+
+        print("initializing")
+        for name, entry in extensions:
+            if hasattr(entry.extension, "initialize"):
+                entry.extension.initialize(self)
 
         update = self.updater.update
         stop_trigger = self.stop_trigger
diff --git a/parakeet/training/updater.py b/parakeet/training/updater.py
index cb2213c..fb3bb41 100644
--- a/parakeet/training/updater.py
+++ b/parakeet/training/updater.py
@@ -198,7 +198,7 @@ class StandardUpdater(UpdaterBase):
         return state_dict
 
     def set_state_dict(self, state_dict):
-        """Set state dict for a Updater. Parameters of models, states for 
+        """Set state dict for a Updater. Parameters of models, states for
         optimizers and UpdaterState are restored."""
         for name, layer in self.models.items():
             layer.set_state_dict(state_dict[f"{name}_params"])
diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py
new file mode 100644
index 0000000..71e422c
--- /dev/null
+++ b/tests/test_snapshot.py
@@ -0,0 +1,55 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+import shutil
+
+import numpy as np
+import paddle
+from paddle import nn
+from paddle.optimizer import Adam
+from itertools import count
+
+from parakeet.training.updater import StandardUpdater
+from parakeet.training.trainer import Trainer
+from parakeet.training.extensions.snapshot import Snapshot
+
+
+def test_snapshot():
+    model = nn.Linear(3, 4)
+    optimizer = Adam(parameters=model.parameters())
+
+    # use a simplest iterable object as dataloader
+    dataloader = count()
+
+    # hack the training proecss: training does nothing except increse iteration
+    updater = StandardUpdater(model, optimizer, dataloader=dataloader)
+    updater.update_core = lambda x: None
+
+    trainer = Trainer(
+        updater, stop_trigger=(1000, 'iteration'), out='temp_test_snapshot')
+    shutil.rmtree(trainer.out, ignore_errors=True)
+
+    snap = Snapshot(max_size=5)
+    trigger = (10, 'iteration')
+    trainer.extend(snap, name='snapshot', trigger=trigger, priority=0)
+
+    trainer.run()
+
+    checkpoint_dir = trainer.out / "checkpoints"
+    snapshots = sorted(list(checkpoint_dir.glob("snapshot_iter_*.pdz")))
+    for snap in snapshots:
+        print(snap)
+    assert len(snapshots) == 5
+    shutil.rmtree(trainer.out)