init commit for paddlestructure
This commit is contained in:
parent
a5f7511505
commit
bc0d766425
|
@ -1,7 +1,7 @@
|
||||||
include LICENSE.txt
|
include LICENSE.txt
|
||||||
include README.md
|
include README.md
|
||||||
|
|
||||||
recursive-include ppocr/utils *.txt utility.py logging.py
|
recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
||||||
recursive-include ppocr/data/ *.py
|
recursive-include ppocr/data/ *.py
|
||||||
recursive-include ppocr/postprocess *.py
|
recursive-include ppocr/postprocess *.py
|
||||||
recursive-include tools/infer *.py
|
recursive-include tools/infer *.py
|
||||||
|
|
|
@ -19,6 +19,7 @@ __dir__ = os.path.dirname(__file__)
|
||||||
sys.path.append(os.path.join(__dir__, ''))
|
sys.path.append(os.path.join(__dir__, ''))
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -150,6 +151,8 @@ class PaddleOCR(predict_system.TextSystem):
|
||||||
"""
|
"""
|
||||||
params = parse_args(mMain=False)
|
params = parse_args(mMain=False)
|
||||||
params.__dict__.update(**kwargs)
|
params.__dict__.update(**kwargs)
|
||||||
|
if params.show_log:
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
self.use_angle_cls = params.use_angle_cls
|
self.use_angle_cls = params.use_angle_cls
|
||||||
lang = params.lang
|
lang = params.lang
|
||||||
latin_lang = [
|
latin_lang = [
|
||||||
|
|
|
@ -33,8 +33,7 @@ D
|
||||||
Π
|
Π
|
||||||
H
|
H
|
||||||
║
|
║
|
||||||
</
|
</strike>
|
||||||
>
|
|
||||||
L
|
L
|
||||||
Φ
|
Φ
|
||||||
Χ
|
Χ
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
include LICENSE.txt
|
||||||
|
include README.md
|
||||||
|
|
||||||
|
recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
||||||
|
recursive-include ppocr/data/ *.py
|
||||||
|
recursive-include ppocr/postprocess *.py
|
||||||
|
recursive-include tools/infer *.py
|
||||||
|
recursive-include table *.py
|
||||||
|
recursive-include ppstructure *.py
|
|
@ -0,0 +1,161 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__dir__ = os.path.dirname(__file__)
|
||||||
|
sys.path.append(os.path.join(__dir__, ''))
|
||||||
|
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from ppocr.utils.logging import get_logger
|
||||||
|
from predict_system import OCRSystem, save_res
|
||||||
|
from utility import init_args
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||||
|
from ppocr.utils.network import maybe_download, download_with_progressbar
|
||||||
|
|
||||||
|
__all__ = ['PaddleStructure']
|
||||||
|
|
||||||
|
VERSION = '2.1'
|
||||||
|
BASE_DIR = os.path.expanduser("~/.paddlestructure/")
|
||||||
|
|
||||||
|
model_urls = {
|
||||||
|
'det': {
|
||||||
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||||
|
},
|
||||||
|
'rec': {
|
||||||
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||||
|
},
|
||||||
|
'structure': {
|
||||||
|
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args(mMain=True):
|
||||||
|
import argparse
|
||||||
|
parser = init_args()
|
||||||
|
parser.add_help = mMain
|
||||||
|
|
||||||
|
for action in parser._actions:
|
||||||
|
if action.dest in ['rec_char_dict_path', 'structure_char_dict_path']:
|
||||||
|
action.default = None
|
||||||
|
if mMain:
|
||||||
|
return parser.parse_args()
|
||||||
|
else:
|
||||||
|
inference_args_dict = {}
|
||||||
|
for action in parser._actions:
|
||||||
|
inference_args_dict[action.dest] = action.default
|
||||||
|
return argparse.Namespace(**inference_args_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class PaddleStructure(OCRSystem):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
params = parse_args(mMain=False)
|
||||||
|
params.__dict__.update(**kwargs)
|
||||||
|
if params.show_log:
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
params.use_angle_cls = False
|
||||||
|
# init model dir
|
||||||
|
if params.det_model_dir is None:
|
||||||
|
params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det')
|
||||||
|
if params.rec_model_dir is None:
|
||||||
|
params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec')
|
||||||
|
if params.structure_model_dir is None:
|
||||||
|
params.structure_model_dir = os.path.join(BASE_DIR, VERSION, 'structure')
|
||||||
|
# download model
|
||||||
|
maybe_download(params.det_model_dir, model_urls['det'])
|
||||||
|
maybe_download(params.det_model_dir, model_urls['rec'])
|
||||||
|
maybe_download(params.det_model_dir, model_urls['structure'])
|
||||||
|
|
||||||
|
if params.rec_char_dict_path is None:
|
||||||
|
params.rec_char_type = 'EN'
|
||||||
|
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
|
||||||
|
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
|
||||||
|
else:
|
||||||
|
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
|
||||||
|
if params.structure_char_dict_path is None:
|
||||||
|
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
|
||||||
|
params.structure_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
|
||||||
|
else:
|
||||||
|
params.structure_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
|
||||||
|
|
||||||
|
print(params)
|
||||||
|
super().__init__(params)
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
if isinstance(img, str):
|
||||||
|
# download net image
|
||||||
|
if img.startswith('http'):
|
||||||
|
download_with_progressbar(img, 'tmp.jpg')
|
||||||
|
img = 'tmp.jpg'
|
||||||
|
image_file = img
|
||||||
|
img, flag = check_and_read_gif(image_file)
|
||||||
|
if not flag:
|
||||||
|
with open(image_file, 'rb') as f:
|
||||||
|
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
|
||||||
|
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
logger.error("error in loading image:{}".format(image_file))
|
||||||
|
return None
|
||||||
|
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
res = super().__call__(img)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# for cmd
|
||||||
|
args = parse_args(mMain=True)
|
||||||
|
image_dir = args.image_dir
|
||||||
|
save_folder = args.output
|
||||||
|
if image_dir.startswith('http'):
|
||||||
|
download_with_progressbar(image_dir, 'tmp.jpg')
|
||||||
|
image_file_list = ['tmp.jpg']
|
||||||
|
else:
|
||||||
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
|
if len(image_file_list) == 0:
|
||||||
|
logger.error('no images find in {}'.format(args.image_dir))
|
||||||
|
return
|
||||||
|
|
||||||
|
structure_engine = PaddleStructure(**(args.__dict__))
|
||||||
|
for img_path in image_file_list:
|
||||||
|
img_name = os.path.basename(img_path).split('.')[0]
|
||||||
|
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
|
||||||
|
result = structure_engine(img_path)
|
||||||
|
save_res(result, args.output, os.path.basename(img_path).split('.')[0])
|
||||||
|
for item in result:
|
||||||
|
logger.info(item['res'])
|
||||||
|
save_res(result, save_folder, img_name)
|
||||||
|
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
table_engine = PaddleStructure(det_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_det_infer',
|
||||||
|
rec_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_rec_infer',
|
||||||
|
structure_model_dir='../inference/table/ch_ppocr_mobile_v2.0_table_structure_infer',
|
||||||
|
output='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table',
|
||||||
|
show_log=True)
|
||||||
|
img = cv2.imread('/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/ppstructure/test_imgs/table_1.png')
|
||||||
|
result = table_engine(img)
|
||||||
|
for line in result:
|
||||||
|
print(line)
|
|
@ -18,97 +18,93 @@ import subprocess
|
||||||
|
|
||||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(__dir__)
|
sys.path.append(__dir__)
|
||||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
|
||||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||||
import cv2
|
import cv2
|
||||||
import copy
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import tools.infer.utility as utility
|
|
||||||
from tools.infer.predict_system import TextSystem
|
import layoutparser as lp
|
||||||
from ppstructure.table.predict_table import TableSystem, to_excel
|
|
||||||
from ppstructure.layout.predict_layout import LayoutDetector
|
|
||||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
|
from tools.infer.predict_system import TextSystem
|
||||||
|
from ppstructure.table.predict_table import TableSystem, to_excel
|
||||||
|
from ppstructure.utility import parse_args
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
class OCRSystem(object):
|
||||||
parser = utility.init_args()
|
|
||||||
|
|
||||||
# params for output
|
|
||||||
parser.add_argument("--table_output", type=str, default='output/table')
|
|
||||||
# params for table structure
|
|
||||||
parser.add_argument("--table_max_len", type=int, default=488)
|
|
||||||
parser.add_argument("--table_max_text_length", type=int, default=100)
|
|
||||||
parser.add_argument("--table_max_elem_length", type=int, default=800)
|
|
||||||
parser.add_argument("--table_max_cell_num", type=int, default=500)
|
|
||||||
parser.add_argument("--table_model_dir", type=str)
|
|
||||||
parser.add_argument("--table_char_type", type=str, default='en')
|
|
||||||
parser.add_argument("--table_char_dict_path", type=str, default="./ppocr/utils/dict/table_structure_dict.txt")
|
|
||||||
|
|
||||||
# params for layout detector
|
|
||||||
parser.add_argument("--layout_model_dir", type=str)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
class OCRSystem():
|
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.text_system = TextSystem(args)
|
self.text_system = TextSystem(args)
|
||||||
self.table_system = TableSystem(args)
|
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
|
||||||
self.table_layout = LayoutDetector(args)
|
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
|
||||||
|
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
|
||||||
|
enforce_cpu=not args.use_gpu)
|
||||||
self.use_angle_cls = args.use_angle_cls
|
self.use_angle_cls = args.use_angle_cls
|
||||||
self.drop_score = args.drop_score
|
self.drop_score = args.drop_score
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
ori_im = img.copy()
|
ori_im = img.copy()
|
||||||
layout_res = self.table_layout(copy.deepcopy(img))
|
layout_res = self.table_layout.detect(img[..., ::-1])
|
||||||
|
res_list = []
|
||||||
for region in layout_res:
|
for region in layout_res:
|
||||||
x1, y1, x2, y2 = region['bbox']
|
x1, y1, x2, y2 = region.coordinates
|
||||||
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||||||
roi_img = ori_im[y1:y2, x1:x2, :]
|
roi_img = ori_im[y1:y2, x1:x2, :]
|
||||||
if region['label'] == 'table':
|
if region.type == 'Table':
|
||||||
res = self.text_system(roi_img)
|
res = self.table_system(roi_img)
|
||||||
|
elif region.type == 'Figure':
|
||||||
|
continue
|
||||||
else:
|
else:
|
||||||
res = self.text_system(roi_img)
|
filter_boxes, filter_rec_res = self.text_system(roi_img)
|
||||||
region['res'] = res
|
filter_boxes = [x.reshape(-1).tolist() for x in filter_boxes]
|
||||||
return layout_res
|
res = (filter_boxes, filter_rec_res)
|
||||||
|
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
|
||||||
|
return res_list
|
||||||
|
|
||||||
|
|
||||||
|
def save_res(res, save_folder, img_name):
|
||||||
|
excel_save_folder = os.path.join(save_folder, img_name)
|
||||||
|
os.makedirs(excel_save_folder, exist_ok=True)
|
||||||
|
# save res
|
||||||
|
for region in res:
|
||||||
|
if region['type'] == 'Table':
|
||||||
|
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
|
||||||
|
to_excel(region['res'], excel_path)
|
||||||
|
elif region['type'] == 'Figure':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
with open(os.path.join(excel_save_folder, 'res.txt'), 'a', encoding='utf8') as f:
|
||||||
|
for box, rec_res in zip(*region['res']):
|
||||||
|
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
image_file_list = get_image_file_list(args.image_dir)
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
|
image_file_list = image_file_list
|
||||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||||
save_folder = args.table_output
|
save_folder = args.output
|
||||||
os.makedirs(save_folder, exist_ok=True)
|
os.makedirs(save_folder, exist_ok=True)
|
||||||
|
|
||||||
text_sys = OCRSystem(args)
|
structure_sys = OCRSystem(args)
|
||||||
img_num = len(image_file_list)
|
img_num = len(image_file_list)
|
||||||
for i, image_file in enumerate(image_file_list):
|
for i, image_file in enumerate(image_file_list):
|
||||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||||
img, flag = check_and_read_gif(image_file)
|
img, flag = check_and_read_gif(image_file)
|
||||||
img_name = os.path.basename(image_file).split('.')[0]
|
img_name = os.path.basename(image_file).split('.')[0]
|
||||||
# excel_path = os.path.join(excel_save_folder, + '.xlsx')
|
|
||||||
if not flag:
|
if not flag:
|
||||||
img = cv2.imread(image_file)
|
img = cv2.imread(image_file)
|
||||||
if img is None:
|
if img is None:
|
||||||
logger.info("error in loading image:{}".format(image_file))
|
logger.error("error in loading image:{}".format(image_file))
|
||||||
continue
|
continue
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
res = text_sys(img)
|
res = structure_sys(img)
|
||||||
|
save_res(res, save_folder, img_name)
|
||||||
excel_save_folder = os.path.join(save_folder, img_name)
|
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
||||||
os.makedirs(excel_save_folder, exist_ok=True)
|
|
||||||
# save res
|
|
||||||
for region in res:
|
|
||||||
if region['label'] == 'table':
|
|
||||||
excel_path = os.path.join(excel_save_folder, '{}.xlsx'.format(region['bbox']))
|
|
||||||
to_excel(region['res'], excel_path)
|
|
||||||
else:
|
|
||||||
with open(os.path.join(excel_save_folder, 'res.txt'),'a',encoding='utf8') as f:
|
|
||||||
for box, rec_res in zip(*region['res']):
|
|
||||||
f.write('{}\t{}\n'.format(np.array(box).reshape(-1).tolist(), rec_res))
|
|
||||||
logger.info(res)
|
|
||||||
elapse = time.time() - starttime
|
elapse = time.time() - starttime
|
||||||
logger.info("Predict time : {:.3f}s".format(elapse))
|
logger.info("Predict time : {:.3f}s".format(elapse))
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from setuptools import setup
|
||||||
|
from io import open
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
with open('../requirements.txt', encoding="utf-8-sig") as f:
|
||||||
|
requirements = f.readlines()
|
||||||
|
requirements.append('tqdm')
|
||||||
|
requirements.append('layoutparser')
|
||||||
|
|
||||||
|
|
||||||
|
def readme():
|
||||||
|
with open('README_ch.md', encoding="utf-8-sig") as f:
|
||||||
|
README = f.read()
|
||||||
|
return README
|
||||||
|
|
||||||
|
shutil.copytree('../ppocr','./ppocr')
|
||||||
|
shutil.copytree('../tools','./tools')
|
||||||
|
shutil.copytree('../ppstructure','./ppstructure')
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='paddlestructure',
|
||||||
|
packages=['paddlestructure'],
|
||||||
|
package_dir={'paddlestructure': ''},
|
||||||
|
include_package_data=True,
|
||||||
|
entry_points={"console_scripts": ["paddlestructure= paddlestructure.paddlestructure:main"]},
|
||||||
|
version='2.0.6',
|
||||||
|
install_requires=requirements,
|
||||||
|
license='Apache License 2.0',
|
||||||
|
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
|
||||||
|
long_description=readme(),
|
||||||
|
long_description_content_type='text/markdown',
|
||||||
|
url='https://github.com/PaddlePaddle/PaddleOCR',
|
||||||
|
download_url='https://github.com/PaddlePaddle/PaddleOCR.git',
|
||||||
|
keywords=[
|
||||||
|
'ocr textdetection textrecognition paddleocr crnn east star-net rosetta ocrlite db chineseocr chinesetextdetection chinesetextrecognition'
|
||||||
|
],
|
||||||
|
classifiers=[
|
||||||
|
'Intended Audience :: Developers', 'Operating System :: OS Independent',
|
||||||
|
'Natural Language :: Chinese (Simplified)',
|
||||||
|
'Programming Language :: Python :: 3',
|
||||||
|
'Programming Language :: Python :: 3.2',
|
||||||
|
'Programming Language :: Python :: 3.3',
|
||||||
|
'Programming Language :: Python :: 3.4',
|
||||||
|
'Programming Language :: Python :: 3.5',
|
||||||
|
'Programming Language :: Python :: 3.6',
|
||||||
|
'Programming Language :: Python :: 3.7', 'Topic :: Utilities'
|
||||||
|
], )
|
||||||
|
|
||||||
|
shutil.rmtree('ppocr')
|
||||||
|
shutil.rmtree('tools')
|
||||||
|
shutil.rmtree('ppstructure')
|
|
@ -0,0 +1,15 @@
|
||||||
|
# 表格结构和内容预测
|
||||||
|
|
||||||
|
先cd到PaddleOCR/ppstructure目录下
|
||||||
|
|
||||||
|
预测
|
||||||
|
```python
|
||||||
|
python3 table/predict_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs/PMC3006023_004_00.png --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --table_output ../output/table
|
||||||
|
```
|
||||||
|
运行完成后,每张图片的excel表格会保存到table_output字段指定的目录下
|
||||||
|
|
||||||
|
eval
|
||||||
|
|
||||||
|
```python
|
||||||
|
python3 table/eval_table.py --det_model_dir=../inference/db --rec_model_dir=../inference/rec_mv3_large1.0/infer --table_model_dir=../inference/explite3/infer --image_dir=../table/imgs --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --gt_path=path/to/gt.json
|
||||||
|
```
|
|
@ -15,16 +15,21 @@ import os
|
||||||
import sys
|
import sys
|
||||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(__dir__)
|
sys.path.append(__dir__)
|
||||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import json
|
import json
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from ppstructure.table.table_metric import TEDS
|
from ppstructure.table.table_metric import TEDS
|
||||||
from ppstructure.table.predict_table import TableSystem
|
from ppstructure.table.predict_table import TableSystem
|
||||||
from ppstructure.predict_system import parse_args
|
from ppstructure.utility import init_args
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = init_args()
|
||||||
|
parser.add_argument("--gt_path", type=str)
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
def main(gt_path, img_root, args):
|
def main(gt_path, img_root, args):
|
||||||
teds = TEDS(n_jobs=16)
|
teds = TEDS(n_jobs=16)
|
||||||
|
|
||||||
|
@ -33,6 +38,8 @@ def main(gt_path, img_root, args):
|
||||||
pred_htmls = []
|
pred_htmls = []
|
||||||
gt_htmls = []
|
gt_htmls = []
|
||||||
for img_name in tqdm(jsons_gt):
|
for img_name in tqdm(jsons_gt):
|
||||||
|
if img_name != 'PMC1064865_002_00.png':
|
||||||
|
continue
|
||||||
# 读取信息
|
# 读取信息
|
||||||
img = cv2.imread(os.path.join(img_root,img_name))
|
img = cv2.imread(os.path.join(img_root,img_name))
|
||||||
pred_html = text_sys(img)
|
pred_html = text_sys(img)
|
||||||
|
@ -61,6 +68,4 @@ def get_gt_html(gt_structures, contents_with_block):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
gt_path = 'table/match_code/f_gt_bbox.json'
|
main(args.gt_path,args.image_dir, args)
|
||||||
img_root = 'table/imgs'
|
|
||||||
main(gt_path,img_root, args)
|
|
||||||
|
|
|
@ -194,21 +194,3 @@ def matcher_structure(gt_bboxes, pred_bboxes_rows, pred_bboxes):
|
||||||
matched[index].append(i)
|
matched[index].append(i)
|
||||||
pre_bbox = gt_box
|
pre_bbox = gt_box
|
||||||
return matched
|
return matched
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
detect_bboxes = json.load(open('./f_detecion_bbox.json'))
|
|
||||||
gt_bboxes = json.load(open('./f_gt_bbox.json'))
|
|
||||||
all_node = 0
|
|
||||||
matched_right = 0
|
|
||||||
key = 'PMC4796501_003_00.png'
|
|
||||||
print(key)
|
|
||||||
gt_bbox = gt_bboxes[key]
|
|
||||||
pred_bbox = detect_bboxes[key]
|
|
||||||
matched = matcher(gt_bbox, pred_bbox)
|
|
||||||
print(matched)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class TableStructurer(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
pre_process_list = [{
|
pre_process_list = [{
|
||||||
'ResizeTableImage': {
|
'ResizeTableImage': {
|
||||||
'max_len': args.table_max_len
|
'max_len': args.structure_max_len
|
||||||
}
|
}
|
||||||
}, {
|
}, {
|
||||||
'NormalizeImage': {
|
'NormalizeImage': {
|
||||||
|
@ -60,17 +60,17 @@ class TableStructurer(object):
|
||||||
}]
|
}]
|
||||||
postprocess_params = {
|
postprocess_params = {
|
||||||
'name': 'TableLabelDecode',
|
'name': 'TableLabelDecode',
|
||||||
"character_type": args.table_char_type,
|
"character_type": args.structure_char_type,
|
||||||
"character_dict_path": args.table_char_dict_path,
|
"character_dict_path": args.structure_char_dict_path,
|
||||||
"max_text_length": args.table_max_text_length,
|
"max_text_length": args.structure_max_text_length,
|
||||||
"max_elem_length": args.table_max_elem_length,
|
"max_elem_length": args.structure_max_elem_length,
|
||||||
"max_cell_num": args.table_max_cell_num
|
"max_cell_num": args.structure_max_cell_num
|
||||||
}
|
}
|
||||||
|
|
||||||
self.preprocess_op = create_operators(pre_process_list)
|
self.preprocess_op = create_operators(pre_process_list)
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.input_tensor, self.output_tensors = \
|
self.predictor, self.input_tensor, self.output_tensors = \
|
||||||
utility.create_predictor(args, 'table', logger)
|
utility.create_predictor(args, 'structure', logger)
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
ori_im = img.copy()
|
ori_im = img.copy()
|
||||||
|
|
|
@ -18,6 +18,7 @@ import subprocess
|
||||||
|
|
||||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(__dir__)
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||||
|
|
||||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||||
|
@ -25,13 +26,13 @@ import cv2
|
||||||
import copy
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import tools.infer.utility as utility
|
|
||||||
import tools.infer.predict_rec as predict_rec
|
import tools.infer.predict_rec as predict_rec
|
||||||
import tools.infer.predict_det as predict_det
|
import tools.infer.predict_det as predict_det
|
||||||
import ppstructure.table.predict_structure as predict_strture
|
import ppstructure.table.predict_structure as predict_strture
|
||||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
from ppstructure.table.matcher import distance, compute_iou
|
from matcher import distance, compute_iou
|
||||||
|
from ppstructure.utility import parse_args
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
@ -52,12 +53,10 @@ def expand(pix, det_box, shape):
|
||||||
|
|
||||||
|
|
||||||
class TableSystem(object):
|
class TableSystem(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args, text_detector=None, text_recognizer=None):
|
||||||
self.text_detector = predict_det.TextDetector(args)
|
self.text_detector = predict_det.TextDetector(args) if text_detector is None else text_detector
|
||||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
self.text_recognizer = predict_rec.TextRecognizer(args) if text_recognizer is None else text_recognizer
|
||||||
self.table_structurer = predict_strture.TableStructurer(args)
|
self.table_structurer = predict_strture.TableStructurer(args)
|
||||||
self.use_angle_cls = args.use_angle_cls
|
|
||||||
self.drop_score = args.drop_score
|
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
ori_im = img.copy()
|
ori_im = img.copy()
|
||||||
|
@ -75,8 +74,8 @@ class TableSystem(object):
|
||||||
r_boxes.append(box)
|
r_boxes.append(box)
|
||||||
dt_boxes = np.array(r_boxes)
|
dt_boxes = np.array(r_boxes)
|
||||||
|
|
||||||
# logger.info("dt_boxes num : {}, elapse : {}".format(
|
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||||
# len(dt_boxes), elapse))
|
len(dt_boxes), elapse))
|
||||||
if dt_boxes is None:
|
if dt_boxes is None:
|
||||||
return None, None
|
return None, None
|
||||||
img_crop_list = []
|
img_crop_list = []
|
||||||
|
@ -87,8 +86,8 @@ class TableSystem(object):
|
||||||
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
|
text_rect = ori_im[int(y0):int(y1), int(x0):int(x1), :]
|
||||||
img_crop_list.append(text_rect)
|
img_crop_list.append(text_rect)
|
||||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||||
# logger.info("rec_res num : {}, elapse : {}".format(
|
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||||
# len(rec_res), elapse))
|
len(rec_res), elapse))
|
||||||
|
|
||||||
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
|
pred_html, pred = self.rebuild_table(structure_res, dt_boxes, rec_res)
|
||||||
return pred_html
|
return pred_html
|
||||||
|
@ -172,6 +171,7 @@ def sorted_boxes(dt_boxes):
|
||||||
_boxes[i + 1] = tmp
|
_boxes[i + 1] = tmp
|
||||||
return _boxes
|
return _boxes
|
||||||
|
|
||||||
|
|
||||||
def to_excel(html_table, excel_path):
|
def to_excel(html_table, excel_path):
|
||||||
from tablepyxl import tablepyxl
|
from tablepyxl import tablepyxl
|
||||||
tablepyxl.document_to_xl(html_table, excel_path)
|
tablepyxl.document_to_xl(html_table, excel_path)
|
||||||
|
@ -180,19 +180,18 @@ def to_excel(html_table, excel_path):
|
||||||
def main(args):
|
def main(args):
|
||||||
image_file_list = get_image_file_list(args.image_dir)
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
image_file_list = image_file_list[args.process_id::args.total_process_num]
|
||||||
excel_save_folder = 'output/table'
|
os.makedirs(args.output, exist_ok=True)
|
||||||
os.makedirs(excel_save_folder, exist_ok=True)
|
|
||||||
|
|
||||||
text_sys = TableSystem(args)
|
text_sys = TableSystem(args)
|
||||||
img_num = len(image_file_list)
|
img_num = len(image_file_list)
|
||||||
for i, image_file in enumerate(image_file_list):
|
for i, image_file in enumerate(image_file_list):
|
||||||
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
logger.info("[{}/{}] {}".format(i, img_num, image_file))
|
||||||
img, flag = check_and_read_gif(image_file)
|
img, flag = check_and_read_gif(image_file)
|
||||||
excel_path = os.path.join(excel_save_folder, os.path.basename(image_file).split('.')[0] + '.xlsx')
|
excel_path = os.path.join(args.table_output, os.path.basename(image_file).split('.')[0] + '.xlsx')
|
||||||
if not flag:
|
if not flag:
|
||||||
img = cv2.imread(image_file)
|
img = cv2.imread(image_file)
|
||||||
if img is None:
|
if img is None:
|
||||||
logger.info("error in loading image:{}".format(image_file))
|
logger.error("error in loading image:{}".format(image_file))
|
||||||
continue
|
continue
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
pred_html = text_sys(img)
|
pred_html = text_sys(img)
|
||||||
|
@ -205,7 +204,7 @@ def main(args):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = utility.parse_args()
|
args = parse_args()
|
||||||
if args.use_mp:
|
if args.use_mp:
|
||||||
p_list = []
|
p_list = []
|
||||||
total_process_num = args.total_process_num
|
total_process_num = args.total_process_num
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from tools.infer.utility import str2bool, init_args as infer_args
|
||||||
|
|
||||||
|
|
||||||
|
def init_args():
|
||||||
|
parser = infer_args()
|
||||||
|
|
||||||
|
# params for output
|
||||||
|
parser.add_argument("--output", type=str, default='./output/table')
|
||||||
|
# params for table structure
|
||||||
|
parser.add_argument("--structure_max_len", type=int, default=488)
|
||||||
|
parser.add_argument("--structure_max_text_length", type=int, default=100)
|
||||||
|
parser.add_argument("--structure_max_elem_length", type=int, default=800)
|
||||||
|
parser.add_argument("--structure_max_cell_num", type=int, default=500)
|
||||||
|
parser.add_argument("--structure_model_dir", type=str)
|
||||||
|
parser.add_argument("--structure_char_type", type=str, default='en')
|
||||||
|
parser.add_argument("--structure_char_dict_path", type=str, default="../ppocr/utils/dict/table_structure_dict.txt")
|
||||||
|
|
||||||
|
# params for layout detector
|
||||||
|
parser.add_argument("--layout_model_dir", type=str)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = init_args()
|
||||||
|
return parser.parse_args()
|
|
@ -88,7 +88,7 @@ class TextSystem(object):
|
||||||
def __call__(self, img, cls=True):
|
def __call__(self, img, cls=True):
|
||||||
ori_im = img.copy()
|
ori_im = img.copy()
|
||||||
dt_boxes, elapse = self.text_detector(img)
|
dt_boxes, elapse = self.text_detector(img)
|
||||||
logger.info("dt_boxes num : {}, elapse : {}".format(
|
logger.debug("dt_boxes num : {}, elapse : {}".format(
|
||||||
len(dt_boxes), elapse))
|
len(dt_boxes), elapse))
|
||||||
if dt_boxes is None:
|
if dt_boxes is None:
|
||||||
return None, None
|
return None, None
|
||||||
|
@ -103,11 +103,11 @@ class TextSystem(object):
|
||||||
if self.use_angle_cls and cls:
|
if self.use_angle_cls and cls:
|
||||||
img_crop_list, angle_list, elapse = self.text_classifier(
|
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||||
img_crop_list)
|
img_crop_list)
|
||||||
logger.info("cls num : {}, elapse : {}".format(
|
logger.debug("cls num : {}, elapse : {}".format(
|
||||||
len(img_crop_list), elapse))
|
len(img_crop_list), elapse))
|
||||||
|
|
||||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||||
logger.info("rec_res num : {}, elapse : {}".format(
|
logger.debug("rec_res num : {}, elapse : {}".format(
|
||||||
len(rec_res), elapse))
|
len(rec_res), elapse))
|
||||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||||
filter_boxes, filter_rec_res = [], []
|
filter_boxes, filter_rec_res = [], []
|
||||||
|
@ -152,7 +152,7 @@ def main(args):
|
||||||
if not flag:
|
if not flag:
|
||||||
img = cv2.imread(image_file)
|
img = cv2.imread(image_file)
|
||||||
if img is None:
|
if img is None:
|
||||||
logger.info("error in loading image:{}".format(image_file))
|
logger.error("error in loading image:{}".format(image_file))
|
||||||
continue
|
continue
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
dt_boxes, rec_res = text_sys(img)
|
dt_boxes, rec_res = text_sys(img)
|
||||||
|
|
|
@ -109,7 +109,7 @@ def init_args():
|
||||||
parser.add_argument("--use_mp", type=str2bool, default=False)
|
parser.add_argument("--use_mp", type=str2bool, default=False)
|
||||||
parser.add_argument("--total_process_num", type=int, default=1)
|
parser.add_argument("--total_process_num", type=int, default=1)
|
||||||
parser.add_argument("--process_id", type=int, default=0)
|
parser.add_argument("--process_id", type=int, default=0)
|
||||||
|
parser.add_argument("--show_log", type=str2bool, default=True)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,8 +125,8 @@ def create_predictor(args, mode, logger):
|
||||||
model_dir = args.cls_model_dir
|
model_dir = args.cls_model_dir
|
||||||
elif mode == 'rec':
|
elif mode == 'rec':
|
||||||
model_dir = args.rec_model_dir
|
model_dir = args.rec_model_dir
|
||||||
elif mode == 'table':
|
elif mode == 'structure':
|
||||||
model_dir = args.table_model_dir
|
model_dir = args.structure_model_dir
|
||||||
else:
|
else:
|
||||||
model_dir = args.e2e_model_dir
|
model_dir = args.e2e_model_dir
|
||||||
|
|
||||||
|
@ -246,7 +246,8 @@ def create_predictor(args, mode, logger):
|
||||||
|
|
||||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||||
config.switch_use_feed_fetch_ops(False)
|
config.switch_use_feed_fetch_ops(False)
|
||||||
if mode == 'table':
|
config.switch_ir_optim(True)
|
||||||
|
if mode == 'structure':
|
||||||
config.switch_ir_optim(False)
|
config.switch_ir_optim(False)
|
||||||
# create predictor
|
# create predictor
|
||||||
predictor = inference.create_predictor(config)
|
predictor = inference.create_predictor(config)
|
||||||
|
|
Loading…
Reference in New Issue