2020-05-10 16:26:57 +08:00
|
|
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
|
|
|
import sys
|
|
|
|
import yaml
|
|
|
|
import os
|
|
|
|
from ppocr.utils.utility import create_module
|
|
|
|
from ppocr.utils.utility import initial_logger
|
2020-08-15 21:54:59 +08:00
|
|
|
|
2020-05-10 16:26:57 +08:00
|
|
|
logger = initial_logger()
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
import time
|
|
|
|
from ppocr.utils.stats import TrainingStats
|
|
|
|
from eval_utils.eval_det_utils import eval_det_run
|
|
|
|
from eval_utils.eval_rec_utils import eval_rec_run
|
2020-09-01 13:44:51 +08:00
|
|
|
from eval_utils.eval_cls_utils import eval_cls_run
|
2020-05-10 16:26:57 +08:00
|
|
|
from ppocr.utils.save_load import save_model
|
|
|
|
import numpy as np
|
2020-08-14 16:31:13 +08:00
|
|
|
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
2020-05-10 16:26:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
class ArgsParser(ArgumentParser):
|
|
|
|
def __init__(self):
|
|
|
|
super(ArgsParser, self).__init__(
|
|
|
|
formatter_class=RawDescriptionHelpFormatter)
|
|
|
|
self.add_argument("-c", "--config", help="configuration file to use")
|
|
|
|
self.add_argument(
|
|
|
|
"-o", "--opt", nargs='+', help="set configuration options")
|
|
|
|
|
|
|
|
def parse_args(self, argv=None):
|
|
|
|
args = super(ArgsParser, self).parse_args(argv)
|
|
|
|
assert args.config is not None, \
|
|
|
|
"Please specify --config=configure_file_path."
|
|
|
|
args.opt = self._parse_opt(args.opt)
|
|
|
|
return args
|
|
|
|
|
|
|
|
def _parse_opt(self, opts):
|
|
|
|
config = {}
|
|
|
|
if not opts:
|
|
|
|
return config
|
|
|
|
for s in opts:
|
|
|
|
s = s.strip()
|
|
|
|
k, v = s.split('=')
|
|
|
|
config[k] = yaml.load(v, Loader=yaml.Loader)
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
|
|
class AttrDict(dict):
|
|
|
|
"""Single level attribute dict, NOT recursive"""
|
|
|
|
|
|
|
|
def __init__(self, **kwargs):
|
|
|
|
super(AttrDict, self).__init__()
|
|
|
|
super(AttrDict, self).update(kwargs)
|
|
|
|
|
|
|
|
def __getattr__(self, key):
|
|
|
|
if key in self:
|
|
|
|
return self[key]
|
|
|
|
raise AttributeError("object has no attribute '{}'".format(key))
|
|
|
|
|
|
|
|
|
|
|
|
global_config = AttrDict()
|
|
|
|
|
2020-07-11 12:14:05 +08:00
|
|
|
default_config = {'Global': {'debug': False, }}
|
|
|
|
|
2020-05-10 16:26:57 +08:00
|
|
|
|
|
|
|
def load_config(file_path):
|
|
|
|
"""
|
|
|
|
Load config from yml/yaml file.
|
|
|
|
Args:
|
|
|
|
file_path (str): Path of the config file to be loaded.
|
|
|
|
Returns: global config
|
|
|
|
"""
|
2020-07-11 12:14:05 +08:00
|
|
|
merge_config(default_config)
|
2020-05-10 16:26:57 +08:00
|
|
|
_, ext = os.path.splitext(file_path)
|
|
|
|
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
|
|
|
merge_config(yaml.load(open(file_path), Loader=yaml.Loader))
|
|
|
|
assert "reader_yml" in global_config['Global'],\
|
|
|
|
"absence reader_yml in global"
|
|
|
|
reader_file_path = global_config['Global']['reader_yml']
|
|
|
|
_, ext = os.path.splitext(reader_file_path)
|
|
|
|
assert ext in ['.yml', '.yaml'], "only support yaml files for reader"
|
|
|
|
merge_config(yaml.load(open(reader_file_path), Loader=yaml.Loader))
|
|
|
|
return global_config
|
|
|
|
|
|
|
|
|
|
|
|
def merge_config(config):
|
|
|
|
"""
|
|
|
|
Merge config into global config.
|
|
|
|
Args:
|
|
|
|
config (dict): Config to be merged.
|
|
|
|
Returns: global config
|
|
|
|
"""
|
|
|
|
for key, value in config.items():
|
|
|
|
if "." not in key:
|
|
|
|
if isinstance(value, dict) and key in global_config:
|
|
|
|
global_config[key].update(value)
|
|
|
|
else:
|
|
|
|
global_config[key] = value
|
|
|
|
else:
|
|
|
|
sub_keys = key.split('.')
|
2020-06-17 16:11:29 +08:00
|
|
|
assert (
|
|
|
|
sub_keys[0] in global_config
|
|
|
|
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
|
|
|
|
global_config.keys(), sub_keys[0])
|
2020-05-10 16:26:57 +08:00
|
|
|
cur = global_config[sub_keys[0]]
|
|
|
|
for idx, sub_key in enumerate(sub_keys[1:]):
|
|
|
|
assert (sub_key in cur)
|
|
|
|
if idx == len(sub_keys) - 2:
|
|
|
|
cur[sub_key] = value
|
|
|
|
else:
|
|
|
|
cur = cur[sub_key]
|
|
|
|
|
|
|
|
|
|
|
|
def check_gpu(use_gpu):
|
|
|
|
"""
|
|
|
|
Log error and exit when set use_gpu=true in paddlepaddle
|
|
|
|
cpu version.
|
|
|
|
"""
|
|
|
|
err = "Config use_gpu cannot be set as true while you are " \
|
|
|
|
"using paddlepaddle cpu version ! \nPlease try: \n" \
|
|
|
|
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
|
|
|
|
"\t2. Set use_gpu as false in config file to run " \
|
|
|
|
"model on CPU"
|
|
|
|
|
|
|
|
try:
|
|
|
|
if use_gpu and not fluid.is_compiled_with_cuda():
|
|
|
|
logger.error(err)
|
|
|
|
sys.exit(1)
|
|
|
|
except Exception as e:
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def build(config, main_prog, startup_prog, mode):
|
|
|
|
"""
|
|
|
|
Build a program using a model and an optimizer
|
|
|
|
1. create feeds
|
|
|
|
2. create a dataloader
|
|
|
|
3. create a model
|
|
|
|
4. create fetchs
|
|
|
|
5. create an optimizer
|
|
|
|
Args:
|
|
|
|
config(dict): config
|
|
|
|
main_prog(): main program
|
|
|
|
startup_prog(): startup program
|
|
|
|
is_train(bool): train or valid
|
|
|
|
Returns:
|
|
|
|
dataloader(): a bridge between the model and the data
|
|
|
|
fetchs(dict): dict of model outputs(included loss and measures)
|
|
|
|
"""
|
|
|
|
with fluid.program_guard(main_prog, startup_prog):
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
func_infor = config['Architecture']['function']
|
|
|
|
model = create_module(func_infor)(params=config)
|
|
|
|
dataloader, outputs = model(mode=mode)
|
|
|
|
fetch_name_list = list(outputs.keys())
|
|
|
|
fetch_varname_list = [outputs[v].name for v in fetch_name_list]
|
|
|
|
opt_loss_name = None
|
2020-08-14 16:31:13 +08:00
|
|
|
model_average = None
|
|
|
|
img_loss_name = None
|
|
|
|
word_loss_name = None
|
2020-05-10 16:26:57 +08:00
|
|
|
if mode == "train":
|
|
|
|
opt_loss = outputs['total_loss']
|
2020-08-14 16:31:13 +08:00
|
|
|
# srn loss
|
|
|
|
#img_loss = outputs['img_loss']
|
|
|
|
#word_loss = outputs['word_loss']
|
|
|
|
#img_loss_name = img_loss.name
|
|
|
|
#word_loss_name = word_loss.name
|
2020-05-10 16:26:57 +08:00
|
|
|
opt_params = config['Optimizer']
|
|
|
|
optimizer = create_module(opt_params['function'])(opt_params)
|
|
|
|
optimizer.minimize(opt_loss)
|
|
|
|
opt_loss_name = opt_loss.name
|
|
|
|
global_lr = optimizer._global_learning_rate()
|
|
|
|
fetch_name_list.insert(0, "lr")
|
|
|
|
fetch_varname_list.insert(0, global_lr.name)
|
2020-08-16 12:53:26 +08:00
|
|
|
if "loss_type" in config["Global"]:
|
|
|
|
if config['Global']["loss_type"] == 'srn':
|
|
|
|
model_average = fluid.optimizer.ModelAverage(
|
|
|
|
config['Global']['average_window'],
|
|
|
|
min_average_window=config['Global'][
|
|
|
|
'min_average_window'],
|
|
|
|
max_average_window=config['Global'][
|
|
|
|
'max_average_window'])
|
2020-08-14 16:31:13 +08:00
|
|
|
|
2020-08-15 12:39:07 +08:00
|
|
|
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,
|
|
|
|
model_average)
|
2020-05-10 16:26:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
def build_export(config, main_prog, startup_prog):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
with fluid.program_guard(main_prog, startup_prog):
|
|
|
|
with fluid.unique_name.guard():
|
|
|
|
func_infor = config['Architecture']['function']
|
|
|
|
model = create_module(func_infor)(params=config)
|
2020-09-03 18:59:44 +08:00
|
|
|
algorithm = config['Global']['algorithm']
|
|
|
|
if algorithm == "SRN":
|
2020-09-03 15:51:50 +08:00
|
|
|
image, others, outputs = model(mode='export')
|
|
|
|
else:
|
|
|
|
image, outputs = model(mode='export')
|
2020-05-26 21:02:27 +08:00
|
|
|
fetches_var_name = sorted([name for name in outputs.keys()])
|
2020-05-15 22:25:37 +08:00
|
|
|
fetches_var = [outputs[name] for name in fetches_var_name]
|
2020-09-03 18:59:44 +08:00
|
|
|
if algorithm == "SRN":
|
2020-09-03 15:51:50 +08:00
|
|
|
others_var_names = sorted([name for name in others.keys()])
|
|
|
|
feeded_var_names = [image.name] + others_var_names
|
|
|
|
else:
|
|
|
|
feeded_var_names = [image.name]
|
|
|
|
|
2020-05-10 16:26:57 +08:00
|
|
|
target_vars = fetches_var
|
|
|
|
return feeded_var_names, target_vars, fetches_var_name
|
|
|
|
|
|
|
|
|
2020-09-15 20:17:23 +08:00
|
|
|
def create_multi_devices_program(program, loss_var_name, for_quant=False):
|
2020-05-10 16:26:57 +08:00
|
|
|
build_strategy = fluid.BuildStrategy()
|
|
|
|
build_strategy.memory_optimize = False
|
|
|
|
build_strategy.enable_inplace = True
|
2020-09-15 20:17:23 +08:00
|
|
|
if for_quant:
|
|
|
|
build_strategy.fuse_all_reduce_ops = False
|
2020-05-10 16:26:57 +08:00
|
|
|
exec_strategy = fluid.ExecutionStrategy()
|
|
|
|
exec_strategy.num_iteration_per_drop_scope = 1
|
|
|
|
compile_program = fluid.CompiledProgram(program).with_data_parallel(
|
|
|
|
loss_name=loss_var_name,
|
|
|
|
build_strategy=build_strategy,
|
|
|
|
exec_strategy=exec_strategy)
|
|
|
|
return compile_program
|
|
|
|
|
|
|
|
|
2020-09-15 21:32:14 +08:00
|
|
|
def train_eval_det_run(config,
|
|
|
|
exe,
|
|
|
|
train_info_dict,
|
|
|
|
eval_info_dict,
|
|
|
|
is_pruning=False):
|
2020-05-10 16:26:57 +08:00
|
|
|
train_batch_id = 0
|
|
|
|
log_smooth_window = config['Global']['log_smooth_window']
|
|
|
|
epoch_num = config['Global']['epoch_num']
|
|
|
|
print_batch_step = config['Global']['print_batch_step']
|
|
|
|
eval_batch_step = config['Global']['eval_batch_step']
|
2020-07-07 10:35:17 +08:00
|
|
|
start_eval_step = 0
|
|
|
|
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
|
|
|
start_eval_step = eval_batch_step[0]
|
|
|
|
eval_batch_step = eval_batch_step[1]
|
|
|
|
logger.info(
|
|
|
|
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
|
|
|
|
format(start_eval_step, eval_batch_step))
|
2020-05-10 16:26:57 +08:00
|
|
|
save_epoch_step = config['Global']['save_epoch_step']
|
|
|
|
save_model_dir = config['Global']['save_model_dir']
|
2020-05-13 16:05:00 +08:00
|
|
|
if not os.path.exists(save_model_dir):
|
|
|
|
os.makedirs(save_model_dir)
|
2020-05-10 16:26:57 +08:00
|
|
|
train_stats = TrainingStats(log_smooth_window,
|
|
|
|
train_info_dict['fetch_name_list'])
|
|
|
|
best_eval_hmean = -1
|
|
|
|
best_batch_id = 0
|
|
|
|
best_epoch = 0
|
|
|
|
train_loader = train_info_dict['reader']
|
|
|
|
for epoch in range(epoch_num):
|
|
|
|
train_loader.start()
|
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
t1 = time.time()
|
|
|
|
train_outs = exe.run(
|
|
|
|
program=train_info_dict['compile_program'],
|
|
|
|
fetch_list=train_info_dict['fetch_varname_list'],
|
|
|
|
return_numpy=False)
|
|
|
|
stats = {}
|
|
|
|
for tno in range(len(train_outs)):
|
|
|
|
fetch_name = train_info_dict['fetch_name_list'][tno]
|
|
|
|
fetch_value = np.mean(np.array(train_outs[tno]))
|
|
|
|
stats[fetch_name] = fetch_value
|
|
|
|
t2 = time.time()
|
|
|
|
train_batch_elapse = t2 - t1
|
|
|
|
train_stats.update(stats)
|
2020-07-14 14:52:12 +08:00
|
|
|
if train_batch_id > 0 and train_batch_id \
|
2020-05-10 16:26:57 +08:00
|
|
|
% print_batch_step == 0:
|
|
|
|
logs = train_stats.log()
|
|
|
|
strs = 'epoch: {}, iter: {}, {}, time: {:.3f}'.format(
|
|
|
|
epoch, train_batch_id, logs, train_batch_elapse)
|
|
|
|
logger.info(strs)
|
|
|
|
|
2020-07-14 14:52:12 +08:00
|
|
|
if train_batch_id > start_eval_step and\
|
|
|
|
(train_batch_id - start_eval_step) % eval_batch_step == 0:
|
2020-05-10 16:26:57 +08:00
|
|
|
metrics = eval_det_run(exe, config, eval_info_dict, "eval")
|
|
|
|
hmean = metrics['hmean']
|
|
|
|
if hmean >= best_eval_hmean:
|
|
|
|
best_eval_hmean = hmean
|
|
|
|
best_batch_id = train_batch_id
|
|
|
|
best_epoch = epoch
|
|
|
|
save_path = save_model_dir + "/best_accuracy"
|
2020-09-15 21:32:14 +08:00
|
|
|
if is_pruning:
|
2020-09-16 19:48:51 +08:00
|
|
|
import paddleslim as slim
|
2020-09-15 21:32:14 +08:00
|
|
|
slim.prune.save_model(
|
|
|
|
exe, train_info_dict['train_program'],
|
|
|
|
save_path)
|
|
|
|
else:
|
|
|
|
save_model(train_info_dict['train_program'],
|
|
|
|
save_path)
|
2020-05-10 16:26:57 +08:00
|
|
|
strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format(
|
|
|
|
train_batch_id, metrics, best_eval_hmean, best_epoch,
|
|
|
|
best_batch_id)
|
|
|
|
logger.info(strs)
|
|
|
|
train_batch_id += 1
|
|
|
|
|
|
|
|
except fluid.core.EOFException:
|
|
|
|
train_loader.reset()
|
2020-05-19 11:29:52 +08:00
|
|
|
if epoch == 0 and save_epoch_step == 1:
|
2020-05-19 11:15:51 +08:00
|
|
|
save_path = save_model_dir + "/iter_epoch_0"
|
2020-09-15 21:32:14 +08:00
|
|
|
if is_pruning:
|
2020-09-16 19:48:51 +08:00
|
|
|
import paddleslim as slim
|
2020-09-15 21:32:14 +08:00
|
|
|
slim.prune.save_model(exe, train_info_dict['train_program'],
|
|
|
|
save_path)
|
|
|
|
else:
|
|
|
|
save_model(train_info_dict['train_program'], save_path)
|
2020-05-10 16:26:57 +08:00
|
|
|
if epoch > 0 and epoch % save_epoch_step == 0:
|
|
|
|
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
2020-09-15 21:32:14 +08:00
|
|
|
if is_pruning:
|
2020-09-16 19:48:51 +08:00
|
|
|
import paddleslim as slim
|
2020-09-15 21:32:14 +08:00
|
|
|
slim.prune.save_model(exe, train_info_dict['train_program'],
|
|
|
|
save_path)
|
|
|
|
else:
|
|
|
|
save_model(train_info_dict['train_program'], save_path)
|
2020-05-10 16:26:57 +08:00
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|
|
|
train_batch_id = 0
|
|
|
|
log_smooth_window = config['Global']['log_smooth_window']
|
|
|
|
epoch_num = config['Global']['epoch_num']
|
|
|
|
print_batch_step = config['Global']['print_batch_step']
|
|
|
|
eval_batch_step = config['Global']['eval_batch_step']
|
2020-07-07 10:35:17 +08:00
|
|
|
start_eval_step = 0
|
|
|
|
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
|
|
|
start_eval_step = eval_batch_step[0]
|
|
|
|
eval_batch_step = eval_batch_step[1]
|
|
|
|
logger.info(
|
|
|
|
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
|
|
|
|
format(start_eval_step, eval_batch_step))
|
2020-05-10 16:26:57 +08:00
|
|
|
save_epoch_step = config['Global']['save_epoch_step']
|
|
|
|
save_model_dir = config['Global']['save_model_dir']
|
2020-05-13 16:51:36 +08:00
|
|
|
if not os.path.exists(save_model_dir):
|
2020-05-13 16:39:39 +08:00
|
|
|
os.makedirs(save_model_dir)
|
2020-05-10 16:26:57 +08:00
|
|
|
train_stats = TrainingStats(log_smooth_window, ['loss', 'acc'])
|
|
|
|
best_eval_acc = -1
|
|
|
|
best_batch_id = 0
|
|
|
|
best_epoch = 0
|
|
|
|
train_loader = train_info_dict['reader']
|
|
|
|
for epoch in range(epoch_num):
|
|
|
|
train_loader.start()
|
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
t1 = time.time()
|
|
|
|
train_outs = exe.run(
|
|
|
|
program=train_info_dict['compile_program'],
|
|
|
|
fetch_list=train_info_dict['fetch_varname_list'],
|
|
|
|
return_numpy=False)
|
|
|
|
fetch_map = dict(
|
|
|
|
zip(train_info_dict['fetch_name_list'],
|
|
|
|
range(len(train_outs))))
|
|
|
|
|
|
|
|
loss = np.mean(np.array(train_outs[fetch_map['total_loss']]))
|
|
|
|
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
|
|
|
|
preds_idx = fetch_map['decoded_out']
|
|
|
|
preds = np.array(train_outs[preds_idx])
|
|
|
|
labels_idx = fetch_map['label']
|
|
|
|
labels = np.array(train_outs[labels_idx])
|
|
|
|
|
2020-08-14 16:31:13 +08:00
|
|
|
if config['Global']['loss_type'] != 'srn':
|
|
|
|
preds_lod = train_outs[preds_idx].lod()[0]
|
|
|
|
labels_lod = train_outs[labels_idx].lod()[0]
|
|
|
|
|
|
|
|
acc, acc_num, img_num = cal_predicts_accuracy(
|
|
|
|
config['Global']['char_ops'], preds, preds_lod, labels,
|
|
|
|
labels_lod)
|
|
|
|
else:
|
|
|
|
acc, acc_num, img_num = cal_predicts_accuracy_srn(
|
|
|
|
config['Global']['char_ops'], preds, labels,
|
|
|
|
config['Global']['max_text_length'])
|
2020-05-10 16:26:57 +08:00
|
|
|
t2 = time.time()
|
|
|
|
train_batch_elapse = t2 - t1
|
|
|
|
stats = {'loss': loss, 'acc': acc}
|
|
|
|
train_stats.update(stats)
|
2020-07-07 10:51:50 +08:00
|
|
|
if train_batch_id > start_eval_step and (train_batch_id - start_eval_step) \
|
2020-05-10 16:26:57 +08:00
|
|
|
% print_batch_step == 0:
|
|
|
|
logs = train_stats.log()
|
|
|
|
strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format(
|
|
|
|
epoch, train_batch_id, lr, logs, train_batch_elapse)
|
|
|
|
logger.info(strs)
|
|
|
|
|
|
|
|
if train_batch_id > 0 and\
|
|
|
|
train_batch_id % eval_batch_step == 0:
|
2020-08-14 16:31:13 +08:00
|
|
|
model_average = train_info_dict['model_average']
|
|
|
|
if model_average != None:
|
|
|
|
model_average.apply(exe)
|
2020-05-10 16:26:57 +08:00
|
|
|
metrics = eval_rec_run(exe, config, eval_info_dict, "eval")
|
|
|
|
eval_acc = metrics['avg_acc']
|
|
|
|
eval_sample_num = metrics['total_sample_num']
|
|
|
|
if eval_acc > best_eval_acc:
|
|
|
|
best_eval_acc = eval_acc
|
|
|
|
best_batch_id = train_batch_id
|
|
|
|
best_epoch = epoch
|
|
|
|
save_path = save_model_dir + "/best_accuracy"
|
|
|
|
save_model(train_info_dict['train_program'], save_path)
|
|
|
|
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
|
|
|
|
train_batch_id, eval_acc, best_eval_acc, best_epoch,
|
|
|
|
best_batch_id, eval_sample_num)
|
|
|
|
logger.info(strs)
|
|
|
|
train_batch_id += 1
|
|
|
|
|
|
|
|
except fluid.core.EOFException:
|
|
|
|
train_loader.reset()
|
2020-05-19 11:32:40 +08:00
|
|
|
if epoch == 0 and save_epoch_step == 1:
|
2020-05-19 11:15:51 +08:00
|
|
|
save_path = save_model_dir + "/iter_epoch_0"
|
2020-05-26 21:02:27 +08:00
|
|
|
save_model(train_info_dict['train_program'], save_path)
|
2020-05-10 16:26:57 +08:00
|
|
|
if epoch > 0 and epoch % save_epoch_step == 0:
|
|
|
|
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
|
|
|
save_model(train_info_dict['train_program'], save_path)
|
|
|
|
return
|
2020-08-15 21:54:59 +08:00
|
|
|
|
2020-08-15 12:39:07 +08:00
|
|
|
|
2020-09-01 13:44:51 +08:00
|
|
|
def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
|
|
|
|
train_batch_id = 0
|
|
|
|
log_smooth_window = config['Global']['log_smooth_window']
|
|
|
|
epoch_num = config['Global']['epoch_num']
|
|
|
|
print_batch_step = config['Global']['print_batch_step']
|
|
|
|
eval_batch_step = config['Global']['eval_batch_step']
|
|
|
|
start_eval_step = 0
|
|
|
|
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
|
|
|
start_eval_step = eval_batch_step[0]
|
|
|
|
eval_batch_step = eval_batch_step[1]
|
|
|
|
logger.info(
|
|
|
|
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
|
|
|
|
format(start_eval_step, eval_batch_step))
|
|
|
|
save_epoch_step = config['Global']['save_epoch_step']
|
|
|
|
save_model_dir = config['Global']['save_model_dir']
|
|
|
|
if not os.path.exists(save_model_dir):
|
|
|
|
os.makedirs(save_model_dir)
|
|
|
|
train_stats = TrainingStats(log_smooth_window, ['loss', 'acc'])
|
|
|
|
best_eval_acc = -1
|
|
|
|
best_batch_id = 0
|
|
|
|
best_epoch = 0
|
|
|
|
train_loader = train_info_dict['reader']
|
|
|
|
for epoch in range(epoch_num):
|
|
|
|
train_loader.start()
|
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
t1 = time.time()
|
|
|
|
train_outs = exe.run(
|
|
|
|
program=train_info_dict['compile_program'],
|
|
|
|
fetch_list=train_info_dict['fetch_varname_list'],
|
|
|
|
return_numpy=False)
|
|
|
|
fetch_map = dict(
|
|
|
|
zip(train_info_dict['fetch_name_list'],
|
|
|
|
range(len(train_outs))))
|
|
|
|
|
|
|
|
loss = np.mean(np.array(train_outs[fetch_map['total_loss']]))
|
|
|
|
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
|
|
|
|
acc = np.mean(np.array(train_outs[fetch_map['acc']]))
|
|
|
|
|
|
|
|
t2 = time.time()
|
|
|
|
train_batch_elapse = t2 - t1
|
|
|
|
stats = {'loss': loss, 'acc': acc}
|
|
|
|
train_stats.update(stats)
|
|
|
|
if train_batch_id > start_eval_step and (train_batch_id - start_eval_step) \
|
|
|
|
% print_batch_step == 0:
|
|
|
|
logs = train_stats.log()
|
|
|
|
strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format(
|
|
|
|
epoch, train_batch_id, lr, logs, train_batch_elapse)
|
|
|
|
logger.info(strs)
|
|
|
|
|
|
|
|
if train_batch_id > 0 and\
|
|
|
|
train_batch_id % eval_batch_step == 0:
|
|
|
|
model_average = train_info_dict['model_average']
|
|
|
|
if model_average != None:
|
|
|
|
model_average.apply(exe)
|
|
|
|
metrics = eval_cls_run(exe, eval_info_dict)
|
|
|
|
eval_acc = metrics['avg_acc']
|
|
|
|
eval_sample_num = metrics['total_sample_num']
|
|
|
|
if eval_acc > best_eval_acc:
|
|
|
|
best_eval_acc = eval_acc
|
|
|
|
best_batch_id = train_batch_id
|
|
|
|
best_epoch = epoch
|
|
|
|
save_path = save_model_dir + "/best_accuracy"
|
|
|
|
save_model(train_info_dict['train_program'], save_path)
|
|
|
|
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
|
|
|
|
train_batch_id, eval_acc, best_eval_acc, best_epoch,
|
|
|
|
best_batch_id, eval_sample_num)
|
|
|
|
logger.info(strs)
|
|
|
|
train_batch_id += 1
|
|
|
|
|
|
|
|
except fluid.core.EOFException:
|
|
|
|
train_loader.reset()
|
|
|
|
if epoch == 0 and save_epoch_step == 1:
|
|
|
|
save_path = save_model_dir + "/iter_epoch_0"
|
|
|
|
save_model(train_info_dict['train_program'], save_path)
|
|
|
|
if epoch > 0 and epoch % save_epoch_step == 0:
|
|
|
|
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
|
|
|
save_model(train_info_dict['train_program'], save_path)
|
|
|
|
return
|
|
|
|
|
|
|
|
|
2020-08-15 21:54:59 +08:00
|
|
|
def preprocess():
|
|
|
|
FLAGS = ArgsParser().parse_args()
|
|
|
|
config = load_config(FLAGS.config)
|
|
|
|
merge_config(FLAGS.opt)
|
|
|
|
logger.info(config)
|
|
|
|
|
|
|
|
# check if set use_gpu=True in paddlepaddle cpu version
|
|
|
|
use_gpu = config['Global']['use_gpu']
|
|
|
|
check_gpu(use_gpu)
|
|
|
|
|
|
|
|
alg = config['Global']['algorithm']
|
2020-09-03 15:51:50 +08:00
|
|
|
assert alg in [
|
2020-09-01 13:44:51 +08:00
|
|
|
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
|
2020-09-03 15:51:50 +08:00
|
|
|
]
|
2020-08-15 12:39:07 +08:00
|
|
|
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
2020-08-15 21:54:59 +08:00
|
|
|
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
|
|
|
|
|
|
|
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
|
|
|
startup_program = fluid.Program()
|
|
|
|
train_program = fluid.Program()
|
|
|
|
|
|
|
|
if alg in ['EAST', 'DB', 'SAST']:
|
|
|
|
train_alg_type = 'det'
|
2020-09-01 13:44:51 +08:00
|
|
|
elif alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
2020-08-15 21:54:59 +08:00
|
|
|
train_alg_type = 'rec'
|
2020-09-01 13:44:51 +08:00
|
|
|
else:
|
|
|
|
train_alg_type = 'cls'
|
2020-08-15 21:54:59 +08:00
|
|
|
|
|
|
|
return startup_program, train_program, place, config, train_alg_type
|