add comment db post_process
This commit is contained in:
parent
5299d8e62d
commit
725185cd6a
|
@ -24,6 +24,7 @@ import string
|
|||
import cv2
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class DBPostProcess(object):
|
||||
|
@ -39,13 +40,15 @@ class DBPostProcess(object):
|
|||
self.min_size = 3
|
||||
self.dilation_kernel = np.array([[1, 1], [1, 1]])
|
||||
|
||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
def boxes_from_bitmap(self, pred, mask):
|
||||
"""
|
||||
Get boxes from the binarized image predicted by DB.
|
||||
:param pred: the binarized image predicted by DB.
|
||||
:param mask: new 'pred' after threshold filtering.
|
||||
:return: (boxes, the score of each boxes)
|
||||
"""
|
||||
dest_height, dest_width = pred.shape[-2:]
|
||||
bitmap = deepcopy(mask)
|
||||
height, width = bitmap.shape
|
||||
|
||||
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
||||
|
@ -87,6 +90,11 @@ class DBPostProcess(object):
|
|||
return boxes, scores
|
||||
|
||||
def unclip(self, box):
|
||||
"""
|
||||
Shrink or expand the boxaccording to 'unclip_ratio'
|
||||
:param box: The predicted box.
|
||||
:return: uncliped box
|
||||
"""
|
||||
unclip_ratio = self.unclip_ratio
|
||||
poly = Polygon(box)
|
||||
distance = poly.area * unclip_ratio / poly.length
|
||||
|
@ -96,6 +104,11 @@ class DBPostProcess(object):
|
|||
return expanded
|
||||
|
||||
def get_mini_boxes(self, contour):
|
||||
"""
|
||||
Get boxes from the contour or box.
|
||||
:param contour: The predicted contour.
|
||||
:return: The predicted box.
|
||||
"""
|
||||
bounding_box = cv2.minAreaRect(contour)
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
|
@ -119,6 +132,12 @@ class DBPostProcess(object):
|
|||
return box, min(bounding_box[1])
|
||||
|
||||
def box_score_fast(self, bitmap, _box):
|
||||
"""
|
||||
Calculate the score of box.
|
||||
:param bitmap: The binarized image predicted by DB.
|
||||
:param _box: The predicted box
|
||||
:return: score
|
||||
"""
|
||||
h, w = bitmap.shape[:2]
|
||||
box = _box.copy()
|
||||
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
|
||||
|
@ -137,13 +156,14 @@ class DBPostProcess(object):
|
|||
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
boxes_batch = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
height, width = pred.shape[-2:]
|
||||
|
||||
mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel)
|
||||
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, width, height)
|
||||
mask = cv2.dilate(
|
||||
np.array(segmentation[batch_index]).astype(np.uint8),
|
||||
self.dilation_kernel)
|
||||
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index],
|
||||
mask)
|
||||
|
||||
boxes = []
|
||||
for k in range(len(tmp_boxes)):
|
||||
|
|
Loading…
Reference in New Issue