diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 0792cde0..b04c8bc3 100644 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -24,6 +24,7 @@ import string import cv2 from shapely.geometry import Polygon import pyclipper +from copy import deepcopy class DBPostProcess(object): @@ -39,13 +40,15 @@ class DBPostProcess(object): self.min_size = 3 self.dilation_kernel = np.array([[1, 1], [1, 1]]) - def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): - ''' - _bitmap: single map with shape (1, H, W), - whose values are binarized as {0, 1} - ''' - - bitmap = _bitmap + def boxes_from_bitmap(self, pred, mask): + """ + Get boxes from the binarized image predicted by DB. + :param pred: the binarized image predicted by DB. + :param mask: new 'pred' after threshold filtering. + :return: (boxes, the score of each boxes) + """ + dest_height, dest_width = pred.shape[-2:] + bitmap = deepcopy(mask) height, width = bitmap.shape outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, @@ -87,6 +90,11 @@ class DBPostProcess(object): return boxes, scores def unclip(self, box): + """ + Shrink or expand the boxaccording to 'unclip_ratio' + :param box: The predicted box. + :return: uncliped box + """ unclip_ratio = self.unclip_ratio poly = Polygon(box) distance = poly.area * unclip_ratio / poly.length @@ -96,6 +104,11 @@ class DBPostProcess(object): return expanded def get_mini_boxes(self, contour): + """ + Get boxes from the contour or box. + :param contour: The predicted contour. + :return: The predicted box. + """ bounding_box = cv2.minAreaRect(contour) points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) @@ -119,6 +132,12 @@ class DBPostProcess(object): return box, min(bounding_box[1]) def box_score_fast(self, bitmap, _box): + """ + Calculate the score of box. + :param bitmap: The binarized image predicted by DB. + :param _box: The predicted box + :return: score + """ h, w = bitmap.shape[:2] box = _box.copy() xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) @@ -137,13 +156,14 @@ class DBPostProcess(object): pred = pred[:, 0, :, :] segmentation = pred > self.thresh - boxes_batch = [] for batch_index in range(pred.shape[0]): - height, width = pred.shape[-2:] - mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel) - tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, width, height) + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], + mask) boxes = [] for k in range(len(tmp_boxes)):