ParakeetRebeccaRosario/parakeet/training/trainer.py

187 lines
6.6 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import six
import traceback
from pathlib import Path
from collections import OrderedDict
from typing import Callable, Union, List
import tqdm
from parakeet.training.trigger import get_trigger, IntervalTrigger, LimitTrigger
from parakeet.training.updater import UpdaterBase
from parakeet.training.reporter import scope
from parakeet.training.extension import Extension, PRIORITY_READER
class _ExtensionEntry(object):
def __init__(self, extension, trigger, priority):
self.extension = extension
self.trigger = trigger
self.priority = priority
class Trainer(object):
def __init__(self,
updater: UpdaterBase,
stop_trigger: Callable=None,
out: Union[str, Path]='result',
extensions: List[Extension]=None):
self.updater = updater
self.extensions = OrderedDict()
self.stop_trigger = LimitTrigger(*stop_trigger)
self.out = Path(out)
self.observation =...
self._done = False
if extensions:
for ext in extensions:
self.extend(ext)
@property
def is_before_training(self):
return self.updater.state.iteration == 0
def extend(self, extension, name=None, trigger=None, priority=None):
# get name for the extension
# argument \
# -> extention's name \
# -> default_name (class name, when it is an object) \
# -> function name when it is a function \
# -> error
if name is None:
name = getattr(extension, 'name', None)
if name is None:
name = getattr(extension, 'default_name', None)
if name is None:
name = getattr(extension, '__name__', None)
if name is None:
raise ValueError(
"Name is not given for the extension.")
if name == 'training':
raise ValueError("training is a reserved name.")
if trigger is None:
trigger = getattr(extension, 'trigger', (1, 'iteration'))
trigger = get_trigger(trigger)
if priority is None:
priority = getattr(extension, 'priority', PRIORITY_READER)
# add suffix to avoid nameing conflict
ordinal = 0
modified_name = name
while modified_name in self.extensions:
ordinal += 1
modified_name = f"{name}_{ordinal}"
extension.name = modified_name
self.extensions[modified_name] = _ExtensionEntry(extension, trigger,
priority)
def get_extension(self, name):
"""get extension by name."""
extensions = self.extensions
if name in extensions:
return extensions[name].extension
else:
raise ValueError(f'extension {name} not found')
def run(self):
if self._done:
raise RuntimeError("Training is already done!.")
self.out.mkdir(parents=True, exist_ok=True)
# sort extensions by priorities once
extension_order = sorted(
self.extensions.keys(),
key=lambda name: self.extensions[name].priority,
reverse=True)
extensions = [(name, self.extensions[name])
for name in extension_order]
# initializing all extensions
for name, entry in extensions:
if hasattr(entry.extension, "initialize"):
entry.extension.initialize(self)
update = self.updater.update # training step
stop_trigger = self.stop_trigger
print(self.updater.state)
# display only one progress bar
max_iteration = None
if isinstance(stop_trigger, LimitTrigger):
if stop_trigger.unit is 'epoch':
max_epoch = self.stop_trigger.limit
updates_per_epoch = getattr(self.updater, "updates_per_epoch",
None)
max_iteration = max_epoch * updates_per_epoch if updates_per_epoch else None
else:
max_iteration = self.stop_trigger.limit
p = tqdm.tqdm(
initial=self.updater.state.iteration, total=max_iteration)
try:
while not stop_trigger(self):
self.observation = {}
# set observation as the report target
# you can use report freely in Updater.update()
# updating parameters and state
with scope(self.observation):
update()
p.update()
# execute extension when necessary
for name, entry in extensions:
if entry.trigger(self):
entry.extension(self)
# print("###", self.observation)
except Exception as e:
f = sys.stderr
f.write(f"Exception in main training loop: {e}\n")
f.write("Traceback (most recent call last):\n")
traceback.print_tb(sys.exc_info()[2])
f.write(
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
)
# capture the exception in the mian training loop
exc_info = sys.exc_info()
# try to handle it
for name, entry in extensions:
if hasattr(entry.extension, "on_error"):
try:
entry.extension.on_error(self, e, sys.exc_info()[2])
except Exception as ee:
f.write(f"Exception in error handler: {ee}\n")
f.write('Traceback (most recent call last):\n')
traceback.print_tb(sys.exc_info()[2])
# raise exception in main training loop
six.reraise(*exc_info)
finally:
for name, entry in extensions:
if hasattr(entry.extension, "finalize"):
entry.extension.finalize(self)