add fast postprocess
This commit is contained in:
parent
d036c91af1
commit
03895497fa
|
@ -146,7 +146,7 @@ python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img=
|
|||
```
|
||||
|
||||
### 预测推理
|
||||
#### (1).四边形文本检测模型(ICDAR2015)
|
||||
#### (1). 四边形文本检测模型(ICDAR2015)
|
||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换:
|
||||
```
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
|
||||
|
@ -160,7 +160,7 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
|
|||
|
||||
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
|
||||
|
||||
#### (2).弯曲文本检测模型(Total-Text)
|
||||
#### (2). 弯曲文本检测模型(Total-Text)
|
||||
对于弯曲文本样例
|
||||
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
||||
|
@ -170,3 +170,8 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
|
|||
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
||||
|
||||
#### (3). 精度与FPS
|
||||
|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS|
|
||||
| --- | --- | --- | --- | --- | --- | --- |
|
||||
|87.03|82.48|84.69|61.71|58.43|60.03|62.61|
|
||||
|
|
|
@ -173,3 +173,7 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
|
|||
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
|
||||
|
||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
||||
#### (3). Metric and FPS
|
||||
|det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS|
|
||||
| --- | --- | --- | --- | --- | --- | --- |
|
||||
|87.03|82.48|84.69|61.71|58.43|60.03|62.61|
|
||||
|
|
|
@ -21,7 +21,7 @@ import math
|
|||
|
||||
import numpy as np
|
||||
from itertools import groupby
|
||||
from cv2.ximgproc import thinning as thin
|
||||
from skimage.morphology._skeletonize import thin
|
||||
|
||||
|
||||
def get_dict(character_dict_path):
|
||||
|
@ -362,11 +362,10 @@ def generate_pivot_list_fast(p_score,
|
|||
"""
|
||||
p_score = p_score[0]
|
||||
f_direction = f_direction.transpose(1, 2, 0)
|
||||
ret, p_tcl_map = cv2.threshold(p_score, score_thresh, 255,
|
||||
cv2.THRESH_BINARY)
|
||||
skeleton_map = thin(p_tcl_map.astype('uint8'))
|
||||
p_tcl_map = (p_score > score_thresh) * 1.0
|
||||
skeleton_map = thin(p_tcl_map.astype(np.uint8))
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
skeleton_map, connectivity=8)
|
||||
skeleton_map.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
|
|
|
@ -21,7 +21,7 @@ import math
|
|||
|
||||
import numpy as np
|
||||
from itertools import groupby
|
||||
from cv2.ximgproc import thinning as thin
|
||||
from skimage.morphology._skeletonize import thin
|
||||
|
||||
|
||||
def get_dict(character_dict_path):
|
||||
|
@ -35,6 +35,64 @@ def get_dict(character_dict_path):
|
|||
return dict_character
|
||||
|
||||
|
||||
def point_pair2poly(point_pair_list):
|
||||
"""
|
||||
Transfer vertical point_pairs into poly point in clockwise.
|
||||
"""
|
||||
pair_length_list = []
|
||||
for point_pair in point_pair_list:
|
||||
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
|
||||
pair_length_list.append(pair_length)
|
||||
pair_length_list = np.array(pair_length_list)
|
||||
pair_info = (pair_length_list.max(), pair_length_list.min(),
|
||||
pair_length_list.mean())
|
||||
|
||||
point_num = len(point_pair_list) * 2
|
||||
point_list = [0] * point_num
|
||||
for idx, point_pair in enumerate(point_pair_list):
|
||||
point_list[idx] = point_pair[0]
|
||||
point_list[point_num - 1 - idx] = point_pair[1]
|
||||
return np.array(point_list).reshape(-1, 2), pair_info
|
||||
|
||||
|
||||
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
|
||||
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
|
||||
"""
|
||||
expand poly along width.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
left_quad = np.array(
|
||||
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||
right_quad = np.array(
|
||||
[
|
||||
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]
|
||||
],
|
||||
dtype=np.float32)
|
||||
right_ratio = 1.0 + \
|
||||
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||
poly[0] = left_quad_expand[0]
|
||||
poly[-1] = left_quad_expand[-1]
|
||||
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||
poly[point_num // 2] = right_quad_expand[2]
|
||||
return poly
|
||||
|
||||
|
||||
def softmax(logits):
|
||||
"""
|
||||
logits: N x d
|
||||
|
|
|
@ -16,9 +16,14 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import paddle
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
from extract_textpoint_slow import *
|
||||
from extract_textpoint_fast import *
|
||||
from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
|
||||
|
||||
|
||||
class PGNet_PostProcess(object):
|
||||
|
|
Loading…
Reference in New Issue