Merge pull request #1299 from WenmuZhou/fix_predict_system

add predict_cls to predict_system
This commit is contained in:
zhoujun 2020-12-06 20:28:37 -06:00 committed by GitHub
commit 99ee41d8db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 60 additions and 41 deletions

View File

@ -23,7 +23,7 @@ import copy
import numpy as np import numpy as np
import math import math
import time import time
import traceback
import paddle.fluid as fluid import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
@ -106,10 +106,10 @@ class TextClassifier(object):
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch]) self.predictor.run([norm_img_batch])
prob_out = self.output_tensors[0].copy_to_cpu() prob_out = self.output_tensors[0].copy_to_cpu()
cls_res = self.postprocess_op(prob_out) cls_result = self.postprocess_op(prob_out)
elapse += time.time() - starttime elapse += time.time() - starttime
for rno in range(len(cls_res)): for rno in range(len(cls_result)):
label, score = cls_res[rno] label, score = cls_result[rno]
cls_res[indices[beg_img_no + rno]] = [label, score] cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > self.cls_thresh: if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate( img_list[indices[beg_img_no + rno]] = cv2.rotate(
@ -133,8 +133,8 @@ def main(args):
img_list.append(img) img_list.append(img)
try: try:
img_list, cls_res, predict_time = text_classifier(img_list) img_list, cls_res, predict_time = text_classifier(img_list)
except Exception as e: except:
print(e) logger.info(traceback.format_exc())
logger.info( logger.info(
"ERROR!!!! \n" "ERROR!!!! \n"
"Please read the FAQhttps://github.com/PaddlePaddle/PaddleOCR#faq \n" "Please read the FAQhttps://github.com/PaddlePaddle/PaddleOCR#faq \n"
@ -143,10 +143,10 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[
ino])) ino]))
print("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) len(img_list), predict_time))
if __name__ == "__main__": if __name__ == "__main__":
main(utility.parse_args()) main(utility.parse_args())

View File

@ -178,11 +178,12 @@ if __name__ == "__main__":
if count > 0: if count > 0:
total_time += elapse total_time += elapse
count += 1 count += 1
print("Predict time of {}: {}".format(image_file, elapse)) logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file) src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1] img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save, img_path = os.path.join(draw_img_save,
"det_res_{}".format(img_name_pure)) "det_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im) cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
if count > 1: if count > 1:
print("Avg Time:", total_time / (count - 1)) logger.info("Avg Time:", total_time / (count - 1))

View File

@ -22,7 +22,7 @@ import cv2
import numpy as np import numpy as np
import math import math
import time import time
import traceback
import paddle.fluid as fluid import paddle.fluid as fluid
import tools.infer.utility as utility import tools.infer.utility as utility
@ -135,8 +135,8 @@ def main(args):
img_list.append(img) img_list.append(img)
try: try:
rec_res, predict_time = text_recognizer(img_list) rec_res, predict_time = text_recognizer(img_list)
except Exception as e: except:
print(e) logger.info(traceback.format_exc())
logger.info( logger.info(
"ERROR!!!! \n" "ERROR!!!! \n"
"Please read the FAQhttps://github.com/PaddlePaddle/PaddleOCR#faq \n" "Please read the FAQhttps://github.com/PaddlePaddle/PaddleOCR#faq \n"
@ -145,9 +145,9 @@ def main(args):
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[
ino])) ino]))
print("Total predict time for {} images, cost: {:.3f}".format( logger.info("Total predict time for {} images, cost: {:.3f}".format(
len(img_list), predict_time)) len(img_list), predict_time))

View File

@ -23,17 +23,21 @@ import numpy as np
import time import time
from PIL import Image from PIL import Image
import tools.infer.utility as utility import tools.infer.utility as utility
from tools.infer.utility import draw_ocr
import tools.infer.predict_rec as predict_rec import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det import tools.infer.predict_det as predict_det
import tools.infer.predict_cls as predict_cls
from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from tools.infer.utility import draw_ocr_box_txt
class TextSystem(object): class TextSystem(object):
def __init__(self, args): def __init__(self, args):
self.text_detector = predict_det.TextDetector(args) self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args) self.text_recognizer = predict_rec.TextRecognizer(args)
self.use_angle_cls = args.use_angle_cls
if self.use_angle_cls:
self.text_classifier = predict_cls.TextClassifier(args)
def get_rotate_crop_image(self, img, points): def get_rotate_crop_image(self, img, points):
''' '''
@ -72,12 +76,12 @@ class TextSystem(object):
bbox_num = len(img_crop_list) bbox_num = len(img_crop_list)
for bno in range(bbox_num): for bno in range(bbox_num):
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
print(bno, rec_res[bno]) logger.info(bno, rec_res[bno])
def __call__(self, img): def __call__(self, img):
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
print("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse)) logger.info("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse))
if dt_boxes is None: if dt_boxes is None:
return None, None return None, None
img_crop_list = [] img_crop_list = []
@ -88,8 +92,14 @@ class TextSystem(object):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
if self.use_angle_cls:
img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list)
logger.info("cls num : {}, elapse : {}".format(
len(img_crop_list), elapse))
rec_res, elapse = self.text_recognizer(img_crop_list) rec_res, elapse = self.text_recognizer(img_crop_list)
print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) logger.info("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res) # self.print_draw_crop_rec_res(img_crop_list, rec_res)
return dt_boxes, rec_res return dt_boxes, rec_res
@ -119,7 +129,8 @@ def main(args):
image_file_list = get_image_file_list(args.image_dir) image_file_list = get_image_file_list(args.image_dir)
text_sys = TextSystem(args) text_sys = TextSystem(args)
is_visualize = True is_visualize = True
tackle_img_num = 0 font_path = args.vis_font_path
drop_score = args.drop_score
for image_file in image_file_list: for image_file in image_file_list:
img, flag = check_and_read_gif(image_file) img, flag = check_and_read_gif(image_file)
if not flag: if not flag:
@ -128,20 +139,16 @@ def main(args):
logger.info("error in loading image:{}".format(image_file)) logger.info("error in loading image:{}".format(image_file))
continue continue
starttime = time.time() starttime = time.time()
tackle_img_num += 1
if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0:
text_sys = TextSystem(args)
dt_boxes, rec_res = text_sys(img) dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime elapse = time.time() - starttime
print("Predict time of %s: %.3fs" % (image_file, elapse)) logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
drop_score = 0.5
dt_num = len(dt_boxes) dt_num = len(dt_boxes)
for dno in range(dt_num): for dno in range(dt_num):
text, score = rec_res[dno] text, score = rec_res[dno]
if score >= drop_score: if score >= drop_score:
text_str = "%s, %.3f" % (text, score) text_str = "%s, %.3f" % (text, score)
print(text_str) logger.info(text_str)
if is_visualize: if is_visualize:
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
@ -149,15 +156,20 @@ def main(args):
txts = [rec_res[i][0] for i in range(len(rec_res))] txts = [rec_res[i][0] for i in range(len(rec_res))]
scores = [rec_res[i][1] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr( draw_img = draw_ocr_box_txt(
image, boxes, txts, scores, drop_score=drop_score) image,
boxes,
txts,
scores,
drop_score=drop_score,
font_path=font_path)
draw_img_save = "./inference_results/" draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save): if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save) os.makedirs(draw_img_save)
cv2.imwrite( cv2.imwrite(
os.path.join(draw_img_save, os.path.basename(image_file)), os.path.join(draw_img_save, os.path.basename(image_file)),
draw_img[:, :, ::-1]) draw_img[:, :, ::-1])
print("The visualized image saved in {}".format( logger.info("The visualized image saved in {}".format(
os.path.join(draw_img_save, os.path.basename(image_file)))) os.path.join(draw_img_save, os.path.basename(image_file))))

View File

@ -71,6 +71,7 @@ def parse_args():
parser.add_argument("--use_space_char", type=str2bool, default=True) parser.add_argument("--use_space_char", type=str2bool, default=True)
parser.add_argument( parser.add_argument(
"--vis_font_path", type=str, default="./doc/simfang.ttf") "--vis_font_path", type=str, default="./doc/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier # params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False) parser.add_argument("--use_angle_cls", type=str2bool, default=False)
@ -202,7 +203,12 @@ def draw_ocr(image,
return image return image
def draw_ocr_box_txt(image, boxes, txts): def draw_ocr_box_txt(image,
boxes,
txts,
scores=None,
drop_score=0.5,
font_path="./doc/simfang.ttf"):
h, w = image.height, image.width h, w = image.height, image.width
img_left = image.copy() img_left = image.copy()
img_right = Image.new('RGB', (w, h), (255, 255, 255)) img_right = Image.new('RGB', (w, h), (255, 255, 255))
@ -212,7 +218,9 @@ def draw_ocr_box_txt(image, boxes, txts):
random.seed(0) random.seed(0)
draw_left = ImageDraw.Draw(img_left) draw_left = ImageDraw.Draw(img_left)
draw_right = ImageDraw.Draw(img_right) draw_right = ImageDraw.Draw(img_right)
for (box, txt) in zip(boxes, txts): for idx, (box, txt) in enumerate(zip(boxes, txts)):
if scores is not None and scores[idx] < drop_score:
continue
color = (random.randint(0, 255), random.randint(0, 255), color = (random.randint(0, 255), random.randint(0, 255),
random.randint(0, 255)) random.randint(0, 255))
draw_left.polygon(box, fill=color) draw_left.polygon(box, fill=color)
@ -222,14 +230,13 @@ def draw_ocr_box_txt(image, boxes, txts):
box[2][1], box[3][0], box[3][1] box[2][1], box[3][0], box[3][1]
], ],
outline=color) outline=color)
box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][ box_height = math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][
1])**2) 1]) ** 2)
box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][ box_width = math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][
1])**2) 1]) ** 2)
if box_height > 2 * box_width: if box_height > 2 * box_width:
font_size = max(int(box_width * 0.9), 10) font_size = max(int(box_width * 0.9), 10)
font = ImageFont.truetype( font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
"./doc/simfang.ttf", font_size, encoding="utf-8")
cur_y = box[0][1] cur_y = box[0][1]
for c in txt: for c in txt:
char_size = font.getsize(c) char_size = font.getsize(c)
@ -238,8 +245,7 @@ def draw_ocr_box_txt(image, boxes, txts):
cur_y += char_size[1] cur_y += char_size[1]
else: else:
font_size = max(int(box_height * 0.8), 10) font_size = max(int(box_height * 0.8), 10)
font = ImageFont.truetype( font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
"./doc/simfang.ttf", font_size, encoding="utf-8")
draw_right.text( draw_right.text(
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) [box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5) img_left = Image.blend(image, img_left, 0.5)