Merge pull request #810 from yukavio/develop

update bash of slim pruning
This commit is contained in:
Double_V 2020-09-23 20:20:21 +08:00 committed by GitHub
commit c320457d73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 8 additions and 7 deletions

View File

@ -51,14 +51,14 @@ python setup.py install
进入PaddleOCR根目录通过以下命令对模型进行敏感度分析训练 进入PaddleOCR根目录通过以下命令对模型进行敏感度分析训练
```bash ```bash
python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights="your trained model" Global.test_batch_size_per_card=1 python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights="your trained model" Global.test_batch_size_per_card=1
``` ```
### 4. 模型裁剪训练 ### 4. 模型裁剪训练
裁剪时通过之前的敏感度分析文件决定每个网络层的裁剪比例。在具体实现时为了尽可能多的保留从图像中提取的低阶特征我们跳过了backbone中靠近输入的4个卷积层。同样为了减少由于裁剪导致的模型性能损失我们通过之前敏感度分析所获得的敏感度表人工挑选出了一些冗余较少对裁剪较为敏感的[网络层](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/slim/prune/pruning_and_finetune.py#L41)指在较低的裁剪比例下就导致很高性能损失的网络层并在之后的裁剪过程中选择避开这些网络层。裁剪过后finetune的过程沿用OCR检测模型原始的训练策略。 裁剪时通过之前的敏感度分析文件决定每个网络层的裁剪比例。在具体实现时为了尽可能多的保留从图像中提取的低阶特征我们跳过了backbone中靠近输入的4个卷积层。同样为了减少由于裁剪导致的模型性能损失我们通过之前敏感度分析所获得的敏感度表人工挑选出了一些冗余较少对裁剪较为敏感的[网络层](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/slim/prune/pruning_and_finetune.py#L41)指在较低的裁剪比例下就导致很高性能损失的网络层并在之后的裁剪过程中选择避开这些网络层。裁剪过后finetune的过程沿用OCR检测模型原始的训练策略。
```bash ```bash
python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1
``` ```
通过对比可以发现,经过裁剪训练保存的模型更小。 通过对比可以发现,经过裁剪训练保存的模型更小。
@ -66,7 +66,7 @@ python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -
在得到裁剪训练保存的模型后我们可以将其导出为inference_model 在得到裁剪训练保存的模型后我们可以将其导出为inference_model
```bash ```bash
python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model
``` ```
inference model的预测和部署参考 inference model的预测和部署参考

View File

@ -55,7 +55,7 @@ Enter the PaddleOCR root directoryperform sensitivity analysis on the model w
```bash ```bash
python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1
``` ```
@ -67,7 +67,7 @@ python deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Gl
```bash ```bash
python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1 python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights=./deploy/slim/prune/pretrain_models/det_mv3_db/best_accuracy Global.test_batch_size_per_card=1
``` ```
@ -76,7 +76,7 @@ python deploy/slim/prune/pruning_and_finetune.py -c configs/det/det_mv3_db.yml -
We can export the pruned model as inference_model for deployment: We can export the pruned model as inference_model for deployment:
```bash ```bash
python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model python deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db_v1.1.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model
``` ```
Reference for prediction and deployment of inference model: Reference for prediction and deployment of inference model:

View File

@ -92,7 +92,8 @@ def main():
sen = load_sensitivities("sensitivities_0.data") sen = load_sensitivities("sensitivities_0.data")
for i in skip_list: for i in skip_list:
sen.pop(i) if i in sen.keys():
sen.pop(i)
back_bone_list = ['conv' + str(x) for x in range(1, 5)] back_bone_list = ['conv' + str(x) for x in range(1, 5)]
for i in back_bone_list: for i in back_bone_list:
for key in list(sen.keys()): for key in list(sen.keys()):