add comment db post_process

This commit is contained in:
LDOUBLEV 2020-10-28 10:28:34 +08:00
parent 5299d8e62d
commit 725185cd6a
1 changed files with 31 additions and 11 deletions

View File

@ -24,6 +24,7 @@ import string
import cv2
from shapely.geometry import Polygon
import pyclipper
from copy import deepcopy
class DBPostProcess(object):
@ -39,13 +40,15 @@ class DBPostProcess(object):
self.min_size = 3
self.dilation_kernel = np.array([[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
def boxes_from_bitmap(self, pred, mask):
"""
Get boxes from the binarized image predicted by DB.
:param pred: the binarized image predicted by DB.
:param mask: new 'pred' after threshold filtering.
:return: (boxes, the score of each boxes)
"""
dest_height, dest_width = pred.shape[-2:]
bitmap = deepcopy(mask)
height, width = bitmap.shape
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
@ -87,6 +90,11 @@ class DBPostProcess(object):
return boxes, scores
def unclip(self, box):
"""
Shrink or expand the boxaccording to 'unclip_ratio'
:param box: The predicted box.
:return: uncliped box
"""
unclip_ratio = self.unclip_ratio
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
@ -96,6 +104,11 @@ class DBPostProcess(object):
return expanded
def get_mini_boxes(self, contour):
"""
Get boxes from the contour or box.
:param contour: The predicted contour.
:return: The predicted box.
"""
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
@ -119,6 +132,12 @@ class DBPostProcess(object):
return box, min(bounding_box[1])
def box_score_fast(self, bitmap, _box):
"""
Calculate the score of box.
:param bitmap: The binarized image predicted by DB.
:param _box: The predicted box
:return: score
"""
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
@ -137,13 +156,14 @@ class DBPostProcess(object):
pred = pred[:, 0, :, :]
segmentation = pred > self.thresh
boxes_batch = []
for batch_index in range(pred.shape[0]):
height, width = pred.shape[-2:]
mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel)
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, width, height)
mask = cv2.dilate(
np.array(segmentation[batch_index]).astype(np.uint8),
self.dilation_kernel)
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index],
mask)
boxes = []
for k in range(len(tmp_boxes)):