1. better error handling;

2. use absolute path in snapshot records;
3. visualdl takes a logger as init argument rather than creating one.
This commit is contained in:
chenfeiyu 2021-06-27 18:53:45 +08:00
parent 61c13dd69b
commit 29b8b8b0ea
7 changed files with 80 additions and 50 deletions

View File

@ -125,3 +125,4 @@ log_interval_steps: 100 # Interval steps to record the training
# OTHER SETTING #
###########################################################
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
num_snapshots: 10

View File

@ -24,11 +24,10 @@ from paddle.io import DistributedBatchSampler
from timer import timer
from parakeet.datasets.data_table import DataTable
from parakeet.training.updater import UpdaterBase, UpdaterState, StandardUpdater
from parakeet.training.updaters.standard_updater import StandardUpdater, UpdaterState
from parakeet.training.extensions.evaluator import StandardEvaluator
from parakeet.training.trainer import Trainer
from parakeet.training.reporter import report
from parakeet.training.checkpoint import KBest, KLatest
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
from parakeet.utils.profile import synchronize

View File

@ -36,11 +36,11 @@ from parakeet.datasets.data_table import DataTable
from parakeet.training.updater import UpdaterBase
from parakeet.training.trainer import Trainer
from parakeet.training.reporter import report
from parakeet.training.checkpoint import KBest, KLatest
from parakeet.training import extension
from parakeet.training.extensions.snapshot import Snapshot
from parakeet.training.extensions.visualizer import VisualDL
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
from parakeet.training.extensions.visualizer import VisualDL
from parakeet.training.extensions.snapshot import Snapshot
from parakeet.training.seeding import seed_everything
from batch_fn import Clip
@ -66,6 +66,9 @@ def train_sp(args, config):
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
)
# dataloader has been too verbose
logging.getLogger("DataLoader").disabled = True
# construct dataset for training and validation
with jsonlines.open(args.train_metadata, 'r') as reader:
train_metadata = list(reader)
@ -106,12 +109,12 @@ def train_sp(args, config):
train_dataloader = DataLoader(
train_dataset,
batch_sampler=train_sampler,
collate_fn=train_batch_fn, # TODO(defaine collate fn)
collate_fn=train_batch_fn,
num_workers=config.num_workers)
dev_dataloader = DataLoader(
dev_dataset,
batch_sampler=dev_sampler,
collate_fn=train_batch_fn, # TODO(defaine collate fn)
collate_fn=train_batch_fn,
num_workers=config.num_workers)
print("dataloaders done!")
@ -191,18 +194,14 @@ def train_sp(args, config):
trigger=(config.eval_interval_steps, 'iteration'),
priority=3)
if dist.get_rank() == 0:
log_writer = LogWriter(str(output_dir))
writer = LogWriter(str(trainer.out))
trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
trainer.extend(
VisualDL(log_writer), trigger=(1, 'iteration'), priority=1)
trainer.extend(
Snapshot(checkpoint_dir),
trigger=(config.save_interval_steps, 'iteration'),
priority=2)
print("Trainer Done!")
Snapshot(max_size=config.num_snapshots),
trigger=(config.save_interval_steps, 'iteration'))
# with paddle.fluid.profiler.profiler('All', 'total',
# str(output_dir / "profiler.log"),
# 'Default') as prof:
print(trainer.extensions.keys())
print("Trainer Done!")
trainer.run()

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
import logging
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any
@ -48,25 +49,25 @@ class Snapshot(extension.Extension):
trigger = (1, 'epoch')
priority = -100
default_name = "snapshot"
def __init__(self, max_size: int=5, snapshot_on_error: bool=False):
self.records: List[Dict[str, Any]] = []
self.max_size = max_size
self._snapshot_on_error = snapshot_on_error
self._save_all = (max_size == -1)
self.save_fn =...
self.del_fn = os.remove
self.checkpoint_dir =...
def initialize(self, trainer: Trainer):
"""Setting up this extention."""
self.save_fn = trainer.updater.save
self.checkpoint_dir = trainer.out / "checkpoints"
# load existing records
record_path: Path = self.checkpoint_dir / "records.yaml"
record_path: Path = self.checkpoint_dir / "records.jsonl"
if record_path.exists():
logging.debug("Loading from an existing checkpoint dir")
self.records = load_records(record_path)
trainer.updater.load(self.records[-1]['path'])
def on_error(self, trainer, exc, tb):
if self._snapshot_on_error:
@ -87,10 +88,10 @@ class Snapshot(extension.Extension):
path = self.checkpoint_dir / f"snapshot_iter_{iteration}.pdz"
# add the new one
self.save_fn(path)
trainer.updater.save(path)
record = {
"time": str(datetime.now()),
'path': str(path),
'path': str(path.resolve()), # use absolute path
'iteration': iteration
}
self.records.append(record)
@ -98,7 +99,7 @@ class Snapshot(extension.Extension):
# remove the earist
if self.full():
eariest_record = self.records[0]
self.del_fn(eariest_record["path"])
os.remove(eariest_record["path"])
self.records.pop(0)
# update the record file

View File

@ -29,11 +29,8 @@ class VisualDL(extension.Extension):
default_name = 'visualdl'
priority = extension.PRIORITY_READER
def __init__(self):
self.writer =...
def initialize(self, trainer):
self.writer = LogWriter(logdir=str(trainer.out))
def __init__(self, writer):
self.writer = writer
def __call__(self, trainer: Trainer):
for k, v in trainer.observation.items():

View File

@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import six
import traceback
from pathlib import Path
from collections import OrderedDict
from typing import Callable, Union, List
@ -63,7 +66,7 @@ class Trainer(object):
if name is None:
name = getattr(extension, 'name', None)
if name is None:
name = getattr(extenion, 'default_name', None)
name = getattr(extension, 'default_name', None)
if name is None:
name = getattr(extension, '__name__', None)
if name is None:
@ -112,7 +115,7 @@ class Trainer(object):
extensions = [(name, self.extensions[name])
for name in extension_order]
print("initializing")
# initializing all extensions
for name, entry in extensions:
if hasattr(entry.extension, "initialize"):
entry.extension.initialize(self)
@ -120,6 +123,8 @@ class Trainer(object):
update = self.updater.update # training step
stop_trigger = self.stop_trigger
print(self.updater.state)
# TODO(chenfeiyu): display progress bar correctly
# if the trainer is controlled by epoch: use 2 progressbars
# if the trainer is controlled by iteration: use 1 progressbar
@ -129,24 +134,50 @@ class Trainer(object):
else:
max_iteration = self.stop_trigger.period
p = tqdm.tqdm()
p = tqdm.tqdm(initial=self.updater.state.iteration)
while True:
self.observation = {}
# set observation as the report target
# you can use report freely in Updater.update()
try:
while not stop_trigger(self):
self.observation = {}
# set observation as the report target
# you can use report freely in Updater.update()
# updating parameters and state
with scope(self.observation):
update()
p.update()
print(self.observation)
# updating parameters and state
with scope(self.observation):
update()
p.update()
# execute extension when necessary
for name, entry in extensions:
if entry.trigger(self):
entry.extension(self)
# execute extension when necessary
for name, entry in extensions:
if entry.trigger(self):
entry.extension(self)
if stop_trigger(self):
print("Training Done!")
break
# print("###", self.observation)
except Exception as e:
f = sys.stderr
f.write(f"Exception in main training loop: {e}\n")
f.write("Traceback (most recent call last):\n")
traceback.print_tb(sys.exc_info()[2])
f.write(
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
)
# capture the exception in the mian training loop
exc_info = sys.exc_info()
# try to handle it
for name, entry in extensions:
if hasattr(entry.extension, "on_error"):
try:
entry.extension.on_error(self, e, sys.exc_info()[2])
except Exception as ee:
f.write(f"Exception in error handler: {ee}\n")
f.write('Traceback (most recent call last):\n')
traceback.print_tb(sys.exc_info()[2])
# raise exception in main training loop
six.reraise(*exc_info)
finally:
for name, entry in extensions:
if hasattr(entry.extension, "finalize"):
entry.extension.finalize(self)

View File

@ -87,9 +87,11 @@ class UpdaterBase(object):
self.state.iteration = state_dict["iteration"]
def save(self, path):
logging.debug(f"Saving to {path}.")
archive = self.state_dict()
paddle.save(archive, path)
paddle.save(archive, str(path))
def load(self, path):
archive = paddle.load(path)
logging.debug(f"Loading from {path}.")
archive = paddle.load(str(path))
self.set_state_dict(archive)