fix distillation arch and model init
This commit is contained in:
parent
9d1e5d0912
commit
e5d3a2d880
|
@ -4,11 +4,9 @@ Global:
|
|||
epoch_num: 800
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_D081
|
||||
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1
|
||||
save_epoch_step: 3
|
||||
eval_batch_step:
|
||||
- 0
|
||||
- 2000
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: true
|
||||
pretrained_model: null
|
||||
checkpoints: null
|
||||
|
@ -37,12 +35,10 @@ Optimizer:
|
|||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
freeze_params:
|
||||
- false
|
||||
- false
|
||||
pretrained: null
|
||||
Models:
|
||||
Student:
|
||||
pretrained: null
|
||||
freeze_params: false
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
|
@ -59,6 +55,8 @@ Architecture:
|
|||
name: CTCHead
|
||||
fc_decay: 0.00001
|
||||
Teacher:
|
||||
pretrained: null
|
||||
freeze_params: false
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
|
@ -85,16 +83,20 @@ Loss:
|
|||
key: null
|
||||
- DistillationDMLLoss:
|
||||
weight: 1.0
|
||||
model_name_list1: ["Student"]
|
||||
model_name_list2: ["Teacher"]
|
||||
act: "softmax"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: null
|
||||
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode
|
||||
model_name: "Student"
|
||||
key_out: null
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
|
@ -108,10 +110,7 @@ Train:
|
|||
- RecAug: null
|
||||
- CTCLabelEncode: null
|
||||
- RecResizeImg:
|
||||
image_shape:
|
||||
- 3
|
||||
- 32
|
||||
- 320
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
|
@ -135,10 +134,7 @@ Eval:
|
|||
channel_first: false
|
||||
- CTCLabelEncode: null
|
||||
- RecResizeImg:
|
||||
image_shape:
|
||||
- 3
|
||||
- 32
|
||||
- 320
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
|
|
|
@ -62,19 +62,29 @@ class DMLLoss(nn.Layer):
|
|||
DMLLoss
|
||||
"""
|
||||
|
||||
def __init__(self, name="loss_dml"):
|
||||
def __init__(self, act=None, name="loss_dml"):
|
||||
super().__init__()
|
||||
if act is not None:
|
||||
assert act in ["softmax", "sigmoid"]
|
||||
self.name = name
|
||||
if act == "softmax":
|
||||
self.act = nn.Softmax(axis=-1)
|
||||
elif act == "sigmoid":
|
||||
self.act = nn.Sigmoid()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, out1, out2):
|
||||
loss_dict = {}
|
||||
soft_out1 = F.softmax(out1, axis=-1)
|
||||
log_soft_out1 = paddle.log(soft_out1)
|
||||
soft_out2 = F.softmax(out2, axis=-1)
|
||||
log_soft_out2 = paddle.log(soft_out2)
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1)
|
||||
out2 = self.act(out2)
|
||||
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_soft_out1, soft_out2, reduction='batchmean') + F.kl_div(
|
||||
log_soft_out2, soft_out1, reduction='batchmean')) / 2.0
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, log_out1, reduction='batchmean')) / 2.0
|
||||
loss_dict[self.name] = loss
|
||||
return loss_dict
|
||||
|
||||
|
@ -90,7 +100,7 @@ class DistanceLoss(nn.Layer):
|
|||
assert mode in ["l1", "l2", "smooth_l1"]
|
||||
if mode == "l1":
|
||||
self.loss_func = nn.L1Loss(**kargs)
|
||||
elif mode == "l1":
|
||||
elif mode == "l2":
|
||||
self.loss_func = nn.MSELoss(**kargs)
|
||||
elif mode == "smooth_l1":
|
||||
self.loss_func = nn.SmoothL1Loss(**kargs)
|
||||
|
|
|
@ -23,35 +23,28 @@ class DistillationDMLLoss(DMLLoss):
|
|||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_list1=[],
|
||||
model_name_list2=[],
|
||||
key=None,
|
||||
def __init__(self, model_name_pairs=[], act=None, key=None,
|
||||
name="loss_dml"):
|
||||
super().__init__(name=name)
|
||||
if not isinstance(model_name_list1, (list, )):
|
||||
model_name_list1 = [model_name_list1]
|
||||
if not isinstance(model_name_list2, (list, )):
|
||||
model_name_list2 = [model_name_list2]
|
||||
|
||||
assert len(model_name_list1) == len(model_name_list2)
|
||||
self.model_name_list1 = model_name_list1
|
||||
self.model_name_list2 = model_name_list2
|
||||
super().__init__(act=act, name=name)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx in range(len(self.model_name_list1)):
|
||||
out1 = predicts[self.model_name_list1[idx]]
|
||||
out2 = predicts[self.model_name_list2[idx]]
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
assert len(loss) == 1
|
||||
loss = list(loss.values())[0]
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
|
||||
key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
|
@ -64,13 +57,15 @@ class DistillationCTCLoss(CTCLoss):
|
|||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for model_name in self.model_name_list:
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
loss = super().forward(out, batch)
|
||||
if isinstance(loss, dict):
|
||||
assert len(loss) == 1
|
||||
loss = list(loss.values())[0]
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, model_name,
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
return loss_dict
|
||||
|
|
|
@ -34,25 +34,20 @@ class DistillationModel(nn.Layer):
|
|||
config (dict): the super parameters for module.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
freeze_params = config["freeze_params"]
|
||||
pretrained = config["pretrained"]
|
||||
if not isinstance(freeze_params, list):
|
||||
freeze_params = [freeze_params]
|
||||
assert len(config["Models"]) == len(freeze_params)
|
||||
|
||||
if not isinstance(pretrained, list):
|
||||
pretrained = [pretrained] * len(config["Models"])
|
||||
assert len(config["Models"]) == len(pretrained)
|
||||
|
||||
self.model_dict = dict()
|
||||
index = 0
|
||||
for key in config["Models"]:
|
||||
model_config = config["Models"][key]
|
||||
freeze_params = False
|
||||
pretrained = None
|
||||
if "freeze_params" in model_config:
|
||||
freeze_params = model_config.pop("freeze_params")
|
||||
if "pretrained" in model_config:
|
||||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained[index] is not None:
|
||||
if pretrained is not None:
|
||||
load_dygraph_pretrain(model, path=pretrained[index])
|
||||
if freeze_params[index]:
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
self.model_dict[key] = self.add_sublayer(key, model)
|
||||
|
|
|
@ -42,38 +42,10 @@ def _mkdir_if_not_exist(path, logger):
|
|||
raise OSError('Failed to mkdir {}'.format(path))
|
||||
|
||||
|
||||
def load_dygraph_pretrain(model,
|
||||
logger=None,
|
||||
path=None,
|
||||
load_static_weights=False):
|
||||
def load_dygraph_pretrain(model, logger=None, path=None):
|
||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(path))
|
||||
if load_static_weights:
|
||||
pre_state_dict = paddle.static.load_program_state(path)
|
||||
param_state_dict = {}
|
||||
model_dict = model.state_dict()
|
||||
for key in model_dict.keys():
|
||||
weight_name = model_dict[key].name
|
||||
weight_name = weight_name.replace('binarize', '').replace(
|
||||
'thresh', '') # for DB
|
||||
if weight_name in pre_state_dict.keys():
|
||||
# logger.info('Load weight: {}, shape: {}'.format(
|
||||
# weight_name, pre_state_dict[weight_name].shape))
|
||||
if 'encoder_rnn' in key:
|
||||
# delete axis which is 1
|
||||
pre_state_dict[weight_name] = pre_state_dict[
|
||||
weight_name].squeeze()
|
||||
# change axis
|
||||
if len(pre_state_dict[weight_name].shape) > 1:
|
||||
pre_state_dict[weight_name] = pre_state_dict[
|
||||
weight_name].transpose((1, 0))
|
||||
param_state_dict[key] = pre_state_dict[weight_name]
|
||||
else:
|
||||
param_state_dict[key] = model_dict[key]
|
||||
model.set_state_dict(param_state_dict)
|
||||
return
|
||||
|
||||
param_state_dict = paddle.load(path + '.pdparams')
|
||||
model.set_state_dict(param_state_dict)
|
||||
return
|
||||
|
@ -108,15 +80,10 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
|||
|
||||
logger.info("resume from {}".format(checkpoints))
|
||||
elif pretrained_model:
|
||||
load_static_weights = global_config.get('load_static_weights', False)
|
||||
if not isinstance(pretrained_model, list):
|
||||
pretrained_model = [pretrained_model]
|
||||
if not isinstance(load_static_weights, list):
|
||||
load_static_weights = [load_static_weights] * len(pretrained_model)
|
||||
for idx, pretrained in enumerate(pretrained_model):
|
||||
load_static = load_static_weights[idx]
|
||||
load_dygraph_pretrain(
|
||||
model, logger, path=pretrained, load_static_weights=load_static)
|
||||
for pretrained in pretrained_model:
|
||||
load_dygraph_pretrain(model, logger, path=pretrained)
|
||||
logger.info("load pretrained model from {}".format(
|
||||
pretrained_model))
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
# Python端预测部署
|
||||
|
||||
Python预测可以使用`tools/infer.py`,此种方式依赖PaddleDetection源码;也可以使用本篇教程预测方式,先将模型导出,使用一个独立的文件进行预测。
|
||||
|
||||
|
||||
本篇教程使用AnalysisPredictor对[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/deploy/EXPORT_MODEL.md)进行高性能预测。
|
||||
|
||||
在PaddlePaddle中预测引擎和训练引擎底层有着不同的优化方法, 预测引擎使用了AnalysisPredictor,专门针对推理进行了优化,是基于[C++预测库](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_guide/inference_deployment/inference/native_infer.html)的Python接口,该引擎可以对模型进行多项图优化,减少不必要的内存拷贝。如果用户在部署已训练模型的过程中对性能有较高的要求,我们提供了独立于PaddleDetection的预测脚本,方便用户直接集成部署。
|
||||
|
||||
|
||||
主要包含两个步骤:
|
||||
|
||||
- 导出预测模型
|
||||
- 基于Python的预测
|
||||
|
||||
## 1. 导出预测模型
|
||||
|
||||
PaddleDetection在训练过程包括网络的前向和优化器相关参数,而在部署过程中,我们只需要前向参数,具体参考:[导出模型](https://github.com/PaddlePaddle/PaddleDetection/blob/develop/deploy/EXPORT_MODEL.md)
|
||||
|
||||
导出后目录下,包括`infer_cfg.yml`, `model.pdiparams`, `model.pdiparams.info`, `model.pdmodel`四个文件。
|
||||
|
||||
## 2. 基于python的预测
|
||||
|
||||
### 2.1 安装依赖
|
||||
- `PaddlePaddle`的安装:
|
||||
请点击[官方安装文档](https://paddlepaddle.org.cn/install/quick) 选择适合的方式,版本为2.0rc1以上即可
|
||||
- 切换到`PaddleDetection`代码库根目录,执行`pip install -r requirements.txt`安装其它依赖
|
||||
|
||||
### 2.2 执行预测程序
|
||||
在终端输入以下命令进行预测:
|
||||
|
||||
```bash
|
||||
python deploy/python/infer.py --model_dir=/path/to/models --image_file=/path/to/image
|
||||
--use_gpu=(False/True)
|
||||
```
|
||||
|
||||
参数说明如下:
|
||||
|
||||
| 参数 | 是否必须|含义 |
|
||||
|-------|-------|----------|
|
||||
| --model_dir | Yes|上述导出的模型路径 |
|
||||
| --image_file | Option |需要预测的图片 |
|
||||
| --video_file | Option |需要预测的视频 |
|
||||
| --camera_id | Option | 用来预测的摄像头ID,默认为-1(表示不使用摄像头预测,可设置为:0 - (摄像头数目-1) ),预测过程中在可视化界面按`q`退出输出预测结果到:output/output.mp4|
|
||||
| --use_gpu |No|是否GPU,默认为False|
|
||||
| --run_mode |No|使用GPU时,默认为fluid, 可选(fluid/trt_fp32/trt_fp16/trt_int8)|
|
||||
| --threshold |No|预测得分的阈值,默认为0.5|
|
||||
| --output_dir |No|可视化结果保存的根目录,默认为output/|
|
||||
| --run_benchmark |No|是否运行benchmark,同时需指定--image_file|
|
||||
|
||||
说明:
|
||||
|
||||
- run_mode:fluid代表使用AnalysisPredictor,精度float32来推理,其他参数指用AnalysisPredictor,TensorRT不同精度来推理。
|
||||
- PaddlePaddle默认的GPU安装包(<=1.7),不支持基于TensorRT进行预测,如果想基于TensorRT加速预测,需要自行编译,详细可参考[预测库编译教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/advanced_usage/deploy/inference/paddle_tensorrt_infer.html)。
|
||||
|
||||
## 3. 部署性能对比测试
|
||||
对比AnalysisPredictor相对Executor的推理速度
|
||||
|
||||
### 3.1 测试环境:
|
||||
|
||||
- CUDA 9.0
|
||||
- CUDNN 7.5
|
||||
- PaddlePaddle 1.71
|
||||
- GPU: Tesla P40
|
||||
|
||||
### 3.2 测试方式:
|
||||
|
||||
- Batch Size=1
|
||||
- 去掉前100轮warmup时间,测试100轮的平均时间,单位ms/image,只计算模型运行时间,不包括数据的处理和拷贝。
|
||||
|
||||
|
||||
### 3.3 测试结果
|
||||
|
||||
|模型 | AnalysisPredictor | Executor | 输入|
|
||||
|---|----|---|---|
|
||||
| YOLOv3-MobileNetv1 | 15.20 | 19.54 | 608*608
|
||||
| faster_rcnn_r50_fpn_1x | 50.05 | 69.58 |800*1088
|
||||
| faster_rcnn_r50_1x | 326.11 | 347.22 | 800*1067
|
||||
| mask_rcnn_r50_fpn_1x | 67.49 | 91.02 | 800*1088
|
||||
| mask_rcnn_r50_1x | 326.11 | 350.94 | 800*1067
|
Loading…
Reference in New Issue