From 4402e62959babac3fd35cfd89d7af63a55238ab2 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Mon, 9 Nov 2020 18:19:30 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=AD=A3export=5Fmodel=E9=87=8C?= =?UTF-8?q?=E7=9A=84bug=EF=BC=8C=E6=B7=BB=E5=8A=A0predict=5Fdet?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/export_model.py | 24 +++++++++++++++--------- tools/infer/predict_det.py | 28 +++++++++++++++++++--------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/tools/export_model.py b/tools/export_model.py index 60c05725..cf568884 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -12,6 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + import argparse import paddle @@ -20,14 +27,11 @@ from paddle.jit import to_static from ppocr.modeling.architectures import build_model from ppocr.postprocess import build_post_process from ppocr.utils.save_load import init_model +from ppocr.utils.logging import get_logger from tools.program import load_config -from tools.program import merge_config def parse_args(): - def str2bool(v): - return v.lower() in ("true", "t", "1") - parser = argparse.ArgumentParser() parser.add_argument("-c", "--config", help="configuration file to use") parser.add_argument( @@ -43,7 +47,7 @@ class Model(paddle.nn.Layer): # Please modify the 'shape' according to actual needs @to_static(input_spec=[ paddle.static.InputSpec( - shape=[None, 3, 32, None], dtype='float32') + shape=[None, 3, 640, 640], dtype='float32') ]) def forward(self, inputs): x = self.pre_model(inputs) @@ -53,14 +57,13 @@ class Model(paddle.nn.Layer): def main(): FLAGS = parse_args() config = load_config(FLAGS.config) - merge_config(FLAGS.opt) - + logger = get_logger() # build post process post_process_class = build_post_process(config['PostProcess'], config['Global']) # build model - #for rec algorithm + # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) config['Architecture']["Head"]['out_channels'] = char_num @@ -69,7 +72,10 @@ def main(): model.eval() model = Model(model) - paddle.jit.save(model, FLAGS.output_path) + save_path = '{}/{}'.format(FLAGS.output_path, + config['Architecture']['model_type']) + paddle.jit.save(model, save_path) + logger.info('inference model is saved to {}'.format(save_path)) if __name__ == "__main__": diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 561627af..a3850028 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -22,7 +22,6 @@ import cv2 import numpy as np import time import sys - import paddle import tools.infer.utility as utility @@ -39,7 +38,7 @@ class TextDetector(object): postprocess_params = {} if self.det_algorithm == "DB": pre_process_list = [{ - 'ResizeForTest': { + 'DetResizeForTest': { 'limit_side_len': args.det_limit_side_len, 'limit_type': args.det_limit_type } @@ -53,7 +52,7 @@ class TextDetector(object): }, { 'ToCHWImage': None }, { - 'keepKeys': { + 'KeepKeys': { 'keep_keys': ['image', 'shape'] } }] @@ -68,8 +67,9 @@ class TextDetector(object): self.preprocess_op = create_operators(pre_process_list) self.postprocess_op = build_post_process(postprocess_params) - self.predictor = paddle.jit.load(args.det_model_dir) - self.predictor.eval() + self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor( + args, 'det', logger) # paddle.jit.load(args.det_model_dir) + # self.predictor.eval() def order_points_clockwise(self, pts): """ @@ -133,11 +133,23 @@ class TextDetector(object): return None, 0 img = np.expand_dims(img, axis=0) shape_list = np.expand_dims(shape_list, axis=0) + img = img.copy() starttime = time.time() - preds = self.predictor(img) - post_result = self.postprocess_op(preds, shape_list) + if self.use_zero_copy_run: + self.input_tensor.copy_from_cpu(img) + self.predictor.zero_copy_run() + else: + im = paddle.fluid.core.PaddleTensor(img) + self.predictor.run([im]) + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = outputs[0] + # preds = self.predictor(img) + post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape) elapse = time.time() - starttime @@ -146,8 +158,6 @@ class TextDetector(object): if __name__ == "__main__": args = utility.parse_args() - place = paddle.CPUPlace() - paddle.disable_static(place) image_file_list = get_image_file_list(args.image_dir) logger = get_logger()