fix rec distillation (#3994)

* fix rec distillation

* add dist cfg

* fix yaml
This commit is contained in:
littletomatodonkey 2021-09-09 13:08:25 +08:00 committed by GitHub
parent 51f4a2c375
commit 7dc56191f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 255 additions and 95 deletions

View File

@ -4,7 +4,7 @@ Global:
epoch_num: 800 epoch_num: 800
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1 save_model_dir: ./output/rec_mobile_pp-OCRv2
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: [0, 2000] eval_batch_step: [0, 2000]
cal_metric_during_train: true cal_metric_during_train: true
@ -19,7 +19,7 @@ Global:
infer_mode: false infer_mode: false
use_space_char: true use_space_char: true
distributed: true distributed: true
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt save_res_path: ./output/rec/predicts_mobile_pp-OCRv2.txt
Optimizer: Optimizer:
@ -35,79 +35,32 @@ Optimizer:
name: L2 name: L2
factor: 2.0e-05 factor: 2.0e-05
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Architecture:
model_type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Loss: Loss:
name: CombinedLoss name: CTCLoss
loss_config_list:
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
PostProcess: PostProcess:
name: DistillationCTCLabelDecode name: CTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
Metric: Metric:
name: DistillationMetric name: RecMetric
base_metric_name: RecMetric
main_indicator: acc main_indicator: acc
key: "Student"
Train: Train:
dataset: dataset:
@ -132,7 +85,6 @@ Train:
shuffle: true shuffle: true
batch_size_per_card: 128 batch_size_per_card: 128
drop_last: true drop_last: true
num_sections: 1
num_workers: 8 num_workers: 8
Eval: Eval:
dataset: dataset:

View File

@ -0,0 +1,160 @@
Global:
debug: false
use_gpu: true
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_pp-OCRv2_distillation
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
character_type: ch
max_text_length: 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_pp-OCRv2_distillation.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs : [700, 800]
values : [0.001, 0.0001]
warmup_epoch: 5
regularizer:
name: L2
factor: 2.0e-05
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Loss:
name: CombinedLoss
loss_config_list:
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_sections: 1
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 8

View File

@ -39,7 +39,7 @@ PaddleOCR中集成了知识蒸馏的算法具体地有以下几个主要
### 2.1 识别配置文件解析 ### 2.1 识别配置文件解析
配置文件在[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)。 配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)。
#### 2.1.1 模型结构 #### 2.1.1 模型结构
@ -246,6 +246,39 @@ Metric:
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。 关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。
#### 2.1.5 蒸馏模型微调
对蒸馏得到的识别蒸馏进行微调有2种方式。
1基于知识蒸馏的微调这种情况比较简单下载预训练模型在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。
2微调时不使用知识蒸馏这种情况需要首先将预训练模型中的学生模型参数提取出来具体步骤如下。
* 首先下载预训练模型并解压。
```shell
# 下面预训练模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
tar -xf ch_PP-OCRv2_rec_train.tar
```
* 然后使用python对其中的学生模型参数进行提取
```python
import paddle
# 加载预训练模型
all_params = paddle.load("ch_PP-OCRv2_rec_train/best_accuracy.pdparams")
# 查看权重参数的keys
print(all_params.keys())
# 学生模型的权重提取
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# 查看学生模型权重参数的keys
print(s_params.keys())
# 保存
paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
```
转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。
### 2.2 检测配置文件解析 ### 2.2 检测配置文件解析
* coming soon! * coming soon!

View File

@ -56,31 +56,34 @@ class CELoss(nn.Layer):
class KLJSLoss(object): class KLJSLoss(object):
def __init__(self, mode='kl'): def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']" assert mode in ['kl', 'js', 'KL', 'JS'
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode self.mode = mode
def __call__(self, p1, p2, reduction="mean"): def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5)) loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js": if self.mode.lower() == "js":
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5)) loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5 loss *= 0.5
if reduction == "mean": if reduction == "mean":
loss = paddle.mean(loss, axis=[1,2]) loss = paddle.mean(loss, axis=[1, 2])
elif reduction=="none" or reduction is None: elif reduction == "none" or reduction is None:
return loss return loss
else: else:
loss = paddle.sum(loss, axis=[1,2]) loss = paddle.sum(loss, axis=[1, 2])
return loss return loss
class DMLLoss(nn.Layer): class DMLLoss(nn.Layer):
""" """
DMLLoss DMLLoss
""" """
def __init__(self, act=None): def __init__(self, act=None, use_log=False):
super().__init__() super().__init__()
if act is not None: if act is not None:
assert act in ["softmax", "sigmoid"] assert act in ["softmax", "sigmoid"]
@ -91,19 +94,23 @@ class DMLLoss(nn.Layer):
else: else:
self.act = None self.act = None
self.use_log = use_log
self.jskl_loss = KLJSLoss(mode="js") self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2): def forward(self, out1, out2):
if self.act is not None: if self.act is not None:
out1 = self.act(out1) out1 = self.act(out1)
out2 = self.act(out2) out2 = self.act(out2)
if len(out1.shape) < 2: if self.use_log:
# for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1) log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2) log_out2 = paddle.log(out2)
loss = (F.kl_div( loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div( log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0 log_out2, out1, reduction='batchmean')) / 2.0
else: else:
# for detection distillation log is not needed
loss = self.jskl_loss(out1, out2) loss = self.jskl_loss(out1, out2)
return loss return loss

View File

@ -49,11 +49,15 @@ class CombinedLoss(nn.Layer):
loss = loss_func(input, batch, **kargs) loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor): if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss} loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx] weight = self.loss_weight[idx]
for key in loss.keys():
if key == "loss": loss = {key: loss[key] * weight for key in loss}
loss_all += loss[key] * weight
else: if "loss" in loss:
loss_dict["{}_{}".format(key, idx)] = loss[key] loss_all += loss["loss"]
else:
loss_all += paddle.add_n(list(loss.values()))
loss_dict.update(loss)
loss_dict["loss"] = loss_all loss_dict["loss"] = loss_all
return loss_dict return loss_dict

View File

@ -44,10 +44,11 @@ class DistillationDMLLoss(DMLLoss):
def __init__(self, def __init__(self,
model_name_pairs=[], model_name_pairs=[],
act=None, act=None,
use_log=False,
key=None, key=None,
maps_name=None, maps_name=None,
name="dml"): name="dml"):
super().__init__(act=act) super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list) assert isinstance(model_name_pairs, list)
self.key = key self.key = key
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
@ -57,7 +58,8 @@ class DistillationDMLLoss(DMLLoss):
def _check_model_name_pairs(self, model_name_pairs): def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list): if not isinstance(model_name_pairs, list):
return [] return []
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str): elif isinstance(model_name_pairs[0], list) and isinstance(
model_name_pairs[0][0], str):
return model_name_pairs return model_name_pairs
else: else:
return [model_name_pairs] return [model_name_pairs]
@ -112,8 +114,8 @@ class DistillationDMLLoss(DMLLoss):
loss_dict["{}_{}_{}_{}_{}".format(key, pair[ loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], map_name, idx)] = loss[key] 0], pair[1], map_name, idx)] = loss[key]
else: else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c], loss_dict["{}_{}_{}".format(self.name, self.maps_name[
idx)] = loss _c], idx)] = loss
loss_dict = _sum_loss(loss_dict) loss_dict = _sum_loss(loss_dict)

View File

@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer):
for k1, k2 in zip(state_dict.keys(), params.keys()): for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape): if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2] new_state_dict[k1] = params[k2]
else: else:
logger.info( logger.info(
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
) )
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}") logger.info(f"loaded pretrained_model successful from {pm}")
return {} return {}
def load_pretrained_params(model, path): def load_pretrained_params(model, path):
if path is None: if path is None:
return False return False
@ -138,6 +139,7 @@ def load_pretrained_params(model, path):
print(f"load pretrain successful from {path}") print(f"load pretrain successful from {path}")
return model return model
def save_model(model, def save_model(model,
optimizer, optimizer,
model_path, model_path,