Merge pull request #1299 from WenmuZhou/fix_predict_system
add predict_cls to predict_system
This commit is contained in:
commit
99ee41d8db
|
@ -23,7 +23,7 @@ import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import paddle.fluid as fluid
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
import tools.infer.utility as utility
|
import tools.infer.utility as utility
|
||||||
|
@ -106,10 +106,10 @@ class TextClassifier(object):
|
||||||
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
||||||
self.predictor.run([norm_img_batch])
|
self.predictor.run([norm_img_batch])
|
||||||
prob_out = self.output_tensors[0].copy_to_cpu()
|
prob_out = self.output_tensors[0].copy_to_cpu()
|
||||||
cls_res = self.postprocess_op(prob_out)
|
cls_result = self.postprocess_op(prob_out)
|
||||||
elapse += time.time() - starttime
|
elapse += time.time() - starttime
|
||||||
for rno in range(len(cls_res)):
|
for rno in range(len(cls_result)):
|
||||||
label, score = cls_res[rno]
|
label, score = cls_result[rno]
|
||||||
cls_res[indices[beg_img_no + rno]] = [label, score]
|
cls_res[indices[beg_img_no + rno]] = [label, score]
|
||||||
if '180' in label and score > self.cls_thresh:
|
if '180' in label and score > self.cls_thresh:
|
||||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||||
|
@ -133,8 +133,8 @@ def main(args):
|
||||||
img_list.append(img)
|
img_list.append(img)
|
||||||
try:
|
try:
|
||||||
img_list, cls_res, predict_time = text_classifier(img_list)
|
img_list, cls_res, predict_time = text_classifier(img_list)
|
||||||
except Exception as e:
|
except:
|
||||||
print(e)
|
logger.info(traceback.format_exc())
|
||||||
logger.info(
|
logger.info(
|
||||||
"ERROR!!!! \n"
|
"ERROR!!!! \n"
|
||||||
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
|
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
|
||||||
|
@ -143,10 +143,10 @@ def main(args):
|
||||||
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
|
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
|
||||||
exit()
|
exit()
|
||||||
for ino in range(len(img_list)):
|
for ino in range(len(img_list)):
|
||||||
print("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[
|
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[
|
||||||
ino]))
|
ino]))
|
||||||
print("Total predict time for {} images, cost: {:.3f}".format(
|
logger.info("Total predict time for {} images, cost: {:.3f}".format(
|
||||||
len(img_list), predict_time))
|
len(img_list), predict_time))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main(utility.parse_args())
|
main(utility.parse_args())
|
||||||
|
|
|
@ -178,11 +178,12 @@ if __name__ == "__main__":
|
||||||
if count > 0:
|
if count > 0:
|
||||||
total_time += elapse
|
total_time += elapse
|
||||||
count += 1
|
count += 1
|
||||||
print("Predict time of {}: {}".format(image_file, elapse))
|
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
||||||
src_im = utility.draw_text_det_res(dt_boxes, image_file)
|
src_im = utility.draw_text_det_res(dt_boxes, image_file)
|
||||||
img_name_pure = os.path.split(image_file)[-1]
|
img_name_pure = os.path.split(image_file)[-1]
|
||||||
img_path = os.path.join(draw_img_save,
|
img_path = os.path.join(draw_img_save,
|
||||||
"det_res_{}".format(img_name_pure))
|
"det_res_{}".format(img_name_pure))
|
||||||
cv2.imwrite(img_path, src_im)
|
cv2.imwrite(img_path, src_im)
|
||||||
|
logger.info("The visualized image saved in {}".format(img_path))
|
||||||
if count > 1:
|
if count > 1:
|
||||||
print("Avg Time:", total_time / (count - 1))
|
logger.info("Avg Time:", total_time / (count - 1))
|
||||||
|
|
|
@ -22,7 +22,7 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
import paddle.fluid as fluid
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
import tools.infer.utility as utility
|
import tools.infer.utility as utility
|
||||||
|
@ -135,8 +135,8 @@ def main(args):
|
||||||
img_list.append(img)
|
img_list.append(img)
|
||||||
try:
|
try:
|
||||||
rec_res, predict_time = text_recognizer(img_list)
|
rec_res, predict_time = text_recognizer(img_list)
|
||||||
except Exception as e:
|
except:
|
||||||
print(e)
|
logger.info(traceback.format_exc())
|
||||||
logger.info(
|
logger.info(
|
||||||
"ERROR!!!! \n"
|
"ERROR!!!! \n"
|
||||||
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
|
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
|
||||||
|
@ -145,9 +145,9 @@ def main(args):
|
||||||
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
|
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
|
||||||
exit()
|
exit()
|
||||||
for ino in range(len(img_list)):
|
for ino in range(len(img_list)):
|
||||||
print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[
|
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[
|
||||||
ino]))
|
ino]))
|
||||||
print("Total predict time for {} images, cost: {:.3f}".format(
|
logger.info("Total predict time for {} images, cost: {:.3f}".format(
|
||||||
len(img_list), predict_time))
|
len(img_list), predict_time))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,17 +23,21 @@ import numpy as np
|
||||||
import time
|
import time
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import tools.infer.utility as utility
|
import tools.infer.utility as utility
|
||||||
from tools.infer.utility import draw_ocr
|
|
||||||
import tools.infer.predict_rec as predict_rec
|
import tools.infer.predict_rec as predict_rec
|
||||||
import tools.infer.predict_det as predict_det
|
import tools.infer.predict_det as predict_det
|
||||||
|
import tools.infer.predict_cls as predict_cls
|
||||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
from ppocr.utils.logging import get_logger
|
from ppocr.utils.logging import get_logger
|
||||||
|
from tools.infer.utility import draw_ocr_box_txt
|
||||||
|
|
||||||
|
|
||||||
class TextSystem(object):
|
class TextSystem(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.text_detector = predict_det.TextDetector(args)
|
self.text_detector = predict_det.TextDetector(args)
|
||||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||||
|
self.use_angle_cls = args.use_angle_cls
|
||||||
|
if self.use_angle_cls:
|
||||||
|
self.text_classifier = predict_cls.TextClassifier(args)
|
||||||
|
|
||||||
def get_rotate_crop_image(self, img, points):
|
def get_rotate_crop_image(self, img, points):
|
||||||
'''
|
'''
|
||||||
|
@ -72,12 +76,12 @@ class TextSystem(object):
|
||||||
bbox_num = len(img_crop_list)
|
bbox_num = len(img_crop_list)
|
||||||
for bno in range(bbox_num):
|
for bno in range(bbox_num):
|
||||||
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
|
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
|
||||||
print(bno, rec_res[bno])
|
logger.info(bno, rec_res[bno])
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
ori_im = img.copy()
|
ori_im = img.copy()
|
||||||
dt_boxes, elapse = self.text_detector(img)
|
dt_boxes, elapse = self.text_detector(img)
|
||||||
print("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse))
|
logger.info("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse))
|
||||||
if dt_boxes is None:
|
if dt_boxes is None:
|
||||||
return None, None
|
return None, None
|
||||||
img_crop_list = []
|
img_crop_list = []
|
||||||
|
@ -88,8 +92,14 @@ class TextSystem(object):
|
||||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||||
img_crop_list.append(img_crop)
|
img_crop_list.append(img_crop)
|
||||||
|
if self.use_angle_cls:
|
||||||
|
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||||
|
img_crop_list)
|
||||||
|
logger.info("cls num : {}, elapse : {}".format(
|
||||||
|
len(img_crop_list), elapse))
|
||||||
|
|
||||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||||
print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
|
logger.info("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
|
||||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||||
return dt_boxes, rec_res
|
return dt_boxes, rec_res
|
||||||
|
|
||||||
|
@ -119,7 +129,8 @@ def main(args):
|
||||||
image_file_list = get_image_file_list(args.image_dir)
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
text_sys = TextSystem(args)
|
text_sys = TextSystem(args)
|
||||||
is_visualize = True
|
is_visualize = True
|
||||||
tackle_img_num = 0
|
font_path = args.vis_font_path
|
||||||
|
drop_score = args.drop_score
|
||||||
for image_file in image_file_list:
|
for image_file in image_file_list:
|
||||||
img, flag = check_and_read_gif(image_file)
|
img, flag = check_and_read_gif(image_file)
|
||||||
if not flag:
|
if not flag:
|
||||||
|
@ -128,20 +139,16 @@ def main(args):
|
||||||
logger.info("error in loading image:{}".format(image_file))
|
logger.info("error in loading image:{}".format(image_file))
|
||||||
continue
|
continue
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
tackle_img_num += 1
|
|
||||||
if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0:
|
|
||||||
text_sys = TextSystem(args)
|
|
||||||
dt_boxes, rec_res = text_sys(img)
|
dt_boxes, rec_res = text_sys(img)
|
||||||
elapse = time.time() - starttime
|
elapse = time.time() - starttime
|
||||||
print("Predict time of %s: %.3fs" % (image_file, elapse))
|
logger.info("Predict time of %s: %.3fs" % (image_file, elapse))
|
||||||
|
|
||||||
drop_score = 0.5
|
|
||||||
dt_num = len(dt_boxes)
|
dt_num = len(dt_boxes)
|
||||||
for dno in range(dt_num):
|
for dno in range(dt_num):
|
||||||
text, score = rec_res[dno]
|
text, score = rec_res[dno]
|
||||||
if score >= drop_score:
|
if score >= drop_score:
|
||||||
text_str = "%s, %.3f" % (text, score)
|
text_str = "%s, %.3f" % (text, score)
|
||||||
print(text_str)
|
logger.info(text_str)
|
||||||
|
|
||||||
if is_visualize:
|
if is_visualize:
|
||||||
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
||||||
|
@ -149,15 +156,20 @@ def main(args):
|
||||||
txts = [rec_res[i][0] for i in range(len(rec_res))]
|
txts = [rec_res[i][0] for i in range(len(rec_res))]
|
||||||
scores = [rec_res[i][1] for i in range(len(rec_res))]
|
scores = [rec_res[i][1] for i in range(len(rec_res))]
|
||||||
|
|
||||||
draw_img = draw_ocr(
|
draw_img = draw_ocr_box_txt(
|
||||||
image, boxes, txts, scores, drop_score=drop_score)
|
image,
|
||||||
|
boxes,
|
||||||
|
txts,
|
||||||
|
scores,
|
||||||
|
drop_score=drop_score,
|
||||||
|
font_path=font_path)
|
||||||
draw_img_save = "./inference_results/"
|
draw_img_save = "./inference_results/"
|
||||||
if not os.path.exists(draw_img_save):
|
if not os.path.exists(draw_img_save):
|
||||||
os.makedirs(draw_img_save)
|
os.makedirs(draw_img_save)
|
||||||
cv2.imwrite(
|
cv2.imwrite(
|
||||||
os.path.join(draw_img_save, os.path.basename(image_file)),
|
os.path.join(draw_img_save, os.path.basename(image_file)),
|
||||||
draw_img[:, :, ::-1])
|
draw_img[:, :, ::-1])
|
||||||
print("The visualized image saved in {}".format(
|
logger.info("The visualized image saved in {}".format(
|
||||||
os.path.join(draw_img_save, os.path.basename(image_file))))
|
os.path.join(draw_img_save, os.path.basename(image_file))))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -71,6 +71,7 @@ def parse_args():
|
||||||
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vis_font_path", type=str, default="./doc/simfang.ttf")
|
"--vis_font_path", type=str, default="./doc/simfang.ttf")
|
||||||
|
parser.add_argument("--drop_score", type=float, default=0.5)
|
||||||
|
|
||||||
# params for text classifier
|
# params for text classifier
|
||||||
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||||
|
@ -202,7 +203,12 @@ def draw_ocr(image,
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
def draw_ocr_box_txt(image, boxes, txts):
|
def draw_ocr_box_txt(image,
|
||||||
|
boxes,
|
||||||
|
txts,
|
||||||
|
scores=None,
|
||||||
|
drop_score=0.5,
|
||||||
|
font_path="./doc/simfang.ttf"):
|
||||||
h, w = image.height, image.width
|
h, w = image.height, image.width
|
||||||
img_left = image.copy()
|
img_left = image.copy()
|
||||||
img_right = Image.new('RGB', (w, h), (255, 255, 255))
|
img_right = Image.new('RGB', (w, h), (255, 255, 255))
|
||||||
|
@ -212,7 +218,9 @@ def draw_ocr_box_txt(image, boxes, txts):
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
draw_left = ImageDraw.Draw(img_left)
|
draw_left = ImageDraw.Draw(img_left)
|
||||||
draw_right = ImageDraw.Draw(img_right)
|
draw_right = ImageDraw.Draw(img_right)
|
||||||
for (box, txt) in zip(boxes, txts):
|
for idx, (box, txt) in enumerate(zip(boxes, txts)):
|
||||||
|
if scores is not None and scores[idx] < drop_score:
|
||||||
|
continue
|
||||||
color = (random.randint(0, 255), random.randint(0, 255),
|
color = (random.randint(0, 255), random.randint(0, 255),
|
||||||
random.randint(0, 255))
|
random.randint(0, 255))
|
||||||
draw_left.polygon(box, fill=color)
|
draw_left.polygon(box, fill=color)
|
||||||
|
@ -222,14 +230,13 @@ def draw_ocr_box_txt(image, boxes, txts):
|
||||||
box[2][1], box[3][0], box[3][1]
|
box[2][1], box[3][0], box[3][1]
|
||||||
],
|
],
|
||||||
outline=color)
|
outline=color)
|
||||||
box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][
|
box_height = math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][
|
||||||
1])**2)
|
1]) ** 2)
|
||||||
box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][
|
box_width = math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][
|
||||||
1])**2)
|
1]) ** 2)
|
||||||
if box_height > 2 * box_width:
|
if box_height > 2 * box_width:
|
||||||
font_size = max(int(box_width * 0.9), 10)
|
font_size = max(int(box_width * 0.9), 10)
|
||||||
font = ImageFont.truetype(
|
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||||
"./doc/simfang.ttf", font_size, encoding="utf-8")
|
|
||||||
cur_y = box[0][1]
|
cur_y = box[0][1]
|
||||||
for c in txt:
|
for c in txt:
|
||||||
char_size = font.getsize(c)
|
char_size = font.getsize(c)
|
||||||
|
@ -238,8 +245,7 @@ def draw_ocr_box_txt(image, boxes, txts):
|
||||||
cur_y += char_size[1]
|
cur_y += char_size[1]
|
||||||
else:
|
else:
|
||||||
font_size = max(int(box_height * 0.8), 10)
|
font_size = max(int(box_height * 0.8), 10)
|
||||||
font = ImageFont.truetype(
|
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||||
"./doc/simfang.ttf", font_size, encoding="utf-8")
|
|
||||||
draw_right.text(
|
draw_right.text(
|
||||||
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
|
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
|
||||||
img_left = Image.blend(image, img_left, 0.5)
|
img_left = Image.blend(image, img_left, 0.5)
|
||||||
|
|
Loading…
Reference in New Issue