Merge branch 'develop' into fix_infer
This commit is contained in:
commit
f8515609f3
|
@ -15,7 +15,7 @@ EvalReader:
|
||||||
TestReader:
|
TestReader:
|
||||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||||
process_function: ppocr.data.det.db_process,DBProcessTest
|
process_function: ppocr.data.det.db_process,DBProcessTest
|
||||||
single_img_path:
|
infer_img:
|
||||||
img_set_dir: ./train_data/icdar2015/text_localization/
|
img_set_dir: ./train_data/icdar2015/text_localization/
|
||||||
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||||
test_image_shape: [736, 1280]
|
test_image_shape: [736, 1280]
|
||||||
|
|
|
@ -17,7 +17,7 @@ EvalReader:
|
||||||
TestReader:
|
TestReader:
|
||||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||||
process_function: ppocr.data.det.east_process,EASTProcessTest
|
process_function: ppocr.data.det.east_process,EASTProcessTest
|
||||||
single_img_path:
|
infer_img:
|
||||||
img_set_dir: ./train_data/icdar2015/text_localization/
|
img_set_dir: ./train_data/icdar2015/text_localization/
|
||||||
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||||
do_eval: True
|
do_eval: True
|
||||||
|
|
|
@ -10,4 +10,4 @@ EvalReader:
|
||||||
TestReader:
|
TestReader:
|
||||||
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
|
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
|
||||||
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
|
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
|
||||||
infer_img:
|
infer_img: ./infer_img
|
||||||
|
|
|
@ -79,10 +79,10 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./ou
|
||||||
|
|
||||||
测试单张图像的检测效果
|
测试单张图像的检测效果
|
||||||
```
|
```
|
||||||
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.single_img_path="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
|
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
|
||||||
```
|
```
|
||||||
|
|
||||||
测试文件夹下所有图像的检测效果
|
测试文件夹下所有图像的检测效果
|
||||||
```
|
```
|
||||||
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.single_img_path="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
|
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
|
||||||
```
|
```
|
||||||
|
|
|
@ -200,7 +200,7 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model
|
||||||
如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型,下面给出基于EAST文本检测和STAR-Net文本识别执行命令:
|
如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型,下面给出基于EAST文本检测和STAR-Net文本识别执行命令:
|
||||||
|
|
||||||
```
|
```
|
||||||
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/rec/" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
|
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
|
||||||
```
|
```
|
||||||
|
|
||||||
执行命令后,识别结果图像如下:
|
执行命令后,识别结果图像如下:
|
||||||
|
|
|
@ -84,7 +84,7 @@ class EvalTestReader(object):
|
||||||
img_path = os.path.join(img_set_dir, img_name)
|
img_path = os.path.join(img_set_dir, img_name)
|
||||||
img_list.append(img_path)
|
img_list.append(img_path)
|
||||||
else:
|
else:
|
||||||
img_path = self.params['single_img_path']
|
img_path = self.params['infer_img']
|
||||||
img_list = get_image_file_list(img_path)
|
img_list = get_image_file_list(img_path)
|
||||||
|
|
||||||
def batch_iter_reader():
|
def batch_iter_reader():
|
||||||
|
|
|
@ -78,6 +78,7 @@ def main():
|
||||||
'fetch_name_list':eval_fetch_name_list,\
|
'fetch_name_list':eval_fetch_name_list,\
|
||||||
'fetch_varname_list':eval_fetch_varname_list}
|
'fetch_varname_list':eval_fetch_varname_list}
|
||||||
metrics = eval_det_run(exe, config, eval_info_dict, "eval")
|
metrics = eval_det_run(exe, config, eval_info_dict, "eval")
|
||||||
|
print("Eval result", metrics)
|
||||||
else:
|
else:
|
||||||
reader_type = config['Global']['reader_yml']
|
reader_type = config['Global']['reader_yml']
|
||||||
if "benchmark" not in reader_type:
|
if "benchmark" not in reader_type:
|
||||||
|
|
|
@ -34,6 +34,7 @@ import json
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import cv2
|
import cv2
|
||||||
from ppocr.data.reader_main import reader_main
|
from ppocr.data.reader_main import reader_main
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def cal_det_res(exe, config, eval_info_dict):
|
def cal_det_res(exe, config, eval_info_dict):
|
||||||
|
@ -43,6 +44,8 @@ def cal_det_res(exe, config, eval_info_dict):
|
||||||
postprocess_params.update(global_params)
|
postprocess_params.update(global_params)
|
||||||
postprocess = create_module(postprocess_params['function']) \
|
postprocess = create_module(postprocess_params['function']) \
|
||||||
(params=postprocess_params)
|
(params=postprocess_params)
|
||||||
|
if not os.path.exists(os.path.dirname(save_res_path)):
|
||||||
|
os.makedirs(os.path.dirname(save_res_path))
|
||||||
with open(save_res_path, "wb") as fout:
|
with open(save_res_path, "wb") as fout:
|
||||||
tackling_num = 0
|
tackling_num = 0
|
||||||
for data in eval_info_dict['reader']():
|
for data in eval_info_dict['reader']():
|
||||||
|
@ -93,7 +96,7 @@ def load_label_infor(label_file_path, do_ignore=False):
|
||||||
if text == "###" and do_ignore:
|
if text == "###" and do_ignore:
|
||||||
ignore = True
|
ignore = True
|
||||||
bbox_infor[bno]['ignore'] = ignore
|
bbox_infor[bno]['ignore'] = ignore
|
||||||
img_name_label_dict[substr[0]] = bbox_infor
|
img_name_label_dict[os.path.basename(substr[0])] = bbox_infor
|
||||||
return img_name_label_dict
|
return img_name_label_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@ class TextRecognizer(object):
|
||||||
image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
||||||
self.rec_image_shape = image_shape
|
self.rec_image_shape = image_shape
|
||||||
self.character_type = args.rec_char_type
|
self.character_type = args.rec_char_type
|
||||||
|
self.rec_batch_num = args.rec_batch_num
|
||||||
char_ops_params = {}
|
char_ops_params = {}
|
||||||
char_ops_params["character_type"] = args.rec_char_type
|
char_ops_params["character_type"] = args.rec_char_type
|
||||||
char_ops_params["character_dict_path"] = args.rec_char_dict_path
|
char_ops_params["character_dict_path"] = args.rec_char_dict_path
|
||||||
|
@ -59,8 +60,8 @@ class TextRecognizer(object):
|
||||||
|
|
||||||
def __call__(self, img_list):
|
def __call__(self, img_list):
|
||||||
img_num = len(img_list)
|
img_num = len(img_list)
|
||||||
batch_num = 30
|
|
||||||
rec_res = []
|
rec_res = []
|
||||||
|
batch_num = self.rec_batch_num
|
||||||
predict_time = 0
|
predict_time = 0
|
||||||
for beg_img_no in range(0, img_num, batch_num):
|
for beg_img_no in range(0, img_num, batch_num):
|
||||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||||
|
|
|
@ -89,7 +89,7 @@ def sorted_boxes(dt_boxes):
|
||||||
sorted boxes(array) with shape [4, 2]
|
sorted boxes(array) with shape [4, 2]
|
||||||
"""
|
"""
|
||||||
num_boxes = dt_boxes.shape[0]
|
num_boxes = dt_boxes.shape[0]
|
||||||
sorted_boxes = sorted(dt_boxes, key=lambda x: x[0][1])
|
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||||
_boxes = list(sorted_boxes)
|
_boxes = list(sorted_boxes)
|
||||||
|
|
||||||
for i in range(num_boxes - 1):
|
for i in range(num_boxes - 1):
|
||||||
|
|
|
@ -56,6 +56,7 @@ def parse_args():
|
||||||
parser.add_argument("--rec_model_dir", type=str)
|
parser.add_argument("--rec_model_dir", type=str)
|
||||||
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||||
parser.add_argument("--rec_char_type", type=str, default='ch')
|
parser.add_argument("--rec_char_type", type=str, default='ch')
|
||||||
|
parser.add_argument("--rec_batch_num", type=int, default=30)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--rec_char_dict_path",
|
"--rec_char_dict_path",
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -172,7 +173,8 @@ def draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5):
|
||||||
continue
|
continue
|
||||||
font = ImageFont.truetype(
|
font = ImageFont.truetype(
|
||||||
"./doc/simfang.ttf", font_size, encoding="utf-8")
|
"./doc/simfang.ttf", font_size, encoding="utf-8")
|
||||||
new_txt = str(count) + ': ' + txt + ' ' + '%.3f' % (scores[count])
|
new_txt = str(count) + ': ' + txt + ' ' + '%.3f' % (
|
||||||
|
scores[count])
|
||||||
draw_txt.text(
|
draw_txt.text(
|
||||||
(20, gap * (count + 1)), new_txt, txt_color, font=font)
|
(20, gap * (count + 1)), new_txt, txt_color, font=font)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
|
@ -106,7 +106,6 @@ def main():
|
||||||
with open(save_res_path, "wb") as fout:
|
with open(save_res_path, "wb") as fout:
|
||||||
|
|
||||||
test_reader = reader_main(config=config, mode='test')
|
test_reader = reader_main(config=config, mode='test')
|
||||||
# image_file_list = get_image_file_list(args.image_dir)
|
|
||||||
tackling_num = 0
|
tackling_num = 0
|
||||||
for data in test_reader():
|
for data in test_reader():
|
||||||
img_num = len(data)
|
img_num = len(data)
|
||||||
|
@ -135,7 +134,7 @@ def main():
|
||||||
elif config['Global']['algorithm'] == 'DB':
|
elif config['Global']['algorithm'] == 'DB':
|
||||||
dic = {'maps': outs[0]}
|
dic = {'maps': outs[0]}
|
||||||
else:
|
else:
|
||||||
raise Exception("only support algorithm: ['EAST', 'BD']")
|
raise Exception("only support algorithm: ['EAST', 'DB']")
|
||||||
dt_boxes_list = postprocess(dic, ratio_list)
|
dt_boxes_list = postprocess(dic, ratio_list)
|
||||||
for ino in range(img_num):
|
for ino in range(img_num):
|
||||||
dt_boxes = dt_boxes_list[ino]
|
dt_boxes = dt_boxes_list[ino]
|
||||||
|
|
Loading…
Reference in New Issue