PaddleOCR/tools/program.py

375 lines
13 KiB
Python
Raw Normal View History

2020-05-10 16:26:57 +08:00
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
2020-10-13 17:13:33 +08:00
import os
2020-05-10 16:26:57 +08:00
import sys
import yaml
import time
2020-10-13 17:13:33 +08:00
import shutil
import paddle
import paddle.distributed as dist
from tqdm import tqdm
from argparse import ArgumentParser, RawDescriptionHelpFormatter
2020-05-10 16:26:57 +08:00
from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
2020-11-04 20:43:27 +08:00
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader
import numpy as np
2020-05-10 16:26:57 +08:00
2020-11-05 15:13:36 +08:00
2020-05-10 16:26:57 +08:00
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-o", "--opt", nargs='+', help="set configuration options")
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)
return args
def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip()
k, v = s.split('=')
config[k] = yaml.load(v, Loader=yaml.Loader)
return config
class AttrDict(dict):
"""Single level attribute dict, NOT recursive"""
def __init__(self, **kwargs):
super(AttrDict, self).__init__()
super(AttrDict, self).update(kwargs)
def __getattr__(self, key):
if key in self:
return self[key]
raise AttributeError("object has no attribute '{}'".format(key))
global_config = AttrDict()
2020-07-11 12:14:05 +08:00
default_config = {'Global': {'debug': False, }}
2020-05-10 16:26:57 +08:00
def load_config(file_path):
"""
Load config from yml/yaml file.
Args:
file_path (str): Path of the config file to be loaded.
Returns: global config
"""
2020-07-11 12:14:05 +08:00
merge_config(default_config)
2020-05-10 16:26:57 +08:00
_, ext = os.path.splitext(file_path)
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
2020-10-13 17:13:33 +08:00
merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))
2020-05-10 16:26:57 +08:00
return global_config
def merge_config(config):
"""
Merge config into global config.
Args:
config (dict): Config to be merged.
Returns: global config
"""
for key, value in config.items():
if "." not in key:
if isinstance(value, dict) and key in global_config:
global_config[key].update(value)
else:
global_config[key] = value
else:
sub_keys = key.split('.')
2020-06-17 16:11:29 +08:00
assert (
sub_keys[0] in global_config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
global_config.keys(), sub_keys[0])
2020-05-10 16:26:57 +08:00
cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
cur = cur[sub_key]
def check_gpu(use_gpu):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
err = "Config use_gpu cannot be set as true while you are " \
"using paddlepaddle cpu version ! \nPlease try: \n" \
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
"\t2. Set use_gpu as false in config file to run " \
"model on CPU"
try:
2020-10-13 17:13:33 +08:00
if use_gpu and not paddle.fluid.is_compiled_with_cuda():
print(err)
2020-05-10 16:26:57 +08:00
sys.exit(1)
except Exception as e:
pass
2020-10-13 17:13:33 +08:00
def train(config,
2020-11-04 20:43:27 +08:00
train_dataloader,
valid_dataloader,
device,
2020-10-13 17:13:33 +08:00
model,
loss_class,
optimizer,
lr_scheduler,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
vdl_writer=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
2020-05-10 16:26:57 +08:00
log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step']
2020-10-13 17:13:33 +08:00
2020-11-04 20:43:27 +08:00
global_step = 0
2020-07-07 10:35:17 +08:00
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1]
logger.info(
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step))
2020-05-10 16:26:57 +08:00
save_epoch_step = config['Global']['save_epoch_step']
save_model_dir = config['Global']['save_model_dir']
if not os.path.exists(save_model_dir):
os.makedirs(save_model_dir)
2020-10-13 17:13:33 +08:00
main_indicator = eval_class.main_indicator
best_model_dict = {main_indicator: 0}
best_model_dict.update(pre_best_model_dict)
train_stats = TrainingStats(log_smooth_window, ['lr'])
model.train()
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
else:
start_epoch = 0
for epoch in range(start_epoch, epoch_num):
2020-11-04 20:43:27 +08:00
if epoch > 0:
2020-11-06 18:56:53 +08:00
train_dataloader = build_dataloader(config, 'Train', device, logger)
2020-11-16 19:00:27 +08:00
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
batch_start = time.time()
2020-10-13 17:13:33 +08:00
for idx, batch in enumerate(train_dataloader):
2020-11-16 19:00:27 +08:00
train_reader_cost += time.time() - batch_start
2020-10-13 17:13:33 +08:00
if idx >= len(train_dataloader):
break
lr = optimizer.get_lr()
images = batch[0]
preds = model(images)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
2020-11-05 15:13:36 +08:00
avg_loss.backward()
2020-10-13 17:13:33 +08:00
optimizer.step()
optimizer.clear_grad()
2020-11-16 19:00:27 +08:00
train_batch_cost += time.time() - batch_start
batch_sum += len(images)
2020-11-04 20:43:27 +08:00
if not isinstance(lr_scheduler, float):
lr_scheduler.step()
2020-10-13 17:13:33 +08:00
# logger and visualdl
stats = {k: v.numpy().mean() for k, v in loss.items()}
stats['lr'] = lr
train_stats.update(stats)
if cal_metric_during_train: # onlt rec and cls need
batch = [item.numpy() for item in batch]
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
metirc = eval_class.get_metric()
train_stats.update(metirc)
if vdl_writer is not None and dist.get_rank() == 0:
for k, v in train_stats.get().items():
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
2020-11-05 15:13:36 +08:00
if dist.get_rank(
) == 0 and global_step > 0 and global_step % print_batch_step == 0:
2020-10-13 17:13:33 +08:00
logs = train_stats.log()
2020-11-16 19:02:00 +08:00
strs = 'epoch: [{}/{}], iter: {}, {}, reader_cost: {:.5f} s, batch_cost: {:.5f} s, samples: {}, ips: {:.5f}'.format(
2020-11-16 19:00:27 +08:00
epoch, epoch_num, global_step, logs, train_reader_cost /
print_batch_step, train_batch_cost / print_batch_step,
batch_sum, batch_sum / train_batch_cost)
2020-10-13 17:13:33 +08:00
logger.info(strs)
2020-11-16 19:00:27 +08:00
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
2020-10-13 17:13:33 +08:00
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
2020-11-05 15:13:36 +08:00
cur_metirc = eval(model, valid_dataloader, post_process_class,
2020-11-09 13:28:46 +08:00
eval_class)
2020-10-13 17:13:33 +08:00
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
logger.info(cur_metirc_str)
# logger metric
if vdl_writer is not None:
for k, v in cur_metirc.items():
if isinstance(v, (float, int)):
vdl_writer.add_scalar('EVAL/{}'.format(k),
cur_metirc[k], global_step)
if cur_metirc[main_indicator] >= best_model_dict[
main_indicator]:
best_model_dict.update(cur_metirc)
best_model_dict['best_epoch'] = epoch
save_model(
model,
optimizer,
save_model_dir,
logger,
is_best=True,
prefix='best_accuracy',
best_model_dict=best_model_dict,
epoch=epoch)
best_str = 'best metirc, {}'.format(', '.join([
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
]))
logger.info(best_str)
# logger best metric
if vdl_writer is not None:
vdl_writer.add_scalar('EVAL/best_{}'.format(main_indicator),
best_model_dict[main_indicator],
global_step)
global_step += 1
batch_start = time.time()
2020-10-13 17:13:33 +08:00
if dist.get_rank() == 0:
save_model(
model,
optimizer,
save_model_dir,
logger,
is_best=False,
prefix='latest',
best_model_dict=best_model_dict,
epoch=epoch)
if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
save_model(
model,
optimizer,
save_model_dir,
logger,
is_best=False,
prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict,
epoch=epoch)
best_str = 'best metirc, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str)
if dist.get_rank() == 0 and vdl_writer is not None:
vdl_writer.close()
2020-05-10 16:26:57 +08:00
return
2020-11-09 13:28:46 +08:00
def eval(model, valid_dataloader, post_process_class, eval_class):
2020-10-13 17:13:33 +08:00
model.eval()
with paddle.no_grad():
total_frame = 0.0
total_time = 0.0
2020-11-06 18:56:53 +08:00
pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
2020-10-13 17:13:33 +08:00
for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader):
break
2020-11-06 18:56:53 +08:00
images = batch[0]
2020-10-13 17:13:33 +08:00
start = time.time()
preds = model(images)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
post_result = post_process_class(preds, batch[1])
total_time += time.time() - start
# Evaluate the results of the current batch
eval_class(post_result, batch)
2020-11-06 18:56:53 +08:00
pbar.update(1)
2020-10-13 17:13:33 +08:00
total_frame += len(images)
# Get final metirceg. acc or hmean
metirc = eval_class.get_metric()
2020-11-05 15:13:36 +08:00
2020-11-06 18:56:53 +08:00
pbar.close()
2020-10-13 17:13:33 +08:00
model.train()
metirc['fps'] = total_frame / total_time
return metirc
2020-08-15 21:54:59 +08:00
2020-08-15 12:39:07 +08:00
def preprocess(is_train=False):
2020-08-15 21:54:59 +08:00
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
2020-10-13 17:13:33 +08:00
alg = config['Architecture']['algorithm']
assert alg in [
2020-11-12 12:06:46 +08:00
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
2020-10-13 17:13:33 +08:00
]
2020-08-15 21:54:59 +08:00
2020-10-13 17:13:33 +08:00
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
2020-11-05 15:13:36 +08:00
2020-11-04 20:43:27 +08:00
config['Global']['distributed'] = dist.get_world_size() != 1
if is_train:
# save_config
save_model_dir = config['Global']['save_model_dir']
os.makedirs(save_model_dir, exist_ok=True)
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(
dict(config), f, default_flow_style=False, sort_keys=False)
log_file = '{}/train.log'.format(save_model_dir)
else:
log_file = None
logger = get_logger(name='root', log_file=log_file)
2020-11-04 20:43:27 +08:00
if config['Global']['use_visualdl']:
from visualdl import LogWriter
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
os.makedirs(vdl_writer_path, exist_ok=True)
vdl_writer = LogWriter(logdir=vdl_writer_path)
else:
vdl_writer = None
print_dict(config, logger)
logger.info('train with paddle {} and device {}'.format(paddle.__version__,
device))
return config, device, logger, vdl_writer