modify infer tools for sast
This commit is contained in:
parent
c352e176f8
commit
f96b873aa4
|
@ -20,7 +20,5 @@ EvalReader:
|
||||||
TestReader:
|
TestReader:
|
||||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||||
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
||||||
infer_img:
|
infer_img: ./train_data/icdar2015/text_localization/ch4_test_images/img_11.jpg
|
||||||
img_set_dir: ./train_data/icdar2015/text_localization/
|
max_side_len: 1536
|
||||||
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
|
||||||
do_eval: True
|
|
||||||
|
|
|
@ -20,5 +20,5 @@ EvalReader:
|
||||||
TestReader:
|
TestReader:
|
||||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||||
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
||||||
infer_img:
|
infer_img: ./train_data/afs/total_text/Images/Test/img623.jpg
|
||||||
max_side_len: 768
|
max_side_len: 768
|
||||||
|
|
|
@ -49,7 +49,7 @@ class SASTHead(object):
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
g[i] = deconv_bn_layer(input=h[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g0')
|
g[i] = deconv_bn_layer(input=h[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g0')
|
||||||
print("g[{}] shape: {}".format(i, g[i].shape))
|
#print("g[{}] shape: {}".format(i, g[i].shape))
|
||||||
else:
|
else:
|
||||||
g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
|
g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
|
||||||
g[i] = fluid.layers.relu(g[i])
|
g[i] = fluid.layers.relu(g[i])
|
||||||
|
@ -58,7 +58,7 @@ class SASTHead(object):
|
||||||
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i],
|
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i],
|
||||||
filter_size=3, stride=1, act='relu', name='fpn_up_g%d_1'%i)
|
filter_size=3, stride=1, act='relu', name='fpn_up_g%d_1'%i)
|
||||||
g[i] = deconv_bn_layer(input=g[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g%d_2'%i)
|
g[i] = deconv_bn_layer(input=g[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g%d_2'%i)
|
||||||
print("g[{}] shape: {}".format(i, g[i].shape))
|
#print("g[{}] shape: {}".format(i, g[i].shape))
|
||||||
|
|
||||||
g[4] = fluid.layers.elementwise_add(x=g[3], y=h[4])
|
g[4] = fluid.layers.elementwise_add(x=g[3], y=h[4])
|
||||||
g[4] = fluid.layers.relu(g[4])
|
g[4] = fluid.layers.relu(g[4])
|
||||||
|
|
|
@ -22,10 +22,12 @@ from ppocr.utils.utility import initial_logger
|
||||||
logger = initial_logger()
|
logger = initial_logger()
|
||||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
import cv2
|
import cv2
|
||||||
|
from ppocr.data.det.sast_process import SASTProcessTest
|
||||||
from ppocr.data.det.east_process import EASTProcessTest
|
from ppocr.data.det.east_process import EASTProcessTest
|
||||||
from ppocr.data.det.db_process import DBProcessTest
|
from ppocr.data.det.db_process import DBProcessTest
|
||||||
from ppocr.postprocess.db_postprocess import DBPostProcess
|
from ppocr.postprocess.db_postprocess import DBPostProcess
|
||||||
from ppocr.postprocess.east_postprocess import EASTPostPocess
|
from ppocr.postprocess.east_postprocess import EASTPostPocess
|
||||||
|
from ppocr.postprocess.sast_postprocess import SASTPostProcess
|
||||||
import copy
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
|
@ -52,6 +54,14 @@ class TextDetector(object):
|
||||||
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
|
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
|
||||||
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
|
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
|
||||||
self.postprocess_op = EASTPostPocess(postprocess_params)
|
self.postprocess_op = EASTPostPocess(postprocess_params)
|
||||||
|
elif self.det_algorithm == "SAST":
|
||||||
|
self.preprocess_op = SASTProcessTest(preprocess_params)
|
||||||
|
postprocess_params["score_thresh"] = args.det_sast_score_thresh
|
||||||
|
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
|
||||||
|
postprocess_params["sample_pts_num"] = args.det_sast_sample_pts_num
|
||||||
|
postprocess_params["expand_scale"] = args.det_sast_expand_scale
|
||||||
|
postprocess_params["shrink_ratio_of_width"] = args.det_sast_shrink_ratio_of_width
|
||||||
|
self.postprocess_op = SASTPostProcess(postprocess_params)
|
||||||
else:
|
else:
|
||||||
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
|
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -120,8 +130,14 @@ class TextDetector(object):
|
||||||
if self.det_algorithm == "EAST":
|
if self.det_algorithm == "EAST":
|
||||||
outs_dict['f_geo'] = outputs[0]
|
outs_dict['f_geo'] = outputs[0]
|
||||||
outs_dict['f_score'] = outputs[1]
|
outs_dict['f_score'] = outputs[1]
|
||||||
|
elif self.det_algorithm == 'SAST':
|
||||||
|
outs_dict['f_border'] = outputs[0]
|
||||||
|
outs_dict['f_score'] = outputs[1]
|
||||||
|
outs_dict['f_tco'] = outputs[2]
|
||||||
|
outs_dict['f_tvo'] = outputs[3]
|
||||||
else:
|
else:
|
||||||
outs_dict['maps'] = outputs[0]
|
outs_dict['maps'] = outputs[0]
|
||||||
|
|
||||||
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
|
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
|
||||||
dt_boxes = dt_boxes_list[0]
|
dt_boxes = dt_boxes_list[0]
|
||||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||||
|
|
|
@ -53,6 +53,13 @@ def parse_args():
|
||||||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||||
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
||||||
|
|
||||||
|
#SAST parmas
|
||||||
|
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
|
||||||
|
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
||||||
|
parser.add_argument("--det_sast_sample_pts_num", type=float, default=2)
|
||||||
|
parser.add_argument("--det_sast_expand_scale", type=float, default=1.0)
|
||||||
|
parser.add_argument("--det_sast_shrink_ratio_of_width", type=float, default=0.3)
|
||||||
|
|
||||||
#params for text recognizer
|
#params for text recognizer
|
||||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||||
parser.add_argument("--rec_model_dir", type=str)
|
parser.add_argument("--rec_model_dir", type=str)
|
||||||
|
|
|
@ -66,6 +66,25 @@ def draw_det_res(dt_boxes, config, img, img_name):
|
||||||
cv2.imwrite(save_path, src_im)
|
cv2.imwrite(save_path, src_im)
|
||||||
logger.info("The detected Image saved in {}".format(save_path))
|
logger.info("The detected Image saved in {}".format(save_path))
|
||||||
|
|
||||||
|
def gen_im_detection(src_im, detections):
|
||||||
|
"""
|
||||||
|
Generate image with detection results.
|
||||||
|
"""
|
||||||
|
im_detection = src_im.copy()
|
||||||
|
|
||||||
|
h, w, _ = im_detection.shape
|
||||||
|
thickness = int(max((h + w) / 2000, 1))
|
||||||
|
|
||||||
|
for poly in detections:
|
||||||
|
# Draw the first point
|
||||||
|
cv2.putText(im_detection, '0', org=(int(poly[0, 0]), int(poly[0, 1])),
|
||||||
|
fontFace=cv2.FONT_HERSHEY_COMPLEX, fontScale=thickness, color=(255, 0, 0),
|
||||||
|
thickness=thickness)
|
||||||
|
|
||||||
|
cv2.polylines(im_detection, np.array(poly).reshape((1, -1, 2)).astype(np.int32), isClosed=True,
|
||||||
|
color=(0, 0, 255), thickness=thickness)
|
||||||
|
|
||||||
|
return im_detection
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
config = program.load_config(FLAGS.config)
|
config = program.load_config(FLAGS.config)
|
||||||
|
@ -134,8 +153,10 @@ def main():
|
||||||
dic = {'f_score': outs[0], 'f_geo': outs[1]}
|
dic = {'f_score': outs[0], 'f_geo': outs[1]}
|
||||||
elif config['Global']['algorithm'] == 'DB':
|
elif config['Global']['algorithm'] == 'DB':
|
||||||
dic = {'maps': outs[0]}
|
dic = {'maps': outs[0]}
|
||||||
|
elif config['Global']['algorithm'] == 'SAST':
|
||||||
|
dic = {'f_score': outs[0], 'f_border': outs[1], 'f_tvo': outs[2], 'f_tco': outs[3]}
|
||||||
else:
|
else:
|
||||||
raise Exception("only support algorithm: ['EAST', 'DB']")
|
raise Exception("only support algorithm: ['EAST', 'DB', 'SAST']")
|
||||||
dt_boxes_list = postprocess(dic, ratio_list)
|
dt_boxes_list = postprocess(dic, ratio_list)
|
||||||
for ino in range(img_num):
|
for ino in range(img_num):
|
||||||
dt_boxes = dt_boxes_list[ino]
|
dt_boxes = dt_boxes_list[ino]
|
||||||
|
@ -149,7 +170,7 @@ def main():
|
||||||
fout.write(otstr.encode())
|
fout.write(otstr.encode())
|
||||||
src_img = cv2.imread(img_name)
|
src_img = cv2.imread(img_name)
|
||||||
draw_det_res(dt_boxes, config, src_img, img_name)
|
draw_det_res(dt_boxes, config, src_img, img_name)
|
||||||
|
|
||||||
logger.info("success!")
|
logger.info("success!")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue