Merge pull request #756 from LDOUBLEV/fixocr

advance db post_process
This commit is contained in:
Double_V 2020-09-20 00:24:37 +08:00 committed by GitHub
commit d64a4c3f6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 18 additions and 12 deletions

View File

@ -108,9 +108,11 @@ void DBDetector::Run(cv::Mat &img,
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
cv::Mat dilation_map;
cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2,2));
cv::dilate(bit_map, dilation_map, dila_ele);
boxes = post_processor_.BoxesFromBitmap(
pred_map, bit_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_);
pred_map, dilation_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_);
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);

View File

@ -294,7 +294,7 @@ PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
pow(boxes[n][0][1] - boxes[n][1][1], 2)));
rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
pow(boxes[n][0][1] - boxes[n][3][1], 2)));
if (rect_width <= 10 || rect_height <= 10)
if (rect_width <= 4 || rect_height <= 4)
continue;
root_points.push_back(boxes[n]);
}

View File

@ -10,7 +10,7 @@ use_zero_copy_run 1
max_side_len 960
det_db_thresh 0.3
det_db_box_thresh 0.5
det_db_unclip_ratio 2.0
det_db_unclip_ratio 1.6
det_model_dir ./inference/det_db
# cls config

View File

@ -1,4 +1,4 @@
max_side_len 960
det_db_thresh 0.3
det_db_box_thresh 0.5
det_db_unclip_ratio 2.0
det_db_unclip_ratio 1.6

View File

@ -293,7 +293,7 @@ FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes, float ratio_h,
rect_height =
static_cast<int>(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
pow(boxes[n][0][1] - boxes[n][3][1], 2)));
if (rect_width <= 10 || rect_height <= 10)
if (rect_width <= 4 || rect_height <= 4)
continue;
root_points.push_back(boxes[n]);
}

View File

@ -289,8 +289,10 @@ RunDetModel(std::shared_ptr<PaddlePredictor> predictor, cv::Mat img,
const double maxvalue = 255;
cv::Mat bit_map;
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
auto boxes = BoxesFromBitmap(pred_map, bit_map, Config);
cv::Mat dilation_map;
cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2,2));
cv::dilate(bit_map, dilation_map, dila_ele);
auto boxes = BoxesFromBitmap(pred_map, dilation_map, Config);
std::vector<std::vector<std::vector<int>>> filter_boxes =
FilterTagDetRes(boxes, ratio_hw[0], ratio_hw[1], srcimg);

View File

@ -37,6 +37,7 @@ class DBPostProcess(object):
self.max_candidates = params['max_candidates']
self.unclip_ratio = params['unclip_ratio']
self.min_size = 3
self.dilation_kernel = np.array([[1, 1], [1, 1]])
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
@ -140,8 +141,9 @@ class DBPostProcess(object):
boxes_batch = []
for batch_index in range(pred.shape[0]):
height, width = pred.shape[-2:]
tmp_boxes, tmp_scores = self.boxes_from_bitmap(
pred[batch_index], segmentation[batch_index], width, height)
mask = cv2.dilate(np.array(segmentation[batch_index]).astype(np.uint8), self.dilation_kernel)
tmp_boxes, tmp_scores = self.boxes_from_bitmap(pred[batch_index], mask, width, height)
boxes = []
for k in range(len(tmp_boxes)):

View File

@ -47,7 +47,7 @@ def parse_args():
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
@ -64,7 +64,7 @@ def parse_args():
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=30)
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
"--rec_char_dict_path",