fix rec distillation (#3994)
* fix rec distillation * add dist cfg * fix yaml
This commit is contained in:
parent
51f4a2c375
commit
7dc56191f5
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 800
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1
|
||||
save_model_dir: ./output/rec_mobile_pp-OCRv2
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: true
|
||||
|
@ -19,7 +19,7 @@ Global:
|
|||
infer_mode: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
|
||||
save_res_path: ./output/rec/predicts_mobile_pp-OCRv2.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
@ -35,79 +35,32 @@ Optimizer:
|
|||
name: L2
|
||||
factor: 2.0e-05
|
||||
|
||||
Architecture:
|
||||
model_type: &model_type "rec"
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Teacher:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationCTCLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
key: head_out
|
||||
- DistillationDMLLoss:
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: head_out
|
||||
- DistillationDistanceLoss:
|
||||
weight: 1.0
|
||||
mode: "l2"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: backbone_out
|
||||
name: CTCLoss
|
||||
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode
|
||||
model_name: ["Student", "Teacher"]
|
||||
key: head_out
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: RecMetric
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
|
@ -132,7 +85,6 @@ Train:
|
|||
shuffle: true
|
||||
batch_size_per_card: 128
|
||||
drop_last: true
|
||||
num_sections: 1
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 800
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_pp-OCRv2_distillation
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: true
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
character_type: ch
|
||||
max_text_length: 25
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_pp-OCRv2_distillation.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs : [700, 800]
|
||||
values : [0.001, 0.0001]
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 2.0e-05
|
||||
|
||||
Architecture:
|
||||
model_type: &model_type "rec"
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Teacher:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationCTCLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
key: head_out
|
||||
- DistillationDMLLoss:
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
use_log: true
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: head_out
|
||||
- DistillationDistanceLoss:
|
||||
weight: 1.0
|
||||
mode: "l2"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: backbone_out
|
||||
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode
|
||||
model_name: ["Student", "Teacher"]
|
||||
key: head_out
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: RecMetric
|
||||
main_indicator: acc
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
- CTCLabelEncode:
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: 128
|
||||
drop_last: true
|
||||
num_sections: 1
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- CTCLabelEncode:
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 128
|
||||
num_workers: 8
|
|
@ -39,7 +39,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
|
|||
|
||||
### 2.1 识别配置文件解析
|
||||
|
||||
配置文件在[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)。
|
||||
配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)。
|
||||
|
||||
#### 2.1.1 模型结构
|
||||
|
||||
|
@ -246,6 +246,39 @@ Metric:
|
|||
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。
|
||||
|
||||
|
||||
#### 2.1.5 蒸馏模型微调
|
||||
|
||||
对蒸馏得到的识别蒸馏进行微调有2种方式。
|
||||
|
||||
(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。
|
||||
|
||||
(2)微调时不使用知识蒸馏:这种情况,需要首先将预训练模型中的学生模型参数提取出来,具体步骤如下。
|
||||
|
||||
* 首先下载预训练模型并解压。
|
||||
```shell
|
||||
# 下面预训练模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
|
||||
tar -xf ch_PP-OCRv2_rec_train.tar
|
||||
```
|
||||
|
||||
* 然后使用python,对其中的学生模型参数进行提取
|
||||
|
||||
```python
|
||||
import paddle
|
||||
# 加载预训练模型
|
||||
all_params = paddle.load("ch_PP-OCRv2_rec_train/best_accuracy.pdparams")
|
||||
# 查看权重参数的keys
|
||||
print(all_params.keys())
|
||||
# 学生模型的权重提取
|
||||
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
|
||||
# 查看学生模型权重参数的keys
|
||||
print(s_params.keys())
|
||||
# 保存
|
||||
paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
|
||||
```
|
||||
|
||||
转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。
|
||||
|
||||
### 2.2 检测配置文件解析
|
||||
|
||||
* coming soon!
|
||||
|
|
|
@ -56,31 +56,34 @@ class CELoss(nn.Layer):
|
|||
|
||||
class KLJSLoss(object):
|
||||
def __init__(self, mode='kl'):
|
||||
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||
assert mode in ['kl', 'js', 'KL', 'JS'
|
||||
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, p1, p2, reduction="mean"):
|
||||
|
||||
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
|
||||
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
|
||||
|
||||
if self.mode.lower() == "js":
|
||||
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
|
||||
loss += paddle.multiply(
|
||||
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
|
||||
loss *= 0.5
|
||||
if reduction == "mean":
|
||||
loss = paddle.mean(loss, axis=[1,2])
|
||||
elif reduction=="none" or reduction is None:
|
||||
loss = paddle.mean(loss, axis=[1, 2])
|
||||
elif reduction == "none" or reduction is None:
|
||||
return loss
|
||||
else:
|
||||
loss = paddle.sum(loss, axis=[1,2])
|
||||
loss = paddle.sum(loss, axis=[1, 2])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class DMLLoss(nn.Layer):
|
||||
"""
|
||||
DMLLoss
|
||||
"""
|
||||
|
||||
def __init__(self, act=None):
|
||||
def __init__(self, act=None, use_log=False):
|
||||
super().__init__()
|
||||
if act is not None:
|
||||
assert act in ["softmax", "sigmoid"]
|
||||
|
@ -91,19 +94,23 @@ class DMLLoss(nn.Layer):
|
|||
else:
|
||||
self.act = None
|
||||
|
||||
self.use_log = use_log
|
||||
|
||||
self.jskl_loss = KLJSLoss(mode="js")
|
||||
|
||||
def forward(self, out1, out2):
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1)
|
||||
out2 = self.act(out2)
|
||||
if len(out1.shape) < 2:
|
||||
if self.use_log:
|
||||
# for recognition distillation, log is needed for feature map
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, out1, reduction='batchmean')) / 2.0
|
||||
else:
|
||||
# for detection distillation log is not needed
|
||||
loss = self.jskl_loss(out1, out2)
|
||||
return loss
|
||||
|
||||
|
|
|
@ -49,11 +49,15 @@ class CombinedLoss(nn.Layer):
|
|||
loss = loss_func(input, batch, **kargs)
|
||||
if isinstance(loss, paddle.Tensor):
|
||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||
|
||||
weight = self.loss_weight[idx]
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
loss_all += loss[key] * weight
|
||||
else:
|
||||
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
||||
|
||||
loss = {key: loss[key] * weight for key in loss}
|
||||
|
||||
if "loss" in loss:
|
||||
loss_all += loss["loss"]
|
||||
else:
|
||||
loss_all += paddle.add_n(list(loss.values()))
|
||||
loss_dict.update(loss)
|
||||
loss_dict["loss"] = loss_all
|
||||
return loss_dict
|
||||
|
|
|
@ -44,10 +44,11 @@ class DistillationDMLLoss(DMLLoss):
|
|||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
act=None,
|
||||
use_log=False,
|
||||
key=None,
|
||||
maps_name=None,
|
||||
name="dml"):
|
||||
super().__init__(act=act)
|
||||
super().__init__(act=act, use_log=use_log)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
|
@ -57,7 +58,8 @@ class DistillationDMLLoss(DMLLoss):
|
|||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
return []
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(
|
||||
model_name_pairs[0][0], str):
|
||||
return model_name_pairs
|
||||
else:
|
||||
return [model_name_pairs]
|
||||
|
@ -112,8 +114,8 @@ class DistillationDMLLoss(DMLLoss):
|
|||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], map_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
|
||||
idx)] = loss
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
|
||||
_c], idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
||||
|
|
|
@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer):
|
|||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
||||
|
||||
def load_pretrained_params(model, path):
|
||||
if path is None:
|
||||
return False
|
||||
|
@ -138,6 +139,7 @@ def load_pretrained_params(model, path):
|
|||
print(f"load pretrain successful from {path}")
|
||||
return model
|
||||
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
model_path,
|
||||
|
|
Loading…
Reference in New Issue