diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index b96a9fab..06625eaf 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -13,9 +13,9 @@ # limitations under the License. import os import sys -__dir__ = os.path.dirname(__file__) +__dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) -sys.path.append(os.path.join(__dir__, '../..')) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) import tools.infer.utility as utility from ppocr.utils.utility import initial_logger @@ -33,14 +33,12 @@ class TextRecognizer(object): def __init__(self, args): self.predictor, self.input_tensor, self.output_tensors =\ utility.create_predictor(args, mode="rec") - image_shape = [int(v) for v in args.rec_image_shape.split(",")] - self.rec_image_shape = image_shape + self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num self.rec_algorithm = args.rec_algorithm - char_ops_params = {} - char_ops_params["character_type"] = args.rec_char_type - char_ops_params["character_dict_path"] = args.rec_char_dict_path + char_ops_params = {"character_type": args.rec_char_type, + "character_dict_path": args.rec_char_dict_path} if self.rec_algorithm != "RARE": char_ops_params['loss_type'] = 'ctc' self.loss_type = 'ctc' @@ -51,16 +49,11 @@ class TextRecognizer(object): def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape - if self.character_type == "ch": - imgW = int(32 * max_wh_ratio) - h = img.shape[0] - w = img.shape[1] - ratio = w / float(h) - if math.ceil(imgH * ratio) > imgW: - resized_w = imgW - else: - resized_w = int(math.ceil(imgH * ratio)) - resized_image = cv2.resize(img, (resized_w, imgH)) + assert imgC == img.shape[2] + imgW = int(math.ceil(32 * max_wh_ratio)) + h, w = img.shape[:2] + resized_w = int(math.ceil(imgH * w / float(h))) + resized_image = cv2.resize(img, (resized_w, imgH), interpolation=cv2.INTER_CUBIC) resized_image = resized_image.astype('float32') resized_image = resized_image.transpose((2, 0, 1)) / 255 resized_image -= 0.5 @@ -71,7 +64,15 @@ class TextRecognizer(object): def __call__(self, img_list): img_num = len(img_list) - rec_res = [] + # 统计所有文本条的宽高比 + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # 对于文本框比较多且长短差异较大的情况下,通过排序再组合batch可以明显加速识别 + indices = np.argsort(np.array(width_list)) + + # rec_res = [] + rec_res = [['', 0.0]] * img_num batch_num = self.rec_batch_num predict_time = 0 for beg_img_no in range(0, img_num, batch_num): @@ -80,10 +81,12 @@ class TextRecognizer(object): max_wh_ratio = 0 for ino in range(beg_img_no, end_img_no): h, w = img_list[ino].shape[0:2] + # h, w = img_list[indices[ino]].shape[0:2] wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) + # norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio) norm_img = norm_img[np.newaxis, :] norm_img_batch.append(norm_img) norm_img_batch = np.concatenate(norm_img_batch) @@ -111,7 +114,8 @@ class TextRecognizer(object): blank = probs.shape[1] valid_ind = np.where(ind != (blank - 1))[0] score = np.mean(probs[valid_ind, ind[valid_ind]]) - rec_res.append([preds_text, score]) + # rec_res.append([preds_text, score]) + rec_res[indices[beg_img_no + rno]] = [preds_text, score] else: rec_idx_batch = self.output_tensors[0].copy_to_cpu() predict_batch = self.output_tensors[1].copy_to_cpu() @@ -126,19 +130,19 @@ class TextRecognizer(object): preds = rec_idx_batch[rno, 1:end_pos[1]] score = np.mean(predict_batch[rno, 1:end_pos[1]]) preds_text = self.char_ops.decode(preds) - rec_res.append([preds_text, score]) + # rec_res.append([preds_text, score]) + rec_res[indices[beg_img_no + rno]] = [preds_text, score] return rec_res, predict_time -if __name__ == "__main__": - args = utility.parse_args() +def main(args): image_file_list = get_image_file_list(args.image_dir) text_recognizer = TextRecognizer(args) valid_image_file_list = [] img_list = [] for image_file in image_file_list: - img = cv2.imread(image_file) + img = cv2.imread(image_file, cv2.IMREAD_COLOR) if img is None: logger.info("error in loading image:{}".format(image_file)) continue @@ -159,3 +163,7 @@ if __name__ == "__main__": print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Total predict time for %d images:%.3f" % (len(img_list), predict_time)) + + +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 8d075502..e96a1934 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -75,6 +75,7 @@ class TextSystem(object): def __call__(self, img): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) + print("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse)) if dt_boxes is None: return None, None img_crop_list = [] @@ -86,6 +87,7 @@ class TextSystem(object): img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) rec_res, elapse = self.text_recognizer(img_crop_list) + print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) return dt_boxes, rec_res