diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml index f3e75341..38aeffcb 100644 --- a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml +++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml @@ -39,6 +39,7 @@ Architecture: Student: pretrained: null freeze_params: false + return_all_feats: true model_type: rec algorithm: CRNN Transform: @@ -57,6 +58,7 @@ Architecture: Teacher: pretrained: null freeze_params: false + return_all_feats: true model_type: rec algorithm: CRNN Transform: @@ -80,18 +82,26 @@ Loss: - DistillationCTCLoss: weight: 1.0 model_name_list: ["Student", "Teacher"] - key: null + key: head_out - DistillationDMLLoss: weight: 1.0 act: "softmax" model_name_pairs: - ["Student", "Teacher"] - key: null + key: head_out + - DistillationDistanceLoss: + weight: 1.0 + mode: "l2" + model_name_pairs: + - ["Student", "Teacher"] + key: backbone_out + + PostProcess: name: DistillationCTCLabelDecode model_name: "Student" - key_out: null + key: head_out Metric: name: RecMetric diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py index 153bf690..022ae5c6 100644 --- a/ppocr/losses/basic_loss.py +++ b/ppocr/losses/basic_loss.py @@ -97,6 +97,7 @@ class DistanceLoss(nn.Layer): """ def __init__(self, mode="l2", name="loss_dist", **kargs): + super().__init__() assert mode in ["l1", "l2", "smooth_l1"] if mode == "l1": self.loss_func = nn.L1Loss(**kargs) diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index 49012e30..54da7017 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -17,6 +17,7 @@ import paddle.nn as nn from .distillation_loss import DistillationCTCLoss from .distillation_loss import DistillationDMLLoss +from .distillation_loss import DistillationDistanceLoss class CombinedLoss(nn.Layer): diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 40a8da77..a62922f0 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -17,6 +17,7 @@ import paddle.nn as nn from .rec_ctc_loss import CTCLoss from .basic_loss import DMLLoss +from .basic_loss import DistanceLoss class DistillationDMLLoss(DMLLoss): @@ -69,3 +70,36 @@ class DistillationCTCLoss(CTCLoss): else: loss_dict["{}_{}".format(self.name, model_name)] = loss return loss_dict + + +class DistillationDistanceLoss(DistanceLoss): + """ + """ + + def __init__(self, + mode="l2", + model_name_pairs=[], + key=None, + name="loss_distance", + **kargs): + super().__init__(mode=mode, name=name) + assert isinstance(model_name_pairs, list) + self.key = key + self.model_name_pairs = model_name_pairs + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + out1 = predicts[pair[0]] + out2 = predicts[pair[1]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[ + key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss + return loss_dict diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 5a41e507..4c941fcf 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -67,14 +67,23 @@ class BaseModel(nn.Layer): config["Head"]['in_channels'] = in_channels self.head = build_head(config["Head"]) + self.return_all_feats = config.get("return_all_feats", False) + def forward(self, x, data=None): + y = dict() if self.use_transform: x = self.transform(x) x = self.backbone(x) + y["backbone_out"] = x if self.use_neck: x = self.neck(x) + y["neck_out"] = x if data is None: x = self.head(x) else: x = self.head(x, data) - return x + y["head_out"] = x + if self.return_all_feats: + return y + else: + return x diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 5cc7abe7..e5729ea5 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -136,17 +136,17 @@ class DistillationCTCLabelDecode(CTCLabelDecode): character_type='ch', use_space_char=False, model_name="student", - key_out=None, + key=None, **kwargs): super(DistillationCTCLabelDecode, self).__init__( character_dict_path, character_type, use_space_char) self.model_name = model_name - self.key_out = key_out + self.key = key def __call__(self, preds, label=None, *args, **kwargs): pred = preds[self.model_name] - if self.key_out is not None: - pred = pred[self.key_out] + if self.key is not None: + pred = pred[self.key] return super().__call__(pred, label=label, *args, **kwargs)