add snapshot and visualizer
This commit is contained in:
parent
a738954001
commit
83c9f0aeae
|
@ -11,12 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Tomoki Hayashi
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
"""Normalize feature files and dump them."""
|
||||
|
||||
import argparse
|
||||
|
@ -28,13 +22,10 @@ from pathlib import Path
|
|||
import numpy as np
|
||||
import yaml
|
||||
import jsonlines
|
||||
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from tqdm import tqdm
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.utils.h5_utils import read_hdf5
|
||||
from parakeet.utils.h5_utils import write_hdf5
|
||||
|
||||
from config import get_cfg_default
|
||||
|
||||
|
|
|
@ -12,36 +12,45 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import paddle
|
||||
from paddle.fluid.core import _cuda_synchronize
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
from paddle.optimizer.lr import LRScheduler
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
from timer import timer
|
||||
|
||||
from parakeet.datasets.data_table import DataTable
|
||||
from parakeet.training.updater import UpdaterBase, UpdaterState
|
||||
from parakeet.training.updater import UpdaterBase, UpdaterState, StandardUpdater
|
||||
from parakeet.training.extensions.evaluator import StandardEvaluator
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.training.reporter import report
|
||||
from parakeet.training.checkpoint import KBest, KLatest
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||
from parakeet.utils.profile import synchronize
|
||||
|
||||
|
||||
class PWGUpdater(UpdaterBase):
|
||||
class PWGUpdater(StandardUpdater):
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
optimizers,
|
||||
criterions,
|
||||
schedulers,
|
||||
dataloaders,
|
||||
models: Dict[str, Layer],
|
||||
optimizers: Dict[str, Optimizer],
|
||||
criterions: Dict[str, Layer],
|
||||
schedulers: Dict[str, LRScheduler],
|
||||
dataloader: DataLoader,
|
||||
discriminator_train_start_steps: int,
|
||||
lambda_adv: float, ):
|
||||
self.models = models
|
||||
self.generator = models['generator']
|
||||
self.discriminator = models['discriminator']
|
||||
self.generator: Layer = models['generator']
|
||||
self.discriminator: Layer = models['discriminator']
|
||||
|
||||
self.optimizers = optimizers
|
||||
self.optimizer_g = optimizers['generator']
|
||||
self.optimizer_d = optimizers['discriminator']
|
||||
self.optimizer_g: Optimizer = optimizers['generator']
|
||||
self.optimizer_d: Optimizer = optimizers['discriminator']
|
||||
|
||||
self.criterions = criterions
|
||||
self.criterion_stft = criterions['stft']
|
||||
|
@ -51,46 +60,34 @@ class PWGUpdater(UpdaterBase):
|
|||
self.scheduler_g = schedulers['generator']
|
||||
self.scheduler_d = schedulers['discriminator']
|
||||
|
||||
self.dataloaders = dataloaders
|
||||
self.train_dataloader = dataloaders['train']
|
||||
self.dev_dataloader = dataloaders['dev']
|
||||
self.dataloader = dataloader
|
||||
|
||||
self.discriminator_train_start_steps = discriminator_train_start_steps
|
||||
self.lambda_adv = lambda_adv
|
||||
self.state = UpdaterState(iteration=0, epoch=0)
|
||||
|
||||
self.train_iterator = iter(self.train_dataloader)
|
||||
|
||||
def update_core(self):
|
||||
place = paddle.fluid.framework._current_expected_place()
|
||||
with timer() as t:
|
||||
_cuda_synchronize(place)
|
||||
try:
|
||||
batch = next(self.train_iterator)
|
||||
except StopIteration:
|
||||
self.train_iterator = iter(self.train_dataloader)
|
||||
batch = next(self.train_iterator)
|
||||
_cuda_synchronize(place)
|
||||
print(f"Loading a batch takes {t.elapse}s")
|
||||
self.train_iterator = iter(self.dataloader)
|
||||
|
||||
def update_core(self, batch):
|
||||
# parse batch
|
||||
wav, mel = batch
|
||||
|
||||
# Generator
|
||||
noise = paddle.randn(wav.shape)
|
||||
|
||||
_cuda_synchronize(place)
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
wav_ = self.generator(noise, mel)
|
||||
_cuda_synchronize(place)
|
||||
print(f"Generator takes {t.elapse}s")
|
||||
synchronize()
|
||||
logging.debug(f"Generator takes {t.elapse}s.")
|
||||
|
||||
## Multi-resolution stft loss
|
||||
_cuda_synchronize(place)
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
_cuda_synchronize(place)
|
||||
print(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||
synchronize()
|
||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s.")
|
||||
|
||||
report("train/spectral_convergence_loss", float(sc_loss))
|
||||
report("train/log_stft_magnitude_loss", float(mag_loss))
|
||||
|
@ -98,30 +95,31 @@ class PWGUpdater(UpdaterBase):
|
|||
|
||||
## Adversarial loss
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
_cuda_synchronize(place)
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||
_cuda_synchronize(place)
|
||||
print(f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
synchronize()
|
||||
logging.debug(
|
||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
report("train/adversarial_loss", float(adv_loss))
|
||||
gen_loss += self.lambda_adv * adv_loss
|
||||
|
||||
report("train/generator_loss", float(gen_loss))
|
||||
|
||||
_cuda_synchronize(place)
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
self.optimizer_g.clear_grad()
|
||||
gen_loss.backward()
|
||||
_cuda_synchronize(place)
|
||||
print(f"Backward takes {t.elapse}s.")
|
||||
synchronize()
|
||||
logging.debug(f"Backward takes {t.elapse}s.")
|
||||
|
||||
_cuda_synchronize(place)
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
self.optimizer_g.step()
|
||||
self.scheduler_g.step()
|
||||
_cuda_synchronize(place)
|
||||
print(f"Update takes {t.elapse}s.")
|
||||
synchronize()
|
||||
logging.debug(f"Update takes {t.elapse}s.")
|
||||
|
||||
# Disctiminator
|
||||
if self.state.iteration > self.discriminator_train_start_steps:
|
||||
|
@ -138,5 +136,68 @@ class PWGUpdater(UpdaterBase):
|
|||
|
||||
self.optimizer_d.clear_grad()
|
||||
dis_loss.backward()
|
||||
|
||||
self.optimizer_d.step()
|
||||
self.scheduler_d.step()
|
||||
|
||||
|
||||
class PWGEvaluator(StandardEvaluator):
|
||||
def __init__(self, models, criterions, dataloader, lambda_adv):
|
||||
self.models = models
|
||||
self.generator = models['generator']
|
||||
self.discriminator = models['discriminator']
|
||||
|
||||
self.criterions = criterions
|
||||
self.criterion_stft = criterions['stft']
|
||||
self.criterion_mse = criterions['mse']
|
||||
|
||||
self.dataloader = dataloader
|
||||
self.lambda_adv = lambda_adv
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
logging.debug("Evaluate: ")
|
||||
wav, mel = batch
|
||||
noise = paddle.randn(wav.shape)
|
||||
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
wav_ = self.generator(noise, mel)
|
||||
synchronize()
|
||||
logging.debug(f"Generator takes {t.elapse}s")
|
||||
|
||||
## Multi-resolution stft loss
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
sc_loss, mag_loss = self.criterion_stft(
|
||||
wav_.squeeze(1), wav.squeeze(1))
|
||||
synchronize()
|
||||
logging.debug(f"Multi-resolution STFT loss takes {t.elapse}s")
|
||||
|
||||
report("eval/spectral_convergence_loss", float(sc_loss))
|
||||
report("eval/log_stft_magnitude_loss", float(mag_loss))
|
||||
gen_loss = sc_loss + mag_loss
|
||||
|
||||
## Adversarial loss
|
||||
synchronize()
|
||||
with timer() as t:
|
||||
p_ = self.discriminator(wav_)
|
||||
adv_loss = self.criterion_mse(p_, paddle.ones_like(p_))
|
||||
synchronize()
|
||||
logging.debug(
|
||||
f"Discriminator and adversarial loss takes {t.elapse}s")
|
||||
report("eval/adversarial_loss", float(adv_loss))
|
||||
gen_loss += self.lambda_adv * adv_loss
|
||||
|
||||
report("eval/generator_loss", float(gen_loss))
|
||||
|
||||
# Disctiminator
|
||||
with paddle.no_grad():
|
||||
wav_ = self.generator(noise, mel)
|
||||
p = self.discriminator(wav)
|
||||
p_ = self.discriminator(wav_.detach())
|
||||
real_loss = self.criterion_mse(p, paddle.ones_like(p))
|
||||
fake_loss = self.criterion_mse(p_, paddle.zeros_like(p_))
|
||||
report("eval/real_loss", float(real_loss))
|
||||
report("eval/fake_loss", float(fake_loss))
|
||||
dis_loss = real_loss + fake_loss
|
||||
report("eval/discriminator_loss", float(dis_loss))
|
||||
|
|
|
@ -39,10 +39,13 @@ from parakeet.training.reporter import report
|
|||
from parakeet.training.checkpoint import KBest, KLatest
|
||||
from parakeet.models.parallel_wavegan import PWGGenerator, PWGDiscriminator
|
||||
from parakeet.modules.stft_loss import MultiResolutionSTFTLoss
|
||||
from parakeet.training.extensions.visualizer import VisualDL
|
||||
from parakeet.training.extensions.snapshot import Snapshot
|
||||
from parakeet.training.seeding import seed_everything
|
||||
|
||||
from batch_fn import Clip
|
||||
from config import get_cfg_default
|
||||
from pwg_updater import PWGUpdater
|
||||
from pwg_updater import PWGUpdater, PWGEvaluator
|
||||
|
||||
|
||||
def train_sp(args, config):
|
||||
|
@ -56,6 +59,9 @@ def train_sp(args, config):
|
|||
if world_size > 1:
|
||||
paddle.distributed.init_parallel_env()
|
||||
|
||||
# set the random seed, it is a must for multiprocess training
|
||||
seed_everything(42)
|
||||
|
||||
print(
|
||||
f"rank: {dist.get_rank()}, pid: {os.getpid()}, parent_pid: {os.getppid()}",
|
||||
)
|
||||
|
@ -128,8 +134,7 @@ def train_sp(args, config):
|
|||
parameters=generator.parameters(),
|
||||
**config["generator_optimizer_params"])
|
||||
lr_schedule_d = StepDecay(**config["discriminator_scheduler_params"])
|
||||
gradient_clip_d = nn.ClipGradByGlobalNorm(config[
|
||||
"discriminator_grad_norm"])
|
||||
gradient_clip_d = nn.ClipGradByGlobalNorm(config["discriminator_grad_norm"])
|
||||
optimizer_d = Adam(
|
||||
learning_rate=lr_schedule_d,
|
||||
grad_clip=gradient_clip_d,
|
||||
|
@ -138,10 +143,10 @@ def train_sp(args, config):
|
|||
print("optimizers done!")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
log_writer = None
|
||||
checkpoint_dir = output_dir / "checkpoints"
|
||||
if dist.get_rank() == 0:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_writer = LogWriter(str(output_dir))
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
updater = PWGUpdater(
|
||||
models={
|
||||
|
@ -160,18 +165,41 @@ def train_sp(args, config):
|
|||
"generator": lr_schedule_g,
|
||||
"discriminator": lr_schedule_d,
|
||||
},
|
||||
dataloaders={
|
||||
"train": train_dataloader,
|
||||
"dev": dev_dataloader,
|
||||
},
|
||||
dataloader=train_dataloader,
|
||||
discriminator_train_start_steps=config.discriminator_train_start_steps,
|
||||
lambda_adv=config.lambda_adv, )
|
||||
|
||||
evaluator = PWGEvaluator(
|
||||
models={
|
||||
"generator": generator,
|
||||
"discriminator": discriminator,
|
||||
},
|
||||
criterions={
|
||||
"stft": criterion_stft,
|
||||
"mse": criterion_mse,
|
||||
},
|
||||
dataloader=dev_dataloader,
|
||||
lambda_adv=config.lambda_adv, )
|
||||
|
||||
trainer = Trainer(
|
||||
updater,
|
||||
stop_trigger=(config.train_max_steps, "iteration"), # PROFILING
|
||||
stop_trigger=(config.train_max_steps, "iteration"),
|
||||
out=output_dir, )
|
||||
|
||||
trainer.extend(
|
||||
evaluator,
|
||||
trigger=(config.eval_interval_steps, 'iteration'),
|
||||
priority=3)
|
||||
if dist.get_rank() == 0:
|
||||
log_writer = LogWriter(str(output_dir))
|
||||
trainer.extend(
|
||||
VisualDL(log_writer), trigger=(1, 'iteration'), priority=1)
|
||||
trainer.extend(
|
||||
Snapshot(checkpoint_dir),
|
||||
trigger=(config.save_interval_steps, 'iteration'),
|
||||
priority=2)
|
||||
print("Trainer Done!")
|
||||
|
||||
# with paddle.fluid.profiler.profiler('All', 'total',
|
||||
# str(output_dir / "profiler.log"),
|
||||
# 'Default') as prof:
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Callable, Dict
|
||||
|
||||
from tqdm import tqdm
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
from paddle.nn import Layer
|
||||
from paddle.io import DataLoader
|
||||
|
||||
from parakeet.training.reporter import scope, report, DictSummary
|
||||
|
||||
|
||||
class StandardEvaluator(object):
|
||||
def __init__(self, model: Layer, dataloader: DataLoader):
|
||||
# it is designed to hold multiple models
|
||||
models = {"main": model}
|
||||
self.models: Dict[str, Layer] = models
|
||||
self.model = model
|
||||
|
||||
# dataloaders
|
||||
self.dataloader = dataloader
|
||||
|
||||
def evaluate_core(self, batch):
|
||||
# compute
|
||||
self.model(batch) # you may report here
|
||||
|
||||
def evaluate(self):
|
||||
# switch to eval mode
|
||||
for layer in self.models.values():
|
||||
layer.eval()
|
||||
|
||||
summary = DictSummary()
|
||||
for batch in self.dataloader:
|
||||
observation = {}
|
||||
with scope(observation):
|
||||
with paddle.no_grad():
|
||||
self.evaluate_core(
|
||||
batch) # main evaluation computation here.
|
||||
summary.add(observation)
|
||||
summary = summary.compute_mean()
|
||||
return summary
|
||||
|
||||
def __call__(self, trainer=None):
|
||||
self.observation = {}
|
||||
with scope(self.observation):
|
||||
summary = self.evaluate()
|
||||
for k, v in summary.items():
|
||||
report(k, v)
|
||||
print(self.observation)
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
from pathlib import Path
|
||||
|
||||
from parakeet.utils.mp_tools import rank_zero_only
|
||||
from parakeet.training.trainer import Trainer
|
||||
|
||||
|
||||
class Snapshot(object):
|
||||
"""An extension to make snapshot of the updater object inside
|
||||
the trainer. It is done by calling the updater's `save` method.
|
||||
|
||||
An Updater save its state_dict by default, which contains the
|
||||
updater state, (i.e. epoch and iteration) and all the model
|
||||
parameters and optimizer states. If the updater inside the trainer
|
||||
subclasses StandardUpdater, everything is good to go.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
checkpoint_dir : Union[str, Path]
|
||||
The directory to save checkpoints into.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_dir: Union[str, Path]):
|
||||
self.checkpoint_dir = Path(checkpoint_dir)
|
||||
|
||||
@rank_zero_only
|
||||
def __call__(self, trainer: Trainer):
|
||||
iteration = trainer.updater.state.iteration
|
||||
path = self.checkpoint_dir / f"step_{iteration}.pdz"
|
||||
trainer.updater.save(str(path))
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from visualdl import LogWriter
|
||||
from parakeet.training.trainer import Trainer
|
||||
from parakeet.utils.mp_tools import rank_zero_only
|
||||
|
||||
|
||||
class VisualDL(object):
|
||||
"""A wrapper of visualdl log writer. It assumes that the metrics to be visualized
|
||||
are all scalars which are recorded into the `.observation` dictionary of the
|
||||
trainer object. The dictionary is created for each step, thus the visualdl log
|
||||
writer uses the iteration from the updater's `iteration` as the global step to
|
||||
add records.
|
||||
"""
|
||||
|
||||
def __init__(self, writer: LogWriter):
|
||||
self.writer = writer
|
||||
|
||||
@rank_zero_only
|
||||
def __call__(self, trainer: Trainer):
|
||||
for k, v in trainer.observation.items():
|
||||
self.writer.add_scalar(k, v, step=trainer.updater.state.iteration)
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import paddle
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
|
||||
def seed_everything(seed: int):
|
||||
paddle.seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
logging.debug(f"Set the seed of paddle, random, np.random to {seed}.")
|
|
@ -40,15 +40,13 @@ class Trainer(object):
|
|||
self.out = Path(out)
|
||||
self.observation = {}
|
||||
|
||||
def setup(self):
|
||||
pass
|
||||
|
||||
def extend(self, extension, name=None, trigger=None, priority=None):
|
||||
trigger = get_trigger(trigger)
|
||||
|
||||
ordinal = 0
|
||||
modified_name = name
|
||||
while name in self.extensions:
|
||||
while modified_name in self.extensions:
|
||||
print(self.extensions.keys())
|
||||
ordinal += 1
|
||||
modified_name = f"{name}_{ordinal}"
|
||||
|
||||
|
@ -61,8 +59,7 @@ class Trainer(object):
|
|||
self.extensions.keys(),
|
||||
key=lambda name: self.extensions[name].priority,
|
||||
reverse=True)
|
||||
extensions = [(name, self.extensions[name])
|
||||
for name in extension_order]
|
||||
extensions = [(name, self.extensions[name]) for name in extension_order]
|
||||
|
||||
update = self.updater.update
|
||||
stop_trigger = self.stop_trigger
|
||||
|
|
|
@ -12,12 +12,21 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Dict
|
||||
from typing import Union
|
||||
|
||||
from timer import timer
|
||||
import paddle
|
||||
from paddle import Tensor
|
||||
from paddle.nn import Layer
|
||||
from paddle.optimizer import Optimizer
|
||||
from paddle.io import DataLoader
|
||||
from paddle.io import DistributedBatchSampler
|
||||
|
||||
from parakeet.training.reporter import report
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -56,12 +65,34 @@ class UpdaterBase(object):
|
|||
So the best practice is to define a model and define a updater for it.
|
||||
"""
|
||||
|
||||
def update(self):
|
||||
self.state.iteration += 1
|
||||
self.update_core()
|
||||
def __init__(self, init_state=None):
|
||||
if init_state is None:
|
||||
self.state = UpdaterState()
|
||||
else:
|
||||
self.state = init_state
|
||||
|
||||
def update_core(self):
|
||||
pass
|
||||
def update(self, batch):
|
||||
raise NotImplementedError(
|
||||
"Implement your own `update` method for training a step.")
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = {
|
||||
"epoch": self.state.epoch,
|
||||
"iteration": self.state.iteration,
|
||||
}
|
||||
return state_dict
|
||||
|
||||
def set_state_dict(self, state_dict):
|
||||
self.state.epoch = state_dict["epoch"]
|
||||
self.state.iteration = state_dict["iteration"]
|
||||
|
||||
def save(self, path):
|
||||
archive = self.state_dict()
|
||||
paddle.save(archive, path)
|
||||
|
||||
def load(self, path):
|
||||
archive = paddle.load(path)
|
||||
self.set_state_dict(archive)
|
||||
|
||||
|
||||
class StandardUpdater(UpdaterBase):
|
||||
|
@ -71,54 +102,116 @@ class StandardUpdater(UpdaterBase):
|
|||
|
||||
def __init__(self,
|
||||
model: Layer,
|
||||
dataloader: DataLoader,
|
||||
optimizer: Optimizer,
|
||||
loss_func=None,
|
||||
auto_new_epoch: bool=True,
|
||||
dataloader: DataLoader,
|
||||
init_state: Optional[UpdaterState]=None):
|
||||
# it is designed to hold multiple models
|
||||
models = {"main": model}
|
||||
self.models: Dict[str, Layer] = models
|
||||
self.model = model
|
||||
self.dataloader = dataloader
|
||||
self.optimizer = optimizer
|
||||
self.loss_func = loss_func
|
||||
self.auto_new_epoch = auto_new_epoch
|
||||
self.iterator = iter(dataloader)
|
||||
|
||||
# it is designed to hold multiple optimizers
|
||||
optimizers = {"main": optimizer}
|
||||
self.optimizer = optimizer
|
||||
self.optimizers: Dict[str, Optimizer] = optimizers
|
||||
|
||||
# dataloaders
|
||||
self.dataloader = dataloader
|
||||
|
||||
# init state
|
||||
if init_state is None:
|
||||
self.state = UpdaterState()
|
||||
else:
|
||||
self.state = init_state
|
||||
|
||||
self.train_iterator = iter(dataloader)
|
||||
|
||||
def update(self):
|
||||
self.update_core()
|
||||
self.state.iteration += 1
|
||||
|
||||
# switch to training mode
|
||||
for layer in self.models.values():
|
||||
layer.train()
|
||||
|
||||
# training for a step is implemented here
|
||||
batch = self.read_batch()
|
||||
self.update_core(batch)
|
||||
|
||||
def update_core(self, batch):
|
||||
"""A simple case for a training step. Basic assumptions are:
|
||||
Single model;
|
||||
Single optimizer;
|
||||
A batch from the dataloader is just the input of the model;
|
||||
The model return a single loss, or a dict containing serval losses.
|
||||
Parameters updates at every batch, no gradient accumulation.
|
||||
"""
|
||||
loss = self.model(*batch)
|
||||
|
||||
if isinstance(loss, Tensor):
|
||||
loss_dict = {"main": loss}
|
||||
else:
|
||||
# Dict[str, Tensor]
|
||||
loss_dict = loss
|
||||
if "main" not in loss_dict:
|
||||
main_loss = 0
|
||||
for loss_item in loss.values():
|
||||
main_loss += loss_item
|
||||
loss_dict["main"] = main_loss
|
||||
|
||||
for name, loss_item in loss_dict.items():
|
||||
report(name, float(loss_item))
|
||||
|
||||
self.optimizer.clear_gradient()
|
||||
loss_dict["main"].backward()
|
||||
self.optimizer.update()
|
||||
|
||||
def new_epoch(self):
|
||||
self.iterator = iter(self.dataloader)
|
||||
"""Start a new epoch."""
|
||||
self.state.epoch += 1
|
||||
|
||||
def update_core(self):
|
||||
model = self.model
|
||||
optimizer = self.optimizer
|
||||
loss_func = self.loss_func
|
||||
# NOTE: all batch sampler for distributed training should
|
||||
# subclass DistributedBatchSampler and implement `set_epoch` method
|
||||
batch_sampler = self.dataloader.batch_sampler
|
||||
if isinstance(batch_sampler, DistributedBatchSampler):
|
||||
batch_sampler.set_epoch(self.state.epoch)
|
||||
self.train_iterator = iter(self.dataloader)
|
||||
|
||||
model.train()
|
||||
optimizer.clear_grad()
|
||||
|
||||
# fetch a batch
|
||||
try:
|
||||
batch = next(self.iterator)
|
||||
except StopIteration as e:
|
||||
if self.auto_new_epoch:
|
||||
def read_batch(self):
|
||||
"""Read a batch from the data loader, auto renew when data is exhausted."""
|
||||
with timer() as t:
|
||||
try:
|
||||
batch = next(self.train_iterator)
|
||||
except StopIteration:
|
||||
self.new_epoch()
|
||||
batch = next(self.train_iterator)
|
||||
logging.debug(
|
||||
f"Read a batch takes {t.elapse}s.") # replace it with logging
|
||||
return batch
|
||||
|
||||
# forward
|
||||
if self.loss_func is not None:
|
||||
loss = loss_func(batch)
|
||||
else:
|
||||
loss = model(batch)
|
||||
def state_dict(self):
|
||||
"""State dict of a Updater, model, optimizer and updater state are included."""
|
||||
state_dict = super().state_dict()
|
||||
for name, layer in self.models.items():
|
||||
state_dict[f"{name}_params"] = layer.state_dict()
|
||||
for name, optim in self.optimizers.items():
|
||||
state_dict[f"{name}_optimizer"] = optim.state_dict()
|
||||
return state_dict
|
||||
|
||||
# backward
|
||||
loss.backward()
|
||||
def set_state_dict(self, state_dict):
|
||||
"""Set state dict for a Updater. Parameters of models, states for
|
||||
optimizers and UpdaterState are restored."""
|
||||
for name, layer in self.models.items():
|
||||
layer.set_state_dict(state_dict[f"{name}_params"])
|
||||
for name, optim in self.optimizers.items():
|
||||
optim.set_state_dict(state_dict[f"{name}_optimizer"])
|
||||
super().set_state_dict(state_dict)
|
||||
|
||||
# update parameters
|
||||
optimizer.step()
|
||||
def save(self, path):
|
||||
"""Save Updater state dict."""
|
||||
archive = self.state_dict()
|
||||
paddle.save(archive, path)
|
||||
|
||||
def load(self, path):
|
||||
"""Load Updater state dict."""
|
||||
archive = paddle.load(path)
|
||||
self.set_state_dict(archive)
|
||||
|
|
|
@ -13,8 +13,11 @@
|
|||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from paddle.framework import CUDAPlace
|
||||
|
||||
|
||||
def synchronize():
|
||||
"""Trigger cuda synchronization for better timing."""
|
||||
place = paddle.fluid.framework._current_expected_place()
|
||||
paddle.fluid.core._cuda_synchronize(place)
|
||||
if isinstance(place, CUDAPlace):
|
||||
paddle.fluid.core._cuda_synchronize(place)
|
||||
|
|
|
@ -0,0 +1,319 @@
|
|||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import six
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import google.protobuf.text_format as text_format
|
||||
import paddle.fluid.proto.profiler.profiler_pb2 as profiler_pb2
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
'--profile_path',
|
||||
type=str,
|
||||
default='',
|
||||
help='Input profile file name. If there are multiple file, the format '
|
||||
'should be trainer1=file1,trainer2=file2,ps=file3')
|
||||
parser.add_argument(
|
||||
'--timeline_path', type=str, default='', help='Output timeline file name.')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
class _ChromeTraceFormatter(object):
|
||||
def __init__(self):
|
||||
self._events = []
|
||||
self._metadata = []
|
||||
|
||||
def _create_event(self, ph, category, name, pid, tid, timestamp):
|
||||
"""Creates a new Chrome Trace event.
|
||||
|
||||
For details of the file format, see:
|
||||
https://github.com/catapult-project/catapult/blob/master/tracing/README.md
|
||||
|
||||
Args:
|
||||
ph: The type of event - usually a single character.
|
||||
category: The event category as a string.
|
||||
name: The event name as a string.
|
||||
pid: Identifier of the process generating this event as an integer.
|
||||
tid: Identifier of the thread generating this event as an integer.
|
||||
timestamp: The timestamp of this event as a long integer.
|
||||
|
||||
Returns:
|
||||
A JSON compatible event object.
|
||||
"""
|
||||
event = {}
|
||||
event['ph'] = ph
|
||||
event['cat'] = category
|
||||
event['name'] = name.replace("ParallelExecutor::Run/", "")
|
||||
event['pid'] = pid
|
||||
event['tid'] = tid
|
||||
event['ts'] = timestamp
|
||||
return event
|
||||
|
||||
def emit_pid(self, name, pid):
|
||||
"""Adds a process metadata event to the trace.
|
||||
|
||||
Args:
|
||||
name: The process name as a string.
|
||||
pid: Identifier of the process as an integer.
|
||||
"""
|
||||
event = {}
|
||||
event['name'] = 'process_name'
|
||||
event['ph'] = 'M'
|
||||
event['pid'] = pid
|
||||
event['args'] = {'name': name}
|
||||
self._metadata.append(event)
|
||||
|
||||
def emit_region(self, timestamp, duration, pid, tid, category, name, args):
|
||||
"""Adds a region event to the trace.
|
||||
|
||||
Args:
|
||||
timestamp: The start timestamp of this region as a long integer.
|
||||
duration: The duration of this region as a long integer.
|
||||
pid: Identifier of the process generating this event as an integer.
|
||||
tid: Identifier of the thread generating this event as an integer.
|
||||
category: The event category as a string.
|
||||
name: The event name as a string.
|
||||
args: A JSON-compatible dictionary of event arguments.
|
||||
"""
|
||||
event = self._create_event('X', category, name, pid, tid, timestamp)
|
||||
event['dur'] = duration
|
||||
event['args'] = args
|
||||
self._events.append(event)
|
||||
|
||||
def emit_counter(self, category, name, pid, timestamp, counter, value):
|
||||
"""Emits a record for a single counter.
|
||||
|
||||
Args:
|
||||
category: The event category as string
|
||||
name: The event name as string
|
||||
pid: Identifier of the process generating this event as integer
|
||||
timestamp: The timestamps of this event as long integer
|
||||
counter: Name of the counter as string
|
||||
value: Value of the counter as integer
|
||||
tid: Thread id of the allocation as integer
|
||||
"""
|
||||
event = self._create_event('C', category, name, pid, 0, timestamp)
|
||||
event['args'] = {counter: value}
|
||||
self._events.append(event)
|
||||
|
||||
def format_to_string(self, pretty=False):
|
||||
"""Formats the chrome trace to a string.
|
||||
|
||||
Args:
|
||||
pretty: (Optional.) If True, produce human-readable JSON output.
|
||||
|
||||
Returns:
|
||||
A JSON-formatted string in Chrome Trace format.
|
||||
"""
|
||||
trace = {}
|
||||
trace['traceEvents'] = self._metadata + self._events
|
||||
if pretty:
|
||||
return json.dumps(trace, indent=4, separators=(',', ': '))
|
||||
else:
|
||||
return json.dumps(trace, separators=(',', ':'))
|
||||
|
||||
|
||||
class Timeline(object):
|
||||
def __init__(self, profile_dict):
|
||||
self._profile_dict = profile_dict
|
||||
self._pid = 0
|
||||
self._devices = dict()
|
||||
self._mem_devices = dict()
|
||||
self._chrome_trace = _ChromeTraceFormatter()
|
||||
|
||||
def _allocate_pid(self):
|
||||
cur_pid = self._pid
|
||||
self._pid += 1
|
||||
return cur_pid
|
||||
|
||||
def _allocate_pids(self):
|
||||
for k, profile_pb in six.iteritems(self._profile_dict):
|
||||
for event in profile_pb.events:
|
||||
if event.type == profiler_pb2.Event.CPU:
|
||||
if (k, event.device_id, "CPU") not in self._devices:
|
||||
pid = self._allocate_pid()
|
||||
self._devices[(k, event.device_id, "CPU")] = pid
|
||||
# -1 device id represents CUDA API(RunTime) call.(e.g. cudaLaunch, cudaMemcpy)
|
||||
if event.device_id == -1:
|
||||
self._chrome_trace.emit_pid("%s:cuda_api" % k, pid)
|
||||
else:
|
||||
self._chrome_trace.emit_pid(
|
||||
"%s:cpu:block:%d" % (k, event.device_id), pid)
|
||||
elif event.type == profiler_pb2.Event.GPUKernel:
|
||||
if (k, event.device_id, "GPUKernel") not in self._devices:
|
||||
pid = self._allocate_pid()
|
||||
self._devices[(k, event.device_id, "GPUKernel")] = pid
|
||||
self._chrome_trace.emit_pid("%s:gpu:%d" %
|
||||
(k, event.device_id), pid)
|
||||
if not hasattr(profile_pb, "mem_events"):
|
||||
continue
|
||||
for mevent in profile_pb.mem_events:
|
||||
if mevent.place == profiler_pb2.MemEvent.CUDAPlace:
|
||||
if (k, mevent.device_id, "GPU") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, mevent.device_id, "GPU")] = pid
|
||||
self._chrome_trace.emit_pid(
|
||||
"memory usage on %s:gpu:%d" % (k, mevent.device_id),
|
||||
pid)
|
||||
elif mevent.place == profiler_pb2.MemEvent.CPUPlace:
|
||||
if (k, mevent.device_id, "CPU") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, mevent.device_id, "CPU")] = pid
|
||||
self._chrome_trace.emit_pid(
|
||||
"memory usage on %s:cpu:%d" % (k, mevent.device_id),
|
||||
pid)
|
||||
elif mevent.place == profiler_pb2.MemEvent.CUDAPinnedPlace:
|
||||
if (k, mevent.device_id, "CUDAPinnedPlace"
|
||||
) not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, mevent.device_id,
|
||||
"CUDAPinnedPlace")] = pid
|
||||
self._chrome_trace.emit_pid(
|
||||
"memory usage on %s:cudapinnedplace:%d" %
|
||||
(k, mevent.device_id), pid)
|
||||
elif mevent.place == profiler_pb2.MemEvent.NPUPlace:
|
||||
if (k, mevent.device_id, "NPU") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, mevent.device_id, "NPU")] = pid
|
||||
self._chrome_trace.emit_pid(
|
||||
"memory usage on %s:npu:%d" % (k, mevent.device_id),
|
||||
pid)
|
||||
if (k, 0, "CPU") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, 0, "CPU")] = pid
|
||||
self._chrome_trace.emit_pid("memory usage on %s:cpu:%d" %
|
||||
(k, 0), pid)
|
||||
if (k, 0, "GPU") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, 0, "GPU")] = pid
|
||||
self._chrome_trace.emit_pid("memory usage on %s:gpu:%d" %
|
||||
(k, 0), pid)
|
||||
if (k, 0, "CUDAPinnedPlace") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, 0, "CUDAPinnedPlace")] = pid
|
||||
self._chrome_trace.emit_pid(
|
||||
"memory usage on %s:cudapinnedplace:%d" % (k, 0), pid)
|
||||
if (k, 0, "NPU") not in self._mem_devices:
|
||||
pid = self._allocate_pid()
|
||||
self._mem_devices[(k, 0, "NPU")] = pid
|
||||
self._chrome_trace.emit_pid("memory usage on %s:npu:%d" %
|
||||
(k, 0), pid)
|
||||
|
||||
def _allocate_events(self):
|
||||
for k, profile_pb in six.iteritems(self._profile_dict):
|
||||
for event in profile_pb.events:
|
||||
if event.type == profiler_pb2.Event.CPU:
|
||||
type = "CPU"
|
||||
elif event.type == profiler_pb2.Event.GPUKernel:
|
||||
type = "GPUKernel"
|
||||
pid = self._devices[(k, event.device_id, type)]
|
||||
args = {'name': event.name}
|
||||
if event.memcopy.bytes > 0:
|
||||
args['mem_bytes'] = event.memcopy.bytes
|
||||
if hasattr(event, "detail_info") and event.detail_info:
|
||||
args['detail_info'] = event.detail_info
|
||||
# TODO(panyx0718): Chrome tracing only handles ms. However, some
|
||||
# ops takes micro-seconds. Hence, we keep the ns here.
|
||||
self._chrome_trace.emit_region(
|
||||
event.start_ns, (event.end_ns - event.start_ns) / 1.0, pid,
|
||||
event.sub_device_id, 'Op', event.name, args)
|
||||
|
||||
def _allocate_memory_event(self):
|
||||
if not hasattr(profiler_pb2, "MemEvent"):
|
||||
return
|
||||
place_to_str = {
|
||||
profiler_pb2.MemEvent.CPUPlace: "CPU",
|
||||
profiler_pb2.MemEvent.CUDAPlace: "GPU",
|
||||
profiler_pb2.MemEvent.CUDAPinnedPlace: "CUDAPinnedPlace",
|
||||
profiler_pb2.MemEvent.NPUPlace: "NPU"
|
||||
}
|
||||
for k, profile_pb in six.iteritems(self._profile_dict):
|
||||
mem_list = []
|
||||
end_profiler = 0
|
||||
for mevent in profile_pb.mem_events:
|
||||
crt_info = dict()
|
||||
crt_info['time'] = mevent.start_ns
|
||||
crt_info['size'] = mevent.bytes
|
||||
if mevent.place in place_to_str:
|
||||
place = place_to_str[mevent.place]
|
||||
else:
|
||||
place = "UnDefine"
|
||||
crt_info['place'] = place
|
||||
pid = self._mem_devices[(k, mevent.device_id, place)]
|
||||
crt_info['pid'] = pid
|
||||
crt_info['thread_id'] = mevent.thread_id
|
||||
crt_info['device_id'] = mevent.device_id
|
||||
mem_list.append(crt_info)
|
||||
crt_info = dict()
|
||||
crt_info['place'] = place
|
||||
crt_info['pid'] = pid
|
||||
crt_info['thread_id'] = mevent.thread_id
|
||||
crt_info['device_id'] = mevent.device_id
|
||||
crt_info['time'] = mevent.end_ns
|
||||
crt_info['size'] = -mevent.bytes
|
||||
mem_list.append(crt_info)
|
||||
end_profiler = max(end_profiler, crt_info['time'])
|
||||
mem_list.sort(key=lambda tmp: (tmp.get('time', 0)))
|
||||
i = 0
|
||||
total_size = 0
|
||||
while i < len(mem_list):
|
||||
total_size += mem_list[i]['size']
|
||||
while i < len(mem_list) - 1 and mem_list[i]['time'] == mem_list[
|
||||
i + 1]['time']:
|
||||
total_size += mem_list[i + 1]['size']
|
||||
i += 1
|
||||
|
||||
self._chrome_trace.emit_counter(
|
||||
"Memory", "Memory", mem_list[i]['pid'], mem_list[i]['time'],
|
||||
0, total_size)
|
||||
i += 1
|
||||
|
||||
def generate_chrome_trace(self):
|
||||
self._allocate_pids()
|
||||
self._allocate_events()
|
||||
self._allocate_memory_event()
|
||||
return self._chrome_trace.format_to_string()
|
||||
|
||||
|
||||
profile_path = '/tmp/profile'
|
||||
if args.profile_path:
|
||||
profile_path = args.profile_path
|
||||
timeline_path = '/tmp/timeline'
|
||||
if args.timeline_path:
|
||||
timeline_path = args.timeline_path
|
||||
|
||||
profile_paths = profile_path.split(',')
|
||||
profile_dict = dict()
|
||||
if len(profile_paths) == 1:
|
||||
with open(profile_path, 'rb') as f:
|
||||
profile_s = f.read()
|
||||
profile_pb = profiler_pb2.Profile()
|
||||
profile_pb.ParseFromString(profile_s)
|
||||
profile_dict['trainer'] = profile_pb
|
||||
else:
|
||||
for profile_path in profile_paths:
|
||||
k, v = profile_path.split('=')
|
||||
with open(v, 'rb') as f:
|
||||
profile_s = f.read()
|
||||
profile_pb = profiler_pb2.Profile()
|
||||
profile_pb.ParseFromString(profile_s)
|
||||
profile_dict[k] = profile_pb
|
||||
|
||||
tl = Timeline(profile_dict)
|
||||
with open(timeline_path, 'w') as f:
|
||||
f.write(tl.generate_chrome_trace())
|
Loading…
Reference in New Issue