add inference to serving model tool
This commit is contained in:
parent
e53c427330
commit
0005f4d171
|
@ -22,9 +22,9 @@ import time
|
|||
import re
|
||||
import base64
|
||||
from tools.infer.predict_cls import TextClassifier
|
||||
import tools.infer.utility as utility
|
||||
from params import read_params
|
||||
|
||||
global_args = utility.parse_args()
|
||||
global_args = read_params()
|
||||
if global_args.use_gpu:
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
else:
|
||||
|
|
|
@ -22,9 +22,9 @@ import time
|
|||
import re
|
||||
import base64
|
||||
from tools.infer.predict_cls import TextClassifier
|
||||
import tools.infer.utility as utility
|
||||
from params import read_params
|
||||
|
||||
global_args = utility.parse_args()
|
||||
global_args = read_params()
|
||||
if global_args.use_gpu:
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
else:
|
||||
|
|
|
@ -21,9 +21,9 @@ import time
|
|||
import re
|
||||
import base64
|
||||
from tools.infer.predict_det import TextDetector
|
||||
import tools.infer.utility as utility
|
||||
from params import read_params
|
||||
|
||||
global_args = utility.parse_args()
|
||||
global_args = read_params()
|
||||
if global_args.use_gpu:
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
else:
|
||||
|
|
|
@ -21,9 +21,9 @@ import time
|
|||
import re
|
||||
import base64
|
||||
from tools.infer.predict_det import TextDetector
|
||||
import tools.infer.utility as utility
|
||||
from params import read_params
|
||||
|
||||
global_args = utility.parse_args()
|
||||
global_args = read_params()
|
||||
if global_args.use_gpu:
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
else:
|
||||
|
|
|
@ -24,12 +24,13 @@ import base64
|
|||
from clas_local_server import TextClassifierHelper
|
||||
from det_local_server import TextDetectorHelper
|
||||
from rec_local_server import TextRecognizerHelper
|
||||
import tools.infer.utility as utility
|
||||
from tools.infer.predict_system import TextSystem, sorted_boxes
|
||||
from paddle_serving_app.local_predict import Debugger
|
||||
import copy
|
||||
from params import read_params
|
||||
|
||||
global_args = read_params()
|
||||
|
||||
global_args = utility.parse_args()
|
||||
if global_args.use_gpu:
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
else:
|
||||
|
@ -84,8 +85,7 @@ class TextSystemHelper(TextSystem):
|
|||
|
||||
class OCRService(WebService):
|
||||
def init_rec(self):
|
||||
args = utility.parse_args()
|
||||
self.text_system = TextSystemHelper(args)
|
||||
self.text_system = TextSystemHelper(global_args)
|
||||
|
||||
def preprocess(self, feed=[], fetch=[]):
|
||||
# TODO: to handle batch rec images
|
||||
|
|
|
@ -24,11 +24,11 @@ import base64
|
|||
from clas_rpc_server import TextClassifierHelper
|
||||
from det_rpc_server import TextDetectorHelper
|
||||
from rec_rpc_server import TextRecognizerHelper
|
||||
import tools.infer.utility as utility
|
||||
from tools.infer.predict_system import TextSystem, sorted_boxes
|
||||
import copy
|
||||
from params import read_params
|
||||
|
||||
global_args = utility.parse_args()
|
||||
global_args = read_params()
|
||||
if global_args.use_gpu:
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
else:
|
||||
|
@ -87,8 +87,7 @@ class TextSystemHelper(TextSystem):
|
|||
|
||||
class OCRService(WebService):
|
||||
def init_rec(self):
|
||||
args = utility.parse_args()
|
||||
self.text_system = TextSystemHelper(args)
|
||||
self.text_system = TextSystemHelper(global_args)
|
||||
|
||||
def preprocess(self, feed=[], fetch=[]):
|
||||
# TODO: to handle batch rec images
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
class Config(object):
|
||||
pass
|
||||
|
||||
def read_params():
|
||||
cfg = Config()
|
||||
#use gpu
|
||||
cfg.use_gpu = False
|
||||
cfg.use_pdserving = True
|
||||
|
||||
#params for text detector
|
||||
cfg.det_algorithm = "DB"
|
||||
cfg.det_model_dir = "./det_mv_server/"
|
||||
cfg.det_max_side_len = 960
|
||||
|
||||
#DB parmas
|
||||
cfg.det_db_thresh =0.3
|
||||
cfg.det_db_box_thresh =0.5
|
||||
cfg.det_db_unclip_ratio =2.0
|
||||
|
||||
#EAST parmas
|
||||
cfg.det_east_score_thresh = 0.8
|
||||
cfg.det_east_cover_thresh = 0.1
|
||||
cfg.det_east_nms_thresh = 0.2
|
||||
|
||||
#params for text recognizer
|
||||
cfg.rec_algorithm = "CRNN"
|
||||
cfg.rec_model_dir = "./ocr_rec_server/"
|
||||
|
||||
cfg.rec_image_shape = "3, 32, 320"
|
||||
cfg.rec_char_type = 'ch'
|
||||
cfg.rec_batch_num = 30
|
||||
cfg.max_text_length = 25
|
||||
|
||||
cfg.rec_char_dict_path = "./ppocr_keys_v1.txt"
|
||||
cfg.use_space_char = True
|
||||
|
||||
#params for text classifier
|
||||
cfg.use_angle_cls = True
|
||||
cfg.cls_model_dir = "./ocr_clas_server/"
|
||||
cfg.cls_image_shape = "3, 48, 192"
|
||||
cfg.label_list = ['0', '180']
|
||||
cfg.cls_batch_num = 30
|
||||
cfg.cls_thresh = 0.9
|
||||
|
||||
return cfg
|
|
@ -22,9 +22,10 @@ import time
|
|||
import re
|
||||
import base64
|
||||
from tools.infer.predict_rec import TextRecognizer
|
||||
import tools.infer.utility as utility
|
||||
from params import read_params
|
||||
|
||||
global_args = read_params()
|
||||
|
||||
global_args = utility.parse_args()
|
||||
if global_args.use_gpu:
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
else:
|
||||
|
|
|
@ -22,9 +22,9 @@ import time
|
|||
import re
|
||||
import base64
|
||||
from tools.infer.predict_rec import TextRecognizer
|
||||
import tools.infer.utility as utility
|
||||
from params import read_params
|
||||
|
||||
global_args = utility.parse_args()
|
||||
global_args = read_params()
|
||||
if global_args.use_gpu:
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
else:
|
||||
|
|
|
@ -10,117 +10,100 @@
|
|||
|
||||
## 一、训练模型转Serving模型
|
||||
|
||||
### 检测模型转Serving模型
|
||||
在前序文档 [基于Python预测引擎推理](./inference.md) 中,我们提供了如何把训练的checkpoint转换成Paddle模型。Paddle模型通常由一个文件夹构成,内含模型结构描述文件`model`和模型参数文件`params`。Serving模型由两个文件夹构成,用于存放客户端和服务端的配置。
|
||||
|
||||
下载超轻量级中文检测模型:
|
||||
我们以`ch_rec_r34_vd_crnn`模型作为例子,下载链接在:
|
||||
|
||||
```
|
||||
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar && tar xf ./ch_lite/ch_det_mv3_db.tar -C ./ch_lite/
|
||||
wget --no-check-certificate https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar
|
||||
tar xf ch_rec_r34_vd_crnn_infer.tar
|
||||
```
|
||||
|
||||
上述模型是以MobileNetV3为backbone训练的DB算法,将训练好的模型转换成Serving模型只需要运行如下命令:
|
||||
因此我们按照Serving模型转换教程,运行下列python文件。
|
||||
```
|
||||
python tools/inference_to_serving.py --model_dir ch_rec_r34_vd_crnn
|
||||
```
|
||||
最终会在`serving_client_dir`和`serving_server_dir`生成客户端和服务端的模型配置。其中`serving_server_dir`和`serving_client_dir`的名字可以自定义。最终文件结构如下
|
||||
|
||||
```
|
||||
# -c后面设置训练算法的yml配置文件
|
||||
# -o配置可选参数
|
||||
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
|
||||
# Global.save_inference_dir参数设置转换的模型将保存的地址。
|
||||
|
||||
python tools/export_serving_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./ch_lite/det_mv3_db/best_accuracy Global.save_inference_dir=./inference/det_db/
|
||||
```
|
||||
|
||||
转Serving模型时,使用的配置文件和训练时使用的配置文件相同。另外,还需要设置配置文件中的`Global.checkpoints`、`Global.save_inference_dir`参数。 其中`Global.checkpoints`指向训练中保存的模型参数文件,`Global.save_inference_dir`是生成的inference模型要保存的目录。 转换成功后,在`save_inference_dir`目录下有两个文件:
|
||||
|
||||
```
|
||||
inference/det_db/
|
||||
├── serving_client_dir # 客户端配置文件夹
|
||||
└── serving_server_dir # 服务端配置文件夹
|
||||
|
||||
```
|
||||
|
||||
### 识别模型转Serving模型
|
||||
|
||||
下载超轻量中文识别模型:
|
||||
|
||||
```
|
||||
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar && tar xf ./ch_lite/ch_rec_mv3_crnn.tar -C ./ch_lite/
|
||||
```
|
||||
|
||||
识别模型转inference模型与检测的方式相同,如下:
|
||||
|
||||
```
|
||||
# -c后面设置训练算法的yml配置文件
|
||||
# -o配置可选参数
|
||||
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
|
||||
# Global.save_inference_dir参数设置转换的模型将保存的地址。
|
||||
|
||||
python3 tools/export_serving_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints=./ch_lite/rec_mv3_crnn/best_accuracy \
|
||||
Global.save_inference_dir=./inference/rec_crnn/
|
||||
```
|
||||
|
||||
**注意:**如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
|
||||
|
||||
转换成功后,在目录下有两个文件:
|
||||
|
||||
```
|
||||
/inference/rec_crnn/
|
||||
/ch_rec_r34_vd_crnn/
|
||||
├── serving_client_dir # 客户端配置文件夹
|
||||
└── serving_server_dir # 服务端配置文件夹
|
||||
```
|
||||
|
||||
### 方向分类模型转Serving模型
|
||||
|
||||
下载方向分类模型:
|
||||
|
||||
```
|
||||
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile-v1.1.cls_pre.tar && tar xf ./ch_lite/ch_ppocr_mobile-v1.1.cls_pre.tar -C ./ch_lite/
|
||||
```
|
||||
|
||||
方向分类模型转inference模型与检测的方式相同,如下:
|
||||
|
||||
```
|
||||
# -c后面设置训练算法的yml配置文件
|
||||
# -o配置可选参数
|
||||
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
|
||||
# Global.save_inference_dir参数设置转换的模型将保存的地址。
|
||||
|
||||
python3 tools/export_serving_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=./ch_lite/cls_model/best_accuracy \
|
||||
Global.save_inference_dir=./inference/cls/
|
||||
```
|
||||
|
||||
转换成功后,在目录下有两个文件:
|
||||
|
||||
```
|
||||
/inference/cls/
|
||||
├── serving_client_dir # 客户端配置文件夹
|
||||
└── serving_server_dir # 服务端配置文件夹
|
||||
```
|
||||
|
||||
在接下来的教程中,我们将给出推理的demo模型下载链接。
|
||||
|
||||
```
|
||||
wget --no-check-certificate https://paddleocr.bj.bcebos.com/deploy/pdserving/ocr_pdserving_suite.tar.gz
|
||||
tar zxf ocr_pdserving_suite.tar.gz
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 二、文本检测模型Serving推理
|
||||
|
||||
文本检测模型推理,默认使用DB模型的配置参数。当不使用DB模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
|
||||
启动服务可以根据实际需求选择启动`标准版`或者`快速版`,两种方式的对比如下表:
|
||||
|
||||
与本地预测不同的是,Serving预测需要一个客户端和一个服务端,因此接下来的教程都是两行代码。所有的
|
||||
|版本|特点|适用场景|
|
||||
|-|-|-|
|
||||
|标准版|稳定性高,分布式部署|适用于吞吐量大,需要跨机房部署的情况|
|
||||
|快速版|部署方便,预测速度快|适用于对预测速度要求高,迭代速度快的场景,Windows用户只能选择快速版|
|
||||
|
||||
接下来的命令中,我们会指定快速版和标准版的命令。需要说明的是,标准版只能用Linux平台,快速版可以支持Linux/Windows。
|
||||
文本检测模型推理,默认使用DB模型的配置参数,识别默认为CRNN。
|
||||
|
||||
配置文件在`params.py`中,我们贴出配置部分,如果需要做改动,也在这个文件内部进行修改。
|
||||
|
||||
```
|
||||
def read_params():
|
||||
cfg = Config()
|
||||
#use gpu
|
||||
cfg.use_gpu = False # 是否使用GPU
|
||||
cfg.use_pdserving = True # 是否使用paddleserving,必须为True
|
||||
|
||||
#params for text detector
|
||||
cfg.det_algorithm = "DB" # 检测算法, DB/EAST等
|
||||
cfg.det_model_dir = "./det_mv_server/" # 检测算法模型路径
|
||||
cfg.det_max_side_len = 960
|
||||
|
||||
#DB params
|
||||
cfg.det_db_thresh =0.3
|
||||
cfg.det_db_box_thresh =0.5
|
||||
cfg.det_db_unclip_ratio =2.0
|
||||
|
||||
#EAST params
|
||||
cfg.det_east_score_thresh = 0.8
|
||||
cfg.det_east_cover_thresh = 0.1
|
||||
cfg.det_east_nms_thresh = 0.2
|
||||
|
||||
#params for text recognizer
|
||||
cfg.rec_algorithm = "CRNN" # 识别算法, CRNN/RARE等
|
||||
cfg.rec_model_dir = "./ocr_rec_server/" # 识别算法模型路径
|
||||
|
||||
cfg.rec_image_shape = "3, 32, 320"
|
||||
cfg.rec_char_type = 'ch'
|
||||
cfg.rec_batch_num = 30
|
||||
cfg.max_text_length = 25
|
||||
|
||||
cfg.rec_char_dict_path = "./ppocr_keys_v1.txt" # 识别算法字典文件
|
||||
cfg.use_space_char = True
|
||||
|
||||
#params for text classifier
|
||||
cfg.use_angle_cls = True # 是否启用分类算法
|
||||
cfg.cls_model_dir = "./ocr_clas_server/" # 分类算法模型路径
|
||||
cfg.cls_image_shape = "3, 48, 192"
|
||||
cfg.label_list = ['0', '180']
|
||||
cfg.cls_batch_num = 30
|
||||
cfg.cls_thresh = 0.9
|
||||
|
||||
return cfg
|
||||
```
|
||||
与本地预测不同的是,Serving预测需要一个客户端和一个服务端,因此接下来的教程都是两行代码。
|
||||
|
||||
在正式执行服务端启动命令之前,先export PYTHONPATH到工程主目录下。
|
||||
```
|
||||
export PYTHONPATH=$PWD:$PYTHONPATH
|
||||
cd deploy/pdserving
|
||||
```
|
||||
### 1. 超轻量中文检测模型推理
|
||||
|
||||
超轻量中文检测模型推理,可以执行如下命令启动服务端:
|
||||
|
||||
```
|
||||
#根据环境只需要启动其中一个就可以
|
||||
python det_rpc_server.py --use_pdserving True --det_model_dir det_mv_server #标准版,Linux用户
|
||||
python det_local_server.py --use_pdserving True --det_model_dir det_mv_server #快速版,Windows/Linux用户
|
||||
python det_rpc_server.py #标准版,Linux用户
|
||||
python det_local_server.py #快速版,Windows/Linux用户
|
||||
```
|
||||
如果需要使用CPU版本,还需增加 `--use_gpu False`。
|
||||
|
||||
客户端
|
||||
|
||||
|
@ -129,23 +112,8 @@ python det_web_client.py
|
|||
```
|
||||
|
||||
|
||||
|
||||
Serving的推测和本地预测不同点在于,客户端发送请求到服务端,服务端需要检测到文字框之后返回框的坐标,此处没有后处理的图片,只能看到坐标值。
|
||||
|
||||
### 2. DB文本检测模型推理
|
||||
|
||||
首先将DB文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)),可以使用如下命令进行转换:
|
||||
|
||||
```
|
||||
# -c后面设置训练算法的yml配置文件
|
||||
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
|
||||
# Global.save_inference_dir参数设置转换的模型将保存的地址。
|
||||
|
||||
python3 tools/export_serving_model.py -c configs/det/det_r50_vd_db.yml -o Global.checkpoints="./models/det_r50_vd_db/best_accuracy" Global.save_inference_dir="./inference/det_db"
|
||||
```
|
||||
|
||||
经过转换之后,会在`./inference/det_db` 目录下出现`serving_server_dir`和`serving_client_dir`,然后指定`det_model_dir` 。
|
||||
|
||||
## 三、文本识别模型Serving推理
|
||||
|
||||
下面将介绍超轻量中文识别模型推理、基于CTC损失的识别模型推理和基于Attention损失的识别模型推理。对于中文文本识别,建议优先选择基于CTC损失的识别模型,实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。此外,如果训练时修改了文本的字典,请参考下面的自定义文本识别字典的推理。
|
||||
|
@ -153,11 +121,11 @@ python3 tools/export_serving_model.py -c configs/det/det_r50_vd_db.yml -o Global
|
|||
### 1. 超轻量中文识别模型推理
|
||||
|
||||
超轻量中文识别模型推理,可以执行如下命令启动服务端:
|
||||
|
||||
需要注意params.py中的`--use_gpu`的值
|
||||
```
|
||||
#根据环境只需要启动其中一个就可以
|
||||
python rec_rpc_server.py --use_pdserving True --rec_model_dir ocr_rec_server #标准版,Linux用户
|
||||
python rec_local_server.py --use_pdserving True --rec_model_dir ocr_rec_server #快速版,Windows/Linux用户
|
||||
python rec_rpc_server.py #标准版,Linux用户
|
||||
python rec_local_server.py #快速版,Windows/Linux用户
|
||||
```
|
||||
如果需要使用CPU版本,还需增加 `--use_gpu False`。
|
||||
|
||||
|
@ -186,13 +154,12 @@ python rec_web_client.py
|
|||
### 1. 方向分类模型推理
|
||||
|
||||
方向分类模型推理, 可以执行如下命令启动服务端:
|
||||
|
||||
需要注意params.py中的`--use_gpu`的值
|
||||
```
|
||||
#根据环境只需要启动其中一个就可以
|
||||
python clas_rpc_server.py --use_pdserving True --cls_model_dir ocr_clas_server #标准版,Linux用户
|
||||
python clas_local_server.py --use_pdserving True --cls_model_dir ocr_clas_server #快速版,Windows/Linux用户
|
||||
python clas_rpc_server.py #标准版,Linux用户
|
||||
python clas_local_server.py #快速版,Windows/Linux用户
|
||||
```
|
||||
如果需要使用CPU版本,还需增加 `--use_gpu False`。
|
||||
|
||||
客户端
|
||||
|
||||
|
@ -216,20 +183,20 @@ python rec_web_client.py
|
|||
在执行预测时,需要通过参数`image_dir`指定单张图像或者图像集合的路径、参数`det_model_dir`,`cls_model_dir`和`rec_model_dir`分别指定检测,方向分类和识别的inference模型路径。参数`use_angle_cls`用于控制是否启用方向分类模型。与本地预测不同的是,为了减少网络传输耗时,可视化识别结果目前不做处理,用户收到的是推理得到的文字字段。
|
||||
|
||||
执行如下命令启动服务端:
|
||||
|
||||
需要注意params.py中的`--use_gpu`的值
|
||||
```
|
||||
#标准版,Linux用户
|
||||
#GPU用户
|
||||
python -m paddle_serving_server_gpu.serve --model det_mv_server --port 9293 --gpu_id 0
|
||||
python -m paddle_serving_server_gpu.serve --model ocr_cls_server --port 9294 --gpu_id 0
|
||||
python ocr_rpc_server.py --use_pdserving True --use_gpu True --rec_model_dir ocr_rec_server
|
||||
python ocr_rpc_server.py
|
||||
#CPU用户
|
||||
python -m paddle_serving_server.serve --model det_mv_server --port 9293
|
||||
python -m paddle_serving_server.serve --model ocr_cls_server --port 9294
|
||||
python ocr_rpc_server.py --use_pdserving True --use_gpu False --rec_model_dir ocr_rec_server
|
||||
python ocr_rpc_server.py
|
||||
|
||||
#快速版,Windows/Linux用户
|
||||
python ocr_local_server.py --use_gpu False --use_pdserving True --rec_model_dir ocr_rec_server/ --det_model_dir det_mv_server/ --cls_model_dir ocr_clas_server/ --rec_char_dict_path ppocr_keys_v1.txt --use_angle_cls True
|
||||
python ocr_local_server.py
|
||||
```
|
||||
|
||||
客户端
|
||||
|
|
|
@ -21,12 +21,10 @@ from ppocr.utils.utility import initial_logger, check_and_read_gif
|
|||
logger = initial_logger()
|
||||
|
||||
import tools.infer.utility as utility
|
||||
args = utility.parse_args()
|
||||
if args.use_pdserving is False:
|
||||
from .data_augment import AugmentData
|
||||
from .random_crop_data import RandomCropData
|
||||
from .make_shrink_map import MakeShrinkMap
|
||||
from .make_border_map import MakeBorderMap
|
||||
from .data_augment import AugmentData
|
||||
from .random_crop_data import RandomCropData
|
||||
from .make_shrink_map import MakeShrinkMap
|
||||
from .make_border_map import MakeBorderMap
|
||||
|
||||
|
||||
class DBProcessTrain(object):
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if os.environ.get(key, None) is None:
|
||||
os.environ[key] = str(value)
|
||||
|
||||
|
||||
# NOTE(paddle-dev): All of these flags should be
|
||||
# set before `import paddle`. Otherwise, it would
|
||||
# not take any effect.
|
||||
set_paddle_flags(
|
||||
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||
)
|
||||
|
||||
import program
|
||||
from paddle import fluid
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.save_load import init_model
|
||||
from paddle_serving_client.io import save_model
|
||||
|
||||
|
||||
def main():
|
||||
startup_prog, eval_program, place, config, _ = program.preprocess()
|
||||
|
||||
feeded_var_names, target_vars, fetches_var_name = program.build_export(
|
||||
config, eval_program, startup_prog)
|
||||
eval_program = eval_program.clone(for_test=True)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_prog)
|
||||
|
||||
init_model(config, eval_program, exe)
|
||||
|
||||
save_inference_dir = config['Global']['save_inference_dir']
|
||||
if not os.path.exists(save_inference_dir):
|
||||
os.makedirs(save_inference_dir)
|
||||
serving_client_dir = "{}/serving_client_dir".format(save_inference_dir)
|
||||
serving_server_dir = "{}/serving_server_dir".format(save_inference_dir)
|
||||
|
||||
feed_dict = {
|
||||
x: eval_program.global_block().var(x)
|
||||
for x in feeded_var_names
|
||||
}
|
||||
fetch_dict = {x.name: x for x in target_vars}
|
||||
save_model(serving_server_dir, serving_client_dir, feed_dict, fetch_dict,
|
||||
eval_program)
|
||||
print(
|
||||
"paddle serving model saved in {}/serving_server_dir and {}/serving_client_dir".
|
||||
format(save_inference_dir, save_inference_dir))
|
||||
print("save success, output_name_list:", fetches_var_name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -36,10 +36,10 @@ class TextClassifier(object):
|
|||
if args.use_pdserving is False:
|
||||
self.predictor, self.input_tensor, self.output_tensors = \
|
||||
utility.create_predictor(args, mode="cls")
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
|
||||
self.cls_batch_num = args.rec_batch_num
|
||||
self.label_list = args.label_list
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
self.cls_thresh = args.cls_thresh
|
||||
|
||||
def resize_norm_img(self, img):
|
||||
|
|
|
@ -42,7 +42,6 @@ class TextDetector(object):
|
|||
def __init__(self, args):
|
||||
max_side_len = args.det_max_side_len
|
||||
self.det_algorithm = args.det_algorithm
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
preprocess_params = {'max_side_len': max_side_len}
|
||||
postprocess_params = {}
|
||||
if self.det_algorithm == "DB":
|
||||
|
@ -76,6 +75,7 @@ class TextDetector(object):
|
|||
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
|
||||
sys.exit(0)
|
||||
if args.use_pdserving is False:
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
self.predictor, self.input_tensor, self.output_tensors =\
|
||||
utility.create_predictor(args, mode="det")
|
||||
|
||||
|
|
|
@ -37,12 +37,12 @@ class TextRecognizer(object):
|
|||
if args.use_pdserving is False:
|
||||
self.predictor, self.input_tensor, self.output_tensors =\
|
||||
utility.create_predictor(args, mode="rec")
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
||||
self.character_type = args.rec_char_type
|
||||
self.rec_batch_num = args.rec_batch_num
|
||||
self.rec_algorithm = args.rec_algorithm
|
||||
self.text_len = args.max_text_length
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
char_ops_params = {
|
||||
"character_type": args.rec_char_type,
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
|
|
|
@ -37,7 +37,6 @@ def parse_args():
|
|||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||
|
||||
# params for text detector
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from paddle_serving_client.io import inference_model_to_serving
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_dir", type=str)
|
||||
parser.add_argument("--server_dir", type=str, default="serving_server_dir")
|
||||
parser.add_argument("--client_dir", type=str, default="serving_client_dir")
|
||||
return parser.parse_args()
|
||||
|
||||
args = parse_args()
|
||||
inference_model_dir = args.model_dir
|
||||
serving_client_dir = args.server_dir
|
||||
serving_server_dir = args.client_dir
|
||||
feed_var_names, fetch_var_names = inference_model_to_serving(
|
||||
inference_model_dir, serving_client_dir, serving_server_dir, model_filename="model", params_filename="params")
|
Loading…
Reference in New Issue