fix table infer bug
This commit is contained in:
parent
e836ab7f51
commit
330f08ffc7
|
@ -29,6 +29,7 @@ from .label_ops import *
|
|||
from .east_process import *
|
||||
from .sast_process import *
|
||||
from .pg_process import *
|
||||
from .gen_table_mask import *
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
"""
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GenTableMask(object):
|
||||
""" gen table mask """
|
||||
|
||||
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
|
||||
self.shrink_h_max = 5
|
||||
self.shrink_w_max = 5
|
||||
self.mask_type = mask_type
|
||||
|
||||
def projection(self, erosion, h, w, spilt_threshold=0):
|
||||
# 水平投影
|
||||
projection_map = np.ones_like(erosion)
|
||||
project_val_array = [0 for _ in range(0, h)]
|
||||
|
||||
for j in range(0, h):
|
||||
for i in range(0, w):
|
||||
if erosion[j, i] == 255:
|
||||
project_val_array[j] += 1
|
||||
# 根据数组,获取切割点
|
||||
start_idx = 0 # 记录进入字符区的索引
|
||||
end_idx = 0 # 记录进入空白区域的索引
|
||||
in_text = False # 是否遍历到了字符区内
|
||||
box_list = []
|
||||
for i in range(len(project_val_array)):
|
||||
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
||||
in_text = True
|
||||
start_idx = i
|
||||
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
||||
end_idx = i
|
||||
in_text = False
|
||||
if end_idx - start_idx <= 2:
|
||||
continue
|
||||
box_list.append((start_idx, end_idx + 1))
|
||||
|
||||
if in_text:
|
||||
box_list.append((start_idx, h - 1))
|
||||
# 绘制投影直方图
|
||||
for j in range(0, h):
|
||||
for i in range(0, project_val_array[j]):
|
||||
projection_map[j, i] = 0
|
||||
return box_list, projection_map
|
||||
|
||||
def projection_cx(self, box_img):
|
||||
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
|
||||
h, w = box_gray_img.shape
|
||||
# 灰度图片进行二值化处理
|
||||
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
|
||||
# 纵向腐蚀
|
||||
if h < w:
|
||||
kernel = np.ones((2, 1), np.uint8)
|
||||
erode = cv2.erode(thresh1, kernel, iterations=1)
|
||||
else:
|
||||
erode = thresh1
|
||||
# 水平膨胀
|
||||
kernel = np.ones((1, 5), np.uint8)
|
||||
erosion = cv2.dilate(erode, kernel, iterations=1)
|
||||
# 水平投影
|
||||
projection_map = np.ones_like(erosion)
|
||||
project_val_array = [0 for _ in range(0, h)]
|
||||
|
||||
for j in range(0, h):
|
||||
for i in range(0, w):
|
||||
if erosion[j, i] == 255:
|
||||
project_val_array[j] += 1
|
||||
# 根据数组,获取切割点
|
||||
start_idx = 0 # 记录进入字符区的索引
|
||||
end_idx = 0 # 记录进入空白区域的索引
|
||||
in_text = False # 是否遍历到了字符区内
|
||||
box_list = []
|
||||
spilt_threshold = 0
|
||||
for i in range(len(project_val_array)):
|
||||
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
||||
in_text = True
|
||||
start_idx = i
|
||||
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
||||
end_idx = i
|
||||
in_text = False
|
||||
if end_idx - start_idx <= 2:
|
||||
continue
|
||||
box_list.append((start_idx, end_idx + 1))
|
||||
|
||||
if in_text:
|
||||
box_list.append((start_idx, h - 1))
|
||||
# 绘制投影直方图
|
||||
for j in range(0, h):
|
||||
for i in range(0, project_val_array[j]):
|
||||
projection_map[j, i] = 0
|
||||
split_bbox_list = []
|
||||
if len(box_list) > 1:
|
||||
for i, (h_start, h_end) in enumerate(box_list):
|
||||
if i == 0:
|
||||
h_start = 0
|
||||
if i == len(box_list):
|
||||
h_end = h
|
||||
word_img = erosion[h_start:h_end + 1, :]
|
||||
word_h, word_w = word_img.shape
|
||||
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
|
||||
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
|
||||
if h_start > 0:
|
||||
h_start -= 1
|
||||
h_end += 1
|
||||
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
|
||||
split_bbox_list.append([w_start, h_start, w_end, h_end])
|
||||
else:
|
||||
split_bbox_list.append([0, 0, w, h])
|
||||
return split_bbox_list
|
||||
|
||||
def shrink_bbox(self, bbox):
|
||||
left, top, right, bottom = bbox
|
||||
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
|
||||
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
|
||||
left_new = left + sh_w
|
||||
right_new = right - sh_w
|
||||
top_new = top + sh_h
|
||||
bottom_new = bottom - sh_h
|
||||
if left_new >= right_new:
|
||||
left_new = left
|
||||
right_new = right
|
||||
if top_new >= bottom_new:
|
||||
top_new = top
|
||||
bottom_new = bottom
|
||||
return [left_new, top_new, right_new, bottom_new]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
cells = data['cells']
|
||||
height, width = img.shape[0:2]
|
||||
if self.mask_type == 1:
|
||||
mask_img = np.zeros((height, width), dtype=np.float32)
|
||||
else:
|
||||
mask_img = np.zeros((height, width, 3), dtype=np.float32)
|
||||
cell_num = len(cells)
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in cells[cno]:
|
||||
bbox = cells[cno]['bbox']
|
||||
left, top, right, bottom = bbox
|
||||
box_img = img[top:bottom, left:right, :].copy()
|
||||
split_bbox_list = self.projection_cx(box_img)
|
||||
for sno in range(len(split_bbox_list)):
|
||||
split_bbox_list[sno][0] += left
|
||||
split_bbox_list[sno][1] += top
|
||||
split_bbox_list[sno][2] += left
|
||||
split_bbox_list[sno][3] += top
|
||||
|
||||
for sno in range(len(split_bbox_list)):
|
||||
left, top, right, bottom = split_bbox_list[sno]
|
||||
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
|
||||
if self.mask_type == 1:
|
||||
mask_img[top:bottom, left:right] = 1.0
|
||||
data['mask_img'] = mask_img
|
||||
else:
|
||||
mask_img[top:bottom, left:right, :] = (255, 255, 255)
|
||||
data['image'] = mask_img
|
||||
return data
|
||||
|
||||
class ResizeTableImage(object):
|
||||
def __init__(self, max_len, **kwargs):
|
||||
super(ResizeTableImage, self).__init__()
|
||||
self.max_len = max_len
|
||||
|
||||
def get_img_bbox(self, cells):
|
||||
bbox_list = []
|
||||
if len(cells) == 0:
|
||||
return bbox_list
|
||||
cell_num = len(cells)
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in cells[cno]:
|
||||
bbox = cells[cno]['bbox']
|
||||
bbox_list.append(bbox)
|
||||
return bbox_list
|
||||
|
||||
def resize_img_table(self, img, bbox_list, max_len):
|
||||
height, width = img.shape[0:2]
|
||||
ratio = max_len / (max(height, width) * 1.0)
|
||||
resize_h = int(height * ratio)
|
||||
resize_w = int(width * ratio)
|
||||
img_new = cv2.resize(img, (resize_w, resize_h))
|
||||
bbox_list_new = []
|
||||
for bno in range(len(bbox_list)):
|
||||
left, top, right, bottom = bbox_list[bno].copy()
|
||||
left = int(left * ratio)
|
||||
top = int(top * ratio)
|
||||
right = int(right * ratio)
|
||||
bottom = int(bottom * ratio)
|
||||
bbox_list_new.append([left, top, right, bottom])
|
||||
return img_new, bbox_list_new
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if 'cells' not in data:
|
||||
cells = []
|
||||
else:
|
||||
cells = data['cells']
|
||||
bbox_list = self.get_img_bbox(cells)
|
||||
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
|
||||
data['image'] = img_new
|
||||
cell_num = len(cells)
|
||||
bno = 0
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in data['cells'][cno]:
|
||||
data['cells'][cno]['bbox'] = bbox_list_new[bno]
|
||||
bno += 1
|
||||
data['max_len'] = self.max_len
|
||||
return data
|
||||
|
||||
class PaddingTableImage(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(PaddingTableImage, self).__init__()
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
max_len = data['max_len']
|
||||
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
|
||||
height, width = img.shape[0:2]
|
||||
padding_img[0:height, 0:width, :] = img.copy()
|
||||
data['image'] = padding_img
|
||||
return data
|
||||
|
|
@ -24,7 +24,8 @@ __all__ = ['build_post_process']
|
|||
from .db_postprocess import DBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
TableLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
|
@ -33,7 +34,7 @@ def build_post_process(config, global_config=None):
|
|||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode'
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -32,6 +32,7 @@ from ppocr.data import create_operators, transform
|
|||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppstructure.utility import parse_args
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
@ -69,7 +70,7 @@ class TableStructurer(object):
|
|||
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors = \
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'structure', logger)
|
||||
|
||||
def __call__(self, img):
|
||||
|
@ -138,4 +139,4 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(utility.parse_args())
|
||||
main(parse_args())
|
||||
|
|
|
@ -187,7 +187,7 @@ def main(args):
|
|||
for i, image_file in enumerate(image_file_list):
|
||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
excel_path = os.path.join(args.table_output, os.path.basename(image_file).split('.')[0] + '.xlsx')
|
||||
excel_path = os.path.join(args.output, os.path.basename(image_file).split('.')[0] + '.xlsx')
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
|
|
Loading…
Reference in New Issue