add table eval and predict script
This commit is contained in:
parent
ad4853dbe8
commit
794362481e
|
@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
|
|||
self.character_str = string.printable[:-6]
|
||||
dict_character = list(self.character_str)
|
||||
elif character_type in support_character_type:
|
||||
self.character_str = ""
|
||||
self.character_str = []
|
||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
|
||||
character_type)
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str += line
|
||||
self.character_str.append(line)
|
||||
if use_space_char:
|
||||
self.character_str += " "
|
||||
self.character_str.append(" ")
|
||||
dict_character = list(self.character_str)
|
||||
|
||||
else:
|
||||
|
@ -288,3 +288,172 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
||||
|
||||
class TableLabelDecode(object):
|
||||
""" """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
max_elem_length,
|
||||
max_cell_num,
|
||||
character_dict_path,
|
||||
**kwargs):
|
||||
self.max_text_length = max_text_length
|
||||
self.max_elem_length = max_elem_length
|
||||
self.max_cell_num = max_cell_num
|
||||
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
||||
list_character = self.add_special_char(list_character)
|
||||
list_elem = self.add_special_char(list_elem)
|
||||
self.dict_character = {}
|
||||
self.dict_idx_character = {}
|
||||
for i, char in enumerate(list_character):
|
||||
self.dict_idx_character[i] = char
|
||||
self.dict_character[char] = i
|
||||
self.dict_elem = {}
|
||||
self.dict_idx_elem = {}
|
||||
for i, elem in enumerate(list_elem):
|
||||
self.dict_idx_elem[i] = elem
|
||||
self.dict_elem[elem] = i
|
||||
|
||||
def load_char_elem_dict(self, character_dict_path):
|
||||
list_character = []
|
||||
list_elem = []
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
substr = lines[0].decode('utf-8').strip("\n").split("\t")
|
||||
character_num = int(substr[0])
|
||||
elem_num = int(substr[1])
|
||||
for cno in range(1, 1 + character_num):
|
||||
character = lines[cno].decode('utf-8').strip("\n")
|
||||
list_character.append(character)
|
||||
for eno in range(1 + character_num, 1 + character_num + elem_num):
|
||||
elem = lines[eno].decode('utf-8').strip("\n")
|
||||
list_elem.append(elem)
|
||||
return list_character, list_elem
|
||||
|
||||
def add_special_char(self, list_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||
return list_character
|
||||
|
||||
def get_sp_tokens(self):
|
||||
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
|
||||
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
||||
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
||||
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
||||
elem_char_idx1 = self.dict_elem['<td>']
|
||||
elem_char_idx2 = self.dict_elem['<td']
|
||||
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
|
||||
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||
self.max_elem_length, self.max_cell_num])
|
||||
return sp_tokens
|
||||
|
||||
def __call__(self, preds):
|
||||
structure_probs = preds['structure_probs']
|
||||
loc_preds = preds['loc_preds']
|
||||
if isinstance(structure_probs,paddle.Tensor):
|
||||
structure_probs = structure_probs.numpy()
|
||||
if isinstance(loc_preds,paddle.Tensor):
|
||||
loc_preds = loc_preds.numpy()
|
||||
structure_idx = structure_probs.argmax(axis=2)
|
||||
structure_probs = structure_probs.max(axis=2)
|
||||
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
|
||||
structure_probs, 'elem')
|
||||
res_html_code_list = []
|
||||
res_loc_list = []
|
||||
batch_num = len(structure_str)
|
||||
for bno in range(batch_num):
|
||||
res_loc = []
|
||||
for sno in range(len(structure_str[bno])):
|
||||
text = structure_str[bno][sno]
|
||||
if text in ['<td>', '<td']:
|
||||
pos = structure_pos[bno][sno]
|
||||
res_loc.append(loc_preds[bno, pos])
|
||||
res_html_code = ''.join(structure_str[bno])
|
||||
res_loc = np.array(res_loc)
|
||||
res_html_code_list.append(res_html_code)
|
||||
res_loc_list.append(res_loc)
|
||||
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
|
||||
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
|
||||
|
||||
def decode(self, text_index, structure_probs, char_or_elem):
|
||||
"""convert text-label into text-index.
|
||||
"""
|
||||
if char_or_elem == "char":
|
||||
max_len = self.max_text_length
|
||||
current_dict = self.dict_idx_character
|
||||
else:
|
||||
max_len = self.max_elem_length
|
||||
current_dict = self.dict_idx_elem
|
||||
ignored_tokens = self.get_ignored_tokens('elem')
|
||||
beg_idx, end_idx = ignored_tokens
|
||||
|
||||
# select_td_tokens = []
|
||||
# select_span_tokens = []
|
||||
# for elem in self.dict_elem:
|
||||
# # if elem == '<td>' or elem == '<td' or elem == '<tr>'\
|
||||
# # or 'rowspan' in elem or 'colspan' in elem:
|
||||
# if elem == '<td>' or elem == '<td' or elem == '<tr>':
|
||||
# select_td_tokens.append(self.dict_elem[elem])
|
||||
# if 'rowspan' in elem or 'colspan' in elem:
|
||||
# select_span_tokens.append(self.dict_elem[elem])
|
||||
result_list = []
|
||||
result_pos_list = []
|
||||
result_score_list = []
|
||||
result_elem_idx_list = []
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
elem_pos_list = []
|
||||
elem_idx_list = []
|
||||
score_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
tmp_elem_idx = int(text_index[batch_idx][idx])
|
||||
if idx > 0 and tmp_elem_idx == end_idx:
|
||||
break
|
||||
if tmp_elem_idx in ignored_tokens:
|
||||
continue
|
||||
# if tmp_elem_idx in select_td_tokens:
|
||||
# total_td_score += structure_probs[batch_idx, idx]
|
||||
# total_td_num += 1
|
||||
# if tmp_elem_idx in select_span_tokens:
|
||||
# total_span_score += structure_probs[batch_idx, idx]
|
||||
# total_span_num += 1
|
||||
char_list.append(current_dict[tmp_elem_idx])
|
||||
elem_pos_list.append(idx)
|
||||
score_list.append(structure_probs[batch_idx, idx])
|
||||
elem_idx_list.append(tmp_elem_idx)
|
||||
result_list.append(char_list)
|
||||
result_pos_list.append(elem_pos_list)
|
||||
result_score_list.append(score_list)
|
||||
result_elem_idx_list.append(elem_idx_list)
|
||||
return result_list, result_pos_list, result_score_list, result_elem_idx_list
|
||||
|
||||
def get_ignored_tokens(self, char_or_elem):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
||||
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
||||
if char_or_elem == "char":
|
||||
if beg_or_end == "beg":
|
||||
idx = self.dict_character[self.beg_str]
|
||||
elif beg_or_end == "end":
|
||||
idx = self.dict_character[self.end_str]
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
|
||||
% beg_or_end
|
||||
elif char_or_elem == "elem":
|
||||
if beg_or_end == "beg":
|
||||
idx = self.dict_elem[self.beg_str]
|
||||
elif beg_or_end == "end":
|
||||
idx = self.dict_elem[self.end_str]
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
|
||||
% beg_or_end
|
||||
else:
|
||||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
|
|
@ -0,0 +1,278 @@
|
|||
←
|
||||
</overline>
|
||||
☆
|
||||
─
|
||||
α
|
||||
|
||||
|
||||
⋅
|
||||
$
|
||||
ω
|
||||
ψ
|
||||
χ
|
||||
(
|
||||
υ
|
||||
≥
|
||||
σ
|
||||
,
|
||||
ρ
|
||||
ε
|
||||
0
|
||||
■
|
||||
4
|
||||
8
|
||||
✗
|
||||
b
|
||||
<
|
||||
✓
|
||||
Ψ
|
||||
Ω
|
||||
€
|
||||
D
|
||||
3
|
||||
Π
|
||||
H
|
||||
║
|
||||
</
|
||||
>
|
||||
L
|
||||
Φ
|
||||
Χ
|
||||
θ
|
||||
P
|
||||
κ
|
||||
λ
|
||||
μ
|
||||
T
|
||||
ξ
|
||||
X
|
||||
β
|
||||
γ
|
||||
δ
|
||||
\
|
||||
ζ
|
||||
η
|
||||
`
|
||||
d
|
||||
<strike>
|
||||
h
|
||||
f
|
||||
l
|
||||
Θ
|
||||
p
|
||||
√
|
||||
t
|
||||
</sub>
|
||||
x
|
||||
Β
|
||||
Γ
|
||||
Δ
|
||||
|
|
||||
ǂ
|
||||
ɛ
|
||||
j
|
||||
̧
|
||||
➢
|
||||
|
||||
̌
|
||||
′
|
||||
«
|
||||
△
|
||||
▲
|
||||
#
|
||||
</b>
|
||||
'
|
||||
Ι
|
||||
+
|
||||
¶
|
||||
/
|
||||
▼
|
||||
⇑
|
||||
□
|
||||
·
|
||||
7
|
||||
▪
|
||||
;
|
||||
?
|
||||
➔
|
||||
∩
|
||||
C
|
||||
÷
|
||||
G
|
||||
⇒
|
||||
K
|
||||
<sup>
|
||||
O
|
||||
S
|
||||
С
|
||||
W
|
||||
Α
|
||||
[
|
||||
○
|
||||
_
|
||||
●
|
||||
‡
|
||||
c
|
||||
z
|
||||
g
|
||||
<i>
|
||||
o
|
||||
<sub>
|
||||
〈
|
||||
〉
|
||||
s
|
||||
⩽
|
||||
w
|
||||
φ
|
||||
ʹ
|
||||
{
|
||||
»
|
||||
∣
|
||||
̆
|
||||
e
|
||||
ˆ
|
||||
∈
|
||||
τ
|
||||
◆
|
||||
ι
|
||||
∅
|
||||
∆
|
||||
∙
|
||||
∘
|
||||
Ø
|
||||
ß
|
||||
✔
|
||||
∞
|
||||
∑
|
||||
−
|
||||
×
|
||||
◊
|
||||
∗
|
||||
∖
|
||||
˃
|
||||
˂
|
||||
∫
|
||||
"
|
||||
i
|
||||
&
|
||||
π
|
||||
↔
|
||||
*
|
||||
∥
|
||||
æ
|
||||
∧
|
||||
.
|
||||
⁄
|
||||
ø
|
||||
Q
|
||||
∼
|
||||
6
|
||||
⁎
|
||||
:
|
||||
★
|
||||
>
|
||||
a
|
||||
B
|
||||
≈
|
||||
F
|
||||
J
|
||||
̄
|
||||
N
|
||||
♯
|
||||
R
|
||||
V
|
||||
<overline>
|
||||
―
|
||||
Z
|
||||
♣
|
||||
^
|
||||
¤
|
||||
¥
|
||||
§
|
||||
<underline>
|
||||
¢
|
||||
£
|
||||
≦
|
||||
|
||||
≤
|
||||
‖
|
||||
Λ
|
||||
©
|
||||
n
|
||||
↓
|
||||
→
|
||||
↑
|
||||
r
|
||||
°
|
||||
±
|
||||
v
|
||||
<b>
|
||||
♂
|
||||
k
|
||||
♀
|
||||
~
|
||||
ᅟ
|
||||
̇
|
||||
@
|
||||
”
|
||||
♦
|
||||
ł
|
||||
®
|
||||
⊕
|
||||
„
|
||||
!
|
||||
</sup>
|
||||
%
|
||||
⇓
|
||||
)
|
||||
-
|
||||
1
|
||||
5
|
||||
9
|
||||
=
|
||||
А
|
||||
A
|
||||
‰
|
||||
⋆
|
||||
Σ
|
||||
E
|
||||
◦
|
||||
I
|
||||
※
|
||||
M
|
||||
m
|
||||
̨
|
||||
⩾
|
||||
†
|
||||
</i>
|
||||
•
|
||||
U
|
||||
Y
|
||||
|
||||
]
|
||||
̸
|
||||
2
|
||||
‐
|
||||
–
|
||||
‒
|
||||
̂
|
||||
—
|
||||
̀
|
||||
́
|
||||
’
|
||||
‘
|
||||
⋮
|
||||
⋯
|
||||
̊
|
||||
“
|
||||
̈
|
||||
≧
|
||||
q
|
||||
u
|
||||
ı
|
||||
y
|
||||
</underline>
|
||||
|
||||
̃
|
||||
}
|
||||
ν
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,214 @@
|
|||
import json
|
||||
def distance(box_1, box_2):
|
||||
x1, y1, x2, y2 = box_1
|
||||
x3, y3, x4, y4 = box_2
|
||||
# min_x = (x1 + x2) / 2
|
||||
# min_y = (y1 + y2) / 2
|
||||
# max_x = (x3 + x4) / 2
|
||||
# max_y = (y3 + y4) / 2
|
||||
dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
|
||||
dis_2 = abs(x3 - x1) + abs(y3 - y1)
|
||||
dis_3 = abs(x4- x2) + abs(y4 - y2)
|
||||
#dis = pow(min_x - max_x, 2) + pow(min_y - max_y, 2) + pow(x3 - x1, 2) + pow(y3 - y1, 2) + pow(x4- x2, 2) + pow(y4 - y2, 2) + abs(x3 - x1) + abs(y3 - y1) + abs(x4- x2) + abs(y4 - y2)
|
||||
return dis + min(dis_2, dis_3)
|
||||
|
||||
def compute_iou(rec1, rec2):
|
||||
"""
|
||||
computing IoU
|
||||
:param rec1: (y0, x0, y1, x1), which reflects
|
||||
(top, left, bottom, right)
|
||||
:param rec2: (y0, x0, y1, x1)
|
||||
:return: scala value of IoU
|
||||
"""
|
||||
# computing area of each rectangles
|
||||
rec1, rec2 = rec1 * 1000, rec2 * 1000
|
||||
S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
|
||||
S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
|
||||
|
||||
# computing the sum_area
|
||||
sum_area = S_rec1 + S_rec2
|
||||
|
||||
# find the each edge of intersect rectangle
|
||||
left_line = max(rec1[1], rec2[1])
|
||||
right_line = min(rec1[3], rec2[3])
|
||||
top_line = max(rec1[0], rec2[0])
|
||||
bottom_line = min(rec1[2], rec2[2])
|
||||
|
||||
# judge if there is an intersect
|
||||
if left_line >= right_line or top_line >= bottom_line:
|
||||
return 0
|
||||
else:
|
||||
intersect = (right_line - left_line) * (bottom_line - top_line)
|
||||
return (intersect / (sum_area - intersect))*1.0
|
||||
|
||||
|
||||
|
||||
def matcher_merge(ocr_bboxes, pred_bboxes): # ocr_bboxes: OCR pred_bboxes:端到端
|
||||
all_dis = []
|
||||
ious = []
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(ocr_bboxes):
|
||||
distances = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
distances.append((distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) #获取两两cell之间的L1距离和 1- IOU
|
||||
sorted_distances = distances.copy()
|
||||
# 根据距离和IOU挑选最"近"的cell
|
||||
sorted_distances = sorted(sorted_distances, key = lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
else:
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched#, sum(ious) / len(ious)
|
||||
def complex_num(pred_bboxes):
|
||||
complex_nums = []
|
||||
for bbox in pred_bboxes:
|
||||
distances = []
|
||||
temp_ious = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
if bbox != pred_bbox:
|
||||
distances.append(distance(bbox, pred_bbox))
|
||||
temp_ious.append(compute_iou(bbox, pred_bbox))
|
||||
complex_nums.append(temp_ious[distances.index(min(distances))])
|
||||
return sum(complex_nums) / len(complex_nums)
|
||||
|
||||
def get_rows(pred_bboxes):
|
||||
pre_bbox = pred_bboxes[0]
|
||||
res = []
|
||||
step = 0
|
||||
for i in range(len(pred_bboxes)):
|
||||
bbox = pred_bboxes[i]
|
||||
if bbox[1] - pre_bbox[1] > 2 or bbox[0] - pre_bbox[0] < 0:
|
||||
break
|
||||
else:
|
||||
res.append(bbox)
|
||||
step += 1
|
||||
for i in range(step):
|
||||
pred_bboxes.pop(0)
|
||||
return res, pred_bboxes
|
||||
def refine_rows(pred_bboxes): # 微调整行的框,使在一条水平线上
|
||||
ys_1 = []
|
||||
ys_2 = []
|
||||
for box in pred_bboxes:
|
||||
ys_1.append(box[1])
|
||||
ys_2.append(box[3])
|
||||
min_y_1 = sum(ys_1) / len(ys_1)
|
||||
min_y_2 = sum(ys_2) / len(ys_2)
|
||||
re_boxes = []
|
||||
for box in pred_bboxes:
|
||||
box[1] = min_y_1
|
||||
box[3] = min_y_2
|
||||
re_boxes.append(box)
|
||||
return re_boxes
|
||||
|
||||
def matcher_refine_row(gt_bboxes, pred_bboxes):
|
||||
before_refine_pred_bboxes = pred_bboxes.copy()
|
||||
pred_bboxes = []
|
||||
while(len(before_refine_pred_bboxes) != 0):
|
||||
row_bboxes, before_refine_pred_bboxes = get_rows(before_refine_pred_bboxes)
|
||||
print(row_bboxes)
|
||||
pred_bboxes.extend(refine_rows(row_bboxes))
|
||||
all_dis = []
|
||||
ious = []
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(gt_bboxes):
|
||||
distances = []
|
||||
#temp_ious = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
distances.append(distance(gt_box, pred_box))
|
||||
#temp_ious.append(compute_iou(gt_box, pred_box))
|
||||
#all_dis.append(min(distances))
|
||||
#ious.append(temp_ious[distances.index(min(distances))])
|
||||
if distances.index(min(distances)) not in matched.keys():
|
||||
matched[distances.index(min(distances))] = [i]
|
||||
else:
|
||||
matched[distances.index(min(distances))].append(i)
|
||||
return matched#, sum(ious) / len(ious)
|
||||
|
||||
|
||||
|
||||
#先挑选出一行,再进行匹配
|
||||
def matcher_structure_1(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
||||
gt_box_index = 0
|
||||
delete_gt_bboxes = gt_bboxes.copy()
|
||||
match_bboxes_ready = []
|
||||
matched = {}
|
||||
while(len(delete_gt_bboxes) != 0):
|
||||
row_bboxes, delete_gt_bboxes = get_rows(delete_gt_bboxes)
|
||||
row_bboxes = sorted(row_bboxes, key = lambda key: key[0])
|
||||
if len(pred_bboxes_rows) > 0:
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
print(row_bboxes)
|
||||
for i, gt_box in enumerate(row_bboxes):
|
||||
#print(gt_box)
|
||||
pred_distances = []
|
||||
distances = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
pred_distances.append(distance(gt_box, pred_bbox))
|
||||
for j, pred_box in enumerate(match_bboxes_ready):
|
||||
distances.append(distance(gt_box, pred_box))
|
||||
index = pred_distances.index(min(distances))
|
||||
#print('index', index)
|
||||
if index not in matched.keys():
|
||||
matched[index] = [gt_box_index]
|
||||
else:
|
||||
matched[index].append(gt_box_index)
|
||||
gt_box_index += 1
|
||||
return matched
|
||||
|
||||
def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
||||
'''
|
||||
gt_bboxes: 排序后
|
||||
pred_bboxes:
|
||||
'''
|
||||
pre_bbox = gt_bboxes[0]
|
||||
matched = {}
|
||||
match_bboxes_ready = []
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
for i, gt_box in enumerate(gt_bboxes):
|
||||
|
||||
pred_distances = []
|
||||
for pred_bbox in pred_bboxes:
|
||||
pred_distances.append(distance(gt_box, pred_bbox))
|
||||
distances = []
|
||||
gap_pre = gt_box[1] - pre_bbox[1]
|
||||
gap_pre_1 = gt_box[0] - pre_bbox[2]
|
||||
#print(gap_pre, len(pred_bboxes_rows))
|
||||
if (gap_pre_1 < 0 and len(pred_bboxes_rows) > 0):
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
if len(pred_bboxes_rows) == 1:
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) > 0:
|
||||
match_bboxes_ready.extend(pred_bboxes_rows.pop(0))
|
||||
if len(match_bboxes_ready) == 0 and len(pred_bboxes_rows) == 0:
|
||||
break
|
||||
#print(match_bboxes_ready)
|
||||
for j, pred_box in enumerate(match_bboxes_ready):
|
||||
distances.append(distance(gt_box, pred_box))
|
||||
index = pred_distances.index(min(distances))
|
||||
#print(gt_box, index)
|
||||
#match_bboxes_ready.pop(distances.index(min(distances)))
|
||||
print(gt_box, match_bboxes_ready[distances.index(min(distances))])
|
||||
if index not in matched.keys():
|
||||
matched[index] = [i]
|
||||
else:
|
||||
matched[index].append(i)
|
||||
pre_bbox = gt_box
|
||||
return matched
|
||||
|
||||
|
||||
def main():
|
||||
detect_bboxes = json.load(open('./f_detecion_bbox.json'))
|
||||
gt_bboxes = json.load(open('./f_gt_bbox.json'))
|
||||
all_node = 0
|
||||
matched_right = 0
|
||||
key = 'PMC4796501_003_00.png'
|
||||
print(key)
|
||||
gt_bbox = gt_bboxes[key]
|
||||
pred_bbox = detect_bboxes[key]
|
||||
matched = matcher(gt_bbox, pred_bbox)
|
||||
print(matched)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
import time
|
||||
import tools.infer.utility as utility
|
||||
from tools.infer.predict_system import TextSystem
|
||||
from ppstructure.table.predict_table import TableSystem, to_excel
|
||||
from ppstructure.layout.predict_layout import LayoutDetector
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = utility.init_args()
|
||||
|
||||
# params for table structure
|
||||
parser.add_argument("--table_max_len", type=int, default=488)
|
||||
parser.add_argument("--table_max_text_length", type=int, default=100)
|
||||
parser.add_argument("--table_max_elem_length", type=int, default=800)
|
||||
parser.add_argument("--table_max_cell_num", type=int, default=500)
|
||||
parser.add_argument("--table_model_dir", type=str)
|
||||
parser.add_argument("--table_char_type", type=str, default='en')
|
||||
parser.add_argument("--table_char_dict_path", type=str, default="./ppocr/utils/dict/table_structure_dict.txt")
|
||||
|
||||
# params for layout detector
|
||||
parser.add_argument("--layout_model_dir", type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class OCRSystem():
|
||||
def __init__(self, args):
|
||||
self.text_system = TextSystem(args)
|
||||
self.table_system = TableSystem(args)
|
||||
self.table_layout = LayoutDetector(args)
|
||||
self.use_angle_cls = args.use_angle_cls
|
||||
self.drop_score = args.drop_score
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
layout_res = self.table_layout(copy.deepcopy(img))
|
||||
for region in layout_res:
|
||||
x1, y1, x2, y2 = region['bbox']
|
||||
roi_img = ori_im[y1:y2, x1:x2,:]
|
||||
if region['label'] == 'table':
|
||||
res = self.table_system(roi_img)
|
||||
else:
|
||||
res = self.text_system(roi_img)
|
||||
region['res'] = res
|
||||
return layout_res
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||
excel_save_folder = 'output/table'
|
||||
os.makedirs(excel_save_folder, exist_ok=True)
|
||||
|
||||
text_sys = OCRSystem(args)
|
||||
img_num = len(image_file_list)
|
||||
for i, image_file in enumerate(image_file_list):
|
||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
imgname = os.path.basename(image_file).split('.')[0]
|
||||
# excel_path = os.path.join(excel_save_folder, + '.xlsx')
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
res = text_sys(img)
|
||||
|
||||
for region in res:
|
||||
if region['label'] == 'table':
|
||||
# x1, y1, x2, y2 = region['bbox']
|
||||
excel_path = os.path.join(excel_save_folder, '{}_{}.xlsx'.format(imgname,region['bbox']))
|
||||
to_excel(region['res'],excel_path)
|
||||
logger.info(res)
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
if args.use_mp:
|
||||
p_list = []
|
||||
total_process_num = args.total_process_num
|
||||
for process_id in range(total_process_num):
|
||||
cmd = [sys.executable, "-u"] + sys.argv + [
|
||||
"--process_id={}".format(process_id),
|
||||
"--use_mp={}".format(False)
|
||||
]
|
||||
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
|
||||
p_list.append(p)
|
||||
for p in p_list:
|
||||
p.wait()
|
||||
else:
|
||||
main(args)
|
|
@ -0,0 +1,13 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
import cv2
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from ppstructure.table.table_metric import TEDS
|
||||
from ppstructure.table.predict_table import TableSystem, utility
|
||||
|
||||
|
||||
def main(gt_path, img_root, args):
|
||||
teds = TEDS(n_jobs=16)
|
||||
|
||||
text_sys = TableSystem(args)
|
||||
jsons_gt = json.load(open(gt_path)) # gt
|
||||
pred_htmls = []
|
||||
gt_htmls = []
|
||||
for img_name in tqdm(jsons_gt):
|
||||
if img_name != 'PMC1064865_002_00.png':
|
||||
continue
|
||||
# 读取信息
|
||||
img = cv2.imread(os.path.join(img_root,img_name))
|
||||
pred_html = text_sys(img)
|
||||
pred_htmls.append(pred_html)
|
||||
|
||||
gt_structures, gt_bboxes, gt_contents, contents_with_block = jsons_gt[img_name]
|
||||
gt_html, gt = get_gt_html(gt_structures, contents_with_block) # 获取HTMLgt
|
||||
gt_htmls.append(gt_html)
|
||||
scores = teds.batch_evaluate_html(gt_htmls, pred_htmls) # 计算teds
|
||||
print('teds:', sum(scores) / len(scores))
|
||||
|
||||
|
||||
def get_gt_html(gt_structures, contents_with_block):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in gt_structures:
|
||||
if '</td>' in tag:
|
||||
if contents_with_block[td_index] != []:
|
||||
end_html.extend(contents_with_block[td_index])
|
||||
end_html.append(tag)
|
||||
td_index += 1
|
||||
else:
|
||||
end_html.append(tag)
|
||||
return ''.join(end_html), end_html
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = utility.parse_args()
|
||||
gt_path = 'table/match_code/f_gt_bbox.json'
|
||||
img_root = 'table/imgs'
|
||||
main(gt_path,img_root, args)
|
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import traceback
|
||||
import paddle
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TableStructurer(object):
|
||||
def __init__(self, args):
|
||||
pre_process_list = [{
|
||||
'ResizeTableImage': {
|
||||
'max_len': args.table_max_len
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'PaddingTableImage': None
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': ['image']
|
||||
}
|
||||
}]
|
||||
postprocess_params = {
|
||||
'name': 'TableLabelDecode',
|
||||
"character_type": args.table_char_type,
|
||||
"character_dict_path": args.table_char_dict_path,
|
||||
"max_text_length": args.table_max_text_length,
|
||||
"max_elem_length": args.table_max_elem_length,
|
||||
"max_cell_num": args.table_max_cell_num
|
||||
}
|
||||
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors = \
|
||||
utility.create_predictor(args, 'table', logger)
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
data = transform(data, self.preprocess_op)
|
||||
img = data[0]
|
||||
if img is None:
|
||||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
img = img.copy()
|
||||
starttime = time.time()
|
||||
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
|
||||
preds = {}
|
||||
preds['structure_probs'] = outputs[1]
|
||||
preds['loc_preds'] = outputs[0]
|
||||
|
||||
post_result = self.postprocess_op(preds)
|
||||
|
||||
structure_str_list = post_result['structure_str_list']
|
||||
res_loc = post_result['res_loc']
|
||||
imgh, imgw = ori_im.shape[0:2]
|
||||
res_loc_final = []
|
||||
for rno in range(len(res_loc[0])):
|
||||
x0, y0, x1, y1 = res_loc[0][rno]
|
||||
left = max(int(imgw * x0), 0)
|
||||
top = max(int(imgh * y0), 0)
|
||||
right = min(int(imgw * x1), imgw - 1)
|
||||
bottom = min(int(imgh * y1), imgh - 1)
|
||||
res_loc_final.append([left, top, right, bottom])
|
||||
|
||||
structure_str_list = structure_str_list[0][:-1]
|
||||
structure_str_list = ['<html>', '<body>', '<table>'] + structure_str_list + ['</table>', '</body>', '</html>']
|
||||
|
||||
elapse = time.time() - starttime
|
||||
return (structure_str_list, res_loc_final), elapse
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
table_structurer = TableStructurer(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
for image_file in image_file_list:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
structure_res, elapse = table_structurer(img)
|
||||
|
||||
logger.info("result: {}".format(structure_res))
|
||||
|
||||
if count > 0:
|
||||
total_time += elapse
|
||||
count += 1
|
||||
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(utility.parse_args())
|
|
@ -0,0 +1,222 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
import time
|
||||
import tools.infer.utility as utility
|
||||
import tools.infer.predict_rec as predict_rec
|
||||
import tools.infer.predict_det as predict_det
|
||||
import ppstructure.table.predict_structure as predict_strture
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.table_utils.matcher import distance, compute_iou
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def expand(pix, det_box, shape):
|
||||
x0, y0, x1, y1 = det_box
|
||||
# print(shape)
|
||||
h, w, c = shape
|
||||
tmp_x0 = x0 - pix
|
||||
tmp_x1 = x1 + pix
|
||||
tmp_y0 = y0 - pix
|
||||
tmp_y1 = y1 + pix
|
||||
x0_ = tmp_x0 if tmp_x0 >= 0 else 0
|
||||
x1_ = tmp_x1 if tmp_x1 <= w else w
|
||||
y0_ = tmp_y0 if tmp_y0 >= 0 else 0
|
||||
y1_ = tmp_y1 if tmp_y1 <= h else h
|
||||
return x0_, y0_, x1_, y1_
|
||||
|
||||
|
||||
class TableSystem(object):
|
||||
def __init__(self, args):
|
||||
self.text_detector = predict_det.TextDetector(args)
|
||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||
self.table_structurer = predict_strture.TableStructurer(args)
|
||||
self.use_angle_cls = args.use_angle_cls
|
||||
self.drop_score = args.drop_score
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
structure_res, elapse = self.table_structurer(copy.deepcopy(img))
|
||||
dt_boxes, elapse = self.text_detector(copy.deepcopy(img))
|
||||
dt_boxes = sorted_boxes(dt_boxes)
|
||||
|
||||
r_boxes = []
|
||||
for box in dt_boxes:
|
||||
x_min = box[:, 0].min() - 1
|
||||
x_max = box[:, 0].max() + 1
|
||||
y_min = box[:, 1].min() - 1
|
||||
y_max = box[:, 1].max() + 1
|
||||
box = [x_min, y_min, x_max, y_max]
|
||||
r_boxes.append(box)
|
||||
dt_boxes = np.array(r_boxes)
|
||||
|
||||
# logger.info("dt_boxes num : {}, elapse : {}".format(
|
||||
# len(dt_boxes), elapse))
|
||||
if dt_boxes is None:
|
||||
return None, None
|
||||
img_crop_list = []
|
||||
|
||||
for i in range(len(dt_boxes)):
|
||||
det_box = dt_boxes[i]
|
||||
x0, y0, x1, y1 = expand(2, det_box, ori_im.shape)
|
||||
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
|
||||
img_crop_list.append(text_rect)
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
# logger.info("rec_res num : {}, elapse : {}".format(
|
||||
# len(rec_res), elapse))
|
||||
|
||||
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
|
||||
return pred_html
|
||||
|
||||
def rebuild_table(self, structure_res, dt_boxes, rec_res):
|
||||
pred_structures, pred_bboxes = structure_res
|
||||
matched_index = self.match_result(dt_boxes, pred_bboxes)
|
||||
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
|
||||
return pred_html, pred
|
||||
|
||||
def match_result(self, dt_boxes, pred_bboxes):
|
||||
matched = {}
|
||||
for i, gt_box in enumerate(dt_boxes):
|
||||
# gt_box = [np.min(gt_box[:, 0]), np.min(gt_box[:, 1]), np.max(gt_box[:, 0]), np.max(gt_box[:, 1])]
|
||||
distances = []
|
||||
for j, pred_box in enumerate(pred_bboxes):
|
||||
distances.append(
|
||||
(distance(gt_box, pred_box), 1. - compute_iou(gt_box, pred_box))) # 获取两两cell之间的L1距离和 1- IOU
|
||||
sorted_distances = distances.copy()
|
||||
# 根据距离和IOU挑选最"近"的cell
|
||||
sorted_distances = sorted(sorted_distances, key=lambda item: (item[1], item[0]))
|
||||
if distances.index(sorted_distances[0]) not in matched.keys():
|
||||
matched[distances.index(sorted_distances[0])] = [i]
|
||||
else:
|
||||
matched[distances.index(sorted_distances[0])].append(i)
|
||||
return matched
|
||||
|
||||
def get_pred_html(self, pred_structures, matched_index, ocr_contents):
|
||||
end_html = []
|
||||
td_index = 0
|
||||
for tag in pred_structures:
|
||||
if '</td>' in tag:
|
||||
if td_index in matched_index.keys():
|
||||
b_with = False
|
||||
if '<b>' in ocr_contents[matched_index[td_index][0]] and len(matched_index[td_index]) > 1:
|
||||
b_with = True
|
||||
end_html.extend('<b>')
|
||||
for i, td_index_index in enumerate(matched_index[td_index]):
|
||||
content = ocr_contents[td_index_index][0]
|
||||
if len(matched_index[td_index]) > 1:
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if content[0] == ' ':
|
||||
content = content[1:]
|
||||
if '<b>' in content:
|
||||
content = content[3:]
|
||||
if '</b>' in content:
|
||||
content = content[:-4]
|
||||
if len(content) == 0:
|
||||
continue
|
||||
if i != len(matched_index[td_index]) - 1 and ' ' != content[-1]:
|
||||
content += ' '
|
||||
end_html.extend(content)
|
||||
if b_with:
|
||||
end_html.extend('</b>')
|
||||
|
||||
end_html.append(tag)
|
||||
td_index += 1
|
||||
else:
|
||||
end_html.append(tag)
|
||||
return ''.join(end_html), end_html
|
||||
|
||||
|
||||
def sorted_boxes(dt_boxes):
|
||||
"""
|
||||
Sort text boxes in order from top to bottom, left to right
|
||||
args:
|
||||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||
return:
|
||||
sorted boxes(array) with shape [4, 2]
|
||||
"""
|
||||
num_boxes = dt_boxes.shape[0]
|
||||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
tmp = _boxes[i]
|
||||
_boxes[i] = _boxes[i + 1]
|
||||
_boxes[i + 1] = tmp
|
||||
return _boxes
|
||||
|
||||
def to_excel(html_table, excel_path):
|
||||
from tablepyxl import tablepyxl
|
||||
tablepyxl.document_to_xl(html_table, excel_path)
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||
excel_save_folder = 'output/table'
|
||||
os.makedirs(excel_save_folder, exist_ok=True)
|
||||
|
||||
text_sys = TableSystem(args)
|
||||
img_num = len(image_file_list)
|
||||
for i, image_file in enumerate(image_file_list):
|
||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
excel_path = os.path.join(excel_save_folder, os.path.basename(image_file).split('.')[0] + '.xlsx')
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
pred_html = text_sys(img)
|
||||
|
||||
to_excel(pred_html, excel_path)
|
||||
logger.info('excel saved to {}'.format(excel_path))
|
||||
logger.info(pred_html)
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = utility.parse_args()
|
||||
if args.use_mp:
|
||||
p_list = []
|
||||
total_process_num = args.total_process_num
|
||||
for process_id in range(total_process_num):
|
||||
cmd = [sys.executable, "-u"] + sys.argv + [
|
||||
"--process_id={}".format(process_id),
|
||||
"--use_mp={}".format(False)
|
||||
]
|
||||
p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
|
||||
p_list.append(p)
|
||||
for p in p_list:
|
||||
p.wait()
|
||||
else:
|
||||
main(args)
|
|
@ -0,0 +1,16 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['TEDS']
|
||||
from .table_metric import TEDS
|
|
@ -0,0 +1,51 @@
|
|||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
|
||||
def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
|
||||
"""
|
||||
A parallel version of the map function with a progress bar.
|
||||
Args:
|
||||
array (array-like): An array to iterate over.
|
||||
function (function): A python function to apply to the elements of array
|
||||
n_jobs (int, default=16): The number of cores to use
|
||||
use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
|
||||
keyword arguments to function
|
||||
front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
|
||||
Useful for catching bugs
|
||||
Returns:
|
||||
[function(array[0]), function(array[1]), ...]
|
||||
"""
|
||||
# We run the first few iterations serially to catch bugs
|
||||
if front_num > 0:
|
||||
front = [function(**a) if use_kwargs else function(a)
|
||||
for a in array[:front_num]]
|
||||
else:
|
||||
front = []
|
||||
# If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
|
||||
if n_jobs == 1:
|
||||
return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
|
||||
# Assemble the workers
|
||||
with ProcessPoolExecutor(max_workers=n_jobs) as pool:
|
||||
# Pass the elements of array into function
|
||||
if use_kwargs:
|
||||
futures = [pool.submit(function, **a) for a in array[front_num:]]
|
||||
else:
|
||||
futures = [pool.submit(function, a) for a in array[front_num:]]
|
||||
kwargs = {
|
||||
'total': len(futures),
|
||||
'unit': 'it',
|
||||
'unit_scale': True,
|
||||
'leave': True
|
||||
}
|
||||
# Print out the progress as tasks complete
|
||||
for f in tqdm(as_completed(futures), **kwargs):
|
||||
pass
|
||||
out = []
|
||||
# Get the results from the futures.
|
||||
for i, future in tqdm(enumerate(futures)):
|
||||
try:
|
||||
out.append(future.result())
|
||||
except Exception as e:
|
||||
out.append(e)
|
||||
return front + out
|
|
@ -125,6 +125,8 @@ def create_predictor(args, mode, logger):
|
|||
model_dir = args.cls_model_dir
|
||||
elif mode == 'rec':
|
||||
model_dir = args.rec_model_dir
|
||||
elif mode == 'table':
|
||||
model_dir = args.table_model_dir
|
||||
else:
|
||||
model_dir = args.e2e_model_dir
|
||||
|
||||
|
@ -244,7 +246,8 @@ def create_predictor(args, mode, logger):
|
|||
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
|
||||
if mode == 'table':
|
||||
config.switch_ir_optim(False)
|
||||
# create predictor
|
||||
predictor = inference.create_predictor(config)
|
||||
input_names = predictor.get_input_names()
|
||||
|
|
Loading…
Reference in New Issue