add different post process
This commit is contained in:
parent
0cd48c35e5
commit
a0d1f9236d
|
@ -11,7 +11,7 @@ Global:
|
|||
# from static branch, load_static_weights must be set as True.
|
||||
# 2. If you want to finetune the pretrained models we provide in the docs,
|
||||
# you should set load_static_weights as False.
|
||||
load_static_weights: True
|
||||
load_static_weights: False
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
|
@ -94,7 +94,7 @@ Eval:
|
|||
label_file_list: [./train_data/total_text/test/]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- E2ELabelEncode:
|
||||
- E2EResizeForTest:
|
||||
|
|
|
@ -200,16 +200,18 @@ class E2ELabelEncode(BaseRecLabelEncode):
|
|||
self.pad_num = len(self.dict) # the length to pad
|
||||
|
||||
def __call__(self, data):
|
||||
text_label_index_list, temp_text = [], []
|
||||
texts = data['strs']
|
||||
temp_texts = []
|
||||
for text in texts:
|
||||
text = text.lower()
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
text = text + [self.pad_num] * (self.max_text_len - len(text))
|
||||
temp_texts.append(text)
|
||||
data['strs'] = np.array(temp_texts)
|
||||
temp_text = []
|
||||
for c_ in text:
|
||||
if c_ in self.dict:
|
||||
temp_text.append(self.dict[c_])
|
||||
temp_text = temp_text + [self.pad_num] * (self.max_text_len -
|
||||
len(temp_text))
|
||||
text_label_index_list.append(temp_text)
|
||||
data['strs'] = np.array(text_label_index_list)
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ class PGDataSet(Dataset):
|
|||
|
||||
self.logger = logger
|
||||
self.seed = seed
|
||||
self.mode = mode
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
|
@ -62,10 +63,13 @@ class PGDataSet(Dataset):
|
|||
with open(poly_txt_path) as f:
|
||||
for line in f.readlines():
|
||||
poly_str, txt = line.strip().split('\t')
|
||||
poly = map(float, poly_str.split(','))
|
||||
poly = list(map(float, poly_str.split(',')))
|
||||
if self.mode.lower() == "eval":
|
||||
while len(poly) < 100:
|
||||
poly.append(-1)
|
||||
text_polys.append(
|
||||
np.array(
|
||||
list(poly), dtype=np.float32).reshape(-1, 2))
|
||||
poly, dtype=np.float32).reshape(-1, 2))
|
||||
txts.append(txt)
|
||||
txt_tags.append(txt == '###')
|
||||
|
||||
|
@ -135,6 +139,10 @@ class PGDataSet(Dataset):
|
|||
try:
|
||||
if self.data_format == 'icdar':
|
||||
im_path = os.path.join(data_path, 'rgb', data_line)
|
||||
if self.mode.lower() == "eval":
|
||||
poly_path = os.path.join(data_path, 'poly_gt',
|
||||
data_line.split('.')[0] + '.txt')
|
||||
else:
|
||||
poly_path = os.path.join(data_path, 'poly',
|
||||
data_line.split('.')[0] + '.txt')
|
||||
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
|
||||
|
|
|
@ -33,10 +33,20 @@ class E2EMetric(object):
|
|||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
gt_polyons_batch = batch[2]
|
||||
temp_gt_polyons_batch = batch[2]
|
||||
temp_gt_strs_batch = batch[3]
|
||||
ignore_tags_batch = batch[4]
|
||||
gt_polyons_batch = []
|
||||
gt_strs_batch = []
|
||||
|
||||
temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
|
||||
for temp_list in temp_gt_polyons_batch:
|
||||
t = []
|
||||
for index in temp_list:
|
||||
if index[0] != -1 and index[1] != -1:
|
||||
t.append(index)
|
||||
gt_polyons_batch.append(t)
|
||||
|
||||
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
|
||||
for temp_list in temp_gt_strs_batch:
|
||||
t = ""
|
||||
|
@ -46,7 +56,7 @@ class E2EMetric(object):
|
|||
gt_strs_batch.append(t)
|
||||
|
||||
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
||||
[preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
|
||||
[preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
|
||||
# prepare gt
|
||||
gt_info_list = [{
|
||||
'points': gt_polyon,
|
||||
|
|
|
@ -23,7 +23,8 @@ __dir__ = os.path.dirname(__file__)
|
|||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
||||
from ppocr.utils.e2e_utils.extract_textpoint import get_dict, generate_pivot_list, restore_poly
|
||||
from ppocr.utils.e2e_utils.extract_textpoint import *
|
||||
from ppocr.utils.e2e_utils.visual import *
|
||||
import paddle
|
||||
|
||||
|
||||
|
@ -37,6 +38,11 @@ class PGPostProcess(object):
|
|||
self.valid_set = valid_set
|
||||
self.score_thresh = score_thresh
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||
self.is_python35 = True
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
p_score = outs_dict['f_score']
|
||||
p_border = outs_dict['f_border']
|
||||
|
@ -52,17 +58,96 @@ class PGPostProcess(object):
|
|||
p_border = p_border[0]
|
||||
p_direction = p_direction[0]
|
||||
p_char = p_char[0]
|
||||
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[0]
|
||||
instance_yxs_list, seq_strs = generate_pivot_list(
|
||||
is_curved = self.valid_set == "totaltext"
|
||||
instance_yxs_list = generate_pivot_list(
|
||||
p_score,
|
||||
p_char,
|
||||
p_direction,
|
||||
self.Lexicon_Table,
|
||||
score_thresh=self.score_thresh)
|
||||
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
|
||||
p_border, ratio_w, ratio_h,
|
||||
src_w, src_h, self.valid_set)
|
||||
score_thresh=self.score_thresh,
|
||||
is_backbone=True,
|
||||
is_curved=is_curved)
|
||||
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
|
||||
char_seq_idx_set = []
|
||||
for i in range(len(instance_yxs_list)):
|
||||
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
|
||||
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
|
||||
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
|
||||
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
|
||||
feature_len = [len(feature_seq[0])]
|
||||
featyre_seq = paddle.to_tensor(feature_seq)
|
||||
feature_len = np.array([feature_len]).astype(np.int64)
|
||||
length = paddle.to_tensor(feature_len)
|
||||
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
|
||||
input=featyre_seq, blank=36, input_length=length)
|
||||
seq_pred_str = seq_pred[0].numpy().tolist()[0]
|
||||
seq_len = seq_pred[1].numpy()[0][0]
|
||||
temp_t = []
|
||||
for c in seq_pred_str[:seq_len]:
|
||||
temp_t.append(c)
|
||||
char_seq_idx_set.append(temp_t)
|
||||
seq_strs = []
|
||||
for char_idx_set in char_seq_idx_set:
|
||||
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
|
||||
seq_strs.append(pr_str)
|
||||
poly_list = []
|
||||
keep_str_list = []
|
||||
all_point_list = []
|
||||
all_point_pair_list = []
|
||||
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
||||
if len(yx_center_line) == 1:
|
||||
yx_center_line.append(yx_center_line[-1])
|
||||
|
||||
offset_expand = 1.0
|
||||
if self.valid_set == 'totaltext':
|
||||
offset_expand = 1.2
|
||||
|
||||
point_pair_list = []
|
||||
for batch_id, y, x in yx_center_line:
|
||||
offset = p_border[:, y, x].reshape(2, 2)
|
||||
if offset_expand != 1.0:
|
||||
offset_length = np.linalg.norm(
|
||||
offset, axis=1, keepdims=True)
|
||||
expand_length = np.clip(
|
||||
offset_length * (offset_expand - 1),
|
||||
a_min=0.5,
|
||||
a_max=3.0)
|
||||
offset_detal = offset / offset_length * expand_length
|
||||
offset = offset + offset_detal
|
||||
ori_yx = np.array([y, x], dtype=np.float32)
|
||||
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
|
||||
[ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
all_point_list.append([
|
||||
int(round(x * 4.0 / ratio_w)),
|
||||
int(round(y * 4.0 / ratio_h))
|
||||
])
|
||||
all_point_pair_list.append(point_pair.round().astype(np.int32)
|
||||
.tolist())
|
||||
|
||||
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
|
||||
detected_poly = expand_poly_along_width(
|
||||
detected_poly, shrink_ratio_of_width=0.2)
|
||||
detected_poly[:, 0] = np.clip(
|
||||
detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(
|
||||
detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
|
||||
if len(keep_str) < 2:
|
||||
continue
|
||||
|
||||
keep_str_list.append(keep_str)
|
||||
if self.valid_set == 'partvgg':
|
||||
middle_point = len(detected_poly) // 2
|
||||
detected_poly = detected_poly[
|
||||
[0, middle_point - 1, middle_point, -1], :]
|
||||
poly_list.append(detected_poly)
|
||||
elif self.valid_set == 'totaltext':
|
||||
poly_list.append(detected_poly)
|
||||
else:
|
||||
print('--> Not supported format.')
|
||||
exit(-1)
|
||||
data = {
|
||||
'points': poly_list,
|
||||
'strs': keep_str_list,
|
||||
|
|
|
@ -35,7 +35,7 @@ def get_socre(gt_dict, pred_dict):
|
|||
gt = []
|
||||
n = len(gt_dict)
|
||||
for i in range(n):
|
||||
points = gt_dict[i]['points'].tolist()
|
||||
points = gt_dict[i]['points']
|
||||
h = len(points)
|
||||
text = gt_dict[i]['text']
|
||||
xx = [
|
||||
|
@ -51,7 +51,7 @@ def get_socre(gt_dict, pred_dict):
|
|||
t_y.append(points[j][1])
|
||||
xx[1] = np.array([t_x], dtype='int16')
|
||||
xx[3] = np.array([t_y], dtype='int16')
|
||||
if text != "":
|
||||
if text != "" and "#" not in text:
|
||||
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
|
||||
xx[5] = np.array(['c'], dtype='<U1')
|
||||
gt.append(xx)
|
||||
|
@ -89,17 +89,10 @@ def get_socre(gt_dict, pred_dict):
|
|||
area(det_x, det_y)), 2)
|
||||
|
||||
##############################Initialization###################################
|
||||
global_tp = 0
|
||||
global_fp = 0
|
||||
global_fn = 0
|
||||
global_sigma = []
|
||||
global_tau = []
|
||||
tr = 0.7
|
||||
tp = 0.6
|
||||
fsc_k = 0.8
|
||||
k = 2
|
||||
global_pred_str = []
|
||||
global_gt_str = []
|
||||
# global_sigma = []
|
||||
# global_tau = []
|
||||
# global_pred_str = []
|
||||
# global_gt_str = []
|
||||
###############################################################################
|
||||
|
||||
for input_id in range(allInputs):
|
||||
|
@ -147,281 +140,16 @@ def get_socre(gt_dict, pred_dict):
|
|||
local_pred_str[det_id] = pred_seq_str
|
||||
local_gt_str[gt_id] = gt_seq_str
|
||||
|
||||
global_sigma.append(local_sigma_table)
|
||||
global_tau.append(local_tau_table)
|
||||
global_pred_str.append(local_pred_str)
|
||||
global_gt_str.append(local_gt_str)
|
||||
|
||||
global_accumulative_recall = 0
|
||||
global_accumulative_precision = 0
|
||||
total_num_gt = 0
|
||||
total_num_det = 0
|
||||
hit_str_count = 0
|
||||
hit_count = 0
|
||||
|
||||
def one_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idy):
|
||||
hit_str_num = 0
|
||||
for gt_id in range(num_gt):
|
||||
gt_matching_qualified_sigma_candidates = np.where(
|
||||
local_sigma_table[gt_id, :] > tr)
|
||||
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
|
||||
0].shape[0]
|
||||
gt_matching_qualified_tau_candidates = np.where(
|
||||
local_tau_table[gt_id, :] > tp)
|
||||
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
|
||||
0].shape[0]
|
||||
|
||||
det_matching_qualified_sigma_candidates = np.where(
|
||||
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
|
||||
> tr)
|
||||
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
|
||||
0].shape[0]
|
||||
det_matching_qualified_tau_candidates = np.where(
|
||||
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
|
||||
tp)
|
||||
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
|
||||
0].shape[0]
|
||||
|
||||
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
|
||||
(det_matching_num_qualified_sigma_candidates == 1) and (
|
||||
det_matching_num_qualified_tau_candidates == 1):
|
||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
||||
|
||||
gt_flag[0, gt_id] = 1
|
||||
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
|
||||
# recg start
|
||||
|
||||
gt_str_cur = global_gt_str[idy][gt_id]
|
||||
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
|
||||
0]]
|
||||
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
# recg end
|
||||
det_flag[0, matched_det_id] = 1
|
||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
||||
|
||||
def one_to_many(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idy):
|
||||
hit_str_num = 0
|
||||
for gt_id in range(num_gt):
|
||||
# skip the following if the groundtruth was matched
|
||||
if gt_flag[0, gt_id] > 0:
|
||||
continue
|
||||
|
||||
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
|
||||
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
|
||||
|
||||
if num_non_zero_in_sigma >= k:
|
||||
####search for all detections that overlaps with this groundtruth
|
||||
qualified_tau_candidates = np.where((local_tau_table[
|
||||
gt_id, :] >= tp) & (det_flag[0, :] == 0))
|
||||
num_qualified_tau_candidates = qualified_tau_candidates[
|
||||
0].shape[0]
|
||||
|
||||
if num_qualified_tau_candidates == 1:
|
||||
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
|
||||
and
|
||||
(local_sigma_table[gt_id, qualified_tau_candidates] >=
|
||||
tr)):
|
||||
# became an one-to-one case
|
||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
||||
|
||||
gt_flag[0, gt_id] = 1
|
||||
det_flag[0, qualified_tau_candidates] = 1
|
||||
# recg start
|
||||
gt_str_cur = global_gt_str[idy][gt_id]
|
||||
pred_str_cur = global_pred_str[idy][
|
||||
qualified_tau_candidates[0].tolist()[0]]
|
||||
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
# recg end
|
||||
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
|
||||
>= tr):
|
||||
gt_flag[0, gt_id] = 1
|
||||
det_flag[0, qualified_tau_candidates] = 1
|
||||
# recg start
|
||||
|
||||
gt_str_cur = global_gt_str[idy][gt_id]
|
||||
pred_str_cur = global_pred_str[idy][
|
||||
qualified_tau_candidates[0].tolist()[0]]
|
||||
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
# recg end
|
||||
|
||||
global_accumulative_recall = global_accumulative_recall + fsc_k
|
||||
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
|
||||
|
||||
local_accumulative_recall = local_accumulative_recall + fsc_k
|
||||
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
|
||||
|
||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
||||
|
||||
def many_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idy):
|
||||
hit_str_num = 0
|
||||
for det_id in range(num_det):
|
||||
# skip the following if the detection was matched
|
||||
if det_flag[0, det_id] > 0:
|
||||
continue
|
||||
|
||||
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
|
||||
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
|
||||
|
||||
if num_non_zero_in_tau >= k:
|
||||
####search for all detections that overlaps with this groundtruth
|
||||
qualified_sigma_candidates = np.where((
|
||||
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
|
||||
num_qualified_sigma_candidates = qualified_sigma_candidates[
|
||||
0].shape[0]
|
||||
|
||||
if num_qualified_sigma_candidates == 1:
|
||||
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
|
||||
tp) and
|
||||
(local_sigma_table[qualified_sigma_candidates, det_id]
|
||||
>= tr)):
|
||||
# became an one-to-one case
|
||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
||||
|
||||
gt_flag[0, qualified_sigma_candidates] = 1
|
||||
det_flag[0, det_id] = 1
|
||||
# recg start
|
||||
pred_str_cur = global_pred_str[idy][det_id]
|
||||
gt_len = len(qualified_sigma_candidates[0])
|
||||
for idx in range(gt_len):
|
||||
ele_gt_id = qualified_sigma_candidates[0].tolist()[
|
||||
idx]
|
||||
if ele_gt_id not in global_gt_str[idy]:
|
||||
continue
|
||||
gt_str_cur = global_gt_str[idy][ele_gt_id]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
break
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
break
|
||||
# recg end
|
||||
elif (np.sum(local_tau_table[qualified_sigma_candidates,
|
||||
det_id]) >= tp):
|
||||
det_flag[0, det_id] = 1
|
||||
gt_flag[0, qualified_sigma_candidates] = 1
|
||||
# recg start
|
||||
|
||||
pred_str_cur = global_pred_str[idy][det_id]
|
||||
gt_len = len(qualified_sigma_candidates[0])
|
||||
for idx in range(gt_len):
|
||||
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
|
||||
if ele_gt_id not in global_gt_str[idy]:
|
||||
continue
|
||||
gt_str_cur = global_gt_str[idy][ele_gt_id]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
break
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
break
|
||||
# recg end
|
||||
|
||||
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
|
||||
global_accumulative_precision = global_accumulative_precision + fsc_k
|
||||
|
||||
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
|
||||
local_accumulative_precision = local_accumulative_precision + fsc_k
|
||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
||||
global_sigma = local_sigma_table
|
||||
global_tau = local_tau_table
|
||||
global_pred_str = local_pred_str
|
||||
global_gt_str = local_gt_str
|
||||
|
||||
single_data = {}
|
||||
for idx in range(len(global_sigma)):
|
||||
local_sigma_table = global_sigma[idx]
|
||||
local_tau_table = global_tau[idx]
|
||||
|
||||
num_gt = local_sigma_table.shape[0]
|
||||
num_det = local_sigma_table.shape[1]
|
||||
|
||||
total_num_gt = total_num_gt + num_gt
|
||||
total_num_det = total_num_det + num_det
|
||||
|
||||
local_accumulative_recall = 0
|
||||
local_accumulative_precision = 0
|
||||
gt_flag = np.zeros((1, num_gt))
|
||||
det_flag = np.zeros((1, num_det))
|
||||
|
||||
#######first check for one-to-one case##########
|
||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
||||
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idx)
|
||||
|
||||
hit_str_count += hit_str_num
|
||||
#######then check for one-to-many case##########
|
||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
||||
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idx)
|
||||
hit_str_count += hit_str_num
|
||||
#######then check for many-to-one case##########
|
||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
||||
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idx)
|
||||
|
||||
hit_str_count += hit_str_num
|
||||
|
||||
# fid = open(fid_path, 'a+')
|
||||
try:
|
||||
local_precision = local_accumulative_precision / num_det
|
||||
except ZeroDivisionError:
|
||||
local_precision = 0
|
||||
|
||||
try:
|
||||
local_recall = local_accumulative_recall / num_gt
|
||||
except ZeroDivisionError:
|
||||
local_recall = 0
|
||||
|
||||
try:
|
||||
local_f_score = 2 * local_precision * local_recall / (
|
||||
local_precision + local_recall)
|
||||
except ZeroDivisionError:
|
||||
local_f_score = 0
|
||||
|
||||
single_data['sigma'] = global_sigma
|
||||
single_data['global_tau'] = global_tau
|
||||
single_data['global_pred_str'] = global_pred_str
|
||||
single_data['global_gt_str'] = global_gt_str
|
||||
single_data["recall"] = local_recall
|
||||
single_data['precision'] = local_precision
|
||||
single_data['f_score'] = local_f_score
|
||||
return single_data
|
||||
|
||||
|
||||
|
@ -435,10 +163,10 @@ def combine_results(all_data):
|
|||
global_pred_str = []
|
||||
global_gt_str = []
|
||||
for data in all_data:
|
||||
global_sigma.append(data['sigma'][0])
|
||||
global_tau.append(data['global_tau'][0])
|
||||
global_pred_str.append(data['global_pred_str'][0])
|
||||
global_gt_str.append(data['global_gt_str'][0])
|
||||
global_sigma.append(data['sigma'])
|
||||
global_tau.append(data['global_tau'])
|
||||
global_pred_str.append(data['global_pred_str'])
|
||||
global_gt_str.append(data['global_gt_str'])
|
||||
|
||||
global_accumulative_recall = 0
|
||||
global_accumulative_precision = 0
|
||||
|
@ -676,6 +404,8 @@ def combine_results(all_data):
|
|||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idx)
|
||||
hit_str_count += hit_str_num
|
||||
|
||||
try:
|
||||
recall = global_accumulative_recall / total_num_gt
|
||||
except ZeroDivisionError:
|
||||
|
|
|
@ -17,9 +17,11 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import cv2
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from itertools import groupby
|
||||
from cv2.ximgproc import thinning as thin
|
||||
from skimage.morphology._skeletonize import thin
|
||||
|
||||
|
||||
def get_dict(character_dict_path):
|
||||
|
@ -33,39 +35,87 @@ def get_dict(character_dict_path):
|
|||
return dict_character
|
||||
|
||||
|
||||
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
|
||||
def softmax(logits):
|
||||
"""
|
||||
logits: N x d
|
||||
"""
|
||||
max_value = np.max(logits, axis=1, keepdims=True)
|
||||
exp = np.exp(logits - max_value)
|
||||
exp_sum = np.sum(exp, axis=1, keepdims=True)
|
||||
dist = exp / exp_sum
|
||||
return dist
|
||||
|
||||
|
||||
def get_keep_pos_idxs(labels, remove_blank=None):
|
||||
"""
|
||||
Remove duplicate and get pos idxs of keep items.
|
||||
The value of keep_blank should be [None, 95].
|
||||
"""
|
||||
duplicate_len_list = []
|
||||
keep_pos_idx_list = []
|
||||
keep_char_idx_list = []
|
||||
for k, v_ in groupby(labels):
|
||||
current_len = len(list(v_))
|
||||
if k != remove_blank:
|
||||
current_idx = int(sum(duplicate_len_list) + current_len // 2)
|
||||
keep_pos_idx_list.append(current_idx)
|
||||
keep_char_idx_list.append(k)
|
||||
duplicate_len_list.append(current_len)
|
||||
return keep_char_idx_list, keep_pos_idx_list
|
||||
|
||||
|
||||
def remove_blank(labels, blank=0):
|
||||
new_labels = [x for x in labels if x != blank]
|
||||
return new_labels
|
||||
|
||||
|
||||
def insert_blank(labels, blank=0):
|
||||
new_labels = [blank]
|
||||
for l in labels:
|
||||
new_labels += [l, blank]
|
||||
return new_labels
|
||||
|
||||
|
||||
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
|
||||
"""
|
||||
CTC greedy (best path) decoder.
|
||||
"""
|
||||
raw_str = np.argmax(np.array(probs_seq), axis=1)
|
||||
remove_blank_in_pos = None if keep_blank_in_idxs else blank
|
||||
dedup_str, keep_idx_list = get_keep_pos_idxs(
|
||||
raw_str, remove_blank=remove_blank_in_pos)
|
||||
dst_str = remove_blank(dedup_str, blank=blank)
|
||||
return dst_str, keep_idx_list
|
||||
|
||||
|
||||
def instance_ctc_greedy_decoder(gather_info,
|
||||
logits_map,
|
||||
keep_blank_in_idxs=True):
|
||||
"""
|
||||
gather_info: [[x, y], [x, y] ...]
|
||||
logits_map: H x W X (n_chars + 1)
|
||||
"""
|
||||
_, _, C = logits_map.shape
|
||||
ys, xs = zip(*gather_info)
|
||||
logits_seq = logits_map[list(ys), list(xs)]
|
||||
probs_seq = logits_seq
|
||||
labels = np.argmax(probs_seq, axis=1)
|
||||
dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
|
||||
detal = len(gather_info) // (pts_num - 1)
|
||||
keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
|
||||
logits_seq = logits_map[list(ys), list(xs)] # n x 96
|
||||
probs_seq = softmax(logits_seq)
|
||||
dst_str, keep_idx_list = ctc_greedy_decoder(
|
||||
probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
|
||||
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
|
||||
return dst_str, keep_gather_list
|
||||
|
||||
|
||||
def ctc_decoder_for_image(gather_info_list,
|
||||
logits_map,
|
||||
Lexicon_Table,
|
||||
pts_num=6):
|
||||
def ctc_decoder_for_image(gather_info_list, logits_map,
|
||||
keep_blank_in_idxs=True):
|
||||
"""
|
||||
CTC decoder using multiple processes.
|
||||
"""
|
||||
decoder_str = []
|
||||
decoder_xys = []
|
||||
decoder_results = []
|
||||
for gather_info in gather_info_list:
|
||||
if len(gather_info) < pts_num:
|
||||
continue
|
||||
dst_str, xys_list = instance_ctc_greedy_decoder(
|
||||
gather_info, logits_map, pts_num=pts_num)
|
||||
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
|
||||
if len(dst_str_readable) < 2:
|
||||
continue
|
||||
decoder_str.append(dst_str_readable)
|
||||
decoder_xys.append(xys_list)
|
||||
return decoder_str, decoder_xys
|
||||
res = instance_ctc_greedy_decoder(
|
||||
gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
|
||||
decoder_results.append(res)
|
||||
return decoder_results
|
||||
|
||||
|
||||
def sort_with_direction(pos_list, f_direction):
|
||||
|
@ -107,6 +157,58 @@ def sort_with_direction(pos_list, f_direction):
|
|||
return sorted_point, np.array(sorted_direction)
|
||||
|
||||
|
||||
def add_id(pos_list, image_id=0):
|
||||
"""
|
||||
Add id for gather feature, for inference.
|
||||
"""
|
||||
new_list = []
|
||||
for item in pos_list:
|
||||
new_list.append((image_id, item[0], item[1]))
|
||||
return new_list
|
||||
|
||||
|
||||
def sort_and_expand_with_direction(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
h, w, _ = f_direction.shape
|
||||
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
||||
|
||||
# expand along
|
||||
point_num = len(sorted_list)
|
||||
sub_direction_len = max(point_num // 3, 2)
|
||||
left_direction = point_direction[:sub_direction_len, :]
|
||||
right_dirction = point_direction[point_num - sub_direction_len:, :]
|
||||
|
||||
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
||||
left_average_len = np.linalg.norm(left_average_direction)
|
||||
left_start = np.array(sorted_list[0])
|
||||
left_step = left_average_direction / (left_average_len + 1e-6)
|
||||
|
||||
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
||||
right_average_len = np.linalg.norm(right_average_direction)
|
||||
right_step = right_average_direction / (right_average_len + 1e-6)
|
||||
right_start = np.array(sorted_list[-1])
|
||||
|
||||
append_num = max(
|
||||
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
||||
left_list = []
|
||||
right_list = []
|
||||
for i in range(append_num):
|
||||
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
|
||||
'int32').tolist()
|
||||
if ly < h and lx < w and (ly, lx) not in left_list:
|
||||
left_list.append((ly, lx))
|
||||
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
|
||||
'int32').tolist()
|
||||
if ry < h and rx < w and (ry, rx) not in right_list:
|
||||
right_list.append((ry, rx))
|
||||
|
||||
all_list = left_list[::-1] + sorted_list + right_list
|
||||
return all_list
|
||||
|
||||
|
||||
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
|
@ -116,6 +218,7 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
|
|||
h, w, _ = f_direction.shape
|
||||
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
||||
|
||||
# expand along
|
||||
point_num = len(sorted_list)
|
||||
sub_direction_len = max(point_num // 3, 2)
|
||||
left_direction = point_direction[:sub_direction_len, :]
|
||||
|
@ -159,108 +262,258 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
|
|||
return all_list
|
||||
|
||||
|
||||
def point_pair2poly(point_pair_list):
|
||||
"""
|
||||
Transfer vertical point_pairs into poly point in clockwise.
|
||||
"""
|
||||
point_num = len(point_pair_list) * 2
|
||||
point_list = [0] * point_num
|
||||
for idx, point_pair in enumerate(point_pair_list):
|
||||
point_list[idx] = point_pair[0]
|
||||
point_list[point_num - 1 - idx] = point_pair[1]
|
||||
return np.array(point_list).reshape(-1, 2)
|
||||
|
||||
|
||||
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
|
||||
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
|
||||
"""
|
||||
expand poly along width.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
left_quad = np.array(
|
||||
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||
right_quad = np.array(
|
||||
[
|
||||
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]
|
||||
],
|
||||
dtype=np.float32)
|
||||
right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||
poly[0] = left_quad_expand[0]
|
||||
poly[-1] = left_quad_expand[-1]
|
||||
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||
poly[point_num // 2] = right_quad_expand[2]
|
||||
return poly
|
||||
|
||||
|
||||
def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
|
||||
src_h, valid_set):
|
||||
poly_list = []
|
||||
keep_str_list = []
|
||||
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
||||
if len(keep_str) < 2:
|
||||
print('--> too short, {}'.format(keep_str))
|
||||
continue
|
||||
|
||||
offset_expand = 1.0
|
||||
if valid_set == 'totaltext':
|
||||
offset_expand = 1.2
|
||||
|
||||
point_pair_list = []
|
||||
for y, x in yx_center_line:
|
||||
offset = p_border[:, y, x].reshape(2, 2) * offset_expand
|
||||
ori_yx = np.array([y, x], dtype=np.float32)
|
||||
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
|
||||
[ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
detected_poly = point_pair2poly(point_pair_list)
|
||||
detected_poly = expand_poly_along_width(
|
||||
detected_poly, shrink_ratio_of_width=0.2)
|
||||
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
|
||||
keep_str_list.append(keep_str)
|
||||
if valid_set == 'partvgg':
|
||||
middle_point = len(detected_poly) // 2
|
||||
detected_poly = detected_poly[
|
||||
[0, middle_point - 1, middle_point, -1], :]
|
||||
poly_list.append(detected_poly)
|
||||
elif valid_set == 'totaltext':
|
||||
poly_list.append(detected_poly)
|
||||
else:
|
||||
print('--> Not supported format.')
|
||||
exit(-1)
|
||||
return poly_list, keep_str_list
|
||||
|
||||
|
||||
def generate_pivot_list(p_score,
|
||||
def generate_pivot_list_curved(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
Lexicon_Table,
|
||||
score_thresh=0.5):
|
||||
score_thresh=0.5,
|
||||
is_expand=True,
|
||||
is_backbone=False,
|
||||
image_id=0):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
p_score = p_score[0]
|
||||
f_direction = f_direction.transpose(1, 2, 0)
|
||||
ret, p_tcl_map = cv2.threshold(p_score, score_thresh, 255,
|
||||
cv2.THRESH_BINARY)
|
||||
skeleton_map = thin(p_tcl_map.astype('uint8'))
|
||||
p_tcl_map = (p_score > score_thresh) * 1.0
|
||||
skeleton_map = thin(p_tcl_map)
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
skeleton_map, connectivity=8)
|
||||
skeleton_map.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
center_pos_yxs = []
|
||||
end_points_yxs = []
|
||||
instance_center_pos_yxs = []
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 3:
|
||||
continue
|
||||
|
||||
if is_expand:
|
||||
pos_list_sorted = sort_and_expand_with_direction_v2(
|
||||
pos_list, f_direction, p_tcl_map)
|
||||
else:
|
||||
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
|
||||
all_pos_yxs.append(pos_list_sorted)
|
||||
|
||||
# use decoder to filter backgroud points.
|
||||
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||
decode_res = ctc_decoder_for_image(
|
||||
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
|
||||
for decoded_str, keep_yxs_list in decode_res:
|
||||
if is_backbone:
|
||||
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
|
||||
instance_center_pos_yxs.append(keep_yxs_list_with_id)
|
||||
else:
|
||||
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
|
||||
center_pos_yxs.extend(keep_yxs_list)
|
||||
|
||||
if is_backbone:
|
||||
return instance_center_pos_yxs
|
||||
else:
|
||||
return center_pos_yxs, end_points_yxs
|
||||
|
||||
|
||||
def generate_pivot_list_horizontal(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_backbone=False,
|
||||
image_id=0):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
p_score = p_score[0]
|
||||
f_direction = f_direction.transpose(1, 2, 0)
|
||||
p_tcl_map_bi = (p_score > score_thresh) * 1.0
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
p_tcl_map_bi.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
center_pos_yxs = []
|
||||
end_points_yxs = []
|
||||
instance_center_pos_yxs = []
|
||||
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 5:
|
||||
continue
|
||||
|
||||
# add rule here
|
||||
main_direction = extract_main_direction(pos_list,
|
||||
f_direction) # y x
|
||||
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
|
||||
is_h_angle = abs(np.sum(
|
||||
main_direction * reference_directin)) < math.cos(math.pi / 180 *
|
||||
70)
|
||||
|
||||
point_yxs = np.array(pos_list)
|
||||
max_y, max_x = np.max(point_yxs, axis=0)
|
||||
min_y, min_x = np.min(point_yxs, axis=0)
|
||||
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
|
||||
|
||||
pos_list_final = []
|
||||
if is_h_len:
|
||||
xs = np.unique(xs)
|
||||
for x in xs:
|
||||
ys = instance_label_map[:, x].copy().reshape((-1, ))
|
||||
y = int(np.where(ys == instance_id)[0].mean())
|
||||
pos_list_final.append((y, x))
|
||||
else:
|
||||
ys = np.unique(ys)
|
||||
for y in ys:
|
||||
xs = instance_label_map[y, :].copy().reshape((-1, ))
|
||||
x = int(np.where(xs == instance_id)[0].mean())
|
||||
pos_list_final.append((y, x))
|
||||
|
||||
pos_list_sorted, _ = sort_with_direction(pos_list_final,
|
||||
f_direction)
|
||||
all_pos_yxs.append(pos_list_sorted)
|
||||
|
||||
# use decoder to filter backgroud points.
|
||||
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||
decode_res = ctc_decoder_for_image(
|
||||
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
|
||||
for decoded_str, keep_yxs_list in decode_res:
|
||||
if is_backbone:
|
||||
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
|
||||
instance_center_pos_yxs.append(keep_yxs_list_with_id)
|
||||
else:
|
||||
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
|
||||
center_pos_yxs.extend(keep_yxs_list)
|
||||
|
||||
if is_backbone:
|
||||
return instance_center_pos_yxs
|
||||
else:
|
||||
return center_pos_yxs, end_points_yxs
|
||||
|
||||
|
||||
def generate_pivot_list(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_backbone=False,
|
||||
is_curved=True,
|
||||
image_id=0):
|
||||
"""
|
||||
Warp all the function together.
|
||||
"""
|
||||
if is_curved:
|
||||
return generate_pivot_list_curved(
|
||||
p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=score_thresh,
|
||||
is_expand=True,
|
||||
is_backbone=is_backbone,
|
||||
image_id=image_id)
|
||||
else:
|
||||
return generate_pivot_list_horizontal(
|
||||
p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=score_thresh,
|
||||
is_backbone=is_backbone,
|
||||
image_id=image_id)
|
||||
|
||||
|
||||
# for refine module
|
||||
def extract_main_direction(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
pos_list = np.array(pos_list)
|
||||
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
average_direction = average_direction / (
|
||||
np.linalg.norm(average_direction) + 1e-6)
|
||||
return average_direction
|
||||
|
||||
|
||||
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
|
||||
"""
|
||||
pos_list_full = np.array(pos_list).reshape(-1, 3)
|
||||
pos_list = pos_list_full[:, 1:]
|
||||
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
||||
return sorted_list
|
||||
|
||||
|
||||
def sort_by_direction_with_image_id(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
|
||||
def sort_part_with_direction(pos_list_full, point_direction):
|
||||
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
|
||||
pos_list = pos_list_full[:, 1:]
|
||||
point_direction = np.array(point_direction).reshape(-1, 2)
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
||||
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
||||
return sorted_list, sorted_direction
|
||||
|
||||
pos_list = np.array(pos_list).reshape(-1, 3)
|
||||
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
|
||||
point_direction)
|
||||
|
||||
point_num = len(sorted_point)
|
||||
if point_num >= 16:
|
||||
middle_num = point_num // 2
|
||||
first_part_point = sorted_point[:middle_num]
|
||||
first_point_direction = sorted_direction[:middle_num]
|
||||
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
||||
first_part_point, first_point_direction)
|
||||
|
||||
last_part_point = sorted_point[middle_num:]
|
||||
last_point_direction = sorted_direction[middle_num:]
|
||||
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
||||
last_part_point, last_point_direction)
|
||||
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
||||
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
||||
|
||||
return sorted_point
|
||||
|
||||
|
||||
def generate_pivot_list_tt_inference(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_backbone=False,
|
||||
is_curved=True,
|
||||
image_id=0):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
p_score = p_score[0]
|
||||
f_direction = f_direction.transpose(1, 2, 0)
|
||||
p_tcl_map = (p_score > score_thresh) * 1.0
|
||||
skeleton_map = thin(p_tcl_map)
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
skeleton_map.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
|
@ -269,15 +522,11 @@ def generate_pivot_list(p_score,
|
|||
pos_list = []
|
||||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 3:
|
||||
continue
|
||||
|
||||
pos_list_sorted = sort_and_expand_with_direction_v2(
|
||||
pos_list, f_direction, p_tcl_map)
|
||||
all_pos_yxs.append(pos_list_sorted)
|
||||
|
||||
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||
decoded_str, keep_yxs_list = ctc_decoder_for_image(
|
||||
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
|
||||
return keep_yxs_list, decoded_str
|
||||
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
|
||||
all_pos_yxs.append(pos_list_sorted_with_id)
|
||||
return all_pos_yxs
|
||||
|
|
|
@ -151,7 +151,7 @@ if __name__ == "__main__":
|
|||
src_im = utility.draw_e2e_res(points, strs, image_file)
|
||||
img_name_pure = os.path.split(image_file)[-1]
|
||||
img_path = os.path.join(draw_img_save,
|
||||
"e2e_res_{}_pgnet".format(img_name_pure))
|
||||
"e2e_res_{}".format(img_name_pure))
|
||||
cv2.imwrite(img_path, src_im)
|
||||
logger.info("The visualized image saved in {}".format(img_path))
|
||||
if count > 1:
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
2.0,165.0,20.0,167.0,39.0,170.0,57.0,173.0,76.0,176.0,94.0,179.0,113.0,182.0,109.0,218.0,90.0,215.0,72.0,213.0,54.0,210.0,36.0,208.0,18.0,205.0,0.0,203.0 izza
|
||||
2.0,411.0,30.0,412.0,58.0,414.0,87.0,416.0,115.0,418.0,143.0,420.0,172.0,422.0,172.0,476.0,143.0,474.0,114.0,472.0,86.0,471.0,57.0,469.0,28.0,467.0,0.0,466.0 ISA
|
Binary file not shown.
Before Width: | Height: | Size: 41 KiB |
Loading…
Reference in New Issue