fix structure pipeline infer bug

This commit is contained in:
WenmuZhou 2021-06-10 17:12:14 +08:00
parent 330f08ffc7
commit 5cce398ebb
5 changed files with 18 additions and 10 deletions

View File

@ -151,8 +151,8 @@ class PaddleOCR(predict_system.TextSystem):
""" """
params = parse_args(mMain=False) params = parse_args(mMain=False)
params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
if params.show_log: if not params.show_log:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.INFO)
self.use_angle_cls = params.use_angle_cls self.use_angle_cls = params.use_angle_cls
lang = params.lang lang = params.lang
latin_lang = [ latin_lang = [

View File

@ -22,7 +22,7 @@ logger_initialized = {}
@functools.lru_cache() @functools.lru_cache()
def get_logger(name='root', log_file=None, log_level=logging.INFO): def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
"""Initialize and get a logger by name. """Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will logger by adding one or two handlers, otherwise the initialized logger will

View File

@ -66,8 +66,8 @@ class PaddleStructure(OCRSystem):
def __init__(self, **kwargs): def __init__(self, **kwargs):
params = parse_args(mMain=False) params = parse_args(mMain=False)
params.__dict__.update(**kwargs) params.__dict__.update(**kwargs)
if params.show_log: if not params.show_log:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.INFO)
params.use_angle_cls = False params.use_angle_cls = False
# init model dir # init model dir
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir, params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,

View File

@ -24,6 +24,7 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2 import cv2
import numpy as np import numpy as np
import time import time
import logging
import layoutparser as lp import layoutparser as lp
@ -31,7 +32,7 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.infer.predict_system import TextSystem from tools.infer.predict_system import TextSystem
from ppstructure.table.predict_table import TableSystem, to_excel from ppstructure.table.predict_table import TableSystem, to_excel
from ppstructure.utility import parse_args,draw_result from ppstructure.utility import parse_args, draw_result
logger = get_logger() logger = get_logger()
@ -40,6 +41,8 @@ class OCRSystem(object):
def __init__(self, args): def __init__(self, args):
args.det_limit_type = 'resize_long' args.det_limit_type = 'resize_long'
args.drop_score = 0 args.drop_score = 0
if not args.show_log:
logger.setLevel(logging.INFO)
self.text_system = TextSystem(args) self.text_system = TextSystem(args)
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer) self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config", self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
@ -67,6 +70,7 @@ class OCRSystem(object):
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res}) res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
return res_list return res_list
def save_res(res, save_folder, img_name): def save_res(res, save_folder, img_name):
excel_save_folder = os.path.join(save_folder, img_name) excel_save_folder = os.path.join(save_folder, img_name)
os.makedirs(excel_save_folder, exist_ok=True) os.makedirs(excel_save_folder, exist_ok=True)
@ -105,7 +109,7 @@ def main(args):
starttime = time.time() starttime = time.time()
res = structure_sys(img) res = structure_sys(img)
save_res(res, save_folder, img_name) save_res(res, save_folder, img_name)
draw_img = draw_result(img,res, args.vis_font_path) draw_img = draw_result(img, res, args.vis_font_path)
cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img) cv2.imwrite(os.path.join(save_folder, img_name, 'show.jpg'), draw_img)
logger.info('result save to {}'.format(os.path.join(save_folder, img_name))) logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
elapse = time.time() - starttime elapse = time.time() - starttime

View File

@ -24,6 +24,7 @@ import cv2
import copy import copy
import numpy as np import numpy as np
import time import time
import logging
from PIL import Image from PIL import Image
import tools.infer.utility as utility import tools.infer.utility as utility
import tools.infer.predict_rec as predict_rec import tools.infer.predict_rec as predict_rec
@ -38,6 +39,9 @@ logger = get_logger()
class TextSystem(object): class TextSystem(object):
def __init__(self, args): def __init__(self, args):
if not args.show_log:
logger.setLevel(logging.INFO)
self.text_detector = predict_det.TextDetector(args) self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args) self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls self.use_angle_cls = args.use_angle_cls
@ -88,7 +92,7 @@ class TextSystem(object):
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format( logger.debug("dt_boxes num : {}, elapse : {}".format(
len(dt_boxes), elapse)) len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
@ -104,11 +108,11 @@ class TextSystem(object):
if self.use_angle_cls and cls: if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list) img_crop_list)
logger.info("cls num : {}, elapse : {}".format( logger.debug("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse)) len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
logger.info("rec_res num : {}, elapse : {}".format( logger.debug("rec_res num : {}, elapse : {}".format(
len(rec_res), elapse)) len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) # self.print_draw_crop_rec_res(img_crop_list, rec_res)
filter_boxes, filter_rec_res = [], [] filter_boxes, filter_rec_res = [], []