add snapshot and visualizer
This commit is contained in:
parent
a738954001
commit
83c9f0aeae
|
@ -11,12 +11,6 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
#!/usr/bin/env python3
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
# Copyright 2019 Tomoki Hayashi
|
|
||||||
# MIT License (https://opensource.org/licenses/MIT)
|
|
||||||
"""Normalize feature files and dump them."""
|
"""Normalize feature files and dump them."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -28,13 +22,10 @@ from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import yaml
|
import yaml
|
||||||
import jsonlines
|
import jsonlines
|
||||||
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
from parakeet.utils.h5_utils import read_hdf5
|
|
||||||
from parakeet.utils.h5_utils import write_hdf5
|
|
||||||
|
|
||||||
from config import get_cfg_default
|
from config import get_cfg_default
|
||||||
|
|
||||||
|
|
|
@ -12,36 +12,45 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.fluid.core import _cuda_synchronize
|
from paddle.nn import Layer
|
||||||
|
from paddle.optimizer import Optimizer
|
||||||
|
from paddle.optimizer.lr import LRScheduler
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
from timer import timer
|
from timer import timer
|
||||||
|
|
||||||
from parakeet.datasets.data_table import DataTable
|
from parakeet.datasets.data_table import DataTable
|
||||||
from parakeet.training.updater import UpdaterBase, UpdaterState
|
from parakeet.training.updater import UpdaterBase, UpdaterState, StandardUpdater
|
||||||
|
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||||
from parakeet.training.trainer import Trainer
|
from parakeet.training.trainer import Trainer
|
||||||
from parakeet.training.reporter import report
|
from parakeet.training.reporter import report
|
||||||
from parakeet.training.checkpoint import KBest, KLatest
|
from parakeet.training.checkpoint import KBest, KLatest
|
||||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||||
|
from parakeet.utils.profile import synchronize
|
||||||
|
|
||||||
|
|
||||||
class PWGUpdater(UpdaterBase):
|
class PWGUpdater(StandardUpdater):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
models,
|
models: Dict[str, Layer],
|
||||||
optimizers,
|
optimizers: Dict[str, Optimizer],
|
||||||
criterions,
|
criterions: Dict[str, Layer],
|
||||||
schedulers,
|
schedulers: Dict[str, LRScheduler],
|
||||||
dataloaders,
|
dataloader: DataLoader,
|
||||||
discriminator_train_start_steps: int,
|
discriminator_train_start_steps: int,
|
||||||
lambda_adv: float, ):
|
lambda_adv: float, ):
|
||||||
self.models = models
|
self.models = models
|
||||||
self.generator = models['generator']
|
self.generator: Layer = models['generator']
|
||||||
self.discriminator = models['discriminator']
|
self.discriminator: Layer = models['discriminator']
|
||||||
|
|
||||||
self.optimizers = optimizers
|
self.optimizers = optimizers
|
||||||
self.optimizer_g = optimizers['generator']
|
self.optimizer_g: Optimizer = optimizers['generator']
|
||||||
self.optimizer_d = optimizers['discriminator']
|
self.optimizer_d: Optimizer = optimizers['discriminator']
|
||||||
|
|
||||||
self.criterions = criterions
|
self.criterions = criterions
|
||||||
self.criterion_stft = criterions['stft']
|
self.criterion_stft = criterions['stft']
|
||||||
|
@ -51,46 +60,34 @@ class PWGUpdater(UpdaterBase):
|
||||||
self.scheduler_g = schedulers['generator']
|
self.scheduler_g = schedulers['generator']
|
||||||
self.scheduler_d = schedulers['discriminator']
|
self.scheduler_d = schedulers['discriminator']
|
||||||
|
|
||||||
self.dataloaders = dataloaders
|
self.dataloader = dataloader
|
||||||
self.train_dataloader = dataloaders['train']
|
|
||||||
self.dev_dataloader = dataloaders['dev']
|
|
||||||
|
|
||||||
self.discriminator_train_start_steps = discriminator_train_start_steps
|
self.discriminator_train_start_steps = discriminator_train_start_steps
|
||||||
self.lambda_adv = lambda_adv
|
self.lambda_adv = lambda_adv
|
||||||
self.state = UpdaterState(iteration=0, epoch=0)
|
self.state = UpdaterState(iteration=0, epoch=0)
|
||||||
|
|
||||||
self.train_iterator = iter(self.train_dataloader)
|
self.train_iterator = iter(self.dataloader)
|
||||||
|
|
||||||
def update_core(self):
|
|
||||||
place = paddle.fluid.framework._current_expected_place()
|
|
||||||
with timer() as t:
|
|
||||||
_cuda_synchronize(place)
|
|
||||||
try:
|
|
||||||
batch = next(self.train_iterator)
|
|
||||||
except StopIteration:
|
|
||||||
self.train_iterator = iter(self.train_dataloader)
|
|
||||||
batch = next(self.train_iterator)
|
|
||||||
_cuda_synchronize(place)
|
|
||||||
print(f"Loading a batch takes {t.elapse}s")
|
|
||||||
|
|
||||||
|
def update_core(self, batch):
|
||||||
|
# parse batch
|
||||||
wav, mel = batch
|
wav, mel = batch
|
||||||
|
|
||||||
# Generator
|
# Generator
|
||||||
noise = paddle.randn(wav.shape)
|
noise = paddle.randn(wav.shape)
|
||||||
|
|
||||||
_cuda_synchronize(place)
|
synchronize()
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
wav_ = self.generator(noise, mel)
|
wav_ = self.generator(noise, mel)
|
||||||
_cuda_synchronize(place)
|
synchronize()
|
||||||
print(f"Generator takes {t.elapse}s")
|
logging.debug(f"Generator takes {t.elapse}s.")
|
||||||
|
|
||||||
## Multi-resolution stft loss
|
## Multi-resolution stft loss
|
||||||
_cuda_synchronize(place)
|
synchronize()
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
sc_loss, mag_loss = self.criterion_stft(
|
sc_loss, mag_loss = self.criterion_stft(
|
||||||
wav_.squeeze(1), wav.squeeze(1))
|
wav_.squeeze(1), wav.squeeze(1))
|
||||||
_cuda_synchronize(place)
|
synchronize()
|
||||||
print(f"Multi-resolution STFT loss takes {t.elapse}s")
|
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
||||||
|
|
||||||
report("train/spectral_convergence_loss", float(sc_loss))
|
report("train/spectral_convergence_loss", float(sc_loss))
|
||||||
report("train/log_stft_magnitude_loss", float(mag_loss))
|
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||||
|
@ -98,30 +95,31 @@ class PWGUpdater(UpdaterBase):
|
||||||
|
|
||||||
## Adversarial loss
|
## Adversarial loss
|
||||||
if self.state.iteration > self.discriminator_train_start_steps:
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
|
synchronize()
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
_cuda_synchronize(place)
|
|
||||||
p_ = self.discriminator(wav_)
|
p_ = self.discriminator(wav_)
|
||||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
_cuda_synchronize(place)
|
synchronize()
|
||||||
print(f"Discriminator and adversarial loss takes {t.elapse}s")
|
logging.debug(
|
||||||
|
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
report("train/adversarial_loss", float(adv_loss))
|
report("train/adversarial_loss", float(adv_loss))
|
||||||
gen_loss += self.lambda_adv * adv_loss
|
gen_loss += self.lambda_adv * adv_loss
|
||||||
|
|
||||||
report("train/generator_loss", float(gen_loss))
|
report("train/generator_loss", float(gen_loss))
|
||||||
|
|
||||||
_cuda_synchronize(place)
|
synchronize()
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
self.optimizer_g.clear_grad()
|
self.optimizer_g.clear_grad()
|
||||||
gen_loss.backward()
|
gen_loss.backward()
|
||||||
_cuda_synchronize(place)
|
synchronize()
|
||||||
print(f"Backward takes {t.elapse}s.")
|
logging.debug(f"Backward takes {t.elapse}s.")
|
||||||
|
|
||||||
_cuda_synchronize(place)
|
synchronize()
|
||||||
with timer() as t:
|
with timer() as t:
|
||||||
self.optimizer_g.step()
|
self.optimizer_g.step()
|
||||||
self.scheduler_g.step()
|
self.scheduler_g.step()
|
||||||
_cuda_synchronize(place)
|
synchronize()
|
||||||
print(f"Update takes {t.elapse}s.")
|
logging.debug(f"Update takes {t.elapse}s.")
|
||||||
|
|
||||||
# Disctiminator
|
# Disctiminator
|
||||||
if self.state.iteration > self.discriminator_train_start_steps:
|
if self.state.iteration > self.discriminator_train_start_steps:
|
||||||
|
@ -138,5 +136,68 @@ class PWGUpdater(UpdaterBase):
|
||||||
|
|
||||||
self.optimizer_d.clear_grad()
|
self.optimizer_d.clear_grad()
|
||||||
dis_loss.backward()
|
dis_loss.backward()
|
||||||
|
|
||||||
self.optimizer_d.step()
|
self.optimizer_d.step()
|
||||||
self.scheduler_d.step()
|
self.scheduler_d.step()
|
||||||
|
|
||||||
|
|
||||||
|
class PWGEvaluator(StandardEvaluator):
|
||||||
|
def __init__(self, models, criterions, dataloader, lambda_adv):
|
||||||
|
self.models = models
|
||||||
|
self.generator = models['generator']
|
||||||
|
self.discriminator = models['discriminator']
|
||||||
|
|
||||||
|
self.criterions = criterions
|
||||||
|
self.criterion_stft = criterions['stft']
|
||||||
|
self.criterion_mse = criterions['mse']
|
||||||
|
|
||||||
|
self.dataloader = dataloader
|
||||||
|
self.lambda_adv = lambda_adv
|
||||||
|
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
logging.debug("Evaluate: ")
|
||||||
|
wav, mel = batch
|
||||||
|
noise = paddle.randn(wav.shape)
|
||||||
|
|
||||||
|
synchronize()
|
||||||
|
with timer() as t:
|
||||||
|
wav_ = self.generator(noise, mel)
|
||||||
|
synchronize()
|
||||||
|
logging.debug(f"Generator takes {t.elapse}s")
|
||||||
|
|
||||||
|
## Multi-resolution stft loss
|
||||||
|
synchronize()
|
||||||
|
with timer() as t:
|
||||||
|
sc_loss, mag_loss = self.criterion_stft(
|
||||||
|
wav_.squeeze(1), wav.squeeze(1))
|
||||||
|
synchronize()
|
||||||
|
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||||
|
|
||||||
|
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||||
|
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
||||||
|
gen_loss = sc_loss + mag_loss
|
||||||
|
|
||||||
|
## Adversarial loss
|
||||||
|
synchronize()
|
||||||
|
with timer() as t:
|
||||||
|
p_ = self.discriminator(wav_)
|
||||||
|
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||||
|
synchronize()
|
||||||
|
logging.debug(
|
||||||
|
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||||
|
report("eval/adversarial_loss", float(adv_loss))
|
||||||
|
gen_loss += self.lambda_adv * adv_loss
|
||||||
|
|
||||||
|
report("eval/generator_loss", float(gen_loss))
|
||||||
|
|
||||||
|
# Disctiminator
|
||||||
|
with paddle.no_grad():
|
||||||
|
wav_ = self.generator(noise, mel)
|
||||||
|
p = self.discriminator(wav)
|
||||||
|
p_ = self.discriminator(wav_.detach())
|
||||||
|
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
||||||
|
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
||||||
|
report("eval/real_loss", float(real_loss))
|
||||||
|
report("eval/fake_loss", float(fake_loss))
|
||||||
|
dis_loss = real_loss + fake_loss
|
||||||
|
report("eval/discriminator_loss", float(dis_loss))
|
||||||
|
|
|
@ -39,10 +39,13 @@ from parakeet.training.reporter import report
|
||||||
from parakeet.training.checkpoint import KBest, KLatest
|
from parakeet.training.checkpoint import KBest, KLatest
|
||||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||||
|
from parakeet.training.extensions.visualizer import VisualDL
|
||||||
|
from parakeet.training.extensions.snapshot import Snapshot
|
||||||
|
from parakeet.training.seeding import seed_everything
|
||||||
|
|
||||||
from batch_fn import Clip
|
from batch_fn import Clip
|
||||||
from config import get_cfg_default
|
from config import get_cfg_default
|
||||||
from pwg_updater import PWGUpdater
|
from pwg_updater import PWGUpdater, PWGEvaluator
|
||||||
|
|
||||||
|
|
||||||
def train_sp(args, config):
|
def train_sp(args, config):
|
||||||
|
@ -56,6 +59,9 @@ def train_sp(args, config):
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
paddle.distributed.init_parallel_env()
|
paddle.distributed.init_parallel_env()
|
||||||
|
|
||||||
|
# set the random seed, it is a must for multiprocess training
|
||||||
|
seed_everything(42)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||||
)
|
)
|
||||||
|
@ -128,8 +134,7 @@ def train_sp(args, config):
|
||||||
parameters=generator.parameters(),
|
parameters=generator.parameters(),
|
||||||
**config["generator_optimizer_params"])
|
**config["generator_optimizer_params"])
|
||||||
lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"])
|
lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"])
|
||||||
gradient_clip_d = nn.ClipGradByGlobalNorm(config[
|
gradient_clip_d = nn.ClipGradByGlobalNorm(config["discriminator_grad_norm"])
|
||||||
"discriminator_grad_norm"])
|
|
||||||
optimizer_d = Adam(
|
optimizer_d = Adam(
|
||||||
learning_rate=lr_schedule_d,
|
learning_rate=lr_schedule_d,
|
||||||
grad_clip=gradient_clip_d,
|
grad_clip=gradient_clip_d,
|
||||||
|
@ -138,10 +143,10 @@ def train_sp(args, config):
|
||||||
print("optimizers done!")
|
print("optimizers done!")
|
||||||
|
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
log_writer = None
|
checkpoint_dir = output_dir / "checkpoints"
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
log_writer = LogWriter(str(output_dir))
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
updater = PWGUpdater(
|
updater = PWGUpdater(
|
||||||
models={
|
models={
|
||||||
|
@ -160,18 +165,41 @@ def train_sp(args, config):
|
||||||
"generator": lr_schedule_g,
|
"generator": lr_schedule_g,
|
||||||
"discriminator": lr_schedule_d,
|
"discriminator": lr_schedule_d,
|
||||||
},
|
},
|
||||||
dataloaders={
|
dataloader=train_dataloader,
|
||||||
"train": train_dataloader,
|
|
||||||
"dev": dev_dataloader,
|
|
||||||
},
|
|
||||||
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
||||||
lambda_adv=config.lambda_adv, )
|
lambda_adv=config.lambda_adv, )
|
||||||
|
|
||||||
|
evaluator = PWGEvaluator(
|
||||||
|
models={
|
||||||
|
"generator": generator,
|
||||||
|
"discriminator": discriminator,
|
||||||
|
},
|
||||||
|
criterions={
|
||||||
|
"stft": criterion_stft,
|
||||||
|
"mse": criterion_mse,
|
||||||
|
},
|
||||||
|
dataloader=dev_dataloader,
|
||||||
|
lambda_adv=config.lambda_adv, )
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
updater,
|
updater,
|
||||||
stop_trigger=(config.train_max_steps, "iteration"), # PROFILING
|
stop_trigger=(config.train_max_steps, "iteration"),
|
||||||
out=output_dir, )
|
out=output_dir, )
|
||||||
|
|
||||||
|
trainer.extend(
|
||||||
|
evaluator,
|
||||||
|
trigger=(config.eval_interval_steps, 'iteration'),
|
||||||
|
priority=3)
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
log_writer = LogWriter(str(output_dir))
|
||||||
|
trainer.extend(
|
||||||
|
VisualDL(log_writer), trigger=(1, 'iteration'), priority=1)
|
||||||
|
trainer.extend(
|
||||||
|
Snapshot(checkpoint_dir),
|
||||||
|
trigger=(config.save_interval_steps, 'iteration'),
|
||||||
|
priority=2)
|
||||||
|
print("Trainer Done!")
|
||||||
|
|
||||||
# with paddle.fluid.profiler.profiler('All', 'total',
|
# with paddle.fluid.profiler.profiler('All', 'total',
|
||||||
# str(output_dir / "profiler.log"),
|
# str(output_dir / "profiler.log"),
|
||||||
# 'Default') as prof:
|
# 'Default') as prof:
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Optional, Callable, Dict
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
import paddle
|
||||||
|
from paddle import Tensor
|
||||||
|
from paddle.nn import Layer
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
|
||||||
|
from parakeet.training.reporter import scope, report, DictSummary
|
||||||
|
|
||||||
|
|
||||||
|
class StandardEvaluator(object):
|
||||||
|
def __init__(self, model: Layer, dataloader: DataLoader):
|
||||||
|
# it is designed to hold multiple models
|
||||||
|
models = {"main": model}
|
||||||
|
self.models: Dict[str, Layer] = models
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
# dataloaders
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
def evaluate_core(self, batch):
|
||||||
|
# compute
|
||||||
|
self.model(batch) # you may report here
|
||||||
|
|
||||||
|
def evaluate(self):
|
||||||
|
# switch to eval mode
|
||||||
|
for layer in self.models.values():
|
||||||
|
layer.eval()
|
||||||
|
|
||||||
|
summary = DictSummary()
|
||||||
|
for batch in self.dataloader:
|
||||||
|
observation = {}
|
||||||
|
with scope(observation):
|
||||||
|
with paddle.no_grad():
|
||||||
|
self.evaluate_core(
|
||||||
|
batch) # main evaluation computation here.
|
||||||
|
summary.add(observation)
|
||||||
|
summary = summary.compute_mean()
|
||||||
|
return summary
|
||||||
|
|
||||||
|
def __call__(self, trainer=None):
|
||||||
|
self.observation = {}
|
||||||
|
with scope(self.observation):
|
||||||
|
summary = self.evaluate()
|
||||||
|
for k, v in summary.items():
|
||||||
|
report(k, v)
|
||||||
|
print(self.observation)
|
|
@ -0,0 +1,44 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from parakeet.utils.mp_tools import rank_zero_only
|
||||||
|
from parakeet.training.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
class Snapshot(object):
|
||||||
|
"""An extension to make snapshot of the updater object inside
|
||||||
|
the trainer. It is done by calling the updater's `save` method.
|
||||||
|
|
||||||
|
An Updater save its state_dict by default, which contains the
|
||||||
|
updater state, (i.e. epoch and iteration) and all the model
|
||||||
|
parameters and optimizer states. If the updater inside the trainer
|
||||||
|
subclasses StandardUpdater, everything is good to go.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
checkpoint_dir : Union[str, Path]
|
||||||
|
The directory to save checkpoints into.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, checkpoint_dir: Union[str, Path]):
|
||||||
|
self.checkpoint_dir = Path(checkpoint_dir)
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def __call__(self, trainer: Trainer):
|
||||||
|
iteration = trainer.updater.state.iteration
|
||||||
|
path = self.checkpoint_dir / f"step_{iteration}.pdz"
|
||||||
|
trainer.updater.save(str(path))
|
|
@ -0,0 +1,34 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from visualdl import LogWriter
|
||||||
|
from parakeet.training.trainer import Trainer
|
||||||
|
from parakeet.utils.mp_tools import rank_zero_only
|
||||||
|
|
||||||
|
|
||||||
|
class VisualDL(object):
|
||||||
|
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
|
||||||
|
are all scalars which are recorded into the `.observation` dictionary of the
|
||||||
|
trainer object. The dictionary is created for each step, thus the visualdl log
|
||||||
|
writer uses the iteration from the updater's `iteration` as the global step to
|
||||||
|
add records.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, writer: LogWriter):
|
||||||
|
self.writer = writer
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def __call__(self, trainer: Trainer):
|
||||||
|
for k, v in trainer.observation.items():
|
||||||
|
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed: int):
|
||||||
|
paddle.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
logging.debug(f"Set the seed of paddle, random, np.random to {seed}.")
|
|
@ -40,15 +40,13 @@ class Trainer(object):
|
||||||
self.out = Path(out)
|
self.out = Path(out)
|
||||||
self.observation = {}
|
self.observation = {}
|
||||||
|
|
||||||
def setup(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def extend(self, extension, name=None, trigger=None, priority=None):
|
def extend(self, extension, name=None, trigger=None, priority=None):
|
||||||
trigger = get_trigger(trigger)
|
trigger = get_trigger(trigger)
|
||||||
|
|
||||||
ordinal = 0
|
ordinal = 0
|
||||||
modified_name = name
|
modified_name = name
|
||||||
while name in self.extensions:
|
while modified_name in self.extensions:
|
||||||
|
print(self.extensions.keys())
|
||||||
ordinal += 1
|
ordinal += 1
|
||||||
modified_name = f"{name}_{ordinal}"
|
modified_name = f"{name}_{ordinal}"
|
||||||
|
|
||||||
|
@ -61,8 +59,7 @@ class Trainer(object):
|
||||||
self.extensions.keys(),
|
self.extensions.keys(),
|
||||||
key=lambda name: self.extensions[name].priority,
|
key=lambda name: self.extensions[name].priority,
|
||||||
reverse=True)
|
reverse=True)
|
||||||
extensions = [(name, self.extensions[name])
|
extensions = [(name, self.extensions[name]) for name in extension_order]
|
||||||
for name in extension_order]
|
|
||||||
|
|
||||||
update = self.updater.update
|
update = self.updater.update
|
||||||
stop_trigger = self.stop_trigger
|
stop_trigger = self.stop_trigger
|
||||||
|
|
|
@ -12,12 +12,21 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from timer import timer
|
||||||
|
import paddle
|
||||||
|
from paddle import Tensor
|
||||||
from paddle.nn import Layer
|
from paddle.nn import Layer
|
||||||
from paddle.optimizer import Optimizer
|
from paddle.optimizer import Optimizer
|
||||||
from paddle.io import DataLoader
|
from paddle.io import DataLoader
|
||||||
|
from paddle.io import DistributedBatchSampler
|
||||||
|
|
||||||
|
from parakeet.training.reporter import report
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -56,12 +65,34 @@ class UpdaterBase(object):
|
||||||
So the best practice is to define a model and define a updater for it.
|
So the best practice is to define a model and define a updater for it.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def update(self):
|
def __init__(self, init_state=None):
|
||||||
self.state.iteration += 1
|
if init_state is None:
|
||||||
self.update_core()
|
self.state = UpdaterState()
|
||||||
|
else:
|
||||||
|
self.state = init_state
|
||||||
|
|
||||||
def update_core(self):
|
def update(self, batch):
|
||||||
pass
|
raise NotImplementedError(
|
||||||
|
"Implement your own `update` method for training a step.")
|
||||||
|
|
||||||
|
def state_dict(self):
|
||||||
|
state_dict = {
|
||||||
|
"epoch": self.state.epoch,
|
||||||
|
"iteration": self.state.iteration,
|
||||||
|
}
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def set_state_dict(self, state_dict):
|
||||||
|
self.state.epoch = state_dict["epoch"]
|
||||||
|
self.state.iteration = state_dict["iteration"]
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
archive = self.state_dict()
|
||||||
|
paddle.save(archive, path)
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
archive = paddle.load(path)
|
||||||
|
self.set_state_dict(archive)
|
||||||
|
|
||||||
|
|
||||||
class StandardUpdater(UpdaterBase):
|
class StandardUpdater(UpdaterBase):
|
||||||
|
@ -71,54 +102,116 @@ class StandardUpdater(UpdaterBase):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: Layer,
|
model: Layer,
|
||||||
dataloader: DataLoader,
|
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
loss_func=None,
|
dataloader: DataLoader,
|
||||||
auto_new_epoch: bool=True,
|
|
||||||
init_state: Optional[UpdaterState]=None):
|
init_state: Optional[UpdaterState]=None):
|
||||||
|
# it is designed to hold multiple models
|
||||||
|
models = {"main": model}
|
||||||
|
self.models: Dict[str, Layer] = models
|
||||||
self.model = model
|
self.model = model
|
||||||
self.dataloader = dataloader
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.loss_func = loss_func
|
|
||||||
self.auto_new_epoch = auto_new_epoch
|
|
||||||
self.iterator = iter(dataloader)
|
|
||||||
|
|
||||||
|
# it is designed to hold multiple optimizers
|
||||||
|
optimizers = {"main": optimizer}
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.optimizers: Dict[str, Optimizer] = optimizers
|
||||||
|
|
||||||
|
# dataloaders
|
||||||
|
self.dataloader = dataloader
|
||||||
|
|
||||||
|
# init state
|
||||||
if init_state is None:
|
if init_state is None:
|
||||||
self.state = UpdaterState()
|
self.state = UpdaterState()
|
||||||
else:
|
else:
|
||||||
self.state = init_state
|
self.state = init_state
|
||||||
|
|
||||||
|
self.train_iterator = iter(dataloader)
|
||||||
|
|
||||||
def update(self):
|
def update(self):
|
||||||
self.update_core()
|
|
||||||
self.state.iteration += 1
|
self.state.iteration += 1
|
||||||
|
|
||||||
|
# switch to training mode
|
||||||
|
for layer in self.models.values():
|
||||||
|
layer.train()
|
||||||
|
|
||||||
|
# training for a step is implemented here
|
||||||
|
batch = self.read_batch()
|
||||||
|
self.update_core(batch)
|
||||||
|
|
||||||
|
def update_core(self, batch):
|
||||||
|
"""A simple case for a training step. Basic assumptions are:
|
||||||
|
Single model;
|
||||||
|
Single optimizer;
|
||||||
|
A batch from the dataloader is just the input of the model;
|
||||||
|
The model return a single loss, or a dict containing serval losses.
|
||||||
|
Parameters updates at every batch, no gradient accumulation.
|
||||||
|
"""
|
||||||
|
loss = self.model(*batch)
|
||||||
|
|
||||||
|
if isinstance(loss, Tensor):
|
||||||
|
loss_dict = {"main": loss}
|
||||||
|
else:
|
||||||
|
# Dict[str, Tensor]
|
||||||
|
loss_dict = loss
|
||||||
|
if "main" not in loss_dict:
|
||||||
|
main_loss = 0
|
||||||
|
for loss_item in loss.values():
|
||||||
|
main_loss += loss_item
|
||||||
|
loss_dict["main"] = main_loss
|
||||||
|
|
||||||
|
for name, loss_item in loss_dict.items():
|
||||||
|
report(name, float(loss_item))
|
||||||
|
|
||||||
|
self.optimizer.clear_gradient()
|
||||||
|
loss_dict["main"].backward()
|
||||||
|
self.optimizer.update()
|
||||||
|
|
||||||
def new_epoch(self):
|
def new_epoch(self):
|
||||||
self.iterator = iter(self.dataloader)
|
"""Start a new epoch."""
|
||||||
self.state.epoch += 1
|
self.state.epoch += 1
|
||||||
|
|
||||||
def update_core(self):
|
# NOTE: all batch sampler for distributed training should
|
||||||
model = self.model
|
# subclass DistributedBatchSampler and implement `set_epoch` method
|
||||||
optimizer = self.optimizer
|
batch_sampler = self.dataloader.batch_sampler
|
||||||
loss_func = self.loss_func
|
if isinstance(batch_sampler, DistributedBatchSampler):
|
||||||
|
batch_sampler.set_epoch(self.state.epoch)
|
||||||
|
self.train_iterator = iter(self.dataloader)
|
||||||
|
|
||||||
model.train()
|
def read_batch(self):
|
||||||
optimizer.clear_grad()
|
"""Read a batch from the data loader, auto renew when data is exhausted."""
|
||||||
|
with timer() as t:
|
||||||
# fetch a batch
|
try:
|
||||||
try:
|
batch = next(self.train_iterator)
|
||||||
batch = next(self.iterator)
|
except StopIteration:
|
||||||
except StopIteration as e:
|
|
||||||
if self.auto_new_epoch:
|
|
||||||
self.new_epoch()
|
self.new_epoch()
|
||||||
|
batch = next(self.train_iterator)
|
||||||
|
logging.debug(
|
||||||
|
f"Read a batch takes {t.elapse}s.") # replace it with logging
|
||||||
|
return batch
|
||||||
|
|
||||||
# forward
|
def state_dict(self):
|
||||||
if self.loss_func is not None:
|
"""State dict of a Updater, model, optimizer and updater state are included."""
|
||||||
loss = loss_func(batch)
|
state_dict = super().state_dict()
|
||||||
else:
|
for name, layer in self.models.items():
|
||||||
loss = model(batch)
|
state_dict[f"{name}_params"] = layer.state_dict()
|
||||||
|
for name, optim in self.optimizers.items():
|
||||||
|
state_dict[f"{name}_optimizer"] = optim.state_dict()
|
||||||
|
return state_dict
|
||||||
|
|
||||||
# backward
|
def set_state_dict(self, state_dict):
|
||||||
loss.backward()
|
"""Set state dict for a Updater. Parameters of models, states for
|
||||||
|
optimizers and UpdaterState are restored."""
|
||||||
|
for name, layer in self.models.items():
|
||||||
|
layer.set_state_dict(state_dict[f"{name}_params"])
|
||||||
|
for name, optim in self.optimizers.items():
|
||||||
|
optim.set_state_dict(state_dict[f"{name}_optimizer"])
|
||||||
|
super().set_state_dict(state_dict)
|
||||||
|
|
||||||
# update parameters
|
def save(self, path):
|
||||||
optimizer.step()
|
"""Save Updater state dict."""
|
||||||
|
archive = self.state_dict()
|
||||||
|
paddle.save(archive, path)
|
||||||
|
|
||||||
|
def load(self, path):
|
||||||
|
"""Load Updater state dict."""
|
||||||
|
archive = paddle.load(path)
|
||||||
|
self.set_state_dict(archive)
|
||||||
|
|
|
@ -13,8 +13,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
from paddle.framework import CUDAPlace
|
||||||
|
|
||||||
|
|
||||||
def synchronize():
|
def synchronize():
|
||||||
|
"""Trigger cuda synchronization for better timing."""
|
||||||
place = paddle.fluid.framework._current_expected_place()
|
place = paddle.fluid.framework._current_expected_place()
|
||||||
paddle.fluid.core._cuda_synchronize(place)
|
if isinstance(place, CUDAPlace):
|
||||||
|
paddle.fluid.core._cuda_synchronize(place)
|
||||||
|
|
|
@ -0,0 +1,319 @@
|
||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import six
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import google.protobuf.text_format as text_format
|
||||||
|
import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
'--profile_path',
|
||||||
|
type=str,
|
||||||
|
default='',
|
||||||
|
help='Input profile file name. If there are multiple file, the format '
|
||||||
|
'should be trainer1=file1,trainer2=file2,ps=file3')
|
||||||
|
parser.add_argument(
|
||||||
|
'--timeline_path', type=str, default='', help='Output timeline file name.')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class _ChromeTraceFormatter(object):
|
||||||
|
def __init__(self):
|
||||||
|
self._events = []
|
||||||
|
self._metadata = []
|
||||||
|
|
||||||
|
def _create_event(self, ph, category, name, pid, tid, timestamp):
|
||||||
|
"""Creates a new Chrome Trace event.
|
||||||
|
|
||||||
|
For details of the file format, see:
|
||||||
|
https://github.com/catapult-project/catapult/blob/master/tracing/README.md
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ph: The type of event - usually a single character.
|
||||||
|
category: The event category as a string.
|
||||||
|
name: The event name as a string.
|
||||||
|
pid: Identifier of the process generating this event as an integer.
|
||||||
|
tid: Identifier of the thread generating this event as an integer.
|
||||||
|
timestamp: The timestamp of this event as a long integer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON compatible event object.
|
||||||
|
"""
|
||||||
|
event = {}
|
||||||
|
event['ph'] = ph
|
||||||
|
event['cat'] = category
|
||||||
|
event['name'] = name.replace("ParallelExecutor::Run/", "")
|
||||||
|
event['pid'] = pid
|
||||||
|
event['tid'] = tid
|
||||||
|
event['ts'] = timestamp
|
||||||
|
return event
|
||||||
|
|
||||||
|
def emit_pid(self, name, pid):
|
||||||
|
"""Adds a process metadata event to the trace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The process name as a string.
|
||||||
|
pid: Identifier of the process as an integer.
|
||||||
|
"""
|
||||||
|
event = {}
|
||||||
|
event['name'] = 'process_name'
|
||||||
|
event['ph'] = 'M'
|
||||||
|
event['pid'] = pid
|
||||||
|
event['args'] = {'name': name}
|
||||||
|
self._metadata.append(event)
|
||||||
|
|
||||||
|
def emit_region(self, timestamp, duration, pid, tid, category, name, args):
|
||||||
|
"""Adds a region event to the trace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestamp: The start timestamp of this region as a long integer.
|
||||||
|
duration: The duration of this region as a long integer.
|
||||||
|
pid: Identifier of the process generating this event as an integer.
|
||||||
|
tid: Identifier of the thread generating this event as an integer.
|
||||||
|
category: The event category as a string.
|
||||||
|
name: The event name as a string.
|
||||||
|
args: A JSON-compatible dictionary of event arguments.
|
||||||
|
"""
|
||||||
|
event = self._create_event('X', category, name, pid, tid, timestamp)
|
||||||
|
event['dur'] = duration
|
||||||
|
event['args'] = args
|
||||||
|
self._events.append(event)
|
||||||
|
|
||||||
|
def emit_counter(self, category, name, pid, timestamp, counter, value):
|
||||||
|
"""Emits a record for a single counter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: The event category as string
|
||||||
|
name: The event name as string
|
||||||
|
pid: Identifier of the process generating this event as integer
|
||||||
|
timestamp: The timestamps of this event as long integer
|
||||||
|
counter: Name of the counter as string
|
||||||
|
value: Value of the counter as integer
|
||||||
|
tid: Thread id of the allocation as integer
|
||||||
|
"""
|
||||||
|
event = self._create_event('C', category, name, pid, 0, timestamp)
|
||||||
|
event['args'] = {counter: value}
|
||||||
|
self._events.append(event)
|
||||||
|
|
||||||
|
def format_to_string(self, pretty=False):
|
||||||
|
"""Formats the chrome trace to a string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretty: (Optional.) If True, produce human-readable JSON output.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON-formatted string in Chrome Trace format.
|
||||||
|
"""
|
||||||
|
trace = {}
|
||||||
|
trace['traceEvents'] = self._metadata + self._events
|
||||||
|
if pretty:
|
||||||
|
return json.dumps(trace, indent=4, separators=(',', ': '))
|
||||||
|
else:
|
||||||
|
return json.dumps(trace, separators=(',', ':'))
|
||||||
|
|
||||||
|
|
||||||
|
class Timeline(object):
|
||||||
|
def __init__(self, profile_dict):
|
||||||
|
self._profile_dict = profile_dict
|
||||||
|
self._pid = 0
|
||||||
|
self._devices = dict()
|
||||||
|
self._mem_devices = dict()
|
||||||
|
self._chrome_trace = _ChromeTraceFormatter()
|
||||||
|
|
||||||
|
def _allocate_pid(self):
|
||||||
|
cur_pid = self._pid
|
||||||
|
self._pid += 1
|
||||||
|
return cur_pid
|
||||||
|
|
||||||
|
def _allocate_pids(self):
|
||||||
|
for k, profile_pb in six.iteritems(self._profile_dict):
|
||||||
|
for event in profile_pb.events:
|
||||||
|
if event.type == profiler_pb2.Event.CPU:
|
||||||
|
if (k, event.device_id, "CPU") not in self._devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._devices[(k, event.device_id, "CPU")] = pid
|
||||||
|
# -1 device id represents CUDA API(RunTime) call.(e.g. cudaLaunch, cudaMemcpy)
|
||||||
|
if event.device_id == -1:
|
||||||
|
self._chrome_trace.emit_pid("%s:cuda_api" % k, pid)
|
||||||
|
else:
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"%s:cpu:block:%d" % (k, event.device_id), pid)
|
||||||
|
elif event.type == profiler_pb2.Event.GPUKernel:
|
||||||
|
if (k, event.device_id, "GPUKernel") not in self._devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._devices[(k, event.device_id, "GPUKernel")] = pid
|
||||||
|
self._chrome_trace.emit_pid("%s:gpu:%d" %
|
||||||
|
(k, event.device_id), pid)
|
||||||
|
if not hasattr(profile_pb, "mem_events"):
|
||||||
|
continue
|
||||||
|
for mevent in profile_pb.mem_events:
|
||||||
|
if mevent.place == profiler_pb2.MemEvent.CUDAPlace:
|
||||||
|
if (k, mevent.device_id, "GPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, mevent.device_id, "GPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"memory usage on %s:gpu:%d" % (k, mevent.device_id),
|
||||||
|
pid)
|
||||||
|
elif mevent.place == profiler_pb2.MemEvent.CPUPlace:
|
||||||
|
if (k, mevent.device_id, "CPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, mevent.device_id, "CPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"memory usage on %s:cpu:%d" % (k, mevent.device_id),
|
||||||
|
pid)
|
||||||
|
elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace:
|
||||||
|
if (k, mevent.device_id, "CUDAPinnedPlace"
|
||||||
|
) not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, mevent.device_id,
|
||||||
|
"CUDAPinnedPlace")] = pid
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"memory usage on %s:cudapinnedplace:%d" %
|
||||||
|
(k, mevent.device_id), pid)
|
||||||
|
elif mevent.place == profiler_pb2.MemEvent.NPUPlace:
|
||||||
|
if (k, mevent.device_id, "NPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, mevent.device_id, "NPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"memory usage on %s:npu:%d" % (k, mevent.device_id),
|
||||||
|
pid)
|
||||||
|
if (k, 0, "CPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, 0, "CPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" %
|
||||||
|
(k, 0), pid)
|
||||||
|
if (k, 0, "GPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, 0, "GPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" %
|
||||||
|
(k, 0), pid)
|
||||||
|
if (k, 0, "CUDAPinnedPlace") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, 0, "CUDAPinnedPlace")] = pid
|
||||||
|
self._chrome_trace.emit_pid(
|
||||||
|
"memory usage on %s:cudapinnedplace:%d" % (k, 0), pid)
|
||||||
|
if (k, 0, "NPU") not in self._mem_devices:
|
||||||
|
pid = self._allocate_pid()
|
||||||
|
self._mem_devices[(k, 0, "NPU")] = pid
|
||||||
|
self._chrome_trace.emit_pid("memory usage on %s:npu:%d" %
|
||||||
|
(k, 0), pid)
|
||||||
|
|
||||||
|
def _allocate_events(self):
|
||||||
|
for k, profile_pb in six.iteritems(self._profile_dict):
|
||||||
|
for event in profile_pb.events:
|
||||||
|
if event.type == profiler_pb2.Event.CPU:
|
||||||
|
type = "CPU"
|
||||||
|
elif event.type == profiler_pb2.Event.GPUKernel:
|
||||||
|
type = "GPUKernel"
|
||||||
|
pid = self._devices[(k, event.device_id, type)]
|
||||||
|
args = {'name': event.name}
|
||||||
|
if event.memcopy.bytes > 0:
|
||||||
|
args['mem_bytes'] = event.memcopy.bytes
|
||||||
|
if hasattr(event, "detail_info") and event.detail_info:
|
||||||
|
args['detail_info'] = event.detail_info
|
||||||
|
# TODO(panyx0718): Chrome tracing only handles ms. However, some
|
||||||
|
# ops takes micro-seconds. Hence, we keep the ns here.
|
||||||
|
self._chrome_trace.emit_region(
|
||||||
|
event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid,
|
||||||
|
event.sub_device_id, 'Op', event.name, args)
|
||||||
|
|
||||||
|
def _allocate_memory_event(self):
|
||||||
|
if not hasattr(profiler_pb2, "MemEvent"):
|
||||||
|
return
|
||||||
|
place_to_str = {
|
||||||
|
profiler_pb2.MemEvent.CPUPlace: "CPU",
|
||||||
|
profiler_pb2.MemEvent.CUDAPlace: "GPU",
|
||||||
|
profiler_pb2.MemEvent.CUDAPinnedPlace: "CUDAPinnedPlace",
|
||||||
|
profiler_pb2.MemEvent.NPUPlace: "NPU"
|
||||||
|
}
|
||||||
|
for k, profile_pb in six.iteritems(self._profile_dict):
|
||||||
|
mem_list = []
|
||||||
|
end_profiler = 0
|
||||||
|
for mevent in profile_pb.mem_events:
|
||||||
|
crt_info = dict()
|
||||||
|
crt_info['time'] = mevent.start_ns
|
||||||
|
crt_info['size'] = mevent.bytes
|
||||||
|
if mevent.place in place_to_str:
|
||||||
|
place = place_to_str[mevent.place]
|
||||||
|
else:
|
||||||
|
place = "UnDefine"
|
||||||
|
crt_info['place'] = place
|
||||||
|
pid = self._mem_devices[(k, mevent.device_id, place)]
|
||||||
|
crt_info['pid'] = pid
|
||||||
|
crt_info['thread_id'] = mevent.thread_id
|
||||||
|
crt_info['device_id'] = mevent.device_id
|
||||||
|
mem_list.append(crt_info)
|
||||||
|
crt_info = dict()
|
||||||
|
crt_info['place'] = place
|
||||||
|
crt_info['pid'] = pid
|
||||||
|
crt_info['thread_id'] = mevent.thread_id
|
||||||
|
crt_info['device_id'] = mevent.device_id
|
||||||
|
crt_info['time'] = mevent.end_ns
|
||||||
|
crt_info['size'] = -mevent.bytes
|
||||||
|
mem_list.append(crt_info)
|
||||||
|
end_profiler = max(end_profiler, crt_info['time'])
|
||||||
|
mem_list.sort(key=lambda tmp: (tmp.get('time', 0)))
|
||||||
|
i = 0
|
||||||
|
total_size = 0
|
||||||
|
while i < len(mem_list):
|
||||||
|
total_size += mem_list[i]['size']
|
||||||
|
while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[
|
||||||
|
i + 1]['time']:
|
||||||
|
total_size += mem_list[i + 1]['size']
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
self._chrome_trace.emit_counter(
|
||||||
|
"Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'],
|
||||||
|
0, total_size)
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
def generate_chrome_trace(self):
|
||||||
|
self._allocate_pids()
|
||||||
|
self._allocate_events()
|
||||||
|
self._allocate_memory_event()
|
||||||
|
return self._chrome_trace.format_to_string()
|
||||||
|
|
||||||
|
|
||||||
|
profile_path = '/tmp/profile'
|
||||||
|
if args.profile_path:
|
||||||
|
profile_path = args.profile_path
|
||||||
|
timeline_path = '/tmp/timeline'
|
||||||
|
if args.timeline_path:
|
||||||
|
timeline_path = args.timeline_path
|
||||||
|
|
||||||
|
profile_paths = profile_path.split(',')
|
||||||
|
profile_dict = dict()
|
||||||
|
if len(profile_paths) == 1:
|
||||||
|
with open(profile_path, 'rb') as f:
|
||||||
|
profile_s = f.read()
|
||||||
|
profile_pb = profiler_pb2.Profile()
|
||||||
|
profile_pb.ParseFromString(profile_s)
|
||||||
|
profile_dict['trainer'] = profile_pb
|
||||||
|
else:
|
||||||
|
for profile_path in profile_paths:
|
||||||
|
k, v = profile_path.split('=')
|
||||||
|
with open(v, 'rb') as f:
|
||||||
|
profile_s = f.read()
|
||||||
|
profile_pb = profiler_pb2.Profile()
|
||||||
|
profile_pb.ParseFromString(profile_s)
|
||||||
|
profile_dict[k] = profile_pb
|
||||||
|
|
||||||
|
tl = Timeline(profile_dict)
|
||||||
|
with open(timeline_path, 'w') as f:
|
||||||
|
f.write(tl.generate_chrome_trace())
|
Loading…
Reference in New Issue