1. better error handling;
2. use absolute path in snapshot records; 3. visualdl takes a logger as init argument rather than creating one.
This commit is contained in:
parent
61c13dd69b
commit
29b8b8b0ea
|
@ -125,3 +125,4 @@ log_interval_steps: 100 # Interval steps to record the training
|
|||
# OTHER SETTING #
|
||||
###########################################################
|
||||
num_save_intermediate_results: 4 # Number of results to be saved as intermediate results.
|
||||
num_snapshots: 10
|
|
@ -24,11 +24,10 @@ from paddle.io import DistributedBatchSampler
|
|||
from timer import timer
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.training.updater import UpdaterBase, UpdaterState, StandardUpdater
|
||||
from parakeet.training.updaters.standard_updater import StandardUpdater, UpdaterState
|
||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.checkpoint import KBest, KLatest
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||
from parakeet.utils.profile import synchronize
|
||||
|
|
|
@ -36,11 +36,11 @@ from parakeet.datasets.data_table import DataTable
|
|||
from parakeet.training.updater import UpdaterBase
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.checkpoint import KBest, KLatest
|
||||
from parakeet.training import extension
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.seeding import seed_everything
|
||||
|
||||
from batch_fn import Clip
|
||||
|
@ -66,6 +66,9 @@ def train_sp(args, config):
|
|||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
)
|
||||
|
||||
# dataloader has been too verbose
|
||||
logging.getLogger("DataLoader").disabled = True
|
||||
|
||||
# construct dataset for training and validation
|
||||
with jsonlines.open(args.train_metadata, 'r') as reader:
|
||||
train_metadata = list(reader)
|
||||
|
@ -106,12 +109,12 @@ def train_sp(args, config):
|
|||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=train_sampler,
|
||||
collate_fn=train_batch_fn, # TODO(defaine collate fn)
|
||||
collate_fn=train_batch_fn,
|
||||
num_workers=config.num_workers)
|
||||
dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
batch_sampler=dev_sampler,
|
||||
collate_fn=train_batch_fn, # TODO(defaine collate fn)
|
||||
collate_fn=train_batch_fn,
|
||||
num_workers=config.num_workers)
|
||||
print("dataloaders done!")
|
||||
|
||||
|
@ -191,18 +194,14 @@ def train_sp(args, config):
|
|||
trigger=(config.eval_interval_steps, 'iteration'),
|
||||
priority=3)
|
||||
if dist.get_rank() == 0:
|
||||
log_writer = LogWriter(str(output_dir))
|
||||
writer = LogWriter(str(trainer.out))
|
||||
trainer.extend(VisualDL(writer), trigger=(1, 'iteration'))
|
||||
trainer.extend(
|
||||
VisualDL(log_writer), trigger=(1, 'iteration'), priority=1)
|
||||
trainer.extend(
|
||||
Snapshot(checkpoint_dir),
|
||||
trigger=(config.save_interval_steps, 'iteration'),
|
||||
priority=2)
|
||||
print("Trainer Done!")
|
||||
Snapshot(max_size=config.num_snapshots),
|
||||
trigger=(config.save_interval_steps, 'iteration'))
|
||||
|
||||
# with paddle.fluid.profiler.profiler('All', 'total',
|
||||
# str(output_dir / "profiler.log"),
|
||||
# 'Default') as prof:
|
||||
print(trainer.extensions.keys())
|
||||
print("Trainer Done!")
|
||||
trainer.run()
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
|
@ -48,25 +49,25 @@ class Snapshot(extension.Extension):
|
|||
|
||||
trigger = (1, 'epoch')
|
||||
priority = -100
|
||||
default_name = "snapshot"
|
||||
|
||||
def __init__(self, max_size: int=5, snapshot_on_error: bool=False):
|
||||
self.records: List[Dict[str, Any]] = []
|
||||
self.max_size = max_size
|
||||
self._snapshot_on_error = snapshot_on_error
|
||||
self._save_all = (max_size == -1)
|
||||
self.save_fn =...
|
||||
self.del_fn = os.remove
|
||||
self.checkpoint_dir =...
|
||||
|
||||
def initialize(self, trainer: Trainer):
|
||||
"""Setting up this extention."""
|
||||
self.save_fn = trainer.updater.save
|
||||
self.checkpoint_dir = trainer.out / "checkpoints"
|
||||
|
||||
# load existing records
|
||||
record_path: Path = self.checkpoint_dir / "records.yaml"
|
||||
record_path: Path = self.checkpoint_dir / "records.jsonl"
|
||||
if record_path.exists():
|
||||
logging.debug("Loading from an existing checkpoint dir")
|
||||
self.records = load_records(record_path)
|
||||
trainer.updater.load(self.records[-1]['path'])
|
||||
|
||||
def on_error(self, trainer, exc, tb):
|
||||
if self._snapshot_on_error:
|
||||
|
@ -87,10 +88,10 @@ class Snapshot(extension.Extension):
|
|||
path = self.checkpoint_dir / f"snapshot_iter_{iteration}.pdz"
|
||||
|
||||
# add the new one
|
||||
self.save_fn(path)
|
||||
trainer.updater.save(path)
|
||||
record = {
|
||||
"time": str(datetime.now()),
|
||||
'path': str(path),
|
||||
'path': str(path.resolve()), # use absolute path
|
||||
'iteration': iteration
|
||||
}
|
||||
self.records.append(record)
|
||||
|
@ -98,7 +99,7 @@ class Snapshot(extension.Extension):
|
|||
# remove the earist
|
||||
if self.full():
|
||||
eariest_record = self.records[0]
|
||||
self.del_fn(eariest_record["path"])
|
||||
os.remove(eariest_record["path"])
|
||||
self.records.pop(0)
|
||||
|
||||
# update the record file
|
||||
|
|
|
@ -29,11 +29,8 @@ class VisualDL(extension.Extension):
|
|||
default_name = 'visualdl'
|
||||
priority = extension.PRIORITY_READER
|
||||
|
||||
def __init__(self):
|
||||
self.writer =...
|
||||
|
||||
def initialize(self, trainer):
|
||||
self.writer = LogWriter(logdir=str(trainer.out))
|
||||
def __init__(self, writer):
|
||||
self.writer = writer
|
||||
|
||||
def __call__(self, trainer: Trainer):
|
||||
for k, v in trainer.observation.items():
|
||||
|
|
|
@ -12,6 +12,9 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import six
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Union, List
|
||||
|
@ -63,7 +66,7 @@ class Trainer(object):
|
|||
if name is None:
|
||||
name = getattr(extension, 'name', None)
|
||||
if name is None:
|
||||
name = getattr(extenion, 'default_name', None)
|
||||
name = getattr(extension, 'default_name', None)
|
||||
if name is None:
|
||||
name = getattr(extension, '__name__', None)
|
||||
if name is None:
|
||||
|
@ -112,7 +115,7 @@ class Trainer(object):
|
|||
extensions = [(name, self.extensions[name])
|
||||
for name in extension_order]
|
||||
|
||||
print("initializing")
|
||||
# initializing all extensions
|
||||
for name, entry in extensions:
|
||||
if hasattr(entry.extension, "initialize"):
|
||||
entry.extension.initialize(self)
|
||||
|
@ -120,6 +123,8 @@ class Trainer(object):
|
|||
update = self.updater.update # training step
|
||||
stop_trigger = self.stop_trigger
|
||||
|
||||
print(self.updater.state)
|
||||
|
||||
# TODO(chenfeiyu): display progress bar correctly
|
||||
# if the trainer is controlled by epoch: use 2 progressbars
|
||||
# if the trainer is controlled by iteration: use 1 progressbar
|
||||
|
@ -129,24 +134,50 @@ class Trainer(object):
|
|||
else:
|
||||
max_iteration = self.stop_trigger.period
|
||||
|
||||
p = tqdm.tqdm()
|
||||
p = tqdm.tqdm(initial=self.updater.state.iteration)
|
||||
|
||||
while True:
|
||||
self.observation = {}
|
||||
# set observation as the report target
|
||||
# you can use report freely in Updater.update()
|
||||
try:
|
||||
while not stop_trigger(self):
|
||||
self.observation = {}
|
||||
# set observation as the report target
|
||||
# you can use report freely in Updater.update()
|
||||
|
||||
# updating parameters and state
|
||||
with scope(self.observation):
|
||||
update()
|
||||
p.update()
|
||||
print(self.observation)
|
||||
# updating parameters and state
|
||||
with scope(self.observation):
|
||||
update()
|
||||
p.update()
|
||||
|
||||
# execute extension when necessary
|
||||
for name, entry in extensions:
|
||||
if entry.trigger(self):
|
||||
entry.extension(self)
|
||||
# execute extension when necessary
|
||||
for name, entry in extensions:
|
||||
if entry.trigger(self):
|
||||
entry.extension(self)
|
||||
|
||||
if stop_trigger(self):
|
||||
print("Training Done!")
|
||||
break
|
||||
# print("###", self.observation)
|
||||
except Exception as e:
|
||||
f = sys.stderr
|
||||
f.write(f"Exception in main training loop: {e}\n")
|
||||
f.write("Traceback (most recent call last):\n")
|
||||
traceback.print_tb(sys.exc_info()[2])
|
||||
f.write(
|
||||
"Trainer extensions will try to handle the extension. Then all extensions will finalize."
|
||||
)
|
||||
|
||||
# capture the exception in the mian training loop
|
||||
exc_info = sys.exc_info()
|
||||
|
||||
# try to handle it
|
||||
for name, entry in extensions:
|
||||
if hasattr(entry.extension, "on_error"):
|
||||
try:
|
||||
entry.extension.on_error(self, e, sys.exc_info()[2])
|
||||
except Exception as ee:
|
||||
f.write(f"Exception in error handler: {ee}\n")
|
||||
f.write('Traceback (most recent call last):\n')
|
||||
traceback.print_tb(sys.exc_info()[2])
|
||||
|
||||
# raise exception in main training loop
|
||||
six.reraise(*exc_info)
|
||||
finally:
|
||||
for name, entry in extensions:
|
||||
if hasattr(entry.extension, "finalize"):
|
||||
entry.extension.finalize(self)
|
||||
|
|
|
@ -87,9 +87,11 @@ class UpdaterBase(object):
|
|||
self.state.iteration = state_dict["iteration"]
|
||||
|
||||
def save(self, path):
|
||||
logging.debug(f"Saving to {path}.")
|
||||
archive = self.state_dict()
|
||||
paddle.save(archive, path)
|
||||
paddle.save(archive, str(path))
|
||||
|
||||
def load(self, path):
|
||||
archive = paddle.load(path)
|
||||
logging.debug(f"Loading from {path}.")
|
||||
archive = paddle.load(str(path))
|
||||
self.set_state_dict(archive)
|
||||
|
|
Loading…
Reference in New Issue