PaddleOCR/tools/eval_utils/eval_det_utils.py

140 lines
5.1 KiB
Python

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import paddle.fluid as fluid
__all__ = ['eval_det_run']
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
from ppocr.utils.utility import create_module
from .eval_det_iou import DetectionIoUEvaluator
import json
from copy import deepcopy
import cv2
from ppocr.data.reader_main import reader_main
def cal_det_res(exe, config, eval_info_dict):
global_params = config['Global']
save_res_path = global_params['save_res_path']
postprocess_params = deepcopy(config["PostProcess"])
postprocess_params.update(global_params)
postprocess = create_module(postprocess_params['function']) \
(params=postprocess_params)
with open(save_res_path, "wb") as fout:
tackling_num = 0
for data in eval_info_dict['reader']():
img_num = len(data)
tackling_num = tackling_num + img_num
logger.info("test tackling num:%d", tackling_num)
img_list = []
ratio_list = []
img_name_list = []
for ino in range(img_num):
img_list.append(data[ino][0])
ratio_list.append(data[ino][1])
img_name_list.append(data[ino][2])
img_list = np.concatenate(img_list, axis=0)
outs = exe.run(eval_info_dict['program'], \
feed={'image': img_list}, \
fetch_list=eval_info_dict['fetch_varname_list'])
outs_dict = {}
for tno in range(len(outs)):
fetch_name = eval_info_dict['fetch_name_list'][tno]
fetch_value = np.array(outs[tno])
outs_dict[fetch_name] = fetch_value
dt_boxes_list = postprocess(outs_dict, ratio_list)
for ino in range(img_num):
dt_boxes = dt_boxes_list[ino]
img_name = img_name_list[ino]
dt_boxes_json = []
for box in dt_boxes:
tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist()
dt_boxes_json.append(tmp_json)
otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
return
def load_label_infor(label_file_path, do_ignore=False):
img_name_label_dict = {}
with open(label_file_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
substr = line.decode().strip("\n").split("\t")
bbox_infor = json.loads(substr[1])
bbox_num = len(bbox_infor)
for bno in range(bbox_num):
text = bbox_infor[bno]['transcription']
ignore = False
if text == "###" and do_ignore:
ignore = True
bbox_infor[bno]['ignore'] = ignore
img_name_label_dict[substr[0]] = bbox_infor
return img_name_label_dict
def cal_det_metrics(gt_label_path, save_res_path):
"""
calculate the detection metrics
Args:
gt_label_path(string): The groundtruth detection label file path
save_res_path(string): The saved predicted detection label path
return:
claculated metrics including Hmean、precision and recall
"""
evaluator = DetectionIoUEvaluator()
gt_label_infor = load_label_infor(gt_label_path, do_ignore=True)
dt_label_infor = load_label_infor(save_res_path)
results = []
for img_name in gt_label_infor:
gt_label = gt_label_infor[img_name]
if img_name not in dt_label_infor:
dt_label = []
else:
dt_label = dt_label_infor[img_name]
result = evaluator.evaluate_image(gt_label, dt_label)
results.append(result)
methodMetrics = evaluator.combine_results(results)
return methodMetrics
def eval_det_run(exe, config, eval_info_dict, mode):
cal_det_res(exe, config, eval_info_dict)
save_res_path = config['Global']['save_res_path']
if mode == "eval":
gt_label_path = config['EvalReader']['label_file_path']
metrics = cal_det_metrics(gt_label_path, save_res_path)
else:
gt_label_path = config['TestReader']['label_file_path']
do_eval = config['TestReader']['do_eval']
if do_eval:
metrics = cal_det_metrics(gt_label_path, save_res_path)
else:
metrics = {}
return metrics