Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleOCR into master

This commit is contained in:
WenmuZhou 2020-09-16 11:14:18 +08:00
commit 227f5f3a36
28 changed files with 580 additions and 133 deletions

View File

@ -189,7 +189,7 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训
请扫描下面二维码完成问卷填写获取加群二维码和OCR方向的炼丹秘籍
<div align="center">
<img src="./doc/joinus.jpg" width = "200" height = "200" />
<img src="./doc/joinus.PNG" width = "200" height = "200" />
</div>
<a name="许可证书"></a>

View File

@ -56,7 +56,6 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr
- Algorithm introduction
- [Text Detection Algorithm](#TEXTDETECTIONALGORITHM)
- [Text Recognition Algorithm](#TEXTRECOGNITIONALGORITHM)
- [END-TO-END OCR Algorithm](#ENDENDOCRALGORITHM)
- Model training/evaluation
- [Text Detection](./doc/doc_en/detection_en.md)
- [Text Recognition](./doc/doc_en/recognition_en.md)
@ -158,10 +157,6 @@ We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md)
<a name="ENDENDOCRALGORITHM"></a>
## END-TO-END OCR Algorithm
- [ ] [End2End-PSL](https://arxiv.org/abs/1909.07808)(Baidu Self-Research, coming soon)
## Visualization
<a name="UCOCRVIS"></a>
@ -211,7 +206,7 @@ Please refer to the document for training guide and use of PaddleOCR text recogn
Scan the QR code below with your wechat and completing the questionnaire, you can access to offical technical exchange group.
<div align="center">
<img src="./doc/joinus.jpg" width = "200" height = "200" />
<img src="./doc/joinus.PNG" width = "200" height = "200" />
</div>
<a name="LICENSE"></a>

View File

@ -26,6 +26,8 @@ void DBDetector::LoadModel(const std::string &model_dir) {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}

View File

@ -126,6 +126,8 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
// cache 10 different shapes for mkldnn to avoid memory leak
config.SetMkldnnCacheCapacity(10);
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}

View File

@ -3,7 +3,7 @@ use_gpu 0
gpu_id 0
gpu_mem 4000
cpu_math_library_num_threads 10
use_mkldnn 0
use_mkldnn 1
use_zero_copy_run 1
# det config

View File

@ -20,7 +20,7 @@ git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b. Goto Dockerfile directorypsNeed to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword
```
cd docker/cpu
cd deploy/docker/cpu
```
c. Build image
```

View File

@ -20,7 +20,7 @@ git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b.切换至Dockerfile目录需要区分cpu或gpu版本下文以cpu为例gpu版本需要替换一下关键字即可
```
cd docker/cpu
cd deploy/docker/cpu
```
c.生成镜像
```

View File

@ -0,0 +1,34 @@
> 运行示例前请先安装1.2.0或更高版本PaddleSlim
# 模型量化压缩教程
## 概述
该示例使用PaddleSlim提供的[量化压缩API](https://paddlepaddle.github.io/PaddleSlim/api/quantization_api/)对OCR模型进行压缩。
在阅读该示例前,建议您先了解以下内容:
- [OCR模型的常规训练方法](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/detection.md)
- [PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)
## 安装PaddleSlim
可按照[PaddleSlim使用文档](https://paddlepaddle.github.io/PaddleSlim/)中的步骤安装PaddleSlim。
## 量化训练
进入PaddleOCR根目录通过以下命令对模型进行量化
```bash
python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=det_mv3_db/best_accuracy Global.save_model_dir=./output/quant_model
```
## 评估并导出
在得到量化训练保存的模型后我们可以将其导出为inference_model用于预测部署
```bash
python deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_model_dir=./output/quant_model
```

View File

@ -0,0 +1,129 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
sys.path.append(
os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
import program
from paddle import fluid
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.save_load import init_model, load_params
from ppocr.utils.character import CharacterOps
from ppocr.utils.utility import create_module
from ppocr.data.reader_main import reader_main
from paddleslim.quant import quant_aware, convert
from paddle.fluid.layer_helper import LayerHelper
from eval_utils.eval_det_utils import eval_det_run
from eval_utils.eval_rec_utils import eval_rec_run
def main():
# 1. quantization configs
quant_config = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
}
startup_prog, eval_program, place, config, alg_type = program.preprocess()
feeded_var_names, target_vars, fetches_var_name = program.build_export(
config, eval_program, startup_prog)
eval_program = eval_program.clone(for_test=True)
exe = fluid.Executor(place)
exe.run(startup_prog)
eval_program = quant_aware(
eval_program, place, quant_config, scope=None, for_test=True)
init_model(config, eval_program, exe)
# 2. Convert the program before save inference program
# The dtype of eval_program's weights is float32, but in int8 range.
eval_program = convert(eval_program, place, quant_config, scope=None)
eval_fetch_name_list = fetches_var_name
eval_fetch_varname_list = [v.name for v in target_vars]
eval_reader = reader_main(config=config, mode="eval")
quant_info_dict = {'program':eval_program,\
'reader':eval_reader,\
'fetch_name_list':eval_fetch_name_list,\
'fetch_varname_list':eval_fetch_varname_list}
if alg_type == 'det':
final_metrics = eval_det_run(exe, config, quant_info_dict, "eval")
else:
final_metrics = eval_rec_run(exe, config, quant_info_dict, "eval")
print(final_metrics)
# 3. Save inference model
model_path = "./quant_model"
if not os.path.isdir(model_path):
os.makedirs(model_path)
fluid.io.save_inference_model(
dirname=model_path,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
executor=exe,
main_program=eval_program,
model_filename=model_path + '/model',
params_filename=model_path + '/params')
print("model saved as {}".format(model_path))
if __name__ == '__main__':
main()

188
deploy/slim/quantization/quant.py Executable file
View File

@ -0,0 +1,188 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
sys.path.append(
os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
import tools.program as program
from paddle import fluid
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.data.reader_main import reader_main
from ppocr.utils.save_load import init_model
from paddle.fluid.contrib.model_stat import summary
# quant dependencies
import paddle
import paddle.fluid as fluid
from paddleslim.quant import quant_aware, convert
from paddle.fluid.layer_helper import LayerHelper
def pact(x):
"""
Process a variable using the pact method you define
Args:
x(Tensor): Paddle Tensor, need to be preprocess before quantization
Returns:
The processed Tensor x.
"""
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
"""
Build a program using a model and an optimizer
"""
return fluid.optimizer.AdamOptimizer(0.001)
def main():
train_build_outputs = program.build(
config, train_program, startup_program, mode='train')
train_loader = train_build_outputs[0]
train_fetch_name_list = train_build_outputs[1]
train_fetch_varname_list = train_build_outputs[2]
train_opt_loss_name = train_build_outputs[3]
model_average = train_build_outputs[-1]
eval_program = fluid.Program()
eval_build_outputs = program.build(
config, eval_program, startup_program, mode='eval')
eval_fetch_name_list = eval_build_outputs[1]
eval_fetch_varname_list = eval_build_outputs[2]
eval_program = eval_program.clone(for_test=True)
train_reader = reader_main(config=config, mode="train")
train_loader.set_sample_list_generator(train_reader, places=place)
eval_reader = reader_main(config=config, mode="eval")
exe = fluid.Executor(place)
exe.run(startup_program)
# 1. quantization configs
quant_config = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
}
# 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
act_preprocess_func = pact
optimizer_func = get_optimizer
executor = exe
eval_program = quant_aware(
eval_program,
place,
quant_config,
scope=None,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor,
for_test=True)
quant_train_program = quant_aware(
train_program,
place,
quant_config,
scope=None,
act_preprocess_func=act_preprocess_func,
optimizer_func=optimizer_func,
executor=executor,
for_test=False,
return_program=True)
# compile program for multi-devices
train_compile_program = program.create_multi_devices_program(
quant_train_program, train_opt_loss_name, for_quant=True)
init_model(config, quant_train_program, exe)
train_info_dict = {'compile_program':train_compile_program,\
'train_program':quant_train_program,\
'reader':train_loader,\
'fetch_name_list':train_fetch_name_list,\
'fetch_varname_list':train_fetch_varname_list,\
'model_average': model_average}
eval_info_dict = {'program':eval_program,\
'reader':eval_reader,\
'fetch_name_list':eval_fetch_name_list,\
'fetch_varname_list':eval_fetch_varname_list}
if train_alg_type == 'det':
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
else:
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
if __name__ == '__main__':
startup_program, train_program, place, config, train_alg_type = program.preprocess(
)
main()

View File

@ -140,7 +140,7 @@ PaddleOCR提供了多种数据增强方式如果您希望在训练时加入
训练过程中每种扰动方式以50%的概率被选择,具体代码实现请参考:[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
*由于OpenCV的兼容性问题扰动操作暂时只支持GPU*
*由于OpenCV的兼容性问题扰动操作暂时只支持Linux*
- 训练

View File

@ -61,6 +61,14 @@ hub install deploy\hubserving\ocr_rec\
hub install deploy\hubserving\ocr_system\
```
#### 安装模型
安装服务模块前,需要将训练好的模型放到对应的文件夹内。默认使用的是:
./inference/ch_det_mv3_db/
./inference/ch_rec_mv3_crnn/
这两个模型可以在https://github.com/PaddlePaddle/PaddleOCR 下载
可以在./deploy/hubserving/ocr_system/params.py 里面修改成自己的模型
### 3. 启动服务
#### 方式1. 命令行命令启动仅支持CPU
**启动命令:**

BIN
doc/joinus.PNG Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

View File

@ -1,58 +0,0 @@
English | [简体中文](README_cn.md)
## Introduction
Many user hopes package the PaddleOCR service into an docker image, so that it can be quickly released and used in the docker or k8s environment.
This page provide some standardized code to achieve this goal. You can quickly publish the PaddleOCR project into a callable Restful API service through the following steps. (At present, the deployment based on the HubServing mode is implemented first, and author plans to increase the deployment of the PaddleServing mode in the futrue)
## 1. Prerequisites
You need to install the following basic components first
a. Docker
b. Graphics driver and CUDA 10.0+GPU
c. NVIDIA Container ToolkitGPUDocker 19.03+ can skip this
d. cuDNN 7.6+GPU
## 2. Build Image
a. Download PaddleOCR sourcecode
```
git clone https://github.com/PaddlePaddle/PaddleOCR.git
```
b. Goto Dockerfile directorypsNeed to distinguish between cpu and gpu version, the following takes cpu as an example, gpu version needs to replace the keyword
```
cd docker/cpu
```
c. Build image
```
docker build -t paddleocr:cpu .
```
## 3. Start container
a. CPU version
```
sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu
```
b. GPU version (base on NVIDIA Container Toolkit)
```
sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu
```
c. GPU version (Docker 19.03++)
```
sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu
```
d. Check service statusIf you can see the following statement then it means completedSuccessfully installed ocr_system && Running on http://0.0.0.0:8866/
```
docker logs -f paddle_ocr
```
## 4. Test
a. Calculate the Base64 encoding of the picture to be recognized (if you just test, you can use a free online tool, likehttps://freeonlinetools24.com/base64-image/
b. Post a service requestsample request in sample_request.txt
```
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"Input image Base64 encode(need to delete the code 'data:image/jpg;base64,'\"]}" http://localhost:8866/predict/ocr_system
```
c. Get resposneIf the call is successful, the following result will be returned
```
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
```

View File

@ -1,4 +1,4 @@
# -*- coding:utf-8 -*-
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
@ -121,24 +121,22 @@ def RandomCropData(data, size):
all_care_polys = [
text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = crop_area(im, all_care_polys,
min_crop_side_ratio, max_tries)
# crop 图片 保持比例填充
scale_w = size[0] / crop_w
scale_h = size[1] / crop_h
dh, dw = size
scale_w = dw / crop_w
scale_h = dh / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
if keep_ratio:
padimg = np.zeros((size[1], size[0], im.shape[2]), im.dtype)
padimg = np.zeros((dh, dw, im.shape[2]), im.dtype)
padimg[:h, :w] = cv2.resize(
im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
img = padimg
else:
img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w],
tuple(size))
# crop 文本框
(dw, dh))
text_polys_crop = []
ignore_tags_crop = []
texts_crop = []

View File

@ -67,6 +67,7 @@ class DetModel(object):
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
if mode == "train":
if self.algorithm == "EAST":
h, w = int(image_shape[1] // 4), int(image_shape[2] // 4)
@ -108,7 +109,10 @@ class DetModel(object):
name='tvo', shape=[9, 128, 128], dtype='float32')
input_tco = fluid.layers.data(
name='tco', shape=[3, 128, 128], dtype='float32')
feed_list = [image, input_score, input_border, input_mask, input_tvo, input_tco]
feed_list = [
image, input_score, input_border, input_mask, input_tvo,
input_tco
]
labels = {'input_score': input_score,\
'input_border': input_border,\
'input_mask': input_mask,\

View File

@ -68,6 +68,7 @@ class RecModel(object):
image_shape.insert(0, -1)
if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
if self.loss_type == "attention":
label_in = fluid.data(
name='label_in',
@ -136,7 +137,7 @@ class RecModel(object):
else:
labels = None
loader = None
if self.char_type == "ch" and self.infer_img:
if self.char_type == "ch" and self.infer_img and self.loss_type != "srn":
image_shape[-1] = -1
if self.tps != None:
logger.info(
@ -146,6 +147,7 @@ class RecModel(object):
)
image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
image.stop_gradient = False
if self.loss_type == "srn":
encoder_word_pos = fluid.data(
name="encoder_word_pos",
@ -172,16 +174,13 @@ class RecModel(object):
self.max_text_length
],
dtype="float32")
feed_list = [
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
labels = {
'encoder_word_pos': encoder_word_pos,
'gsrm_word_pos': gsrm_word_pos,
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
}
return image, labels, loader
def __call__(self, mode):
@ -218,8 +217,13 @@ class RecModel(object):
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
if self.loss_type == "srn":
raise Exception(
"Warning! SRN does not support export model currently")
return [
image, labels, {
'decoded_out': decoded_out,
'predicts': predict
}
]
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else:
predict = predicts['predict']

View File

@ -35,12 +35,13 @@ class CTCPredict(object):
self.fc_decay = params.get("fc_decay", 0.0004)
def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs)
if self.encoder_type != "reshape":
encoder_features = fluid.layers.concat(encoder_features, axis=1)
name = "ctc_fc"
para_attr, bias_attr = get_para_bias_attr(
l2_decay=self.fc_decay, k=encoder_features.shape[1], name=name)
with fluid.scope_guard("skip_quant"):
encoder_features = self.encoder(inputs)
if self.encoder_type != "reshape":
encoder_features = fluid.layers.concat(encoder_features, axis=1)
name = "ctc_fc"
para_attr, bias_attr = get_para_bias_attr(
l2_decay=self.fc_decay, k=encoder_features.shape[1], name=name)
predict = fluid.layers.fc(input=encoder_features,
size=self.char_num + 1,
param_attr=para_attr,

View File

@ -90,15 +90,3 @@ def check_and_read_gif(img_path):
return imgvalue, True
return None, False
def create_multi_devices_program(program, loss_var_name):
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = True
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1
compile_program = fluid.CompiledProgram(program).with_data_parallel(
loss_name=loss_var_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
return compile_program

View File

@ -40,6 +40,7 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
self.text_len = args.max_text_length
self.use_zero_copy_run = args.use_zero_copy_run
char_ops_params = {
"character_type": args.rec_char_type,
@ -47,12 +48,15 @@ class TextRecognizer(object):
"use_space_char": args.use_space_char,
"max_text_length": args.max_text_length
}
if self.rec_algorithm != "RARE":
if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]:
char_ops_params['loss_type'] = 'ctc'
self.loss_type = 'ctc'
else:
elif self.rec_algorithm == "RARE":
char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention'
elif self.rec_algorithm == "SRN":
char_ops_params['loss_type'] = 'srn'
self.loss_type = 'srn'
self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio):
@ -75,6 +79,83 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length,
char_num):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self,
img,
image_shape,
num_heads,
max_text_length,
char_ops=None):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
char_num = char_ops.get_char_num()
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@ -84,7 +165,7 @@ class TextRecognizer(object):
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
# rec_res = []
#rec_res = []
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
predict_time = 0
@ -98,20 +179,62 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
if self.loss_type != "srn":
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8,
25, self.char_ops)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
norm_img_batch = norm_img_batch.copy()
starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
else:
if self.loss_type == "srn":
starttime = time.time()
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(
gsrm_slf_attn_bias2_list)
starttime = time.time()
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
encoder_word_pos_list = fluid.core.PaddleTensor(
encoder_word_pos_list)
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch, encoder_word_pos_list,
gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list,
gsrm_word_pos_list
]
self.predictor.run(inputs)
else:
starttime = time.time()
if self.use_zero_copy_run:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
else:
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
self.predictor.run([norm_img_batch])
if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
@ -136,6 +259,26 @@ class TextRecognizer(object):
score = np.mean(probs[valid_ind, ind[valid_ind]])
# rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
elif self.loss_type == 'srn':
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
probs = self.output_tensors[1].copy_to_cpu()
char_num = self.char_ops.get_char_num()
preds = rec_idx_batch.reshape(-1)
elapse = time.time() - starttime
predict_time += elapse
total_preds = preds.copy()
for ino in range(int(len(rec_idx_batch) / self.text_len)):
preds = total_preds[ino * self.text_len:(ino + 1) *
self.text_len]
ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
preds = preds[:valid_ind[-1] + 1]
preds_text = self.char_ops.decode(preds)
rec_res[indices[beg_img_no + ino]] = [preds_text, score]
else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu()
@ -170,6 +313,7 @@ def main(args):
continue
valid_image_file_list.append(image_file)
img_list.append(img)
try:
rec_res, predict_time = text_recognizer(img_list)
except Exception as e:

View File

@ -122,7 +122,6 @@ def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_sys = TextSystem(args)
is_visualize = True
tackle_img_num = 0
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
@ -131,9 +130,6 @@ def main(args):
logger.info("error in loading image:{}".format(image_file))
continue
starttime = time.time()
tackle_img_num += 1
if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0:
text_sys = TextSystem(args)
dt_boxes, rec_res = text_sys(img)
elapse = time.time() - starttime
print("Predict time of %s: %.3fs" % (image_file, elapse))
@ -153,11 +149,7 @@ def main(args):
scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr(
image,
boxes,
txts,
scores,
drop_score=drop_score)
image, boxes, txts, scores, drop_score=drop_score)
draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)

View File

@ -101,6 +101,8 @@ def create_predictor(args, mode):
config.disable_gpu()
config.set_cpu_math_library_num_threads(6)
if args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
#config.enable_memory_optim()
@ -114,7 +116,8 @@ def create_predictor(args, mode):
predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0])
for name in input_names:
input_tensor = predictor.get_input_tensor(name)
output_names = predictor.get_output_names()
output_tensors = []
for output_name in output_names:

View File

@ -145,7 +145,7 @@ def main():
preds = preds.reshape(-1)
probs = np.array(predict[1])
ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num-1))[0]
valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])

View File

@ -208,18 +208,29 @@ def build_export(config, main_prog, startup_prog):
with fluid.unique_name.guard():
func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config)
image, outputs = model(mode='export')
algorithm = config['Global']['algorithm']
if algorithm == "SRN":
image, others, outputs = model(mode='export')
else:
image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name]
feeded_var_names = [image.name]
if algorithm == "SRN":
others_var_names = sorted([name for name in others.keys()])
feeded_var_names = [image.name] + others_var_names
else:
feeded_var_names = [image.name]
target_vars = fetches_var
return feeded_var_names, target_vars, fetches_var_name
def create_multi_devices_program(program, loss_var_name):
def create_multi_devices_program(program, loss_var_name, for_quant=False):
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = True
if for_quant:
build_strategy.fuse_all_reduce_ops = False
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 1
compile_program = fluid.CompiledProgram(program).with_data_parallel(
@ -409,7 +420,9 @@ def preprocess():
check_gpu(use_gpu)
alg = config['Global']['algorithm']
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'
]
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
config['Global']['char_ops'] = CharacterOps(config['Global'])