init commit for paddlestructure
This commit is contained in:
parent
a5f7511505
commit
bc0d766425
|
@ -1,7 +1,7 @@
|
|||
include LICENSE.txt
|
||||
include README.md
|
||||
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
||||
recursive-include ppocr/data/ *.py
|
||||
recursive-include ppocr/postprocess *.py
|
||||
recursive-include tools/infer *.py
|
||||
|
|
|
@ -19,6 +19,7 @@ __dir__ = os.path.dirname(__file__)
|
|||
sys.path.append(os.path.join(__dir__, ''))
|
||||
|
||||
import cv2
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -150,6 +151,8 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
"""
|
||||
params = parse_args(mMain=False)
|
||||
params.__dict__.update(**kwargs)
|
||||
if params.show_log:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
self.use_angle_cls = params.use_angle_cls
|
||||
lang = params.lang
|
||||
latin_lang = [
|
||||
|
|
|
@ -33,8 +33,7 @@ D
|
|||
Π
|
||||
H
|
||||
║
|
||||
</
|
||||
>
|
||||
</strike>
|
||||
L
|
||||
Φ
|
||||
Χ
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
include LICENSE.txt
|
||||
include README.md
|
||||
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
||||
recursive-include ppocr/data/ *.py
|
||||
recursive-include ppocr/postprocess *.py
|
||||
recursive-include tools/infer *.py
|
||||
recursive-include table *.py
|
||||
recursive-include ppstructure *.py
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(os.path.join(__dir__, ''))
|
||||
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
from predict_system import OCRSystem, save_res
|
||||
from utility import init_args
|
||||
|
||||
logger = get_logger()
|
||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||
from ppocr.utils.network import maybe_download, download_with_progressbar
|
||||
|
||||
__all__ = ['PaddleStructure']
|
||||
|
||||
VERSION = '2.1'
|
||||
BASE_DIR = os.path.expanduser("~/.paddlestructure/")
|
||||
|
||||
model_urls = {
|
||||
'det': {
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||
},
|
||||
'rec': {
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||
},
|
||||
'structure': {
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def parse_args(mMain=True):
|
||||
import argparse
|
||||
parser = init_args()
|
||||
parser.add_help = mMain
|
||||
|
||||
for action in parser._actions:
|
||||
if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']:
|
||||
action.default = None
|
||||
if mMain:
|
||||
return parser.parse_args()
|
||||
else:
|
||||
inference_args_dict = {}
|
||||
for action in parser._actions:
|
||||
inference_args_dict[action.dest] = action.default
|
||||
return argparse.Namespace(**inference_args_dict)
|
||||
|
||||
|
||||
class PaddleStructure(OCRSystem):
|
||||
def __init__(self, **kwargs):
|
||||
params = parse_args(mMain=False)
|
||||
params.__dict__.update(**kwargs)
|
||||
if params.show_log:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
params.use_angle_cls = False
|
||||
# init model dir
|
||||
if params.det_model_dir is None:
|
||||
params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det')
|
||||
if params.rec_model_dir is None:
|
||||
params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec')
|
||||
if params.structure_model_dir is None:
|
||||
params.structure_model_dir = os.path.join(BASE_DIR, VERSION, 'structure')
|
||||
# download model
|
||||
maybe_download(params.det_model_dir, model_urls['det'])
|
||||
maybe_download(params.det_model_dir, model_urls['rec'])
|
||||
maybe_download(params.det_model_dir, model_urls['structure'])
|
||||
|
||||
if params.rec_char_dict_path is None:
|
||||
params.rec_char_type = 'EN'
|
||||
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
|
||||
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
|
||||
else:
|
||||
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
|
||||
if params.structure_char_dict_path is None:
|
||||
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
|
||||
params.structure_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
|
||||
else:
|
||||
params.structure_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
|
||||
|
||||
print(params)
|
||||
super().__init__(params)
|
||||
|
||||
def __call__(self, img):
|
||||
if isinstance(img, str):
|
||||
# download net image
|
||||
if img.startswith('http'):
|
||||
download_with_progressbar(img, 'tmp.jpg')
|
||||
img = 'tmp.jpg'
|
||||
image_file = img
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
with open(image_file, 'rb') as f:
|
||||
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
|
||||
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
return None
|
||||
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
res = super().__call__(img)
|
||||
return res
|
||||
|
||||
|
||||
def main():
|
||||
# for cmd
|
||||
args = parse_args(mMain=True)
|
||||
image_dir = args.image_dir
|
||||
save_folder = args.output
|
||||
if image_dir.startswith('http'):
|
||||
download_with_progressbar(image_dir, 'tmp.jpg')
|
||||
image_file_list = ['tmp.jpg']
|
||||
else:
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
if len(image_file_list) == 0:
|
||||
logger.error('no images find in {}'.format(args.image_dir))
|
||||
return
|
||||
|
||||
structure_engine = PaddleStructure(**(args.__dict__))
|
||||
for img_path in image_file_list:
|
||||
img_name = os.path.basename(img_path).split('.')[0]
|
||||
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
|
||||
result = structure_engine(img_path)
|
||||
save_res(result, args.output, os.path.basename(img_path).split('.')[0])
|
||||
for item in result:
|
||||
logger.info(item['res'])
|
||||
save_res(result, save_folder, img_name)
|
||||
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
table_engine = PaddleStructure(det_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_det_infer',
|
||||
rec_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_rec_infer',
|
||||
structure_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_structure_infer',
|
||||
output='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table',
|
||||
show_log=True)
|
||||
img = cv2.imread('/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/ppstructure/test_imgs/table_1.png')
|
||||
result = table_engine(img)
|
||||
for line in result:
|
||||
print(line)
|
|
@ -18,97 +18,93 @@ import subprocess
|
|||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
import time
|
||||
import tools.infer.utility as utility
|
||||
from tools.infer.predict_system import TextSystem
|
||||
from ppstructure.table.predict_table import TableSystem, to_excel
|
||||
from ppstructure.layout.predict_layout import LayoutDetector
|
||||
|
||||
import layoutparser as lp
|
||||
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
from tools.infer.predict_system import TextSystem
|
||||
from ppstructure.table.predict_table import TableSystem, to_excel
|
||||
from ppstructure.utility import parse_args
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = utility.init_args()
|
||||
|
||||
# params for output
|
||||
parser.add_argument("--table_output", type=str, default='output/table')
|
||||
# params for table structure
|
||||
parser.add_argument("--table_max_len", type=int, default=488)
|
||||
parser.add_argument("--table_max_text_length", type=int, default=100)
|
||||
parser.add_argument("--table_max_elem_length", type=int, default=800)
|
||||
parser.add_argument("--table_max_cell_num", type=int, default=500)
|
||||
parser.add_argument("--table_model_dir", type=str)
|
||||
parser.add_argument("--table_char_type", type=str, default='en')
|
||||
parser.add_argument("--table_char_dict_path", type=str, default="./ppocr/utils/dict/table_structure_dict.txt")
|
||||
|
||||
# params for layout detector
|
||||
parser.add_argument("--layout_model_dir", type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class OCRSystem():
|
||||
class OCRSystem(object):
|
||||
def __init__(self, args):
|
||||
self.text_system = TextSystem(args)
|
||||
self.table_system = TableSystem(args)
|
||||
self.table_layout = LayoutDetector(args)
|
||||
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
|
||||
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
|
||||
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
|
||||
enforce_cpu=not args.use_gpu)
|
||||
self.use_angle_cls = args.use_angle_cls
|
||||
self.drop_score = args.drop_score
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
layout_res = self.table_layout(copy.deepcopy(img))
|
||||
layout_res = self.table_layout.detect(img[..., ::-1])
|
||||
res_list = []
|
||||
for region in layout_res:
|
||||
x1, y1, x2, y2 = region['bbox']
|
||||
x1, y1, x2, y2 = region.coordinates
|
||||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||
roi_img = ori_im[y1:y2, x1:x2, :]
|
||||
if region['label'] == 'table':
|
||||
res = self.text_system(roi_img)
|
||||
if region.type == 'Table':
|
||||
res = self.table_system(roi_img)
|
||||
elif region.type == 'Figure':
|
||||
continue
|
||||
else:
|
||||
res = self.text_system(roi_img)
|
||||
region['res'] = res
|
||||
return layout_res
|
||||
filter_boxes, filter_rec_res = self.text_system(roi_img)
|
||||
filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
|
||||
res = (filter_boxes, filter_rec_res)
|
||||
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
|
||||
return res_list
|
||||
|
||||
|
||||
def save_res(res, save_folder, img_name):
|
||||
excel_save_folder = os.path.join(save_folder, img_name)
|
||||
os.makedirs(excel_save_folder, exist_ok=True)
|
||||
# save res
|
||||
for region in res:
|
||||
if region['type'] == 'Table':
|
||||
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
|
||||
to_excel(region['res'], excel_path)
|
||||
elif region['type'] == 'Figure':
|
||||
pass
|
||||
else:
|
||||
with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
|
||||
for box, rec_res in zip(*region['res']):
|
||||
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
image_file_list = image_file_list
|
||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||
save_folder = args.table_output
|
||||
save_folder = args.output
|
||||
os.makedirs(save_folder, exist_ok=True)
|
||||
|
||||
text_sys = OCRSystem(args)
|
||||
structure_sys = OCRSystem(args)
|
||||
img_num = len(image_file_list)
|
||||
for i, image_file in enumerate(image_file_list):
|
||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
img_name = os.path.basename(image_file).split('.')[0]
|
||||
# excel_path = os.path.join(excel_save_folder, + '.xlsx')
|
||||
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
res = text_sys(img)
|
||||
|
||||
excel_save_folder = os.path.join(save_folder, img_name)
|
||||
os.makedirs(excel_save_folder, exist_ok=True)
|
||||
# save res
|
||||
for region in res:
|
||||
if region['label'] == 'table':
|
||||
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
|
||||
to_excel(region['res'], excel_path)
|
||||
else:
|
||||
with open(os.path.join(excel_save_folder, 'res.txt'),'a',encoding='utf8') as f:
|
||||
for box, rec_res in zip(*region['res']):
|
||||
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
|
||||
logger.info(res)
|
||||
res = structure_sys(img)
|
||||
save_res(res, save_folder, img_name)
|
||||
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
||||
elapse = time.time() - starttime
|
||||
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from setuptools import setup
|
||||
from io import open
|
||||
import shutil
|
||||
|
||||
with open('../requirements.txt', encoding="utf-8-sig") as f:
|
||||
requirements = f.readlines()
|
||||
requirements.append('tqdm')
|
||||
requirements.append('layoutparser')
|
||||
|
||||
|
||||
def readme():
|
||||
with open('README_ch.md', encoding="utf-8-sig") as f:
|
||||
README = f.read()
|
||||
return README
|
||||
|
||||
shutil.copytree('../ppocr','./ppocr')
|
||||
shutil.copytree('../tools','./tools')
|
||||
shutil.copytree('../ppstructure','./ppstructure')
|
||||
|
||||
setup(
|
||||
name='paddlestructure',
|
||||
packages=['paddlestructure'],
|
||||
package_dir={'paddlestructure': ''},
|
||||
include_package_data=True,
|
||||
entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]},
|
||||
version='2.0.6',
|
||||
install_requires=requirements,
|
||||
license='Apache License 2.0',
|
||||
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
|
||||
long_description=readme(),
|
||||
long_description_content_type='text/markdown',
|
||||
url='https://github.com/PaddlePaddle/PaddleOCR',
|
||||
download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
|
||||
keywords=[
|
||||
'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
|
||||
],
|
||||
classifiers=[
|
||||
'Intended Audience :: Developers', 'Operating System :: OS Independent',
|
||||
'Natural Language :: Chinese (Simplified)',
|
||||
'Programming Language :: Python :: 3',
|
||||
'Programming Language :: Python :: 3.2',
|
||||
'Programming Language :: Python :: 3.3',
|
||||
'Programming Language :: Python :: 3.4',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
|
||||
], )
|
||||
|
||||
shutil.rmtree('ppocr')
|
||||
shutil.rmtree('tools')
|
||||
shutil.rmtree('ppstructure')
|
|
@ -0,0 +1,15 @@
|
|||
# 表格结构和内容预测
|
||||
|
||||
先cd到PaddleOCR/ppstructure目录下
|
||||
|
||||
预测
|
||||
```python
|
||||
python3 table/predict_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs/PMC3006023_004_00.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --table_output ../output/table
|
||||
```
|
||||
运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下
|
||||
|
||||
eval
|
||||
|
||||
```python
|
||||
python3 table/eval_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
||||
```
|
|
@ -15,16 +15,21 @@ import os
|
|||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import cv2
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from ppstructure.table.table_metric import TEDS
|
||||
from ppstructure.table.predict_table import TableSystem
|
||||
from ppstructure.predict_system import parse_args
|
||||
from ppstructure.utility import init_args
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = init_args()
|
||||
parser.add_argument("--gt_path", type=str)
|
||||
return parser.parse_args()
|
||||
|
||||
def main(gt_path, img_root, args):
|
||||
teds = TEDS(n_jobs=16)
|
||||
|
||||
|
@ -33,6 +38,8 @@ def main(gt_path, img_root, args):
|
|||
pred_htmls = []
|
||||
gt_htmls = []
|
||||
for img_name in tqdm(jsons_gt):
|
||||
if img_name != 'PMC1064865_002_00.png':
|
||||
continue
|
||||
# 读取信息
|
||||
img = cv2.imread(os.path.join(img_root,img_name))
|
||||
pred_html = text_sys(img)
|
||||
|
@ -61,6 +68,4 @@ def get_gt_html(gt_structures, contents_with_block):
|
|||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
gt_path = 'table/match_code/f_gt_bbox.json'
|
||||
img_root = 'table/imgs'
|
||||
main(gt_path,img_root, args)
|
||||
main(args.gt_path,args.image_dir, args)
|
||||
|
|
|
@ -194,21 +194,3 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
|||
matched[index].append(i)
|
||||
pre_bbox = gt_box
|
||||
return matched
|
||||
|
||||
|
||||
def main():
|
||||
detect_bboxes = json.load(open('./f_detecion_bbox.json'))
|
||||
gt_bboxes = json.load(open('./f_gt_bbox.json'))
|
||||
all_node = 0
|
||||
matched_right = 0
|
||||
key = 'PMC4796501_003_00.png'
|
||||
print(key)
|
||||
gt_bbox = gt_bboxes[key]
|
||||
pred_bbox = detect_bboxes[key]
|
||||
matched = matcher(gt_bbox, pred_bbox)
|
||||
print(matched)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class TableStructurer(object):
|
|||
def __init__(self, args):
|
||||
pre_process_list = [{
|
||||
'ResizeTableImage': {
|
||||
'max_len': args.table_max_len
|
||||
'max_len': args.structure_max_len
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
|
@ -60,17 +60,17 @@ class TableStructurer(object):
|
|||
}]
|
||||
postprocess_params = {
|
||||
'name': 'TableLabelDecode',
|
||||
"character_type": args.table_char_type,
|
||||
"character_dict_path": args.table_char_dict_path,
|
||||
"max_text_length": args.table_max_text_length,
|
||||
"max_elem_length": args.table_max_elem_length,
|
||||
"max_cell_num": args.table_max_cell_num
|
||||
"character_type": args.structure_char_type,
|
||||
"character_dict_path": args.structure_char_dict_path,
|
||||
"max_text_length": args.structure_max_text_length,
|
||||
"max_elem_length": args.structure_max_elem_length,
|
||||
"max_cell_num": args.structure_max_cell_num
|
||||
}
|
||||
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors = \
|
||||
utility.create_predictor(args, 'table', logger)
|
||||
utility.create_predictor(args, 'structure', logger)
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
|
|
|
@ -18,6 +18,7 @@ import subprocess
|
|||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
@ -25,13 +26,13 @@ import cv2
|
|||
import copy
|
||||
import numpy as np
|
||||
import time
|
||||
import tools.infer.utility as utility
|
||||
import tools.infer.predict_rec as predict_rec
|
||||
import tools.infer.predict_det as predict_det
|
||||
import ppstructure.table.predict_structure as predict_strture
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppstructure.table.matcher import distance, compute_iou
|
||||
from matcher import distance, compute_iou
|
||||
from ppstructure.utility import parse_args
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
@ -52,12 +53,10 @@ def expand(pix, det_box, shape):
|
|||
|
||||
|
||||
class TableSystem(object):
|
||||
def __init__(self, args):
|
||||
self.text_detector = predict_det.TextDetector(args)
|
||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||
def __init__(self, args, text_detector=None, text_recognizer=None):
|
||||
self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
|
||||
self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
|
||||
self.table_structurer = predict_strture.TableStructurer(args)
|
||||
self.use_angle_cls = args.use_angle_cls
|
||||
self.drop_score = args.drop_score
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
|
@ -75,8 +74,8 @@ class TableSystem(object):
|
|||
r_boxes.append(box)
|
||||
dt_boxes = np.array(r_boxes)
|
||||
|
||||
# logger.info("dt_boxes num : {}, elapse : {}".format(
|
||||
# len(dt_boxes), elapse))
|
||||
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
if dt_boxes is None:
|
||||
return None, None
|
||||
img_crop_list = []
|
||||
|
@ -87,8 +86,8 @@ class TableSystem(object):
|
|||
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
|
||||
img_crop_list.append(text_rect)
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
# logger.info("rec_res num : {}, elapse : {}".format(
|
||||
# len(rec_res), elapse))
|
||||
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||
len(rec_res), elapse))
|
||||
|
||||
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
|
||||
return pred_html
|
||||
|
@ -172,6 +171,7 @@ def sorted_boxes(dt_boxes):
|
|||
_boxes[i + 1] = tmp
|
||||
return _boxes
|
||||
|
||||
|
||||
def to_excel(html_table, excel_path):
|
||||
from tablepyxl import tablepyxl
|
||||
tablepyxl.document_to_xl(html_table, excel_path)
|
||||
|
@ -180,19 +180,18 @@ def to_excel(html_table, excel_path):
|
|||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||
excel_save_folder = 'output/table'
|
||||
os.makedirs(excel_save_folder, exist_ok=True)
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
|
||||
text_sys = TableSystem(args)
|
||||
img_num = len(image_file_list)
|
||||
for i, image_file in enumerate(image_file_list):
|
||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
excel_path = os.path.join(excel_save_folder, os.path.basename(image_file).split('.')[0] + '.xlsx')
|
||||
excel_path = os.path.join(args.table_output, os.path.basename(image_file).split('.')[0] + '.xlsx')
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
pred_html = text_sys(img)
|
||||
|
@ -205,7 +204,7 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = utility.parse_args()
|
||||
args = parse_args()
|
||||
if args.use_mp:
|
||||
p_list = []
|
||||
total_process_num = args.total_process_num
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from tools.infer.utility import str2bool, init_args as infer_args
|
||||
|
||||
|
||||
def init_args():
|
||||
parser = infer_args()
|
||||
|
||||
# params for output
|
||||
parser.add_argument("--output", type=str, default='./output/table')
|
||||
# params for table structure
|
||||
parser.add_argument("--structure_max_len", type=int, default=488)
|
||||
parser.add_argument("--structure_max_text_length", type=int, default=100)
|
||||
parser.add_argument("--structure_max_elem_length", type=int, default=800)
|
||||
parser.add_argument("--structure_max_cell_num", type=int, default=500)
|
||||
parser.add_argument("--structure_model_dir", type=str)
|
||||
parser.add_argument("--structure_char_type", type=str, default='en')
|
||||
parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
|
||||
|
||||
# params for layout detector
|
||||
parser.add_argument("--layout_model_dir", type=str)
|
||||
return parser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = init_args()
|
||||
return parser.parse_args()
|
|
@ -88,7 +88,7 @@ class TextSystem(object):
|
|||
def __call__(self, img, cls=True):
|
||||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
logger.info("dt_boxes num : {}, elapse : {}".format(
|
||||
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
if dt_boxes is None:
|
||||
return None, None
|
||||
|
@ -103,11 +103,11 @@ class TextSystem(object):
|
|||
if self.use_angle_cls and cls:
|
||||
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||
img_crop_list)
|
||||
logger.info("cls num : {}, elapse : {}".format(
|
||||
logger.debug("cls num : {}, elapse : {}".format(
|
||||
len(img_crop_list), elapse))
|
||||
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
logger.info("rec_res num : {}, elapse : {}".format(
|
||||
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||
len(rec_res), elapse))
|
||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
|
@ -152,7 +152,7 @@ def main(args):
|
|||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
starttime = time.time()
|
||||
dt_boxes, rec_res = text_sys(img)
|
||||
|
|
|
@ -109,7 +109,7 @@ def init_args():
|
|||
parser.add_argument("--use_mp", type=str2bool, default=False)
|
||||
parser.add_argument("--total_process_num", type=int, default=1)
|
||||
parser.add_argument("--process_id", type=int, default=0)
|
||||
|
||||
parser.add_argument("--show_log", type=str2bool, default=True)
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -125,8 +125,8 @@ def create_predictor(args, mode, logger):
|
|||
model_dir = args.cls_model_dir
|
||||
elif mode == 'rec':
|
||||
model_dir = args.rec_model_dir
|
||||
elif mode == 'table':
|
||||
model_dir = args.table_model_dir
|
||||
elif mode == 'structure':
|
||||
model_dir = args.structure_model_dir
|
||||
else:
|
||||
model_dir = args.e2e_model_dir
|
||||
|
||||
|
@ -246,7 +246,8 @@ def create_predictor(args, mode, logger):
|
|||
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
if mode == 'table':
|
||||
config.switch_ir_optim(True)
|
||||
if mode == 'structure':
|
||||
config.switch_ir_optim(False)
|
||||
# create predictor
|
||||
predictor = inference.create_predictor(config)
|
||||
|
|
Loading…
Reference in New Issue