From 8e05ffed7e43057554437925e4a5d061a84eaef5 Mon Sep 17 00:00:00 2001 From: dyning Date: Mon, 13 Jul 2020 17:25:30 +0800 Subject: [PATCH] move out visulization from hubserving --- deploy/hubserving/ocr_det/config.json | 1 - deploy/hubserving/ocr_det/module.py | 38 +++---- deploy/hubserving/ocr_rec/module.py | 16 ++- deploy/hubserving/ocr_system/config.json | 1 - deploy/hubserving/ocr_system/module.py | 59 +++-------- doc/doc_ch/serving.md | 46 ++++----- tools/infer/predict_system.py | 4 - tools/infer/utility.py | 2 +- tools/test_hubserving.py | 121 ++++++++++++++++++++--- 9 files changed, 168 insertions(+), 120 deletions(-) diff --git a/deploy/hubserving/ocr_det/config.json b/deploy/hubserving/ocr_det/config.json index 9f6fd50f..c8ef055e 100644 --- a/deploy/hubserving/ocr_det/config.json +++ b/deploy/hubserving/ocr_det/config.json @@ -6,7 +6,6 @@ "use_gpu": true }, "predict_args": { - "visualization": false } } }, diff --git a/deploy/hubserving/ocr_det/module.py b/deploy/hubserving/ocr_det/module.py index 6b7bafb8..e5fac23f 100644 --- a/deploy/hubserving/ocr_det/module.py +++ b/deploy/hubserving/ocr_det/module.py @@ -19,7 +19,7 @@ import numpy as np import paddle.fluid as fluid import paddlehub as hub -from tools.infer.utility import draw_boxes, base64_to_cv2 +from tools.infer.utility import base64_to_cv2 from tools.infer.predict_det import TextDetector @@ -68,16 +68,12 @@ class OCRDet(hub.Module): def predict(self, images=[], - paths=[], - draw_img_save='ocr_det_result', - visualization=False): + paths=[]): """ Get the text box in the predicted images. Args: images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths paths (list[str]): The paths of images. If paths not images - draw_img_save (str): The directory to store output images. - visualization (bool): Whether to save image or not. Returns: res (list): The result of text detection box and save path of images. """ @@ -93,29 +89,21 @@ class OCRDet(hub.Module): all_results = [] for img in predicted_data: - result = {'save_path': ''} if img is None: logger.info("error in loading image") - result['data'] = [] - all_results.append(result) + all_results.append([]) continue dt_boxes, elapse = self.text_detector(img) - print("Predict time : ", elapse) - result['data'] = dt_boxes.astype(np.int).tolist() + logger.info("Predict time : {}".format(elapse)) - if visualization: - image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - draw_img = draw_boxes(image, dt_boxes) - draw_img = np.array(draw_img) - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) - saved_name = 'ndarray_{}.jpg'.format(time.time()) - save_file_path = os.path.join(draw_img_save, saved_name) - cv2.imwrite(save_file_path, draw_img[:, :, ::-1]) - print("The visualized image saved in {}".format(save_file_path)) - result['save_path'] = save_file_path - - all_results.append(result) + rec_res_final = [] + for dno in range(len(dt_boxes)): + rec_res_final.append( + { + 'text_region': dt_boxes[dno].astype(np.int).tolist() + } + ) + all_results.append(rec_res_final) return all_results @serving @@ -134,5 +122,5 @@ if __name__ == '__main__': './doc/imgs/11.jpg', './doc/imgs/12.jpg', ] - res = ocr.predict(paths=image_path, visualization=True) + res = ocr.predict(paths=image_path) print(res) \ No newline at end of file diff --git a/deploy/hubserving/ocr_rec/module.py b/deploy/hubserving/ocr_rec/module.py index 77a907d6..cf612dd7 100644 --- a/deploy/hubserving/ocr_rec/module.py +++ b/deploy/hubserving/ocr_rec/module.py @@ -92,12 +92,24 @@ class OCRRec(hub.Module): if img is None: continue img_list.append(img) + + rec_res_final = [] try: rec_res, predict_time = self.text_recognizer(img_list) + for dno in range(len(rec_res)): + text, score = rec_res[dno] + rec_res_final.append( + { + 'text': text, + 'confidence': float(score), + } + ) except Exception as e: print(e) - return [] - return rec_res + return [[]] + + return [rec_res_final] + @serving def serving_method(self, images, **kwargs): diff --git a/deploy/hubserving/ocr_system/config.json b/deploy/hubserving/ocr_system/config.json index 21c701c6..48e7e154 100644 --- a/deploy/hubserving/ocr_system/config.json +++ b/deploy/hubserving/ocr_system/config.json @@ -6,7 +6,6 @@ "use_gpu": true }, "predict_args": { - "visualization": false } } }, diff --git a/deploy/hubserving/ocr_system/module.py b/deploy/hubserving/ocr_system/module.py index a70697f4..bed0c4e6 100644 --- a/deploy/hubserving/ocr_system/module.py +++ b/deploy/hubserving/ocr_system/module.py @@ -19,7 +19,7 @@ import numpy as np import paddle.fluid as fluid import paddlehub as hub -from tools.infer.utility import draw_ocr, base64_to_cv2 +from tools.infer.utility import base64_to_cv2 from tools.infer.predict_system import TextSystem @@ -68,18 +68,12 @@ class OCRSystem(hub.Module): def predict(self, images=[], - paths=[], - draw_img_save='ocr_result', - visualization=False, - text_thresh=0.5): + paths=[]): """ Get the chinese texts in the predicted images. Args: images (list(numpy.ndarray)): images data, shape of each is [H, W, C]. If images not paths paths (list[str]): The paths of images. If paths not images - draw_img_save (str): The directory to store output images. - visualization (bool): Whether to save image or not. - text_thresh(float): the threshold of the recognize chinese texts' confidence Returns: res (list): The result of chinese texts and save path of images. """ @@ -93,53 +87,30 @@ class OCRSystem(hub.Module): assert predicted_data != [], "There is not any image to be predicted. Please check the input data." - cnt = 0 all_results = [] for img in predicted_data: - result = {'save_path': ''} if img is None: logger.info("error in loading image") - result['data'] = [] - all_results.append(result) + all_results.append([]) continue starttime = time.time() dt_boxes, rec_res = self.text_sys(img) elapse = time.time() - starttime - cnt += 1 - print("Predict time of image %d: %.3fs" % (cnt, elapse)) + logger.info("Predict time: {}".format(elapse)) + dt_num = len(dt_boxes) rec_res_final = [] + for dno in range(dt_num): text, score = rec_res[dno] - # if the recognized text confidence score is lower than text_thresh, then drop it - if score >= text_thresh: - # text_str = "%s, %.3f" % (text, score) - # print(text_str) - rec_res_final.append( - { - 'text': text, - 'confidence': float(score), - 'text_box_position': dt_boxes[dno].astype(np.int).tolist() - } - ) - result['data'] = rec_res_final - - if visualization: - image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - boxes = dt_boxes - txts = [rec_res[i][0] for i in range(len(rec_res))] - scores = [rec_res[i][1] for i in range(len(rec_res))] - - draw_img = draw_ocr(image, boxes, txts, scores, draw_txt=True, drop_score=0.5) - if not os.path.exists(draw_img_save): - os.makedirs(draw_img_save) - saved_name = 'ndarray_{}.jpg'.format(time.time()) - save_file_path = os.path.join(draw_img_save, saved_name) - cv2.imwrite(save_file_path, draw_img[:, :, ::-1]) - print("The visualized image saved in {}".format(save_file_path)) - result['save_path'] = save_file_path - - all_results.append(result) + rec_res_final.append( + { + 'text': text, + 'confidence': float(score), + 'text_region': dt_boxes[dno].astype(np.int).tolist() + } + ) + all_results.append(rec_res_final) return all_results @serving @@ -158,5 +129,5 @@ if __name__ == '__main__': './doc/imgs/11.jpg', './doc/imgs/12.jpg', ] - res = ocr.predict(paths=image_path, visualization=False) + res = ocr.predict(paths=image_path) print(res) \ No newline at end of file diff --git a/doc/doc_ch/serving.md b/doc/doc_ch/serving.md index 69860e67..743017af 100644 --- a/doc/doc_ch/serving.md +++ b/doc/doc_ch/serving.md @@ -23,8 +23,14 @@ deploy/hubserving/ocr_system/ ## 快速启动服务 以下步骤以检测+识别2阶段串联服务为例,如果只需要检测服务或识别服务,替换相应文件路径即可。 -### 1. 安装paddlehub -```pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple``` +### 1. 准备环境 +```shell +# 安装paddlehub +pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple + +# 设置环境变量 +export PYTHONPATH=. +``` ### 2. 安装服务模块 PaddleOCR提供3种服务模块,根据需要安装所需模块。如: @@ -75,7 +81,6 @@ $ hub serving start --modules [Module1==Version1, Module2==Version2, ...] \ "use_gpu": true }, "predict_args": { - "visualization": false } } }, @@ -99,32 +104,21 @@ hub serving start -c deploy/hubserving/ocr_system/config.json ``` ## 发送预测请求 -配置好服务端,以下数行代码即可实现发送预测请求,获取预测结果: +配置好服务端,可使用以下命令发送预测请求,获取预测结果: -```python -import requests -import json -import cv2 -import base64 +```python tools/test_hubserving.py server_url image_path``` -def cv2_to_base64(image): - return base64.b64encode(image).decode('utf8') +需要给脚本传递2个参数: +- **server_url**:服务地址,格式为 +`http://[ip_address]:[port]/predict/[module_name]` +例如,如果使用配置文件启动检测、识别、检测+识别2阶段服务,那么发送请求的url将分别是: +`http://127.0.0.1:8866/predict/ocr_det` +`http://127.0.0.1:8867/predict/ocr_rec` +`http://127.0.0.1:8868/predict/ocr_system` +- **image_path**:测试图像路径,可以是单张图片路径,也可以是图像集合目录路径 -# 发送HTTP请求 -data = {'images':[cv2_to_base64(open("./doc/imgs/11.jpg", 'rb').read())]} -headers = {"Content-type": "application/json"} -# url = "http://127.0.0.1:8866/predict/ocr_det" -# url = "http://127.0.0.1:8866/predict/ocr_rec" -url = "http://127.0.0.1:8866/predict/ocr_system" -r = requests.post(url=url, headers=headers, data=json.dumps(data)) - -# 打印预测结果 -print(r.json()["results"]) -``` - -你可能需要根据实际情况修改`url`字符串中的端口号和服务模块名称。 - -上面所示代码都已写入测试脚本,可直接运行命令:```python tools/test_hubserving.py``` +访问示例: +```python tools/test_hubserving.py http://127.0.0.1:8868/predict/ocr_system ./doc/imgs/``` ## 自定义修改服务模块 如果需要修改服务逻辑,你一般需要操作以下步骤(以修改`ocr_system`为例): diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 65478b69..e96a1934 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -117,16 +117,12 @@ def main(args): image_file_list = get_image_file_list(args.image_dir) text_sys = TextSystem(args) is_visualize = True - tackle_img_num = 0 for image_file in image_file_list: img = cv2.imread(image_file) if img is None: logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() - tackle_img_num += 1 - if not args.use_gpu and tackle_img_num % 30 == 0: - text_sys = TextSystem(args) dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime print("Predict time of %s: %.3fs" % (image_file, elapse)) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index f4361a76..0cf66d4c 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -91,7 +91,7 @@ def create_predictor(args, mode): config.enable_use_gpu(args.gpu_mem, 0) else: config.disable_gpu() - config.enable_mkldnn() + # config.enable_mkldnn() config.set_cpu_math_library_num_threads(4) #config.enable_memory_optim() config.disable_glog_info() diff --git a/tools/test_hubserving.py b/tools/test_hubserving.py index edf6ec8c..ea592906 100644 --- a/tools/test_hubserving.py +++ b/tools/test_hubserving.py @@ -1,25 +1,114 @@ -#!usr/bin/python -# -*- coding: utf-8 -*- +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +from ppocr.utils.utility import initial_logger +logger = initial_logger() +import cv2 +import numpy as np +import time +from PIL import Image +from ppocr.utils.utility import get_image_file_list +from tools.infer.utility import draw_ocr, draw_boxes import requests import json -import cv2 import base64 -import time + def cv2_to_base64(image): return base64.b64encode(image).decode('utf8') -start = time.time() -# 发送HTTP请求 -data = {'images':[cv2_to_base64(open("./doc/imgs/11.jpg", 'rb').read())]} -headers = {"Content-type": "application/json"} -# url = "http://127.0.0.1:8866/predict/ocr_det" -# url = "http://127.0.0.1:8866/predict/ocr_rec" -url = "http://127.0.0.1:8866/predict/ocr_system" -r = requests.post(url=url, headers=headers, data=json.dumps(data)) -end = time.time() -# 打印预测结果 -print(r.json()["results"]) -print("time cost: ", end - start) +def draw_server_result(image_file, res): + img = cv2.imread(image_file) + image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + if len(res) == 0: + return np.array(image) + keys = res[0].keys() + if 'text_region' not in keys: # for ocr_rec, draw function is invalid + print("draw function is invalid for ocr_rec!") + return None + elif 'text' not in keys: # for ocr_det + print("draw text boxes only!") + boxes = [] + for dno in range(len(res)): + boxes.append(res[dno]['text_region']) + boxes = np.array(boxes) + draw_img = draw_boxes(image, boxes) + return draw_img + else: # for ocr_system + print("draw boxes and texts!") + boxes = [] + texts = [] + scores = [] + for dno in range(len(res)): + boxes.append(res[dno]['text_region']) + texts.append(res[dno]['text']) + scores.append(res[dno]['confidence']) + boxes = np.array(boxes) + scores = np.array(scores) + draw_img = draw_ocr(image, boxes, texts, scores, draw_txt=True, drop_score=0.5) + return draw_img + + +def main(url, image_path): + image_file_list = get_image_file_list(image_path) + is_visualize = False + headers = {"Content-type": "application/json"} + cnt = 0 + total_time = 0 + for image_file in image_file_list: + img = open(image_file, 'rb').read() + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + + # 发送HTTP请求 + starttime = time.time() + data = {'images':[cv2_to_base64(img)]} + r = requests.post(url=url, headers=headers, data=json.dumps(data)) + elapse = time.time() - starttime + total_time += elapse + print("Predict time of %s: %.3fs" % (image_file, elapse)) + res = r.json()["results"][0] + # print(res) + + if is_visualize: + draw_img = draw_server_result(image_file, res) + if draw_img is not None: + draw_img_save = "./server_results/" + if not os.path.exists(draw_img_save): + os.makedirs(draw_img_save) + cv2.imwrite( + os.path.join(draw_img_save, os.path.basename(image_file)), + draw_img[:, :, ::-1]) + print("The visualized image saved in {}".format( + os.path.join(draw_img_save, os.path.basename(image_file)))) + cnt += 1 + if cnt % 100 == 0: + print(cnt, "processed") + print("avg time cost: ", float(total_time)/cnt) + +if __name__ == '__main__': + if len(sys.argv) != 3: + print("Usage: %s server_url image_path" % sys.argv[0]) + else: + server_url = sys.argv[1] + image_path = sys.argv[2] + main(server_url, image_path) \ No newline at end of file