diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 1e8aa0d8..1c5d8a2b 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -18,19 +18,60 @@ import paddle.nn as nn from .rec_ctc_loss import CTCLoss from .basic_loss import DMLLoss from .basic_loss import DistanceLoss +from .det_db_loss import DBLoss +from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss -class DistillationDMLLoss(DMLLoss): +def _sum_loss(loss_dict): + if "loss" in loss_dict.keys(): + return loss_dict + else: + loss_dict["loss"] = 0. + for k, value in loss_dict.items(): + if k == "loss": + continue + else: + loss_dict["loss"] += value + return loss_dict + +# class DistillationDMLLoss(DMLLoss): """ """ - def __init__(self, model_name_pairs=[], act=None, key=None, + def __init__(self, + model_name_pairs=[], + act=None, + key=None, + maps_name=None, name="loss_dml"): super().__init__(act=act) assert isinstance(model_name_pairs, list) self.key = key self.model_name_pairs = model_name_pairs self.name = name + self.maps_name = self.maps_name + + def _check_maps_name(self, maps_name): + if maps_name is None: + return None + elif type(maps_name) == str: + return [maps_name] + elif type(maps_name) == list: + return [maps_name] + else: + return None + + def _slice_out(self, outs): + new_outs = {} + for k in self.maps_name: + if k == "thrink_maps": + new_outs[k] = paddle.slice(outs, axes=1, starts=0, ends=1) + elif k == "threshold_maps": + new_outs[k] = paddle.slice(outs, axes=1, starts=1, ends=2) + elif k == "binary_maps": + new_outs[k] = paddle.slice(outs, axes=1, starts=2, ends=3) + else: + continue def forward(self, predicts, batch): loss_dict = dict() @@ -40,13 +81,30 @@ class DistillationDMLLoss(DMLLoss): if self.key is not None: out1 = out1[self.key] out2 = out2[self.key] - loss = super().forward(out1, out2) - if isinstance(loss, dict): - for key in loss: - loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], - idx)] = loss[key] + + if self.maps_name is None: + loss = super().forward(out1, out2) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1], + idx)] = loss[key] + else: + loss_dict["{}_{}".format(self.name, idx)] = loss else: - loss_dict["{}_{}".format(self.name, idx)] = loss + outs1 = self._slice_out(out1) + outs2 = self._slice_out(out2) + for k in outs1.keys(): + loss = super().forward(outs1[k], outs2[k]) + if isinstance(loss, dict): + for key in loss: + loss_dict["{}_{}_{}_{}_{}".format(key, pair[ + 0], pair[1], map_name, idx)] = loss[key] + else: + loss_dict["{}_{}_{}".format(self.name, map_name, + idx)] = loss + + loss_dict = _sum_loss(loss_dict) + return loss_dict @@ -73,6 +131,93 @@ class DistillationCTCLoss(CTCLoss): return loss_dict +""" +class DistillationDBLoss(DBLoss): + def __init__(self, + model_name_list=[], + balance_loss=True, + main_loss_type='DiceLoss', + alpha=5, + beta=10, + ohem_ratio=3, + eps=1e-6, + name="db_loss", + **kwargs): + super().__init__() + self.model_name_list = model_name_list + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, model_name in enumerate(self.model_name_list): + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + + loss = super().forward(out, batch) + + if isinstance(loss, dict): + for key in loss.keys(): + if key == "loss": + continue + loss_dict[f"{self.name}_{model_name}_{key}"] = loss[key] + else: + loss_dict[f"{self.name}_{model_name}"] = loss + + loss_dict = _sum_loss(loss_dict) + return loss_dict + + +class DistillationDilaDBLoss(DBLoss): + def __init__(self, model_name_pairs=[], + balance_loss=True, + main_loss_type='DiceLoss', + alpha=5, + beta=10, + ohem_ratio=3, + eps=1e-6, + name="dila_dbloss"): + super().__init__() + self.model_name_pairs = model_name_pairs + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for idx, pair in enumerate(self.model_name_pairs): + stu_outs = predicts[pair[0]] + tch_outs = predicts[pair[1]] + if self.key is not None: + stu_preds = stu_outs[self.key] + tch_preds = tch_outs[self.key] + + stu_shrink_maps = stu_preds[:, 0, :, :] + stu_binary_maps = stu_preds[:, 2, :, :] + + # dilation to teacher prediction + dilation_w = np.array([[1,1], [1,1]]) + th_shrink_maps = tch_preds[:, 0, :, :] + th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3 + dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32) + for i in range(th_shrink_maps.shape[0]): + dilate_maps[i] = cv2.dilate(th_shrink_maps[i, :, :].astype(np.uint8), dilation_w) + th_shrink_maps = paddle.to_tensor(dilate_maps) + + label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[1:] + + # calculate the shrink map loss + bce_loss = self.alpha * self.bce_loss(stu_shrink_maps, th_shrink_maps, + label_shrink_mask) + loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps, + label_shrink_mask) + + k = f"{self.name}_{pair[0]}_{pair[1]}" + loss_dict[k] = bce_loss + loss_binary_maps + + loss_dict = _sum_loss(loss_dict) + return loss +""" + + class DistillationDistanceLoss(DistanceLoss): """ """