From fffd556cab25c0839a396c935fd8122c83b04f04 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Fri, 3 Sep 2021 12:09:50 +0000 Subject: [PATCH] fix distill model predict --- tools/infer_det.py | 55 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/tools/infer_det.py b/tools/infer_det.py index a964cd28..ce16da8d 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -34,23 +34,21 @@ import paddle from ppocr.data import create_operators, transform from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import init_model +from ppocr.utils.save_load import init_model, load_dygraph_params from ppocr.utils.utility import get_image_file_list import tools.program as program -def draw_det_res(dt_boxes, config, img, img_name): +def draw_det_res(dt_boxes, config, img, img_name, save_path): if len(dt_boxes) > 0: import cv2 src_im = img for box in dt_boxes: box = box.astype(np.int32).reshape((-1, 1, 2)) cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) - save_det_path = os.path.dirname(config['Global'][ - 'save_res_path']) + "/det_results/" - if not os.path.exists(save_det_path): - os.makedirs(save_det_path) - save_path = os.path.join(save_det_path, os.path.basename(img_name)) + if not os.path.exists(save_path): + os.makedirs(save_path) + save_path = os.path.join(save_path, os.path.basename(img_name)) cv2.imwrite(save_path, src_im) logger.info("The detected Image saved in {}".format(save_path)) @@ -61,8 +59,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model) - + _ = load_dygraph_params(config, model, logger, None) # build post process post_process_class = build_post_process(config['PostProcess']) @@ -96,17 +93,41 @@ def main(): images = paddle.to_tensor(images) preds = model(images) post_result = post_process_class(preds, shape_list) - boxes = post_result[0]['points'] - # write result + + src_img = cv2.imread(file) + dt_boxes_json = [] - for box in boxes: - tmp_json = {"transcription": ""} - tmp_json['points'] = box.tolist() - dt_boxes_json.append(tmp_json) + # parser boxes if post_result is dict + if isinstance(post_result, dict): + det_box_json = {} + for k in post_result.keys(): + boxes = post_result[k][0]['points'] + dt_boxes_list = [] + for box in boxes: + tmp_json = {"transcription": ""} + tmp_json['points'] = box.tolist() + dt_boxes_list.append(tmp_json) + det_box_json[k] = dt_boxes_list + save_det_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/det_results_{}/".format(k) + draw_det_res(boxes, config, src_img, file, save_det_path) + else: + boxes = post_result[0]['points'] + dt_boxes_json = [] + # write result + for box in boxes: + tmp_json = {"transcription": ""} + tmp_json['points'] = box.tolist() + dt_boxes_json.append(tmp_json) + save_det_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/det_results/" + draw_det_res(boxes, config, src_img, file, save_det_path) otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" fout.write(otstr.encode()) - src_img = cv2.imread(file) - draw_det_res(boxes, config, src_img, file) + + save_det_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/det_results/" + draw_det_res(boxes, config, src_img, file, save_det_path) logger.info("success!")