diff --git a/README.md b/README.md index a5c450ea..d3ce2c8a 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,12 @@ PaddleOCR aims to create a rich, leading, and practical OCR tools that help users train better models and apply them into practice. **Recent updates** -- 2020.5.30,Model prediction and training support Windows systems, and the display of recognition results is optimized -- 2020.5.30,Open source general Chinese OCR model -- 2020.5.30,Provide Ultra-lightweight Chinese OCR model inference +- 2020.6.8 Add [dataset](./doc/datasets.md) and keep updating +- 2020.6.5 Add `attention` model in `inference_model` +- 2020.6.5 Support separate prediction and recognition, output result score +- 2020.5.30 Provide ultra-lightweight Chinese OCR online experience +- 2020.5.30 Model prediction and training supported on Windows system +- [more](./doc/update.md) ## Features - Ultra-lightweight Chinese OCR model, total model size is only 8.6M @@ -38,6 +41,8 @@ Please see [Quick installation](./doc/installation.md) #### 2. Download inference models #### (1) Download Ultra-lightweight Chinese OCR models +*If wget is not installed in the windows system, you can copy the link to the browser to download the model. After model downloaded, unzip it and place it in the corresponding directory* + ``` mkdir inference && cd inference # Download the detection part of the Ultra-lightweight Chinese OCR and decompress it @@ -64,6 +69,9 @@ The following code implements text detection and recognition inference tandemly. # Set PYTHONPATH environment variable export PYTHONPATH=. +# Setting environment variable in Windows +SET PYTHONPATH=. + # Prediction on a single image by specifying image path to image_dir python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_det_mv3_db/" --rec_model_dir="./inference/ch_rec_mv3_crnn/" @@ -87,6 +95,7 @@ For more text detection and recognition models, please refer to the document [In - [Text detection model training/evaluation/prediction](./doc/detection.md) - [Text recognition model training/evaluation/prediction](./doc/recognition.md) - [Inference](./doc/inference.md) +- [Dataset](./doc/datasets.md) ## Text detection algorithm @@ -104,6 +113,12 @@ On the ICDAR2015 dataset, the text detection result is as follows: |DB|ResNet50_vd|83.79%|80.65%|82.19%|[Download link](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)| |DB|MobileNetV3|75.92%|73.18%|74.53%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)| +For use of [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/datasets.md#1icdar2019-lsvt) street view dataset with a total of 3w training data,the related configuration and pre-trained models for Chinese detection task are as follows: +|Model|Backbone|Configuration file|Pre-trained model| +|-|-|-|-| +|Ultra-lightweight Chinese model|MobileNetV3|det_mv3_db.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)| +|General Chinese OCR model|ResNet50_vd|det_r50_vd_db.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_det_r50_vd_db.tar)| + * Note: For the training and evaluation of the above DB model, post-processing parameters box_thresh=0.6 and unclip_ratio=1.5 need to be set. If using different datasets and different models for training, these two parameters can be adjusted for better result. For the training guide and use of PaddleOCR text detection algorithms, please refer to the document [Text detection model training/evaluation/prediction](./doc/detection.md) @@ -130,6 +145,12 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)| |RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)| +We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/datasets.md#1icdar2019-lsvt) dataset and cropout 30w traning data from original photos by using position groundtruth and make some calibration needed. In addition, based on the LSVT corpus, 500w synthetic data is generated to train the Chinese model. The related configuration and pre-trained models are as follows: +|Model|Backbone|Configuration file|Pre-trained model| +|-|-|-|-| +|Ultra-lightweight Chinese model|MobileNetV3|rec_chinese_lite_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)| +|General Chinese OCR model|Resnet34_vd|rec_chinese_common_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn.tar)| + Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/recognition.md) ## End-to-end OCR algorithm @@ -172,6 +193,8 @@ Please refer to the document for training guide and use of PaddleOCR text recogn 5. Release time of self-developed algorithm Baidu Self-developed algorithms such as SAST, SRN and end2end PSL will be released in June or July. Please be patient. + +[more](./doc/FAQ.md) ## Welcome to the PaddleOCR technical exchange group Add Wechat: paddlehelp, remark OCR, small assistant will pull you into the group ~ diff --git a/configs/rec/rec_benchmark_reader.yml b/configs/rec/rec_benchmark_reader.yml index 3d1e3e0b..524f2f68 100755 --- a/configs/rec/rec_benchmark_reader.yml +++ b/configs/rec/rec_benchmark_reader.yml @@ -10,4 +10,3 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,LMDBReader lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ - infer_img: ./infer_img diff --git a/configs/rec/rec_chinese_common_train.yml b/configs/rec/rec_chinese_common_train.yml new file mode 100644 index 00000000..af56dca2 --- /dev/null +++ b/configs/rec/rec_chinese_common_train.yml @@ -0,0 +1,43 @@ +Global: + algorithm: CRNN + use_gpu: true + epoch_num: 3000 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec_CRNN + save_epoch_step: 3 + eval_batch_step: 2000 + train_batch_size_per_card: 128 + test_batch_size_per_card: 128 + image_shape: [3, 32, 320] + max_text_length: 25 + character_type: ch + character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt + loss_type: ctc + reader_yml: ./configs/rec/rec_chinese_reader.yml + pretrain_weights: + checkpoints: + save_inference_dir: + infer_img: + +Architecture: + function: ppocr.modeling.architectures.rec_model,RecModel + +Backbone: + function: ppocr.modeling.backbones.rec_resnet_vd,ResNet + layers: 34 + +Head: + function: ppocr.modeling.heads.rec_ctc_head,CTCPredict + encoder_type: rnn + SeqRNN: + hidden_size: 256 + +Loss: + function: ppocr.modeling.losses.rec_ctc_loss,CTCLoss + +Optimizer: + function: ppocr.optimizer,AdamDecay + base_lr: 0.0005 + beta1: 0.9 + beta2: 0.999 diff --git a/configs/rec/rec_chinese_lite_train.yml b/configs/rec/rec_chinese_lite_train.yml index ec1b7a69..b64313a1 100755 --- a/configs/rec/rec_chinese_lite_train.yml +++ b/configs/rec/rec_chinese_lite_train.yml @@ -15,9 +15,11 @@ Global: character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt loss_type: ctc reader_yml: ./configs/rec/rec_chinese_reader.yml - pretrain_weights: + pretrain_weights: checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_chinese_reader.yml b/configs/rec/rec_chinese_reader.yml index f09a1ea7..a44efd99 100755 --- a/configs/rec/rec_chinese_reader.yml +++ b/configs/rec/rec_chinese_reader.yml @@ -11,4 +11,3 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,SimpleReader - infer_img: ./infer_img diff --git a/configs/rec/rec_icdar15_reader.yml b/configs/rec/rec_icdar15_reader.yml index 12facda1..322d5f25 100755 --- a/configs/rec/rec_icdar15_reader.yml +++ b/configs/rec/rec_icdar15_reader.yml @@ -11,4 +11,3 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,SimpleReader - infer_img: ./infer_img diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml index 6596fc33..934a9410 100755 --- a/configs/rec/rec_icdar15_train.yml +++ b/configs/rec/rec_icdar15_train.yml @@ -14,9 +14,11 @@ Global: character_type: en loss_type: ctc reader_yml: ./configs/rec/rec_icdar15_reader.yml - pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy + pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 11a09ee9..d2e096fb 100755 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -17,6 +17,7 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml index bbbb6d1f..ceec09ce 100755 --- a/configs/rec/rec_mv3_none_none_ctc.yml +++ b/configs/rec/rec_mv3_none_none_ctc.yml @@ -17,6 +17,7 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_mv3_tps_bilstm_attn.yml b/configs/rec/rec_mv3_tps_bilstm_attn.yml index 03a2e901..7fc4f679 100755 --- a/configs/rec/rec_mv3_tps_bilstm_attn.yml +++ b/configs/rec/rec_mv3_tps_bilstm_attn.yml @@ -13,11 +13,14 @@ Global: max_text_length: 25 character_type: en loss_type: attention + tps: true reader_yml: ./configs/rec/rec_benchmark_reader.yml pretrain_weights: checkpoints: save_inference_dir: - + infer_img: + + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml index 47247b72..4b9660bc 100755 --- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml @@ -13,10 +13,12 @@ Global: max_text_length: 25 character_type: en loss_type: ctc + tps: true reader_yml: ./configs/rec/rec_benchmark_reader.yml pretrain_weights: checkpoints: save_inference_dir: + infer_img: Architecture: diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml index 10181936..b71e8fea 100755 --- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml @@ -17,7 +17,9 @@ Global: pretrain_weights: checkpoints: save_inference_dir: - + infer_img: + + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml index ff4c5763..d9c9458d 100755 --- a/configs/rec/rec_r34_vd_none_none_ctc.yml +++ b/configs/rec/rec_r34_vd_none_none_ctc.yml @@ -17,6 +17,7 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_r34_vd_tps_bilstm_attn.yml b/configs/rec/rec_r34_vd_tps_bilstm_attn.yml index 4d96e9e7..dfcd97fa 100755 --- a/configs/rec/rec_r34_vd_tps_bilstm_attn.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_attn.yml @@ -13,10 +13,13 @@ Global: max_text_length: 25 character_type: en loss_type: attention + tps: true reader_yml: ./configs/rec/rec_benchmark_reader.yml pretrain_weights: checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml index 844721a2..574a088c 100755 --- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml @@ -13,10 +13,13 @@ Global: max_text_length: 25 character_type: en loss_type: ctc + tps: true reader_yml: ./configs/rec/rec_benchmark_reader.yml pretrain_weights: checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/doc/FAQ.md b/doc/FAQ.md new file mode 100644 index 00000000..373e275d --- /dev/null +++ b/doc/FAQ.md @@ -0,0 +1,43 @@ +## FAQ + +1. **预测报错:got an unexpected keyword argument 'gradient_clip'** +安装的paddle版本不对,目前本项目仅支持paddle1.7,近期会适配到1.8。 + +2. **转换attention识别模型时报错:KeyError: 'predict'** +基于Attention损失的识别模型推理还在调试中。对于中文文本识别,建议优先选择基于CTC损失的识别模型,实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。 + +3. **关于推理速度** +图片中的文字较多时,预测时间会增,可以使用--rec_batch_num设置更小预测batch num,默认值为30,可以改为10或其他数值。 + +4. **服务部署与移动端部署** +预计6月中下旬会先后发布基于Serving的服务部署方案和基于Paddle Lite的移动端部署方案,欢迎持续关注。 + +5. **自研算法发布时间** +自研算法SAST、SRN、End2End-PSL都将在6-7月陆续发布,敬请期待。 + +6. **如何在Windows或Mac系统上运行** +PaddleOCR已完成Windows和Mac系统适配,运行时注意两点:1、在[快速安装](installation.md)时,如果不想安装docker,可跳过第一步,直接从第二步安装paddle开始。2、inference模型下载时,如果没有安装wget,可直接点击模型链接或将链接地址复制到浏览器进行下载,并解压放置到相应目录。 + +7. **超轻量模型和通用OCR模型的区别** +目前PaddleOCR开源了2个中文模型,分别是8.6M超轻量中文模型和通用中文OCR模型。两者对比信息如下: + - 相同点:两者使用相同的**算法**和**训练数据**; + - 不同点:不同之处在于**骨干网络**和**通道参数**,超轻量模型使用MobileNetV3作为骨干网络,通用模型使用Resnet50_vd作为检测模型backbone,Resnet34_vd作为识别模型backbone,具体参数差异可对比两种模型训练的配置文件. + +|模型|骨干网络|检测训练配置|识别训练配置| +|-|-|-|-| +|8.6M超轻量中文OCR模型|MobileNetV3+MobileNetV3|det_mv3_db.yml|rec_chinese_lite_train.yml| +|通用中文OCR模型|Resnet50_vd+Resnet34_vd|det_r50_vd_db.yml|rec_chinese_common_train.yml| + +8. **是否有计划开源仅识别数字或仅识别英文+数字的模型** +暂不计划开源仅数字、仅数字+英文、或其他小垂类专用模型。PaddleOCR开源了多种检测、识别算法供用户自定义训练,两种中文模型也是基于开源的算法库训练产出,有小垂类需求的小伙伴,可以按照教程准备好数据,选择合适的配置文件,自行训练,相信能有不错的效果。训练有任何问题欢迎提issue或在交流群提问,我们会及时解答。 + +9. **开源模型使用的训练数据是什么,能否开源** +目前开源的模型,数据集和量级如下: + - 检测: + 英文数据集,ICDAR2015 + 中文数据集,LSVT街景数据集训练数据3w张图片 + - 识别: + 英文数据集,MJSynth和SynthText合成数据,数据量上千万。 + 中文数据集,LSVT街景数据集根据真值将图crop出来,并进行位置校准,总共30w张图像。此外基于LSVT的语料,合成数据500w。 + + 其中,公开数据集都是开源的,用户可自行搜索下载,也可参考[中文数据集](datasets.md),合成数据暂不开源,用户可使用开源合成工具自行合成,可参考的合成工具包括[text_renderer](https://github.com/Sanster/text_renderer)、[SynthText](https://github.com/ankush-me/SynthText)、[TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator)等。 \ No newline at end of file diff --git a/doc/WeChat.jpeg b/doc/WeChat.jpeg deleted file mode 100644 index c721b7d9..00000000 Binary files a/doc/WeChat.jpeg and /dev/null differ diff --git a/doc/datasets.md b/doc/datasets.md new file mode 100644 index 00000000..1bca82c4 --- /dev/null +++ b/doc/datasets.md @@ -0,0 +1,58 @@ +## 数据集 +这里整理了常用中文数据集,持续更新中,欢迎各位小伙伴贡献数据集~ +- [ICDAR2019-LSVT](#ICDAR2019-LSVT) +- [ICDAR2017-RCTW-17](#ICDAR2017-RCTW-17) +- [中文街景文字识别](#中文街景文字识别) +- [中文文档文字识别](#中文文档文字识别) +- [ICDAR2019-ArT](#ICDAR2019-ArT) + +除了开源数据,用户还可使用合成工具自行合成,可参考的合成工具包括[text_renderer](https://github.com/Sanster/text_renderer)、[SynthText](https://github.com/ankush-me/SynthText)、[TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator)等。 + + +#### 1、ICDAR2019-LSVT +- **数据来源**:https://ai.baidu.com/broad/introduction?dataset=lsvt +- **数据简介**: 共45w中文街景图像,包含5w(2w测试+3w训练)全标注数据(文本坐标+文本内容),40w弱标注数据(仅文本内容),如下图所示: + ![](datasets/LSVT_1.jpg) + (a) 全标注数据 + ![](datasets/LSVT_2.jpg) + (b) 弱标注数据 +- **下载地址**:https://ai.baidu.com/broad/download?dataset=lsvt + + +#### 2、ICDAR2017-RCTW-17 +- **数据来源**:https://rctw.vlrlab.net/ +- **数据简介**:共包含12,000+图像,大部分图片是通过手机摄像头在野外采集的。有些是截图。这些图片展示了各种各样的场景,包括街景、海报、菜单、室内场景和手机应用程序的截图。 + ![](datasets/rctw.jpg) +- **下载地址**:https://rctw.vlrlab.net/dataset/ + + +#### 3、中文街景文字识别 +- **数据来源**:https://aistudio.baidu.com/aistudio/competition/detail/8 +- **数据简介**:共包括29万张图片,其中21万张图片作为训练集(带标注),8万张作为测试集(无标注)。数据集采自中国街景,并由街景图片中的文字行区域(例如店铺标牌、地标等等)截取出来而形成。所有图像都经过一些预处理,将文字区域利用仿射变化,等比映射为一张高为48像素的图片,如图所示: + ![](datasets/ch_street_rec_1.png) + (a) 标注:魅派集成吊顶 + ![](datasets/ch_street_rec_2.png) + (b) 标注:母婴用品连锁 +- **下载地址** +https://aistudio.baidu.com/aistudio/datasetdetail/8429 + + +#### 4、中文文档文字识别 +- **数据来源**:https://github.com/YCG09/chinese_ocr +- **数据简介**: + - 共约364万张图片,按照99:1划分成训练集和验证集。 + - 数据利用中文语料库(新闻 + 文言文),通过字体、大小、灰度、模糊、透视、拉伸等变化随机生成 + - 包含汉字、英文字母、数字和标点共5990个字符(字符集合:https://github.com/YCG09/chinese_ocr/blob/master/train/char_std_5990.txt ) + - 每个样本固定10个字符,字符随机截取自语料库中的句子 + - 图片分辨率统一为280x32 + ![](datasets/ch_doc1.jpg) + ![](datasets/ch_doc2.jpg) + ![](datasets/ch_doc3.jpg) +- **下载地址**:https://pan.baidu.com/s/1QkI7kjah8SPHwOQ40rS1Pw (密码:lu7m) + + +#### 5、ICDAR2019-ArT +- **数据来源**:https://ai.baidu.com/broad/introduction?dataset=art +- **数据简介**:共包含10,166张图像,训练集5603图,测试集4563图。由Total-Text、SCUT-CTW1500、Baidu Curved Scene Text三部分组成,包含水平、多方向和弯曲等多种形状的文本。 + ![](datasets/ArT.jpg) +- **下载地址**:https://ai.baidu.com/broad/download?dataset=art \ No newline at end of file diff --git a/doc/datasets/ArT.jpg b/doc/datasets/ArT.jpg new file mode 100644 index 00000000..7855ba4a Binary files /dev/null and b/doc/datasets/ArT.jpg differ diff --git a/doc/datasets/LSVT_1.jpg b/doc/datasets/LSVT_1.jpg new file mode 100644 index 00000000..ea11a7da Binary files /dev/null and b/doc/datasets/LSVT_1.jpg differ diff --git a/doc/datasets/LSVT_2.jpg b/doc/datasets/LSVT_2.jpg new file mode 100644 index 00000000..67bfbe52 Binary files /dev/null and b/doc/datasets/LSVT_2.jpg differ diff --git a/doc/datasets/ch_doc1.jpg b/doc/datasets/ch_doc1.jpg new file mode 100644 index 00000000..53534400 Binary files /dev/null and b/doc/datasets/ch_doc1.jpg differ diff --git a/doc/datasets/ch_doc2.jpg b/doc/datasets/ch_doc2.jpg new file mode 100644 index 00000000..23343b8d Binary files /dev/null and b/doc/datasets/ch_doc2.jpg differ diff --git a/doc/datasets/ch_doc3.jpg b/doc/datasets/ch_doc3.jpg new file mode 100644 index 00000000..c0c20536 Binary files /dev/null and b/doc/datasets/ch_doc3.jpg differ diff --git a/doc/datasets/ch_street_rec_1.png b/doc/datasets/ch_street_rec_1.png new file mode 100644 index 00000000..a0e158cb Binary files /dev/null and b/doc/datasets/ch_street_rec_1.png differ diff --git a/doc/datasets/ch_street_rec_2.png b/doc/datasets/ch_street_rec_2.png new file mode 100644 index 00000000..bfa0fd01 Binary files /dev/null and b/doc/datasets/ch_street_rec_2.png differ diff --git a/doc/datasets/rctw.jpg b/doc/datasets/rctw.jpg new file mode 100644 index 00000000..1e1f945b Binary files /dev/null and b/doc/datasets/rctw.jpg differ diff --git a/doc/detection.md b/doc/detection.md new file mode 100644 index 00000000..849755bf --- /dev/null +++ b/doc/detection.md @@ -0,0 +1,99 @@ +# 文字检测 + +本节以icdar15数据集为例,介绍PaddleOCR中检测模型的训练、评估与测试。 + +## 数据准备 +icdar2015数据集可以从[官网](https://rrc.cvc.uab.es/?ch=4&com=downloads)下载到,首次下载需注册。 + +将下载到的数据集解压到工作目录下,假设解压在 PaddleOCR/train_data/ 下。另外,PaddleOCR将零散的标注文件整理成单独的标注文件 +,您可以通过wget的方式进行下载。 +``` +# 在PaddleOCR路径下 +cd PaddleOCR/ +wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt +wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_label.txt +``` + +解压数据集和下载标注文件后,PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是: +``` +/PaddleOCR/train_data/icdar2015/text_localization/ + └─ icdar_c4_train_imgs/ icdar数据集的训练数据 + └─ ch4_test_images/ icdar数据集的测试数据 + └─ train_icdar2015_label.txt icdar数据集的训练标注 + └─ test_icdar2015_label.txt icdar数据集的测试标注 +``` + +提供的标注文件格式为: +``` +" 图像文件名 json.dumps编码的图像标注信息" +ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}] +``` +json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 +`transcription` 表示当前文本框的文字,在文本检测任务中并不需要这个信息。 +如果您想在其他数据集上训练PaddleOCR,可以按照上述形式构建标注文件。 + + +## 快速启动训练 + +首先下载pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet50_vd, +您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。 +``` +cd PaddleOCR/ +# 下载MobileNetV3的预训练模型 +wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar +# 下载ResNet50的预训练模型 +wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar +``` + +**启动训练** + +*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false* + +``` +python3 tools/train.py -c configs/det/det_mv3_db.yml +``` + +上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。 +有关配置文件的详细解释,请参考[链接](./config.md)。 + +您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001 +``` +python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001 +``` + +## 指标评估 + +PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。 + +运行如下代码,根据配置文件det_db_mv3.yml中save_res_path指定的测试集检测结果文件,计算评估指标。 + +评估时设置后处理参数box_thresh=0.6,unclip_ratio=1.5,使用不同数据集、不同模型训练,可调整这两个参数进行优化 +``` +python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5 +``` +训练中模型参数默认保存在Global.save_model_dir目录下。在评估指标时,需要设置Global.checkpoints指向保存的参数文件。 + +比如: +``` +python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5 +``` + +* 注:box_thresh、unclip_ratio是DB后处理所需要的参数,在评估EAST模型时不需要设置 + +## 测试检测效果 + +测试单张图像的检测效果 +``` +python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" +``` + +测试DB模型时,调整后处理阈值, +``` +python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5 +``` + + +测试文件夹下所有图像的检测效果 +``` +python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy" +``` diff --git a/doc/inference.md b/doc/inference.md new file mode 100644 index 00000000..0d5f45fd --- /dev/null +++ b/doc/inference.md @@ -0,0 +1,219 @@ + +# 基于预测引擎推理 + +inference 模型(fluid.io.save_inference_model保存的模型) +一般是模型训练完成后保存的固化模型,多用于预测部署。 +训练过程中保存的模型是checkpoints模型,保存的是模型的参数,多用于恢复训练等。 +与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合与实际系统集成。更详细的介绍请参考文档[分类预测框架](https://paddleclas.readthedocs.io/zh_CN/latest/extension/paddle_inference.html). + +接下来首先介绍如何将训练的模型转换成inference模型,然后将依次介绍文本检测、文本识别以及两者串联基于预测引擎推理。 + +## 训练模型转inference模型 +### 检测模型转inference模型 + +下载超轻量级中文检测模型: +``` +wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar && tar xf ./ch_lite/ch_det_mv3_db.tar -C ./ch_lite/ +``` +上述模型是以MobileNetV3为backbone训练的DB算法,将训练好的模型转换成inference模型只需要运行如下命令: +``` +python3 tools/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./ch_lite/det_mv3_db/best_accuracy Global.save_inference_dir=./inference/det_db/ +``` +转inference模型时,使用的配置文件和训练时使用的配置文件相同。另外,还需要设置配置文件中的Global.checkpoints、Global.save_inference_dir参数。 +其中Global.checkpoints指向训练中保存的模型参数文件,Global.save_inference_dir是生成的inference模型要保存的目录。 +转换成功后,在save_inference_dir 目录下有两个文件: +``` +inference/det_db/ + └─ model 检测inference模型的program文件 + └─ params 检测inference模型的参数文件 +``` + +### 识别模型转inference模型 + +下载超轻量中文识别模型: +``` +wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar && tar xf ./ch_lite/ch_rec_mv3_crnn.tar -C ./ch_lite/ +``` + +识别模型转inference模型与检测的方式相同,如下: +``` +python3 tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints=./ch_lite/rec_mv3_crnn/best_accuracy \ + Global.save_inference_dir=./inference/rec_crnn/ +``` + +如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的character_dict_path是否是所需要的字典文件。 + +转换成功后,在目录下有两个文件: +``` +/inference/rec_crnn/ + └─ model 识别inference模型的program文件 + └─ params 识别inference模型的参数文件 +``` + +## 文本检测模型推理 + +下面将介绍超轻量中文检测模型推理、DB文本检测模型推理和EAST文本检测模型推理。默认配置是根据DB文本检测模型推理设置的。由于EAST和DB算法差别很大,在推理时,需要通过传入相应的参数适配EAST文本检测算法。 + +### 1.超轻量中文检测模型推理 + +超轻量中文检测模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" +``` + +可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: + +![](imgs_results/det_res_2.jpg) + +通过设置参数det_max_side_len的大小,改变检测算法中图片规范化的最大值。当图片的长宽都小于det_max_side_len,则使用原图预测,否则将图片等比例缩放到最大值,进行预测。该参数默认设置为det_max_side_len=960. 如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以执行如下命令: + +``` +python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_max_side_len=1200 +``` + +如果想使用CPU进行预测,执行命令如下 +``` +python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False +``` + +### 2.DB文本检测模型推理 + +首先将DB文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)),可以使用如下命令进行转换: + +``` +# -c后面设置训练算法的yml配置文件 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_model.py -c configs/det/det_r50_vd_db.yml -o Global.checkpoints="./models/det_r50_vd_db/best_accuracy" Global.save_inference_dir="./inference/det_db" +``` + +DB文本检测模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/" +``` + +可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: + +![](imgs_results/det_res_img_10_db.jpg) + +**注意**:由于ICDAR2015数据集只有1000张训练图像,主要针对英文场景,所以上述模型对中文文本图像检测效果非常差。 + +### 3.EAST文本检测模型推理 + +首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_east.tar)),可以使用如下命令进行转换: + +``` +# -c后面设置训练算法的yml配置文件 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.checkpoints="./models/det_r50_vd_east/best_accuracy" Global.save_inference_dir="./inference/det_east" +``` + +EAST文本检测模型推理,需要设置参数det_algorithm,指定检测算法类型为EAST,可以执行如下命令: + +``` +python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" +``` +可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下: + +![](imgs_results/det_res_img_10_east.jpg) + +**注意**:本代码库中EAST后处理中NMS采用的Python版本,所以预测速度比较耗时。如果采用C++版本,会有明显加速。 + + +## 文本识别模型推理 + +下面将介绍超轻量中文识别模型推理和基于CTC损失的识别模型推理。**而基于Attention损失的识别模型推理还在调试中**。对于中文文本识别,建议优先选择基于CTC损失的识别模型,实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。 + + +### 1.超轻量中文识别模型推理 + +超轻量中文识别模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --rec_model_dir="./inference/rec_crnn/" +``` + +![](imgs_words/ch/word_4.jpg) + +执行命令后,上面图像的预测结果(识别的文本和得分)会打印到屏幕上,示例如下: + +Predicts of ./doc/imgs_words/ch/word_4.jpg:['实力活力', 0.89552695] + + +### 2.基于CTC损失的识别模型推理 + +我们以STAR-Net为例,介绍基于CTC损失的识别模型推理。 CRNN和Rosetta使用方式类似,不用设置识别算法参数rec_algorithm。 + +首先将STAR-Net文本识别训练过程中保存的模型,转换成inference model。以基于Resnet34_vd骨干网络,使用MJSynth和SynthText两个英文文本识别合成数据集训练 +的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_ctc.tar)),可以使用如下命令进行转换: + +``` +# -c后面设置训练算法的yml配置文件 +# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。 +# Global.save_inference_dir参数设置转换的模型将保存的地址。 + +python3 tools/export_model.py -c configs/rec/rec_r34_vd_tps_bilstm_ctc.yml -o Global.checkpoints="./models/rec_r34_vd_tps_bilstm_ctc/best_accuracy" Global.save_inference_dir="./inference/starnet" +``` + +STAR-Net文本识别模型推理,可以执行如下命令: + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en" +``` + +### 3.基于Attention损失的识别模型推理 + +基于Attention损失的识别模型与ctc不同,需要额外设置识别算法参数 --rec_algorithm="RARE" + +RARE 文本识别模型推理,可以执行如下命令: +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/sare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE" +``` + +![](imgs_words_en/word_336.png) + +执行命令后,上面图像的识别结果如下: + +Predicts of ./doc/imgs_words_en/word_336.png:['super', 0.9999555] + +**注意**:由于上述模型是参考[DTRB](https://arxiv.org/abs/1904.01906)文本识别训练和评估流程,与超轻量级中文识别模型训练有两方面不同: + +- 训练时采用的图像分辨率不同,训练上述模型采用的图像分辨率是[3,32,100],而中文模型训练时,为了保证长文本的识别效果,训练时采用的图像分辨率是[3, 32, 320]。预测推理程序默认的的形状参数是训练中文采用的图像分辨率,即[3, 32, 320]。因此,这里推理上述英文模型时,需要通过参数rec_image_shape设置识别图像的形状。 + +- 字符列表,DTRB论文中实验只是针对26个小写英文本母和10个数字进行实验,总共36个字符。所有大小字符都转成了小写字符,不在上面列表的字符都忽略,认为是空格。因此这里没有输入字符字典,而是通过如下命令生成字典.因此在推理时需要设置参数rec_char_type,指定为英文"en"。 + +``` +self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" +dict_character = list(self.character_str) +``` + +## 文本检测、识别串联推理 + +### 1.超轻量中文OCR模型推理 + +在执行预测时,需要通过参数image_dir指定单张图像或者图像集合的路径、参数det_model_dir指定检测inference模型的路径和参数rec_model_dir指定识别inference模型的路径。可视化识别结果默认保存到 ./inference_results 文件夹里面。 + +``` +python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" +``` + +执行命令后,识别结果图像如下: + +![](imgs_results/2.jpg) + +### 2.其他模型推理 + +如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型,下面给出基于EAST文本检测和STAR-Net文本识别执行命令: + +``` +python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en" +``` + +执行命令后,识别结果图像如下: + +![](imgs_results/img_10.jpg) diff --git a/doc/installation.md b/doc/installation.md new file mode 100644 index 00000000..f1edbf4a --- /dev/null +++ b/doc/installation.md @@ -0,0 +1,81 @@ +## 快速安装 + +经测试PaddleOCR可在glibc 2.23上运行,您也可以测试其他glibc版本或安装glic 2.23 +PaddleOCR 工作环境 +- PaddlePaddle1.7 +- python3 +- glibc 2.23 + +建议使用我们提供的docker运行PaddleOCR,有关docker使用请参考[链接](https://docs.docker.com/get-started/)。 + +*如您希望使用 mac 或 windows直接运行预测代码,可以从第2步开始执行。* + +1. (建议)准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。 +``` +# 切换到工作目录下 +cd /home/Projects +# 首次运行需创建一个docker容器,再次运行时不需要运行当前命令 +# 创建一个名字为ppocr的docker容器,并将当前目录映射到容器的/paddle目录下 + +如果您希望在CPU环境下使用docker,使用docker而不是nvidia-docker创建docker +sudo docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash + +如果您的机器安装的是CUDA9,请运行以下命令创建容器 +sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash + +如果您的机器安装的是CUDA10,请运行以下命令创建容器 +sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda10.0-cudnn7-dev /bin/bash + +您也可以访问[DockerHub](https://hub.docker.com/r/paddlepaddle/paddle/tags/)获取与您机器适配的镜像。 + +# ctrl+P+Q可退出docker,重新进入docker使用如下命令 +sudo docker container exec -it ppocr /bin/bash +``` + +注意:如果docker pull过慢,可以按照如下步骤手动下载后加载docker,以cuda9 docker为例,使用cuda10 docker只需要将cuda9改为cuda10即可。 +``` +# 下载CUDA9 docker的压缩文件,并解压 +wget https://paddleocr.bj.bcebos.com/docker/docker_pdocr_cuda9.tar.gz +# 为减少下载时间,上传的docker image是压缩过的,需要解压使用 +tar zxf docker_pdocr_cuda9.tar.gz +# 创建image +docker load < docker_pdocr_cuda9.tar +# 完成上述步骤后通过docker images检查是否加载了下载的镜像 +docker images +# 执行docker images后如果有下面的输出,即可按照按照 步骤1 创建docker环境。 +hub.baidubce.com/paddlepaddle/paddle latest-gpu-cuda9.0-cudnn7-dev f56310dcc829 +``` + +2. 安装PaddlePaddle Fluid v1.7(暂不支持更高版本,适配工作进行中) +``` +pip3 install --upgrade pip + +如果您的机器安装的是CUDA9,请运行以下命令安装 +python3 -m pip install paddlepaddle-gpu==1.7.2.post97 -i https://pypi.tuna.tsinghua.edu.cn/simple + +如果您的机器安装的是CUDA10,请运行以下命令安装 +python3 -m pip install paddlepaddle-gpu==1.7.2.post107 -i https://pypi.tuna.tsinghua.edu.cn/simple + +如果您的机器是CPU,请运行以下命令安装 + +python3 -m pip install paddlepaddle==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/simple + +更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 +``` + +3. 克隆PaddleOCR repo代码 +``` +【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR + +如果因为网络问题无法pull成功,也可选择使用码云上的托管: + +git clone https://gitee.com/paddlepaddle/PaddleOCR + +注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。 +``` + +4. 安装第三方库 +``` +cd PaddleOCR +pip3 install -r requirments.txt +``` diff --git a/doc/recognition.md b/doc/recognition.md new file mode 100644 index 00000000..45fa74ab --- /dev/null +++ b/doc/recognition.md @@ -0,0 +1,225 @@ +## 文字识别 + +### 数据准备 + + +PaddleOCR 支持两种数据格式: `lmdb` 用于训练公开数据,调试算法; `通用数据` 训练自己的数据: + +请按如下步骤设置数据集: + +训练数据的默认存储路径是 `PaddleOCR/train_data`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录: + +``` +ln -sf /train_data/dataset +``` + + +* 数据下载 + +若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。 + +* 使用自己数据集: + +若您希望使用自己的数据进行训练,请参考下文组织您的数据。 + +- 训练集 + +首先请将训练图片放入同一个文件夹(train_images),并用一个txt文件(rec_gt_train.txt)记录图片路径和标签。 + +* 注意: 默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错 + +``` +" 图像文件名 图像标注信息 " + +train_data/train_0001.jpg 简单可依赖 +train_data/train_0002.jpg 用科技让复杂的世界更简单 +``` +PaddleOCR 提供了一份用于训练 icdar2015 数据集的标签文件,通过以下方式下载: + +``` +# 训练集标签 +wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_train.txt +# 测试集标签 +wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt + + +``` + +最终训练集应有如下文件结构: + +``` +|-train_data + |-ic15_data + |- rec_gt_train.txt + |- train + |- word_001.png + |- word_002.jpg + |- word_003.jpg + | ... +``` + +- 测试集 + +同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个rec_gt_test.txt,测试集的结构如下所示: + +``` +|-train_data + |-ic15_data + |- rec_gt_test.txt + |- test + |- word_001.jpg + |- word_002.jpg + |- word_003.jpg + | ... +``` + +- 字典 + +最后需要提供一个字典({word_dict_name}.txt),使模型在训练时,可以将所有出现的字符映射为字典的索引。 + +因此字典需要包含所有希望被正确识别的字符,{word_dict_name}.txt需要写成如下格式,并以 `utf-8` 编码格式保存: + +``` +l +d +a +d +r +n +``` + +word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,“and” 将被映射成 [2 5 1] + +`ppocr/utils/ppocr_keys_v1.txt` 是一个包含6623个字符的中文字典, +`ppocr/utils/ic15_dict.txt` 是一个包含36个字符的英文字典, +您可以按需使用。 + +如需自定义dic文件,请修改 `configs/rec/rec_icdar15_train.yml` 中的 `character_dict_path` 字段, 并将 `character_type` 设置为 `ch`。 + +### 启动训练 + +PaddleOCR提供了训练脚本、评估脚本和预测脚本,本节将以 CRNN 识别模型为例: + +首先下载pretrain model,您可以下载训练好的模型在 icdar2015 数据上进行finetune + +``` +cd PaddleOCR/ +# 下载MobileNetV3的预训练模型 +wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_mv3_none_bilstm_ctc.tar +# 解压模型参数 +cd pretrain_models +tar -xf rec_mv3_none_bilstm_ctc.tar && rm -rf rec_mv3_none_bilstm_ctc.tar +``` + +开始训练: + +*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false* + +``` +# 设置PYTHONPATH路径 +export PYTHONPATH=$PYTHONPATH:. +# GPU训练 支持单卡,多卡训练,通过CUDA_VISIBLE_DEVICES指定卡号 +export CUDA_VISIBLE_DEVICES=0,1,2,3 +# 训练icdar15英文数据 +python3 tools/train.py -c configs/rec/rec_icdar15_train.yml +``` + +PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_train.yml` 中修改 `eval_batch_step` 设置评估频率,默认每500个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/rec_CRNN/best_accuracy` 。 + +如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。 + +* 提示: 可通过 -c 参数选择 `configs/rec/` 路径下的多种模型配置进行训练,PaddleOCR支持的识别算法有: + + +| 配置文件 | 算法名称 | backbone | trans | seq | pred | +| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | +| rec_chinese_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | +| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc | +| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc | +| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc | +| rec_mv3_tps_bilstm_ctc.yml | STARNet | Mobilenet_v3 large 0.5 | tps | BiLSTM | ctc | +| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention | +| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc | +| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc | +| rec_r34_vd_tps_bilstm_attn.yml | RARE | Resnet34_vd | tps | BiLSTM | attention | +| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc | + +训练中文数据,推荐使用`rec_chinese_lite_train.yml`,如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: + +以 `rec_mv3_none_none_ctc.yml` 为例: +``` +Global: + ... + # 修改 image_shape 以适应长文本 + image_shape: [3, 32, 320] + ... + # 修改字符类型 + character_type: ch + # 添加自定义字典,如修改字典请将路径指向新字典 + character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt + ... + # 修改reader类型 + reader_yml: ./configs/rec/rec_chinese_reader.yml + ... + +... +``` +**注意,预测/评估时的配置文件请务必与训练一致。** + + + +### 评估 + +评估数据集可以通过 `configs/rec/rec_icdar15_reader.yml` 修改EvalReader中的 `label_file_path` 设置。 + +*注意* 评估时必须确保配置文件中 infer_img 字段为空 +``` +export CUDA_VISIBLE_DEVICES=0 +# GPU 评估, Global.checkpoints 为待测权重 +python3 tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy +``` + +### 预测 + +* 训练引擎的预测 + +使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。 + +默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 指定权重: + +``` +# 预测英文结果 +python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png +``` + +预测图片: + +![](./imgs_words/en/word_1.png) + +得到输入图像的预测结果: + +``` +infer_img: doc/imgs_words/en/word_1.png + index: [19 24 18 23 29] + word : joint +``` + +预测使用的配置文件必须与训练一致,如您通过 `python3 tools/train.py -c configs/rec/rec_chinese_lite_train.yml` 完成了中文模型的训练, +您可以使用如下命令进行中文模型预测。 + +``` +# 预测中文结果 +python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/ch/word_1.jpg +``` + +预测图片: + +![](./imgs_words/ch/word_1.jpg) + +得到输入图像的预测结果: + +``` +infer_img: doc/imgs_words/ch/word_1.jpg + index: [2092 177 312 2503] + word : 韩国小馆 +``` diff --git a/doc/update.md b/doc/update.md new file mode 100644 index 00000000..69961b98 --- /dev/null +++ b/doc/update.md @@ -0,0 +1,10 @@ +# 版本更新 + +- 2020.6.5 支持 `attetnion` 模型导出 `inference_model` +- 2020.6.5 支持单独预测识别时,输出结果得分 +- 2020.5.30 提供超轻量级中文OCR在线体验 +- 2020.5.30 模型预测、训练支持Windows系统 +- 2020.5.30 开源通用中文OCR模型 +- 2020.5.14 发布[PaddleOCR公开课](https://www.bilibili.com/video/BV1nf4y1U7RX?p=4) +- 2020.5.14 发布[PaddleOCR实战练习](https://aistudio.baidu.com/aistudio/projectdetail/467229) +- 2020.5.14 开源8.6M超轻量级中文OCR模型 diff --git a/ppocr/data/det/dataset_traversal.py b/ppocr/data/det/dataset_traversal.py index 272d7317..ab635de3 100644 --- a/ppocr/data/det/dataset_traversal.py +++ b/ppocr/data/det/dataset_traversal.py @@ -61,8 +61,6 @@ class TrainReader(object): if len(batch_outs) == self.batch_size: yield batch_outs batch_outs = [] - if len(batch_outs) != 0: - yield batch_outs return batch_iter_reader diff --git a/ppocr/data/det/db_process.py b/ppocr/data/det/db_process.py index d347ed44..0993324a 100644 --- a/ppocr/data/det/db_process.py +++ b/ppocr/data/det/db_process.py @@ -17,6 +17,8 @@ import cv2 import numpy as np import json import sys +from ppocr.utils.utility import initial_logger +logger = initial_logger() from .data_augment import AugmentData from .random_crop_data import RandomCropData @@ -100,6 +102,7 @@ class DBProcessTrain(object): img_path, gt_label = self.convert_label_infor(label_infor) imgvalue = cv2.imread(img_path) if imgvalue is None: + logger.info("{} does not exist!".format(img_path)) return None data = self.make_data_dict(imgvalue, gt_label) data = AugmentData(data) diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index f60b9fe3..e57717d9 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -41,13 +41,18 @@ class LMDBReader(object): self.loss_type = params['loss_type'] self.max_text_length = params['max_text_length'] self.mode = params['mode'] + self.drop_last = False + self.use_tps = False + if "tps" in params: + self.ues_tps = True if params['mode'] == 'train': self.batch_size = params['train_batch_size_per_card'] - elif params['mode'] == "eval": + self.drop_last = True + else: self.batch_size = params['test_batch_size_per_card'] - elif params['mode'] == "test": - self.batch_size = 1 - self.infer_img = params["infer_img"] + self.drop_last = False + self.infer_img = params['infer_img'] + def load_hierarchical_lmdb_dataset(self): lmdb_sets = {} dataset_idx = 0 @@ -100,13 +105,18 @@ class LMDBReader(object): process_id = 0 def sample_iter_reader(): - if self.mode == 'test': + if self.mode != 'train' and self.infer_img is not None: image_file_list = get_image_file_list(self.infer_img) for single_img in image_file_list: img = cv2.imread(single_img) - if img.shape[-1]==1 or len(list(img.shape))==2: + if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - norm_img = process_image(img, self.image_shape) + norm_img = process_image( + img=img, + image_shape=self.image_shape, + char_ops=self.char_ops, + tps=self.use_tps, + infer_mode=True) yield norm_img else: lmdb_sets = self.load_hierarchical_lmdb_dataset() @@ -126,9 +136,13 @@ class LMDBReader(object): if sample_info is None: continue img, label = sample_info - outs = process_image(img, self.image_shape, label, - self.char_ops, self.loss_type, - self.max_text_length) + outs = process_image( + img=img, + image_shape=self.image_shape, + label=label, + char_ops=self.char_ops, + loss_type=self.loss_type, + max_text_length=self.max_text_length) if outs is None: continue yield outs @@ -136,6 +150,7 @@ class LMDBReader(object): if finish_read_num == len(lmdb_sets): break self.close_lmdb_dataset(lmdb_sets) + def batch_iter_reader(): batch_outs = [] for outs in sample_iter_reader(): @@ -143,10 +158,11 @@ class LMDBReader(object): if len(batch_outs) == self.batch_size: yield batch_outs batch_outs = [] - if len(batch_outs) != 0: - yield batch_outs + if not self.drop_last: + if len(batch_outs) != 0: + yield batch_outs - if self.mode != 'test': + if self.infer_img is None: return batch_iter_reader return sample_iter_reader @@ -165,26 +181,34 @@ class SimpleReader(object): self.loss_type = params['loss_type'] self.max_text_length = params['max_text_length'] self.mode = params['mode'] + self.infer_img = params['infer_img'] + self.use_tps = False + if "tps" in params: + self.use_tps = True if params['mode'] == 'train': self.batch_size = params['train_batch_size_per_card'] - elif params['mode'] == 'eval': - self.batch_size = params['test_batch_size_per_card'] + self.drop_last = True else: - self.batch_size = 1 - self.infer_img = params['infer_img'] + self.batch_size = params['test_batch_size_per_card'] + self.drop_last = False def __call__(self, process_id): if self.mode != 'train': process_id = 0 def sample_iter_reader(): - if self.mode == 'test': + if self.mode != 'train' and self.infer_img is not None: image_file_list = get_image_file_list(self.infer_img) for single_img in image_file_list: img = cv2.imread(single_img) - if img.shape[-1]==1 or len(list(img.shape))==2: + if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - norm_img = process_image(img, self.image_shape) + norm_img = process_image( + img=img, + image_shape=self.image_shape, + char_ops=self.char_ops, + tps=self.use_tps, + infer_mode=True) yield norm_img else: with open(self.label_file_path, "rb") as fin: @@ -192,7 +216,7 @@ class SimpleReader(object): img_num = len(label_infor_list) img_id_list = list(range(img_num)) random.shuffle(img_id_list) - if sys.platform=="win32": + if sys.platform == "win32": print("multiprocess is not fully compatible with Windows." "num_workers will be 1.") self.num_workers = 1 @@ -204,7 +228,7 @@ class SimpleReader(object): if img is None: logger.info("{} does not exist!".format(img_path)) continue - if img.shape[-1]==1 or len(list(img.shape))==2: + if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) label = substr[1] @@ -222,9 +246,10 @@ class SimpleReader(object): if len(batch_outs) == self.batch_size: yield batch_outs batch_outs = [] - if len(batch_outs) != 0: - yield batch_outs + if not self.drop_last: + if len(batch_outs) != 0: + yield batch_outs - if self.mode != 'test': + if self.infer_img is None: return batch_iter_reader return sample_iter_reader diff --git a/ppocr/data/rec/img_tools.py b/ppocr/data/rec/img_tools.py index a27f108c..57543293 100755 --- a/ppocr/data/rec/img_tools.py +++ b/ppocr/data/rec/img_tools.py @@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape): return padding_im +def resize_norm_img_chinese(img, image_shape): + imgC, imgH, imgW = image_shape + # todo: change to 0 and modified image shape + max_wh_ratio = 0 + h, w = img.shape[0], img.shape[1] + ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, ratio) + imgW = int(32 * max_wh_ratio) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def get_img_data(value): """get_img_data""" if not value: @@ -66,8 +92,13 @@ def process_image(img, label=None, char_ops=None, loss_type=None, - max_text_length=None): - norm_img = resize_norm_img(img, image_shape) + max_text_length=None, + tps=None, + infer_mode=False): + if not infer_mode or char_ops.character_type == "en" or tps != None: + norm_img = resize_norm_img(img, image_shape) + else: + norm_img = resize_norm_img_chinese(img, image_shape) norm_img = norm_img[np.newaxis, :] if label is not None: char_num = char_ops.get_char_num() diff --git a/ppocr/modeling/architectures/rec_model.py b/ppocr/modeling/architectures/rec_model.py index d88c620b..e80a50ab 100755 --- a/ppocr/modeling/architectures/rec_model.py +++ b/ppocr/modeling/architectures/rec_model.py @@ -30,6 +30,8 @@ class RecModel(object): global_params = params['Global'] char_num = global_params['char_ops'].get_char_num() global_params['char_num'] = char_num + self.char_type = global_params['character_type'] + self.infer_img = global_params['infer_img'] if "TPS" in params: tps_params = deepcopy(params["TPS"]) tps_params.update(global_params) @@ -60,8 +62,8 @@ class RecModel(object): def create_feed(self, mode): image_shape = deepcopy(self.image_shape) image_shape.insert(0, -1) - image = fluid.data(name='image', shape=image_shape, dtype='float32') if mode == "train": + image = fluid.data(name='image', shape=image_shape, dtype='float32') if self.loss_type == "attention": label_in = fluid.data( name='label_in', @@ -86,6 +88,16 @@ class RecModel(object): use_double_buffer=True, iterable=False) else: + if self.char_type == "ch" and self.infer_img: + image_shape[-1] = -1 + if self.tps != None: + logger.info( + "WARNRNG!!!\n" + "TPS does not support variable shape in chinese!" + "We set img_shape to be the same , it may affect the inference effect" + ) + image_shape = deepcopy(self.image_shape) + image = fluid.data(name='image', shape=image_shape, dtype='float32') labels = None loader = None return image, labels, loader @@ -110,7 +122,11 @@ class RecModel(object): return loader, outputs elif mode == "export": predict = predicts['predict'] - predict = fluid.layers.softmax(predict) + if self.loss_type == "ctc": + predict = fluid.layers.softmax(predict) return [image, {'decoded_out': decoded_out, 'predicts': predict}] else: - return loader, {'decoded_out': decoded_out} + predict = predicts['predict'] + if self.loss_type == "ctc": + predict = fluid.layers.softmax(predict) + return loader, {'decoded_out': decoded_out, 'predicts': predict} diff --git a/ppocr/modeling/heads/rec_attention_head.py b/ppocr/modeling/heads/rec_attention_head.py index 8f5b4cc4..66c8f300 100755 --- a/ppocr/modeling/heads/rec_attention_head.py +++ b/ppocr/modeling/heads/rec_attention_head.py @@ -123,6 +123,8 @@ class AttentionPredict(object): full_ids = fluid.layers.fill_constant_batch_size_like( input=init_state, shape=[-1, 1], dtype='int64', value=1) + full_scores = fluid.layers.fill_constant_batch_size_like( + input=init_state, shape=[-1, 1], dtype='float32', value=1) cond = layers.less_than(x=counter, y=array_len) while_op = layers.While(cond=cond) @@ -171,6 +173,9 @@ class AttentionPredict(object): new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1) fluid.layers.assign(new_ids, full_ids) + new_scores = fluid.layers.concat([full_scores, topk_scores], axis=1) + fluid.layers.assign(new_scores, full_scores) + layers.increment(x=counter, value=1, in_place=True) # update the memories @@ -184,7 +189,7 @@ class AttentionPredict(object): length_cond = layers.less_than(x=counter, y=array_len) finish_cond = layers.logical_not(layers.is_empty(x=topk_indices)) layers.logical_and(x=length_cond, y=finish_cond, out=cond) - return full_ids + return full_ids, full_scores def __call__(self, inputs, labels=None, mode=None): encoder_features = self.encoder(inputs) @@ -223,10 +228,10 @@ class AttentionPredict(object): decoder_size, char_num) _, decoded_out = layers.topk(input=predict, k=1) decoded_out = layers.lod_reset(decoded_out, y=label_out) - predicts = {'predict': predict, 'decoded_out': decoded_out} + predicts = {'predict':predict, 'decoded_out':decoded_out} else: - ids = self.gru_attention_infer( + ids, predict = self.gru_attention_infer( decoder_boot, self.max_length, char_num, word_vector_dim, encoded_vector, encoded_proj, decoder_size) - predicts = {'decoded_out': ids} + predicts = {'predict':predict, 'decoded_out':ids} return predicts diff --git a/tools/eval_utils/eval_rec_utils.py b/tools/eval_utils/eval_rec_utils.py index 3ceaa159..2d7d7e1d 100644 --- a/tools/eval_utils/eval_rec_utils.py +++ b/tools/eval_utils/eval_rec_utils.py @@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): total_sample_num = 0 total_acc_num = 0 total_batch_num = 0 - if mode == "test": + if mode == "eval": is_remove_duplicate = False else: is_remove_duplicate = True @@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict): total_correct_number = 0 eval_data_acc_info = {} for eval_data in eval_data_list: - config['EvalReader']['lmdb_sets_dir'] = \ + config['TestReader']['lmdb_sets_dir'] = \ eval_data_dir + "/" + eval_data - eval_reader = reader_main(config=config, mode="eval") + eval_reader = reader_main(config=config, mode="test") eval_info_dict['reader'] = eval_reader - metrics = eval_rec_run(exe, config, eval_info_dict, "eval") + metrics = eval_rec_run(exe, config, eval_info_dict, "test") total_evaluation_data_number += metrics['total_sample_num'] total_correct_number += metrics['total_acc_num'] eval_data_acc_info[eval_data] = metrics diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index c005bfb7..d47b81af 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -32,10 +32,16 @@ class TextRecognizer(object): self.rec_image_shape = image_shape self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num + self.rec_algorithm = args.rec_algorithm char_ops_params = {} char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_dict_path"] = args.rec_char_dict_path - char_ops_params['loss_type'] = 'ctc' + if self.rec_algorithm != "RARE": + char_ops_params['loss_type'] = 'ctc' + self.loss_type = 'ctc' + else: + char_ops_params['loss_type'] = 'attention' + self.loss_type = 'attention' self.char_ops = CharacterOps(char_ops_params) def resize_norm_img(self, img, max_wh_ratio): @@ -80,26 +86,43 @@ class TextRecognizer(object): starttime = time.time() self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.zero_copy_run() - rec_idx_batch = self.output_tensors[0].copy_to_cpu() - rec_idx_lod = self.output_tensors[0].lod()[0] - predict_batch = self.output_tensors[1].copy_to_cpu() - predict_lod = self.output_tensors[1].lod()[0] - elapse = time.time() - starttime - predict_time += elapse - starttime = time.time() - for rno in range(len(rec_idx_lod) - 1): - beg = rec_idx_lod[rno] - end = rec_idx_lod[rno + 1] - rec_idx_tmp = rec_idx_batch[beg:end, 0] - preds_text = self.char_ops.decode(rec_idx_tmp) - beg = predict_lod[rno] - end = predict_lod[rno + 1] - probs = predict_batch[beg:end, :] - ind = np.argmax(probs, axis=1) - blank = probs.shape[1] - valid_ind = np.where(ind != (blank - 1))[0] - score = np.mean(probs[valid_ind, ind[valid_ind]]) - rec_res.append([preds_text, score]) + + if self.loss_type == "ctc": + rec_idx_batch = self.output_tensors[0].copy_to_cpu() + rec_idx_lod = self.output_tensors[0].lod()[0] + predict_batch = self.output_tensors[1].copy_to_cpu() + predict_lod = self.output_tensors[1].lod()[0] + elapse = time.time() - starttime + predict_time += elapse + for rno in range(len(rec_idx_lod) - 1): + beg = rec_idx_lod[rno] + end = rec_idx_lod[rno + 1] + rec_idx_tmp = rec_idx_batch[beg:end, 0] + preds_text = self.char_ops.decode(rec_idx_tmp) + beg = predict_lod[rno] + end = predict_lod[rno + 1] + probs = predict_batch[beg:end, :] + ind = np.argmax(probs, axis=1) + blank = probs.shape[1] + valid_ind = np.where(ind != (blank - 1))[0] + score = np.mean(probs[valid_ind, ind[valid_ind]]) + rec_res.append([preds_text, score]) + else: + rec_idx_batch = self.output_tensors[0].copy_to_cpu() + predict_batch = self.output_tensors[1].copy_to_cpu() + elapse = time.time() - starttime + predict_time += elapse + for rno in range(len(rec_idx_batch)): + end_pos = np.where(rec_idx_batch[rno, :] == 1)[0] + if len(end_pos) <= 1: + preds = rec_idx_batch[rno, 1:] + score = np.mean(predict_batch[rno, 1:]) + else: + preds = rec_idx_batch[rno, 1:end_pos[1]] + score = np.mean(predict_batch[rno, 1:end_pos[1]]) + preds_text = self.char_ops.decode(preds) + rec_res.append([preds_text, score]) + return rec_res, predict_time @@ -116,7 +139,17 @@ if __name__ == "__main__": continue valid_image_file_list.append(image_file) img_list.append(img) - rec_res, predict_time = text_recognizer(img_list) + try: + rec_res, predict_time = text_recognizer(img_list) + except Exception as e: + print(e) + logger.info( + "ERROR!!!! \n" + "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" + "If your model has tps module: " + "TPS does not support variable shape.\n" + "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") + exit() for ino in range(len(img_list)): print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Total predict time for %d images:%.3f" % diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 25bae1ca..ec64a38b 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -21,6 +21,7 @@ import time import multiprocessing import numpy as np + def set_paddle_flags(**kwargs): for key, value in kwargs.items(): if os.environ.get(key, None) is None: @@ -54,6 +55,7 @@ def main(): program.merge_config(FLAGS.opt) logger.info(config) char_ops = CharacterOps(config['Global']) + loss_type = config['Global']['loss_type'] config['Global']['char_ops'] = char_ops # check if set use_gpu=True in paddlepaddle cpu version @@ -78,35 +80,44 @@ def main(): init_model(config, eval_prog, exe) blobs = reader_main(config, 'test')() - infer_img = config['TestReader']['infer_img'] + infer_img = config['Global']['infer_img'] infer_list = get_image_file_list(infer_img) max_img_num = len(infer_list) if len(infer_list) == 0: logger.info("Can not find img in infer_img dir.") for i in range(max_img_num): - print("infer_img:",infer_list[i]) + print("infer_img:%s" % infer_list[i]) img = next(blobs) predict = exe.run(program=eval_prog, feed={"image": img}, fetch_list=fetch_varname_list, return_numpy=False) - - preds = np.array(predict[0]) - if preds.shape[1] == 1: + if loss_type == "ctc": + preds = np.array(predict[0]) preds = preds.reshape(-1) preds_lod = predict[0].lod()[0] preds_text = char_ops.decode(preds) - else: + probs = np.array(predict[1]) + ind = np.argmax(probs, axis=1) + blank = probs.shape[1] + valid_ind = np.where(ind != (blank - 1))[0] + score = np.mean(probs[valid_ind, ind[valid_ind]]) + elif loss_type == "attention": + preds = np.array(predict[0]) + probs = np.array(predict[1]) end_pos = np.where(preds[0, :] == 1)[0] if len(end_pos) <= 1: - preds_text = preds[0, 1:] + preds = preds[0, 1:] + score = np.mean(probs[0, 1:]) else: - preds_text = preds[0, 1:end_pos[1]] - preds_text = preds_text.reshape(-1) - preds_text = char_ops.decode(preds_text) + preds = preds[0, 1:end_pos[1]] + score = np.mean(probs[0, 1:end_pos[1]]) + preds = preds.reshape(-1) + preds_text = char_ops.decode(preds) - print("\t index:",preds) - print("\t word :",preds_text) + print("\t index:", preds) + print("\t word :", preds_text) + print("\t score :", score) # save for inference model target_var = [] diff --git a/tools/program.py b/tools/program.py index 18a4ab7d..30e9d737 100755 --- a/tools/program.py +++ b/tools/program.py @@ -114,7 +114,7 @@ def merge_config(config): global_config[key] = value else: sub_keys = key.split('.') - assert (sub_keys[0] in global_config) + assert (sub_keys[0] in global_config), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(global_config.keys(), sub_keys[0]) cur = global_config[sub_keys[0]] for idx, sub_key in enumerate(sub_keys[1:]): assert (sub_key in cur)