commit
42f0219cda
|
@ -0,0 +1,40 @@
|
|||
> 运行示例前请先安装develop版本PaddleSlim
|
||||
|
||||
# 模型裁剪压缩教程
|
||||
|
||||
## 概述
|
||||
|
||||
该示例使用PaddleSlim提供的[裁剪压缩API](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/)对OCR模型进行压缩。
|
||||
在阅读该示例前,建议您先了解以下内容:
|
||||
|
||||
- [OCR模型的常规训练方法](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md)
|
||||
- [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)
|
||||
|
||||
## 安装PaddleSlim
|
||||
可按照[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)中的步骤安装PaddleSlim。
|
||||
|
||||
|
||||
|
||||
## 敏感度分析训练
|
||||
|
||||
进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析:
|
||||
|
||||
```bash
|
||||
python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1
|
||||
```
|
||||
|
||||
## 裁剪模型与fine-tune
|
||||
|
||||
```bash
|
||||
python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 评估并导出
|
||||
|
||||
在得到裁剪训练保存的模型后,我们可以将其导出为inference_model,用于预测部署:
|
||||
|
||||
```bash
|
||||
python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model
|
||||
```
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
import program
|
||||
from paddle import fluid
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.save_load import init_model
|
||||
from paddleslim.prune import load_model
|
||||
|
||||
|
||||
def main():
|
||||
startup_prog, eval_program, place, config, _ = program.preprocess()
|
||||
|
||||
feeded_var_names, target_vars, fetches_var_name = program.build_export(
|
||||
config, eval_program, startup_prog)
|
||||
eval_program = eval_program.clone(for_test=True)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_prog)
|
||||
|
||||
if config['Global']['checkpoints'] is not None:
|
||||
path = config['Global']['checkpoints']
|
||||
else:
|
||||
path = config['Global']['pretrain_weights']
|
||||
|
||||
load_model(exe, eval_program, path)
|
||||
|
||||
save_inference_dir = config['Global']['save_inference_dir']
|
||||
if not os.path.exists(save_inference_dir):
|
||||
os.makedirs(save_inference_dir)
|
||||
fluid.io.save_inference_model(
|
||||
dirname=save_inference_dir,
|
||||
feeded_var_names=feeded_var_names,
|
||||
main_program=eval_program,
|
||||
target_vars=target_vars,
|
||||
executor=exe,
|
||||
model_filename='model',
|
||||
params_filename='params')
|
||||
print("inference model saved in {}/model and {}/params".format(
|
||||
save_inference_dir, save_inference_dir))
|
||||
print("save success, output_name_list:", fetches_var_name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
import tools.program as program
|
||||
from paddle import fluid
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.data.reader_main import reader_main
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.character import CharacterOps
|
||||
from ppocr.utils.utility import initial_logger
|
||||
from paddleslim.prune import Pruner, save_model
|
||||
from paddleslim.analysis import flops
|
||||
from paddleslim.core.graph_wrapper import *
|
||||
from paddleslim.prune import load_sensitivities, get_ratios_by_loss, merge_sensitive
|
||||
logger = initial_logger()
|
||||
|
||||
skip_list = [
|
||||
'conv10_linear_weights', 'conv11_linear_weights', 'conv12_expand_weights',
|
||||
'conv12_linear_weights', 'conv12_se_2_weights', 'conv13_linear_weights',
|
||||
'conv2_linear_weights', 'conv4_linear_weights', 'conv5_expand_weights',
|
||||
'conv5_linear_weights', 'conv5_se_2_weights', 'conv6_linear_weights',
|
||||
'conv7_linear_weights', 'conv8_expand_weights', 'conv8_linear_weights',
|
||||
'conv9_expand_weights', 'conv9_linear_weights'
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
config = program.load_config(FLAGS.config)
|
||||
program.merge_config(FLAGS.opt)
|
||||
logger.info(config)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
program.check_gpu(use_gpu)
|
||||
|
||||
alg = config['Global']['algorithm']
|
||||
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
|
||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||
startup_program = fluid.Program()
|
||||
train_program = fluid.Program()
|
||||
train_build_outputs = program.build(
|
||||
config, train_program, startup_program, mode='train')
|
||||
train_loader = train_build_outputs[0]
|
||||
train_fetch_name_list = train_build_outputs[1]
|
||||
train_fetch_varname_list = train_build_outputs[2]
|
||||
train_opt_loss_name = train_build_outputs[3]
|
||||
|
||||
eval_program = fluid.Program()
|
||||
eval_build_outputs = program.build(
|
||||
config, eval_program, startup_program, mode='eval')
|
||||
eval_fetch_name_list = eval_build_outputs[1]
|
||||
eval_fetch_varname_list = eval_build_outputs[2]
|
||||
eval_program = eval_program.clone(for_test=True)
|
||||
|
||||
train_reader = reader_main(config=config, mode="train")
|
||||
train_loader.set_sample_list_generator(train_reader, places=place)
|
||||
|
||||
eval_reader = reader_main(config=config, mode="eval")
|
||||
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_program)
|
||||
|
||||
# compile program for multi-devices
|
||||
init_model(config, train_program, exe)
|
||||
|
||||
sen = load_sensitivities("sensitivities_0.data")
|
||||
for i in skip_list:
|
||||
sen.pop(i)
|
||||
back_bone_list = ['conv' + str(x) for x in range(1, 5)]
|
||||
for i in back_bone_list:
|
||||
for key in list(sen.keys()):
|
||||
if i + '_' in key:
|
||||
sen.pop(key)
|
||||
ratios = get_ratios_by_loss(sen, 0.03)
|
||||
logger.info("FLOPs before pruning: {}".format(flops(eval_program)))
|
||||
pruner = Pruner(criterion='geometry_median')
|
||||
print("ratios: {}".format(ratios))
|
||||
pruned_val_program, _, _ = pruner.prune(
|
||||
eval_program,
|
||||
fluid.global_scope(),
|
||||
params=ratios.keys(),
|
||||
ratios=ratios.values(),
|
||||
place=place,
|
||||
only_graph=True)
|
||||
|
||||
pruned_program, _, _ = pruner.prune(
|
||||
train_program,
|
||||
fluid.global_scope(),
|
||||
params=ratios.keys(),
|
||||
ratios=ratios.values(),
|
||||
place=place)
|
||||
logger.info("FLOPs after pruning: {}".format(flops(pruned_val_program)))
|
||||
train_compile_program = program.create_multi_devices_program(
|
||||
pruned_program, train_opt_loss_name)
|
||||
|
||||
|
||||
train_info_dict = {'compile_program':train_compile_program,\
|
||||
'train_program':pruned_program,\
|
||||
'reader':train_loader,\
|
||||
'fetch_name_list':train_fetch_name_list,\
|
||||
'fetch_varname_list':train_fetch_varname_list}
|
||||
|
||||
eval_info_dict = {'program':pruned_val_program,\
|
||||
'reader':eval_reader,\
|
||||
'fetch_name_list':eval_fetch_name_list,\
|
||||
'fetch_varname_list':eval_fetch_varname_list}
|
||||
|
||||
if alg in ['EAST', 'DB']:
|
||||
program.train_eval_det_run(
|
||||
config, exe, train_info_dict, eval_info_dict, is_pruning=True)
|
||||
else:
|
||||
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = program.ArgsParser()
|
||||
FLAGS = parser.parse_args()
|
||||
main()
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
import json
|
||||
import cv2
|
||||
from paddle import fluid
|
||||
import paddleslim as slim
|
||||
from copy import deepcopy
|
||||
from tools.eval_utils.eval_det_utils import eval_det_run
|
||||
|
||||
from tools import program
|
||||
from ppocr.utils.utility import initial_logger
|
||||
from ppocr.data.reader_main import reader_main
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.character import CharacterOps
|
||||
from ppocr.utils.utility import create_module
|
||||
from ppocr.data.reader_main import reader_main
|
||||
|
||||
logger = initial_logger()
|
||||
|
||||
|
||||
def get_pruned_params(program):
|
||||
params = []
|
||||
for param in program.global_block().all_parameters():
|
||||
if len(
|
||||
param.shape
|
||||
) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name:
|
||||
params.append(param.name)
|
||||
return params
|
||||
|
||||
|
||||
def eval_function(eval_args, mode='eval'):
|
||||
exe = eval_args['exe']
|
||||
config = eval_args['config']
|
||||
eval_info_dict = eval_args['eval_info_dict']
|
||||
metrics = eval_det_run(exe, config, eval_info_dict, mode=mode)
|
||||
return metrics['hmean']
|
||||
|
||||
|
||||
def main():
|
||||
config = program.load_config(FLAGS.config)
|
||||
program.merge_config(FLAGS.opt)
|
||||
logger.info(config)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
program.check_gpu(use_gpu)
|
||||
|
||||
alg = config['Global']['algorithm']
|
||||
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
|
||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||
startup_prog = fluid.Program()
|
||||
eval_program = fluid.Program()
|
||||
eval_build_outputs = program.build(
|
||||
config, eval_program, startup_prog, mode='test')
|
||||
eval_fetch_name_list = eval_build_outputs[1]
|
||||
eval_fetch_varname_list = eval_build_outputs[2]
|
||||
eval_program = eval_program.clone(for_test=True)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_prog)
|
||||
|
||||
init_model(config, eval_program, exe)
|
||||
|
||||
eval_reader = reader_main(config=config, mode="eval")
|
||||
eval_info_dict = {'program':eval_program,\
|
||||
'reader':eval_reader,\
|
||||
'fetch_name_list':eval_fetch_name_list,\
|
||||
'fetch_varname_list':eval_fetch_varname_list}
|
||||
eval_args = dict()
|
||||
eval_args = {'exe': exe, 'config': config, 'eval_info_dict': eval_info_dict}
|
||||
metrics = eval_function(eval_args)
|
||||
print("Baseline: {}".format(metrics))
|
||||
|
||||
params = get_pruned_params(eval_program)
|
||||
print('Start to analyze')
|
||||
sens_0 = slim.prune.sensitivity(
|
||||
eval_program,
|
||||
place,
|
||||
params,
|
||||
eval_function,
|
||||
sensitivities_file="sensitivities_0.data",
|
||||
pruned_ratios=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
|
||||
eval_args=eval_args,
|
||||
criterion='geometry_median')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = program.ArgsParser()
|
||||
FLAGS = parser.parse_args()
|
||||
main()
|
|
@ -241,7 +241,11 @@ def create_multi_devices_program(program, loss_var_name, for_quant=False):
|
|||
return compile_program
|
||||
|
||||
|
||||
def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
||||
def train_eval_det_run(config,
|
||||
exe,
|
||||
train_info_dict,
|
||||
eval_info_dict,
|
||||
is_pruning=False):
|
||||
train_batch_id = 0
|
||||
log_smooth_window = config['Global']['log_smooth_window']
|
||||
epoch_num = config['Global']['epoch_num']
|
||||
|
@ -297,7 +301,14 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
|||
best_batch_id = train_batch_id
|
||||
best_epoch = epoch
|
||||
save_path = save_model_dir + "/best_accuracy"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_pruning:
|
||||
import paddleslim as slim
|
||||
slim.prune.save_model(
|
||||
exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
save_model(train_info_dict['train_program'],
|
||||
save_path)
|
||||
strs = 'Test iter: {}, metrics:{}, best_hmean:{:.6f}, best_epoch:{}, best_batch_id:{}'.format(
|
||||
train_batch_id, metrics, best_eval_hmean, best_epoch,
|
||||
best_batch_id)
|
||||
|
@ -308,10 +319,20 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
|||
train_loader.reset()
|
||||
if epoch == 0 and save_epoch_step == 1:
|
||||
save_path = save_model_dir + "/iter_epoch_0"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_pruning:
|
||||
import paddleslim as slim
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if is_pruning:
|
||||
import paddleslim as slim
|
||||
slim.prune.save_model(exe, train_info_dict['train_program'],
|
||||
save_path)
|
||||
else:
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
return
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue