add slim/prune
This commit is contained in:
parent
ed6b2f0c71
commit
d4f1758d55
|
@ -0,0 +1,40 @@
|
|||
> 运行示例前请先安装develop版本PaddleSlim
|
||||
|
||||
# 模型裁剪压缩教程
|
||||
|
||||
## 概述
|
||||
|
||||
该示例使用PaddleSlim提供的[裁剪压缩API](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)对OCR模型进行压缩。
|
||||
在阅读该示例前,建议您先了解以下内容:
|
||||
|
||||
- [OCR模型的常规训练方法](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md)
|
||||
- [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)
|
||||
|
||||
## 安装PaddleSlim
|
||||
可按照[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)中的步骤安装PaddleSlim。
|
||||
|
||||
|
||||
|
||||
## 敏感度分析训练
|
||||
|
||||
进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析:
|
||||
|
||||
```bash
|
||||
python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1
|
||||
```
|
||||
|
||||
## 裁剪模型与fine-tune
|
||||
|
||||
```bash
|
||||
python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 评估并导出
|
||||
|
||||
在得到裁剪训练保存的模型后,我们可以将其导出为inference_model,用于预测部署:
|
||||
|
||||
```bash
|
||||
python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model
|
||||
```
|
|
@ -0,0 +1,156 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
|
||||
__all__ = ['eval_det_run']
|
||||
|
||||
import logging
|
||||
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
|
||||
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import cv2
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from ppocr.utils.utility import create_module
|
||||
from ppocr.data.reader_main import reader_main
|
||||
from tools.eval_utils.eval_det_iou import DetectionIoUEvaluator
|
||||
|
||||
|
||||
def cal_det_res(exe, config, eval_info_dict):
|
||||
global_params = config['Global']
|
||||
save_res_path = global_params['save_res_path']
|
||||
postprocess_params = deepcopy(config["PostProcess"])
|
||||
postprocess_params.update(global_params)
|
||||
postprocess = create_module(postprocess_params['function']) \
|
||||
(params=postprocess_params)
|
||||
if not os.path.exists(os.path.dirname(save_res_path)):
|
||||
os.makedirs(os.path.dirname(save_res_path))
|
||||
with open(save_res_path, "wb") as fout:
|
||||
tackling_num = 0
|
||||
for data in eval_info_dict['reader']():
|
||||
img_num = len(data)
|
||||
tackling_num = tackling_num + img_num
|
||||
logger.info("test tackling num:%d", tackling_num)
|
||||
img_list = []
|
||||
ratio_list = []
|
||||
img_name_list = []
|
||||
for ino in range(img_num):
|
||||
img_list.append(data[ino][0])
|
||||
ratio_list.append(data[ino][1])
|
||||
img_name_list.append(data[ino][2])
|
||||
try:
|
||||
img_list = np.concatenate(img_list, axis=0)
|
||||
except:
|
||||
err = "concatenate error usually caused by different input image shapes in evaluation or testing.\n \
|
||||
Please set \"test_batch_size_per_card\" in main yml as 1\n \
|
||||
or add \"test_image_shape: [h, w]\" in reader yml for EvalReader."
|
||||
|
||||
raise Exception(err)
|
||||
outs = exe.run(eval_info_dict['program'], \
|
||||
feed={'image': img_list}, \
|
||||
fetch_list=eval_info_dict['fetch_varname_list'])
|
||||
outs_dict = {}
|
||||
for tno in range(len(outs)):
|
||||
fetch_name = eval_info_dict['fetch_name_list'][tno]
|
||||
fetch_value = np.array(outs[tno])
|
||||
outs_dict[fetch_name] = fetch_value
|
||||
dt_boxes_list = postprocess(outs_dict, ratio_list)
|
||||
for ino in range(img_num):
|
||||
dt_boxes = dt_boxes_list[ino]
|
||||
img_name = img_name_list[ino]
|
||||
dt_boxes_json = []
|
||||
for box in dt_boxes:
|
||||
tmp_json = {"transcription": ""}
|
||||
tmp_json['points'] = box.tolist()
|
||||
dt_boxes_json.append(tmp_json)
|
||||
otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
|
||||
fout.write(otstr.encode())
|
||||
return
|
||||
|
||||
|
||||
def load_label_infor(label_file_path, do_ignore=False):
|
||||
img_name_label_dict = {}
|
||||
with open(label_file_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
substr = line.decode().strip("\n").split("\t")
|
||||
bbox_infor = json.loads(substr[1])
|
||||
bbox_num = len(bbox_infor)
|
||||
for bno in range(bbox_num):
|
||||
text = bbox_infor[bno]['transcription']
|
||||
ignore = False
|
||||
if text == "###" and do_ignore:
|
||||
ignore = True
|
||||
bbox_infor[bno]['ignore'] = ignore
|
||||
img_name_label_dict[os.path.basename(substr[0])] = bbox_infor
|
||||
return img_name_label_dict
|
||||
|
||||
|
||||
def cal_det_metrics(gt_label_path, save_res_path):
|
||||
"""
|
||||
calculate the detection metrics
|
||||
Args:
|
||||
gt_label_path(string): The groundtruth detection label file path
|
||||
save_res_path(string): The saved predicted detection label path
|
||||
return:
|
||||
claculated metrics including Hmean, precision and recall
|
||||
"""
|
||||
evaluator = DetectionIoUEvaluator()
|
||||
gt_label_infor = load_label_infor(gt_label_path, do_ignore=True)
|
||||
dt_label_infor = load_label_infor(save_res_path)
|
||||
results = []
|
||||
for img_name in gt_label_infor:
|
||||
gt_label = gt_label_infor[img_name]
|
||||
if img_name not in dt_label_infor:
|
||||
dt_label = []
|
||||
else:
|
||||
dt_label = dt_label_infor[img_name]
|
||||
result = evaluator.evaluate_image(gt_label, dt_label)
|
||||
results.append(result)
|
||||
methodMetrics = evaluator.combine_results(results)
|
||||
return methodMetrics
|
||||
|
||||
|
||||
def eval_det_run(eval_args, mode='eval'):
|
||||
exe = eval_args['exe']
|
||||
config = eval_args['config']
|
||||
eval_info_dict = eval_args['eval_info_dict']
|
||||
cal_det_res(exe, config, eval_info_dict)
|
||||
|
||||
save_res_path = config['Global']['save_res_path']
|
||||
if mode == "eval":
|
||||
gt_label_path = config['EvalReader']['label_file_path']
|
||||
metrics = cal_det_metrics(gt_label_path, save_res_path)
|
||||
else:
|
||||
gt_label_path = config['TestReader']['label_file_path']
|
||||
do_eval = config['TestReader']['do_eval']
|
||||
if do_eval:
|
||||
metrics = cal_det_metrics(gt_label_path, save_res_path)
|
||||
else:
|
||||
metrics = {}
|
||||
return metrics['hmean']
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if os.environ.get(key, None) is None:
|
||||
os.environ[key] = str(value)
|
||||
|
||||
|
||||
# NOTE(paddle-dev): All of these flags should be
|
||||
# set before `import paddle`. Otherwise, it would
|
||||
# not take any effect.
|
||||
set_paddle_flags(
|
||||
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||
)
|
||||
|
||||
import program
|
||||
from paddle import fluid
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.save_load import init_model
|
||||
from paddleslim.prune import load_model
|
||||
|
||||
|
||||
def main():
|
||||
startup_prog, eval_program, place, config, _ = program.preprocess()
|
||||
|
||||
feeded_var_names, target_vars, fetches_var_name = program.build_export(
|
||||
config, eval_program, startup_prog)
|
||||
eval_program = eval_program.clone(for_test=True)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_prog)
|
||||
|
||||
if config['Global']['checkpoints'] is not None:
|
||||
path = config['Global']['checkpoints']
|
||||
else:
|
||||
path = config['Global']['pretrain_weights']
|
||||
|
||||
load_model(exe, eval_program, path)
|
||||
|
||||
save_inference_dir = config['Global']['save_inference_dir']
|
||||
if not os.path.exists(save_inference_dir):
|
||||
os.makedirs(save_inference_dir)
|
||||
fluid.io.save_inference_model(
|
||||
dirname=save_inference_dir,
|
||||
feeded_var_names=feeded_var_names,
|
||||
main_program=eval_program,
|
||||
target_vars=target_vars,
|
||||
executor=exe,
|
||||
model_filename='model',
|
||||
params_filename='params')
|
||||
print("inference model saved in {}/model and {}/params".format(
|
||||
save_inference_dir, save_inference_dir))
|
||||
print("save success, output_name_list:", fetches_var_name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,188 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if os.environ.get(key, None) is None:
|
||||
os.environ[key] = str(value)
|
||||
|
||||
|
||||
# NOTE(paddle-dev): All of these flags should be
|
||||
# set before `import paddle`. Otherwise, it would
|
||||
# not take any effect.
|
||||
set_paddle_flags(
|
||||
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||
)
|
||||
|
||||
import tools.program as program
|
||||
from paddle import fluid
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.data.reader_main import reader_main
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.character import CharacterOps
|
||||
from ppocr.utils.utility import initial_logger
|
||||
from paddleslim.prune import Pruner, save_model
|
||||
from paddleslim.analysis import flops
|
||||
from paddleslim.core.graph_wrapper import *
|
||||
from paddleslim.prune import load_sensitivities, get_ratios_by_loss, merge_sensitive
|
||||
logger = initial_logger()
|
||||
|
||||
skip_list = [
|
||||
'conv10_linear_weights', 'conv11_linear_weights', 'conv12_expand_weights',
|
||||
'conv12_linear_weights', 'conv12_se_2_weights', 'conv13_linear_weights',
|
||||
'conv2_linear_weights', 'conv4_linear_weights', 'conv5_expand_weights',
|
||||
'conv5_linear_weights', 'conv5_se_2_weights', 'conv6_linear_weights',
|
||||
'conv7_linear_weights', 'conv8_expand_weights', 'conv8_linear_weights',
|
||||
'conv9_expand_weights', 'conv9_linear_weights'
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
config = program.load_config(FLAGS.config)
|
||||
program.merge_config(FLAGS.opt)
|
||||
logger.info(config)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
program.check_gpu(use_gpu)
|
||||
|
||||
alg = config['Global']['algorithm']
|
||||
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
|
||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||
startup_program = fluid.Program()
|
||||
train_program = fluid.Program()
|
||||
train_build_outputs = program.build(
|
||||
config, train_program, startup_program, mode='train')
|
||||
train_loader = train_build_outputs[0]
|
||||
train_fetch_name_list = train_build_outputs[1]
|
||||
train_fetch_varname_list = train_build_outputs[2]
|
||||
train_opt_loss_name = train_build_outputs[3]
|
||||
|
||||
eval_program = fluid.Program()
|
||||
eval_build_outputs = program.build(
|
||||
config, eval_program, startup_program, mode='eval')
|
||||
eval_fetch_name_list = eval_build_outputs[1]
|
||||
eval_fetch_varname_list = eval_build_outputs[2]
|
||||
eval_program = eval_program.clone(for_test=True)
|
||||
|
||||
train_reader = reader_main(config=config, mode="train")
|
||||
train_loader.set_sample_list_generator(train_reader, places=place)
|
||||
|
||||
eval_reader = reader_main(config=config, mode="eval")
|
||||
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_program)
|
||||
|
||||
# compile program for multi-devices
|
||||
init_model(config, train_program, exe)
|
||||
|
||||
# params = get_pruned_params(train_program)
|
||||
'''
|
||||
sens_file = ['sensitivities_'+ str(x) for x in range(0,4)]
|
||||
sens = []
|
||||
for f in sens_file:
|
||||
sens.append(load_sensitivities(f+'.data'))
|
||||
sen = merge_sensitive(sens)
|
||||
'''
|
||||
sen = load_sensitivities("sensitivities_0.data")
|
||||
for i in skip_list:
|
||||
sen.pop(i)
|
||||
back_bone_list = ['conv' + str(x) for x in range(1, 5)]
|
||||
for i in back_bone_list:
|
||||
for key in list(sen.keys()):
|
||||
if i + '_' in key:
|
||||
sen.pop(key)
|
||||
ratios = get_ratios_by_loss(sen, 0.03)
|
||||
logger.info("FLOPs before pruning: {}".format(flops(eval_program)))
|
||||
pruner = Pruner(criterion='geometry_median')
|
||||
print("ratios: {}".format(ratios))
|
||||
pruned_val_program, _, _ = pruner.prune(
|
||||
eval_program,
|
||||
fluid.global_scope(),
|
||||
params=ratios.keys(),
|
||||
ratios=ratios.values(),
|
||||
place=place,
|
||||
only_graph=True)
|
||||
|
||||
pruned_program, _, _ = pruner.prune(
|
||||
train_program,
|
||||
fluid.global_scope(),
|
||||
params=ratios.keys(),
|
||||
ratios=ratios.values(),
|
||||
place=place)
|
||||
logger.info("FLOPs after pruning: {}".format(flops(pruned_val_program)))
|
||||
train_compile_program = program.create_multi_devices_program(
|
||||
pruned_program, train_opt_loss_name)
|
||||
|
||||
|
||||
train_info_dict = {'compile_program':train_compile_program,\
|
||||
'train_program':pruned_program,\
|
||||
'reader':train_loader,\
|
||||
'fetch_name_list':train_fetch_name_list,\
|
||||
'fetch_varname_list':train_fetch_varname_list}
|
||||
|
||||
eval_info_dict = {'program':pruned_val_program,\
|
||||
'reader':eval_reader,\
|
||||
'fetch_name_list':eval_fetch_name_list,\
|
||||
'fetch_varname_list':eval_fetch_varname_list}
|
||||
|
||||
if alg in ['EAST', 'DB']:
|
||||
program.train_eval_det_run(
|
||||
config, exe, train_info_dict, eval_info_dict, is_pruning=True)
|
||||
else:
|
||||
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
||||
|
||||
|
||||
def test_reader():
|
||||
config = program.load_config(FLAGS.config)
|
||||
program.merge_config(FLAGS.opt)
|
||||
print(config)
|
||||
train_reader = reader_main(config=config, mode="train")
|
||||
import time
|
||||
starttime = time.time()
|
||||
count = 0
|
||||
try:
|
||||
for data in train_reader():
|
||||
count += 1
|
||||
if count % 1 == 0:
|
||||
batch_time = time.time() - starttime
|
||||
starttime = time.time()
|
||||
print("reader:", count, len(data), batch_time)
|
||||
except Exception as e:
|
||||
logger.info(e)
|
||||
logger.info("finish reader: {}, Success!".format(count))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = program.ArgsParser()
|
||||
FLAGS = parser.parse_args()
|
||||
main()
|
||||
# test_reader()
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if os.environ.get(key, None) is None:
|
||||
os.environ[key] = str(value)
|
||||
|
||||
|
||||
# NOTE(paddle-dev): All of these flags should be
|
||||
# set before `import paddle`. Otherwise, it would
|
||||
# not take any effect.
|
||||
set_paddle_flags(
|
||||
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||
)
|
||||
|
||||
import json
|
||||
import cv2
|
||||
from paddle import fluid
|
||||
import paddleslim as slim
|
||||
from copy import deepcopy
|
||||
from eval_det_utils import eval_det_run
|
||||
|
||||
from tools import program
|
||||
from ppocr.utils.utility import initial_logger
|
||||
from ppocr.data.reader_main import reader_main
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.character import CharacterOps
|
||||
from ppocr.utils.utility import create_module
|
||||
from ppocr.data.reader_main import reader_main
|
||||
|
||||
logger = initial_logger()
|
||||
|
||||
|
||||
def get_pruned_params(program):
|
||||
params = []
|
||||
for param in program.global_block().all_parameters():
|
||||
if len(
|
||||
param.shape
|
||||
) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name:
|
||||
params.append(param.name)
|
||||
return params
|
||||
|
||||
|
||||
def main():
|
||||
config = program.load_config(FLAGS.config)
|
||||
program.merge_config(FLAGS.opt)
|
||||
logger.info(config)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
program.check_gpu(use_gpu)
|
||||
|
||||
alg = config['Global']['algorithm']
|
||||
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
|
||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||
startup_prog = fluid.Program()
|
||||
eval_program = fluid.Program()
|
||||
eval_build_outputs = program.build(
|
||||
config, eval_program, startup_prog, mode='test')
|
||||
eval_fetch_name_list = eval_build_outputs[1]
|
||||
eval_fetch_varname_list = eval_build_outputs[2]
|
||||
eval_program = eval_program.clone(for_test=True)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_prog)
|
||||
|
||||
init_model(config, eval_program, exe)
|
||||
|
||||
eval_reader = reader_main(config=config, mode="eval")
|
||||
eval_info_dict = {'program':eval_program,\
|
||||
'reader':eval_reader,\
|
||||
'fetch_name_list':eval_fetch_name_list,\
|
||||
'fetch_varname_list':eval_fetch_varname_list}
|
||||
eval_args = dict()
|
||||
eval_args = {'exe': exe, 'config': config, 'eval_info_dict': eval_info_dict}
|
||||
metrics = eval_det_run(eval_args)
|
||||
print("Baseline: {}".format(metrics))
|
||||
|
||||
params = get_pruned_params(eval_program)
|
||||
print('Start to analyze')
|
||||
sens_0 = slim.prune.sensitivity(
|
||||
eval_program,
|
||||
place,
|
||||
params,
|
||||
eval_det_run,
|
||||
sensitivities_file="sensitivities_0.data",
|
||||
pruned_ratios=[0.1],
|
||||
eval_args=eval_args,
|
||||
criterion='geometry_median')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = program.ArgsParser()
|
||||
FLAGS = parser.parse_args()
|
||||
main()
|
|
@ -33,6 +33,7 @@ from eval_utils.eval_rec_utils import eval_rec_run
|
|||
from ppocr.utils.save_load import save_model
|
||||
import numpy as np
|
||||
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
||||
import paddleslim as slim
|
||||
|
||||
|
||||
class ArgsParser(ArgumentParser):
|
||||
|
@ -238,7 +239,11 @@ def create_multi_devices_program(program, loss_var_name):
|
|||
return compile_program
|
||||
|
||||
|
||||
def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
||||
def train_eval_det_run(config,
|
||||
exe,
|
||||
train_info_dict,
|
||||
eval_info_dict,
|
||||
is_pruning=False):
|
||||
train_batch_id = 0
|
||||
log_smooth_window = config['Global']['log_smooth_window']
|
||||
epoch_num = config['Global']['epoch_num']
|
||||
|
@ -294,7 +299,13 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
|||
best_batch_id = train_batch_id
|
||||
best_epoch = epoch
|
||||
save_path = save_model_dir + "/best_accuracy"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_pruning:
|
||||
slim.prune.save_model(
|
||||
exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
save_model(train_info_dict['train_program'],
|
||||
save_path)
|
||||
strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format(
|
||||
train_batch_id, metrics, best_eval_hmean, best_epoch,
|
||||
best_batch_id)
|
||||
|
@ -305,10 +316,18 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
|||
train_loader.reset()
|
||||
if epoch == 0 and save_epoch_step == 1:
|
||||
save_path = save_model_dir + "/iter_epoch_0"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_pruning:
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_pruning:
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
return
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue