72 lines
2.3 KiB
Python
Executable File
72 lines
2.3 KiB
Python
Executable File
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import sys
|
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(__dir__)
|
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
|
|
|
import cv2
|
|
import json
|
|
from tqdm import tqdm
|
|
from ppstructure.table.table_metric import TEDS
|
|
from ppstructure.table.predict_table import TableSystem
|
|
from ppstructure.utility import init_args
|
|
|
|
|
|
def parse_args():
|
|
parser = init_args()
|
|
parser.add_argument("--gt_path", type=str)
|
|
return parser.parse_args()
|
|
|
|
def main(gt_path, img_root, args):
|
|
teds = TEDS(n_jobs=16)
|
|
|
|
text_sys = TableSystem(args)
|
|
jsons_gt = json.load(open(gt_path)) # gt
|
|
pred_htmls = []
|
|
gt_htmls = []
|
|
for img_name in tqdm(jsons_gt):
|
|
if img_name != 'PMC1064865_002_00.png':
|
|
continue
|
|
# 读取信息
|
|
img = cv2.imread(os.path.join(img_root,img_name))
|
|
pred_html = text_sys(img)
|
|
pred_htmls.append(pred_html)
|
|
|
|
gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
|
|
gt_html, gt = get_gt_html(gt_structures, contents_with_block) # 获取HTMLgt
|
|
gt_htmls.append(gt_html)
|
|
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) # 计算teds
|
|
print('teds:', sum(scores) / len(scores))
|
|
|
|
|
|
def get_gt_html(gt_structures, contents_with_block):
|
|
end_html = []
|
|
td_index = 0
|
|
for tag in gt_structures:
|
|
if '</td>' in tag:
|
|
if contents_with_block[td_index] != []:
|
|
end_html.extend(contents_with_block[td_index])
|
|
end_html.append(tag)
|
|
td_index += 1
|
|
else:
|
|
end_html.append(tag)
|
|
return ''.join(end_html), end_html
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
main(args.gt_path,args.image_dir, args)
|