update from original repo

This commit is contained in:
Khanh Tran 2020-06-09 11:45:49 +07:00
commit 406463efb6
42 changed files with 1025 additions and 85 deletions

View File

@ -2,9 +2,12 @@
PaddleOCR aims to create a rich, leading, and practical OCR tools that help users train better models and apply them into practice. PaddleOCR aims to create a rich, leading, and practical OCR tools that help users train better models and apply them into practice.
**Recent updates** **Recent updates**
- 2020.5.30Model prediction and training support Windows systems, and the display of recognition results is optimized - 2020.6.8 Add [dataset](./doc/datasets.md) and keep updating
- 2020.5.30Open source general Chinese OCR model - 2020.6.5 Add `attention` model in `inference_model`
- 2020.5.30Provide Ultra-lightweight Chinese OCR model inference - 2020.6.5 Support separate prediction and recognition, output result score
- 2020.5.30 Provide ultra-lightweight Chinese OCR online experience
- 2020.5.30 Model prediction and training supported on Windows system
- [more](./doc/update.md)
## Features ## Features
- Ultra-lightweight Chinese OCR model, total model size is only 8.6M - Ultra-lightweight Chinese OCR model, total model size is only 8.6M
@ -38,6 +41,8 @@ Please see [Quick installation](./doc/installation.md)
#### 2. Download inference models #### 2. Download inference models
#### (1) Download Ultra-lightweight Chinese OCR models #### (1) Download Ultra-lightweight Chinese OCR models
*If wget is not installed in the windows system, you can copy the link to the browser to download the model. After model downloaded, unzip it and place it in the corresponding directory*
``` ```
mkdir inference && cd inference mkdir inference && cd inference
# Download the detection part of the Ultra-lightweight Chinese OCR and decompress it # Download the detection part of the Ultra-lightweight Chinese OCR and decompress it
@ -64,6 +69,9 @@ The following code implements text detection and recognition inference tandemly.
# Set PYTHONPATH environment variable # Set PYTHONPATH environment variable
export PYTHONPATH=. export PYTHONPATH=.
# Setting environment variable in Windows
SET PYTHONPATH=.
# Prediction on a single image by specifying image path to image_dir # Prediction on a single image by specifying image path to image_dir
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_det_mv3_db/" --rec_model_dir="./inference/ch_rec_mv3_crnn/" python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_det_mv3_db/" --rec_model_dir="./inference/ch_rec_mv3_crnn/"
@ -87,6 +95,7 @@ For more text detection and recognition models, please refer to the document [In
- [Text detection model training/evaluation/prediction](./doc/detection.md) - [Text detection model training/evaluation/prediction](./doc/detection.md)
- [Text recognition model training/evaluation/prediction](./doc/recognition.md) - [Text recognition model training/evaluation/prediction](./doc/recognition.md)
- [Inference](./doc/inference.md) - [Inference](./doc/inference.md)
- [Dataset](./doc/datasets.md)
## Text detection algorithm ## Text detection algorithm
@ -104,6 +113,12 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|DB|ResNet50_vd|83.79%|80.65%|82.19%|[Download link](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)| |DB|ResNet50_vd|83.79%|80.65%|82.19%|[Download link](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)|
|DB|MobileNetV3|75.92%|73.18%|74.53%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)| |DB|MobileNetV3|75.92%|73.18%|74.53%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)|
For use of [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/datasets.md#1icdar2019-lsvt) street view dataset with a total of 3w training datathe related configuration and pre-trained models for Chinese detection task are as follows:
|Model|Backbone|Configuration file|Pre-trained model|
|-|-|-|-|
|Ultra-lightweight Chinese model|MobileNetV3|det_mv3_db.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|
|General Chinese OCR model|ResNet50_vd|det_r50_vd_db.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db.tar)|
* Note: For the training and evaluation of the above DB model, post-processing parameters box_thresh=0.6 and unclip_ratio=1.5 need to be set. If using different datasets and different models for training, these two parameters can be adjusted for better result. * Note: For the training and evaluation of the above DB model, post-processing parameters box_thresh=0.6 and unclip_ratio=1.5 need to be set. If using different datasets and different models for training, these two parameters can be adjusted for better result.
For the training guide and use of PaddleOCR text detection algorithms, please refer to the document [Text detection model training/evaluation/prediction](./doc/detection.md) For the training guide and use of PaddleOCR text detection algorithms, please refer to the document [Text detection model training/evaluation/prediction](./doc/detection.md)
@ -130,6 +145,12 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)| |RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)|
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)| |RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)|
We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/datasets.md#1icdar2019-lsvt) dataset and cropout 30w traning data from original photos by using position groundtruth and make some calibration needed. In addition, based on the LSVT corpus, 500w synthetic data is generated to train the Chinese model. The related configuration and pre-trained models are as follows:
|Model|Backbone|Configuration file|Pre-trained model|
|-|-|-|-|
|Ultra-lightweight Chinese model|MobileNetV3|rec_chinese_lite_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|
|General Chinese OCR model|Resnet34_vd|rec_chinese_common_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)|
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/recognition.md) Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/recognition.md)
## End-to-end OCR algorithm ## End-to-end OCR algorithm
@ -173,6 +194,8 @@ Please refer to the document for training guide and use of PaddleOCR text recogn
Baidu Self-developed algorithms such as SAST, SRN and end2end PSL will be released in June or July. Please be patient. Baidu Self-developed algorithms such as SAST, SRN and end2end PSL will be released in June or July. Please be patient.
[more](./doc/FAQ.md)
## Welcome to the PaddleOCR technical exchange group ## Welcome to the PaddleOCR technical exchange group
Add Wechat: paddlehelp, remark OCR, small assistant will pull you into the group ~ Add Wechat: paddlehelp, remark OCR, small assistant will pull you into the group ~

View File

@ -10,4 +10,3 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
infer_img: ./infer_img

View File

@ -0,0 +1,43 @@
Global:
algorithm: CRNN
use_gpu: true
epoch_num: 3000
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_CRNN
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 128
test_batch_size_per_card: 128
image_shape: [3, 32, 320]
max_text_length: 25
character_type: ch
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: ctc
reader_yml: ./configs/rec/rec_chinese_reader.yml
pretrain_weights:
checkpoints:
save_inference_dir:
infer_img:
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
Backbone:
function: ppocr.modeling.backbones.rec_resnet_vd,ResNet
layers: 34
Head:
function: ppocr.modeling.heads.rec_ctc_head,CTCPredict
encoder_type: rnn
SeqRNN:
hidden_size: 256
Loss:
function: ppocr.modeling.losses.rec_ctc_loss,CTCLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.0005
beta1: 0.9
beta2: 0.999

View File

@ -18,6 +18,8 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel

View File

@ -11,4 +11,3 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img

View File

@ -11,4 +11,3 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img

View File

@ -17,6 +17,8 @@ Global:
pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel

View File

@ -17,6 +17,7 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel

View File

@ -17,6 +17,7 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel

View File

@ -13,10 +13,13 @@ Global:
max_text_length: 25 max_text_length: 25
character_type: en character_type: en
loss_type: attention loss_type: attention
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel

View File

@ -13,10 +13,12 @@ Global:
max_text_length: 25 max_text_length: 25
character_type: en character_type: en
loss_type: ctc loss_type: ctc
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:

View File

@ -17,6 +17,8 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel

View File

@ -17,6 +17,7 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel

View File

@ -13,10 +13,13 @@ Global:
max_text_length: 25 max_text_length: 25
character_type: en character_type: en
loss_type: attention loss_type: attention
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel

View File

@ -13,10 +13,13 @@ Global:
max_text_length: 25 max_text_length: 25
character_type: en character_type: en
loss_type: ctc loss_type: ctc
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel

43
doc/FAQ.md Normal file
View File

@ -0,0 +1,43 @@
## FAQ
1. **预测报错got an unexpected keyword argument 'gradient_clip'**
安装的paddle版本不对目前本项目仅支持paddle1.7近期会适配到1.8。
2. **转换attention识别模型时报错KeyError: 'predict'**
基于Attention损失的识别模型推理还在调试中。对于中文文本识别建议优先选择基于CTC损失的识别模型实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。
3. **关于推理速度**
图片中的文字较多时,预测时间会增,可以使用--rec_batch_num设置更小预测batch num默认值为30可以改为10或其他数值。
4. **服务部署与移动端部署**
预计6月中下旬会先后发布基于Serving的服务部署方案和基于Paddle Lite的移动端部署方案欢迎持续关注。
5. **自研算法发布时间**
自研算法SAST、SRN、End2End-PSL都将在6-7月陆续发布敬请期待。
6. **如何在Windows或Mac系统上运行**
PaddleOCR已完成Windows和Mac系统适配运行时注意两点1、在[快速安装](installation.md)时如果不想安装docker可跳过第一步直接从第二步安装paddle开始。2、inference模型下载时如果没有安装wget可直接点击模型链接或将链接地址复制到浏览器进行下载并解压放置到相应目录。
7. **超轻量模型和通用OCR模型的区别**
目前PaddleOCR开源了2个中文模型分别是8.6M超轻量中文模型和通用中文OCR模型。两者对比信息如下
- 相同点:两者使用相同的**算法**和**训练数据**
- 不同点:不同之处在于**骨干网络**和**通道参数**超轻量模型使用MobileNetV3作为骨干网络通用模型使用Resnet50_vd作为检测模型backboneResnet34_vd作为识别模型backbone具体参数差异可对比两种模型训练的配置文件.
|模型|骨干网络|检测训练配置|识别训练配置|
|-|-|-|-|
|8.6M超轻量中文OCR模型|MobileNetV3+MobileNetV3|det_mv3_db.yml|rec_chinese_lite_train.yml|
|通用中文OCR模型|Resnet50_vd+Resnet34_vd|det_r50_vd_db.yml|rec_chinese_common_train.yml|
8. **是否有计划开源仅识别数字或仅识别英文+数字的模型**
暂不计划开源仅数字、仅数字+英文、或其他小垂类专用模型。PaddleOCR开源了多种检测、识别算法供用户自定义训练两种中文模型也是基于开源的算法库训练产出有小垂类需求的小伙伴可以按照教程准备好数据选择合适的配置文件自行训练相信能有不错的效果。训练有任何问题欢迎提issue或在交流群提问我们会及时解答。
9. **开源模型使用的训练数据是什么,能否开源**
目前开源的模型,数据集和量级如下:
- 检测:
英文数据集ICDAR2015
中文数据集LSVT街景数据集训练数据3w张图片
- 识别:
英文数据集MJSynth和SynthText合成数据数据量上千万。
中文数据集LSVT街景数据集根据真值将图crop出来并进行位置校准总共30w张图像。此外基于LSVT的语料合成数据500w。
其中,公开数据集都是开源的,用户可自行搜索下载,也可参考[中文数据集](datasets.md),合成数据暂不开源,用户可使用开源合成工具自行合成,可参考的合成工具包括[text_renderer](https://github.com/Sanster/text_renderer)、[SynthText](https://github.com/ankush-me/SynthText)、[TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator)等。

Binary file not shown.

Before

Width:  |  Height:  |  Size: 194 KiB

58
doc/datasets.md Normal file
View File

@ -0,0 +1,58 @@
## 数据集
这里整理了常用中文数据集,持续更新中,欢迎各位小伙伴贡献数据集~
- [ICDAR2019-LSVT](#ICDAR2019-LSVT)
- [ICDAR2017-RCTW-17](#ICDAR2017-RCTW-17)
- [中文街景文字识别](#中文街景文字识别)
- [中文文档文字识别](#中文文档文字识别)
- [ICDAR2019-ArT](#ICDAR2019-ArT)
除了开源数据,用户还可使用合成工具自行合成,可参考的合成工具包括[text_renderer](https://github.com/Sanster/text_renderer)、[SynthText](https://github.com/ankush-me/SynthText)、[TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator)等。
<a name="ICDAR2019-LSVT"></a>
#### 1、ICDAR2019-LSVT
- **数据来源**https://ai.baidu.com/broad/introduction?dataset=lsvt
- **数据简介** 共45w中文街景图像包含5w2w测试+3w训练全标注数据文本坐标+文本内容40w弱标注数据仅文本内容如下图所示
![](datasets/LSVT_1.jpg)
(a) 全标注数据
![](datasets/LSVT_2.jpg)
(b) 弱标注数据
- **下载地址**https://ai.baidu.com/broad/download?dataset=lsvt
<a name="ICDAR2017-RCTW-17"></a>
#### 2、ICDAR2017-RCTW-17
- **数据来源**https://rctw.vlrlab.net/
- **数据简介**共包含12,000+图像,大部分图片是通过手机摄像头在野外采集的。有些是截图。这些图片展示了各种各样的场景,包括街景、海报、菜单、室内场景和手机应用程序的截图。
![](datasets/rctw.jpg)
- **下载地址**https://rctw.vlrlab.net/dataset/
<a name="中文街景文字识别"></a>
#### 3、中文街景文字识别
- **数据来源**https://aistudio.baidu.com/aistudio/competition/detail/8
- **数据简介**共包括29万张图片其中21万张图片作为训练集带标注8万张作为测试集无标注。数据集采自中国街景并由街景图片中的文字行区域例如店铺标牌、地标等等截取出来而形成。所有图像都经过一些预处理将文字区域利用仿射变化等比映射为一张高为48像素的图片如图所示
![](datasets/ch_street_rec_1.png)
(a) 标注:魅派集成吊顶
![](datasets/ch_street_rec_2.png)
(b) 标注:母婴用品连锁
- **下载地址**
https://aistudio.baidu.com/aistudio/datasetdetail/8429
<a name="中文文档文字识别"></a>
#### 4、中文文档文字识别
- **数据来源**https://github.com/YCG09/chinese_ocr
- **数据简介**
- 共约364万张图片按照99:1划分成训练集和验证集。
- 数据利用中文语料库(新闻 + 文言文),通过字体、大小、灰度、模糊、透视、拉伸等变化随机生成
- 包含汉字、英文字母、数字和标点共5990个字符字符集合https://github.com/YCG09/chinese_ocr/blob/master/train/char_std_5990.txt
- 每个样本固定10个字符字符随机截取自语料库中的句子
- 图片分辨率统一为280x32
![](datasets/ch_doc1.jpg)
![](datasets/ch_doc2.jpg)
![](datasets/ch_doc3.jpg)
- **下载地址**https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (密码lu7m)
<a name="ICDAR2019-ArT"></a>
#### 5、ICDAR2019-ArT
- **数据来源**https://ai.baidu.com/broad/introduction?dataset=art
- **数据简介**共包含10,166张图像训练集5603图测试集4563图。由Total-Text、SCUT-CTW1500、Baidu Curved Scene Text三部分组成包含水平、多方向和弯曲等多种形状的文本。
![](datasets/ArT.jpg)
- **下载地址**https://ai.baidu.com/broad/download?dataset=art

BIN
doc/datasets/ArT.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.1 MiB

BIN
doc/datasets/LSVT_1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 123 KiB

BIN
doc/datasets/LSVT_2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

BIN
doc/datasets/ch_doc1.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

BIN
doc/datasets/ch_doc2.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 KiB

BIN
doc/datasets/ch_doc3.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 114 KiB

BIN
doc/datasets/rctw.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 93 KiB

99
doc/detection.md Normal file
View File

@ -0,0 +1,99 @@
# 文字检测
本节以icdar15数据集为例介绍PaddleOCR中检测模型的训练、评估与测试。
## 数据准备
icdar2015数据集可以从[官网](https://rrc.cvc.uab.es/?ch=4&com=downloads)下载到,首次下载需注册。
将下载到的数据集解压到工作目录下,假设解压在 PaddleOCR/train_data/ 下。另外PaddleOCR将零散的标注文件整理成单独的标注文件
您可以通过wget的方式进行下载。
```
# 在PaddleOCR路径下
cd PaddleOCR/
wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt
wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_label.txt
```
解压数据集和下载标注文件后PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是:
```
/PaddleOCR/train_data/icdar2015/text_localization/
└─ icdar_c4_train_imgs/ icdar数据集的训练数据
└─ ch4_test_images/ icdar数据集的测试数据
└─ train_icdar2015_label.txt icdar数据集的训练标注
└─ test_icdar2015_label.txt icdar数据集的测试标注
```
提供的标注文件格式为:
```
" 图像文件名 json.dumps编码的图像标注信息"
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]
```
json.dumps编码前的图像标注信息是包含多个字典的list字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。
`transcription` 表示当前文本框的文字,在文本检测任务中并不需要这个信息。
如果您想在其他数据集上训练PaddleOCR可以按照上述形式构建标注文件。
## 快速启动训练
首先下载pretrain modelPaddleOCR的检测模型目前支持两种backbone分别是MobileNetV3、ResNet50_vd
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。
```
cd PaddleOCR/
# 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
# 下载ResNet50的预训练模型
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
```
**启动训练**
*如果您安装的是cpu版本请将配置文件中的 `use_gpu` 字段修改为false*
```
python3 tools/train.py -c configs/det/det_mv3_db.yml
```
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)。
您也可以通过-o参数在不需要修改yml文件的情况下改变训练的参数比如调整训练的学习率为0.0001
```
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
```
## 指标评估
PaddleOCR计算三个OCR检测相关的指标分别是Precision、Recall、Hmean。
运行如下代码根据配置文件det_db_mv3.yml中save_res_path指定的测试集检测结果文件计算评估指标。
评估时设置后处理参数box_thresh=0.6unclip_ratio=1.5,使用不同数据集、不同模型训练,可调整这两个参数进行优化
```
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
训练中模型参数默认保存在Global.save_model_dir目录下。在评估指标时需要设置Global.checkpoints指向保存的参数文件。
比如:
```
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
* 注box_thresh、unclip_ratio是DB后处理所需要的参数在评估EAST模型时不需要设置
## 测试检测效果
测试单张图像的检测效果
```
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
```
测试DB模型时调整后处理阈值
```
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
测试文件夹下所有图像的检测效果
```
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
```

219
doc/inference.md Normal file
View File

@ -0,0 +1,219 @@
# 基于预测引擎推理
inference 模型fluid.io.save_inference_model保存的模型
一般是模型训练完成后保存的固化模型,多用于预测部署。
训练过程中保存的模型是checkpoints模型保存的是模型的参数多用于恢复训练等。
与checkpoints模型相比inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合与实际系统集成。更详细的介绍请参考文档[分类预测框架](https://paddleclas.readthedocs.io/zh_CN/latest/extension/paddle_inference.html).
接下来首先介绍如何将训练的模型转换成inference模型然后将依次介绍文本检测、文本识别以及两者串联基于预测引擎推理。
## 训练模型转inference模型
### 检测模型转inference模型
下载超轻量级中文检测模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar && tar xf ./ch_lite/ch_det_mv3_db.tar -C ./ch_lite/
```
上述模型是以MobileNetV3为backbone训练的DB算法将训练好的模型转换成inference模型只需要运行如下命令
```
python3 tools/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./ch_lite/det_mv3_db/best_accuracy Global.save_inference_dir=./inference/det_db/
```
转inference模型时使用的配置文件和训练时使用的配置文件相同。另外还需要设置配置文件中的Global.checkpoints、Global.save_inference_dir参数。
其中Global.checkpoints指向训练中保存的模型参数文件Global.save_inference_dir是生成的inference模型要保存的目录。
转换成功后在save_inference_dir 目录下有两个文件:
```
inference/det_db/
└─ model 检测inference模型的program文件
└─ params 检测inference模型的参数文件
```
### 识别模型转inference模型
下载超轻量中文识别模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar && tar xf ./ch_lite/ch_rec_mv3_crnn.tar -C ./ch_lite/
```
识别模型转inference模型与检测的方式相同如下
```
python3 tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints=./ch_lite/rec_mv3_crnn/best_accuracy \
Global.save_inference_dir=./inference/rec_crnn/
```
如果您是在自己的数据集上训练的模型并且调整了中文字符的字典文件请注意修改配置文件中的character_dict_path是否是所需要的字典文件。
转换成功后,在目录下有两个文件:
```
/inference/rec_crnn/
└─ model 识别inference模型的program文件
└─ params 识别inference模型的参数文件
```
## 文本检测模型推理
下面将介绍超轻量中文检测模型推理、DB文本检测模型推理和EAST文本检测模型推理。默认配置是根据DB文本检测模型推理设置的。由于EAST和DB算法差别很大在推理时需要通过传入相应的参数适配EAST文本检测算法。
### 1.超轻量中文检测模型推理
超轻量中文检测模型推理,可以执行如下命令:
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/"
```
可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
![](imgs_results/det_res_2.jpg)
通过设置参数det_max_side_len的大小改变检测算法中图片规范化的最大值。当图片的长宽都小于det_max_side_len则使用原图预测否则将图片等比例缩放到最大值进行预测。该参数默认设置为det_max_side_len=960. 如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以执行如下命令:
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_max_side_len=1200
```
如果想使用CPU进行预测执行命令如下
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
```
### 2.DB文本检测模型推理
首先将DB文本检测训练过程中保存的模型转换成inference model。以基于Resnet50_vd骨干网络在ICDAR2015英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)),可以使用如下命令进行转换:
```
# -c后面设置训练算法的yml配置文件
# Global.checkpoints参数设置待转换的训练模型地址不用添加文件后缀.pdmodel.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.checkpoints="./models/det_r50_vd_db/best_accuracy" Global.save_inference_dir="./inference/det_db"
```
DB文本检测模型推理可以执行如下命令
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/"
```
可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
![](imgs_results/det_res_img_10_db.jpg)
**注意**由于ICDAR2015数据集只有1000张训练图像主要针对英文场景所以上述模型对中文文本图像检测效果非常差。
### 3.EAST文本检测模型推理
首先将EAST文本检测训练过程中保存的模型转换成inference model。以基于Resnet50_vd骨干网络在ICDAR2015英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_east.tar)),可以使用如下命令进行转换:
```
# -c后面设置训练算法的yml配置文件
# Global.checkpoints参数设置待转换的训练模型地址不用添加文件后缀.pdmodel.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.checkpoints="./models/det_r50_vd_east/best_accuracy" Global.save_inference_dir="./inference/det_east"
```
EAST文本检测模型推理需要设置参数det_algorithm指定检测算法类型为EAST可以执行如下命令
```
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST"
```
可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
![](imgs_results/det_res_img_10_east.jpg)
**注意**本代码库中EAST后处理中NMS采用的Python版本所以预测速度比较耗时。如果采用C++版本,会有明显加速。
## 文本识别模型推理
下面将介绍超轻量中文识别模型推理和基于CTC损失的识别模型推理。**而基于Attention损失的识别模型推理还在调试中**。对于中文文本识别建议优先选择基于CTC损失的识别模型实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。
### 1.超轻量中文识别模型推理
超轻量中文识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --rec_model_dir="./inference/rec_crnn/"
```
![](imgs_words/ch/word_4.jpg)
执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下:
Predicts of ./doc/imgs_words/ch/word_4.jpg:['实力活力', 0.89552695]
### 2.基于CTC损失的识别模型推理
我们以STAR-Net为例介绍基于CTC损失的识别模型推理。 CRNN和Rosetta使用方式类似不用设置识别算法参数rec_algorithm。
首先将STAR-Net文本识别训练过程中保存的模型转换成inference model。以基于Resnet34_vd骨干网络使用MJSynth和SynthText两个英文文本识别合成数据集训练
的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_ctc.tar)),可以使用如下命令进行转换:
```
# -c后面设置训练算法的yml配置文件
# Global.checkpoints参数设置待转换的训练模型地址不用添加文件后缀.pdmodel.pdopt或.pdparams。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/rec/rec_r34_vd_tps_bilstm_ctc.yml -o Global.checkpoints="./models/rec_r34_vd_tps_bilstm_ctc/best_accuracy" Global.save_inference_dir="./inference/starnet"
```
STAR-Net文本识别模型推理可以执行如下命令
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
```
### 3.基于Attention损失的识别模型推理
基于Attention损失的识别模型与ctc不同需要额外设置识别算法参数 --rec_algorithm="RARE"
RARE 文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/sare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
```
![](imgs_words_en/word_336.png)
执行命令后,上面图像的识别结果如下:
Predicts of ./doc/imgs_words_en/word_336.png:['super', 0.9999555]
**注意**:由于上述模型是参考[DTRB](https://arxiv.org/abs/1904.01906)文本识别训练和评估流程,与超轻量级中文识别模型训练有两方面不同:
- 训练时采用的图像分辨率不同,训练上述模型采用的图像分辨率是[332100],而中文模型训练时,为了保证长文本的识别效果,训练时采用的图像分辨率是[3, 32, 320]。预测推理程序默认的的形状参数是训练中文采用的图像分辨率,即[3, 32, 320]。因此这里推理上述英文模型时需要通过参数rec_image_shape设置识别图像的形状。
- 字符列表DTRB论文中实验只是针对26个小写英文本母和10个数字进行实验总共36个字符。所有大小字符都转成了小写字符不在上面列表的字符都忽略认为是空格。因此这里没有输入字符字典而是通过如下命令生成字典.因此在推理时需要设置参数rec_char_type指定为英文"en"。
```
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
```
## 文本检测、识别串联推理
### 1.超轻量中文OCR模型推理
在执行预测时需要通过参数image_dir指定单张图像或者图像集合的路径、参数det_model_dir指定检测inference模型的路径和参数rec_model_dir指定识别inference模型的路径。可视化识别结果默认保存到 ./inference_results 文件夹里面。
```
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/"
```
执行命令后,识别结果图像如下:
![](imgs_results/2.jpg)
### 2.其他模型推理
如果想尝试使用其他检测算法或者识别算法请参考上述文本检测模型推理和文本识别模型推理更新相应配置和模型下面给出基于EAST文本检测和STAR-Net文本识别执行命令
```
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
```
执行命令后,识别结果图像如下:
![](imgs_results/img_10.jpg)

81
doc/installation.md Normal file
View File

@ -0,0 +1,81 @@
## 快速安装
经测试PaddleOCR可在glibc 2.23上运行您也可以测试其他glibc版本或安装glic 2.23
PaddleOCR 工作环境
- PaddlePaddle1.7
- python3
- glibc 2.23
建议使用我们提供的docker运行PaddleOCR有关docker使用请参考[链接](https://docs.docker.com/get-started/)。
*如您希望使用 mac 或 windows直接运行预测代码可以从第2步开始执行。*
1. 建议准备docker环境。第一次使用这个镜像会自动下载该镜像请耐心等待。
```
# 切换到工作目录下
cd /home/Projects
# 首次运行需创建一个docker容器再次运行时不需要运行当前命令
# 创建一个名字为ppocr的docker容器并将当前目录映射到容器的/paddle目录下
如果您希望在CPU环境下使用docker使用docker而不是nvidia-docker创建docker
sudo docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash
如果您的机器安装的是CUDA9请运行以下命令创建容器
sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash
如果您的机器安装的是CUDA10请运行以下命令创建容器
sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda10.0-cudnn7-dev /bin/bash
您也可以访问[DockerHub](https://hub.docker.com/r/paddlepaddle/paddle/tags/)获取与您机器适配的镜像。
# ctrl+P+Q可退出docker重新进入docker使用如下命令
sudo docker container exec -it ppocr /bin/bash
```
注意如果docker pull过慢可以按照如下步骤手动下载后加载docker,以cuda9 docker为例使用cuda10 docker只需要将cuda9改为cuda10即可。
```
# 下载CUDA9 docker的压缩文件并解压
wget https://paddleocr.bj.bcebos.com/docker/docker_pdocr_cuda9.tar.gz
# 为减少下载时间上传的docker image是压缩过的需要解压使用
tar zxf docker_pdocr_cuda9.tar.gz
# 创建image
docker load < docker_pdocr_cuda9.tar
# 完成上述步骤后通过docker images检查是否加载了下载的镜像
docker images
# 执行docker images后如果有下面的输出即可按照按照 步骤1 创建docker环境。
hub.baidubce.com/paddlepaddle/paddle latest-gpu-cuda9.0-cudnn7-dev f56310dcc829
```
2. 安装PaddlePaddle Fluid v1.7(暂不支持更高版本,适配工作进行中)
```
pip3 install --upgrade pip
如果您的机器安装的是CUDA9请运行以下命令安装
python3 -m pip install paddlepaddle-gpu==1.7.2.post97 -i https://pypi.tuna.tsinghua.edu.cn/simple
如果您的机器安装的是CUDA10请运行以下命令安装
python3 -m pip install paddlepaddle-gpu==1.7.2.post107 -i https://pypi.tuna.tsinghua.edu.cn/simple
如果您的机器是CPU请运行以下命令安装
python3 -m pip install paddlepaddle==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
```
3. 克隆PaddleOCR repo代码
```
【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
如果因为网络问题无法pull成功也可选择使用码云上的托管
git clone https://gitee.com/paddlepaddle/PaddleOCR
码云托管代码可能无法实时同步本github项目更新存在3~5天延时请优先使用推荐方式。
```
4. 安装第三方库
```
cd PaddleOCR
pip3 install -r requirments.txt
```

225
doc/recognition.md Normal file
View File

@ -0,0 +1,225 @@
## 文字识别
### 数据准备
PaddleOCR 支持两种数据格式: `lmdb` 用于训练公开数据,调试算法; `通用数据` 训练自己的数据:
请按如下步骤设置数据集:
训练数据的默认存储路径是 `PaddleOCR/train_data`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录:
```
ln -sf <path/to/dataset> <path/to/paddle_detection>/train_data/dataset
```
* 数据下载
若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。
* 使用自己数据集:
若您希望使用自己的数据进行训练,请参考下文组织您的数据。
- 训练集
首先请将训练图片放入同一个文件夹train_images并用一个txt文件rec_gt_train.txt记录图片路径和标签。
* 注意: 默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错
```
" 图像文件名 图像标注信息 "
train_data/train_0001.jpg 简单可依赖
train_data/train_0002.jpg 用科技让复杂的世界更简单
```
PaddleOCR 提供了一份用于训练 icdar2015 数据集的标签文件,通过以下方式下载:
```
# 训练集标签
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_train.txt
# 测试集标签
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt
```
最终训练集应有如下文件结构:
```
|-train_data
|-ic15_data
|- rec_gt_train.txt
|- train
|- word_001.png
|- word_002.jpg
|- word_003.jpg
| ...
```
- 测试集
同训练集类似测试集也需要提供一个包含所有图片的文件夹test和一个rec_gt_test.txt测试集的结构如下所示
```
|-train_data
|-ic15_data
|- rec_gt_test.txt
|- test
|- word_001.jpg
|- word_002.jpg
|- word_003.jpg
| ...
```
- 字典
最后需要提供一个字典({word_dict_name}.txt使模型在训练时可以将所有出现的字符映射为字典的索引。
因此字典需要包含所有希望被正确识别的字符,{word_dict_name}.txt需要写成如下格式并以 `utf-8` 编码格式保存:
```
l
d
a
d
r
n
```
word_dict.txt 每行有一个单字将字符与数字索引映射在一起“and” 将被映射成 [2 5 1]
`ppocr/utils/ppocr_keys_v1.txt` 是一个包含6623个字符的中文字典
`ppocr/utils/ic15_dict.txt` 是一个包含36个字符的英文字典
您可以按需使用。
如需自定义dic文件请修改 `configs/rec/rec_icdar15_train.yml` 中的 `character_dict_path` 字段, 并将 `character_type` 设置为 `ch`
### 启动训练
PaddleOCR提供了训练脚本、评估脚本和预测脚本本节将以 CRNN 识别模型为例:
首先下载pretrain model您可以下载训练好的模型在 icdar2015 数据上进行finetune
```
cd PaddleOCR/
# 下载MobileNetV3的预训练模型
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_mv3_none_bilstm_ctc.tar
# 解压模型参数
cd pretrain_models
tar -xf rec_mv3_none_bilstm_ctc.tar && rm -rf rec_mv3_none_bilstm_ctc.tar
```
开始训练:
*如果您安装的是cpu版本请将配置文件中的 `use_gpu` 字段修改为false*
```
# 设置PYTHONPATH路径
export PYTHONPATH=$PYTHONPATH:.
# GPU训练 支持单卡多卡训练通过CUDA_VISIBLE_DEVICES指定卡号
export CUDA_VISIBLE_DEVICES=0,1,2,3
# 训练icdar15英文数据
python3 tools/train.py -c configs/rec/rec_icdar15_train.yml
```
PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_train.yml` 中修改 `eval_batch_step` 设置评估频率默认每500个iter评估一次。评估过程中默认将最佳acc模型保存为 `output/rec_CRNN/best_accuracy`
如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。
* 提示: 可通过 -c 参数选择 `configs/rec/` 路径下的多种模型配置进行训练PaddleOCR支持的识别算法有
| 配置文件 | 算法名称 | backbone | trans | seq | pred |
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: |
| rec_chinese_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc |
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
| rec_mv3_tps_bilstm_ctc.yml | STARNet | Mobilenet_v3 large 0.5 | tps | BiLSTM | ctc |
| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention |
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
| rec_r34_vd_tps_bilstm_attn.yml | RARE | Resnet34_vd | tps | BiLSTM | attention |
| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
训练中文数据,推荐使用`rec_chinese_lite_train.yml`,如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
`rec_mv3_none_none_ctc.yml` 为例:
```
Global:
...
# 修改 image_shape 以适应长文本
image_shape: [3, 32, 320]
...
# 修改字符类型
character_type: ch
# 添加自定义字典,如修改字典请将路径指向新字典
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
...
# 修改reader类型
reader_yml: ./configs/rec/rec_chinese_reader.yml
...
...
```
**注意,预测/评估时的配置文件请务必与训练一致。**
### 评估
评估数据集可以通过 `configs/rec/rec_icdar15_reader.yml` 修改EvalReader中的 `label_file_path` 设置。
*注意* 评估时必须确保配置文件中 infer_img 字段为空
```
export CUDA_VISIBLE_DEVICES=0
# GPU 评估, Global.checkpoints 为待测权重
python3 tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
### 预测
* 训练引擎的预测
使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。
默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 指定权重:
```
# 预测英文结果
python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
```
预测图片:
![](./imgs_words/en/word_1.png)
得到输入图像的预测结果:
```
infer_img: doc/imgs_words/en/word_1.png
index: [19 24 18 23 29]
word : joint
```
预测使用的配置文件必须与训练一致,如您通过 `python3 tools/train.py -c configs/rec/rec_chinese_lite_train.yml` 完成了中文模型的训练,
您可以使用如下命令进行中文模型预测。
```
# 预测中文结果
python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/ch/word_1.jpg
```
预测图片:
![](./imgs_words/ch/word_1.jpg)
得到输入图像的预测结果:
```
infer_img: doc/imgs_words/ch/word_1.jpg
index: [2092 177 312 2503]
word : 韩国小馆
```

10
doc/update.md Normal file
View File

@ -0,0 +1,10 @@
# 版本更新
- 2020.6.5 支持 `attetnion` 模型导出 `inference_model`
- 2020.6.5 支持单独预测识别时,输出结果得分
- 2020.5.30 提供超轻量级中文OCR在线体验
- 2020.5.30 模型预测、训练支持Windows系统
- 2020.5.30 开源通用中文OCR模型
- 2020.5.14 发布[PaddleOCR公开课](https://www.bilibili.com/video/BV1nf4y1U7RX?p=4)
- 2020.5.14 发布[PaddleOCR实战练习](https://aistudio.baidu.com/aistudio/projectdetail/467229)
- 2020.5.14 开源8.6M超轻量级中文OCR模型

View File

@ -61,8 +61,6 @@ class TrainReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader return batch_iter_reader

View File

@ -17,6 +17,8 @@ import cv2
import numpy as np import numpy as np
import json import json
import sys import sys
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from .data_augment import AugmentData from .data_augment import AugmentData
from .random_crop_data import RandomCropData from .random_crop_data import RandomCropData
@ -100,6 +102,7 @@ class DBProcessTrain(object):
img_path, gt_label = self.convert_label_infor(label_infor) img_path, gt_label = self.convert_label_infor(label_infor)
imgvalue = cv2.imread(img_path) imgvalue = cv2.imread(img_path)
if imgvalue is None: if imgvalue is None:
logger.info("{} does not exist!".format(img_path))
return None return None
data = self.make_data_dict(imgvalue, gt_label) data = self.make_data_dict(imgvalue, gt_label)
data = AugmentData(data) data = AugmentData(data)

View File

@ -41,13 +41,18 @@ class LMDBReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.drop_last = False
self.use_tps = False
if "tps" in params:
self.ues_tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == "eval": self.drop_last = True
else:
self.batch_size = params['test_batch_size_per_card'] self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test": self.drop_last = False
self.batch_size = 1 self.infer_img = params['infer_img']
self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self): def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
@ -100,13 +105,18 @@ class LMDBReader(object):
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img yield norm_img
else: else:
lmdb_sets = self.load_hierarchical_lmdb_dataset() lmdb_sets = self.load_hierarchical_lmdb_dataset()
@ -126,9 +136,13 @@ class LMDBReader(object):
if sample_info is None: if sample_info is None:
continue continue
img, label = sample_info img, label = sample_info
outs = process_image(img, self.image_shape, label, outs = process_image(
self.char_ops, self.loss_type, img=img,
self.max_text_length) image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length)
if outs is None: if outs is None:
continue continue
yield outs yield outs
@ -136,6 +150,7 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets): if finish_read_num == len(lmdb_sets):
break break
self.close_lmdb_dataset(lmdb_sets) self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader(): def batch_iter_reader():
batch_outs = [] batch_outs = []
for outs in sample_iter_reader(): for outs in sample_iter_reader():
@ -143,10 +158,11 @@ class LMDBReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if len(batch_outs) != 0: if not self.drop_last:
yield batch_outs if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'test': if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader
@ -165,26 +181,34 @@ class SimpleReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.infer_img = params['infer_img']
self.use_tps = False
if "tps" in params:
self.use_tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == 'eval': self.drop_last = True
self.batch_size = params['test_batch_size_per_card']
else: else:
self.batch_size = 1 self.batch_size = params['test_batch_size_per_card']
self.infer_img = params['infer_img'] self.drop_last = False
def __call__(self, process_id): def __call__(self, process_id):
if self.mode != 'train': if self.mode != 'train':
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img yield norm_img
else: else:
with open(self.label_file_path, "rb") as fin: with open(self.label_file_path, "rb") as fin:
@ -192,7 +216,7 @@ class SimpleReader(object):
img_num = len(label_infor_list) img_num = len(label_infor_list)
img_id_list = list(range(img_num)) img_id_list = list(range(img_num))
random.shuffle(img_id_list) random.shuffle(img_id_list)
if sys.platform=="win32": if sys.platform == "win32":
print("multiprocess is not fully compatible with Windows." print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.") "num_workers will be 1.")
self.num_workers = 1 self.num_workers = 1
@ -204,7 +228,7 @@ class SimpleReader(object):
if img is None: if img is None:
logger.info("{} does not exist!".format(img_path)) logger.info("{} does not exist!".format(img_path))
continue continue
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
label = substr[1] label = substr[1]
@ -222,9 +246,10 @@ class SimpleReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if len(batch_outs) != 0: if not self.drop_last:
yield batch_outs if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'test': if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader

View File

@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape):
return padding_im return padding_im
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = 0
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(32 * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def get_img_data(value): def get_img_data(value):
"""get_img_data""" """get_img_data"""
if not value: if not value:
@ -66,8 +92,13 @@ def process_image(img,
label=None, label=None,
char_ops=None, char_ops=None,
loss_type=None, loss_type=None,
max_text_length=None): max_text_length=None,
norm_img = resize_norm_img(img, image_shape) tps=None,
infer_mode=False):
if not infer_mode or char_ops.character_type == "en" or tps != None:
norm_img = resize_norm_img(img, image_shape)
else:
norm_img = resize_norm_img_chinese(img, image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
if label is not None: if label is not None:
char_num = char_ops.get_char_num() char_num = char_ops.get_char_num()

View File

@ -30,6 +30,8 @@ class RecModel(object):
global_params = params['Global'] global_params = params['Global']
char_num = global_params['char_ops'].get_char_num() char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num global_params['char_num'] = char_num
self.char_type = global_params['character_type']
self.infer_img = global_params['infer_img']
if "TPS" in params: if "TPS" in params:
tps_params = deepcopy(params["TPS"]) tps_params = deepcopy(params["TPS"])
tps_params.update(global_params) tps_params.update(global_params)
@ -60,8 +62,8 @@ class RecModel(object):
def create_feed(self, mode): def create_feed(self, mode):
image_shape = deepcopy(self.image_shape) image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1) image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if mode == "train": if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if self.loss_type == "attention": if self.loss_type == "attention":
label_in = fluid.data( label_in = fluid.data(
name='label_in', name='label_in',
@ -86,6 +88,16 @@ class RecModel(object):
use_double_buffer=True, use_double_buffer=True,
iterable=False) iterable=False)
else: else:
if self.char_type == "ch" and self.infer_img:
image_shape[-1] = -1
if self.tps != None:
logger.info(
"WARNRNG!!!\n"
"TPS does not support variable shape in chinese!"
"We set img_shape to be the same , it may affect the inference effect"
)
image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
labels = None labels = None
loader = None loader = None
return image, labels, loader return image, labels, loader
@ -110,7 +122,11 @@ class RecModel(object):
return loader, outputs return loader, outputs
elif mode == "export": elif mode == "export":
predict = predicts['predict'] predict = predicts['predict']
predict = fluid.layers.softmax(predict) if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else: else:
return loader, {'decoded_out': decoded_out} predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return loader, {'decoded_out': decoded_out, 'predicts': predict}

View File

@ -123,6 +123,8 @@ class AttentionPredict(object):
full_ids = fluid.layers.fill_constant_batch_size_like( full_ids = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='int64', value=1) input=init_state, shape=[-1, 1], dtype='int64', value=1)
full_scores = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='float32', value=1)
cond = layers.less_than(x=counter, y=array_len) cond = layers.less_than(x=counter, y=array_len)
while_op = layers.While(cond=cond) while_op = layers.While(cond=cond)
@ -171,6 +173,9 @@ class AttentionPredict(object):
new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1) new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1)
fluid.layers.assign(new_ids, full_ids) fluid.layers.assign(new_ids, full_ids)
new_scores = fluid.layers.concat([full_scores, topk_scores], axis=1)
fluid.layers.assign(new_scores, full_scores)
layers.increment(x=counter, value=1, in_place=True) layers.increment(x=counter, value=1, in_place=True)
# update the memories # update the memories
@ -184,7 +189,7 @@ class AttentionPredict(object):
length_cond = layers.less_than(x=counter, y=array_len) length_cond = layers.less_than(x=counter, y=array_len)
finish_cond = layers.logical_not(layers.is_empty(x=topk_indices)) finish_cond = layers.logical_not(layers.is_empty(x=topk_indices))
layers.logical_and(x=length_cond, y=finish_cond, out=cond) layers.logical_and(x=length_cond, y=finish_cond, out=cond)
return full_ids return full_ids, full_scores
def __call__(self, inputs, labels=None, mode=None): def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs) encoder_features = self.encoder(inputs)
@ -223,10 +228,10 @@ class AttentionPredict(object):
decoder_size, char_num) decoder_size, char_num)
_, decoded_out = layers.topk(input=predict, k=1) _, decoded_out = layers.topk(input=predict, k=1)
decoded_out = layers.lod_reset(decoded_out, y=label_out) decoded_out = layers.lod_reset(decoded_out, y=label_out)
predicts = {'predict': predict, 'decoded_out': decoded_out} predicts = {'predict':predict, 'decoded_out':decoded_out}
else: else:
ids = self.gru_attention_infer( ids, predict = self.gru_attention_infer(
decoder_boot, self.max_length, char_num, word_vector_dim, decoder_boot, self.max_length, char_num, word_vector_dim,
encoded_vector, encoded_proj, decoder_size) encoded_vector, encoded_proj, decoder_size)
predicts = {'decoded_out': ids} predicts = {'predict':predict, 'decoded_out':ids}
return predicts return predicts

View File

@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
total_sample_num = 0 total_sample_num = 0
total_acc_num = 0 total_acc_num = 0
total_batch_num = 0 total_batch_num = 0
if mode == "test": if mode == "eval":
is_remove_duplicate = False is_remove_duplicate = False
else: else:
is_remove_duplicate = True is_remove_duplicate = True
@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict):
total_correct_number = 0 total_correct_number = 0
eval_data_acc_info = {} eval_data_acc_info = {}
for eval_data in eval_data_list: for eval_data in eval_data_list:
config['EvalReader']['lmdb_sets_dir'] = \ config['TestReader']['lmdb_sets_dir'] = \
eval_data_dir + "/" + eval_data eval_data_dir + "/" + eval_data
eval_reader = reader_main(config=config, mode="eval") eval_reader = reader_main(config=config, mode="test")
eval_info_dict['reader'] = eval_reader eval_info_dict['reader'] = eval_reader
metrics = eval_rec_run(exe, config, eval_info_dict, "eval") metrics = eval_rec_run(exe, config, eval_info_dict, "test")
total_evaluation_data_number += metrics['total_sample_num'] total_evaluation_data_number += metrics['total_sample_num']
total_correct_number += metrics['total_acc_num'] total_correct_number += metrics['total_acc_num']
eval_data_acc_info[eval_data] = metrics eval_data_acc_info[eval_data] = metrics

View File

@ -32,10 +32,16 @@ class TextRecognizer(object):
self.rec_image_shape = image_shape self.rec_image_shape = image_shape
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
char_ops_params = {} char_ops_params = {}
char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_type"] = args.rec_char_type
char_ops_params["character_dict_path"] = args.rec_char_dict_path char_ops_params["character_dict_path"] = args.rec_char_dict_path
char_ops_params['loss_type'] = 'ctc' if self.rec_algorithm != "RARE":
char_ops_params['loss_type'] = 'ctc'
self.loss_type = 'ctc'
else:
char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention'
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
@ -80,26 +86,43 @@ class TextRecognizer(object):
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() self.predictor.zero_copy_run()
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0] if self.loss_type == "ctc":
predict_batch = self.output_tensors[1].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_lod = self.output_tensors[1].lod()[0] rec_idx_lod = self.output_tensors[0].lod()[0]
elapse = time.time() - starttime predict_batch = self.output_tensors[1].copy_to_cpu()
predict_time += elapse predict_lod = self.output_tensors[1].lod()[0]
starttime = time.time() elapse = time.time() - starttime
for rno in range(len(rec_idx_lod) - 1): predict_time += elapse
beg = rec_idx_lod[rno] for rno in range(len(rec_idx_lod) - 1):
end = rec_idx_lod[rno + 1] beg = rec_idx_lod[rno]
rec_idx_tmp = rec_idx_batch[beg:end, 0] end = rec_idx_lod[rno + 1]
preds_text = self.char_ops.decode(rec_idx_tmp) rec_idx_tmp = rec_idx_batch[beg:end, 0]
beg = predict_lod[rno] preds_text = self.char_ops.decode(rec_idx_tmp)
end = predict_lod[rno + 1] beg = predict_lod[rno]
probs = predict_batch[beg:end, :] end = predict_lod[rno + 1]
ind = np.argmax(probs, axis=1) probs = predict_batch[beg:end, :]
blank = probs.shape[1] ind = np.argmax(probs, axis=1)
valid_ind = np.where(ind != (blank - 1))[0] blank = probs.shape[1]
score = np.mean(probs[valid_ind, ind[valid_ind]]) valid_ind = np.where(ind != (blank - 1))[0]
rec_res.append([preds_text, score]) score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu()
elapse = time.time() - starttime
predict_time += elapse
for rno in range(len(rec_idx_batch)):
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
if len(end_pos) <= 1:
preds = rec_idx_batch[rno, 1:]
score = np.mean(predict_batch[rno, 1:])
else:
preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]])
preds_text = self.char_ops.decode(preds)
rec_res.append([preds_text, score])
return rec_res, predict_time return rec_res, predict_time
@ -116,7 +139,17 @@ if __name__ == "__main__":
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
rec_res, predict_time = text_recognizer(img_list) try:
rec_res, predict_time = text_recognizer(img_list)
except Exception as e:
print(e)
logger.info(
"ERROR!!!! \n"
"Please read the FAQhttps://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" % print("Total predict time for %d images:%.3f" %

View File

@ -21,6 +21,7 @@ import time
import multiprocessing import multiprocessing
import numpy as np import numpy as np
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
if os.environ.get(key, None) is None: if os.environ.get(key, None) is None:
@ -54,6 +55,7 @@ def main():
program.merge_config(FLAGS.opt) program.merge_config(FLAGS.opt)
logger.info(config) logger.info(config)
char_ops = CharacterOps(config['Global']) char_ops = CharacterOps(config['Global'])
loss_type = config['Global']['loss_type']
config['Global']['char_ops'] = char_ops config['Global']['char_ops'] = char_ops
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
@ -78,35 +80,44 @@ def main():
init_model(config, eval_prog, exe) init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')() blobs = reader_main(config, 'test')()
infer_img = config['TestReader']['infer_img'] infer_img = config['Global']['infer_img']
infer_list = get_image_file_list(infer_img) infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list) max_img_num = len(infer_list)
if len(infer_list) == 0: if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.") logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num): for i in range(max_img_num):
print("infer_img:",infer_list[i]) print("infer_img:%s" % infer_list[i])
img = next(blobs) img = next(blobs)
predict = exe.run(program=eval_prog, predict = exe.run(program=eval_prog,
feed={"image": img}, feed={"image": img},
fetch_list=fetch_varname_list, fetch_list=fetch_varname_list,
return_numpy=False) return_numpy=False)
if loss_type == "ctc":
preds = np.array(predict[0]) preds = np.array(predict[0])
if preds.shape[1] == 1:
preds = preds.reshape(-1) preds = preds.reshape(-1)
preds_lod = predict[0].lod()[0] preds_lod = predict[0].lod()[0]
preds_text = char_ops.decode(preds) preds_text = char_ops.decode(preds)
else: probs = np.array(predict[1])
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
elif loss_type == "attention":
preds = np.array(predict[0])
probs = np.array(predict[1])
end_pos = np.where(preds[0, :] == 1)[0] end_pos = np.where(preds[0, :] == 1)[0]
if len(end_pos) <= 1: if len(end_pos) <= 1:
preds_text = preds[0, 1:] preds = preds[0, 1:]
score = np.mean(probs[0, 1:])
else: else:
preds_text = preds[0, 1:end_pos[1]] preds = preds[0, 1:end_pos[1]]
preds_text = preds_text.reshape(-1) score = np.mean(probs[0, 1:end_pos[1]])
preds_text = char_ops.decode(preds_text) preds = preds.reshape(-1)
preds_text = char_ops.decode(preds)
print("\t index:",preds) print("\t index:", preds)
print("\t word :",preds_text) print("\t word :", preds_text)
print("\t score :", score)
# save for inference model # save for inference model
target_var = [] target_var = []

View File

@ -114,7 +114,7 @@ def merge_config(config):
global_config[key] = value global_config[key] = value
else: else:
sub_keys = key.split('.') sub_keys = key.split('.')
assert (sub_keys[0] in global_config) assert (sub_keys[0] in global_config), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]] cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]): for idx, sub_key in enumerate(sub_keys[1:]):
assert (sub_key in cur) assert (sub_key in cur)