add slim/prune
This commit is contained in:
parent
ed6b2f0c71
commit
d4f1758d55
|
@ -0,0 +1,40 @@
|
||||||
|
> 运行示例前请先安装develop版本PaddleSlim
|
||||||
|
|
||||||
|
# 模型裁剪压缩教程
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
该示例使用PaddleSlim提供的[裁剪压缩API](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)对OCR模型进行压缩。
|
||||||
|
在阅读该示例前,建议您先了解以下内容:
|
||||||
|
|
||||||
|
- [OCR模型的常规训练方法](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md)
|
||||||
|
- [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)
|
||||||
|
|
||||||
|
## 安装PaddleSlim
|
||||||
|
可按照[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)中的步骤安装PaddleSlim。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 敏感度分析训练
|
||||||
|
|
||||||
|
进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1
|
||||||
|
```
|
||||||
|
|
||||||
|
## 裁剪模型与fine-tune
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 评估并导出
|
||||||
|
|
||||||
|
在得到裁剪训练保存的模型后,我们可以将其导出为inference_model,用于预测部署:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model
|
||||||
|
```
|
|
@ -0,0 +1,156 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
|
__dir__ = os.path.dirname(__file__)
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||||
|
|
||||||
|
__all__ = ['eval_det_run']
|
||||||
|
|
||||||
|
import logging
|
||||||
|
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
|
||||||
|
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import json
|
||||||
|
from copy import deepcopy
|
||||||
|
from ppocr.utils.utility import create_module
|
||||||
|
from ppocr.data.reader_main import reader_main
|
||||||
|
from tools.eval_utils.eval_det_iou import DetectionIoUEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
def cal_det_res(exe, config, eval_info_dict):
|
||||||
|
global_params = config['Global']
|
||||||
|
save_res_path = global_params['save_res_path']
|
||||||
|
postprocess_params = deepcopy(config["PostProcess"])
|
||||||
|
postprocess_params.update(global_params)
|
||||||
|
postprocess = create_module(postprocess_params['function']) \
|
||||||
|
(params=postprocess_params)
|
||||||
|
if not os.path.exists(os.path.dirname(save_res_path)):
|
||||||
|
os.makedirs(os.path.dirname(save_res_path))
|
||||||
|
with open(save_res_path, "wb") as fout:
|
||||||
|
tackling_num = 0
|
||||||
|
for data in eval_info_dict['reader']():
|
||||||
|
img_num = len(data)
|
||||||
|
tackling_num = tackling_num + img_num
|
||||||
|
logger.info("test tackling num:%d", tackling_num)
|
||||||
|
img_list = []
|
||||||
|
ratio_list = []
|
||||||
|
img_name_list = []
|
||||||
|
for ino in range(img_num):
|
||||||
|
img_list.append(data[ino][0])
|
||||||
|
ratio_list.append(data[ino][1])
|
||||||
|
img_name_list.append(data[ino][2])
|
||||||
|
try:
|
||||||
|
img_list = np.concatenate(img_list, axis=0)
|
||||||
|
except:
|
||||||
|
err = "concatenate error usually caused by different input image shapes in evaluation or testing.\n \
|
||||||
|
Please set \"test_batch_size_per_card\" in main yml as 1\n \
|
||||||
|
or add \"test_image_shape: [h, w]\" in reader yml for EvalReader."
|
||||||
|
|
||||||
|
raise Exception(err)
|
||||||
|
outs = exe.run(eval_info_dict['program'], \
|
||||||
|
feed={'image': img_list}, \
|
||||||
|
fetch_list=eval_info_dict['fetch_varname_list'])
|
||||||
|
outs_dict = {}
|
||||||
|
for tno in range(len(outs)):
|
||||||
|
fetch_name = eval_info_dict['fetch_name_list'][tno]
|
||||||
|
fetch_value = np.array(outs[tno])
|
||||||
|
outs_dict[fetch_name] = fetch_value
|
||||||
|
dt_boxes_list = postprocess(outs_dict, ratio_list)
|
||||||
|
for ino in range(img_num):
|
||||||
|
dt_boxes = dt_boxes_list[ino]
|
||||||
|
img_name = img_name_list[ino]
|
||||||
|
dt_boxes_json = []
|
||||||
|
for box in dt_boxes:
|
||||||
|
tmp_json = {"transcription": ""}
|
||||||
|
tmp_json['points'] = box.tolist()
|
||||||
|
dt_boxes_json.append(tmp_json)
|
||||||
|
otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
|
||||||
|
fout.write(otstr.encode())
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def load_label_infor(label_file_path, do_ignore=False):
|
||||||
|
img_name_label_dict = {}
|
||||||
|
with open(label_file_path, "rb") as fin:
|
||||||
|
lines = fin.readlines()
|
||||||
|
for line in lines:
|
||||||
|
substr = line.decode().strip("\n").split("\t")
|
||||||
|
bbox_infor = json.loads(substr[1])
|
||||||
|
bbox_num = len(bbox_infor)
|
||||||
|
for bno in range(bbox_num):
|
||||||
|
text = bbox_infor[bno]['transcription']
|
||||||
|
ignore = False
|
||||||
|
if text == "###" and do_ignore:
|
||||||
|
ignore = True
|
||||||
|
bbox_infor[bno]['ignore'] = ignore
|
||||||
|
img_name_label_dict[os.path.basename(substr[0])] = bbox_infor
|
||||||
|
return img_name_label_dict
|
||||||
|
|
||||||
|
|
||||||
|
def cal_det_metrics(gt_label_path, save_res_path):
|
||||||
|
"""
|
||||||
|
calculate the detection metrics
|
||||||
|
Args:
|
||||||
|
gt_label_path(string): The groundtruth detection label file path
|
||||||
|
save_res_path(string): The saved predicted detection label path
|
||||||
|
return:
|
||||||
|
claculated metrics including Hmean, precision and recall
|
||||||
|
"""
|
||||||
|
evaluator = DetectionIoUEvaluator()
|
||||||
|
gt_label_infor = load_label_infor(gt_label_path, do_ignore=True)
|
||||||
|
dt_label_infor = load_label_infor(save_res_path)
|
||||||
|
results = []
|
||||||
|
for img_name in gt_label_infor:
|
||||||
|
gt_label = gt_label_infor[img_name]
|
||||||
|
if img_name not in dt_label_infor:
|
||||||
|
dt_label = []
|
||||||
|
else:
|
||||||
|
dt_label = dt_label_infor[img_name]
|
||||||
|
result = evaluator.evaluate_image(gt_label, dt_label)
|
||||||
|
results.append(result)
|
||||||
|
methodMetrics = evaluator.combine_results(results)
|
||||||
|
return methodMetrics
|
||||||
|
|
||||||
|
|
||||||
|
def eval_det_run(eval_args, mode='eval'):
|
||||||
|
exe = eval_args['exe']
|
||||||
|
config = eval_args['config']
|
||||||
|
eval_info_dict = eval_args['eval_info_dict']
|
||||||
|
cal_det_res(exe, config, eval_info_dict)
|
||||||
|
|
||||||
|
save_res_path = config['Global']['save_res_path']
|
||||||
|
if mode == "eval":
|
||||||
|
gt_label_path = config['EvalReader']['label_file_path']
|
||||||
|
metrics = cal_det_metrics(gt_label_path, save_res_path)
|
||||||
|
else:
|
||||||
|
gt_label_path = config['TestReader']['label_file_path']
|
||||||
|
do_eval = config['TestReader']['do_eval']
|
||||||
|
if do_eval:
|
||||||
|
metrics = cal_det_metrics(gt_label_path, save_res_path)
|
||||||
|
else:
|
||||||
|
metrics = {}
|
||||||
|
return metrics['hmean']
|
|
@ -0,0 +1,81 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||||
|
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||||
|
|
||||||
|
|
||||||
|
def set_paddle_flags(**kwargs):
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if os.environ.get(key, None) is None:
|
||||||
|
os.environ[key] = str(value)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(paddle-dev): All of these flags should be
|
||||||
|
# set before `import paddle`. Otherwise, it would
|
||||||
|
# not take any effect.
|
||||||
|
set_paddle_flags(
|
||||||
|
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||||
|
)
|
||||||
|
|
||||||
|
import program
|
||||||
|
from paddle import fluid
|
||||||
|
from ppocr.utils.utility import initial_logger
|
||||||
|
logger = initial_logger()
|
||||||
|
from ppocr.utils.save_load import init_model
|
||||||
|
from paddleslim.prune import load_model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
startup_prog, eval_program, place, config, _ = program.preprocess()
|
||||||
|
|
||||||
|
feeded_var_names, target_vars, fetches_var_name = program.build_export(
|
||||||
|
config, eval_program, startup_prog)
|
||||||
|
eval_program = eval_program.clone(for_test=True)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
exe.run(startup_prog)
|
||||||
|
|
||||||
|
if config['Global']['checkpoints'] is not None:
|
||||||
|
path = config['Global']['checkpoints']
|
||||||
|
else:
|
||||||
|
path = config['Global']['pretrain_weights']
|
||||||
|
|
||||||
|
load_model(exe, eval_program, path)
|
||||||
|
|
||||||
|
save_inference_dir = config['Global']['save_inference_dir']
|
||||||
|
if not os.path.exists(save_inference_dir):
|
||||||
|
os.makedirs(save_inference_dir)
|
||||||
|
fluid.io.save_inference_model(
|
||||||
|
dirname=save_inference_dir,
|
||||||
|
feeded_var_names=feeded_var_names,
|
||||||
|
main_program=eval_program,
|
||||||
|
target_vars=target_vars,
|
||||||
|
executor=exe,
|
||||||
|
model_filename='model',
|
||||||
|
params_filename='params')
|
||||||
|
print("inference model saved in {}/model and {}/params".format(
|
||||||
|
save_inference_dir, save_inference_dir))
|
||||||
|
print("save success, output_name_list:", fetches_var_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
|
@ -0,0 +1,188 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
__dir__ = os.path.dirname(__file__)
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||||
|
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||||
|
|
||||||
|
|
||||||
|
def set_paddle_flags(**kwargs):
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if os.environ.get(key, None) is None:
|
||||||
|
os.environ[key] = str(value)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(paddle-dev): All of these flags should be
|
||||||
|
# set before `import paddle`. Otherwise, it would
|
||||||
|
# not take any effect.
|
||||||
|
set_paddle_flags(
|
||||||
|
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||||
|
)
|
||||||
|
|
||||||
|
import tools.program as program
|
||||||
|
from paddle import fluid
|
||||||
|
from ppocr.utils.utility import initial_logger
|
||||||
|
logger = initial_logger()
|
||||||
|
from ppocr.data.reader_main import reader_main
|
||||||
|
from ppocr.utils.save_load import init_model
|
||||||
|
from ppocr.utils.character import CharacterOps
|
||||||
|
from ppocr.utils.utility import initial_logger
|
||||||
|
from paddleslim.prune import Pruner, save_model
|
||||||
|
from paddleslim.analysis import flops
|
||||||
|
from paddleslim.core.graph_wrapper import *
|
||||||
|
from paddleslim.prune import load_sensitivities, get_ratios_by_loss, merge_sensitive
|
||||||
|
logger = initial_logger()
|
||||||
|
|
||||||
|
skip_list = [
|
||||||
|
'conv10_linear_weights', 'conv11_linear_weights', 'conv12_expand_weights',
|
||||||
|
'conv12_linear_weights', 'conv12_se_2_weights', 'conv13_linear_weights',
|
||||||
|
'conv2_linear_weights', 'conv4_linear_weights', 'conv5_expand_weights',
|
||||||
|
'conv5_linear_weights', 'conv5_se_2_weights', 'conv6_linear_weights',
|
||||||
|
'conv7_linear_weights', 'conv8_expand_weights', 'conv8_linear_weights',
|
||||||
|
'conv9_expand_weights', 'conv9_linear_weights'
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = program.load_config(FLAGS.config)
|
||||||
|
program.merge_config(FLAGS.opt)
|
||||||
|
logger.info(config)
|
||||||
|
|
||||||
|
# check if set use_gpu=True in paddlepaddle cpu version
|
||||||
|
use_gpu = config['Global']['use_gpu']
|
||||||
|
program.check_gpu(use_gpu)
|
||||||
|
|
||||||
|
alg = config['Global']['algorithm']
|
||||||
|
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
|
||||||
|
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
|
||||||
|
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||||
|
|
||||||
|
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||||
|
startup_program = fluid.Program()
|
||||||
|
train_program = fluid.Program()
|
||||||
|
train_build_outputs = program.build(
|
||||||
|
config, train_program, startup_program, mode='train')
|
||||||
|
train_loader = train_build_outputs[0]
|
||||||
|
train_fetch_name_list = train_build_outputs[1]
|
||||||
|
train_fetch_varname_list = train_build_outputs[2]
|
||||||
|
train_opt_loss_name = train_build_outputs[3]
|
||||||
|
|
||||||
|
eval_program = fluid.Program()
|
||||||
|
eval_build_outputs = program.build(
|
||||||
|
config, eval_program, startup_program, mode='eval')
|
||||||
|
eval_fetch_name_list = eval_build_outputs[1]
|
||||||
|
eval_fetch_varname_list = eval_build_outputs[2]
|
||||||
|
eval_program = eval_program.clone(for_test=True)
|
||||||
|
|
||||||
|
train_reader = reader_main(config=config, mode="train")
|
||||||
|
train_loader.set_sample_list_generator(train_reader, places=place)
|
||||||
|
|
||||||
|
eval_reader = reader_main(config=config, mode="eval")
|
||||||
|
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
exe.run(startup_program)
|
||||||
|
|
||||||
|
# compile program for multi-devices
|
||||||
|
init_model(config, train_program, exe)
|
||||||
|
|
||||||
|
# params = get_pruned_params(train_program)
|
||||||
|
'''
|
||||||
|
sens_file = ['sensitivities_'+ str(x) for x in range(0,4)]
|
||||||
|
sens = []
|
||||||
|
for f in sens_file:
|
||||||
|
sens.append(load_sensitivities(f+'.data'))
|
||||||
|
sen = merge_sensitive(sens)
|
||||||
|
'''
|
||||||
|
sen = load_sensitivities("sensitivities_0.data")
|
||||||
|
for i in skip_list:
|
||||||
|
sen.pop(i)
|
||||||
|
back_bone_list = ['conv' + str(x) for x in range(1, 5)]
|
||||||
|
for i in back_bone_list:
|
||||||
|
for key in list(sen.keys()):
|
||||||
|
if i + '_' in key:
|
||||||
|
sen.pop(key)
|
||||||
|
ratios = get_ratios_by_loss(sen, 0.03)
|
||||||
|
logger.info("FLOPs before pruning: {}".format(flops(eval_program)))
|
||||||
|
pruner = Pruner(criterion='geometry_median')
|
||||||
|
print("ratios: {}".format(ratios))
|
||||||
|
pruned_val_program, _, _ = pruner.prune(
|
||||||
|
eval_program,
|
||||||
|
fluid.global_scope(),
|
||||||
|
params=ratios.keys(),
|
||||||
|
ratios=ratios.values(),
|
||||||
|
place=place,
|
||||||
|
only_graph=True)
|
||||||
|
|
||||||
|
pruned_program, _, _ = pruner.prune(
|
||||||
|
train_program,
|
||||||
|
fluid.global_scope(),
|
||||||
|
params=ratios.keys(),
|
||||||
|
ratios=ratios.values(),
|
||||||
|
place=place)
|
||||||
|
logger.info("FLOPs after pruning: {}".format(flops(pruned_val_program)))
|
||||||
|
train_compile_program = program.create_multi_devices_program(
|
||||||
|
pruned_program, train_opt_loss_name)
|
||||||
|
|
||||||
|
|
||||||
|
train_info_dict = {'compile_program':train_compile_program,\
|
||||||
|
'train_program':pruned_program,\
|
||||||
|
'reader':train_loader,\
|
||||||
|
'fetch_name_list':train_fetch_name_list,\
|
||||||
|
'fetch_varname_list':train_fetch_varname_list}
|
||||||
|
|
||||||
|
eval_info_dict = {'program':pruned_val_program,\
|
||||||
|
'reader':eval_reader,\
|
||||||
|
'fetch_name_list':eval_fetch_name_list,\
|
||||||
|
'fetch_varname_list':eval_fetch_varname_list}
|
||||||
|
|
||||||
|
if alg in ['EAST', 'DB']:
|
||||||
|
program.train_eval_det_run(
|
||||||
|
config, exe, train_info_dict, eval_info_dict, is_pruning=True)
|
||||||
|
else:
|
||||||
|
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reader():
|
||||||
|
config = program.load_config(FLAGS.config)
|
||||||
|
program.merge_config(FLAGS.opt)
|
||||||
|
print(config)
|
||||||
|
train_reader = reader_main(config=config, mode="train")
|
||||||
|
import time
|
||||||
|
starttime = time.time()
|
||||||
|
count = 0
|
||||||
|
try:
|
||||||
|
for data in train_reader():
|
||||||
|
count += 1
|
||||||
|
if count % 1 == 0:
|
||||||
|
batch_time = time.time() - starttime
|
||||||
|
starttime = time.time()
|
||||||
|
print("reader:", count, len(data), batch_time)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(e)
|
||||||
|
logger.info("finish reader: {}, Success!".format(count))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = program.ArgsParser()
|
||||||
|
FLAGS = parser.parse_args()
|
||||||
|
main()
|
||||||
|
# test_reader()
|
|
@ -0,0 +1,121 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
__dir__ = os.path.dirname(__file__)
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||||
|
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||||
|
|
||||||
|
|
||||||
|
def set_paddle_flags(**kwargs):
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if os.environ.get(key, None) is None:
|
||||||
|
os.environ[key] = str(value)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(paddle-dev): All of these flags should be
|
||||||
|
# set before `import paddle`. Otherwise, it would
|
||||||
|
# not take any effect.
|
||||||
|
set_paddle_flags(
|
||||||
|
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||||
|
)
|
||||||
|
|
||||||
|
import json
|
||||||
|
import cv2
|
||||||
|
from paddle import fluid
|
||||||
|
import paddleslim as slim
|
||||||
|
from copy import deepcopy
|
||||||
|
from eval_det_utils import eval_det_run
|
||||||
|
|
||||||
|
from tools import program
|
||||||
|
from ppocr.utils.utility import initial_logger
|
||||||
|
from ppocr.data.reader_main import reader_main
|
||||||
|
from ppocr.utils.save_load import init_model
|
||||||
|
from ppocr.utils.character import CharacterOps
|
||||||
|
from ppocr.utils.utility import create_module
|
||||||
|
from ppocr.data.reader_main import reader_main
|
||||||
|
|
||||||
|
logger = initial_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_pruned_params(program):
|
||||||
|
params = []
|
||||||
|
for param in program.global_block().all_parameters():
|
||||||
|
if len(
|
||||||
|
param.shape
|
||||||
|
) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name:
|
||||||
|
params.append(param.name)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = program.load_config(FLAGS.config)
|
||||||
|
program.merge_config(FLAGS.opt)
|
||||||
|
logger.info(config)
|
||||||
|
|
||||||
|
# check if set use_gpu=True in paddlepaddle cpu version
|
||||||
|
use_gpu = config['Global']['use_gpu']
|
||||||
|
program.check_gpu(use_gpu)
|
||||||
|
|
||||||
|
alg = config['Global']['algorithm']
|
||||||
|
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
|
||||||
|
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
|
||||||
|
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||||
|
|
||||||
|
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||||
|
startup_prog = fluid.Program()
|
||||||
|
eval_program = fluid.Program()
|
||||||
|
eval_build_outputs = program.build(
|
||||||
|
config, eval_program, startup_prog, mode='test')
|
||||||
|
eval_fetch_name_list = eval_build_outputs[1]
|
||||||
|
eval_fetch_varname_list = eval_build_outputs[2]
|
||||||
|
eval_program = eval_program.clone(for_test=True)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
exe.run(startup_prog)
|
||||||
|
|
||||||
|
init_model(config, eval_program, exe)
|
||||||
|
|
||||||
|
eval_reader = reader_main(config=config, mode="eval")
|
||||||
|
eval_info_dict = {'program':eval_program,\
|
||||||
|
'reader':eval_reader,\
|
||||||
|
'fetch_name_list':eval_fetch_name_list,\
|
||||||
|
'fetch_varname_list':eval_fetch_varname_list}
|
||||||
|
eval_args = dict()
|
||||||
|
eval_args = {'exe': exe, 'config': config, 'eval_info_dict': eval_info_dict}
|
||||||
|
metrics = eval_det_run(eval_args)
|
||||||
|
print("Baseline: {}".format(metrics))
|
||||||
|
|
||||||
|
params = get_pruned_params(eval_program)
|
||||||
|
print('Start to analyze')
|
||||||
|
sens_0 = slim.prune.sensitivity(
|
||||||
|
eval_program,
|
||||||
|
place,
|
||||||
|
params,
|
||||||
|
eval_det_run,
|
||||||
|
sensitivities_file="sensitivities_0.data",
|
||||||
|
pruned_ratios=[0.1],
|
||||||
|
eval_args=eval_args,
|
||||||
|
criterion='geometry_median')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = program.ArgsParser()
|
||||||
|
FLAGS = parser.parse_args()
|
||||||
|
main()
|
|
@ -33,6 +33,7 @@ from eval_utils.eval_rec_utils import eval_rec_run
|
||||||
from ppocr.utils.save_load import save_model
|
from ppocr.utils.save_load import save_model
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
||||||
|
import paddleslim as slim
|
||||||
|
|
||||||
|
|
||||||
class ArgsParser(ArgumentParser):
|
class ArgsParser(ArgumentParser):
|
||||||
|
@ -238,7 +239,11 @@ def create_multi_devices_program(program, loss_var_name):
|
||||||
return compile_program
|
return compile_program
|
||||||
|
|
||||||
|
|
||||||
def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
def train_eval_det_run(config,
|
||||||
|
exe,
|
||||||
|
train_info_dict,
|
||||||
|
eval_info_dict,
|
||||||
|
is_pruning=False):
|
||||||
train_batch_id = 0
|
train_batch_id = 0
|
||||||
log_smooth_window = config['Global']['log_smooth_window']
|
log_smooth_window = config['Global']['log_smooth_window']
|
||||||
epoch_num = config['Global']['epoch_num']
|
epoch_num = config['Global']['epoch_num']
|
||||||
|
@ -294,7 +299,13 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
||||||
best_batch_id = train_batch_id
|
best_batch_id = train_batch_id
|
||||||
best_epoch = epoch
|
best_epoch = epoch
|
||||||
save_path = save_model_dir + "/best_accuracy"
|
save_path = save_model_dir + "/best_accuracy"
|
||||||
save_model(train_info_dict['train_program'], save_path)
|
if is_pruning:
|
||||||
|
slim.prune.save_model(
|
||||||
|
exe, train_info_dict['train_program'],
|
||||||
|
save_path)
|
||||||
|
else:
|
||||||
|
save_model(train_info_dict['train_program'],
|
||||||
|
save_path)
|
||||||
strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format(
|
strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format(
|
||||||
train_batch_id, metrics, best_eval_hmean, best_epoch,
|
train_batch_id, metrics, best_eval_hmean, best_epoch,
|
||||||
best_batch_id)
|
best_batch_id)
|
||||||
|
@ -305,9 +316,17 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
||||||
train_loader.reset()
|
train_loader.reset()
|
||||||
if epoch == 0 and save_epoch_step == 1:
|
if epoch == 0 and save_epoch_step == 1:
|
||||||
save_path = save_model_dir + "/iter_epoch_0"
|
save_path = save_model_dir + "/iter_epoch_0"
|
||||||
|
if is_pruning:
|
||||||
|
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||||
|
save_path)
|
||||||
|
else:
|
||||||
save_model(train_info_dict['train_program'], save_path)
|
save_model(train_info_dict['train_program'], save_path)
|
||||||
if epoch > 0 and epoch % save_epoch_step == 0:
|
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||||
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||||
|
if is_pruning:
|
||||||
|
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||||
|
save_path)
|
||||||
|
else:
|
||||||
save_model(train_info_dict['train_program'], save_path)
|
save_model(train_info_dict['train_program'], save_path)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue