diff --git a/examples/parallelwave_gan/baker/conf/default.yaml b/examples/parallelwave_gan/baker/conf/default.yaml index 29d6c0f..ce5b064 100644 --- a/examples/parallelwave_gan/baker/conf/default.yaml +++ b/examples/parallelwave_gan/baker/conf/default.yaml @@ -125,3 +125,4 @@ log_interval_steps: 100 # Interval steps to record the training # OTHER SETTING # ########################################################### num_save_intermediate_results: 4 # Number of results to be saved as intermediate results. +num_snapshots: 10 \ No newline at end of file diff --git a/examples/parallelwave_gan/baker/pwg_updater.py b/examples/parallelwave_gan/baker/pwg_updater.py index bd7dbeb..f7ff916 100644 --- a/examples/parallelwave_gan/baker/pwg_updater.py +++ b/examples/parallelwave_gan/baker/pwg_updater.py @@ -24,11 +24,10 @@ from paddle.io import DistributedBatchSampler from timer import timer from parakeet.datasets.data_table import DataTable -from parakeet.training.updater import UpdaterBase, UpdaterState, StandardUpdater +from parakeet.training.updaters.standard_updater import StandardUpdater, UpdaterState from parakeet.training.extensions.evaluator import StandardEvaluator from parakeet.training.trainer import Trainer from parakeet.training.reporter import report -from parakeet.training.checkpoint import KBest, KLatest from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator from parakeet.modules.stft_loss import MultiResolutionSTFTLoss from parakeet.utils.profile import synchronize diff --git a/examples/parallelwave_gan/baker/train.py b/examples/parallelwave_gan/baker/train.py index 7232f0b..bf8767a 100644 --- a/examples/parallelwave_gan/baker/train.py +++ b/examples/parallelwave_gan/baker/train.py @@ -36,11 +36,11 @@ from parakeet.datasets.data_table import DataTable from parakeet.training.updater import UpdaterBase from parakeet.training.trainer import Trainer from parakeet.training.reporter import report -from parakeet.training.checkpoint import KBest, KLatest +from parakeet.training import extension +from parakeet.training.extensions.snapshot import Snapshot +from parakeet.training.extensions.visualizer import VisualDL from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator from parakeet.modules.stft_loss import MultiResolutionSTFTLoss -from parakeet.training.extensions.visualizer import VisualDL -from parakeet.training.extensions.snapshot import Snapshot from parakeet.training.seeding import seed_everything from batch_fn import Clip @@ -66,6 +66,9 @@ def train_sp(args, config): f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", ) + # dataloader has been too verbose + logging.getLogger("DataLoader").disabled = True + # construct dataset for training and validation with jsonlines.open(args.train_metadata, 'r') as reader: train_metadata = list(reader) @@ -106,12 +109,12 @@ def train_sp(args, config): train_dataloader = DataLoader( train_dataset, batch_sampler=train_sampler, - collate_fn=train_batch_fn, # TODO(defaine collate fn) + collate_fn=train_batch_fn, num_workers=config.num_workers) dev_dataloader = DataLoader( dev_dataset, batch_sampler=dev_sampler, - collate_fn=train_batch_fn, # TODO(defaine collate fn) + collate_fn=train_batch_fn, num_workers=config.num_workers) print("dataloaders done!") @@ -191,18 +194,14 @@ def train_sp(args, config): trigger=(config.eval_interval_steps, 'iteration'), priority=3) if dist.get_rank() == 0: - log_writer = LogWriter(str(output_dir)) + writer = LogWriter(str(trainer.out)) + trainer.extend(VisualDL(writer), trigger=(1, 'iteration')) trainer.extend( - VisualDL(log_writer), trigger=(1, 'iteration'), priority=1) - trainer.extend( - Snapshot(checkpoint_dir), - trigger=(config.save_interval_steps, 'iteration'), - priority=2) - print("Trainer Done!") + Snapshot(max_size=config.num_snapshots), + trigger=(config.save_interval_steps, 'iteration')) - # with paddle.fluid.profiler.profiler('All', 'total', - # str(output_dir / "profiler.log"), - # 'Default') as prof: + print(trainer.extensions.keys()) + print("Trainer Done!") trainer.run() diff --git a/parakeet/training/extensions/snapshot.py b/parakeet/training/extensions/snapshot.py index 853d62e..a209524 100644 --- a/parakeet/training/extensions/snapshot.py +++ b/parakeet/training/extensions/snapshot.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import logging from pathlib import Path from datetime import datetime from typing import List, Dict, Any @@ -48,25 +49,25 @@ class Snapshot(extension.Extension): trigger = (1, 'epoch') priority = -100 + default_name = "snapshot" def __init__(self, max_size: int=5, snapshot_on_error: bool=False): self.records: List[Dict[str, Any]] = [] self.max_size = max_size self._snapshot_on_error = snapshot_on_error self._save_all = (max_size == -1) - self.save_fn =... - self.del_fn = os.remove self.checkpoint_dir =... def initialize(self, trainer: Trainer): """Setting up this extention.""" - self.save_fn = trainer.updater.save self.checkpoint_dir = trainer.out / "checkpoints" # load existing records - record_path: Path = self.checkpoint_dir / "records.yaml" + record_path: Path = self.checkpoint_dir / "records.jsonl" if record_path.exists(): + logging.debug("Loading from an existing checkpoint dir") self.records = load_records(record_path) + trainer.updater.load(self.records[-1]['path']) def on_error(self, trainer, exc, tb): if self._snapshot_on_error: @@ -87,10 +88,10 @@ class Snapshot(extension.Extension): path = self.checkpoint_dir / f"snapshot_iter_{iteration}.pdz" # add the new one - self.save_fn(path) + trainer.updater.save(path) record = { "time": str(datetime.now()), - 'path': str(path), + 'path': str(path.resolve()), # use absolute path 'iteration': iteration } self.records.append(record) @@ -98,7 +99,7 @@ class Snapshot(extension.Extension): # remove the earist if self.full(): eariest_record = self.records[0] - self.del_fn(eariest_record["path"]) + os.remove(eariest_record["path"]) self.records.pop(0) # update the record file diff --git a/parakeet/training/extensions/visualizer.py b/parakeet/training/extensions/visualizer.py index 2a42ae0..138bf1e 100644 --- a/parakeet/training/extensions/visualizer.py +++ b/parakeet/training/extensions/visualizer.py @@ -29,11 +29,8 @@ class VisualDL(extension.Extension): default_name = 'visualdl' priority = extension.PRIORITY_READER - def __init__(self): - self.writer =... - - def initialize(self, trainer): - self.writer = LogWriter(logdir=str(trainer.out)) + def __init__(self, writer): + self.writer = writer def __call__(self, trainer: Trainer): for k, v in trainer.observation.items(): diff --git a/parakeet/training/trainer.py b/parakeet/training/trainer.py index f4b6fbb..9e90d9a 100644 --- a/parakeet/training/trainer.py +++ b/parakeet/training/trainer.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys +import six +import traceback from pathlib import Path from collections import OrderedDict from typing import Callable, Union, List @@ -63,7 +66,7 @@ class Trainer(object): if name is None: name = getattr(extension, 'name', None) if name is None: - name = getattr(extenion, 'default_name', None) + name = getattr(extension, 'default_name', None) if name is None: name = getattr(extension, '__name__', None) if name is None: @@ -112,7 +115,7 @@ class Trainer(object): extensions = [(name, self.extensions[name]) for name in extension_order] - print("initializing") + # initializing all extensions for name, entry in extensions: if hasattr(entry.extension, "initialize"): entry.extension.initialize(self) @@ -120,6 +123,8 @@ class Trainer(object): update = self.updater.update # training step stop_trigger = self.stop_trigger + print(self.updater.state) + # TODO(chenfeiyu): display progress bar correctly # if the trainer is controlled by epoch: use 2 progressbars # if the trainer is controlled by iteration: use 1 progressbar @@ -129,24 +134,50 @@ class Trainer(object): else: max_iteration = self.stop_trigger.period - p = tqdm.tqdm() + p = tqdm.tqdm(initial=self.updater.state.iteration) - while True: - self.observation = {} - # set observation as the report target - # you can use report freely in Updater.update() + try: + while not stop_trigger(self): + self.observation = {} + # set observation as the report target + # you can use report freely in Updater.update() - # updating parameters and state - with scope(self.observation): - update() - p.update() - print(self.observation) + # updating parameters and state + with scope(self.observation): + update() + p.update() - # execute extension when necessary - for name, entry in extensions: - if entry.trigger(self): - entry.extension(self) + # execute extension when necessary + for name, entry in extensions: + if entry.trigger(self): + entry.extension(self) - if stop_trigger(self): - print("Training Done!") - break + # print("###", self.observation) + except Exception as e: + f = sys.stderr + f.write(f"Exception in main training loop: {e}\n") + f.write("Traceback (most recent call last):\n") + traceback.print_tb(sys.exc_info()[2]) + f.write( + "Trainer extensions will try to handle the extension. Then all extensions will finalize." + ) + + # capture the exception in the mian training loop + exc_info = sys.exc_info() + + # try to handle it + for name, entry in extensions: + if hasattr(entry.extension, "on_error"): + try: + entry.extension.on_error(self, e, sys.exc_info()[2]) + except Exception as ee: + f.write(f"Exception in error handler: {ee}\n") + f.write('Traceback (most recent call last):\n') + traceback.print_tb(sys.exc_info()[2]) + + # raise exception in main training loop + six.reraise(*exc_info) + finally: + for name, entry in extensions: + if hasattr(entry.extension, "finalize"): + entry.extension.finalize(self) diff --git a/parakeet/training/updater.py b/parakeet/training/updater.py index 2d9ec3d..5ec5eec 100644 --- a/parakeet/training/updater.py +++ b/parakeet/training/updater.py @@ -87,9 +87,11 @@ class UpdaterBase(object): self.state.iteration = state_dict["iteration"] def save(self, path): + logging.debug(f"Saving to {path}.") archive = self.state_dict() - paddle.save(archive, path) + paddle.save(archive, str(path)) def load(self, path): - archive = paddle.load(path) + logging.debug(f"Loading from {path}.") + archive = paddle.load(str(path)) self.set_state_dict(archive)