From 83c9f0aeae9bd0046b222fbd8b8fe1cd7e22b3e4 Mon Sep 17 00:00:00 2001 From: chenfeiyu Date: Mon, 21 Jun 2021 09:56:26 +0000 Subject: [PATCH] add snapshot and visualizer --- examples/parallelwave_gan/baker/normalize.py | 9 - .../parallelwave_gan/baker/pwg_updater.py | 147 +++++--- examples/parallelwave_gan/baker/train.py | 48 ++- parakeet/training/extensions/evaluator.py | 63 ++++ parakeet/training/extensions/snapshot.py | 44 +++ parakeet/training/extensions/visualizer.py | 34 ++ parakeet/training/seeding.py | 26 ++ parakeet/training/trainer.py | 9 +- parakeet/training/updater.py | 165 +++++++-- parakeet/utils/profile.py | 5 +- parakeet/utils/timeline.py | 319 ++++++++++++++++++ 11 files changed, 764 insertions(+), 105 deletions(-) create mode 100644 parakeet/training/extensions/evaluator.py create mode 100644 parakeet/training/extensions/snapshot.py create mode 100644 parakeet/training/extensions/visualizer.py create mode 100644 parakeet/training/seeding.py create mode 100644 parakeet/utils/timeline.py diff --git a/examples/parallelwave_gan/baker/normalize.py b/examples/parallelwave_gan/baker/normalize.py index 6134917..0cf2841 100644 --- a/examples/parallelwave_gan/baker/normalize.py +++ b/examples/parallelwave_gan/baker/normalize.py @@ -11,12 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -# Copyright 2019 Tomoki Hayashi -# MIT License (https://opensource.org/licenses/MIT) """Normalize feature files and dump them.""" import argparse @@ -28,13 +22,10 @@ from pathlib import Path import numpy as np import yaml import jsonlines - from sklearn.preprocessing import StandardScaler from tqdm import tqdm from parakeet.datasets.data_table import DataTable -from parakeet.utils.h5_utils import read_hdf5 -from parakeet.utils.h5_utils import write_hdf5 from config import get_cfg_default diff --git a/examples/parallelwave_gan/baker/pwg_updater.py b/examples/parallelwave_gan/baker/pwg_updater.py index 29f313b..bd7dbeb 100644 --- a/examples/parallelwave_gan/baker/pwg_updater.py +++ b/examples/parallelwave_gan/baker/pwg_updater.py @@ -12,36 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +from typing import Dict + import paddle -from paddle.fluid.core import _cuda_synchronize +from paddle.nn import Layer +from paddle.optimizer import Optimizer +from paddle.optimizer.lr import LRScheduler +from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler from timer import timer from parakeet.datasets.data_table import DataTable -from parakeet.training.updater import UpdaterBase, UpdaterState +from parakeet.training.updater import UpdaterBase, UpdaterState, StandardUpdater +from parakeet.training.extensions.evaluator import StandardEvaluator from parakeet.training.trainer import Trainer from parakeet.training.reporter import report from parakeet.training.checkpoint import KBest, KLatest from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator from parakeet.modules.stft_loss import MultiResolutionSTFTLoss +from parakeet.utils.profile import synchronize -class PWGUpdater(UpdaterBase): +class PWGUpdater(StandardUpdater): def __init__( self, - models, - optimizers, - criterions, - schedulers, - dataloaders, + models: Dict[str, Layer], + optimizers: Dict[str, Optimizer], + criterions: Dict[str, Layer], + schedulers: Dict[str, LRScheduler], + dataloader: DataLoader, discriminator_train_start_steps: int, lambda_adv: float, ): self.models = models - self.generator = models['generator'] - self.discriminator = models['discriminator'] + self.generator: Layer = models['generator'] + self.discriminator: Layer = models['discriminator'] self.optimizers = optimizers - self.optimizer_g = optimizers['generator'] - self.optimizer_d = optimizers['discriminator'] + self.optimizer_g: Optimizer = optimizers['generator'] + self.optimizer_d: Optimizer = optimizers['discriminator'] self.criterions = criterions self.criterion_stft = criterions['stft'] @@ -51,46 +60,34 @@ class PWGUpdater(UpdaterBase): self.scheduler_g = schedulers['generator'] self.scheduler_d = schedulers['discriminator'] - self.dataloaders = dataloaders - self.train_dataloader = dataloaders['train'] - self.dev_dataloader = dataloaders['dev'] + self.dataloader = dataloader self.discriminator_train_start_steps = discriminator_train_start_steps self.lambda_adv = lambda_adv self.state = UpdaterState(iteration=0, epoch=0) - self.train_iterator = iter(self.train_dataloader) - - def update_core(self): - place = paddle.fluid.framework._current_expected_place() - with timer() as t: - _cuda_synchronize(place) - try: - batch = next(self.train_iterator) - except StopIteration: - self.train_iterator = iter(self.train_dataloader) - batch = next(self.train_iterator) - _cuda_synchronize(place) - print(f"Loading a batch takes {t.elapse}s") + self.train_iterator = iter(self.dataloader) + def update_core(self, batch): + # parse batch wav, mel = batch # Generator noise = paddle.randn(wav.shape) - _cuda_synchronize(place) + synchronize() with timer() as t: wav_ = self.generator(noise, mel) - _cuda_synchronize(place) - print(f"Generator takes {t.elapse}s") + synchronize() + logging.debug(f"Generator takes {t.elapse}s.") ## Multi-resolution stft loss - _cuda_synchronize(place) + synchronize() with timer() as t: sc_loss, mag_loss = self.criterion_stft( wav_.squeeze(1), wav.squeeze(1)) - _cuda_synchronize(place) - print(f"Multi-resolution STFT loss takes {t.elapse}s") + synchronize() + logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.") report("train/spectral_convergence_loss", float(sc_loss)) report("train/log_stft_magnitude_loss", float(mag_loss)) @@ -98,30 +95,31 @@ class PWGUpdater(UpdaterBase): ## Adversarial loss if self.state.iteration > self.discriminator_train_start_steps: + synchronize() with timer() as t: - _cuda_synchronize(place) p_ = self.discriminator(wav_) adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) - _cuda_synchronize(place) - print(f"Discriminator and adversarial loss takes {t.elapse}s") + synchronize() + logging.debug( + f"Discriminator and adversarial loss takes {t.elapse}s") report("train/adversarial_loss", float(adv_loss)) gen_loss += self.lambda_adv * adv_loss report("train/generator_loss", float(gen_loss)) - _cuda_synchronize(place) + synchronize() with timer() as t: self.optimizer_g.clear_grad() gen_loss.backward() - _cuda_synchronize(place) - print(f"Backward takes {t.elapse}s.") + synchronize() + logging.debug(f"Backward takes {t.elapse}s.") - _cuda_synchronize(place) + synchronize() with timer() as t: self.optimizer_g.step() self.scheduler_g.step() - _cuda_synchronize(place) - print(f"Update takes {t.elapse}s.") + synchronize() + logging.debug(f"Update takes {t.elapse}s.") # Disctiminator if self.state.iteration > self.discriminator_train_start_steps: @@ -138,5 +136,68 @@ class PWGUpdater(UpdaterBase): self.optimizer_d.clear_grad() dis_loss.backward() + self.optimizer_d.step() self.scheduler_d.step() + + +class PWGEvaluator(StandardEvaluator): + def __init__(self, models, criterions, dataloader, lambda_adv): + self.models = models + self.generator = models['generator'] + self.discriminator = models['discriminator'] + + self.criterions = criterions + self.criterion_stft = criterions['stft'] + self.criterion_mse = criterions['mse'] + + self.dataloader = dataloader + self.lambda_adv = lambda_adv + + def evaluate_core(self, batch): + logging.debug("Evaluate: ") + wav, mel = batch + noise = paddle.randn(wav.shape) + + synchronize() + with timer() as t: + wav_ = self.generator(noise, mel) + synchronize() + logging.debug(f"Generator takes {t.elapse}s") + + ## Multi-resolution stft loss + synchronize() + with timer() as t: + sc_loss, mag_loss = self.criterion_stft( + wav_.squeeze(1), wav.squeeze(1)) + synchronize() + logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s") + + report("eval/spectral_convergence_loss", float(sc_loss)) + report("eval/log_stft_magnitude_loss", float(mag_loss)) + gen_loss = sc_loss + mag_loss + + ## Adversarial loss + synchronize() + with timer() as t: + p_ = self.discriminator(wav_) + adv_loss = self.criterion_mse(p_, paddle.ones_like(p_)) + synchronize() + logging.debug( + f"Discriminator and adversarial loss takes {t.elapse}s") + report("eval/adversarial_loss", float(adv_loss)) + gen_loss += self.lambda_adv * adv_loss + + report("eval/generator_loss", float(gen_loss)) + + # Disctiminator + with paddle.no_grad(): + wav_ = self.generator(noise, mel) + p = self.discriminator(wav) + p_ = self.discriminator(wav_.detach()) + real_loss = self.criterion_mse(p, paddle.ones_like(p)) + fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_)) + report("eval/real_loss", float(real_loss)) + report("eval/fake_loss", float(fake_loss)) + dis_loss = real_loss + fake_loss + report("eval/discriminator_loss", float(dis_loss)) diff --git a/examples/parallelwave_gan/baker/train.py b/examples/parallelwave_gan/baker/train.py index 087963e..2854b66 100644 --- a/examples/parallelwave_gan/baker/train.py +++ b/examples/parallelwave_gan/baker/train.py @@ -39,10 +39,13 @@ from parakeet.training.reporter import report from parakeet.training.checkpoint import KBest, KLatest from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator from parakeet.modules.stft_loss import MultiResolutionSTFTLoss +from parakeet.training.extensions.visualizer import VisualDL +from parakeet.training.extensions.snapshot import Snapshot +from parakeet.training.seeding import seed_everything from batch_fn import Clip from config import get_cfg_default -from pwg_updater import PWGUpdater +from pwg_updater import PWGUpdater, PWGEvaluator def train_sp(args, config): @@ -56,6 +59,9 @@ def train_sp(args, config): if world_size > 1: paddle.distributed.init_parallel_env() + # set the random seed, it is a must for multiprocess training + seed_everything(42) + print( f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}", ) @@ -128,8 +134,7 @@ def train_sp(args, config): parameters=generator.parameters(), **config["generator_optimizer_params"]) lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"]) - gradient_clip_d = nn.ClipGradByGlobalNorm(config[ - "discriminator_grad_norm"]) + gradient_clip_d = nn.ClipGradByGlobalNorm(config["discriminator_grad_norm"]) optimizer_d = Adam( learning_rate=lr_schedule_d, grad_clip=gradient_clip_d, @@ -138,10 +143,10 @@ def train_sp(args, config): print("optimizers done!") output_dir = Path(args.output_dir) - log_writer = None + checkpoint_dir = output_dir / "checkpoints" if dist.get_rank() == 0: output_dir.mkdir(parents=True, exist_ok=True) - log_writer = LogWriter(str(output_dir)) + checkpoint_dir.mkdir(parents=True, exist_ok=True) updater = PWGUpdater( models={ @@ -160,18 +165,41 @@ def train_sp(args, config): "generator": lr_schedule_g, "discriminator": lr_schedule_d, }, - dataloaders={ - "train": train_dataloader, - "dev": dev_dataloader, - }, + dataloader=train_dataloader, discriminator_train_start_steps=config.discriminator_train_start_steps, lambda_adv=config.lambda_adv, ) + evaluator = PWGEvaluator( + models={ + "generator": generator, + "discriminator": discriminator, + }, + criterions={ + "stft": criterion_stft, + "mse": criterion_mse, + }, + dataloader=dev_dataloader, + lambda_adv=config.lambda_adv, ) + trainer = Trainer( updater, - stop_trigger=(config.train_max_steps, "iteration"), # PROFILING + stop_trigger=(config.train_max_steps, "iteration"), out=output_dir, ) + trainer.extend( + evaluator, + trigger=(config.eval_interval_steps, 'iteration'), + priority=3) + if dist.get_rank() == 0: + log_writer = LogWriter(str(output_dir)) + trainer.extend( + VisualDL(log_writer), trigger=(1, 'iteration'), priority=1) + trainer.extend( + Snapshot(checkpoint_dir), + trigger=(config.save_interval_steps, 'iteration'), + priority=2) + print("Trainer Done!") + # with paddle.fluid.profiler.profiler('All', 'total', # str(output_dir / "profiler.log"), # 'Default') as prof: diff --git a/parakeet/training/extensions/evaluator.py b/parakeet/training/extensions/evaluator.py new file mode 100644 index 0000000..eb3f877 --- /dev/null +++ b/parakeet/training/extensions/evaluator.py @@ -0,0 +1,63 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import Optional, Callable, Dict + +from tqdm import tqdm +import paddle +from paddle import Tensor +from paddle.nn import Layer +from paddle.io import DataLoader + +from parakeet.training.reporter import scope, report, DictSummary + + +class StandardEvaluator(object): + def __init__(self, model: Layer, dataloader: DataLoader): + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models + self.model = model + + # dataloaders + self.dataloader = dataloader + + def evaluate_core(self, batch): + # compute + self.model(batch) # you may report here + + def evaluate(self): + # switch to eval mode + for layer in self.models.values(): + layer.eval() + + summary = DictSummary() + for batch in self.dataloader: + observation = {} + with scope(observation): + with paddle.no_grad(): + self.evaluate_core( + batch) # main evaluation computation here. + summary.add(observation) + summary = summary.compute_mean() + return summary + + def __call__(self, trainer=None): + self.observation = {} + with scope(self.observation): + summary = self.evaluate() + for k, v in summary.items(): + report(k, v) + print(self.observation) diff --git a/parakeet/training/extensions/snapshot.py b/parakeet/training/extensions/snapshot.py new file mode 100644 index 0000000..e31403b --- /dev/null +++ b/parakeet/training/extensions/snapshot.py @@ -0,0 +1,44 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Union +from pathlib import Path + +from parakeet.utils.mp_tools import rank_zero_only +from parakeet.training.trainer import Trainer + + +class Snapshot(object): + """An extension to make snapshot of the updater object inside + the trainer. It is done by calling the updater's `save` method. + + An Updater save its state_dict by default, which contains the + updater state, (i.e. epoch and iteration) and all the model + parameters and optimizer states. If the updater inside the trainer + subclasses StandardUpdater, everything is good to go. + + Parameters + ---------- + checkpoint_dir : Union[str, Path] + The directory to save checkpoints into. + """ + + def __init__(self, checkpoint_dir: Union[str, Path]): + self.checkpoint_dir = Path(checkpoint_dir) + + @rank_zero_only + def __call__(self, trainer: Trainer): + iteration = trainer.updater.state.iteration + path = self.checkpoint_dir / f"step_{iteration}.pdz" + trainer.updater.save(str(path)) diff --git a/parakeet/training/extensions/visualizer.py b/parakeet/training/extensions/visualizer.py new file mode 100644 index 0000000..9a1bcbb --- /dev/null +++ b/parakeet/training/extensions/visualizer.py @@ -0,0 +1,34 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from visualdl import LogWriter +from parakeet.training.trainer import Trainer +from parakeet.utils.mp_tools import rank_zero_only + + +class VisualDL(object): + """A wrapper of visualdl log writer. It assumes that the metrics to be visualized + are all scalars which are recorded into the `.observation` dictionary of the + trainer object. The dictionary is created for each step, thus the visualdl log + writer uses the iteration from the updater's `iteration` as the global step to + add records. + """ + + def __init__(self, writer: LogWriter): + self.writer = writer + + @rank_zero_only + def __call__(self, trainer: Trainer): + for k, v in trainer.observation.items(): + self.writer.add_scalar(k, v, step=trainer.updater.state.iteration) diff --git a/parakeet/training/seeding.py b/parakeet/training/seeding.py new file mode 100644 index 0000000..1a6660f --- /dev/null +++ b/parakeet/training/seeding.py @@ -0,0 +1,26 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import paddle +import random +import numpy as np + + +def seed_everything(seed: int): + paddle.seed(seed) + random.seed(seed) + np.random.seed(seed) + logging.debug(f"Set the seed of paddle, random, np.random to {seed}.") diff --git a/parakeet/training/trainer.py b/parakeet/training/trainer.py index ad3661f..99e5114 100644 --- a/parakeet/training/trainer.py +++ b/parakeet/training/trainer.py @@ -40,15 +40,13 @@ class Trainer(object): self.out = Path(out) self.observation = {} - def setup(self): - pass - def extend(self, extension, name=None, trigger=None, priority=None): trigger = get_trigger(trigger) ordinal = 0 modified_name = name - while name in self.extensions: + while modified_name in self.extensions: + print(self.extensions.keys()) ordinal += 1 modified_name = f"{name}_{ordinal}" @@ -61,8 +59,7 @@ class Trainer(object): self.extensions.keys(), key=lambda name: self.extensions[name].priority, reverse=True) - extensions = [(name, self.extensions[name]) - for name in extension_order] + extensions = [(name, self.extensions[name]) for name in extension_order] update = self.updater.update stop_trigger = self.stop_trigger diff --git a/parakeet/training/updater.py b/parakeet/training/updater.py index c8383f2..cb2213c 100644 --- a/parakeet/training/updater.py +++ b/parakeet/training/updater.py @@ -12,12 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from dataclasses import dataclass from typing import Optional +from typing import Dict +from typing import Union +from timer import timer +import paddle +from paddle import Tensor from paddle.nn import Layer from paddle.optimizer import Optimizer from paddle.io import DataLoader +from paddle.io import DistributedBatchSampler + +from parakeet.training.reporter import report @dataclass @@ -56,12 +65,34 @@ class UpdaterBase(object): So the best practice is to define a model and define a updater for it. """ - def update(self): - self.state.iteration += 1 - self.update_core() + def __init__(self, init_state=None): + if init_state is None: + self.state = UpdaterState() + else: + self.state = init_state - def update_core(self): - pass + def update(self, batch): + raise NotImplementedError( + "Implement your own `update` method for training a step.") + + def state_dict(self): + state_dict = { + "epoch": self.state.epoch, + "iteration": self.state.iteration, + } + return state_dict + + def set_state_dict(self, state_dict): + self.state.epoch = state_dict["epoch"] + self.state.iteration = state_dict["iteration"] + + def save(self, path): + archive = self.state_dict() + paddle.save(archive, path) + + def load(self, path): + archive = paddle.load(path) + self.set_state_dict(archive) class StandardUpdater(UpdaterBase): @@ -71,54 +102,116 @@ class StandardUpdater(UpdaterBase): def __init__(self, model: Layer, - dataloader: DataLoader, optimizer: Optimizer, - loss_func=None, - auto_new_epoch: bool=True, + dataloader: DataLoader, init_state: Optional[UpdaterState]=None): + # it is designed to hold multiple models + models = {"main": model} + self.models: Dict[str, Layer] = models self.model = model - self.dataloader = dataloader - self.optimizer = optimizer - self.loss_func = loss_func - self.auto_new_epoch = auto_new_epoch - self.iterator = iter(dataloader) + # it is designed to hold multiple optimizers + optimizers = {"main": optimizer} + self.optimizer = optimizer + self.optimizers: Dict[str, Optimizer] = optimizers + + # dataloaders + self.dataloader = dataloader + + # init state if init_state is None: self.state = UpdaterState() else: self.state = init_state + self.train_iterator = iter(dataloader) + def update(self): - self.update_core() self.state.iteration += 1 + # switch to training mode + for layer in self.models.values(): + layer.train() + + # training for a step is implemented here + batch = self.read_batch() + self.update_core(batch) + + def update_core(self, batch): + """A simple case for a training step. Basic assumptions are: + Single model; + Single optimizer; + A batch from the dataloader is just the input of the model; + The model return a single loss, or a dict containing serval losses. + Parameters updates at every batch, no gradient accumulation. + """ + loss = self.model(*batch) + + if isinstance(loss, Tensor): + loss_dict = {"main": loss} + else: + # Dict[str, Tensor] + loss_dict = loss + if "main" not in loss_dict: + main_loss = 0 + for loss_item in loss.values(): + main_loss += loss_item + loss_dict["main"] = main_loss + + for name, loss_item in loss_dict.items(): + report(name, float(loss_item)) + + self.optimizer.clear_gradient() + loss_dict["main"].backward() + self.optimizer.update() + def new_epoch(self): - self.iterator = iter(self.dataloader) + """Start a new epoch.""" self.state.epoch += 1 - def update_core(self): - model = self.model - optimizer = self.optimizer - loss_func = self.loss_func + # NOTE: all batch sampler for distributed training should + # subclass DistributedBatchSampler and implement `set_epoch` method + batch_sampler = self.dataloader.batch_sampler + if isinstance(batch_sampler, DistributedBatchSampler): + batch_sampler.set_epoch(self.state.epoch) + self.train_iterator = iter(self.dataloader) - model.train() - optimizer.clear_grad() - - # fetch a batch - try: - batch = next(self.iterator) - except StopIteration as e: - if self.auto_new_epoch: + def read_batch(self): + """Read a batch from the data loader, auto renew when data is exhausted.""" + with timer() as t: + try: + batch = next(self.train_iterator) + except StopIteration: self.new_epoch() + batch = next(self.train_iterator) + logging.debug( + f"Read a batch takes {t.elapse}s.") # replace it with logging + return batch - # forward - if self.loss_func is not None: - loss = loss_func(batch) - else: - loss = model(batch) + def state_dict(self): + """State dict of a Updater, model, optimizer and updater state are included.""" + state_dict = super().state_dict() + for name, layer in self.models.items(): + state_dict[f"{name}_params"] = layer.state_dict() + for name, optim in self.optimizers.items(): + state_dict[f"{name}_optimizer"] = optim.state_dict() + return state_dict - # backward - loss.backward() + def set_state_dict(self, state_dict): + """Set state dict for a Updater. Parameters of models, states for + optimizers and UpdaterState are restored.""" + for name, layer in self.models.items(): + layer.set_state_dict(state_dict[f"{name}_params"]) + for name, optim in self.optimizers.items(): + optim.set_state_dict(state_dict[f"{name}_optimizer"]) + super().set_state_dict(state_dict) - # update parameters - optimizer.step() + def save(self, path): + """Save Updater state dict.""" + archive = self.state_dict() + paddle.save(archive, path) + + def load(self, path): + """Load Updater state dict.""" + archive = paddle.load(path) + self.set_state_dict(archive) diff --git a/parakeet/utils/profile.py b/parakeet/utils/profile.py index 1e246eb..29007a9 100644 --- a/parakeet/utils/profile.py +++ b/parakeet/utils/profile.py @@ -13,8 +13,11 @@ # limitations under the License. import paddle +from paddle.framework import CUDAPlace def synchronize(): + """Trigger cuda synchronization for better timing.""" place = paddle.fluid.framework._current_expected_place() - paddle.fluid.core._cuda_synchronize(place) + if isinstance(place, CUDAPlace): + paddle.fluid.core._cuda_synchronize(place) diff --git a/parakeet/utils/timeline.py b/parakeet/utils/timeline.py new file mode 100644 index 0000000..2a399b7 --- /dev/null +++ b/parakeet/utils/timeline.py @@ -0,0 +1,319 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import six +import sys +import unittest + +import google.protobuf.text_format as text_format +import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2 + +parser = argparse.ArgumentParser(description=__doc__) +parser.add_argument( + '--profile_path', + type=str, + default='', + help='Input profile file name. If there are multiple file, the format ' + 'should be trainer1=file1,trainer2=file2,ps=file3') +parser.add_argument( + '--timeline_path', type=str, default='', help='Output timeline file name.') +args = parser.parse_args() + + +class _ChromeTraceFormatter(object): + def __init__(self): + self._events = [] + self._metadata = [] + + def _create_event(self, ph, category, name, pid, tid, timestamp): + """Creates a new Chrome Trace event. + + For details of the file format, see: + https://github.com/catapult-project/catapult/blob/master/tracing/README.md + + Args: + ph: The type of event - usually a single character. + category: The event category as a string. + name: The event name as a string. + pid: Identifier of the process generating this event as an integer. + tid: Identifier of the thread generating this event as an integer. + timestamp: The timestamp of this event as a long integer. + + Returns: + A JSON compatible event object. + """ + event = {} + event['ph'] = ph + event['cat'] = category + event['name'] = name.replace("ParallelExecutor::Run/", "") + event['pid'] = pid + event['tid'] = tid + event['ts'] = timestamp + return event + + def emit_pid(self, name, pid): + """Adds a process metadata event to the trace. + + Args: + name: The process name as a string. + pid: Identifier of the process as an integer. + """ + event = {} + event['name'] = 'process_name' + event['ph'] = 'M' + event['pid'] = pid + event['args'] = {'name': name} + self._metadata.append(event) + + def emit_region(self, timestamp, duration, pid, tid, category, name, args): + """Adds a region event to the trace. + + Args: + timestamp: The start timestamp of this region as a long integer. + duration: The duration of this region as a long integer. + pid: Identifier of the process generating this event as an integer. + tid: Identifier of the thread generating this event as an integer. + category: The event category as a string. + name: The event name as a string. + args: A JSON-compatible dictionary of event arguments. + """ + event = self._create_event('X', category, name, pid, tid, timestamp) + event['dur'] = duration + event['args'] = args + self._events.append(event) + + def emit_counter(self, category, name, pid, timestamp, counter, value): + """Emits a record for a single counter. + + Args: + category: The event category as string + name: The event name as string + pid: Identifier of the process generating this event as integer + timestamp: The timestamps of this event as long integer + counter: Name of the counter as string + value: Value of the counter as integer + tid: Thread id of the allocation as integer + """ + event = self._create_event('C', category, name, pid, 0, timestamp) + event['args'] = {counter: value} + self._events.append(event) + + def format_to_string(self, pretty=False): + """Formats the chrome trace to a string. + + Args: + pretty: (Optional.) If True, produce human-readable JSON output. + + Returns: + A JSON-formatted string in Chrome Trace format. + """ + trace = {} + trace['traceEvents'] = self._metadata + self._events + if pretty: + return json.dumps(trace, indent=4, separators=(',', ': ')) + else: + return json.dumps(trace, separators=(',', ':')) + + +class Timeline(object): + def __init__(self, profile_dict): + self._profile_dict = profile_dict + self._pid = 0 + self._devices = dict() + self._mem_devices = dict() + self._chrome_trace = _ChromeTraceFormatter() + + def _allocate_pid(self): + cur_pid = self._pid + self._pid += 1 + return cur_pid + + def _allocate_pids(self): + for k, profile_pb in six.iteritems(self._profile_dict): + for event in profile_pb.events: + if event.type == profiler_pb2.Event.CPU: + if (k, event.device_id, "CPU") not in self._devices: + pid = self._allocate_pid() + self._devices[(k, event.device_id, "CPU")] = pid + # -1 device id represents CUDA API(RunTime) call.(e.g. cudaLaunch, cudaMemcpy) + if event.device_id == -1: + self._chrome_trace.emit_pid("%s:cuda_api" % k, pid) + else: + self._chrome_trace.emit_pid( + "%s:cpu:block:%d" % (k, event.device_id), pid) + elif event.type == profiler_pb2.Event.GPUKernel: + if (k, event.device_id, "GPUKernel") not in self._devices: + pid = self._allocate_pid() + self._devices[(k, event.device_id, "GPUKernel")] = pid + self._chrome_trace.emit_pid("%s:gpu:%d" % + (k, event.device_id), pid) + if not hasattr(profile_pb, "mem_events"): + continue + for mevent in profile_pb.mem_events: + if mevent.place == profiler_pb2.MemEvent.CUDAPlace: + if (k, mevent.device_id, "GPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, "GPU")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:gpu:%d" % (k, mevent.device_id), + pid) + elif mevent.place == profiler_pb2.MemEvent.CPUPlace: + if (k, mevent.device_id, "CPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, "CPU")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:cpu:%d" % (k, mevent.device_id), + pid) + elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace: + if (k, mevent.device_id, "CUDAPinnedPlace" + ) not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, + "CUDAPinnedPlace")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:cudapinnedplace:%d" % + (k, mevent.device_id), pid) + elif mevent.place == profiler_pb2.MemEvent.NPUPlace: + if (k, mevent.device_id, "NPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, mevent.device_id, "NPU")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:npu:%d" % (k, mevent.device_id), + pid) + if (k, 0, "CPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "CPU")] = pid + self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" % + (k, 0), pid) + if (k, 0, "GPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "GPU")] = pid + self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" % + (k, 0), pid) + if (k, 0, "CUDAPinnedPlace") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "CUDAPinnedPlace")] = pid + self._chrome_trace.emit_pid( + "memory usage on %s:cudapinnedplace:%d" % (k, 0), pid) + if (k, 0, "NPU") not in self._mem_devices: + pid = self._allocate_pid() + self._mem_devices[(k, 0, "NPU")] = pid + self._chrome_trace.emit_pid("memory usage on %s:npu:%d" % + (k, 0), pid) + + def _allocate_events(self): + for k, profile_pb in six.iteritems(self._profile_dict): + for event in profile_pb.events: + if event.type == profiler_pb2.Event.CPU: + type = "CPU" + elif event.type == profiler_pb2.Event.GPUKernel: + type = "GPUKernel" + pid = self._devices[(k, event.device_id, type)] + args = {'name': event.name} + if event.memcopy.bytes > 0: + args['mem_bytes'] = event.memcopy.bytes + if hasattr(event, "detail_info") and event.detail_info: + args['detail_info'] = event.detail_info + # TODO(panyx0718): Chrome tracing only handles ms. However, some + # ops takes micro-seconds. Hence, we keep the ns here. + self._chrome_trace.emit_region( + event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid, + event.sub_device_id, 'Op', event.name, args) + + def _allocate_memory_event(self): + if not hasattr(profiler_pb2, "MemEvent"): + return + place_to_str = { + profiler_pb2.MemEvent.CPUPlace: "CPU", + profiler_pb2.MemEvent.CUDAPlace: "GPU", + profiler_pb2.MemEvent.CUDAPinnedPlace: "CUDAPinnedPlace", + profiler_pb2.MemEvent.NPUPlace: "NPU" + } + for k, profile_pb in six.iteritems(self._profile_dict): + mem_list = [] + end_profiler = 0 + for mevent in profile_pb.mem_events: + crt_info = dict() + crt_info['time'] = mevent.start_ns + crt_info['size'] = mevent.bytes + if mevent.place in place_to_str: + place = place_to_str[mevent.place] + else: + place = "UnDefine" + crt_info['place'] = place + pid = self._mem_devices[(k, mevent.device_id, place)] + crt_info['pid'] = pid + crt_info['thread_id'] = mevent.thread_id + crt_info['device_id'] = mevent.device_id + mem_list.append(crt_info) + crt_info = dict() + crt_info['place'] = place + crt_info['pid'] = pid + crt_info['thread_id'] = mevent.thread_id + crt_info['device_id'] = mevent.device_id + crt_info['time'] = mevent.end_ns + crt_info['size'] = -mevent.bytes + mem_list.append(crt_info) + end_profiler = max(end_profiler, crt_info['time']) + mem_list.sort(key=lambda tmp: (tmp.get('time', 0))) + i = 0 + total_size = 0 + while i < len(mem_list): + total_size += mem_list[i]['size'] + while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[ + i + 1]['time']: + total_size += mem_list[i + 1]['size'] + i += 1 + + self._chrome_trace.emit_counter( + "Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'], + 0, total_size) + i += 1 + + def generate_chrome_trace(self): + self._allocate_pids() + self._allocate_events() + self._allocate_memory_event() + return self._chrome_trace.format_to_string() + + +profile_path = '/tmp/profile' +if args.profile_path: + profile_path = args.profile_path +timeline_path = '/tmp/timeline' +if args.timeline_path: + timeline_path = args.timeline_path + +profile_paths = profile_path.split(',') +profile_dict = dict() +if len(profile_paths) == 1: + with open(profile_path, 'rb') as f: + profile_s = f.read() + profile_pb = profiler_pb2.Profile() + profile_pb.ParseFromString(profile_s) + profile_dict['trainer'] = profile_pb +else: + for profile_path in profile_paths: + k, v = profile_path.split('=') + with open(v, 'rb') as f: + profile_s = f.read() + profile_pb = profiler_pb2.Profile() + profile_pb.ParseFromString(profile_s) + profile_dict[k] = profile_pb + +tl = Timeline(profile_dict) +with open(timeline_path, 'w') as f: + f.write(tl.generate_chrome_trace())