147 lines
6.0 KiB
Python
147 lines
6.0 KiB
Python
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
__dir__ = os.path.dirname(__file__)
|
|
sys.path.append(__dir__)
|
|
sys.path.append(os.path.join(__dir__, '..'))
|
|
|
|
import cv2
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
from ppocr.utils.logging import get_logger
|
|
from ppstructure.predict_system import OCRSystem, save_res
|
|
from ppstructure.utility import init_args, draw_result
|
|
|
|
logger = get_logger()
|
|
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
|
from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link
|
|
|
|
__all__ = ['PaddleStructure', 'draw_result', 'save_res']
|
|
|
|
VERSION = '2.1'
|
|
BASE_DIR = os.path.expanduser("~/.paddlestructure/")
|
|
|
|
model_urls = {
|
|
'det': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar',
|
|
'rec': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
|
|
'table': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar'
|
|
|
|
}
|
|
|
|
|
|
def parse_args(mMain=True):
|
|
import argparse
|
|
parser = init_args()
|
|
parser.add_help = mMain
|
|
|
|
for action in parser._actions:
|
|
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
|
|
action.default = None
|
|
if mMain:
|
|
return parser.parse_args()
|
|
else:
|
|
inference_args_dict = {}
|
|
for action in parser._actions:
|
|
inference_args_dict[action.dest] = action.default
|
|
return argparse.Namespace(**inference_args_dict)
|
|
|
|
|
|
class PaddleStructure(OCRSystem):
|
|
def __init__(self, **kwargs):
|
|
params = parse_args(mMain=False)
|
|
params.__dict__.update(**kwargs)
|
|
if not params.show_log:
|
|
logger.setLevel(logging.INFO)
|
|
params.use_angle_cls = False
|
|
# init model dir
|
|
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
|
|
os.path.join(BASE_DIR, VERSION, 'det'),
|
|
model_urls['det'])
|
|
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
|
|
os.path.join(BASE_DIR, VERSION, 'rec'),
|
|
model_urls['rec'])
|
|
params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir,
|
|
os.path.join(BASE_DIR, VERSION, 'table'),
|
|
model_urls['table'])
|
|
# download model
|
|
maybe_download(params.det_model_dir, det_url)
|
|
maybe_download(params.rec_model_dir, rec_url)
|
|
maybe_download(params.table_model_dir, table_url)
|
|
|
|
if params.rec_char_dict_path is None:
|
|
params.rec_char_type = 'EN'
|
|
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')):
|
|
params.rec_char_dict_path = str(Path(__file__).parent / 'ppocr/utils/dict/table_dict.txt')
|
|
else:
|
|
params.rec_char_dict_path = str(Path(__file__).parent.parent / 'ppocr/utils/dict/table_dict.txt')
|
|
if params.table_char_dict_path is None:
|
|
if os.path.exists(str(Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')):
|
|
params.table_char_dict_path = str(
|
|
Path(__file__).parent / 'ppocr/utils/dict/table_structure_dict.txt')
|
|
else:
|
|
params.table_char_dict_path = str(
|
|
Path(__file__).parent.parent / 'ppocr/utils/dict/table_structure_dict.txt')
|
|
|
|
print(params)
|
|
super().__init__(params)
|
|
|
|
def __call__(self, img):
|
|
if isinstance(img, str):
|
|
# download net image
|
|
if img.startswith('http'):
|
|
download_with_progressbar(img, 'tmp.jpg')
|
|
img = 'tmp.jpg'
|
|
image_file = img
|
|
img, flag = check_and_read_gif(image_file)
|
|
if not flag:
|
|
with open(image_file, 'rb') as f:
|
|
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
|
|
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
|
if img is None:
|
|
logger.error("error in loading image:{}".format(image_file))
|
|
return None
|
|
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
|
|
|
res = super().__call__(img)
|
|
return res
|
|
|
|
|
|
def main():
|
|
# for cmd
|
|
args = parse_args(mMain=True)
|
|
image_dir = args.image_dir
|
|
save_folder = args.output
|
|
if image_dir.startswith('http'):
|
|
download_with_progressbar(image_dir, 'tmp.jpg')
|
|
image_file_list = ['tmp.jpg']
|
|
else:
|
|
image_file_list = get_image_file_list(args.image_dir)
|
|
if len(image_file_list) == 0:
|
|
logger.error('no images find in {}'.format(args.image_dir))
|
|
return
|
|
|
|
structure_engine = PaddleStructure(**(args.__dict__))
|
|
for img_path in image_file_list:
|
|
img_name = os.path.basename(img_path).split('.')[0]
|
|
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
|
|
result = structure_engine(img_path)
|
|
for item in result:
|
|
logger.info(item['res'])
|
|
save_res(result, save_folder, img_name)
|
|
logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) |