From a621bef8930ba41ace08a25bef3b52e390a2148f Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Tue, 1 Dec 2020 16:45:45 +0800 Subject: [PATCH] add predict_cls to predict_system --- tools/infer/predict_cls.py | 2 +- tools/infer/predict_system.py | 29 +++++++---------------------- tools/infer/utility.py | 17 ++++++----------- 3 files changed, 14 insertions(+), 34 deletions(-) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 7d7e4720..9ec03396 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -37,7 +37,7 @@ logger = get_logger() class TextClassifier(object): def __init__(self, args): self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] - self.cls_batch_num = args.rec_batch_num + self.cls_batch_num = args.cls_batch_num self.cls_thresh = args.cls_thresh self.use_zero_copy_run = args.use_zero_copy_run postprocess_params = { diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 7ebe3ec3..4e810397 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -23,21 +23,17 @@ import numpy as np import time from PIL import Image import tools.infer.utility as utility +from tools.infer.utility import draw_ocr import tools.infer.predict_rec as predict_rec import tools.infer.predict_det as predict_det -import tools.infer.predict_cls as predict_cls from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger -from tools.infer.utility import draw_ocr_box_txt class TextSystem(object): def __init__(self, args): self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) - self.use_angle_cls = args.use_angle_cls - if self.use_angle_cls: - self.text_classifier = predict_cls.TextClassifier(args) def get_rotate_crop_image(self, img, points): ''' @@ -92,15 +88,6 @@ class TextSystem(object): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) - cv2.imwrite( - '/home/zhoujun20/dygraph/PaddleOCR_rc/inference_results/{}.jpg'. - format(bno), img_crop) - if self.use_angle_cls: - img_crop_list, angle_list, elapse = self.text_classifier( - img_crop_list) - print("cls num : {}, elapse : {}".format( - len(img_crop_list), elapse)) - rec_res, elapse = self.text_recognizer(img_crop_list) print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) @@ -132,7 +119,7 @@ def main(args): image_file_list = get_image_file_list(args.image_dir) text_sys = TextSystem(args) is_visualize = True - font_path = args.vis_font_path + tackle_img_num = 0 for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: @@ -141,6 +128,9 @@ def main(args): logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() + tackle_img_num += 1 + if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0: + text_sys = TextSystem(args) dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime print("Predict time of %s: %.3fs" % (image_file, elapse)) @@ -159,13 +149,8 @@ def main(args): txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] - draw_img = draw_ocr_box_txt( - image, - boxes, - txts, - scores, - drop_score=drop_score, - font_path=font_path) + draw_img = draw_ocr( + image, boxes, txts, scores, drop_score=drop_score) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 1d8cf22a..5a524516 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -202,12 +202,7 @@ def draw_ocr(image, return image -def draw_ocr_box_txt(image, - boxes, - txts, - scores=None, - drop_score=0.5, - font_path="./doc/simfang.ttf"): +def draw_ocr_box_txt(image, boxes, txts): h, w = image.height, image.width img_left = image.copy() img_right = Image.new('RGB', (w, h), (255, 255, 255)) @@ -217,9 +212,7 @@ def draw_ocr_box_txt(image, random.seed(0) draw_left = ImageDraw.Draw(img_left) draw_right = ImageDraw.Draw(img_right) - for idx, (box, txt) in enumerate(zip(boxes, txts)): - if scores is not None and scores[idx] < drop_score: - continue + for (box, txt) in zip(boxes, txts): color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) draw_left.polygon(box, fill=color) @@ -235,7 +228,8 @@ def draw_ocr_box_txt(image, 1])**2) if box_height > 2 * box_width: font_size = max(int(box_width * 0.9), 10) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + font = ImageFont.truetype( + "./doc/simfang.ttf", font_size, encoding="utf-8") cur_y = box[0][1] for c in txt: char_size = font.getsize(c) @@ -244,7 +238,8 @@ def draw_ocr_box_txt(image, cur_y += char_size[1] else: font_size = max(int(box_height * 0.8), 10) - font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + font = ImageFont.truetype( + "./doc/simfang.ttf", font_size, encoding="utf-8") draw_right.text( [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) img_left = Image.blend(image, img_left, 0.5)