diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 265ab592..48b1e025 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -202,6 +202,12 @@ if __name__ == "__main__": count = 0 total_time = 0 draw_img_save = "./inference_results" + # warmup 10 times + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(10): + res = text_detector(img) + if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) for image_file in image_file_list: diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 24388026..c7808e2e 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -254,6 +254,12 @@ def main(args): total_images_num = 0 valid_image_file_list = [] img_list = [] + # warmup 10 times + if args.warmup: + img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8) + for i in range(10): + res = text_recognizer([img]) + for idx, image_file in enumerate(image_file_list): img, flag = check_and_read_gif(image_file) if not flag: diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 235a075b..d9433ffb 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -147,6 +147,12 @@ def main(args): is_visualize = True font_path = args.vis_font_path drop_score = args.drop_score + # warm up 10 times + if args.warmup: + img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8) + for i in range(10): + res = text_sys(img) + for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: diff --git a/tools/infer/utility.py b/tools/infer/utility.py index a558f490..38cd6d76 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -105,6 +105,7 @@ def init_args(): parser.add_argument("--enable_mkldnn", type=str2bool, default=False) parser.add_argument("--cpu_threads", type=int, default=10) parser.add_argument("--use_pdserving", type=str2bool, default=False) + parser.add_argument("--warmup", type=str2bool, default=True) parser.add_argument("--use_mp", type=str2bool, default=False) parser.add_argument("--total_process_num", type=int, default=1)