diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 594197a6..85ce580f 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -382,23 +382,12 @@ class TableLabelDecode(object): """convert text-label into text-index. """ if char_or_elem == "char": - max_len = self.max_text_length current_dict = self.dict_idx_character else: - max_len = self.max_elem_length current_dict = self.dict_idx_elem ignored_tokens = self.get_ignored_tokens('elem') beg_idx, end_idx = ignored_tokens - # select_td_tokens = [] - # select_span_tokens = [] - # for elem in self.dict_elem: - # # if elem == '' or elem == ''\ - # # or 'rowspan' in elem or 'colspan' in elem: - # if elem == '' or elem == '': - # select_td_tokens.append(self.dict_elem[elem]) - # if 'rowspan' in elem or 'colspan' in elem: - # select_span_tokens.append(self.dict_elem[elem]) result_list = [] result_pos_list = [] result_score_list = [] @@ -415,12 +404,7 @@ class TableLabelDecode(object): break if tmp_elem_idx in ignored_tokens: continue - # if tmp_elem_idx in select_td_tokens: - # total_td_score += structure_probs[batch_idx, idx] - # total_td_num += 1 - # if tmp_elem_idx in select_span_tokens: - # total_span_score += structure_probs[batch_idx, idx] - # total_span_num += 1 + char_list.append(current_dict[tmp_elem_idx]) elem_pos_list.append(idx) score_list.append(structure_probs[batch_idx, idx]) diff --git a/ppstructure/table/eval_table.py b/ppstructure/table/eval_table.py index 00b9cd51..1bcbaa8d 100755 --- a/ppstructure/table/eval_table.py +++ b/ppstructure/table/eval_table.py @@ -38,15 +38,15 @@ def main(gt_path, img_root, args): pred_htmls = [] gt_htmls = [] for img_name in tqdm(jsons_gt): - # 读取信息 + # read image img = cv2.imread(os.path.join(img_root,img_name)) pred_html = text_sys(img) pred_htmls.append(pred_html) gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name] - gt_html, gt = get_gt_html(gt_structures, contents_with_block) # 获取HTMLgt + gt_html, gt = get_gt_html(gt_structures, contents_with_block) gt_htmls.append(gt_html) - scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) # 计算teds + scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) print('teds:', sum(scores) / len(scores)) diff --git a/ppstructure/table/matcher.py b/ppstructure/table/matcher.py index b3c70430..c3b56384 100755 --- a/ppstructure/table/matcher.py +++ b/ppstructure/table/matcher.py @@ -2,14 +2,9 @@ import json def distance(box_1, box_2): x1, y1, x2, y2 = box_1 x3, y3, x4, y4 = box_2 - # min_x = (x1 + x2) / 2 - # min_y = (y1 + y2) / 2 - # max_x = (x3 + x4) / 2 - # max_y = (y3 + y4) / 2 dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2) dis_2 = abs(x3 - x1) + abs(y3 - y1) dis_3 = abs(x4- x2) + abs(y4 - y2) - #dis = pow(min_x - max_x, 2) + pow(min_y - max_y, 2) + pow(x3 - x1, 2) + pow(y3 - y1, 2) + pow(x4- x2, 2) + pow(y4 - y2, 2) + abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2) return dis + min(dis_2, dis_3) def compute_iou(rec1, rec2): @@ -21,7 +16,6 @@ def compute_iou(rec1, rec2): :return: scala value of IoU """ # computing area of each rectangles - rec1, rec2 = rec1 * 1000, rec2 * 1000 S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) @@ -36,29 +30,31 @@ def compute_iou(rec1, rec2): # judge if there is an intersect if left_line >= right_line or top_line >= bottom_line: - return 0 + return 0.0 else: intersect = (right_line - left_line) * (bottom_line - top_line) return (intersect / (sum_area - intersect))*1.0 -def matcher_merge(ocr_bboxes, pred_bboxes): # ocr_bboxes: OCR pred_bboxes:端到端 +def matcher_merge(ocr_bboxes, pred_bboxes): all_dis = [] ious = [] matched = {} for i, gt_box in enumerate(ocr_bboxes): distances = [] for j, pred_box in enumerate(pred_bboxes): - distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) #获取两两cell之间的L1距离和 1- IOU + # compute l1 distence and IOU between two boxes + distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) sorted_distances = distances.copy() - # 根据距离和IOU挑选最"近"的cell + # select nearest cell sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0])) if distances.index(sorted_distances[0]) not in matched.keys(): matched[distances.index(sorted_distances[0])] = [i] else: matched[distances.index(sorted_distances[0])].append(i) return matched#, sum(ious) / len(ious) + def complex_num(pred_bboxes): complex_nums = [] for bbox in pred_bboxes: