diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py index 1c5d8a2b..b19f3f89 100644 --- a/ppocr/losses/distillation_loss.py +++ b/ppocr/losses/distillation_loss.py @@ -34,7 +34,8 @@ def _sum_loss(loss_dict): loss_dict["loss"] += value return loss_dict -# class DistillationDMLLoss(DMLLoss): + +class DistillationDMLLoss(DMLLoss): """ """ @@ -131,93 +132,6 @@ class DistillationCTCLoss(CTCLoss): return loss_dict -""" -class DistillationDBLoss(DBLoss): - def __init__(self, - model_name_list=[], - balance_loss=True, - main_loss_type='DiceLoss', - alpha=5, - beta=10, - ohem_ratio=3, - eps=1e-6, - name="db_loss", - **kwargs): - super().__init__() - self.model_name_list = model_name_list - self.name = name - - def forward(self, predicts, batch): - loss_dict = dict() - for idx, model_name in enumerate(self.model_name_list): - out = predicts[model_name] - if self.key is not None: - out = out[self.key] - - loss = super().forward(out, batch) - - if isinstance(loss, dict): - for key in loss.keys(): - if key == "loss": - continue - loss_dict[f"{self.name}_{model_name}_{key}"] = loss[key] - else: - loss_dict[f"{self.name}_{model_name}"] = loss - - loss_dict = _sum_loss(loss_dict) - return loss_dict - - -class DistillationDilaDBLoss(DBLoss): - def __init__(self, model_name_pairs=[], - balance_loss=True, - main_loss_type='DiceLoss', - alpha=5, - beta=10, - ohem_ratio=3, - eps=1e-6, - name="dila_dbloss"): - super().__init__() - self.model_name_pairs = model_name_pairs - self.name = name - - def forward(self, predicts, batch): - loss_dict = dict() - for idx, pair in enumerate(self.model_name_pairs): - stu_outs = predicts[pair[0]] - tch_outs = predicts[pair[1]] - if self.key is not None: - stu_preds = stu_outs[self.key] - tch_preds = tch_outs[self.key] - - stu_shrink_maps = stu_preds[:, 0, :, :] - stu_binary_maps = stu_preds[:, 2, :, :] - - # dilation to teacher prediction - dilation_w = np.array([[1,1], [1,1]]) - th_shrink_maps = tch_preds[:, 0, :, :] - th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3 - dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32) - for i in range(th_shrink_maps.shape[0]): - dilate_maps[i] = cv2.dilate(th_shrink_maps[i, :, :].astype(np.uint8), dilation_w) - th_shrink_maps = paddle.to_tensor(dilate_maps) - - label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[1:] - - # calculate the shrink map loss - bce_loss = self.alpha * self.bce_loss(stu_shrink_maps, th_shrink_maps, - label_shrink_mask) - loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps, - label_shrink_mask) - - k = f"{self.name}_{pair[0]}_{pair[1]}" - loss_dict[k] = bce_loss + loss_binary_maps - - loss_dict = _sum_loss(loss_dict) - return loss -""" - - class DistillationDistanceLoss(DistanceLoss): """ """