fix bug in predict_det for sast & update docs

This commit is contained in:
licx 2020-08-18 20:32:00 +08:00
parent d0ea95d6e4
commit 0c3b5d8e76
2 changed files with 21 additions and 4 deletions

View File

@ -296,7 +296,11 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model
<a name="其他模型推理"></a>
### 2. 其他模型推理
如果想尝试使用其他检测算法或者识别算法请参考上述文本检测模型推理和文本识别模型推理更新相应配置和模型下面给出基于EAST文本检测和STAR-Net文本识别执行命令
如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型。
**注意由于检测框矫正逻辑的局限性SAST弯曲文本检测模型使用参数`--det_sast_polygon=True`时)暂时无法用来模型串联。**
下面给出基于EAST文本检测和STAR-Net文本识别执行命令
```
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"

View File

@ -58,7 +58,8 @@ class TextDetector(object):
self.preprocess_op = SASTProcessTest(preprocess_params)
postprocess_params["score_thresh"] = args.det_sast_score_thresh
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
if args.det_sast_polygon:
self.det_sast_polygon = args.det_sast_polygon
if self.det_sast_polygon:
postprocess_params["sample_pts_num"] = 6
postprocess_params["expand_scale"] = 1.2
postprocess_params["shrink_ratio_of_width"] = 0.2
@ -99,7 +100,7 @@ class TextDetector(object):
return rect
def clip_det_res(self, points, img_height, img_width):
for pno in range(4):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
@ -118,6 +119,15 @@ class TextDetector(object):
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def __call__(self, img):
ori_im = img.copy()
im, ratio_list = self.preprocess_op(img)
@ -145,7 +155,10 @@ class TextDetector(object):
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
dt_boxes = dt_boxes_list[0]
# dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
if self.det_algorithm == "SAST" and self.det_sast_polygon:
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
elapse = time.time() - starttime
return dt_boxes, elapse