diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index be7529d7..0a232f7a 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -11,7 +11,7 @@ Global: # from static branch, load_static_weights must be set as True. # 2. If you want to finetune the pretrained models we provide in the docs, # you should set load_static_weights as False. - load_static_weights: True + load_static_weights: False cal_metric_during_train: False pretrained_model: checkpoints: @@ -94,7 +94,7 @@ Eval: label_file_list: [./train_data/total_text/test/] transforms: - DecodeImage: # load image - img_mode: BGR + img_mode: RGB channel_first: False - E2ELabelEncode: - E2EResizeForTest: diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index cbb11009..47e0cbf0 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -200,16 +200,18 @@ class E2ELabelEncode(BaseRecLabelEncode): self.pad_num = len(self.dict) # the length to pad def __call__(self, data): + text_label_index_list, temp_text = [], [] texts = data['strs'] - temp_texts = [] for text in texts: text = text.lower() - text = self.encode(text) - if text is None: - return None - text = text + [self.pad_num] * (self.max_text_len - len(text)) - temp_texts.append(text) - data['strs'] = np.array(temp_texts) + temp_text = [] + for c_ in text: + if c_ in self.dict: + temp_text.append(self.dict[c_]) + temp_text = temp_text + [self.pad_num] * (self.max_text_len - + len(temp_text)) + text_label_index_list.append(temp_text) + data['strs'] = np.array(text_label_index_list) return data diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index 10109512..ae063835 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -24,6 +24,7 @@ class PGDataSet(Dataset): self.logger = logger self.seed = seed + self.mode = mode global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] @@ -62,10 +63,13 @@ class PGDataSet(Dataset): with open(poly_txt_path) as f: for line in f.readlines(): poly_str, txt = line.strip().split('\t') - poly = map(float, poly_str.split(',')) + poly = list(map(float, poly_str.split(','))) + if self.mode.lower() == "eval": + while len(poly) < 100: + poly.append(-1) text_polys.append( np.array( - list(poly), dtype=np.float32).reshape(-1, 2)) + poly, dtype=np.float32).reshape(-1, 2)) txts.append(txt) txt_tags.append(txt == '###') @@ -135,8 +139,12 @@ class PGDataSet(Dataset): try: if self.data_format == 'icdar': im_path = os.path.join(data_path, 'rgb', data_line) - poly_path = os.path.join(data_path, 'poly', - data_line.split('.')[0] + '.txt') + if self.mode.lower() == "eval": + poly_path = os.path.join(data_path, 'poly_gt', + data_line.split('.')[0] + '.txt') + else: + poly_path = os.path.join(data_path, 'poly', + data_line.split('.')[0] + '.txt') text_polys, text_tags, text_strs = self.extract_polys(poly_path) else: image_dir = os.path.join(os.path.dirname(data_path), 'image') diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py index 75ffbfb0..684d7742 100644 --- a/ppocr/metrics/e2e_metric.py +++ b/ppocr/metrics/e2e_metric.py @@ -33,10 +33,20 @@ class E2EMetric(object): self.reset() def __call__(self, preds, batch, **kwargs): - gt_polyons_batch = batch[2] + temp_gt_polyons_batch = batch[2] temp_gt_strs_batch = batch[3] ignore_tags_batch = batch[4] + gt_polyons_batch = [] gt_strs_batch = [] + + temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist() + for temp_list in temp_gt_polyons_batch: + t = [] + for index in temp_list: + if index[0] != -1 and index[1] != -1: + t.append(index) + gt_polyons_batch.append(t) + temp_gt_strs_batch = temp_gt_strs_batch[0].tolist() for temp_list in temp_gt_strs_batch: t = "" @@ -46,7 +56,7 @@ class E2EMetric(object): gt_strs_batch.append(t) for pred, gt_polyons, gt_strs, ignore_tags in zip( - [preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch): + [preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch): # prepare gt gt_info_list = [{ 'points': gt_polyon, diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py index 2cc7dc24..d9c0048f 100644 --- a/ppocr/postprocess/pg_postprocess.py +++ b/ppocr/postprocess/pg_postprocess.py @@ -23,7 +23,8 @@ __dir__ = os.path.dirname(__file__) sys.path.append(__dir__) sys.path.append(os.path.join(__dir__, '..')) -from ppocr.utils.e2e_utils.extract_textpoint import get_dict, generate_pivot_list, restore_poly +from ppocr.utils.e2e_utils.extract_textpoint import * +from ppocr.utils.e2e_utils.visual import * import paddle @@ -37,6 +38,11 @@ class PGPostProcess(object): self.valid_set = valid_set self.score_thresh = score_thresh + # c++ la-nms is faster, but only support python 3.5 + self.is_python35 = False + if sys.version_info.major == 3 and sys.version_info.minor == 5: + self.is_python35 = True + def __call__(self, outs_dict, shape_list): p_score = outs_dict['f_score'] p_border = outs_dict['f_border'] @@ -52,17 +58,96 @@ class PGPostProcess(object): p_border = p_border[0] p_direction = p_direction[0] p_char = p_char[0] - src_h, src_w, ratio_h, ratio_w = shape_list[0] - instance_yxs_list, seq_strs = generate_pivot_list( + is_curved = self.valid_set == "totaltext" + instance_yxs_list = generate_pivot_list( p_score, p_char, p_direction, - self.Lexicon_Table, - score_thresh=self.score_thresh) - poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs, - p_border, ratio_w, ratio_h, - src_w, src_h, self.valid_set) + score_thresh=self.score_thresh, + is_backbone=True, + is_curved=is_curved) + p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0)) + char_seq_idx_set = [] + for i in range(len(instance_yxs_list)): + gather_info_lod = paddle.to_tensor(instance_yxs_list[i]) + f_char_map = paddle.transpose(p_char, [0, 2, 3, 1]) + feature_seq = paddle.gather_nd(f_char_map, gather_info_lod) + feature_seq = np.expand_dims(feature_seq.numpy(), axis=0) + feature_len = [len(feature_seq[0])] + featyre_seq = paddle.to_tensor(feature_seq) + feature_len = np.array([feature_len]).astype(np.int64) + length = paddle.to_tensor(feature_len) + seq_pred = paddle.fluid.layers.ctc_greedy_decoder( + input=featyre_seq, blank=36, input_length=length) + seq_pred_str = seq_pred[0].numpy().tolist()[0] + seq_len = seq_pred[1].numpy()[0][0] + temp_t = [] + for c in seq_pred_str[:seq_len]: + temp_t.append(c) + char_seq_idx_set.append(temp_t) + seq_strs = [] + for char_idx_set in char_seq_idx_set: + pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set]) + seq_strs.append(pr_str) + poly_list = [] + keep_str_list = [] + all_point_list = [] + all_point_pair_list = [] + for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): + if len(yx_center_line) == 1: + yx_center_line.append(yx_center_line[-1]) + + offset_expand = 1.0 + if self.valid_set == 'totaltext': + offset_expand = 1.2 + + point_pair_list = [] + for batch_id, y, x in yx_center_line: + offset = p_border[:, y, x].reshape(2, 2) + if offset_expand != 1.0: + offset_length = np.linalg.norm( + offset, axis=1, keepdims=True) + expand_length = np.clip( + offset_length * (offset_expand - 1), + a_min=0.5, + a_max=3.0) + offset_detal = offset / offset_length * expand_length + offset = offset + offset_detal + ori_yx = np.array([y, x], dtype=np.float32) + point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array( + [ratio_w, ratio_h]).reshape(-1, 2) + point_pair_list.append(point_pair) + + all_point_list.append([ + int(round(x * 4.0 / ratio_w)), + int(round(y * 4.0 / ratio_h)) + ]) + all_point_pair_list.append(point_pair.round().astype(np.int32) + .tolist()) + + detected_poly, pair_length_info = point_pair2poly(point_pair_list) + detected_poly = expand_poly_along_width( + detected_poly, shrink_ratio_of_width=0.2) + detected_poly[:, 0] = np.clip( + detected_poly[:, 0], a_min=0, a_max=src_w) + detected_poly[:, 1] = np.clip( + detected_poly[:, 1], a_min=0, a_max=src_h) + + if len(keep_str) < 2: + continue + + keep_str_list.append(keep_str) + if self.valid_set == 'partvgg': + middle_point = len(detected_poly) // 2 + detected_poly = detected_poly[ + [0, middle_point - 1, middle_point, -1], :] + poly_list.append(detected_poly) + elif self.valid_set == 'totaltext': + poly_list.append(detected_poly) + else: + print('--> Not supported format.') + exit(-1) data = { 'points': poly_list, 'strs': keep_str_list, diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py index 37fa5c00..8033a9ff 100755 --- a/ppocr/utils/e2e_metric/Deteval.py +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -35,7 +35,7 @@ def get_socre(gt_dict, pred_dict): gt = [] n = len(gt_dict) for i in range(n): - points = gt_dict[i]['points'].tolist() + points = gt_dict[i]['points'] h = len(points) text = gt_dict[i]['text'] xx = [ @@ -51,7 +51,7 @@ def get_socre(gt_dict, pred_dict): t_y.append(points[j][1]) xx[1] = np.array([t_x], dtype='int16') xx[3] = np.array([t_y], dtype='int16') - if text != "": + if text != "" and "#" not in text: xx[4] = np.array([text], dtype='U{}'.format(len(text))) xx[5] = np.array(['c'], dtype=' tr) - gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[ - 0].shape[0] - gt_matching_qualified_tau_candidates = np.where( - local_tau_table[gt_id, :] > tp) - gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[ - 0].shape[0] - - det_matching_qualified_sigma_candidates = np.where( - local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] - > tr) - det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[ - 0].shape[0] - det_matching_qualified_tau_candidates = np.where( - local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > - tp) - det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[ - 0].shape[0] - - if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \ - (det_matching_num_qualified_sigma_candidates == 1) and ( - det_matching_num_qualified_tau_candidates == 1): - global_accumulative_recall = global_accumulative_recall + 1.0 - global_accumulative_precision = global_accumulative_precision + 1.0 - local_accumulative_recall = local_accumulative_recall + 1.0 - local_accumulative_precision = local_accumulative_precision + 1.0 - - gt_flag[0, gt_id] = 1 - matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) - # recg start - - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ - 0]] - - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - # recg end - det_flag[0, matched_det_id] = 1 - return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num - - def one_to_many(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): - hit_str_num = 0 - for gt_id in range(num_gt): - # skip the following if the groundtruth was matched - if gt_flag[0, gt_id] > 0: - continue - - non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0) - num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0] - - if num_non_zero_in_sigma >= k: - ####search for all detections that overlaps with this groundtruth - qualified_tau_candidates = np.where((local_tau_table[ - gt_id, :] >= tp) & (det_flag[0, :] == 0)) - num_qualified_tau_candidates = qualified_tau_candidates[ - 0].shape[0] - - if num_qualified_tau_candidates == 1: - if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) - and - (local_sigma_table[gt_id, qualified_tau_candidates] >= - tr)): - # became an one-to-one case - global_accumulative_recall = global_accumulative_recall + 1.0 - global_accumulative_precision = global_accumulative_precision + 1.0 - local_accumulative_recall = local_accumulative_recall + 1.0 - local_accumulative_precision = local_accumulative_precision + 1.0 - - gt_flag[0, gt_id] = 1 - det_flag[0, qualified_tau_candidates] = 1 - # recg start - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][ - qualified_tau_candidates[0].tolist()[0]] - - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - # recg end - elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) - >= tr): - gt_flag[0, gt_id] = 1 - det_flag[0, qualified_tau_candidates] = 1 - # recg start - - gt_str_cur = global_gt_str[idy][gt_id] - pred_str_cur = global_pred_str[idy][ - qualified_tau_candidates[0].tolist()[0]] - - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - # recg end - - global_accumulative_recall = global_accumulative_recall + fsc_k - global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k - - local_accumulative_recall = local_accumulative_recall + fsc_k - local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k - - return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num - - def many_to_one(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idy): - hit_str_num = 0 - for det_id in range(num_det): - # skip the following if the detection was matched - if det_flag[0, det_id] > 0: - continue - - non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0) - num_non_zero_in_tau = non_zero_in_tau[0].shape[0] - - if num_non_zero_in_tau >= k: - ####search for all detections that overlaps with this groundtruth - qualified_sigma_candidates = np.where(( - local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0)) - num_qualified_sigma_candidates = qualified_sigma_candidates[ - 0].shape[0] - - if num_qualified_sigma_candidates == 1: - if ((local_tau_table[qualified_sigma_candidates, det_id] >= - tp) and - (local_sigma_table[qualified_sigma_candidates, det_id] - >= tr)): - # became an one-to-one case - global_accumulative_recall = global_accumulative_recall + 1.0 - global_accumulative_precision = global_accumulative_precision + 1.0 - local_accumulative_recall = local_accumulative_recall + 1.0 - local_accumulative_precision = local_accumulative_precision + 1.0 - - gt_flag[0, qualified_sigma_candidates] = 1 - det_flag[0, det_id] = 1 - # recg start - pred_str_cur = global_pred_str[idy][det_id] - gt_len = len(qualified_sigma_candidates[0]) - for idx in range(gt_len): - ele_gt_id = qualified_sigma_candidates[0].tolist()[ - idx] - if ele_gt_id not in global_gt_str[idy]: - continue - gt_str_cur = global_gt_str[idy][ele_gt_id] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - break - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - break - # recg end - elif (np.sum(local_tau_table[qualified_sigma_candidates, - det_id]) >= tp): - det_flag[0, det_id] = 1 - gt_flag[0, qualified_sigma_candidates] = 1 - # recg start - - pred_str_cur = global_pred_str[idy][det_id] - gt_len = len(qualified_sigma_candidates[0]) - for idx in range(gt_len): - ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] - if ele_gt_id not in global_gt_str[idy]: - continue - gt_str_cur = global_gt_str[idy][ele_gt_id] - if pred_str_cur == gt_str_cur: - hit_str_num += 1 - break - else: - if pred_str_cur.lower() == gt_str_cur.lower(): - hit_str_num += 1 - break - # recg end - - global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k - global_accumulative_precision = global_accumulative_precision + fsc_k - - local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k - local_accumulative_precision = local_accumulative_precision + fsc_k - return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num + global_sigma = local_sigma_table + global_tau = local_tau_table + global_pred_str = local_pred_str + global_gt_str = local_gt_str single_data = {} - for idx in range(len(global_sigma)): - local_sigma_table = global_sigma[idx] - local_tau_table = global_tau[idx] - - num_gt = local_sigma_table.shape[0] - num_det = local_sigma_table.shape[1] - - total_num_gt = total_num_gt + num_gt - total_num_det = total_num_det + num_det - - local_accumulative_recall = 0 - local_accumulative_precision = 0 - gt_flag = np.zeros((1, num_gt)) - det_flag = np.zeros((1, num_det)) - - #######first check for one-to-one case########## - local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ - gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) - - hit_str_count += hit_str_num - #######then check for one-to-many case########## - local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ - gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) - hit_str_count += hit_str_num - #######then check for many-to-one case########## - local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ - gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, - local_accumulative_recall, local_accumulative_precision, - global_accumulative_recall, global_accumulative_precision, - gt_flag, det_flag, idx) - - hit_str_count += hit_str_num - - # fid = open(fid_path, 'a+') - try: - local_precision = local_accumulative_precision / num_det - except ZeroDivisionError: - local_precision = 0 - - try: - local_recall = local_accumulative_recall / num_gt - except ZeroDivisionError: - local_recall = 0 - - try: - local_f_score = 2 * local_precision * local_recall / ( - local_precision + local_recall) - except ZeroDivisionError: - local_f_score = 0 - single_data['sigma'] = global_sigma single_data['global_tau'] = global_tau single_data['global_pred_str'] = global_pred_str single_data['global_gt_str'] = global_gt_str - single_data["recall"] = local_recall - single_data['precision'] = local_precision - single_data['f_score'] = local_f_score return single_data @@ -435,10 +163,10 @@ def combine_results(all_data): global_pred_str = [] global_gt_str = [] for data in all_data: - global_sigma.append(data['sigma'][0]) - global_tau.append(data['global_tau'][0]) - global_pred_str.append(data['global_pred_str'][0]) - global_gt_str.append(data['global_gt_str'][0]) + global_sigma.append(data['sigma']) + global_tau.append(data['global_tau']) + global_pred_str.append(data['global_pred_str']) + global_gt_str.append(data['global_gt_str']) global_accumulative_recall = 0 global_accumulative_precision = 0 @@ -676,6 +404,8 @@ def combine_results(all_data): local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, idx) + hit_str_count += hit_str_num + try: recall = global_accumulative_recall / total_num_gt except ZeroDivisionError: diff --git a/ppocr/utils/e2e_utils/extract_textpoint.py b/ppocr/utils/e2e_utils/extract_textpoint.py index d64f1e83..975ca161 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint.py +++ b/ppocr/utils/e2e_utils/extract_textpoint.py @@ -17,9 +17,11 @@ from __future__ import division from __future__ import print_function import cv2 +import math + import numpy as np from itertools import groupby -from cv2.ximgproc import thinning as thin +from skimage.morphology._skeletonize import thin def get_dict(character_dict_path): @@ -33,39 +35,87 @@ def get_dict(character_dict_path): return dict_character -def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): +def softmax(logits): + """ + logits: N x d + """ + max_value = np.max(logits, axis=1, keepdims=True) + exp = np.exp(logits - max_value) + exp_sum = np.sum(exp, axis=1, keepdims=True) + dist = exp / exp_sum + return dist + + +def get_keep_pos_idxs(labels, remove_blank=None): + """ + Remove duplicate and get pos idxs of keep items. + The value of keep_blank should be [None, 95]. + """ + duplicate_len_list = [] + keep_pos_idx_list = [] + keep_char_idx_list = [] + for k, v_ in groupby(labels): + current_len = len(list(v_)) + if k != remove_blank: + current_idx = int(sum(duplicate_len_list) + current_len // 2) + keep_pos_idx_list.append(current_idx) + keep_char_idx_list.append(k) + duplicate_len_list.append(current_len) + return keep_char_idx_list, keep_pos_idx_list + + +def remove_blank(labels, blank=0): + new_labels = [x for x in labels if x != blank] + return new_labels + + +def insert_blank(labels, blank=0): + new_labels = [blank] + for l in labels: + new_labels += [l, blank] + return new_labels + + +def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): + """ + CTC greedy (best path) decoder. + """ + raw_str = np.argmax(np.array(probs_seq), axis=1) + remove_blank_in_pos = None if keep_blank_in_idxs else blank + dedup_str, keep_idx_list = get_keep_pos_idxs( + raw_str, remove_blank=remove_blank_in_pos) + dst_str = remove_blank(dedup_str, blank=blank) + return dst_str, keep_idx_list + + +def instance_ctc_greedy_decoder(gather_info, + logits_map, + keep_blank_in_idxs=True): + """ + gather_info: [[x, y], [x, y] ...] + logits_map: H x W X (n_chars + 1) + """ _, _, C = logits_map.shape ys, xs = zip(*gather_info) - logits_seq = logits_map[list(ys), list(xs)] - probs_seq = logits_seq - labels = np.argmax(probs_seq, axis=1) - dst_str = [k for k, v_ in groupby(labels) if k != C - 1] - detal = len(gather_info) // (pts_num - 1) - keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1] + logits_seq = logits_map[list(ys), list(xs)] # n x 96 + probs_seq = softmax(logits_seq) + dst_str, keep_idx_list = ctc_greedy_decoder( + probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs) keep_gather_list = [gather_info[idx] for idx in keep_idx_list] return dst_str, keep_gather_list -def ctc_decoder_for_image(gather_info_list, - logits_map, - Lexicon_Table, - pts_num=6): +def ctc_decoder_for_image(gather_info_list, logits_map, + keep_blank_in_idxs=True): """ CTC decoder using multiple processes. """ - decoder_str = [] - decoder_xys = [] + decoder_results = [] for gather_info in gather_info_list: - if len(gather_info) < pts_num: - continue - dst_str, xys_list = instance_ctc_greedy_decoder( - gather_info, logits_map, pts_num=pts_num) - dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str]) - if len(dst_str_readable) < 2: - continue - decoder_str.append(dst_str_readable) - decoder_xys.append(xys_list) - return decoder_str, decoder_xys + res = instance_ctc_greedy_decoder( + gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs) + decoder_results.append(res) + return decoder_results def sort_with_direction(pos_list, f_direction): @@ -107,6 +157,58 @@ def sort_with_direction(pos_list, f_direction): return sorted_point, np.array(sorted_direction) +def add_id(pos_list, image_id=0): + """ + Add id for gather feature, for inference. + """ + new_list = [] + for item in pos_list: + new_list.append((image_id, item[0], item[1])) + return new_list + + +def sort_and_expand_with_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + h, w, _ = f_direction.shape + sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + + # expand along + point_num = len(sorted_list) + sub_direction_len = max(point_num // 3, 2) + left_direction = point_direction[:sub_direction_len, :] + right_dirction = point_direction[point_num - sub_direction_len:, :] + + left_average_direction = -np.mean(left_direction, axis=0, keepdims=True) + left_average_len = np.linalg.norm(left_average_direction) + left_start = np.array(sorted_list[0]) + left_step = left_average_direction / (left_average_len + 1e-6) + + right_average_direction = np.mean(right_dirction, axis=0, keepdims=True) + right_average_len = np.linalg.norm(right_average_direction) + right_step = right_average_direction / (right_average_len + 1e-6) + right_start = np.array(sorted_list[-1]) + + append_num = max( + int((left_average_len + right_average_len) / 2.0 * 0.15), 1) + left_list = [] + right_list = [] + for i in range(append_num): + ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ly < h and lx < w and (ly, lx) not in left_list: + left_list.append((ly, lx)) + ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype( + 'int32').tolist() + if ry < h and rx < w and (ry, rx) not in right_list: + right_list.append((ry, rx)) + + all_list = left_list[::-1] + sorted_list + right_list + return all_list + + def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): """ f_direction: h x w x 2 @@ -116,6 +218,7 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): h, w, _ = f_direction.shape sorted_list, point_direction = sort_with_direction(pos_list, f_direction) + # expand along point_num = len(sorted_list) sub_direction_len = max(point_num // 3, 2) left_direction = point_direction[:sub_direction_len, :] @@ -159,108 +262,258 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map): return all_list -def point_pair2poly(point_pair_list): - """ - Transfer vertical point_pairs into poly point in clockwise. - """ - point_num = len(point_pair_list) * 2 - point_list = [0] * point_num - for idx, point_pair in enumerate(point_pair_list): - point_list[idx] = point_pair[0] - point_list[point_num - 1 - idx] = point_pair[1] - return np.array(point_list).reshape(-1, 2) - - -def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.): - ratio_pair = np.array( - [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) - p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair - p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair - return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) - - -def expand_poly_along_width(poly, shrink_ratio_of_width=0.3): - """ - expand poly along width. - """ - point_num = poly.shape[0] - left_quad = np.array( - [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) - left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ - (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) - left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0) - right_quad = np.array( - [ - poly[point_num // 2 - 2], poly[point_num // 2 - 1], - poly[point_num // 2], poly[point_num // 2 + 1] - ], - dtype=np.float32) - right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ - (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) - right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio) - poly[0] = left_quad_expand[0] - poly[-1] = left_quad_expand[-1] - poly[point_num // 2 - 1] = right_quad_expand[1] - poly[point_num // 2] = right_quad_expand[2] - return poly - - -def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, - src_h, valid_set): - poly_list = [] - keep_str_list = [] - for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): - if len(keep_str) < 2: - print('--> too short, {}'.format(keep_str)) - continue - - offset_expand = 1.0 - if valid_set == 'totaltext': - offset_expand = 1.2 - - point_pair_list = [] - for y, x in yx_center_line: - offset = p_border[:, y, x].reshape(2, 2) * offset_expand - ori_yx = np.array([y, x], dtype=np.float32) - point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array( - [ratio_w, ratio_h]).reshape(-1, 2) - point_pair_list.append(point_pair) - - detected_poly = point_pair2poly(point_pair_list) - detected_poly = expand_poly_along_width( - detected_poly, shrink_ratio_of_width=0.2) - detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w) - detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h) - - keep_str_list.append(keep_str) - if valid_set == 'partvgg': - middle_point = len(detected_poly) // 2 - detected_poly = detected_poly[ - [0, middle_point - 1, middle_point, -1], :] - poly_list.append(detected_poly) - elif valid_set == 'totaltext': - poly_list.append(detected_poly) - else: - print('--> Not supported format.') - exit(-1) - return poly_list, keep_str_list - - -def generate_pivot_list(p_score, - p_char_maps, - f_direction, - Lexicon_Table, - score_thresh=0.5): +def generate_pivot_list_curved(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_expand=True, + is_backbone=False, + image_id=0): """ return center point and end point of TCL instance; filter with the char maps; """ p_score = p_score[0] f_direction = f_direction.transpose(1, 2, 0) - ret, p_tcl_map = cv2.threshold(p_score, score_thresh, 255, - cv2.THRESH_BINARY) - skeleton_map = thin(p_tcl_map.astype('uint8')) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map) instance_count, instance_label_map = cv2.connectedComponents( - skeleton_map, connectivity=8) + skeleton_map.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + center_pos_yxs = [] + end_points_yxs = [] + instance_center_pos_yxs = [] + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + ### FIX-ME, eliminate outlier + if len(pos_list) < 3: + continue + + if is_expand: + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, p_tcl_map) + else: + pos_list_sorted, _ = sort_with_direction(pos_list, f_direction) + all_pos_yxs.append(pos_list_sorted) + + # use decoder to filter backgroud points. + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decode_res = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True) + for decoded_str, keep_yxs_list in decode_res: + if is_backbone: + keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) + instance_center_pos_yxs.append(keep_yxs_list_with_id) + else: + end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) + center_pos_yxs.extend(keep_yxs_list) + + if is_backbone: + return instance_center_pos_yxs + else: + return center_pos_yxs, end_points_yxs + + +def generate_pivot_list_horizontal(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map_bi = (p_score > score_thresh) * 1.0 + instance_count, instance_label_map = cv2.connectedComponents( + p_tcl_map_bi.astype(np.uint8), connectivity=8) + + # get TCL Instance + all_pos_yxs = [] + center_pos_yxs = [] + end_points_yxs = [] + instance_center_pos_yxs = [] + + if instance_count > 0: + for instance_id in range(1, instance_count): + pos_list = [] + ys, xs = np.where(instance_label_map == instance_id) + pos_list = list(zip(ys, xs)) + + ### FIX-ME, eliminate outlier + if len(pos_list) < 5: + continue + + # add rule here + main_direction = extract_main_direction(pos_list, + f_direction) # y x + reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x + is_h_angle = abs(np.sum( + main_direction * reference_directin)) < math.cos(math.pi / 180 * + 70) + + point_yxs = np.array(pos_list) + max_y, max_x = np.max(point_yxs, axis=0) + min_y, min_x = np.min(point_yxs, axis=0) + is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x) + + pos_list_final = [] + if is_h_len: + xs = np.unique(xs) + for x in xs: + ys = instance_label_map[:, x].copy().reshape((-1, )) + y = int(np.where(ys == instance_id)[0].mean()) + pos_list_final.append((y, x)) + else: + ys = np.unique(ys) + for y in ys: + xs = instance_label_map[y, :].copy().reshape((-1, )) + x = int(np.where(xs == instance_id)[0].mean()) + pos_list_final.append((y, x)) + + pos_list_sorted, _ = sort_with_direction(pos_list_final, + f_direction) + all_pos_yxs.append(pos_list_sorted) + + # use decoder to filter backgroud points. + p_char_maps = p_char_maps.transpose([1, 2, 0]) + decode_res = ctc_decoder_for_image( + all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True) + for decoded_str, keep_yxs_list in decode_res: + if is_backbone: + keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id) + instance_center_pos_yxs.append(keep_yxs_list_with_id) + else: + end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1])) + center_pos_yxs.extend(keep_yxs_list) + + if is_backbone: + return instance_center_pos_yxs + else: + return center_pos_yxs, end_points_yxs + + +def generate_pivot_list(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0): + """ + Warp all the function together. + """ + if is_curved: + return generate_pivot_list_curved( + p_score, + p_char_maps, + f_direction, + score_thresh=score_thresh, + is_expand=True, + is_backbone=is_backbone, + image_id=image_id) + else: + return generate_pivot_list_horizontal( + p_score, + p_char_maps, + f_direction, + score_thresh=score_thresh, + is_backbone=is_backbone, + image_id=image_id) + + +# for refine module +def extract_main_direction(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + pos_list = np.array(pos_list) + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + average_direction = average_direction / ( + np.linalg.norm(average_direction) + 1e-6) + return average_direction + + +def sort_by_direction_with_image_id_deprecated(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[id, y, x], [id, y, x], [id, y, x] ...] + """ + pos_list_full = np.array(pos_list).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + return sorted_list + + +def sort_by_direction_with_image_id(pos_list, f_direction): + """ + f_direction: h x w x 2 + pos_list: [[y, x], [y, x], [y, x] ...] + """ + + def sort_part_with_direction(pos_list_full, point_direction): + pos_list_full = np.array(pos_list_full).reshape(-1, 3) + pos_list = pos_list_full[:, 1:] + point_direction = np.array(point_direction).reshape(-1, 2) + average_direction = np.mean(point_direction, axis=0, keepdims=True) + pos_proj_leng = np.sum(pos_list * average_direction, axis=1) + sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist() + sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist() + return sorted_list, sorted_direction + + pos_list = np.array(pos_list).reshape(-1, 3) + point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y + point_direction = point_direction[:, ::-1] # x, y -> y, x + sorted_point, sorted_direction = sort_part_with_direction(pos_list, + point_direction) + + point_num = len(sorted_point) + if point_num >= 16: + middle_num = point_num // 2 + first_part_point = sorted_point[:middle_num] + first_point_direction = sorted_direction[:middle_num] + sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction( + first_part_point, first_point_direction) + + last_part_point = sorted_point[middle_num:] + last_point_direction = sorted_direction[middle_num:] + sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction( + last_part_point, last_point_direction) + sorted_point = sorted_fist_part_point + sorted_last_part_point + sorted_direction = sorted_fist_part_direction + sorted_last_part_direction + + return sorted_point + + +def generate_pivot_list_tt_inference(p_score, + p_char_maps, + f_direction, + score_thresh=0.5, + is_backbone=False, + is_curved=True, + image_id=0): + """ + return center point and end point of TCL instance; filter with the char maps; + """ + p_score = p_score[0] + f_direction = f_direction.transpose(1, 2, 0) + p_tcl_map = (p_score > score_thresh) * 1.0 + skeleton_map = thin(p_tcl_map) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) # get TCL Instance all_pos_yxs = [] @@ -269,15 +522,11 @@ def generate_pivot_list(p_score, pos_list = [] ys, xs = np.where(instance_label_map == instance_id) pos_list = list(zip(ys, xs)) - + ### FIX-ME, eliminate outlier if len(pos_list) < 3: continue - pos_list_sorted = sort_and_expand_with_direction_v2( pos_list, f_direction, p_tcl_map) - all_pos_yxs.append(pos_list_sorted) - - p_char_maps = p_char_maps.transpose([1, 2, 0]) - decoded_str, keep_yxs_list = ctc_decoder_for_image( - all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table) - return keep_yxs_list, decoded_str + pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id) + all_pos_yxs.append(pos_list_sorted_with_id) + return all_pos_yxs diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py index 406e1bf3..a5c57914 100755 --- a/tools/infer/predict_e2e.py +++ b/tools/infer/predict_e2e.py @@ -151,7 +151,7 @@ if __name__ == "__main__": src_im = utility.draw_e2e_res(points, strs, image_file) img_name_pure = os.path.split(image_file)[-1] img_path = os.path.join(draw_img_save, - "e2e_res_{}_pgnet".format(img_name_pure)) + "e2e_res_{}".format(img_name_pure)) cv2.imwrite(img_path, src_im) logger.info("The visualized image saved in {}".format(img_path)) if count > 1: diff --git a/train_data/total_text/train/poly/2.txt b/train_data/total_text/train/poly/2.txt deleted file mode 100644 index 961d9680..00000000 --- a/train_data/total_text/train/poly/2.txt +++ /dev/null @@ -1,2 +0,0 @@ -2.0,165.0,20.0,167.0,39.0,170.0,57.0,173.0,76.0,176.0,94.0,179.0,113.0,182.0,109.0,218.0,90.0,215.0,72.0,213.0,54.0,210.0,36.0,208.0,18.0,205.0,0.0,203.0 izza -2.0,411.0,30.0,412.0,58.0,414.0,87.0,416.0,115.0,418.0,143.0,420.0,172.0,422.0,172.0,476.0,143.0,474.0,114.0,472.0,86.0,471.0,57.0,469.0,28.0,467.0,0.0,466.0 ISA diff --git a/train_data/total_text/train/rgb/2.jpg b/train_data/total_text/train/rgb/2.jpg deleted file mode 100644 index f3bc7a06..00000000 Binary files a/train_data/total_text/train/rgb/2.jpg and /dev/null differ