add new trainer
This commit is contained in:
parent
27e0201d0d
commit
4d3014f4d5
|
@ -607,7 +607,7 @@ class Tacotron2(nn.Layer):
|
|||
num_layers=postnet_conv_layers,
|
||||
dropout=p_postnet_dropout)
|
||||
|
||||
def forward(self, text_inputs, text_lens, mels, output_lens=None, speaker_ids=None, tones=Nones):
|
||||
def forward(self, text_inputs, text_lens, mels, output_lens=None, speaker_ids=None, tones=None):
|
||||
"""Calculate forward propagation of tacotron2.
|
||||
|
||||
Parameters
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import contextlib
|
||||
|
||||
OBSERVATIONS = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def scope(observations):
|
||||
# make `observation` the target to report to.
|
||||
# it is basically a dictionary that stores temporary observations
|
||||
global OBSERVATIONS
|
||||
old = OBSERVATIONS
|
||||
OBSERVATIONS = observations
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
OBSERVATIONS = old
|
||||
|
||||
def get_observations():
|
||||
global OBSERVATIONS
|
||||
return OBSERVATIONS
|
||||
|
||||
def report(name, value):
|
||||
# a simple function to report named value
|
||||
# you can use it everywhere, it will get the default target and writ to it
|
||||
# you can think of it as std.out
|
||||
observations = get_observations()
|
||||
if observations is None:
|
||||
return
|
||||
else:
|
||||
observations[name] = value
|
|
@ -0,0 +1,78 @@
|
|||
from pathlib import Path
|
||||
import tqdm
|
||||
from dataclasses import dataclass
|
||||
|
||||
from parakeet.training.trigger import get_trigger, IntervalTrigger
|
||||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.reporter import scope
|
||||
|
||||
|
||||
class ExtensionEntry(object):
|
||||
def __init__(self, extension, trigger, priority):
|
||||
self.extension = extension
|
||||
self.trigger = trigger
|
||||
self.priority = priority
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self,
|
||||
updater: UpdaterBase,
|
||||
stop_trigger=None,
|
||||
out='result',
|
||||
extensions=None):
|
||||
self.updater = updater
|
||||
self.extensions = {}
|
||||
self.stop_trigger = get_trigger(stop_trigger)
|
||||
self.out = Path(out)
|
||||
self.observation = {}
|
||||
|
||||
def setup(self):
|
||||
pass
|
||||
|
||||
def extend(self, extension, name=None, trigger=None, priority=None):
|
||||
trigger = get_trigger(trigger)
|
||||
|
||||
ordinal = 0
|
||||
modified_name = name
|
||||
while name in self.extensions:
|
||||
ordinal += 1
|
||||
modified_name = f"{name}_{ordinal}"
|
||||
|
||||
self.extensions[modified_name] = ExtensionEntry(
|
||||
extension, trigger, priority)
|
||||
|
||||
def run(self):
|
||||
# sort extensions by priorities once
|
||||
extension_order = sorted(
|
||||
self.extensions.keys(),
|
||||
key=lambda name: self.extensions[name].priority,
|
||||
reverse=True)
|
||||
extensions = [(name, self.extensions[name])
|
||||
for name in extension_order]
|
||||
|
||||
update = self.updater.update
|
||||
stop_trigger = self.stop_trigger
|
||||
|
||||
# TODO(chenfeiyu): display progress bar correctly
|
||||
# if the trainer is controlled by epoch: use 2 progressbars
|
||||
# if the trainer is controlled by iteration: use 1 progressbar
|
||||
if isinstance(stop_trigger, IntervalTrigger):
|
||||
if stop_trigger.unit is 'epoch':
|
||||
max_epoch = self.stop_trigger.period
|
||||
else:
|
||||
max_iteration = self.stop_trigger.period
|
||||
|
||||
while not stop_trigger(self):
|
||||
self.observation = {}
|
||||
# set observation as the report target
|
||||
# you can use report freely in Updater.update()
|
||||
|
||||
# updating parameters and state
|
||||
with scope(self.observation):
|
||||
update()
|
||||
|
||||
# execute extension when necessary
|
||||
for name, entry in extensions:
|
||||
if entry.trigger(self):
|
||||
entry.extension(self)
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
class IntervalTrigger(object):
|
||||
def __init__(self, period: int , unit: str):
|
||||
if unit not in ("iteration", "epoch"):
|
||||
raise ValueError("unit should be 'iteration' or 'epoch'")
|
||||
self.period = period
|
||||
self.unit = unit
|
||||
|
||||
def __call__(self, trainer):
|
||||
state = trainer.updater.state
|
||||
if self.unit == "epoch":
|
||||
fire = not (state.epoch % self.period)
|
||||
else:
|
||||
fire = not (state.iteration % self.iteration)
|
||||
return fire
|
||||
|
||||
|
||||
def never_file_trigger(trainer):
|
||||
return False
|
||||
|
||||
|
||||
def get_trigger(trigger):
|
||||
if trigger is None:
|
||||
return never_file_trigger
|
||||
if callable(trigger):
|
||||
return trigger
|
||||
else:
|
||||
trigger = IntervalTrigger(*trigger)
|
||||
return trigger
|
|
@ -0,0 +1,107 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
from paddle.io import DataLoader
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdaterState:
|
||||
iteration: int = 0
|
||||
epoch: int = 0
|
||||
|
||||
|
||||
class UpdaterBase(object):
|
||||
"""An updater is the abstraction of how a model is trained given the
|
||||
dataloader and the optimizer.
|
||||
|
||||
The `update_core` method is a step in the training loop with only necessary
|
||||
operations (get a batch, forward and backward, update the parameters).
|
||||
|
||||
Other stuffs are made extensions. Visualization, saving, loading and
|
||||
periodical validation and evaluation are not considered here.
|
||||
|
||||
But even in such simplist case, things are not that simple. There is an
|
||||
attempt to standardize this process and requires only the model and
|
||||
dataset and do all the stuffs automatically. But this may hurt flexibility.
|
||||
|
||||
If we assume a batch yield from the dataloader is just the input to the
|
||||
model, we will find that some model requires more arguments, or just some
|
||||
keyword arguments. But this prevents us from over-simplifying it.
|
||||
|
||||
From another perspective, the batch may includes not just the input, but
|
||||
also the target. But the model's forward method may just need the input.
|
||||
We can pass a dict or a super-long tuple to the model and let it pick what
|
||||
it really needs. But this is an abuse of lazy interface.
|
||||
|
||||
After all, we care about how a model is trained. But just how the model is
|
||||
used for inference. We want to control how a model is trained. We just
|
||||
don't want to be messed up with other auxiliary code.
|
||||
|
||||
So the best practice is to define a model and define a updater for it.
|
||||
"""
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def update_core(self):
|
||||
pass
|
||||
|
||||
|
||||
class StandardUpdater(UpdaterBase):
|
||||
"""An example of over-simplification. Things may not be that simple, but
|
||||
you can subclass it to fit your need.
|
||||
"""
|
||||
def __init__(self,
|
||||
model: Layer,
|
||||
dataloader: DataLoader,
|
||||
optimizer: Optimizer,
|
||||
loss_func=None,
|
||||
auto_new_epoch: bool = True,
|
||||
init_state: Optional[UpdaterState] = None):
|
||||
self.model = model
|
||||
self.dataloader = dataloader
|
||||
self.optimizer = optimizer
|
||||
self.loss_func = loss_func
|
||||
self.auto_new_epoch = auto_new_epoch
|
||||
self.iterator = iter(dataloader)
|
||||
|
||||
if init_state is None:
|
||||
self.state = UpdaterState()
|
||||
else:
|
||||
self.state = init_state
|
||||
|
||||
def update(self):
|
||||
self.update_core()
|
||||
self.state.iteration += 1
|
||||
|
||||
def new_epoch(self):
|
||||
self.iterator = iter(self.dataloader)
|
||||
self.state.epoch += 1
|
||||
|
||||
def update_core(self):
|
||||
model = self.model
|
||||
optimizer = self.optimizer
|
||||
loss_func = self.loss_func
|
||||
|
||||
model.train()
|
||||
optimizer.clear_grad()
|
||||
|
||||
# fetch a batch
|
||||
try:
|
||||
batch = next(self.iterator)
|
||||
except StopIteration as e:
|
||||
if self.auto_new_epoch:
|
||||
self.new_epoch()
|
||||
|
||||
# forward
|
||||
if self.loss_func is not None:
|
||||
loss = loss_func(batch)
|
||||
else:
|
||||
loss = model(batch)
|
||||
|
||||
# backward
|
||||
loss.backward()
|
||||
|
||||
# update parameters
|
||||
optimizer.step()
|
Loading…
Reference in New Issue