remove unused code

This commit is contained in:
WenmuZhou 2021-06-08 18:01:39 +08:00
parent 3b81e30490
commit dfba983c82
3 changed files with 10 additions and 30 deletions

View File

@ -382,23 +382,12 @@ class TableLabelDecode(object):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
max_len = self.max_text_length
current_dict = self.dict_idx_character
else:
max_len = self.max_elem_length
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
# select_td_tokens = []
# select_span_tokens = []
# for elem in self.dict_elem:
# # if elem == '<td>' or elem == '<td' or elem == '<tr>'\
# # or 'rowspan' in elem or 'colspan' in elem:
# if elem == '<td>' or elem == '<td' or elem == '<tr>':
# select_td_tokens.append(self.dict_elem[elem])
# if 'rowspan' in elem or 'colspan' in elem:
# select_span_tokens.append(self.dict_elem[elem])
result_list = []
result_pos_list = []
result_score_list = []
@ -415,12 +404,7 @@ class TableLabelDecode(object):
break
if tmp_elem_idx in ignored_tokens:
continue
# if tmp_elem_idx in select_td_tokens:
# total_td_score += structure_probs[batch_idx, idx]
# total_td_num += 1
# if tmp_elem_idx in select_span_tokens:
# total_span_score += structure_probs[batch_idx, idx]
# total_span_num += 1
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])

View File

@ -38,15 +38,15 @@ def main(gt_path, img_root, args):
pred_htmls = []
gt_htmls = []
for img_name in tqdm(jsons_gt):
# 读取信息
# read image
img = cv2.imread(os.path.join(img_root,img_name))
pred_html = text_sys(img)
pred_htmls.append(pred_html)
gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
gt_html, gt = get_gt_html(gt_structures, contents_with_block) # 获取HTMLgt
gt_html, gt = get_gt_html(gt_structures, contents_with_block)
gt_htmls.append(gt_html)
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) # 计算teds
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
print('teds:', sum(scores) / len(scores))

View File

@ -2,14 +2,9 @@ import json
def distance(box_1, box_2):
x1, y1, x2, y2 = box_1
x3, y3, x4, y4 = box_2
# min_x = (x1 + x2) / 2
# min_y = (y1 + y2) / 2
# max_x = (x3 + x4) / 2
# max_y = (y3 + y4) / 2
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
dis_2 = abs(x3 - x1) + abs(y3 - y1)
dis_3 = abs(x4- x2) + abs(y4 - y2)
#dis = pow(min_x - max_x, 2) + pow(min_y - max_y, 2) + pow(x3 - x1, 2) + pow(y3 - y1, 2) + pow(x4- x2, 2) + pow(y4 - y2, 2) + abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
return dis + min(dis_2, dis_3)
def compute_iou(rec1, rec2):
@ -21,7 +16,6 @@ def compute_iou(rec1, rec2):
:return: scala value of IoU
"""
# computing area of each rectangles
rec1, rec2 = rec1 * 1000, rec2 * 1000
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
@ -36,29 +30,31 @@ def compute_iou(rec1, rec2):
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0
return 0.0
else:
intersect = (right_line - left_line) * (bottom_line - top_line)
return (intersect / (sum_area - intersect))*1.0
def matcher_merge(ocr_bboxes, pred_bboxes): # ocr_bboxes: OCR pred_bboxes端到端
def matcher_merge(ocr_bboxes, pred_bboxes):
all_dis = []
ious = []
matched = {}
for i, gt_box in enumerate(ocr_bboxes):
distances = []
for j, pred_box in enumerate(pred_bboxes):
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) #获取两两cell之间的L1距离和 1- IOU
# compute l1 distence and IOU between two boxes
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box)))
sorted_distances = distances.copy()
# 根据距离和IOU挑选最"近"的cell
# select nearest cell
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
return matched#, sum(ious) / len(ious)
def complex_num(pred_bboxes):
complex_nums = []
for bbox in pred_bboxes: