150 lines
5.6 KiB
Python
150 lines
5.6 KiB
Python
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import logging
|
|
import numpy as np
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
__all__ = ['eval_det_run']
|
|
|
|
import logging
|
|
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
|
|
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from ppocr.utils.utility import create_module
|
|
from .eval_det_iou import DetectionIoUEvaluator
|
|
import json
|
|
from copy import deepcopy
|
|
import cv2
|
|
from ppocr.data.reader_main import reader_main
|
|
import os
|
|
|
|
|
|
def cal_det_res(exe, config, eval_info_dict):
|
|
global_params = config['Global']
|
|
save_res_path = global_params['save_res_path']
|
|
postprocess_params = deepcopy(config["PostProcess"])
|
|
postprocess_params.update(global_params)
|
|
postprocess = create_module(postprocess_params['function']) \
|
|
(params=postprocess_params)
|
|
if not os.path.exists(os.path.dirname(save_res_path)):
|
|
os.makedirs(os.path.dirname(save_res_path))
|
|
with open(save_res_path, "wb") as fout:
|
|
tackling_num = 0
|
|
for data in eval_info_dict['reader']():
|
|
img_num = len(data)
|
|
tackling_num = tackling_num + img_num
|
|
logger.info("test tackling num:%d", tackling_num)
|
|
img_list = []
|
|
ratio_list = []
|
|
img_name_list = []
|
|
for ino in range(img_num):
|
|
img_list.append(data[ino][0])
|
|
ratio_list.append(data[ino][1])
|
|
img_name_list.append(data[ino][2])
|
|
try:
|
|
img_list = np.concatenate(img_list, axis=0)
|
|
except:
|
|
err = "concatenate error usually caused by different input image shapes in evaluation or testing.\n \
|
|
Please set \"test_batch_size_per_card\" in main yml as 1\n \
|
|
or add \"test_image_shape: [h, w]\" in reader yml for EvalReader."
|
|
|
|
raise Exception(err)
|
|
outs = exe.run(eval_info_dict['program'], \
|
|
feed={'image': img_list}, \
|
|
fetch_list=eval_info_dict['fetch_varname_list'])
|
|
outs_dict = {}
|
|
for tno in range(len(outs)):
|
|
fetch_name = eval_info_dict['fetch_name_list'][tno]
|
|
fetch_value = np.array(outs[tno])
|
|
outs_dict[fetch_name] = fetch_value
|
|
dt_boxes_list = postprocess(outs_dict, ratio_list)
|
|
for ino in range(img_num):
|
|
dt_boxes = dt_boxes_list[ino]
|
|
img_name = img_name_list[ino]
|
|
dt_boxes_json = []
|
|
for box in dt_boxes:
|
|
tmp_json = {"transcription": ""}
|
|
tmp_json['points'] = box.tolist()
|
|
dt_boxes_json.append(tmp_json)
|
|
otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
|
|
fout.write(otstr.encode())
|
|
return
|
|
|
|
|
|
def load_label_infor(label_file_path, do_ignore=False):
|
|
img_name_label_dict = {}
|
|
with open(label_file_path, "rb") as fin:
|
|
lines = fin.readlines()
|
|
for line in lines:
|
|
substr = line.decode().strip("\n").split("\t")
|
|
bbox_infor = json.loads(substr[1])
|
|
bbox_num = len(bbox_infor)
|
|
for bno in range(bbox_num):
|
|
text = bbox_infor[bno]['transcription']
|
|
ignore = False
|
|
if text == "###" and do_ignore:
|
|
ignore = True
|
|
bbox_infor[bno]['ignore'] = ignore
|
|
img_name_label_dict[os.path.basename(substr[0])] = bbox_infor
|
|
return img_name_label_dict
|
|
|
|
|
|
def cal_det_metrics(gt_label_path, save_res_path):
|
|
"""
|
|
calculate the detection metrics
|
|
Args:
|
|
gt_label_path(string): The groundtruth detection label file path
|
|
save_res_path(string): The saved predicted detection label path
|
|
return:
|
|
claculated metrics including Hmean, precision and recall
|
|
"""
|
|
evaluator = DetectionIoUEvaluator()
|
|
gt_label_infor = load_label_infor(gt_label_path, do_ignore=True)
|
|
dt_label_infor = load_label_infor(save_res_path)
|
|
results = []
|
|
for img_name in gt_label_infor:
|
|
gt_label = gt_label_infor[img_name]
|
|
if img_name not in dt_label_infor:
|
|
dt_label = []
|
|
else:
|
|
dt_label = dt_label_infor[img_name]
|
|
result = evaluator.evaluate_image(gt_label, dt_label)
|
|
results.append(result)
|
|
methodMetrics = evaluator.combine_results(results)
|
|
return methodMetrics
|
|
|
|
|
|
def eval_det_run(exe, config, eval_info_dict, mode):
|
|
cal_det_res(exe, config, eval_info_dict)
|
|
|
|
save_res_path = config['Global']['save_res_path']
|
|
if mode == "eval":
|
|
gt_label_path = config['EvalReader']['label_file_path']
|
|
metrics = cal_det_metrics(gt_label_path, save_res_path)
|
|
else:
|
|
gt_label_path = config['TestReader']['label_file_path']
|
|
do_eval = config['TestReader']['do_eval']
|
|
if do_eval:
|
|
metrics = cal_det_metrics(gt_label_path, save_res_path)
|
|
else:
|
|
metrics = {}
|
|
return metrics
|