cherry-pick 3505
This commit is contained in:
parent
6887f45720
commit
44826b515a
|
@ -19,6 +19,7 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import string
|
import string
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
class ClsLabelEncode(object):
|
class ClsLabelEncode(object):
|
||||||
|
@ -39,7 +40,6 @@ class DetLabelEncode(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
import json
|
|
||||||
label = data['label']
|
label = data['label']
|
||||||
label = json.loads(label)
|
label = json.loads(label)
|
||||||
nBox = len(label)
|
nBox = len(label)
|
||||||
|
@ -54,8 +54,8 @@ class DetLabelEncode(object):
|
||||||
else:
|
else:
|
||||||
txt_tags.append(False)
|
txt_tags.append(False)
|
||||||
boxes = self.expand_points_num(boxes)
|
boxes = self.expand_points_num(boxes)
|
||||||
boxes = np.array(boxes, dtype=np.float32)
|
#boxes = np.array(boxes, dtype=np.float32)
|
||||||
txt_tags = np.array(txt_tags, dtype=np.bool)
|
#txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||||
|
|
||||||
data['polys'] = boxes
|
data['polys'] = boxes
|
||||||
data['texts'] = txts
|
data['texts'] = txts
|
||||||
|
@ -352,8 +352,10 @@ class SRNLabelEncode(BaseRecLabelEncode):
|
||||||
% beg_or_end
|
% beg_or_end
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|
||||||
class TableLabelEncode(object):
|
class TableLabelEncode(object):
|
||||||
""" Convert between text-label and text-index """
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
max_text_length,
|
max_text_length,
|
||||||
max_elem_length,
|
max_elem_length,
|
||||||
|
@ -364,7 +366,8 @@ class TableLabelEncode(object):
|
||||||
self.max_text_length = max_text_length
|
self.max_text_length = max_text_length
|
||||||
self.max_elem_length = max_elem_length
|
self.max_elem_length = max_elem_length
|
||||||
self.max_cell_num = max_cell_num
|
self.max_cell_num = max_cell_num
|
||||||
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
list_character, list_elem = self.load_char_elem_dict(
|
||||||
|
character_dict_path)
|
||||||
list_character = self.add_special_char(list_character)
|
list_character = self.add_special_char(list_character)
|
||||||
list_elem = self.add_special_char(list_elem)
|
list_elem = self.add_special_char(list_elem)
|
||||||
self.dict_character = {}
|
self.dict_character = {}
|
||||||
|
@ -412,18 +415,22 @@ class TableLabelEncode(object):
|
||||||
return None
|
return None
|
||||||
elem_num = len(structure)
|
elem_num = len(structure)
|
||||||
structure = [0] + structure + [len(self.dict_elem) - 1]
|
structure = [0] + structure + [len(self.dict_elem) - 1]
|
||||||
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
|
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
|
||||||
|
)
|
||||||
structure = np.array(structure)
|
structure = np.array(structure)
|
||||||
data['structure'] = structure
|
data['structure'] = structure
|
||||||
elem_char_idx1 = self.dict_elem['<td>']
|
elem_char_idx1 = self.dict_elem['<td>']
|
||||||
elem_char_idx2 = self.dict_elem['<td']
|
elem_char_idx2 = self.dict_elem['<td']
|
||||||
span_idx_list = self.get_span_idx_list()
|
span_idx_list = self.get_span_idx_list()
|
||||||
td_idx_list = np.logical_or(structure == elem_char_idx1, structure == elem_char_idx2)
|
td_idx_list = np.logical_or(structure == elem_char_idx1,
|
||||||
|
structure == elem_char_idx2)
|
||||||
td_idx_list = np.where(td_idx_list)[0]
|
td_idx_list = np.where(td_idx_list)[0]
|
||||||
|
|
||||||
structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
|
structure_mask = np.ones(
|
||||||
|
(self.max_elem_length + 2, 1), dtype=np.float32)
|
||||||
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
|
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
|
||||||
bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
|
bbox_list_mask = np.zeros(
|
||||||
|
(self.max_elem_length + 2, 1), dtype=np.float32)
|
||||||
img_height, img_width, img_ch = data['image'].shape
|
img_height, img_width, img_ch = data['image'].shape
|
||||||
if len(span_idx_list) > 0:
|
if len(span_idx_list) > 0:
|
||||||
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
|
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
|
||||||
|
@ -450,9 +457,11 @@ class TableLabelEncode(object):
|
||||||
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
||||||
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
||||||
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
||||||
data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
|
data['sp_tokens'] = np.array([
|
||||||
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
|
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
|
||||||
self.max_elem_length, self.max_cell_num, elem_num])
|
elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||||
|
self.max_elem_length, self.max_cell_num, elem_num
|
||||||
|
])
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def encode(self, text, char_or_elem):
|
def encode(self, text, char_or_elem):
|
||||||
|
@ -509,4 +518,3 @@ class TableLabelEncode(object):
|
||||||
assert False, "Unsupport type %s in char_or_elem" \
|
assert False, "Unsupport type %s in char_or_elem" \
|
||||||
% char_or_elem
|
% char_or_elem
|
||||||
return idx
|
return idx
|
||||||
|
|
|
@ -24,6 +24,7 @@ from paddle import inference
|
||||||
import time
|
import time
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
def str2bool(v):
|
def str2bool(v):
|
||||||
return v.lower() in ("true", "t", "1")
|
return v.lower() in ("true", "t", "1")
|
||||||
|
|
||||||
|
@ -47,8 +48,8 @@ def init_args():
|
||||||
|
|
||||||
# DB parmas
|
# DB parmas
|
||||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||||
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
|
parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
|
||||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
|
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
|
||||||
parser.add_argument("--max_batch_size", type=int, default=10)
|
parser.add_argument("--max_batch_size", type=int, default=10)
|
||||||
parser.add_argument("--use_dilation", type=bool, default=False)
|
parser.add_argument("--use_dilation", type=bool, default=False)
|
||||||
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
||||||
|
|
Loading…
Reference in New Issue