diff --git a/parakeet/__init__.py b/parakeet/__init__.py index d4940a3..cce4fff 100644 --- a/parakeet/__init__.py +++ b/parakeet/__init__.py @@ -14,4 +14,7 @@ __version__ = "0.2.0-beta.0" +import logging from parakeet import audio, data, datasets, frontend, models, modules, training, utils + +logging.getLogger('parakeet').addHandler(logging.NullHandler()) diff --git a/parakeet/training/checkpoint.py b/parakeet/training/checkpoint.py deleted file mode 100644 index bbbbdc0..0000000 --- a/parakeet/training/checkpoint.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Callable, Mapping, List, Union -import os -from pathlib import Path - - -class KBest(object): - """ - A utility class to help save the hard drive by only keeping K best - checkpoints. - - To be as modularized as possible, this class does not assume anything like - a Trainer class or anything like a checkpoint directory, it does not know - about the model or the optimizer, etc. - - It is basically a dynamically mantained K-bset Mapping. When a new item is - added to the map, save_fn is called. And when an item is removed from the - map, del_fn is called. `save_fn` and `del_fn` takes a Path object as input - and returns nothing. - - Though it is designed to control checkpointing behaviors, it can be used - to do something else if you pass some save_fn and del_fn. - - Example - -------- - - >>> from pathlib import Path - >>> import shutil - >>> import paddle - >>> from paddle import nn - - >>> model = nn.Linear(2, 3) - >>> def save_model(path): - ... paddle.save(model.state_dict(), path) - - >>> kbest_manager = KBest(max_size=5, save_fn=save_model) - >>> checkpoint_dir = Path("checkpoints") - >>> shutil.rmtree(checkpoint_dir) - >>> checkpoint_dir.mkdir(parents=True) - >>> a = np.random.rand(20) - >>> for i, score in enumerate(a): - ... path = checkpoint_dir / f"step_{i}" - ... kbest_manager.add_checkpoint(score, path) - >>> assert len(list(checkpoint_dir.glob("step_*"))) == 5 - """ - - def __init__(self, - max_size: int=5, - save_fn: Callable[[Union[Path, str]], None]=None, - del_fn: Callable[[Union[Path, str]], None]=os.remove): - self.best_records: Mapping[Path, float] = {} - self.save_fn = save_fn - self.del_fn = del_fn - self.max_size = max_size - self._save_all = (max_size == -1) - - def should_save(self, metric: float) -> bool: - if not self.full(): - return True - - # already full - worst_record_path = max(self.best_records, key=self.best_records.get) - worst_metric = self.best_records[worst_record_path] - return metric < worst_metric - - def full(self): - return (not self._save_all) and len(self.best_records) == self.max_size - - def add_checkpoint(self, metric, path): - if self.should_save(metric): - self.save_checkpoint_and_update(metric, path) - - def save_checkpoint_and_update(self, metric, path): - # remove the worst - if self.full(): - worst_record_path = max(self.best_records, - key=self.best_records.get) - self.best_records.pop(worst_record_path) - self.del_fn(worst_record_path) - - # add the new one - self.save_fn(path) - self.best_records[path] = metric - - -class KLatest(object): - """ - A utility class to help save the hard drive by only keeping K latest - checkpoints. - - To be as modularized as possible, this class does not assume anything like - a Trainer class or anything like a checkpoint directory, it does not know - about the model or the optimizer, etc. - - It is basically a dynamically mantained Queue. When a new item is - added to the queue, save_fn is called. And when an item is removed from the - queue, del_fn is called. `save_fn` and `del_fn` takes a Path object as input - and returns nothing. - - Though it is designed to control checkpointing behaviors, it can be used - to do something else if you pass some save_fn and del_fn. - - Example - -------- - - >>> from pathlib import Path - >>> import shutil - >>> import paddle - >>> from paddle import nn - - >>> model = nn.Linear(2, 3) - >>> def save_model(path): - ... paddle.save(model.state_dict(), path) - - >>> klatest_manager = KLatest(max_size=5, save_fn=save_model) - >>> checkpoint_dir = Path("checkpoints") - >>> shutil.rmtree(checkpoint_dir) - >>> checkpoint_dir.mkdir(parents=True) - >>> for i in range(20): - ... path = checkpoint_dir / f"step_{i}" - ... klatest_manager.add_checkpoint(path) - >>> assert len(list(checkpoint_dir.glob("step_*"))) == 5 - """ - - def __init__(self, - max_size: int=5, - save_fn: Callable[[Path], None]=None, - del_fn: Callable[[Path], None]=lambda f: f.unlink()): - self.latest_records: List[Path] = [] - self.save_fn = save_fn - self.del_fn = del_fn - self.max_size = max_size - self._save_all = (max_size == -1) - - def full(self): - return ( - not self._save_all) and len(self.latest_records) == self.max_size - - def add_checkpoint(self, path): - self.save_checkpoint_and_update(path) - - def save_checkpoint_and_update(self, path): - # remove the earist - if self.full(): - eariest_record_path = self.latest_records.pop(0) - self.del_fn(eariest_record_path) - - # add the new one - self.save_fn(path) - self.latest_records.append(path) diff --git a/parakeet/training/extension.py b/parakeet/training/extension.py new file mode 100644 index 0000000..57c4f29 --- /dev/null +++ b/parakeet/training/extension.py @@ -0,0 +1,80 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable + +PRIORITY_WRITER = 300 +PRIORITY_EDITOR = 200 +PRIORITY_READER = 100 + + +class Extension(object): + """Extension to customize the behavior of Trainer.""" + trigger = (1, 'iteration') + priority = PRIORITY_READER + name = None + + @property + def default_name(self): + """Default name of the extension, class name by default.""" + return type(self).__name__ + + def __call__(self, trainer): + """Main action of the extention. After each update, it is executed + when the trigger fires.""" + raise NotImplementedError( + 'Extension implementation must override __call__.') + + def initialize(self, trainer): + """Action that is executed once to get the corect trainer state. + It is called before training normally, but if the trainer restores + states with an Snapshot extension, this method should also be called.g + """ + pass + + def on_error(self, trainer, exc, tb): + """Handles the error raised during training before finalization. + """ + pass + + def finalize(self, trainer): + """Action that is executed when training is done. + For example, visualizers would need to be closed. + """ + pass + + +def make_extension(trigger: Callable=None, + default_name: str=None, + priority: int=None, + finalizer: Callable=None, + initializer: Callable=None, + on_error: Callable=None): + """Make an Extension-like object by injecting required attributes to it. + """ + if trigger is None: + trigger = Extension.trigger + if priority is None: + priority = Extension.priority + + def decorator(ext): + ext.trigger = trigger + ext.default_name = default_name or ext.__name__ + ext.priority = priority + ext.finalize = finalizer + ext.on_error = on_error + ext.initialize = initializer + return ext + + return decorator diff --git a/parakeet/training/extensions/evaluator.py b/parakeet/training/extensions/evaluator.py index eb3f877..6ebaae6 100644 --- a/parakeet/training/extensions/evaluator.py +++ b/parakeet/training/extensions/evaluator.py @@ -22,9 +22,17 @@ from paddle.nn import Layer from paddle.io import DataLoader from parakeet.training.reporter import scope, report, DictSummary +from parakeet.training import extension -class StandardEvaluator(object): +class StandardEvaluator(extension.Extension): + + trigger = (1, 'epoch') + default_name = 'validation' + priority = extension.PRIORITY_WRITER + + name = None + def __init__(self, model: Layer, dataloader: DataLoader): # it is designed to hold multiple models models = {"main": model} @@ -43,21 +51,23 @@ class StandardEvaluator(object): for layer in self.models.values(): layer.eval() + # to average evaluation metrics summary = DictSummary() for batch in self.dataloader: observation = {} with scope(observation): + # main evaluation computation here. with paddle.no_grad(): - self.evaluate_core( - batch) # main evaluation computation here. + self.evaluate_core(batch) summary.add(observation) summary = summary.compute_mean() return summary def __call__(self, trainer=None): - self.observation = {} - with scope(self.observation): - summary = self.evaluate() - for k, v in summary.items(): - report(k, v) - print(self.observation) + # evaluate and report the averaged metric to current observation + # if it is used to extend a trainer, the metrics is reported to + # to observation of the trainer + # or otherwise, you can use your own observation + summary = self.evaluate() + for k, v in summary.items(): + report(k, v) diff --git a/parakeet/training/extensions/snapshot.py b/parakeet/training/extensions/snapshot.py index 9cafef1..853d62e 100644 --- a/parakeet/training/extensions/snapshot.py +++ b/parakeet/training/extensions/snapshot.py @@ -12,18 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, List, Dict, Any -from pathlib import Path -import jsonlines import os +from pathlib import Path from datetime import datetime -import logging +from typing import List, Dict, Any + +import jsonlines from parakeet.utils.mp_tools import rank_zero_only from parakeet.training.trainer import Trainer +from parakeet.training import extension -class Snapshot(object): +def load_records(records_fp): + """Load record files (json lines.)""" + with jsonlines.open(records_fp, 'r') as reader: + records = list(reader) + return records + + +class Snapshot(extension.Extension): """An extension to make snapshot of the updater object inside the trainer. It is done by calling the updater's `save` method. @@ -38,34 +46,46 @@ class Snapshot(object): The directory to save checkpoints into. """ - def __init__(self, max_size: int=5): + trigger = (1, 'epoch') + priority = -100 + + def __init__(self, max_size: int=5, snapshot_on_error: bool=False): self.records: List[Dict[str, Any]] = [] self.max_size = max_size + self._snapshot_on_error = snapshot_on_error self._save_all = (max_size == -1) self.save_fn =... - self.del_fn =... + self.del_fn = os.remove self.checkpoint_dir =... - def initialize(self, trainer): - """setting up this extention.""" + def initialize(self, trainer: Trainer): + """Setting up this extention.""" self.save_fn = trainer.updater.save - self.del_fn = os.remove self.checkpoint_dir = trainer.out / "checkpoints" + # load existing records + record_path: Path = self.checkpoint_dir / "records.yaml" + if record_path.exists(): + self.records = load_records(record_path) + + def on_error(self, trainer, exc, tb): + if self._snapshot_on_error: + self.save_checkpoint_and_update(trainer) + + def __call__(self, trainer: Trainer): + self.save_checkpoint_and_update(trainer) + def full(self): - return (not self._save_all) and len(self.records) >= self.max_size + """Whether the number of snapshots it keeps track of is greater + than the max_size.""" + return (not self._save_all) and len(self.records) > self.max_size @rank_zero_only - def save_checkpoint_and_update(self, trainer): + def save_checkpoint_and_update(self, trainer: Trainer): + """Saving new snapshot and remove the oldest snapshot if needed.""" iteration = trainer.updater.state.iteration path = self.checkpoint_dir / f"snapshot_iter_{iteration}.pdz" - # remove the earist - if self.full(): - eariest_record = self.records[0] - self.del_fn(eariest_record["path"]) - self.records.pop(0) - # add the new one self.save_fn(path) record = { @@ -75,10 +95,14 @@ class Snapshot(object): } self.records.append(record) - # update the record - with jsonlines.open(self.checkpoint_dir / "records.jsonl", 'w') as f: - for record in self.records: - f.write(record) + # remove the earist + if self.full(): + eariest_record = self.records[0] + self.del_fn(eariest_record["path"]) + self.records.pop(0) - def __call__(self, trainer): - self.save_checkpoint_and_update(trainer) + # update the record file + record_path = self.checkpoint_dir / "records.jsonl" + with jsonlines.open(record_path, 'w') as writer: + for record in self.records: + writer.write(record) diff --git a/parakeet/training/extensions/visualizer.py b/parakeet/training/extensions/visualizer.py index 9a1bcbb..2a42ae0 100644 --- a/parakeet/training/extensions/visualizer.py +++ b/parakeet/training/extensions/visualizer.py @@ -13,22 +13,31 @@ # limitations under the License. from visualdl import LogWriter + from parakeet.training.trainer import Trainer -from parakeet.utils.mp_tools import rank_zero_only +from parakeet.training import extension -class VisualDL(object): +class VisualDL(extension.Extension): """A wrapper of visualdl log writer. It assumes that the metrics to be visualized - are all scalars which are recorded into the `.observation` dictionary of the - trainer object. The dictionary is created for each step, thus the visualdl log + are all scalars which are recorded into the `.observation` dictionary of the + trainer object. The dictionary is created for each step, thus the visualdl log writer uses the iteration from the updater's `iteration` as the global step to add records. """ + trigger = (1, 'iteration') + default_name = 'visualdl' + priority = extension.PRIORITY_READER - def __init__(self, writer: LogWriter): - self.writer = writer + def __init__(self): + self.writer =... + + def initialize(self, trainer): + self.writer = LogWriter(logdir=str(trainer.out)) - @rank_zero_only def __call__(self, trainer: Trainer): for k, v in trainer.observation.items(): self.writer.add_scalar(k, v, step=trainer.updater.state.iteration) + + def finalize(self, trainer): + self.writer.close() diff --git a/parakeet/training/seeding.py b/parakeet/training/seeding.py index 1a6660f..1663d2d 100644 --- a/parakeet/training/seeding.py +++ b/parakeet/training/seeding.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import logging import paddle -import random import numpy as np def seed_everything(seed: int): + """Seed paddle, random and np.random to help reproductivity.""" paddle.seed(seed) random.seed(seed) np.random.seed(seed) diff --git a/parakeet/training/trainer.py b/parakeet/training/trainer.py index 38ccefb..f4b6fbb 100644 --- a/parakeet/training/trainer.py +++ b/parakeet/training/trainer.py @@ -13,15 +13,18 @@ # limitations under the License. from pathlib import Path +from collections import OrderedDict +from typing import Callable, Union, List + import tqdm -from dataclasses import dataclass from parakeet.training.trigger import get_trigger, IntervalTrigger from parakeet.training.updater import UpdaterBase from parakeet.training.reporter import scope +from parakeet.training.extension import Extension, PRIORITY_READER -class ExtensionEntry(object): +class _ExtensionEntry(object): def __init__(self, extension, trigger, priority): self.extension = extension self.trigger = trigger @@ -31,29 +34,76 @@ class ExtensionEntry(object): class Trainer(object): def __init__(self, updater: UpdaterBase, - stop_trigger=None, - out='result', - extensions=None): + stop_trigger: Callable=None, + out: Union[str, Path]='result', + extensions: List[Extension]=None): self.updater = updater - self.extensions = {} + self.extensions = OrderedDict() self.stop_trigger = get_trigger(stop_trigger) self.out = Path(out) - self.observation = {} + self.observation =... + + self._done = False + if extensions: + for ext in extensions: + self.extend(ext) + + @property + def is_before_training(self): + return self.updater.state.iteration == 0 def extend(self, extension, name=None, trigger=None, priority=None): + # get name for the extension + # argument \ + # -> extention's name \ + # -> default_name (class name, when it is an object) \ + # -> function name when it is a function \ + # -> error + + if name is None: + name = getattr(extension, 'name', None) + if name is None: + name = getattr(extenion, 'default_name', None) + if name is None: + name = getattr(extension, '__name__', None) + if name is None: + raise ValueError( + "Name is not given for the extension.") + if name == 'training': + raise ValueError("training is a reserved name.") + + if trigger is None: + trigger = getattr(extension, 'trigger', (1, 'iteration')) trigger = get_trigger(trigger) + if priority is None: + priority = getattr(extension, 'priority', PRIORITY_READER) + + # add suffix to avoid nameing conflict ordinal = 0 modified_name = name while modified_name in self.extensions: - print(self.extensions.keys()) ordinal += 1 modified_name = f"{name}_{ordinal}" + extension.name = modified_name - self.extensions[modified_name] = ExtensionEntry(extension, trigger, - priority) + self.extensions[modified_name] = _ExtensionEntry(extension, trigger, + priority) + + def get_extension(self, name): + """get extension by name.""" + extensions = self.extensions + if name in extensions: + return extensions[name].extension + else: + raise ValueError(f'extension {name} not found') def run(self): + if self._done: + raise RuntimeError("Training is already done!.") + + self.out.mkdir(parents=True, exist_ok=True) + # sort extensions by priorities once extension_order = sorted( self.extensions.keys(), @@ -67,7 +117,7 @@ class Trainer(object): if hasattr(entry.extension, "initialize"): entry.extension.initialize(self) - update = self.updater.update + update = self.updater.update # training step stop_trigger = self.stop_trigger # TODO(chenfeiyu): display progress bar correctly diff --git a/parakeet/training/trigger.py b/parakeet/training/trigger.py index 5d165a6..f5834bc 100644 --- a/parakeet/training/trigger.py +++ b/parakeet/training/trigger.py @@ -12,21 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - -class IntervalTrigger(object): - def __init__(self, period: int, unit: str): - if unit not in ("iteration", "epoch"): - raise ValueError("unit should be 'iteration' or 'epoch'") - self.period = period - self.unit = unit - - def __call__(self, trainer): - state = trainer.updater.state - if self.unit == "epoch": - fire = state.epoch % self.period == 0 - else: - fire = state.iteration % self.period == 0 - return fire +from parakeet.training.triggers.interval_trigger import IntervalTrigger def never_file_trigger(trainer): diff --git a/parakeet/training/triggers/interval_trigger.py b/parakeet/training/triggers/interval_trigger.py new file mode 100644 index 0000000..82f441e --- /dev/null +++ b/parakeet/training/triggers/interval_trigger.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class IntervalTrigger(object): + """A Predicate to do something every N cycle.""" + + def __init__(self, period: int, unit: str): + if unit not in ("iteration", "epoch"): + raise ValueError("unit should be 'iteration' or 'epoch'") + self.period = period + self.unit = unit + + def __call__(self, trainer): + state = trainer.updater.state + # we use a special scheme so we can use iteration % period == 0 as + # the predicate + # increase the iteration then update parameters + # instead of updating then increase iteration + if self.unit == "epoch": + fire = state.epoch % self.period == 0 + else: + fire = state.iteration % self.period == 0 + return fire diff --git a/parakeet/training/triggers/time_trigger.py b/parakeet/training/triggers/time_trigger.py new file mode 100644 index 0000000..aff9382 --- /dev/null +++ b/parakeet/training/triggers/time_trigger.py @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TimeTrigger(object): + """Trigger based on a fixed time interval. + + This trigger accepts iterations with a given interval time. + + Args: + period (float): Interval time. It is given in seconds. + + """ + + def __init__(self, period): + self._period = period + self._next_time = self._period + + def __call__(self, trainer): + if self._next_time < trainer.elapsed_time: + self._next_time += self._period + return True + else: + return False diff --git a/parakeet/training/updater.py b/parakeet/training/updater.py index fb3bb41..2d9ec3d 100644 --- a/parakeet/training/updater.py +++ b/parakeet/training/updater.py @@ -93,125 +93,3 @@ class UpdaterBase(object): def load(self, path): archive = paddle.load(path) self.set_state_dict(archive) - - -class StandardUpdater(UpdaterBase): - """An example of over-simplification. Things may not be that simple, but - you can subclass it to fit your need. - """ - - def __init__(self, - model: Layer, - optimizer: Optimizer, - dataloader: DataLoader, - init_state: Optional[UpdaterState]=None): - # it is designed to hold multiple models - models = {"main": model} - self.models: Dict[str, Layer] = models - self.model = model - - # it is designed to hold multiple optimizers - optimizers = {"main": optimizer} - self.optimizer = optimizer - self.optimizers: Dict[str, Optimizer] = optimizers - - # dataloaders - self.dataloader = dataloader - - # init state - if init_state is None: - self.state = UpdaterState() - else: - self.state = init_state - - self.train_iterator = iter(dataloader) - - def update(self): - self.state.iteration += 1 - - # switch to training mode - for layer in self.models.values(): - layer.train() - - # training for a step is implemented here - batch = self.read_batch() - self.update_core(batch) - - def update_core(self, batch): - """A simple case for a training step. Basic assumptions are: - Single model; - Single optimizer; - A batch from the dataloader is just the input of the model; - The model return a single loss, or a dict containing serval losses. - Parameters updates at every batch, no gradient accumulation. - """ - loss = self.model(*batch) - - if isinstance(loss, Tensor): - loss_dict = {"main": loss} - else: - # Dict[str, Tensor] - loss_dict = loss - if "main" not in loss_dict: - main_loss = 0 - for loss_item in loss.values(): - main_loss += loss_item - loss_dict["main"] = main_loss - - for name, loss_item in loss_dict.items(): - report(name, float(loss_item)) - - self.optimizer.clear_gradient() - loss_dict["main"].backward() - self.optimizer.update() - - def new_epoch(self): - """Start a new epoch.""" - self.state.epoch += 1 - - # NOTE: all batch sampler for distributed training should - # subclass DistributedBatchSampler and implement `set_epoch` method - batch_sampler = self.dataloader.batch_sampler - if isinstance(batch_sampler, DistributedBatchSampler): - batch_sampler.set_epoch(self.state.epoch) - self.train_iterator = iter(self.dataloader) - - def read_batch(self): - """Read a batch from the data loader, auto renew when data is exhausted.""" - with timer() as t: - try: - batch = next(self.train_iterator) - except StopIteration: - self.new_epoch() - batch = next(self.train_iterator) - logging.debug( - f"Read a batch takes {t.elapse}s.") # replace it with logging - return batch - - def state_dict(self): - """State dict of a Updater, model, optimizer and updater state are included.""" - state_dict = super().state_dict() - for name, layer in self.models.items(): - state_dict[f"{name}_params"] = layer.state_dict() - for name, optim in self.optimizers.items(): - state_dict[f"{name}_optimizer"] = optim.state_dict() - return state_dict - - def set_state_dict(self, state_dict): - """Set state dict for a Updater. Parameters of models, states for - optimizers and UpdaterState are restored.""" - for name, layer in self.models.items(): - layer.set_state_dict(state_dict[f"{name}_params"]) - for name, optim in self.optimizers.items(): - optim.set_state_dict(state_dict[f"{name}_optimizer"]) - super().set_state_dict(state_dict) - - def save(self, path): - """Save Updater state dict.""" - archive = self.state_dict() - paddle.save(archive, path) - - def load(self, path): - """Load Updater state dict.""" - archive = paddle.load(path) - self.set_state_dict(archive) diff --git a/parakeet/training/updaters/standard_updater.py b/parakeet/training/updaters/standard_updater.py new file mode 100644 index 0000000..5cc0252 --- /dev/null +++ b/parakeet/training/updaters/standard_updater.py @@ -0,0 +1,152 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass +from typing import Optional +from typing import Dict +from typing import Union + +from timer import timer +import paddle +from paddle import Tensor +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler + +from parakeet.training.reporter import report +from parakeet.training.updater import UpdaterBase, UpdaterState + + +class StandardUpdater(UpdaterBase): + """An example of over-simplification. Things may not be that simple, but + you can subclass it to fit your need. + """ + + def __init__(self, + model: Layer, + optimizer: Optimizer, + dataloader: DataLoader, + init_state: Optional[UpdaterState]=None): + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models + self.model = model + + # it is designed to hold multiple optimizers + optimizers = {"main": optimizer} + self.optimizer = optimizer + self.optimizers: Dict[str, Optimizer] = optimizers + + # dataloaders + self.dataloader = dataloader + + # init state + if init_state is None: + self.state = UpdaterState() + else: + self.state = init_state + + self.train_iterator = iter(dataloader) + + def update(self): + self.state.iteration += 1 + + # switch to training mode + for layer in self.models.values(): + layer.train() + + # training for a step is implemented here + batch = self.read_batch() + self.update_core(batch) + + def update_core(self, batch): + """A simple case for a training step. Basic assumptions are: + Single model; + Single optimizer; + A batch from the dataloader is just the input of the model; + The model return a single loss, or a dict containing serval losses. + Parameters updates at every batch, no gradient accumulation. + """ + loss = self.model(*batch) + + if isinstance(loss, Tensor): + loss_dict = {"main": loss} + else: + # Dict[str, Tensor] + loss_dict = loss + if "main" not in loss_dict: + main_loss = 0 + for loss_item in loss.values(): + main_loss += loss_item + loss_dict["main"] = main_loss + + for name, loss_item in loss_dict.items(): + report(name, float(loss_item)) + + self.optimizer.clear_gradient() + loss_dict["main"].backward() + self.optimizer.update() + + def new_epoch(self): + """Start a new epoch.""" + self.state.epoch += 1 + + # NOTE: all batch sampler for distributed training should + # subclass DistributedBatchSampler and implement `set_epoch` method + batch_sampler = self.dataloader.batch_sampler + if isinstance(batch_sampler, DistributedBatchSampler): + batch_sampler.set_epoch(self.state.epoch) + self.train_iterator = iter(self.dataloader) + + def read_batch(self): + """Read a batch from the data loader, auto renew when data is exhausted.""" + with timer() as t: + try: + batch = next(self.train_iterator) + except StopIteration: + self.new_epoch() + batch = next(self.train_iterator) + logging.debug( + f"Read a batch takes {t.elapse}s.") # replace it with logging + return batch + + def state_dict(self): + """State dict of a Updater, model, optimizer and updater state are included.""" + state_dict = super().state_dict() + for name, layer in self.models.items(): + state_dict[f"{name}_params"] = layer.state_dict() + for name, optim in self.optimizers.items(): + state_dict[f"{name}_optimizer"] = optim.state_dict() + return state_dict + + def set_state_dict(self, state_dict): + """Set state dict for a Updater. Parameters of models, states for + optimizers and UpdaterState are restored.""" + for name, layer in self.models.items(): + layer.set_state_dict(state_dict[f"{name}_params"]) + for name, optim in self.optimizers.items(): + optim.set_state_dict(state_dict[f"{name}_optimizer"]) + super().set_state_dict(state_dict) + + def save(self, path): + """Save Updater state dict.""" + archive = self.state_dict() + paddle.save(archive, path) + + def load(self, path): + """Load Updater state dict.""" + archive = paddle.load(path) + self.set_state_dict(archive) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py deleted file mode 100644 index 120115f..0000000 --- a/tests/test_checkpoint.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path -import shutil - -import numpy as np -from parakeet.training.checkpoint import KBest, KLatest - - -def test_kbest(): - def save_fn(path): - with open(path, 'wt') as f: - f.write(f"My path is {str(path)}\n") - - K = 1 - kbest_manager = KBest(max_size=K, save_fn=save_fn) - checkpoint_dir = Path("checkpoints") - if checkpoint_dir.exists(): - shutil.rmtree(checkpoint_dir) - checkpoint_dir.mkdir(parents=True) - a = np.random.rand(20) - for i, score in enumerate(a): - path = checkpoint_dir / f"step_{i}" - kbest_manager.add_checkpoint(score, path) - assert len(list(checkpoint_dir.glob("step_*"))) == K - shutil.rmtree(checkpoint_dir) - - -def test_klatest(): - def save_fn(path): - with open(path, 'wt') as f: - f.write(f"My path is {str(path)}\n") - - K = 5 - klatest_manager = KLatest(max_size=K, save_fn=save_fn) - checkpoint_dir = Path("checkpoints") - if checkpoint_dir.exists(): - shutil.rmtree(checkpoint_dir) - checkpoint_dir.mkdir(parents=True) - for i in range(20): - path = checkpoint_dir / f"step_{i}" - klatest_manager.add_checkpoint(path) - assert len(list(checkpoint_dir.glob("step_*"))) == K - shutil.rmtree(checkpoint_dir)