commit
91341b81bd
41
README.md
41
README.md
|
@ -4,12 +4,11 @@ English | [简体中文](README_cn.md)
|
|||
PaddleOCR aims to create rich, leading, and practical OCR tools that help users train better models and apply them into practice.
|
||||
|
||||
**Recent updates**
|
||||
- 2020.8.16, Release text detection algorithm [SAST](https://arxiv.org/abs/1908.05498) and text recognition algorithm [SRN](https://arxiv.org/abs/2003.12294)
|
||||
- 2020.7.23, Release the playback and PPT of live class on BiliBili station, PaddleOCR Introduction, [address](https://aistudio.baidu.com/aistudio/course/introduce/1519)
|
||||
- 2020.7.15, Add mobile App demo , support both iOS and Android ( based on easyedge and Paddle Lite)
|
||||
- 2020.7.15, Improve the deployment ability, add the C + + inference , serving deployment. In addtion, the benchmarks of the ultra-lightweight OCR model are provided.
|
||||
- 2020.7.15, Improve the deployment ability, add the C + + inference , serving deployment. In addition, the benchmarks of the ultra-lightweight OCR model are provided.
|
||||
- 2020.7.15, Add several related datasets, data annotation and synthesis tools.
|
||||
- 2020.7.9 Add a new model to support recognize the character "space".
|
||||
- 2020.7.9 Add the data augument and learning rate decay strategies during training.
|
||||
- [more](./doc/doc_en/update_en.md)
|
||||
|
||||
## Features
|
||||
|
@ -18,7 +17,7 @@ PaddleOCR aims to create rich, leading, and practical OCR tools that help users
|
|||
- Detection model DB (4.1M) + recognition model CRNN (4.5M)
|
||||
- Various text detection algorithms: EAST, DB
|
||||
- Various text recognition algorithms: Rosetta, CRNN, STAR-Net, RARE
|
||||
- Support Linux, Windows, MacOS and other systems.
|
||||
- Support Linux, Windows, macOS and other systems.
|
||||
|
||||
## Visualization
|
||||
|
||||
|
@ -30,9 +29,9 @@ PaddleOCR aims to create rich, leading, and practical OCR tools that help users
|
|||
|
||||
You can also quickly experience the ultra-lightweight OCR : [Online Experience](https://www.paddlepaddle.org.cn/hub/scene/ocr)
|
||||
|
||||
Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Android systems): [Sign in the website to obtain the QR code for installing the App](https://ai.baidu.com/easyedge/app/openSource?from=paddlelite)
|
||||
Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Android systems): [Sign in to the website to obtain the QR code for installing the App](https://ai.baidu.com/easyedge/app/openSource?from=paddlelite)
|
||||
|
||||
Also, you can scan the QR code blow to install the App (**Android support only**)
|
||||
Also, you can scan the QR code below to install the App (**Android support only**)
|
||||
|
||||
<div align="center">
|
||||
<img src="./doc/ocr-android-easyedge.png" width = "200" height = "200" />
|
||||
|
@ -79,7 +78,7 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr
|
|||
- Visualization
|
||||
- [Ultra-lightweight Chinese/English OCR Visualization](#UCOCRVIS)
|
||||
- [General Chinese/English OCR Visualization](#GeOCRVIS)
|
||||
- [Chinese/English OCR Visualization (Support Space Recognization )](#SpaceOCRVIS)
|
||||
- [Chinese/English OCR Visualization (Support Space Recognition )](#SpaceOCRVIS)
|
||||
- [Community](#Community)
|
||||
- [References](./doc/doc_en/reference_en.md)
|
||||
- [License](#LICENSE)
|
||||
|
@ -91,7 +90,7 @@ Mobile DEMO experience (based on EasyEdge and Paddle-Lite, supports iOS and Andr
|
|||
PaddleOCR open source text detection algorithms list:
|
||||
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
|
||||
- [x] DB([paper](https://arxiv.org/abs/1911.08947))
|
||||
- [ ] SAST([paper](https://arxiv.org/abs/1908.05498))(Baidu Self-Research, comming soon)
|
||||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))(Baidu Self-Research)
|
||||
|
||||
On the ICDAR2015 dataset, the text detection result is as follows:
|
||||
|
||||
|
@ -101,8 +100,17 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|
|||
|EAST|MobileNetV3|81.67%|79.83%|80.74%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)|
|
||||
|DB|ResNet50_vd|83.79%|80.65%|82.19%|[Download link](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)|
|
||||
|DB|MobileNetV3|75.92%|73.18%|74.53%|[Download link](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)|
|
||||
|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[Download link](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_icdar2015.tar)|
|
||||
|
||||
For use of [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) street view dataset with a total of 3w training data,the related configuration and pre-trained models for text detection task are as follows:
|
||||
On Total-Text dataset, the text detection result is as follows:
|
||||
|
||||
|Model|Backbone|precision|recall|Hmean|Download link|
|
||||
|-|-|-|-|-|-|
|
||||
|SAST|ResNet50_vd|88.74%|79.80%|84.03%|[Download link](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)|
|
||||
|
||||
**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi).
|
||||
|
||||
For use of [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) street view dataset with a total of 3w training data,the related configuration and pre-trained models for text detection task are as follows:
|
||||
|Model|Backbone|Configuration file|Pre-trained model|
|
||||
|-|-|-|-|
|
||||
|ultra-lightweight OCR model|MobileNetV3|det_mv3_db.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|
|
||||
|
@ -120,7 +128,7 @@ PaddleOCR open-source text recognition algorithms list:
|
|||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
|
||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
|
||||
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))(Baidu Self-Research, comming soon)
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))(Baidu Self-Research)
|
||||
|
||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
||||
|
@ -134,8 +142,14 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar)|
|
||||
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)|
|
||||
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[Download link](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)|
|
||||
|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[Download link](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)|
|
||||
|
||||
**Note:** SRN model uses data expansion method to expand the two training sets mentioned above, and the expanded data can be downloaded from [Baidu Drive](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA) (download code: y3ry).
|
||||
|
||||
The average accuracy of the two-stage training in the original paper is 89.74%, and that of one stage training in paddleocr is 88.33%. Both pre-trained weights can be downloaded [here](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar).
|
||||
|
||||
We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) dataset and cropout 30w training data from original photos by using position groundtruth and make some calibration needed. In addition, based on the LSVT corpus, 500w synthetic data is generated to train the model. The related configuration and pre-trained models are as follows:
|
||||
|
||||
We use [LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/datasets_en.md#1-icdar2019-lsvt) dataset and cropout 30w traning data from original photos by using position groundtruth and make some calibration needed. In addition, based on the LSVT corpus, 500w synthetic data is generated to train the model. The related configuration and pre-trained models are as follows:
|
||||
|Model|Backbone|Configuration file|Pre-trained model|
|
||||
|-|-|-|-|
|
||||
|ultra-lightweight OCR model|MobileNetV3|rec_chinese_lite_train.yml|[Download link](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn.tar)|[inference model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar) & [pre-trained model](https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance.tar)|
|
||||
|
@ -145,7 +159,7 @@ Please refer to the document for training guide and use of PaddleOCR text recogn
|
|||
|
||||
<a name="ENDENDOCRALGORITHM"></a>
|
||||
## END-TO-END OCR Algorithm
|
||||
- [ ] [End2End-PSL](https://arxiv.org/abs/1909.07808)(Baidu Self-Research, comming soon)
|
||||
- [ ] [End2End-PSL](https://arxiv.org/abs/1909.07808)(Baidu Self-Research, coming soon)
|
||||
|
||||
## Visualization
|
||||
|
||||
|
@ -207,9 +221,10 @@ This project is released under <a href="https://github.com/PaddlePaddle/PaddleOC
|
|||
## Contribution
|
||||
We welcome all the contributions to PaddleOCR and appreciate for your feedback very much.
|
||||
|
||||
- Many thanks to [Khanh Tran](https://github.com/xxxpsyduck) for contributing the English documentation.
|
||||
- Many thanks to [Khanh Tran](https://github.com/xxxpsyduck) and [Karl Horky](https://github.com/karlhorky) for contributing and revising the English documentation.
|
||||
- Many thanks to [zhangxin](https://github.com/ZhangXinNan) for contributing the new visualize function、add .gitgnore and discard set PYTHONPATH manually.
|
||||
- Many thanks to [lyl120117](https://github.com/lyl120117) for contributing the code for printing the network structure.
|
||||
- Thanks [xiangyubo](https://github.com/xiangyubo) for contributing the handwritten Chinese OCR datasets.
|
||||
- Thanks [authorfu](https://github.com/authorfu) for contributing Android demo and [xiadeye](https://github.com/xiadeye) contributing iOS demo, respectively.
|
||||
- Thanks [BeyondYourself](https://github.com/BeyondYourself) for contributing many great suggestions and simplifying part of the code style.
|
||||
- Thanks [tangmq](https://gitee.com/tangmq) for contributing Dockerized deployment services to PaddleOCR and supporting the rapid release of callable Restful API services.
|
||||
|
|
25
README_cn.md
25
README_cn.md
|
@ -4,12 +4,11 @@
|
|||
PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。
|
||||
|
||||
**近期更新**
|
||||
- 2020.8.16 开源文本检测算法[SAST](https://arxiv.org/abs/1908.05498)和文本识别算法[SRN](https://arxiv.org/abs/2003.12294)
|
||||
- 2020.7.23 发布7月21日B站直播课回放和PPT,PaddleOCR开源大礼包全面解读,[获取地址](https://aistudio.baidu.com/aistudio/course/introduce/1519)
|
||||
- 2020.7.15 添加基于EasyEdge和Paddle-Lite的移动端DEMO,支持iOS和Android系统
|
||||
- 2020.7.15 完善预测部署,添加基于C++预测引擎推理、服务化部署和端侧部署方案,以及超轻量级中文OCR模型预测耗时Benchmark
|
||||
- 2020.7.15 整理OCR相关数据集、常用数据标注以及合成工具
|
||||
- 2020.7.9 添加支持空格的识别模型,识别效果,预测及训练方式请参考快速开始和文本识别训练相关文档
|
||||
- 2020.7.9 添加数据增强、学习率衰减策略,具体参考[配置文件](./doc/doc_ch/config.md)
|
||||
- [more](./doc/doc_ch/update.md)
|
||||
|
||||
|
||||
|
@ -93,7 +92,7 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
|
|||
PaddleOCR开源的文本检测算法列表:
|
||||
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
|
||||
- [x] DB([paper](https://arxiv.org/abs/1911.08947))
|
||||
- [ ] SAST([paper](https://arxiv.org/abs/1908.05498))(百度自研, coming soon)
|
||||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))(百度自研)
|
||||
|
||||
在ICDAR2015文本检测公开数据集上,算法效果如下:
|
||||
|
||||
|
@ -103,8 +102,19 @@ PaddleOCR开源的文本检测算法列表:
|
|||
|EAST|MobileNetV3|81.67%|79.83%|80.74%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)|
|
||||
|DB|ResNet50_vd|83.79%|80.65%|82.19%|[下载链接](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)|
|
||||
|DB|MobileNetV3|75.92%|73.18%|74.53%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)|
|
||||
|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[下载链接](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_icdar2015.tar)|
|
||||
|
||||
在Total-text文本检测公开数据集上,算法效果如下:
|
||||
|
||||
|模型|骨干网络|precision|recall|Hmean|下载链接|
|
||||
|-|-|-|-|-|-|
|
||||
|SAST|ResNet50_vd|88.74%|79.80%|84.03%|[下载链接](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)|
|
||||
|
||||
**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:[百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi)
|
||||
|
||||
|
||||
使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集共3w张数据,训练中文检测模型的相关配置和预训练文件如下:
|
||||
|
||||
|模型|骨干网络|配置文件|预训练模型|
|
||||
|-|-|-|-|
|
||||
|超轻量中文模型|MobileNetV3|det_mv3_db.yml|[下载链接](https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar)|
|
||||
|
@ -122,7 +132,7 @@ PaddleOCR开源的文本识别算法列表:
|
|||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
|
||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
|
||||
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研, coming soon)
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))(百度自研)
|
||||
|
||||
参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|
@ -136,6 +146,10 @@ PaddleOCR开源的文本识别算法列表:
|
|||
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_ctc.tar)|
|
||||
|RARE|Resnet34_vd|84.90%|rec_r34_vd_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_r34_vd_tps_bilstm_attn.tar)|
|
||||
|RARE|MobileNetV3|83.32%|rec_mv3_tps_bilstm_attn|[下载链接](https://paddleocr.bj.bcebos.com/rec_mv3_tps_bilstm_attn.tar)|
|
||||
|SRN|Resnet50_vd_fpn|88.33%|rec_r50fpn_vd_none_srn|[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)|
|
||||
|
||||
**说明:** SRN模型使用了数据扰动方法对上述提到对两个训练集进行增广,增广后的数据可以在[百度网盘](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA)上下载,提取码: y3ry。
|
||||
原始论文使用两阶段训练平均精度为89.74%,PaddleOCR中使用one-stage训练,平均精度为88.33%。两种预训练权重均在[下载链接](https://paddleocr.bj.bcebos.com/SRN/rec_r50fpn_vd_none_srn.tar)中。
|
||||
|
||||
使用[LSVT](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/datasets.md#1icdar2019-lsvt)街景数据集根据真值将图crop出来30w数据,进行位置校准。此外基于LSVT语料生成500w合成数据训练中文模型,相关配置和预训练文件如下:
|
||||
|
||||
|
@ -205,9 +219,10 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训
|
|||
## 贡献代码
|
||||
我们非常欢迎你为PaddleOCR贡献代码,也十分感谢你的反馈。
|
||||
|
||||
- 非常感谢 [Khanh Tran](https://github.com/xxxpsyduck) 贡献了英文文档
|
||||
- 非常感谢 [Khanh Tran](https://github.com/xxxpsyduck) 和 [Karl Horky](https://github.com/karlhorky) 贡献修改英文文档
|
||||
- 非常感谢 [zhangxin](https://github.com/ZhangXinNan)([Blog](https://blog.csdn.net/sdlypyzq)) 贡献新的可视化方式、添加.gitgnore、处理手动设置PYTHONPATH环境变量的问题
|
||||
- 非常感谢 [lyl120117](https://github.com/lyl120117) 贡献打印网络结构的代码
|
||||
- 非常感谢 [xiangyubo](https://github.com/xiangyubo) 贡献手写中文OCR数据集
|
||||
- 非常感谢 [authorfu](https://github.com/authorfu) 贡献Android和[xiadeye](https://github.com/xiadeye) 贡献IOS的demo代码
|
||||
- 非常感谢 [BeyondYourself](https://github.com/BeyondYourself) 给PaddleOCR提了很多非常棒的建议,并简化了PaddleOCR的部分代码风格。
|
||||
- 非常感谢 [tangmq](https://gitee.com/tangmq) 给PaddleOCR增加Docker化部署服务,支持快速发布可调用的Restful API服务。
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
Global:
|
||||
algorithm: SAST
|
||||
use_gpu: true
|
||||
epoch_num: 2000
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/det_sast/
|
||||
save_epoch_step: 20
|
||||
eval_batch_step: 5000
|
||||
train_batch_size_per_card: 8
|
||||
test_batch_size_per_card: 8
|
||||
image_shape: [3, 512, 512]
|
||||
reader_yml: ./configs/det/det_sast_icdar15_reader.yml
|
||||
pretrain_weights: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
||||
save_res_path: ./output/det_sast/predicts_sast.txt
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.det_model,DetModel
|
||||
|
||||
Backbone:
|
||||
function: ppocr.modeling.backbones.det_resnet_vd_sast,ResNet
|
||||
layers: 50
|
||||
|
||||
Head:
|
||||
function: ppocr.modeling.heads.det_sast_head,SASTHead
|
||||
model_name: large
|
||||
only_fpn_up: False
|
||||
# with_cab: False
|
||||
with_cab: True
|
||||
|
||||
Loss:
|
||||
function: ppocr.modeling.losses.det_sast_loss,SASTLoss
|
||||
|
||||
Optimizer:
|
||||
function: ppocr.optimizer,RMSProp
|
||||
base_lr: 0.001
|
||||
decay:
|
||||
function: piecewise_decay
|
||||
boundaries: [30000, 50000, 80000, 100000, 150000]
|
||||
decay_rate: 0.3
|
||||
|
||||
PostProcess:
|
||||
function: ppocr.postprocess.sast_postprocess,SASTPostProcess
|
||||
score_thresh: 0.5
|
||||
sample_pts_num: 2
|
||||
nms_thresh: 0.2
|
||||
expand_scale: 1.0
|
||||
shrink_ratio_of_width: 0.3
|
|
@ -0,0 +1,50 @@
|
|||
Global:
|
||||
algorithm: SAST
|
||||
use_gpu: true
|
||||
epoch_num: 2000
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/det_sast/
|
||||
save_epoch_step: 20
|
||||
eval_batch_step: 5000
|
||||
train_batch_size_per_card: 8
|
||||
test_batch_size_per_card: 1
|
||||
image_shape: [3, 512, 512]
|
||||
reader_yml: ./configs/det/det_sast_totaltext_reader.yml
|
||||
pretrain_weights: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
||||
save_res_path: ./output/det_sast/predicts_sast.txt
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.det_model,DetModel
|
||||
|
||||
Backbone:
|
||||
function: ppocr.modeling.backbones.det_resnet_vd_sast,ResNet
|
||||
layers: 50
|
||||
|
||||
Head:
|
||||
function: ppocr.modeling.heads.det_sast_head,SASTHead
|
||||
model_name: large
|
||||
only_fpn_up: False
|
||||
# with_cab: False
|
||||
with_cab: True
|
||||
|
||||
Loss:
|
||||
function: ppocr.modeling.losses.det_sast_loss,SASTLoss
|
||||
|
||||
Optimizer:
|
||||
function: ppocr.optimizer,RMSProp
|
||||
base_lr: 0.001
|
||||
decay:
|
||||
function: piecewise_decay
|
||||
boundaries: [30000, 50000, 80000, 100000, 150000]
|
||||
decay_rate: 0.3
|
||||
|
||||
PostProcess:
|
||||
function: ppocr.postprocess.sast_postprocess,SASTPostProcess
|
||||
score_thresh: 0.5
|
||||
sample_pts_num: 6
|
||||
nms_thresh: 0.2
|
||||
expand_scale: 1.2
|
||||
shrink_ratio_of_width: 0.2
|
|
@ -0,0 +1,24 @@
|
|||
TrainReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,TrainReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTrain
|
||||
num_workers: 8
|
||||
img_set_dir: ./train_data/
|
||||
label_file_path: [./train_data/icdar2013/train_label_json.txt, ./train_data/icdar2015/train_label_json.txt, ./train_data/icdar17_mlt_latin/train_label_json.txt, ./train_data/coco_text_icdar_4pts/train_label_json.txt]
|
||||
data_ratio_list: [0.1, 0.45, 0.3, 0.15]
|
||||
min_crop_side_ratio: 0.3
|
||||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
|
||||
EvalReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
||||
img_set_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
max_side_len: 1536
|
||||
|
||||
TestReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
||||
infer_img: ./train_data/icdar2015/text_localization/ch4_test_images/img_11.jpg
|
||||
max_side_len: 1536
|
|
@ -0,0 +1,24 @@
|
|||
TrainReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,TrainReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTrain
|
||||
num_workers: 8
|
||||
img_set_dir: ./train_data/
|
||||
label_file_path: [./train_data/art_latin_icdar_14pt/train_no_tt_test/train_label_json.txt, ./train_data/total_text_icdar_14pt/train_label_json.txt]
|
||||
data_ratio_list: [0.5, 0.5]
|
||||
min_crop_side_ratio: 0.3
|
||||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
|
||||
EvalReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
||||
img_set_dir: ./train_data/
|
||||
label_file_path: ./train_data/total_text_icdar_14pt/test_label_json.txt
|
||||
max_side_len: 768
|
||||
|
||||
TestReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
||||
infer_img: ./train_data/afs/total_text/Images/Test/img623.jpg
|
||||
max_side_len: 768
|
|
@ -0,0 +1,49 @@
|
|||
Global:
|
||||
algorithm: SRN
|
||||
use_gpu: true
|
||||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output/rec_pvam_withrotate
|
||||
save_epoch_step: 1
|
||||
eval_batch_step: 8000
|
||||
train_batch_size_per_card: 64
|
||||
test_batch_size_per_card: 1
|
||||
image_shape: [1, 64, 256]
|
||||
max_text_length: 25
|
||||
character_type: en
|
||||
loss_type: srn
|
||||
num_heads: 8
|
||||
average_window: 0.15
|
||||
max_average_window: 15625
|
||||
min_average_window: 10000
|
||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
infer_img:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
||||
Backbone:
|
||||
function: ppocr.modeling.backbones.rec_resnet50_fpn,ResNet
|
||||
layers: 50
|
||||
|
||||
Head:
|
||||
function: ppocr.modeling.heads.rec_srn_all_head,SRNPredict
|
||||
encoder_type: rnn
|
||||
num_encoder_TUs: 2
|
||||
num_decoder_TUs: 4
|
||||
hidden_dims: 512
|
||||
SeqRNN:
|
||||
hidden_size: 256
|
||||
|
||||
Loss:
|
||||
function: ppocr.modeling.losses.rec_srn_loss,SRNLoss
|
||||
|
||||
Optimizer:
|
||||
function: ppocr.optimizer,AdamDecay
|
||||
base_lr: 0.0001
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
|
@ -3,11 +3,11 @@ import java.security.MessageDigest
|
|||
apply plugin: 'com.android.application'
|
||||
|
||||
android {
|
||||
compileSdkVersion 28
|
||||
compileSdkVersion 29
|
||||
defaultConfig {
|
||||
applicationId "com.baidu.paddle.lite.demo.ocr"
|
||||
minSdkVersion 15
|
||||
targetSdkVersion 28
|
||||
minSdkVersion 23
|
||||
targetSdkVersion 29
|
||||
versionCode 1
|
||||
versionName "1.0"
|
||||
testInstrumentationRunner "android.support.test.runner.AndroidJUnitRunner"
|
||||
|
@ -39,9 +39,8 @@ android {
|
|||
|
||||
dependencies {
|
||||
implementation fileTree(include: ['*.jar'], dir: 'libs')
|
||||
implementation 'com.android.support:appcompat-v7:28.0.0'
|
||||
implementation 'com.android.support.constraint:constraint-layout:1.1.3'
|
||||
implementation 'com.android.support:design:28.0.0'
|
||||
implementation 'androidx.appcompat:appcompat:1.1.0'
|
||||
implementation 'androidx.constraintlayout:constraintlayout:1.1.3'
|
||||
testImplementation 'junit:junit:4.12'
|
||||
androidTestImplementation 'com.android.support.test:runner:1.0.2'
|
||||
androidTestImplementation 'com.android.support.test.espresso:espresso-core:3.0.2'
|
||||
|
|
|
@ -14,10 +14,10 @@
|
|||
android:roundIcon="@mipmap/ic_launcher_round"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/AppTheme">
|
||||
<!-- to test MiniActivity, change this to com.baidu.paddle.lite.demo.ocr.MiniActivity -->
|
||||
<activity android:name="com.baidu.paddle.lite.demo.ocr.MainActivity">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.MAIN"/>
|
||||
|
||||
<category android:name="android.intent.category.LAUNCHER"/>
|
||||
</intent-filter>
|
||||
</activity>
|
||||
|
@ -26,7 +26,7 @@
|
|||
android:label="Settings">
|
||||
</activity>
|
||||
<provider
|
||||
android:name="android.support.v4.content.FileProvider"
|
||||
android:name="androidx.core.content.FileProvider"
|
||||
android:authorities="com.baidu.paddle.lite.demo.ocr.fileprovider"
|
||||
android:exported="false"
|
||||
android:grantUriPermissions="true">
|
||||
|
|
|
@ -30,7 +30,7 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(JNIEnv *env, jobject
|
|||
}
|
||||
|
||||
/**
|
||||
* "LITE_POWER_HIGH" 转为 paddle::lite_api::LITE_POWER_HIGH
|
||||
* "LITE_POWER_HIGH" convert to paddle::lite_api::LITE_POWER_HIGH
|
||||
* @param cpu_mode
|
||||
* @return
|
||||
*/
|
||||
|
|
|
@ -37,7 +37,7 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path, const std:
|
|||
return RETURN_OK;
|
||||
}
|
||||
/**
|
||||
* 调试用,保存第一步的框选结果
|
||||
* for debug use, show result of First Step
|
||||
* @param filter_boxes
|
||||
* @param boxes
|
||||
* @param srcimg
|
||||
|
|
|
@ -12,26 +12,26 @@
|
|||
namespace ppredictor {
|
||||
|
||||
/**
|
||||
* 配置
|
||||
* Config
|
||||
*/
|
||||
struct OCR_Config {
|
||||
int thread_num = 4; // 线程数
|
||||
int thread_num = 4; // Thread num
|
||||
paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode
|
||||
};
|
||||
|
||||
/**
|
||||
* 一个四边形内图片的推理结果,
|
||||
* PolyGone Result
|
||||
*/
|
||||
struct OCRPredictResult {
|
||||
std::vector<int> word_index; //
|
||||
std::vector<int> word_index;
|
||||
std::vector<std::vector<int>> points;
|
||||
float score;
|
||||
};
|
||||
|
||||
/**
|
||||
* OCR 一共有2个模型进行推理,
|
||||
* 1. 使用第一个模型(det),框选出多个四边形
|
||||
* 2. 从原图从抠出这些多边形,使用第二个模型(rec),获取文本
|
||||
* OCR there are 2 models
|
||||
* 1. First model(det),select polygones to show where are the texts
|
||||
* 2. crop from the origin images, use these polygones to infer
|
||||
*/
|
||||
class OCR_PPredictor : public PPredictor_Interface {
|
||||
public:
|
||||
|
@ -50,7 +50,7 @@ public:
|
|||
int init(const std::string &det_model_content, const std::string &rec_model_content);
|
||||
int init_from_file(const std::string &det_model_path, const std::string &rec_model_path);
|
||||
/**
|
||||
* 返回OCR结果
|
||||
* Return OCR result
|
||||
* @param dims
|
||||
* @param input_data
|
||||
* @param input_len
|
||||
|
@ -69,7 +69,7 @@ public:
|
|||
private:
|
||||
|
||||
/**
|
||||
* 从第一个模型的结果中计算有文字的四边形
|
||||
* calcul Polygone from the result image of first model
|
||||
* @param pred
|
||||
* @param output_height
|
||||
* @param output_width
|
||||
|
@ -81,7 +81,7 @@ private:
|
|||
const cv::Mat &origin);
|
||||
|
||||
/**
|
||||
* 第二个模型的推理
|
||||
* infer for second model
|
||||
*
|
||||
* @param boxes
|
||||
* @param origin
|
||||
|
@ -91,14 +91,14 @@ private:
|
|||
infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes, const cv::Mat &origin);
|
||||
|
||||
/**
|
||||
* 第二个模型提取文字的后处理
|
||||
* Postprocess or sencod model to extract text
|
||||
* @param res
|
||||
* @return
|
||||
*/
|
||||
std::vector<int> postprocess_rec_word_index(const PredictorOutput &res);
|
||||
|
||||
/**
|
||||
* 计算第二个模型的文字的置信度
|
||||
* calculate confidence of second model text result
|
||||
* @param res
|
||||
* @return
|
||||
*/
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
namespace ppredictor {
|
||||
|
||||
/**
|
||||
* PaddleLite Preditor 通用接口
|
||||
* PaddleLite Preditor Common Interface
|
||||
*/
|
||||
class PPredictor_Interface {
|
||||
public:
|
||||
|
@ -21,7 +21,7 @@ public:
|
|||
};
|
||||
|
||||
/**
|
||||
* 通用推理
|
||||
* Common Predictor
|
||||
*/
|
||||
class PPredictor : public PPredictor_Interface {
|
||||
public:
|
||||
|
@ -33,9 +33,9 @@ public:
|
|||
}
|
||||
|
||||
/**
|
||||
* 初始化paddlitelite的opt模型,nb格式,与init_paddle二选一
|
||||
* init paddlitelite opt model,nb format ,or use ini_paddle
|
||||
* @param model_content
|
||||
* @return 0 目前是固定值0, 之后其他值表示失败
|
||||
* @return 0
|
||||
*/
|
||||
virtual int init_nb(const std::string &model_content);
|
||||
|
||||
|
|
|
@ -21,10 +21,10 @@ public:
|
|||
const std::vector<std::vector<uint64_t>> get_lod() const;
|
||||
const std::vector<int64_t> get_shape() const;
|
||||
|
||||
std::vector<float> data; // 通常是float返回,与下面的data_int二选一
|
||||
std::vector<int> data_int; // 少数层是int返回,与 data二选一
|
||||
std::vector<int64_t> shape; // PaddleLite输出层的shape
|
||||
std::vector<std::vector<uint64_t>> lod; // PaddleLite输出层的lod
|
||||
std::vector<float> data; // return float, or use data_int
|
||||
std::vector<int> data_int; // several layers return int ,or use data
|
||||
std::vector<int64_t> shape; // PaddleLite output shape
|
||||
std::vector<std::vector<uint64_t>> lod; // PaddleLite output lod
|
||||
|
||||
private:
|
||||
std::unique_ptr<const paddle::lite_api::Tensor> _tensor;
|
||||
|
|
|
@ -19,15 +19,16 @@ package com.baidu.paddle.lite.demo.ocr;
|
|||
import android.content.res.Configuration;
|
||||
import android.os.Bundle;
|
||||
import android.preference.PreferenceActivity;
|
||||
import android.support.annotation.LayoutRes;
|
||||
import android.support.annotation.Nullable;
|
||||
import android.support.v7.app.ActionBar;
|
||||
import android.support.v7.app.AppCompatDelegate;
|
||||
import android.support.v7.widget.Toolbar;
|
||||
import android.view.MenuInflater;
|
||||
import android.view.View;
|
||||
import android.view.ViewGroup;
|
||||
|
||||
import androidx.annotation.LayoutRes;
|
||||
import androidx.annotation.Nullable;
|
||||
import androidx.appcompat.app.ActionBar;
|
||||
import androidx.appcompat.app.AppCompatDelegate;
|
||||
import androidx.appcompat.widget.Toolbar;
|
||||
|
||||
/**
|
||||
* A {@link PreferenceActivity} which implements and proxies the necessary calls
|
||||
* to be used with AppCompat.
|
||||
|
|
|
@ -19,11 +19,6 @@ import android.os.HandlerThread;
|
|||
import android.os.Message;
|
||||
import android.preference.PreferenceManager;
|
||||
import android.provider.MediaStore;
|
||||
import android.support.annotation.NonNull;
|
||||
import android.support.v4.app.ActivityCompat;
|
||||
import android.support.v4.content.ContextCompat;
|
||||
import android.support.v4.content.FileProvider;
|
||||
import android.support.v7.app.AppCompatActivity;
|
||||
import android.text.method.ScrollingMovementMethod;
|
||||
import android.util.Log;
|
||||
import android.view.Menu;
|
||||
|
@ -33,6 +28,12 @@ import android.widget.ImageView;
|
|||
import android.widget.TextView;
|
||||
import android.widget.Toast;
|
||||
|
||||
import androidx.annotation.NonNull;
|
||||
import androidx.appcompat.app.AppCompatActivity;
|
||||
import androidx.core.app.ActivityCompat;
|
||||
import androidx.core.content.ContextCompat;
|
||||
import androidx.core.content.FileProvider;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
package com.baidu.paddle.lite.demo.ocr;
|
||||
|
||||
import android.graphics.Bitmap;
|
||||
import android.graphics.BitmapFactory;
|
||||
import android.os.Build;
|
||||
import android.os.Bundle;
|
||||
import android.os.Handler;
|
||||
import android.os.HandlerThread;
|
||||
import android.os.Message;
|
||||
import android.util.Log;
|
||||
import android.view.View;
|
||||
import android.widget.Button;
|
||||
import android.widget.ImageView;
|
||||
import android.widget.TextView;
|
||||
import android.widget.Toast;
|
||||
|
||||
import androidx.appcompat.app.AppCompatActivity;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
|
||||
public class MiniActivity extends AppCompatActivity {
|
||||
|
||||
|
||||
public static final int REQUEST_LOAD_MODEL = 0;
|
||||
public static final int REQUEST_RUN_MODEL = 1;
|
||||
public static final int REQUEST_UNLOAD_MODEL = 2;
|
||||
public static final int RESPONSE_LOAD_MODEL_SUCCESSED = 0;
|
||||
public static final int RESPONSE_LOAD_MODEL_FAILED = 1;
|
||||
public static final int RESPONSE_RUN_MODEL_SUCCESSED = 2;
|
||||
public static final int RESPONSE_RUN_MODEL_FAILED = 3;
|
||||
|
||||
private static final String TAG = "MiniActivity";
|
||||
|
||||
protected Handler receiver = null; // Receive messages from worker thread
|
||||
protected Handler sender = null; // Send command to worker thread
|
||||
protected HandlerThread worker = null; // Worker thread to load&run model
|
||||
protected volatile Predictor predictor = null;
|
||||
|
||||
private String assetModelDirPath = "models/ocr_v1_for_cpu";
|
||||
private String assetlabelFilePath = "labels/ppocr_keys_v1.txt";
|
||||
|
||||
private Button button;
|
||||
private ImageView imageView; // image result
|
||||
private TextView textView; // text result
|
||||
|
||||
@Override
|
||||
protected void onCreate(Bundle savedInstanceState) {
|
||||
super.onCreate(savedInstanceState);
|
||||
setContentView(R.layout.activity_mini);
|
||||
|
||||
Log.i(TAG, "SHOW in Logcat");
|
||||
|
||||
// Prepare the worker thread for mode loading and inference
|
||||
worker = new HandlerThread("Predictor Worker");
|
||||
worker.start();
|
||||
sender = new Handler(worker.getLooper()) {
|
||||
public void handleMessage(Message msg) {
|
||||
switch (msg.what) {
|
||||
case REQUEST_LOAD_MODEL:
|
||||
// Load model and reload test image
|
||||
if (!onLoadModel()) {
|
||||
runOnUiThread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
Toast.makeText(MiniActivity.this, "Load model failed!", Toast.LENGTH_SHORT).show();
|
||||
}
|
||||
});
|
||||
}
|
||||
break;
|
||||
case REQUEST_RUN_MODEL:
|
||||
// Run model if model is loaded
|
||||
final boolean isSuccessed = onRunModel();
|
||||
runOnUiThread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
if (isSuccessed){
|
||||
onRunModelSuccessed();
|
||||
}else{
|
||||
Toast.makeText(MiniActivity.this, "Run model failed!", Toast.LENGTH_SHORT).show();
|
||||
}
|
||||
}
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
sender.sendEmptyMessage(REQUEST_LOAD_MODEL); // corresponding to REQUEST_LOAD_MODEL, to call onLoadModel()
|
||||
|
||||
imageView = findViewById(R.id.imageView);
|
||||
textView = findViewById(R.id.sample_text);
|
||||
button = findViewById(R.id.button);
|
||||
button.setOnClickListener(new View.OnClickListener() {
|
||||
@Override
|
||||
public void onClick(View v) {
|
||||
sender.sendEmptyMessage(REQUEST_RUN_MODEL);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void onDestroy() {
|
||||
onUnloadModel();
|
||||
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.JELLY_BEAN_MR2) {
|
||||
worker.quitSafely();
|
||||
} else {
|
||||
worker.quit();
|
||||
}
|
||||
super.onDestroy();
|
||||
}
|
||||
|
||||
/**
|
||||
* call in onCreate, model init
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private boolean onLoadModel() {
|
||||
if (predictor == null) {
|
||||
predictor = new Predictor();
|
||||
}
|
||||
return predictor.init(this, assetModelDirPath, assetlabelFilePath);
|
||||
}
|
||||
|
||||
/**
|
||||
* init engine
|
||||
* call in onCreate
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
private boolean onRunModel() {
|
||||
try {
|
||||
String assetImagePath = "images/5.jpg";
|
||||
InputStream imageStream = getAssets().open(assetImagePath);
|
||||
Bitmap image = BitmapFactory.decodeStream(imageStream);
|
||||
// Input is Bitmap
|
||||
predictor.setInputImage(image);
|
||||
return predictor.isLoaded() && predictor.runModel();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private void onRunModelSuccessed() {
|
||||
Log.i(TAG, "onRunModelSuccessed");
|
||||
textView.setText(predictor.outputResult);
|
||||
imageView.setImageBitmap(predictor.outputImage);
|
||||
}
|
||||
|
||||
private void onUnloadModel() {
|
||||
if (predictor != null) {
|
||||
predictor.releaseModel();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -38,7 +38,7 @@ public class Predictor {
|
|||
protected float scoreThreshold = 0.1f;
|
||||
protected Bitmap inputImage = null;
|
||||
protected Bitmap outputImage = null;
|
||||
protected String outputResult = "";
|
||||
protected volatile String outputResult = "";
|
||||
protected float preprocessTime = 0;
|
||||
protected float postprocessTime = 0;
|
||||
|
||||
|
@ -46,6 +46,16 @@ public class Predictor {
|
|||
public Predictor() {
|
||||
}
|
||||
|
||||
public boolean init(Context appCtx, String modelPath, String labelPath) {
|
||||
isLoaded = loadModel(appCtx, modelPath, cpuThreadNum, cpuPowerMode);
|
||||
if (!isLoaded) {
|
||||
return false;
|
||||
}
|
||||
isLoaded = loadLabel(appCtx, labelPath);
|
||||
return isLoaded;
|
||||
}
|
||||
|
||||
|
||||
public boolean init(Context appCtx, String modelPath, String labelPath, int cpuThreadNum, String cpuPowerMode,
|
||||
String inputColorFormat,
|
||||
long[] inputShape, float[] inputMean,
|
||||
|
@ -76,11 +86,7 @@ public class Predictor {
|
|||
Log.e(TAG, "Only BGR color format is supported.");
|
||||
return false;
|
||||
}
|
||||
isLoaded = loadModel(appCtx, modelPath, cpuThreadNum, cpuPowerMode);
|
||||
if (!isLoaded) {
|
||||
return false;
|
||||
}
|
||||
isLoaded = loadLabel(appCtx, labelPath);
|
||||
boolean isLoaded = init(appCtx, modelPath, labelPath);
|
||||
if (!isLoaded) {
|
||||
return false;
|
||||
}
|
||||
|
@ -222,7 +228,7 @@ public class Predictor {
|
|||
for (int i = 0; i < warmupIterNum; i++) {
|
||||
paddlePredictor.runImage(inputData, width, height, channels, inputImage);
|
||||
}
|
||||
warmupIterNum = 0; // 之后不要再warm了
|
||||
warmupIterNum = 0; // do not need warm
|
||||
// Run inference
|
||||
start = new Date();
|
||||
ArrayList<OcrResultModel> results = paddlePredictor.runImage(inputData, width, height, channels, inputImage);
|
||||
|
@ -317,7 +323,7 @@ public class Predictor {
|
|||
for (Point p : result.getPoints()) {
|
||||
sb.append("(").append(p.x).append(",").append(p.y).append(") ");
|
||||
}
|
||||
Log.i(TAG, sb.toString());
|
||||
Log.i(TAG, sb.toString()); // show LOG in Logcat panel
|
||||
outputResultSb.append(i + 1).append(": ").append(result.getLabel()).append("\n");
|
||||
}
|
||||
outputResult = outputResultSb.toString();
|
||||
|
|
|
@ -5,7 +5,8 @@ import android.os.Bundle;
|
|||
import android.preference.CheckBoxPreference;
|
||||
import android.preference.EditTextPreference;
|
||||
import android.preference.ListPreference;
|
||||
import android.support.v7.app.ActionBar;
|
||||
|
||||
import androidx.appcompat.app.ActionBar;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<android.support.constraint.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
|
@ -96,4 +96,4 @@
|
|||
|
||||
</RelativeLayout>
|
||||
|
||||
</android.support.constraint.ConstraintLayout>
|
||||
</androidx.constraintlayout.widget.ConstraintLayout>
|
|
@ -0,0 +1,46 @@
|
|||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!-- for MiniActivity Use Only -->
|
||||
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
|
||||
xmlns:app="http://schemas.android.com/apk/res-auto"
|
||||
xmlns:tools="http://schemas.android.com/tools"
|
||||
android:layout_width="match_parent"
|
||||
android:layout_height="match_parent"
|
||||
app:layout_constraintLeft_toLeftOf="parent"
|
||||
app:layout_constraintLeft_toRightOf="parent"
|
||||
tools:context=".MainActivity">
|
||||
|
||||
<TextView
|
||||
android:id="@+id/sample_text"
|
||||
android:layout_width="0dp"
|
||||
android:layout_height="wrap_content"
|
||||
android:text="Hello World!"
|
||||
app:layout_constraintLeft_toLeftOf="parent"
|
||||
app:layout_constraintRight_toRightOf="parent"
|
||||
app:layout_constraintTop_toBottomOf="@id/imageView"
|
||||
android:scrollbars="vertical"
|
||||
/>
|
||||
|
||||
<ImageView
|
||||
android:id="@+id/imageView"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:paddingTop="20dp"
|
||||
android:paddingBottom="20dp"
|
||||
app:layout_constraintBottom_toTopOf="@id/imageView"
|
||||
app:layout_constraintLeft_toLeftOf="parent"
|
||||
app:layout_constraintRight_toRightOf="parent"
|
||||
app:layout_constraintTop_toTopOf="parent"
|
||||
tools:srcCompat="@tools:sample/avatars" />
|
||||
|
||||
<Button
|
||||
android:id="@+id/button"
|
||||
android:layout_width="wrap_content"
|
||||
android:layout_height="wrap_content"
|
||||
android:layout_marginBottom="4dp"
|
||||
android:text="Button"
|
||||
app:layout_constraintBottom_toBottomOf="parent"
|
||||
app:layout_constraintLeft_toLeftOf="parent"
|
||||
app:layout_constraintRight_toRightOf="parent"
|
||||
tools:layout_editor_absoluteX="161dp" />
|
||||
|
||||
</androidx.constraintlayout.widget.ConstraintLayout>
|
|
@ -1,4 +1,4 @@
|
|||
#Thu Aug 22 15:05:37 CST 2019
|
||||
#Wed Jul 22 23:48:44 CST 2020
|
||||
distributionBase=GRADLE_USER_HOME
|
||||
distributionPath=wrapper/dists
|
||||
zipStoreBase=GRADLE_USER_HOME
|
||||
|
|
|
@ -47,7 +47,7 @@ public:
|
|||
|
||||
this->det_db_box_thresh = stod(config_map_["det_db_box_thresh"]);
|
||||
|
||||
this->det_db_box_thresh = stod(config_map_["det_db_box_thresh"]);
|
||||
this->det_db_unclip_ratio = stod(config_map_["det_db_unclip_ratio"]);
|
||||
|
||||
this->det_model_dir.assign(config_map_["det_model_dir"]);
|
||||
|
||||
|
|
|
@ -33,6 +33,8 @@ def read_params():
|
|||
cfg.rec_image_shape = "3, 32, 320"
|
||||
cfg.rec_char_type = 'ch'
|
||||
cfg.rec_batch_num = 30
|
||||
cfg.max_text_length = 25
|
||||
|
||||
cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt"
|
||||
cfg.use_space_char = True
|
||||
|
||||
|
|
|
@ -33,6 +33,8 @@ def read_params():
|
|||
cfg.rec_image_shape = "3, 32, 320"
|
||||
cfg.rec_char_type = 'ch'
|
||||
cfg.rec_batch_num = 30
|
||||
cfg.max_text_length = 25
|
||||
|
||||
cfg.rec_char_dict_path = "./ppocr/utils/ppocr_keys_v1.txt"
|
||||
cfg.use_space_char = True
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ download_and_extract() {
|
|||
}
|
||||
|
||||
echo -e "[Download ios ocr demo denpendancy]\n"
|
||||
download_and_extract "${OCR_MODEL_URL}" "./ios-demo/ocr_demo/models"
|
||||
download_and_extract "${PADDLE_LITE_LIB_URL}" "./ios-demo/ocr_demo"
|
||||
download_and_extract "${OPENCV3_FRAMEWORK_URL}" "./ios-demo/ocr_demo"
|
||||
download_and_extract "${OCR_MODEL_URL}" "./ocr_demo/models"
|
||||
download_and_extract "${PADDLE_LITE_LIB_URL}" "./ocr_demo"
|
||||
download_and_extract "${OPENCV3_FRAMEWORK_URL}" "./ocr_demo"
|
||||
echo -e "[done]\n"
|
||||
|
|
|
@ -13,7 +13,7 @@ deployment solutions for end-side deployment issues.
|
|||
- Computer (for Compiling Paddle Lite)
|
||||
- Mobile phone (arm7 or arm8)
|
||||
|
||||
## 2. Build ncnn library
|
||||
## 2. Build PaddleLite library
|
||||
[build for Docker](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#docker)
|
||||
[build for Linux](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#android)
|
||||
[build for MAC OS](https://paddle-lite.readthedocs.io/zh/latest/user_guides/source_compile.html#id13)
|
||||
|
|
|
@ -21,7 +21,10 @@ from paddle_serving_client import Client
|
|||
from paddle_serving_app.reader import Sequential, ResizeByFactor
|
||||
from paddle_serving_app.reader import Div, Normalize, Transpose
|
||||
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
if sys.argv[1] == 'gpu':
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
elif sys.argv[1] == 'cpu':
|
||||
from paddle_serving_server.web_service import WebService
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
|
@ -64,8 +67,13 @@ class OCRService(WebService):
|
|||
|
||||
ocr_service = OCRService(name="ocr")
|
||||
ocr_service.load_model_config("ocr_det_model")
|
||||
ocr_service.set_gpus("0")
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
ocr_service.init_det()
|
||||
ocr_service.run_debugger_service()
|
||||
if sys.argv[1] == 'gpu':
|
||||
ocr_service.set_gpus("0")
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
ocr_service.run_debugger_service(gpu=True)
|
||||
elif sys.argv[1] == 'cpu':
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292)
|
||||
ocr_service.run_debugger_service()
|
||||
ocr_service.init_det()
|
||||
ocr_service.run_web_service()
|
||||
|
|
|
@ -21,7 +21,10 @@ from paddle_serving_client import Client
|
|||
from paddle_serving_app.reader import Sequential, ResizeByFactor
|
||||
from paddle_serving_app.reader import Div, Normalize, Transpose
|
||||
from paddle_serving_app.reader import DBPostProcess, FilterBoxes
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
if sys.argv[1] == 'gpu':
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
elif sys.argv[1] == 'cpu':
|
||||
from paddle_serving_server.web_service import WebService
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
|
@ -65,8 +68,11 @@ class OCRService(WebService):
|
|||
|
||||
ocr_service = OCRService(name="ocr")
|
||||
ocr_service.load_model_config("ocr_det_model")
|
||||
ocr_service.set_gpus("0")
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
if sys.argv[1] == 'gpu':
|
||||
ocr_service.set_gpus("0")
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
elif sys.argv[1] == 'cpu':
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
|
||||
ocr_service.init_det()
|
||||
ocr_service.run_rpc_service()
|
||||
ocr_service.run_web_service()
|
||||
|
|
|
@ -22,7 +22,10 @@ from paddle_serving_client import Client
|
|||
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
|
||||
from paddle_serving_app.reader import Div, Normalize, Transpose
|
||||
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
if sys.argv[1] == 'gpu':
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
elif sys.argv[1] == 'cpu':
|
||||
from paddle_serving_server.web_service import WebService
|
||||
from paddle_serving_app.local_predict import Debugger
|
||||
import time
|
||||
import re
|
||||
|
@ -37,8 +40,12 @@ class OCRService(WebService):
|
|||
(2, 0, 1))
|
||||
])
|
||||
self.det_client = Debugger()
|
||||
self.det_client.load_model_config(
|
||||
det_model_config, gpu=True, profile=False)
|
||||
if sys.argv[1] == 'gpu':
|
||||
self.det_client.load_model_config(
|
||||
det_model_config, gpu=True, profile=False)
|
||||
elif sys.argv[1] == 'cpu':
|
||||
self.det_client.load_model_config(
|
||||
det_model_config, gpu=False, profile=False)
|
||||
self.ocr_reader = OCRReader()
|
||||
|
||||
def preprocess(self, feed=[], fetch=[]):
|
||||
|
@ -97,7 +104,11 @@ class OCRService(WebService):
|
|||
|
||||
ocr_service = OCRService(name="ocr")
|
||||
ocr_service.load_model_config("ocr_rec_model")
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292)
|
||||
ocr_service.init_det_debugger(det_model_config="ocr_det_model")
|
||||
ocr_service.run_debugger_service(gpu=True)
|
||||
if sys.argv[1] == 'gpu':
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
ocr_service.run_debugger_service(gpu=True)
|
||||
elif sys.argv[1] == 'cpu':
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
|
||||
ocr_service.run_debugger_service()
|
||||
ocr_service.run_web_service()
|
||||
|
|
|
@ -22,7 +22,10 @@ from paddle_serving_client import Client
|
|||
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
|
||||
from paddle_serving_app.reader import Div, Normalize, Transpose
|
||||
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
if sys.argv[1] == 'gpu':
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
elif sys.argv[1] == 'cpu':
|
||||
from paddle_serving_server.web_service import WebService
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
|
@ -90,8 +93,11 @@ class OCRService(WebService):
|
|||
|
||||
ocr_service = OCRService(name="ocr")
|
||||
ocr_service.load_model_config("ocr_rec_model")
|
||||
ocr_service.set_gpus("0")
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
if sys.argv[1] == 'gpu':
|
||||
ocr_service.set_gpus("0")
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
elif sys.argv[1] == 'cpu':
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292)
|
||||
ocr_service.init_det_client(
|
||||
det_port=9293,
|
||||
det_client_config="ocr_det_client/serving_client_conf.prototxt")
|
||||
|
|
|
@ -16,20 +16,33 @@
|
|||
|
||||
**Python3操作指南:**
|
||||
```
|
||||
#以下提供beta版本的paddle serving whl包,欢迎试用,正式版会在7月底正式上线
|
||||
#以下提供beta版本的paddle serving whl包,欢迎试用,正式版会在8月中正式上线
|
||||
#GPU用户下载server包使用这个链接
|
||||
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/others/paddle_serving_server_gpu-0.3.2-py3-none-any.whl
|
||||
python -m pip install paddle_serving_server_gpu-0.3.2-py3-none-any.whl
|
||||
#CPU版本使用这个链接
|
||||
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/others/paddle_serving_server-0.3.2-py3-none-any.whl
|
||||
python -m pip install paddle_serving_server-0.3.2-py3-none-any.whl
|
||||
#客户端和App包使用以下链接(CPU,GPU通用)
|
||||
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/others/paddle_serving_client-0.3.2-cp36-none-any.whl
|
||||
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/others/paddle_serving_app-0.1.2-py3-none-any.whl
|
||||
python -m pip install paddle_serving_app-0.1.2-py3-none-any.whl paddle_serving_server_gpu-0.3.2-py3-none-any.whl paddle_serving_client-0.3.2-cp36-none-any.whl
|
||||
python -m pip install paddle_serving_app-0.1.2-py3-none-any.whl paddle_serving_client-0.3.2-cp36-none-any.whl
|
||||
```
|
||||
|
||||
**Python2操作指南:**
|
||||
```
|
||||
#以下提供beta版本的paddle serving whl包,欢迎试用,正式版会在7月底正式上线
|
||||
#以下提供beta版本的paddle serving whl包,欢迎试用,正式版会在8月中正式上线
|
||||
#GPU用户下载server包使用这个链接
|
||||
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/others/paddle_serving_server_gpu-0.3.2-py2-none-any.whl
|
||||
python -m pip install paddle_serving_server_gpu-0.3.2-py2-none-any.whl
|
||||
#CPU版本使用这个链接
|
||||
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/others/paddle_serving_server-0.3.2-py2-none-any.whl
|
||||
python -m pip install paddle_serving_server-0.3.2-py2-none-any.whl
|
||||
|
||||
#客户端和App包使用以下链接(CPU,GPU通用)
|
||||
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/others/paddle_serving_app-0.1.2-py2-none-any.whl
|
||||
wget --no-check-certificate https://paddle-serving.bj.bcebos.com/others/paddle_serving_client-0.3.2-cp27-none-any.whl
|
||||
python -m pip install paddle_serving_app-0.1.2-py2-none-any.whl paddle_serving_server_gpu-0.3.2-py2-none-any.whl paddle_serving_client-0.3.2-cp27-none-any.whl
|
||||
python -m pip install paddle_serving_app-0.1.2-py2-none-any.whl paddle_serving_client-0.3.2-cp27-none-any.whl
|
||||
```
|
||||
|
||||
### 2. 模型转换
|
||||
|
@ -42,6 +55,23 @@ tar -xzvf ocr_det.tar.gz
|
|||
```
|
||||
执行上述命令会下载`db_crnn_mobile`的模型,如果想要下载规模更大的`db_crnn_server`模型,可以在下载预测模型并解压之后。参考[如何从Paddle保存的预测模型转为Paddle Serving格式可部署的模型](https://github.com/PaddlePaddle/Serving/blob/develop/doc/INFERENCE_TO_SERVING_CN.md)。
|
||||
|
||||
我们以`ch_rec_r34_vd_crnn`模型作为例子,下载链接在:
|
||||
|
||||
```
|
||||
wget --no-check-certificate https://paddleocr.bj.bcebos.com/ch_models/ch_rec_r34_vd_crnn_infer.tar
|
||||
tar xf ch_rec_r34_vd_crnn_infer.tar
|
||||
```
|
||||
因此我们按照Serving模型转换教程,运行下列python文件。
|
||||
```
|
||||
from paddle_serving_client.io import inference_model_to_serving
|
||||
inference_model_dir = "ch_rec_r34_vd_crnn"
|
||||
serving_client_dir = "serving_client_dir"
|
||||
serving_server_dir = "serving_server_dir"
|
||||
feed_var_names, fetch_var_names = inference_model_to_serving(
|
||||
inference_model_dir, serving_client_dir, serving_server_dir, model_filename="model", params_filename="params")
|
||||
```
|
||||
最终会在`serving_client_dir`和`serving_server_dir`生成客户端和服务端的模型配置。
|
||||
|
||||
### 3. 启动服务
|
||||
启动服务可以根据实际需求选择启动`标准版`或者`快速版`,两种方式的对比如下表:
|
||||
|
||||
|
@ -53,14 +83,21 @@ tar -xzvf ocr_det.tar.gz
|
|||
#### 方式1. 启动标准版服务
|
||||
|
||||
```
|
||||
# cpu,gpu启动二选一,以下是cpu启动
|
||||
python -m paddle_serving_server.serve --model ocr_det_model --port 9293
|
||||
python ocr_web_server.py cpu
|
||||
# gpu启动
|
||||
python -m paddle_serving_server_gpu.serve --model ocr_det_model --port 9293 --gpu_id 0
|
||||
python ocr_web_server.py
|
||||
python ocr_web_server.py gpu
|
||||
```
|
||||
|
||||
#### 方式2. 启动快速版服务
|
||||
|
||||
```
|
||||
python ocr_local_server.py
|
||||
# cpu,gpu启动二选一,以下是cpu启动
|
||||
python ocr_local_server.py cpu
|
||||
# gpu启动
|
||||
python ocr_local_server.py gpu
|
||||
```
|
||||
|
||||
## 发送预测请求
|
||||
|
@ -85,7 +122,7 @@ python ocr_web_client.py
|
|||
|
||||
在`ocr_web_server.py`或是`ocr_local_server.py`当中的`preprocess`函数里面做了检测服务和识别服务的前处理,`postprocess`函数里面做了识别的后处理服务,可以在相应的函数中做修改。调用了`paddle_serving_app`库提供的常见CV模型的前处理/后处理库。
|
||||
|
||||
如果想要单独启动Paddle Serving的检测服务和识别服务,参见下列表格, 执行对应的脚本即可。
|
||||
如果想要单独启动Paddle Serving的检测服务和识别服务,参见下列表格, 执行对应的脚本即可,并且在命令行参数注明用的CPU或是GPU来提供服务。
|
||||
|
||||
| 模型 | 标准版 | 快速版 |
|
||||
| ---- | ----------------- | ------------------- |
|
||||
|
|
|
@ -22,7 +22,10 @@ from paddle_serving_client import Client
|
|||
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
|
||||
from paddle_serving_app.reader import Div, Normalize, Transpose
|
||||
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
if sys.argv[1] == 'gpu':
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
elif sys.argv[1] == 'cpu':
|
||||
from paddle_serving_server.web_service import WebService
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
|
@ -65,8 +68,12 @@ class OCRService(WebService):
|
|||
|
||||
ocr_service = OCRService(name="ocr")
|
||||
ocr_service.load_model_config("ocr_rec_model")
|
||||
ocr_service.set_gpus("0")
|
||||
ocr_service.init_rec()
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
ocr_service.run_debugger_service()
|
||||
if sys.argv[1] == 'gpu':
|
||||
ocr_service.set_gpus("0")
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
ocr_service.run_debugger_service(gpu=True)
|
||||
elif sys.argv[1] == 'cpu':
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
|
||||
ocr_service.run_debugger_service()
|
||||
ocr_service.run_web_service()
|
||||
|
|
|
@ -22,7 +22,10 @@ from paddle_serving_client import Client
|
|||
from paddle_serving_app.reader import Sequential, URL2Image, ResizeByFactor
|
||||
from paddle_serving_app.reader import Div, Normalize, Transpose
|
||||
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
if sys.argv[1] == 'gpu':
|
||||
from paddle_serving_server_gpu.web_service import WebService
|
||||
elif sys.argv[1] == 'cpu':
|
||||
from paddle_serving_server.web_service import WebService
|
||||
import time
|
||||
import re
|
||||
import base64
|
||||
|
@ -64,8 +67,11 @@ class OCRService(WebService):
|
|||
|
||||
ocr_service = OCRService(name="ocr")
|
||||
ocr_service.load_model_config("ocr_rec_model")
|
||||
ocr_service.set_gpus("0")
|
||||
ocr_service.init_rec()
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
if sys.argv[1] == 'gpu':
|
||||
ocr_service.set_gpus("0")
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="gpu", gpuid=0)
|
||||
elif sys.argv[1] == 'cpu':
|
||||
ocr_service.prepare_server(workdir="workdir", port=9292, device="cpu")
|
||||
ocr_service.run_rpc_service()
|
||||
ocr_service.run_web_service()
|
||||
|
|
|
@ -39,7 +39,7 @@ PaddleOCR已完成Windows和Mac系统适配,运行时注意两点:1、在[
|
|||
- 识别:
|
||||
英文数据集,MJSynth和SynthText合成数据,数据量上千万。
|
||||
中文数据集,LSVT街景数据集根据真值将图crop出来,并进行位置校准,总共30w张图像。此外基于LSVT的语料,合成数据500w。
|
||||
|
||||
|
||||
其中,公开数据集都是开源的,用户可自行搜索下载,也可参考[中文数据集](./datasets.md),合成数据暂不开源,用户可使用开源合成工具自行合成,可参考的合成工具包括[text_renderer](https://github.com/Sanster/text_renderer)、[SynthText](https://github.com/ankush-me/SynthText)、[TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator)等。
|
||||
|
||||
10. **使用带TPS的识别模型预测报错**
|
||||
|
@ -49,3 +49,5 @@ PaddleOCR已完成Windows和Mac系统适配,运行时注意两点:1、在[
|
|||
11. **自定义字典训练的模型,识别结果出现字典里没出现的字**
|
||||
预测时没有设置采用的自定义字典路径。设置方法是在预测时,通过增加输入参数rec_char_dict_path来设置。
|
||||
|
||||
12. **cpp infer与python inference的结果不一致,相差较大**
|
||||
导出的inference model版本与预测库版本需要保持一致,比如在Windows下,Paddle官网提供的预测库版本是1.8,而PaddleOCR提供的inference model 版本是1.7,因此最终预测结果会有差别。可以在Paddle1.8环境下导出模型,再基于该模型进行预测。
|
||||
|
|
|
@ -32,6 +32,9 @@
|
|||
| loss_type | 设置 loss 类型 | ctc | 支持两种loss: ctc / attention |
|
||||
| distort | 设置是否使用数据增强 | false | 设置为true时,将在训练时随机进行扰动,支持的扰动操作可阅读[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) |
|
||||
| use_space_char | 设置是否识别空格 | false | 仅在 character_type=ch 时支持空格 |
|
||||
| average_window | ModelAverage优化器中的窗口长度计算比例 | 0.15 | 目前仅应用与SRN |
|
||||
| max_average_window | 平均值计算窗口长度的最大值 | 15625 | 推荐设置为一轮训练中mini-batchs的数目|
|
||||
| min_average_window | 平均值计算窗口长度的最小值 | 10000 | \ |
|
||||
| reader_yml | 设置reader配置文件 | ./configs/rec/rec_icdar15_reader.yml | \ |
|
||||
| pretrain_weights | 加载预训练模型路径 | ./pretrain_models/CRNN/best_accuracy | \ |
|
||||
| checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 |
|
||||
|
@ -60,6 +63,9 @@
|
|||
| beta1 | 设置一阶矩估计的指数衰减率 | 0.9 | \ |
|
||||
| beta2 | 设置二阶矩估计的指数衰减率 | 0.999 | \ |
|
||||
| decay | 是否使用decay | \ | \ |
|
||||
| function(decay) | 设置decay方式 | cosine_decay | 目前只支持cosin_decay |
|
||||
| step_each_epoch | 每个epoch包含多少次迭代 | 20 | 计算方式:total_image_num / (batch_size_per_card * card_size) |
|
||||
| total_epoch | 总共迭代多少个epoch | 1000 | 与Global.epoch_num 一致 |
|
||||
| function(decay) | 设置decay方式 | - | 目前支持cosine_decay, cosine_decay_warmup与piecewise_decay |
|
||||
| step_each_epoch | 每个epoch包含多少次迭代, cosine_decay/cosine_decay_warmup时有效 | 20 | 计算方式:total_image_num / (batch_size_per_card * card_size) |
|
||||
| total_epoch | 总共迭代多少个epoch, cosine_decay/cosine_decay_warmup时有效 | 1000 | 与Global.epoch_num 一致 |
|
||||
| warmup_minibatch | 线性warmup的迭代次数, cosine_decay_warmup时有效 | 1000 | \ |
|
||||
| boundaries | 学习率下降时的迭代次数间隔, piecewise_decay时有效 | - | 参数为列表形式 |
|
||||
| decay_rate | 学习率衰减系数, piecewise_decay时有效 | - | \ |
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
# 文字检测
|
||||
|
||||
本节以icdar15数据集为例,介绍PaddleOCR中检测模型的训练、评估与测试。
|
||||
本节以icdar2015数据集为例,介绍PaddleOCR中检测模型的训练、评估与测试。
|
||||
|
||||
## 数据准备
|
||||
icdar2015数据集可以从[官网](https://rrc.cvc.uab.es/?ch=4&com=downloads)下载到,首次下载需注册。
|
||||
|
||||
将下载到的数据集解压到工作目录下,假设解压在 PaddleOCR/train_data/ 下。另外,PaddleOCR将零散的标注文件整理成单独的标注文件
|
||||
,您可以通过wget的方式进行下载。
|
||||
```
|
||||
```shell
|
||||
# 在PaddleOCR路径下
|
||||
cd PaddleOCR/
|
||||
wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt
|
||||
|
@ -23,21 +23,21 @@ wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_la
|
|||
└─ test_icdar2015_label.txt icdar数据集的测试标注
|
||||
```
|
||||
|
||||
提供的标注文件格式为,其中中间是"\t"分隔:
|
||||
提供的标注文件格式如下,中间用"\t"分隔:
|
||||
```
|
||||
" 图像文件名 json.dumps编码的图像标注信息"
|
||||
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]
|
||||
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]
|
||||
```
|
||||
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。
|
||||
`transcription` 表示当前文本框的文字,在文本检测任务中并不需要这个信息。
|
||||
如果您想在其他数据集上训练PaddleOCR,可以按照上述形式构建标注文件。
|
||||
`transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
|
||||
|
||||
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。
|
||||
|
||||
## 快速启动训练
|
||||
|
||||
首先下载模型backbone的pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet50_vd,
|
||||
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。
|
||||
```
|
||||
```shell
|
||||
cd PaddleOCR/
|
||||
# 下载MobileNetV3的预训练模型
|
||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
|
||||
|
@ -45,7 +45,7 @@ wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/Mob
|
|||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
|
||||
|
||||
# 解压预训练模型文件,以MobileNetV3为例
|
||||
tar xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models/
|
||||
tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models/
|
||||
|
||||
# 注:正确解压backbone预训练权重文件后,文件夹下包含众多以网络层命名的权重文件,格式如下:
|
||||
./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
||||
|
@ -57,11 +57,11 @@ tar xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models
|
|||
|
||||
```
|
||||
|
||||
**启动训练**
|
||||
#### 启动训练
|
||||
|
||||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||
|
||||
```
|
||||
```shell
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
||||
```
|
||||
|
||||
|
@ -69,52 +69,52 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=
|
|||
有关配置文件的详细解释,请参考[链接](./config.md)。
|
||||
|
||||
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
|
||||
```
|
||||
```shell
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
|
||||
```
|
||||
|
||||
**断点训练**
|
||||
#### 断点训练
|
||||
|
||||
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
|
||||
```
|
||||
```shell
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./your/trained/model
|
||||
```
|
||||
|
||||
**注意**:Global.checkpoints的优先级高于Global.pretrain_weights的优先级,即同时指定两个参数时,优先加载Global.checkpoints指定的模型,如果Global.checkpoints指定的模型路径有误,会加载Global.pretrain_weights指定的模型。
|
||||
**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
|
||||
|
||||
## 指标评估
|
||||
|
||||
PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。
|
||||
|
||||
运行如下代码,根据配置文件det_db_mv3.yml中save_res_path指定的测试集检测结果文件,计算评估指标。
|
||||
运行如下代码,根据配置文件`det_db_mv3.yml`中`save_res_path`指定的测试集检测结果文件,计算评估指标。
|
||||
|
||||
评估时设置后处理参数box_thresh=0.6,unclip_ratio=1.5,使用不同数据集、不同模型训练,可调整这两个参数进行优化
|
||||
```
|
||||
评估时设置后处理参数`box_thresh=0.6`,`unclip_ratio=1.5`,使用不同数据集、不同模型训练,可调整这两个参数进行优化
|
||||
```shell
|
||||
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
|
||||
```
|
||||
训练中模型参数默认保存在Global.save_model_dir目录下。在评估指标时,需要设置Global.checkpoints指向保存的参数文件。
|
||||
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
|
||||
|
||||
比如:
|
||||
```
|
||||
```shell
|
||||
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
|
||||
```
|
||||
|
||||
* 注:box_thresh、unclip_ratio是DB后处理所需要的参数,在评估EAST模型时不需要设置
|
||||
* 注:`box_thresh`、`unclip_ratio`是DB后处理所需要的参数,在评估EAST模型时不需要设置
|
||||
|
||||
## 测试检测效果
|
||||
|
||||
测试单张图像的检测效果
|
||||
```
|
||||
```shell
|
||||
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
|
||||
```
|
||||
|
||||
测试DB模型时,调整后处理阈值,
|
||||
```
|
||||
```shell
|
||||
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
|
||||
```
|
||||
|
||||
|
||||
测试文件夹下所有图像的检测效果
|
||||
```
|
||||
```shell
|
||||
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"
|
||||
```
|
||||
|
|
|
@ -1,14 +1,37 @@
|
|||
|
||||
# 基于Python预测引擎推理
|
||||
|
||||
inference 模型(fluid.io.save_inference_model保存的模型)
|
||||
一般是模型训练完成后保存的固化模型,多用于预测部署。
|
||||
训练过程中保存的模型是checkpoints模型,保存的是模型的参数,多用于恢复训练等。
|
||||
inference 模型(`fluid.io.save_inference_model`保存的模型)
|
||||
一般是模型训练完成后保存的固化模型,多用于预测部署。训练过程中保存的模型是checkpoints模型,保存的是模型的参数,多用于恢复训练等。
|
||||
与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合与实际系统集成。更详细的介绍请参考文档[分类预测框架](https://paddleclas.readthedocs.io/zh_CN/latest/extension/paddle_inference.html).
|
||||
|
||||
接下来首先介绍如何将训练的模型转换成inference模型,然后将依次介绍文本检测、文本识别以及两者串联基于预测引擎推理。
|
||||
|
||||
|
||||
- [一、训练模型转inference模型](#训练模型转inference模型)
|
||||
- [检测模型转inference模型](#检测模型转inference模型)
|
||||
- [识别模型转inference模型](#识别模型转inference模型)
|
||||
|
||||
- [二、文本检测模型推理](#文本检测模型推理)
|
||||
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
|
||||
- [2. DB文本检测模型推理](#DB文本检测模型推理)
|
||||
- [3. EAST文本检测模型推理](#EAST文本检测模型推理)
|
||||
- [4. SAST文本检测模型推理](#SAST文本检测模型推理)
|
||||
|
||||
- [三、文本识别模型推理](#文本识别模型推理)
|
||||
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
|
||||
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
|
||||
- [3. 基于Attention损失的识别模型推理](#基于Attention损失的识别模型推理)
|
||||
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
||||
|
||||
- [四、文本检测、识别串联推理](#文本检测、识别串联推理)
|
||||
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
|
||||
- [2. 其他模型推理](#其他模型推理)
|
||||
|
||||
|
||||
<a name="训练模型转inference模型"></a>
|
||||
## 一、训练模型转inference模型
|
||||
<a name="检测模型转inference模型"></a>
|
||||
### 检测模型转inference模型
|
||||
|
||||
下载超轻量级中文检测模型:
|
||||
|
@ -24,15 +47,16 @@ wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db.tar &
|
|||
|
||||
python3 tools/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./ch_lite/det_mv3_db/best_accuracy Global.save_inference_dir=./inference/det_db/
|
||||
```
|
||||
转inference模型时,使用的配置文件和训练时使用的配置文件相同。另外,还需要设置配置文件中的Global.checkpoints、Global.save_inference_dir参数。
|
||||
其中Global.checkpoints指向训练中保存的模型参数文件,Global.save_inference_dir是生成的inference模型要保存的目录。
|
||||
转换成功后,在save_inference_dir 目录下有两个文件:
|
||||
转inference模型时,使用的配置文件和训练时使用的配置文件相同。另外,还需要设置配置文件中的`Global.checkpoints`、`Global.save_inference_dir`参数。
|
||||
其中`Global.checkpoints`指向训练中保存的模型参数文件,`Global.save_inference_dir`是生成的inference模型要保存的目录。
|
||||
转换成功后,在`save_inference_dir`目录下有两个文件:
|
||||
```
|
||||
inference/det_db/
|
||||
└─ model 检测inference模型的program文件
|
||||
└─ params 检测inference模型的参数文件
|
||||
```
|
||||
|
||||
<a name="识别模型转inference模型"></a>
|
||||
### 识别模型转inference模型
|
||||
|
||||
下载超轻量中文识别模型:
|
||||
|
@ -51,7 +75,7 @@ python3 tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Globa
|
|||
Global.save_inference_dir=./inference/rec_crnn/
|
||||
```
|
||||
|
||||
如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的character_dict_path是否是所需要的字典文件。
|
||||
**注意:**如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
|
||||
|
||||
转换成功后,在目录下有两个文件:
|
||||
```
|
||||
|
@ -60,11 +84,13 @@ python3 tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Globa
|
|||
└─ params 识别inference模型的参数文件
|
||||
```
|
||||
|
||||
<a name="文本检测模型推理"></a>
|
||||
## 二、文本检测模型推理
|
||||
|
||||
下面将介绍超轻量中文检测模型推理、DB文本检测模型推理和EAST文本检测模型推理。默认配置是根据DB文本检测模型推理设置的。由于EAST和DB算法差别很大,在推理时,需要通过传入相应的参数适配EAST文本检测算法。
|
||||
文本检测模型推理,默认使用DB模型的配置参数。当不使用DB模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
|
||||
|
||||
### 1.超轻量中文检测模型推理
|
||||
<a name="超轻量中文检测模型推理"></a>
|
||||
### 1. 超轻量中文检测模型推理
|
||||
|
||||
超轻量中文检测模型推理,可以执行如下命令:
|
||||
|
||||
|
@ -72,11 +98,11 @@ python3 tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Globa
|
|||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/"
|
||||
```
|
||||
|
||||
可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/det_res_2.jpg)
|
||||
|
||||
通过设置参数det_max_side_len的大小,改变检测算法中图片规范化的最大值。当图片的长宽都小于det_max_side_len,则使用原图预测,否则将图片等比例缩放到最大值,进行预测。该参数默认设置为det_max_side_len=960. 如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以执行如下命令:
|
||||
通过设置参数`det_max_side_len`的大小,改变检测算法中图片规范化的最大值。当图片的长宽都小于`det_max_side_len`,则使用原图预测,否则将图片等比例缩放到最大值,进行预测。该参数默认设置为`det_max_side_len=960`。 如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以执行如下命令:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_max_side_len=1200
|
||||
|
@ -87,7 +113,8 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_di
|
|||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
### 2.DB文本检测模型推理
|
||||
<a name="DB文本检测模型推理"></a>
|
||||
### 2. DB文本检测模型推理
|
||||
|
||||
首先将DB文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)),可以使用如下命令进行转换:
|
||||
|
||||
|
@ -105,13 +132,14 @@ DB文本检测模型推理,可以执行如下命令:
|
|||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_db/"
|
||||
```
|
||||
|
||||
可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/det_res_img_10_db.jpg)
|
||||
|
||||
**注意**:由于ICDAR2015数据集只有1000张训练图像,主要针对英文场景,所以上述模型对中文文本图像检测效果非常差。
|
||||
**注意**:由于ICDAR2015数据集只有1000张训练图像,且主要针对英文场景,所以上述模型对中文文本图像检测效果会比较差。
|
||||
|
||||
### 3.EAST文本检测模型推理
|
||||
<a name="EAST文本检测模型推理"></a>
|
||||
### 3. EAST文本检测模型推理
|
||||
|
||||
首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/det_r50_vd_east.tar)),可以使用如下命令进行转换:
|
||||
|
||||
|
@ -123,24 +151,59 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_
|
|||
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.checkpoints="./models/det_r50_vd_east/best_accuracy" Global.save_inference_dir="./inference/det_east"
|
||||
```
|
||||
|
||||
EAST文本检测模型推理,需要设置参数det_algorithm,指定检测算法类型为EAST,可以执行如下命令:
|
||||
**EAST文本检测模型推理,需要设置参数`--det_algorithm="EAST"`**,可以执行如下命令:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/"
|
||||
```
|
||||
可视化文本检测结果默认保存到 ./inference_results 文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/det_res_img_10_east.jpg)
|
||||
|
||||
**注意**:本代码库中EAST后处理中NMS采用的Python版本,所以预测速度比较耗时。如果采用C++版本,会有明显加速。
|
||||
**注意**:本代码库中,EAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
|
||||
|
||||
|
||||
<a name="SAST文本检测模型推理"></a>
|
||||
### 4. SAST文本检测模型推理
|
||||
#### (1). 四边形文本检测模型(ICDAR2015)
|
||||
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_icdar2015.tar)),可以使用如下命令进行转换:
|
||||
```
|
||||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.checkpoints="./models/sast_r50_vd_icdar2015/best_accuracy" Global.save_inference_dir="./inference/det_sast_ic15"
|
||||
```
|
||||
**SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`**,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_sast_ic15/"
|
||||
```
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/det_res_img_10_sast.jpg)
|
||||
|
||||
#### (2). 弯曲文本检测模型(Total-Text)
|
||||
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)),可以使用如下命令进行转换:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.checkpoints="./models/sast_r50_vd_total_text/best_accuracy" Global.save_inference_dir="./inference/det_sast_tt"
|
||||
```
|
||||
|
||||
**SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,**可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
|
||||
```
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/det_res_img623_sast.jpg)
|
||||
|
||||
**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
|
||||
|
||||
|
||||
<a name="文本识别模型推理"></a>
|
||||
## 三、文本识别模型推理
|
||||
|
||||
下面将介绍超轻量中文识别模型推理、基于CTC损失的识别模型推理和基于Attention损失的识别模型推理。对于中文文本识别,建议优先选择基于CTC损失的识别模型,实践中也发现基于Attention损失的效果不如基于CTC损失的识别模型。此外,如果训练时修改了文本的字典,请参考下面的自定义文本识别字典的推理。
|
||||
|
||||
|
||||
### 1.超轻量中文识别模型推理
|
||||
<a name="超轻量中文识别模型推理"></a>
|
||||
### 1. 超轻量中文识别模型推理
|
||||
|
||||
超轻量中文识别模型推理,可以执行如下命令:
|
||||
|
||||
|
@ -155,7 +218,8 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg"
|
|||
Predicts of ./doc/imgs_words/ch/word_4.jpg:['实力活力', 0.89552695]
|
||||
|
||||
|
||||
### 2.基于CTC损失的识别模型推理
|
||||
<a name="基于CTC损失的识别模型推理"></a>
|
||||
### 2. 基于CTC损失的识别模型推理
|
||||
|
||||
我们以STAR-Net为例,介绍基于CTC损失的识别模型推理。 CRNN和Rosetta使用方式类似,不用设置识别算法参数rec_algorithm。
|
||||
|
||||
|
@ -176,7 +240,8 @@ STAR-Net文本识别模型推理,可以执行如下命令:
|
|||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
|
||||
```
|
||||
|
||||
### 3.基于Attention损失的识别模型推理
|
||||
<a name="基于Attention损失的识别模型推理"></a>
|
||||
### 3. 基于Attention损失的识别模型推理
|
||||
|
||||
基于Attention损失的识别模型与ctc不同,需要额外设置识别算法参数 --rec_algorithm="RARE"
|
||||
|
||||
|
@ -202,16 +267,18 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
|||
dict_character = list(self.character_str)
|
||||
```
|
||||
|
||||
### 4.自定义文本识别字典的推理
|
||||
<a name="自定义文本识别字典的推理"></a>
|
||||
### 4. 自定义文本识别字典的推理
|
||||
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path"
|
||||
```
|
||||
|
||||
<a name="文本检测、识别串联推理"></a>
|
||||
## 四、文本检测、识别串联推理
|
||||
|
||||
### 1.超轻量中文OCR模型推理
|
||||
<a name="超轻量中文OCR模型推理"></a>
|
||||
### 1. 超轻量中文OCR模型推理
|
||||
|
||||
在执行预测时,需要通过参数image_dir指定单张图像或者图像集合的路径、参数det_model_dir指定检测inference模型的路径和参数rec_model_dir指定识别inference模型的路径。可视化识别结果默认保存到 ./inference_results 文件夹里面。
|
||||
|
||||
|
@ -223,9 +290,14 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model
|
|||
|
||||
![](../imgs_results/2.jpg)
|
||||
|
||||
### 2.其他模型推理
|
||||
<a name="其他模型推理"></a>
|
||||
### 2. 其他模型推理
|
||||
|
||||
如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型,下面给出基于EAST文本检测和STAR-Net文本识别执行命令:
|
||||
如果想尝试使用其他检测算法或者识别算法,请参考上述文本检测模型推理和文本识别模型推理,更新相应配置和模型。
|
||||
|
||||
**注意:由于检测框矫正逻辑的局限性,暂不支持使用SAST弯曲文本检测模型(即,使用参数`--det_sast_polygon=True`时)进行模型串联。**
|
||||
|
||||
下面给出基于EAST文本检测和STAR-Net文本识别执行命令:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
|
||||
|
|
|
@ -11,7 +11,7 @@ PaddleOCR 工作环境
|
|||
|
||||
*如您希望使用 mac 或 windows直接运行预测代码,可以从第2步开始执行。*
|
||||
|
||||
1. (建议)准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。
|
||||
**1. (建议)准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。**
|
||||
```
|
||||
# 切换到工作目录下
|
||||
cd /home/Projects
|
||||
|
@ -21,10 +21,10 @@ cd /home/Projects
|
|||
如果您希望在CPU环境下使用docker,使用docker而不是nvidia-docker创建docker
|
||||
sudo docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash
|
||||
|
||||
如果您的机器安装的是CUDA9,请运行以下命令创建容器
|
||||
如果使用CUDA9,请运行以下命令创建容器
|
||||
sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash
|
||||
|
||||
如果您的机器安装的是CUDA10,请运行以下命令创建容器
|
||||
如果使用CUDA10,请运行以下命令创建容器
|
||||
sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda10.0-cudnn7-dev /bin/bash
|
||||
|
||||
您也可以访问[DockerHub](https://hub.docker.com/r/paddlepaddle/paddle/tags/)获取与您机器适配的镜像。
|
||||
|
@ -47,19 +47,7 @@ docker images
|
|||
hub.baidubce.com/paddlepaddle/paddle latest-gpu-cuda9.0-cudnn7-dev f56310dcc829
|
||||
```
|
||||
|
||||
2. 更改python3默认版本
|
||||
|
||||
docker中的python默认使用python3.5,PaddleOCR需要在Python3.7下执行(该版本下,对于第三方依赖库的兼容性更好一些)。进入docker后,可以编辑`/etc/profile`文件,之后在文件末尾添加
|
||||
|
||||
```shell
|
||||
|
||||
alias python3=python3.7
|
||||
alias pip3=pip3.7
|
||||
```
|
||||
|
||||
保存之后,使用`source /etc/profile`命令使设置的默认Python生效。
|
||||
|
||||
3. 安装PaddlePaddle Fluid v1.7
|
||||
**2. 安装PaddlePaddle Fluid v1.7**
|
||||
```
|
||||
pip3 install --upgrade pip
|
||||
|
||||
|
@ -76,7 +64,7 @@ python3 -m pip install paddlepaddle==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/
|
|||
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
||||
```
|
||||
|
||||
4. 克隆PaddleOCR repo代码
|
||||
**3. 克隆PaddleOCR repo代码**
|
||||
```
|
||||
【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
|
||||
|
||||
|
@ -87,7 +75,7 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
|
|||
注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
|
||||
```
|
||||
|
||||
5. 安装第三方库
|
||||
**4. 安装第三方库**
|
||||
```
|
||||
cd PaddleOCR
|
||||
pip3 install -r requirments.txt
|
||||
|
|
|
@ -18,6 +18,8 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
|
|||
|
||||
若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。
|
||||
|
||||
如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。
|
||||
|
||||
* 使用自己数据集:
|
||||
|
||||
若您希望使用自己的数据进行训练,请参考下文组织您的数据。
|
||||
|
@ -161,6 +163,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
|||
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
||||
| rec_r34_vd_tps_bilstm_attn.yml | RARE | Resnet34_vd | tps | BiLSTM | attention |
|
||||
| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
|
||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||
|
||||
训练中文数据,推荐使用`rec_chinese_lite_train.yml`,如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# 更新
|
||||
- 2020.8.16 开源文本检测算法[SAST](https://arxiv.org/abs/1908.05498)和文本识别算法[SRN](https://arxiv.org/abs/2003.12294)
|
||||
- 2020.7.23 发布7月21日B站直播课回放和PPT,PaddleOCR开源大礼包全面解读,[获取地址](https://aistudio.baidu.com/aistudio/course/introduce/1519)
|
||||
- 2020.7.15 添加基于EasyEdge和Paddle-Lite的移动端DEMO,支持iOS和Android系统
|
||||
- 2020.7.15 完善预测部署,添加基于C++预测引擎推理、服务化部署和端侧部署方案,以及超轻量级中文OCR模型预测耗时Benchmark
|
||||
|
|
|
@ -48,6 +48,9 @@ At present, the open source model, dataset and magnitude are as follows:
|
|||
Error message: Input(X) dims[3] and Input(Grid) dims[2] should be equal, but received X dimension[3](108) != Grid dimension[2](100)
|
||||
Solution:TPS does not support variable shape. Please set --rec_image_shape='3,32,100' and --rec_char_type='en'
|
||||
|
||||
11. **Custom dictionary used during training, the recognition results show that words do not appear in the dictionary**
|
||||
11. **Custom dictionary used during training, the recognition results show that words do not appear in the dictionary**
|
||||
The used custom dictionary path is not set when making prediction. The solution is setting parameter `rec_char_dict_path` to the corresponding dictionary file.
|
||||
|
||||
The used custom dictionary path is not set when making prediction. The solution is setting parameter `rec_char_dict_path` to the corresponding dictionary file.
|
||||
|
||||
12. **Results of cpp_infer and python_inference are very different**
|
||||
Versions of exprted inference model and inference libraray should be same. For example, on Windows platform, version of the inference libraray that PaddlePaddle provides is 1.8, but version of the inference model that PaddleOCR provides is 1.7, you should export model yourself(`tools/export_model.py`) on PaddlePaddle1.8 and then use the exported model for inference.
|
||||
|
|
|
@ -60,6 +60,9 @@ Take `rec_icdar15_train.yml` as an example:
|
|||
| beta1 | Set the exponential decay rate for the 1st moment estimates | 0.9 | \ |
|
||||
| beta2 | Set the exponential decay rate for the 2nd moment estimates | 0.999 | \ |
|
||||
| decay | Whether to use decay | \ | \ |
|
||||
| function(decay) | Set the decay function | cosine_decay | Only support cosine_decay |
|
||||
| step_each_epoch | The number of steps in an epoch. | 20 | Calculation :total_image_num / (batch_size_per_card * card_size) |
|
||||
| total_epoch | The number of epochs | 1000 | Consistent with Global.epoch_num |
|
||||
| function(decay) | Set the decay function | cosine_decay | Support cosine_decay, cosine_decay_warmup and piecewise_decay |
|
||||
| step_each_epoch | The number of steps in an epoch. Used in cosine_decay/cosine_decay_warmup | 20 | Calculation: total_image_num / (batch_size_per_card * card_size) |
|
||||
| total_epoch | The number of epochs. Used in cosine_decay/cosine_decay_warmup | 1000 | Consistent with Global.epoch_num |
|
||||
| warmup_minibatch | Number of steps for linear warmup. Used in cosine_decay_warmup | 1000 | \ |
|
||||
| boundaries | The step intervals to reduce learning rate. Used in piecewise_decay | - | The format is list |
|
||||
| decay_rate | Learning rate decay rate. Used in piecewise_decay | - | \ |
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# TEXT DETECTION
|
||||
|
||||
This section uses the icdar15 dataset as an example to introduce the training, evaluation, and testing of the detection model in PaddleOCR.
|
||||
This section uses the icdar2015 dataset as an example to introduce the training, evaluation, and testing of the detection model in PaddleOCR.
|
||||
|
||||
## DATA PREPARATION
|
||||
The icdar2015 dataset can be obtained from [official website](https://rrc.cvc.uab.es/?ch=4&com=downloads). Registration is required for downloading.
|
||||
|
||||
Decompress the downloaded dataset to the working directory, assuming it is decompressed under PaddleOCR/train_data/. In addition, PaddleOCR organizes many scattered annotation files into two separate annotation files for train and test respectively, which can be downloaded by wget:
|
||||
```
|
||||
```shell
|
||||
# Under the PaddleOCR path
|
||||
cd PaddleOCR/
|
||||
wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/train_icdar2015_label.txt
|
||||
|
@ -25,18 +25,21 @@ After decompressing the data set and downloading the annotation file, PaddleOCR/
|
|||
The provided annotation file format is as follow, seperated by "\t":
|
||||
```
|
||||
" Image file name Image annotation information encoded by json.dumps"
|
||||
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]
|
||||
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]
|
||||
```
|
||||
The image annotation after json.dumps() encoding is a list containing multiple dictionaries. The `points` in the dictionary represent the coordinates (x, y) of the four points of the text box, arranged clockwise from the point at the upper left corner.
|
||||
The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries.
|
||||
|
||||
`transcription` represents the text of the current text box, and this information is not needed in the text detection task.
|
||||
If you want to train PaddleOCR on other datasets, you can build the annotation file according to the above format.
|
||||
The `points` in the dictionary represent the coordinates (x, y) of the four points of the text box, arranged clockwise from the point at the upper left corner.
|
||||
|
||||
`transcription` represents the text of the current text box. **When its content is "###" it means that the text box is invalid and will be skipped during training.**
|
||||
|
||||
If you want to train PaddleOCR on other datasets, please build the annotation file according to the above format.
|
||||
|
||||
|
||||
## TRAINING
|
||||
|
||||
First download the pretrained model. The detection model of PaddleOCR currently supports two backbones, namely MobileNetV3 and ResNet50_vd. You can use the model in [PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures) to replace backbone according to your needs.
|
||||
```
|
||||
```shell
|
||||
cd PaddleOCR/
|
||||
# Download the pre-trained model of MobileNetV3
|
||||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x0_5_pretrained.tar
|
||||
|
@ -44,7 +47,7 @@ wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/Mob
|
|||
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar
|
||||
|
||||
# decompressing the pre-training model file, take MobileNetV3 as an example
|
||||
tar xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models/
|
||||
tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models/
|
||||
|
||||
# Note: After decompressing the backbone pre-training weight file correctly, the file list in the folder is as follows:
|
||||
./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
||||
|
@ -56,9 +59,9 @@ tar xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models
|
|||
|
||||
```
|
||||
|
||||
**START TRAINING**
|
||||
#### START TRAINING
|
||||
*If CPU version installed, please set the parameter `use_gpu` to `false` in the configuration.*
|
||||
```
|
||||
```shell
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml
|
||||
```
|
||||
|
||||
|
@ -66,19 +69,19 @@ In the above instruction, use `-c` to select the training to use the `configs/de
|
|||
For a detailed explanation of the configuration file, please refer to [config](./config_en.md).
|
||||
|
||||
You can also use `-o` to change the training parameters without modifying the yml file. For example, adjust the training learning rate to 0.0001
|
||||
```
|
||||
```shell
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
|
||||
```
|
||||
|
||||
**load trained model and conntinue training**
|
||||
#### load trained model and conntinue training
|
||||
If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded.
|
||||
|
||||
For example:
|
||||
```
|
||||
```shell
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./your/trained/model
|
||||
```
|
||||
|
||||
**Note**:The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by Global.checkpoints will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded.
|
||||
**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded.
|
||||
|
||||
|
||||
## EVALUATION
|
||||
|
@ -89,7 +92,7 @@ Run the following code to calculate the evaluation indicators. The result will b
|
|||
|
||||
When evaluating, set post-processing parameters `box_thresh=0.6`, `unclip_ratio=1.5`. If you use different datasets, different models for training, these two parameters should be adjusted for better result.
|
||||
|
||||
```
|
||||
```shell
|
||||
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
|
||||
```
|
||||
The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating indicators, you need to set `Global.checkpoints` to point to the saved parameter file.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
|
||||
# Reasoning based on Python prediction engine
|
||||
|
||||
The inference model (the model saved by fluid.io.save_inference_model) is generally a solidified model saved after the model training is completed, and is mostly used to give prediction in deployment.
|
||||
The inference model (the model saved by `fluid.io.save_inference_model`) is generally a solidified model saved after the model training is completed, and is mostly used to give prediction in deployment.
|
||||
|
||||
The model saved during the training process is the checkpoints model, which saves the parameters of the model and is mostly used to resume training.
|
||||
|
||||
|
@ -9,7 +9,31 @@ Compared with the checkpoints model, the inference model will additionally save
|
|||
|
||||
Next, we first introduce how to convert a trained model into an inference model, and then we will introduce text detection, text recognition, and the concatenation of them based on inference model.
|
||||
|
||||
- [CONVERT TRAINING MODEL TO INFERENCE MODEL](#CONVERT)
|
||||
- [Convert detection model to inference model](#Convert_detection_model)
|
||||
- [Convert recognition model to inference model](#Convert_recognition_model)
|
||||
|
||||
|
||||
- [TEXT DETECTION MODEL INFERENCE](#DETECTION_MODEL_INFERENCE)
|
||||
- [1. LIGHTWEIGHT CHINESE DETECTION MODEL INFERENCE](#LIGHTWEIGHT_DETECTION)
|
||||
- [2. DB TEXT DETECTION MODEL INFERENCE](#DB_DETECTION)
|
||||
- [3. EAST TEXT DETECTION MODEL INFERENCE](#EAST_DETECTION)
|
||||
- [4. SAST TEXT DETECTION MODEL INFERENCE](#SAST_DETECTION)
|
||||
|
||||
- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
|
||||
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
|
||||
- [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
|
||||
- [3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE](#ATTENTION-BASED_RECOGNITION)
|
||||
- [4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
|
||||
|
||||
|
||||
- [TEXT DETECTION AND RECOGNITION INFERENCE CONCATENATION](#CONCATENATION)
|
||||
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_CHINESE_MODEL)
|
||||
- [2. OTHER MODELS](#OTHER_MODELS)
|
||||
|
||||
<a name="CONVERT"></a>
|
||||
## CONVERT TRAINING MODEL TO INFERENCE MODEL
|
||||
<a name="Convert_detection_model"></a>
|
||||
### Convert detection model to inference model
|
||||
|
||||
Download the lightweight Chinese detection model:
|
||||
|
@ -35,6 +59,7 @@ inference/det_db/
|
|||
└─ params Check the parameter file of the inference model
|
||||
```
|
||||
|
||||
<a name="Convert_recognition_model"></a>
|
||||
### Convert recognition model to inference model
|
||||
|
||||
Download the lightweight Chinese recognition model:
|
||||
|
@ -62,11 +87,13 @@ After the conversion is successful, there are two files in the directory:
|
|||
└─ params Identify the parameter files of the inference model
|
||||
```
|
||||
|
||||
<a name="DETECTION_MODEL_INFERENCE"></a>
|
||||
## TEXT DETECTION MODEL INFERENCE
|
||||
|
||||
The following will introduce the lightweight Chinese detection model inference, DB text detection model inference and EAST text detection model inference. The default configuration is based on the inference setting of the DB text detection model.
|
||||
Because EAST and DB algorithms are very different, when inference, it is necessary to **adapt the EAST text detection algorithm by passing in corresponding parameters**.
|
||||
|
||||
<a name="LIGHTWEIGHT_DETECTION"></a>
|
||||
### 1. LIGHTWEIGHT CHINESE DETECTION MODEL INFERENCE
|
||||
|
||||
For lightweight Chinese detection model inference, you can execute the following commands:
|
||||
|
@ -90,6 +117,7 @@ If you want to use the CPU for prediction, execute the command as follows
|
|||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB_DETECTION"></a>
|
||||
### 2. DB TEXT DETECTION MODEL INFERENCE
|
||||
|
||||
First, convert the model saved in the DB text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)), you can use the following command to convert:
|
||||
|
@ -114,6 +142,7 @@ The visualized text detection results are saved to the `./inference_results` fol
|
|||
|
||||
**Note**: Since the ICDAR2015 dataset has only 1,000 training images, mainly for English scenes, the above model has very poor detection result on Chinese text images.
|
||||
|
||||
<a name="EAST_DETECTION"></a>
|
||||
### 3. EAST TEXT DETECTION MODEL INFERENCE
|
||||
|
||||
First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/det_r50_vd_east.tar)), you can use the following command to convert:
|
||||
|
@ -126,23 +155,64 @@ First, convert the model saved in the EAST text detection training process into
|
|||
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.checkpoints="./models/det_r50_vd_east/best_accuracy" Global.save_inference_dir="./inference/det_east"
|
||||
```
|
||||
|
||||
For EAST text detection model inference, you need to set the parameter det_algorithm, specify the detection algorithm type to EAST, run the following command:
|
||||
**For EAST text detection model inference, you need to set the parameter ``--det_algorithm="EAST"``**, run the following command:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST"
|
||||
```
|
||||
|
||||
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
|
||||
|
||||
![](../imgs_results/det_res_img_10_east.jpg)
|
||||
|
||||
**Note**: The Python version of NMS in EAST post-processing used in this codebase so the prediction speed is quite slow. If you use the C++ version, there will be a significant speedup.
|
||||
**Note**: EAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
|
||||
|
||||
|
||||
<a name="SAST_DETECTION"></a>
|
||||
### 4. SAST TEXT DETECTION MODEL INFERENCE
|
||||
#### (1). Quadrangle text detection model (ICDAR2015)
|
||||
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_icdar2015.tar)), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.checkpoints="./models/sast_r50_vd_icdar2015/best_accuracy" Global.save_inference_dir="./inference/det_sast_ic15"
|
||||
```
|
||||
|
||||
**For SAST quadrangle text detection model inference, you need to set the parameter `--det_algorithm="SAST"`**, run the following command:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_sast_ic15/"
|
||||
```
|
||||
|
||||
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
|
||||
|
||||
![](../imgs_results/det_res_img_10_sast.jpg)
|
||||
|
||||
#### (2). Curved text detection model (Total-Text)
|
||||
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.checkpoints="./models/sast_r50_vd_total_text/best_accuracy" Global.save_inference_dir="./inference/det_sast_tt"
|
||||
```
|
||||
|
||||
**For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`**, run the following command:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
|
||||
```
|
||||
|
||||
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
|
||||
|
||||
![](../imgs_results/det_res_img623_sast.jpg)
|
||||
|
||||
**Note**: SAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
|
||||
|
||||
<a name="RECOGNITION_MODEL_INFERENCE"></a>
|
||||
## TEXT RECOGNITION MODEL INFERENCE
|
||||
|
||||
The following will introduce the lightweight Chinese recognition model inference, other CTC-based and Attention-based text recognition models inference. For Chinese text recognition, it is recommended to choose the recognition model based on CTC loss. In practice, it is also found that the result of the model based on Attention loss is not as good as the one based on CTC loss. In addition, if the characters dictionary is modified during training, make sure that you use the same characters set during inferencing. Please check below for details.
|
||||
|
||||
|
||||
<a name="LIGHTWEIGHT_RECOGNITION"></a>
|
||||
### 1. LIGHTWEIGHT CHINESE TEXT RECOGNITION MODEL REFERENCE
|
||||
|
||||
For lightweight Chinese recognition model inference, you can execute the following commands:
|
||||
|
@ -158,6 +228,7 @@ After executing the command, the prediction results (recognized text and score)
|
|||
Predicts of ./doc/imgs_words/ch/word_4.jpg:['实力活力', 0.89552695]
|
||||
|
||||
|
||||
<a name="CTC-BASED_RECOGNITION"></a>
|
||||
### 2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE
|
||||
|
||||
Taking STAR-Net as an example, we introduce the recognition model inference based on CTC loss. CRNN and Rosetta are used in a similar way, by setting the recognition algorithm parameter `rec_algorithm`.
|
||||
|
@ -178,6 +249,7 @@ For STAR-Net text recognition model inference, execute the following commands:
|
|||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
|
||||
```
|
||||
|
||||
<a name="ATTENTION-BASED_RECOGNITION"></a>
|
||||
### 3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE
|
||||
![](../imgs_words_en/word_336.png)
|
||||
|
||||
|
@ -196,6 +268,7 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
|||
dict_character = list(self.character_str)
|
||||
```
|
||||
|
||||
<a name="USING_CUSTOM_CHARACTERS"></a>
|
||||
### 4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
|
||||
If the chars dictionary is modified during training, you need to specify the new dictionary path by setting the parameter `rec_char_dict_path` when using your inference model to predict.
|
||||
|
||||
|
@ -203,8 +276,10 @@ If the chars dictionary is modified during training, you need to specify the new
|
|||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path"
|
||||
```
|
||||
|
||||
<a name="CONCATENATION"></a>
|
||||
## TEXT DETECTION AND RECOGNITION INFERENCE CONCATENATION
|
||||
|
||||
<a name="LIGHTWEIGHT_CHINESE_MODEL"></a>
|
||||
### 1. LIGHTWEIGHT CHINESE MODEL
|
||||
|
||||
When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, and the parameter `rec_model_dir` specifies the path to identify the inference model. The visualized recognition results are saved to the `./inference_results` folder by default.
|
||||
|
@ -217,9 +292,14 @@ After executing the command, the recognition result image is as follows:
|
|||
|
||||
![](../imgs_results/2.jpg)
|
||||
|
||||
<a name="OTHER_MODELS"></a>
|
||||
### 2. OTHER MODELS
|
||||
|
||||
If you want to try other detection algorithms or recognition algorithms, please refer to the above text detection model inference and text recognition model inference, update the corresponding configuration and model, the following command uses the combination of the EAST text detection and STAR-Net text recognition:
|
||||
If you want to try other detection algorithms or recognition algorithms, please refer to the above text detection model inference and text recognition model inference, update the corresponding configuration and model.
|
||||
|
||||
**Note: due to the limitation of rotation logic of detected box, SAST curved text detection model (using the parameter `det_sast_polygon=True`) is not supported for model combination yet.**
|
||||
|
||||
The following command uses the combination of the EAST text detection and STAR-Net text recognition:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
|
||||
|
|
|
@ -11,21 +11,21 @@ It is recommended to use the docker provided by us to run PaddleOCR, please refe
|
|||
|
||||
*If you want to directly run the prediction code on mac or windows, you can start from step 2.*
|
||||
|
||||
1. (Recommended) Prepare a docker environment. The first time you use this image, it will be downloaded automatically. Please be patient.
|
||||
**1. (Recommended) Prepare a docker environment. The first time you use this image, it will be downloaded automatically. Please be patient.**
|
||||
```
|
||||
# Switch to the working directory
|
||||
cd /home/Projects
|
||||
# You need to create a docker container for the first run, and do not need to run the current command when you run it again
|
||||
# Create a docker container named ppocr and map the current directory to the /paddle directory of the container
|
||||
|
||||
#If you want to use docker in a CPU environment, use docker instead of nvidia-docker to create docker
|
||||
#If using CPU, use docker instead of nvidia-docker to create docker
|
||||
sudo docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash
|
||||
```
|
||||
If you have cuda9 installed on your machine, please run the following command to create a container:
|
||||
If using CUDA9, please run the following command to create a container:
|
||||
```
|
||||
sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev /bin/bash
|
||||
```
|
||||
If you have cuda10 installed on your machine, please run the following command to create a container:
|
||||
If using CUDA10, please run the following command to create a container:
|
||||
```
|
||||
sudo nvidia-docker run --name ppocr -v $PWD:/paddle --network=host -it hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda10.0-cudnn7-dev /bin/bash
|
||||
```
|
||||
|
@ -49,20 +49,7 @@ docker images
|
|||
hub.baidubce.com/paddlepaddle/paddle latest-gpu-cuda9.0-cudnn7-dev f56310dcc829
|
||||
```
|
||||
|
||||
2. Change default version of python3
|
||||
|
||||
Python3.5 is used as the default version of python. However, Python3.7 is preferred in PaddleOCR for better compatibility of third-party libraries. After entering docker, you can edit file `/etc/profile`, add the following content at the end of the file.
|
||||
|
||||
|
||||
```shell
|
||||
|
||||
alias python3=python3.7
|
||||
alias pip3=pip3.7
|
||||
```
|
||||
|
||||
After saving the file `/etc/profile`. The command `source /etc/profile` needs to be carried out to make the default python version as 3.7 effective.
|
||||
|
||||
3. Install PaddlePaddle Fluid v1.7 (the higher version is not supported yet, the adaptation work is in progress)
|
||||
**2. Install PaddlePaddle Fluid v1.7 (the higher version is not supported yet, the adaptation work is in progress)**
|
||||
```
|
||||
pip3 install --upgrade pip
|
||||
|
||||
|
@ -78,7 +65,7 @@ python3 -m pip install paddlepaddle==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/
|
|||
For more software version requirements, please refer to the instructions in [Installation Document](https://www.paddlepaddle.org.cn/install/quick) for operation.
|
||||
|
||||
|
||||
4. Clone PaddleOCR repo
|
||||
**3. Clone PaddleOCR repo**
|
||||
```
|
||||
# Recommend
|
||||
git clone https://github.com/PaddlePaddle/PaddleOCR
|
||||
|
@ -90,7 +77,7 @@ git clone https://gitee.com/paddlepaddle/PaddleOCR
|
|||
# Note: The cloud-hosting code may not be able to synchronize the update with this GitHub project in real time. There might be a delay of 3-5 days. Please give priority to the recommended method.
|
||||
```
|
||||
|
||||
5. Install third-party libraries
|
||||
**4. Install third-party libraries**
|
||||
```
|
||||
cd PaddleOCR
|
||||
pip3 install -r requirments.txt
|
||||
|
|
|
@ -18,6 +18,8 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
|
|||
|
||||
If you do not have a dataset locally, you can download it on the official website [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads). Also refer to [DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),download the lmdb format dataset required for benchmark
|
||||
|
||||
If you want to reproduce the paper indicators of SRN, you need to download offline [augmented data](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA), extraction code: y3ry. The augmented data is obtained by rotation and perturbation of mjsynth and synthtext. Please unzip the data to {your_path}/PaddleOCR/train_data/data_lmdb_Release/training/path.
|
||||
|
||||
* Use your own dataset:
|
||||
|
||||
If you want to use your own data for training, please refer to the following to organize your data.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# RECENT UPDATES
|
||||
- 2020.8.16 Release text detection algorithm [SAST](https://arxiv.org/abs/1908.05498) and text recognition algorithm [SRN](https://arxiv.org/abs/2003.12294)
|
||||
- 2020.7.23, Release the playback and PPT of live class on BiliBili station, PaddleOCR Introduction, [address](https://aistudio.baidu.com/aistudio/course/introduce/1519)
|
||||
- 2020.7.15, Add mobile App demo , support both iOS and Android ( based on easyedge and Paddle Lite)
|
||||
- 2020.7.15, Improve the deployment ability, add the C + + inference , serving deployment. In addtion, the benchmarks of the ultra-lightweight Chinese OCR model are provided.
|
||||
|
|
Binary file not shown.
After Width: | Height: | Size: 248 KiB |
Binary file not shown.
After Width: | Height: | Size: 126 KiB |
Binary file not shown.
After Width: | Height: | Size: 333 KiB |
|
@ -0,0 +1,28 @@
|
|||
# Version: 1.0.0
|
||||
FROM hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev
|
||||
|
||||
# PaddleOCR base on Python3.7
|
||||
RUN pip3.7 install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN python3.7 -m pip install paddlepaddle==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN pip3.7 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN git clone https://gitee.com/PaddlePaddle/PaddleOCR
|
||||
|
||||
WORKDIR /PaddleOCR
|
||||
|
||||
RUN pip3.7 install -r requirments.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN mkdir -p /PaddleOCR/inference
|
||||
# Download orc detect model(light version). if you want to change normal version, you can change ch_det_mv3_db_infer to ch_det_r50_vd_db_infer, also remember change det_model_dir in deploy/hubserving/ocr_system/params.py)
|
||||
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar /PaddleOCR/inference
|
||||
RUN tar xf /PaddleOCR/inference/ch_det_mv3_db_infer.tar -C /PaddleOCR/inference
|
||||
|
||||
# Download orc recognition model(light version). If you want to change normal version, you can change ch_rec_mv3_crnn_infer to ch_rec_r34_vd_crnn_enhance_infer, also remember change rec_model_dir in deploy/hubserving/ocr_system/params.py)
|
||||
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar /PaddleOCR/inference
|
||||
RUN tar xf /PaddleOCR/inference/ch_rec_mv3_crnn_infer.tar -C /PaddleOCR/inference
|
||||
|
||||
EXPOSE 8866
|
||||
|
||||
CMD ["/bin/bash","-c","export PYTHONPATH=. && hub install deploy/hubserving/ocr_system/ && hub serving start -m ocr_system"]
|
|
@ -0,0 +1,28 @@
|
|||
# Version: 1.0.0
|
||||
FROM hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda10.0-cudnn7-dev
|
||||
|
||||
# PaddleOCR base on Python3.7
|
||||
RUN pip3.7 install --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN python3.7 -m pip install paddlepaddle-gpu==1.7.2.post107 -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN pip3.7 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN git clone https://gitee.com/PaddlePaddle/PaddleOCR
|
||||
|
||||
WORKDIR /home/PaddleOCR
|
||||
|
||||
RUN pip3.7 install -r requirments.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
RUN mkdir -p /PaddleOCR/inference
|
||||
# Download orc detect model(light version). if you want to change normal version, you can change ch_det_mv3_db_infer to ch_det_r50_vd_db_infer, also remember change det_model_dir in deploy/hubserving/ocr_system/params.py)
|
||||
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar /PaddleOCR/inference
|
||||
RUN tar xf /PaddleOCR/inference/ch_det_mv3_db_infer.tar -C /PaddleOCR/inference
|
||||
|
||||
# Download orc recognition model(light version). If you want to change normal version, you can change ch_rec_mv3_crnn_infer to ch_rec_r34_vd_crnn_enhance_infer, also remember change rec_model_dir in deploy/hubserving/ocr_system/params.py)
|
||||
ADD https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_infer.tar /PaddleOCR/inference
|
||||
RUN tar xf /PaddleOCR/inference/ch_rec_mv3_crnn_infer.tar -C /PaddleOCR/inference
|
||||
|
||||
EXPOSE 8866
|
||||
|
||||
CMD ["/bin/bash","-c","export PYTHONPATH=. && hub install deploy/hubserving/ocr_system/ && hub serving start -m ocr_system"]
|
|
@ -0,0 +1,55 @@
|
|||
# Docker化部署服务
|
||||
在日常项目应用中,相信大家一般都会希望能通过Docker技术,把PaddleOCR服务打包成一个镜像,以便在Docker或k8s环境里,快速发布上线使用。
|
||||
|
||||
本文将提供一些标准化的代码来实现这样的目标。大家通过如下步骤可以把PaddleOCR项目快速发布成可调用的Restful API服务。(目前暂时先实现了基于HubServing模式的部署,后续作者计划增加PaddleServing模式的部署)
|
||||
|
||||
## 1.实施前提准备
|
||||
|
||||
需要先完成如下基本组件的安装:
|
||||
a. Docker环境
|
||||
b. 显卡驱动和CUDA 10.0+(GPU)
|
||||
c. NVIDIA Container Toolkit(GPU,Docker 19.03以上版本可以跳过此步)
|
||||
d. cuDNN 7.6+(GPU)
|
||||
|
||||
## 2.制作镜像
|
||||
a.下载PaddleOCR项目代码
|
||||
```
|
||||
git clone https://github.com/PaddlePaddle/PaddleOCR.git
|
||||
```
|
||||
b.切换至Dockerfile目录(注:需要区分cpu或gpu版本,下文以cpu为例,gpu版本需要替换一下关键字即可)
|
||||
```
|
||||
cd docker/cpu
|
||||
```
|
||||
c.生成镜像
|
||||
```
|
||||
docker build -t paddleocr:cpu .
|
||||
```
|
||||
|
||||
## 3.启动Docker容器
|
||||
a. CPU 版本
|
||||
```
|
||||
sudo docker run -dp 8866:8866 --name paddle_ocr paddleocr:cpu
|
||||
```
|
||||
b. GPU 版本 (通过NVIDIA Container Toolkit)
|
||||
```
|
||||
sudo nvidia-docker run -dp 8866:8866 --name paddle_ocr paddleocr:gpu
|
||||
```
|
||||
c. GPU 版本 (Docker 19.03以上版本,可以直接用如下命令)
|
||||
```
|
||||
sudo docker run -dp 8866:8866 --gpus all --name paddle_ocr paddleocr:gpu
|
||||
```
|
||||
d. 检查服务运行情况(出现:Successfully installed ocr_system和Running on http://0.0.0.0:8866/等信息,表示运行成功)
|
||||
```
|
||||
docker logs -f paddle_ocr
|
||||
```
|
||||
|
||||
## 4.测试服务
|
||||
a. 计算待识别图片的Base64编码(如果只是测试一下效果,可以通过免费的在线工具实现,如:http://tool.chinaz.com/tools/imgtobase/)
|
||||
b. 发送服务请求(可参见sample_request.txt中的值)
|
||||
```
|
||||
curl -H "Content-Type:application/json" -X POST --data "{\"images\": [\"填入图片Base64编码(需要删除'data:image/jpg;base64,')\"]}" http://localhost:8866/predict/ocr_system
|
||||
```
|
||||
c. 返回结果(如果调用成功,会返回如下结果)
|
||||
```
|
||||
{"msg":"","results":[[{"confidence":0.8403433561325073,"text":"约定","text_region":[[345,377],[641,390],[634,540],[339,528]]},{"confidence":0.8131805658340454,"text":"最终相遇","text_region":[[356,532],[624,530],[624,596],[356,598]]}]],"status":"0"}
|
||||
```
|
File diff suppressed because one or more lines are too long
|
@ -31,22 +31,27 @@ class TrainReader(object):
|
|||
def __init__(self, params):
|
||||
self.num_workers = params['num_workers']
|
||||
self.label_file_path = params['label_file_path']
|
||||
print(self.label_file_path)
|
||||
self.use_mul_data = False
|
||||
if isinstance(self.label_file_path, list):
|
||||
self.use_mul_data = True
|
||||
self.data_ratio_list = params['data_ratio_list']
|
||||
self.batch_size = params['train_batch_size_per_card']
|
||||
assert 'process_function' in params,\
|
||||
"absence process_function in Reader"
|
||||
self.process = create_module(params['process_function'])(params)
|
||||
|
||||
def __call__(self, process_id):
|
||||
with open(self.label_file_path, "rb") as fin:
|
||||
label_infor_list = fin.readlines()
|
||||
img_num = len(label_infor_list)
|
||||
img_id_list = list(range(img_num))
|
||||
if sys.platform == "win32" and self.num_workers != 1:
|
||||
print("multiprocess is not fully compatible with Windows."
|
||||
"num_workers will be 1.")
|
||||
self.num_workers = 1
|
||||
def sample_iter_reader():
|
||||
with open(self.label_file_path, "rb") as fin:
|
||||
label_infor_list = fin.readlines()
|
||||
img_num = len(label_infor_list)
|
||||
img_id_list = list(range(img_num))
|
||||
random.shuffle(img_id_list)
|
||||
if sys.platform == "win32" and self.num_workers != 1:
|
||||
print("multiprocess is not fully compatible with Windows."
|
||||
"num_workers will be 1.")
|
||||
self.num_workers = 1
|
||||
for img_id in range(process_id, img_num, self.num_workers):
|
||||
label_infor = label_infor_list[img_id_list[img_id]]
|
||||
outs = self.process(label_infor)
|
||||
|
@ -54,13 +59,64 @@ class TrainReader(object):
|
|||
continue
|
||||
yield outs
|
||||
|
||||
def sample_iter_reader_mul():
|
||||
batch_size = 1000
|
||||
data_source_list = self.label_file_path
|
||||
batch_size_list = list(map(int, [max(1.0, batch_size * x) for x in self.data_ratio_list]))
|
||||
print(self.data_ratio_list, batch_size_list)
|
||||
|
||||
data_filename_list, data_size_list, fetch_record_list = [], [], []
|
||||
for data_source in data_source_list:
|
||||
image_files = open(data_source, "rb").readlines()
|
||||
random.shuffle(image_files)
|
||||
data_filename_list.append(image_files)
|
||||
data_size_list.append(len(image_files))
|
||||
fetch_record_list.append(0)
|
||||
|
||||
image_batch = []
|
||||
# get a batch of img_fns and poly_fns
|
||||
for i in range(0, len(batch_size_list)):
|
||||
bs = batch_size_list[i]
|
||||
ds = data_size_list[i]
|
||||
image_names = data_filename_list[i]
|
||||
fetch_record = fetch_record_list[i]
|
||||
data_path = data_source_list[i]
|
||||
for j in range(fetch_record, fetch_record + bs):
|
||||
index = j % ds
|
||||
image_batch.append(image_names[index])
|
||||
|
||||
if (fetch_record + bs) > ds:
|
||||
fetch_record_list[i] = 0
|
||||
random.shuffle(data_filename_list[i])
|
||||
else:
|
||||
fetch_record_list[i] = fetch_record + bs
|
||||
|
||||
if sys.platform == "win32":
|
||||
print("multiprocess is not fully compatible with Windows."
|
||||
"num_workers will be 1.")
|
||||
self.num_workers = 1
|
||||
|
||||
for label_infor in image_batch:
|
||||
outs = self.process(label_infor)
|
||||
if outs is None:
|
||||
continue
|
||||
yield outs
|
||||
|
||||
def batch_iter_reader():
|
||||
batch_outs = []
|
||||
for outs in sample_iter_reader():
|
||||
batch_outs.append(outs)
|
||||
if len(batch_outs) == self.batch_size:
|
||||
yield batch_outs
|
||||
batch_outs = []
|
||||
if self.use_mul_data:
|
||||
print("Sample date from multiple datasets!")
|
||||
for outs in sample_iter_reader_mul():
|
||||
batch_outs.append(outs)
|
||||
if len(batch_outs) == self.batch_size:
|
||||
yield batch_outs
|
||||
batch_outs = []
|
||||
else:
|
||||
for outs in sample_iter_reader():
|
||||
batch_outs.append(outs)
|
||||
if len(batch_outs) == self.batch_size:
|
||||
yield batch_outs
|
||||
batch_outs = []
|
||||
|
||||
return batch_iter_reader
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ import cv2
|
|||
import numpy as np
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
class EASTProcessTrain(object):
|
||||
def __init__(self, params):
|
||||
|
@ -52,7 +53,7 @@ class EASTProcessTrain(object):
|
|||
label_infor = label_infor.decode()
|
||||
label_infor = label_infor.encode('utf-8').decode('utf-8-sig')
|
||||
substr = label_infor.strip("\n").split("\t")
|
||||
img_path = self.img_set_dir + substr[0]
|
||||
img_path = os.path.join(self.img_set_dir, substr[0])
|
||||
label = json.loads(substr[1])
|
||||
nBox = len(label)
|
||||
wordBBs, txts, txt_tags = [], [], []
|
||||
|
|
|
@ -0,0 +1,781 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
|
||||
class SASTProcessTrain(object):
|
||||
"""
|
||||
SAST process function for training
|
||||
"""
|
||||
def __init__(self, params):
|
||||
self.img_set_dir = params['img_set_dir']
|
||||
self.min_crop_side_ratio = params['min_crop_side_ratio']
|
||||
self.min_crop_size = params['min_crop_size']
|
||||
image_shape = params['image_shape']
|
||||
self.input_size = image_shape[1]
|
||||
self.min_text_size = params['min_text_size']
|
||||
self.max_text_size = params['max_text_size']
|
||||
|
||||
def convert_label_infor(self, label_infor):
|
||||
label_infor = label_infor.decode()
|
||||
label_infor = label_infor.encode('utf-8').decode('utf-8-sig')
|
||||
substr = label_infor.strip("\n").split("\t")
|
||||
img_path = self.img_set_dir + substr[0]
|
||||
label = json.loads(substr[1])
|
||||
nBox = len(label)
|
||||
wordBBs, txts, txt_tags = [], [], []
|
||||
for bno in range(0, nBox):
|
||||
wordBB = label[bno]['points']
|
||||
txt = label[bno]['transcription']
|
||||
wordBBs.append(wordBB)
|
||||
txts.append(txt)
|
||||
if txt == '###':
|
||||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
wordBBs = np.array(wordBBs, dtype=np.float32)
|
||||
txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||
return img_path, wordBBs, txt_tags, txts
|
||||
|
||||
def quad_area(self, poly):
|
||||
"""
|
||||
compute area of a polygon
|
||||
:param poly:
|
||||
:return:
|
||||
"""
|
||||
edge = [
|
||||
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
|
||||
]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def gen_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if True:
|
||||
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad
|
||||
|
||||
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
|
||||
"""
|
||||
check so that the text poly is in the same direction,
|
||||
and also filter some invalid polygons
|
||||
:param polys:
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
(h, w) = xxx_todo_changeme
|
||||
if polys.shape[0] == 0:
|
||||
return polys, np.array([]), np.array([])
|
||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||
|
||||
validated_polys = []
|
||||
validated_tags = []
|
||||
hv_tags = []
|
||||
for poly, tag in zip(polys, tags):
|
||||
quad = self.gen_quad_from_poly(poly)
|
||||
p_area = self.quad_area(quad)
|
||||
if abs(p_area) < 1:
|
||||
print('invalid poly')
|
||||
continue
|
||||
if p_area > 0:
|
||||
if tag == False:
|
||||
print('poly in wrong direction')
|
||||
tag = True # reversed cases should be ignore
|
||||
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
|
||||
quad = quad[(0, 3, 2, 1), :]
|
||||
|
||||
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - quad[2])
|
||||
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
|
||||
hv_tag = 1
|
||||
|
||||
if len_w * 2.0 < len_h:
|
||||
hv_tag = 0
|
||||
|
||||
validated_polys.append(poly)
|
||||
validated_tags.append(tag)
|
||||
hv_tags.append(hv_tag)
|
||||
return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
|
||||
|
||||
def crop_area(self, im, polys, tags, hv_tags, txts, crop_background=False, max_tries=25):
|
||||
"""
|
||||
make random crop from the input image
|
||||
:param im:
|
||||
:param polys:
|
||||
:param tags:
|
||||
:param crop_background:
|
||||
:param max_tries: 50 -> 25
|
||||
:return:
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
pad_h = h // 10
|
||||
pad_w = w // 10
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w: maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h: maxy + pad_h] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return im, polys, tags, hv_tags, txts
|
||||
for i in range(max_tries):
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
# if xmax - xmin < ARGS.min_crop_side_ratio * w or \
|
||||
# ymax - ymin < ARGS.min_crop_side_ratio * h:
|
||||
if xmax - xmin < self.min_crop_size or \
|
||||
ymax - ymin < self.min_crop_size:
|
||||
# area too small
|
||||
continue
|
||||
if polys.shape[0] != 0:
|
||||
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
||||
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
||||
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
else:
|
||||
selected_polys = []
|
||||
if len(selected_polys) == 0:
|
||||
# no text in this area
|
||||
if crop_background:
|
||||
txts_tmp = []
|
||||
for selected_poly in selected_polys:
|
||||
txts_tmp.append(txts[selected_poly])
|
||||
txts = txts_tmp
|
||||
return im[ymin : ymax + 1, xmin : xmax + 1, :], \
|
||||
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
||||
else:
|
||||
continue
|
||||
im = im[ymin: ymax + 1, xmin: xmax + 1, :]
|
||||
polys = polys[selected_polys]
|
||||
tags = tags[selected_polys]
|
||||
hv_tags = hv_tags[selected_polys]
|
||||
txts_tmp = []
|
||||
for selected_poly in selected_polys:
|
||||
txts_tmp.append(txts[selected_poly])
|
||||
txts = txts_tmp
|
||||
polys[:, :, 0] -= xmin
|
||||
polys[:, :, 1] -= ymin
|
||||
return im, polys, tags, hv_tags, txts
|
||||
|
||||
return im, polys, tags, hv_tags, txts
|
||||
|
||||
def generate_direction_map(self, poly_quads, direction_map):
|
||||
"""
|
||||
"""
|
||||
width_list = []
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_w = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
width_list.append(quad_w)
|
||||
height_list.append(quad_h)
|
||||
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
|
||||
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
|
||||
|
||||
for quad in poly_quads:
|
||||
direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
||||
direct_vector = direct_vector_full / (np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
||||
direction_label = tuple(map(float, [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)]))
|
||||
cv2.fillPoly(direction_map, quad.round().astype(np.int32)[np.newaxis, :, :], direction_label)
|
||||
return direction_map
|
||||
|
||||
def calculate_average_height(self, poly_quads):
|
||||
"""
|
||||
"""
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
height_list.append(quad_h)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
return average_height
|
||||
|
||||
def generate_tcl_label(self, hw, polys, tags, ds_ratio,
|
||||
tcl_ratio=0.3, shrink_ratio_of_width=0.15):
|
||||
"""
|
||||
Generate polygon.
|
||||
"""
|
||||
h, w = hw
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
|
||||
score_map = np.zeros((h, w,), dtype=np.float32)
|
||||
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
||||
training_mask = np.ones((h, w,), dtype=np.float32)
|
||||
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape([1, 1, 3]).astype(np.float32)
|
||||
|
||||
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||
poly = poly_tag[0]
|
||||
tag = poly_tag[1]
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
||||
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
||||
continue
|
||||
|
||||
if tag:
|
||||
# continue
|
||||
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
||||
else:
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
tcl_quads = self.poly2quads(tcl_poly)
|
||||
poly_quads = self.poly2quads(poly)
|
||||
# stcl map
|
||||
stcl_quads, quad_index = self.shrink_poly_along_width(tcl_quads, shrink_ratio_of_width=shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0 / tcl_ratio)
|
||||
# generate tcl map
|
||||
cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
|
||||
|
||||
# generate tbo map
|
||||
for idx, quad in enumerate(stcl_quads):
|
||||
quad_mask = np.zeros((h, w), dtype=np.float32)
|
||||
quad_mask = cv2.fillPoly(quad_mask, np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
||||
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], quad_mask, tbo_map)
|
||||
return score_map, tbo_map, training_mask
|
||||
|
||||
def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25):
|
||||
"""
|
||||
Generate tcl map, tvo map and tbo map.
|
||||
"""
|
||||
h, w = hw
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
poly_mask = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
tvo_map = np.ones((9, h, w), dtype=np.float32)
|
||||
tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
|
||||
tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
|
||||
poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
|
||||
|
||||
# tco map
|
||||
tco_map = np.ones((3, h, w), dtype=np.float32)
|
||||
tco_map[0] = np.tile(np.arange(0, w), (h, 1))
|
||||
tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
|
||||
poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
|
||||
|
||||
poly_short_edge_map = np.ones((h, w), dtype=np.float32)
|
||||
|
||||
for poly, poly_tag in zip(polys, tags):
|
||||
|
||||
if poly_tag == True:
|
||||
continue
|
||||
|
||||
# adjust point order for vertical poly
|
||||
poly = self.adjust_point(poly)
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
# generate tcl map and text, 128 * 128
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
|
||||
# generate poly_tv_xy_map
|
||||
for idx in range(4):
|
||||
cv2.fillPoly(poly_tv_xy_map[2 * idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 0], 0), w)))
|
||||
cv2.fillPoly(poly_tv_xy_map[2 * idx + 1],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 1], 0), h)))
|
||||
|
||||
# generate poly_tc_xy_map
|
||||
for idx in range(2):
|
||||
cv2.fillPoly(poly_tc_xy_map[idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), float(center_point[idx]))
|
||||
|
||||
# generate poly_short_edge_map
|
||||
cv2.fillPoly(poly_short_edge_map,
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
|
||||
|
||||
# generate poly_mask and training_mask
|
||||
cv2.fillPoly(poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1)
|
||||
|
||||
tvo_map *= poly_mask
|
||||
tvo_map[:8] -= poly_tv_xy_map
|
||||
tvo_map[-1] /= poly_short_edge_map
|
||||
tvo_map = tvo_map.transpose((1, 2, 0))
|
||||
|
||||
tco_map *= poly_mask
|
||||
tco_map[:2] -= poly_tc_xy_map
|
||||
tco_map[-1] /= poly_short_edge_map
|
||||
tco_map = tco_map.transpose((1, 2, 0))
|
||||
|
||||
return tvo_map, tco_map
|
||||
|
||||
def adjust_point(self, poly):
|
||||
"""
|
||||
adjust point order.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
if point_num == 4:
|
||||
len_1 = np.linalg.norm(poly[0] - poly[1])
|
||||
len_2 = np.linalg.norm(poly[1] - poly[2])
|
||||
len_3 = np.linalg.norm(poly[2] - poly[3])
|
||||
len_4 = np.linalg.norm(poly[3] - poly[0])
|
||||
|
||||
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
||||
poly = poly[[1, 2, 3, 0], :]
|
||||
|
||||
elif point_num > 4:
|
||||
vector_1 = poly[0] - poly[1]
|
||||
vector_2 = poly[1] - poly[2]
|
||||
cos_theta = np.dot(vector_1, vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
||||
theta = np.arccos(np.round(cos_theta, decimals=4))
|
||||
|
||||
if abs(theta) > (70 / 180 * math.pi):
|
||||
index = list(range(1, point_num)) + [0]
|
||||
poly = poly[np.array(index), :]
|
||||
return poly
|
||||
|
||||
def gen_min_area_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if point_num == 4:
|
||||
min_area_quad = poly
|
||||
center_point = np.sum(poly, axis=0) / 4
|
||||
else:
|
||||
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad, center_point
|
||||
|
||||
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def shrink_poly_along_width(self, quads, shrink_ratio_of_width, expand_height_ratio=1.0):
|
||||
"""
|
||||
shrink poly with given length.
|
||||
"""
|
||||
upper_edge_list = []
|
||||
|
||||
def get_cut_info(edge_len_list, cut_len):
|
||||
for idx, edge_len in enumerate(edge_len_list):
|
||||
cut_len -= edge_len
|
||||
if cut_len <= 0.000001:
|
||||
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
||||
return idx, ratio
|
||||
|
||||
for quad in quads:
|
||||
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
||||
upper_edge_list.append(upper_edge_len)
|
||||
|
||||
# length of left edge and right edge.
|
||||
left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
|
||||
right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
|
||||
|
||||
shrink_length = min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
|
||||
# shrinking length
|
||||
upper_len_left = shrink_length
|
||||
upper_len_right = sum(upper_edge_list) - shrink_length
|
||||
|
||||
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
||||
left_quad = self.shrink_quad_along_width(quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
||||
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
||||
right_quad = self.shrink_quad_along_width(quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
||||
|
||||
out_quad_list = []
|
||||
if left_idx == right_idx:
|
||||
out_quad_list.append([left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
||||
else:
|
||||
out_quad_list.append(left_quad)
|
||||
for idx in range(left_idx + 1, right_idx):
|
||||
out_quad_list.append(quads[idx])
|
||||
out_quad_list.append(right_quad)
|
||||
|
||||
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
||||
|
||||
def vector_angle(self, A, B):
|
||||
"""
|
||||
Calculate the angle between vector AB and x-axis positive direction.
|
||||
"""
|
||||
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
||||
return np.arctan2(*AB)
|
||||
|
||||
def theta_line_cross_point(self, theta, point):
|
||||
"""
|
||||
Calculate the line through given point and angle in ax + by + c =0 form.
|
||||
"""
|
||||
x, y = point
|
||||
cos = np.cos(theta)
|
||||
sin = np.sin(theta)
|
||||
return [sin, -cos, cos * y - sin * x]
|
||||
|
||||
def line_cross_two_point(self, A, B):
|
||||
"""
|
||||
Calculate the line through given point A and B in ax + by + c =0 form.
|
||||
"""
|
||||
angle = self.vector_angle(A, B)
|
||||
return self.theta_line_cross_point(angle, A)
|
||||
|
||||
def average_angle(self, poly):
|
||||
"""
|
||||
Calculate the average angle between left and right edge in given poly.
|
||||
"""
|
||||
p0, p1, p2, p3 = poly
|
||||
angle30 = self.vector_angle(p3, p0)
|
||||
angle21 = self.vector_angle(p2, p1)
|
||||
return (angle30 + angle21) / 2
|
||||
|
||||
def line_cross_point(self, line1, line2):
|
||||
"""
|
||||
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
||||
"""
|
||||
a1, b1, c1 = line1
|
||||
a2, b2, c2 = line2
|
||||
d = a1 * b2 - a2 * b1
|
||||
|
||||
if d == 0:
|
||||
#print("line1", line1)
|
||||
#print("line2", line2)
|
||||
print('Cross point does not exist')
|
||||
return np.array([0, 0], dtype=np.float32)
|
||||
else:
|
||||
x = (b1 * c2 - b2 * c1) / d
|
||||
y = (a2 * c1 - a1 * c2) / d
|
||||
|
||||
return np.array([x, y], dtype=np.float32)
|
||||
|
||||
def quad2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point. (4, 2)
|
||||
"""
|
||||
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
||||
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
||||
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
||||
|
||||
def poly2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point.
|
||||
"""
|
||||
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
tcl_poly = np.zeros_like(poly)
|
||||
point_num = poly.shape[0]
|
||||
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
|
||||
tcl_poly[idx] = point_pair[0]
|
||||
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
||||
return tcl_poly
|
||||
|
||||
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
||||
"""
|
||||
Generate tbo_map for give quad.
|
||||
"""
|
||||
# upper and lower line function: ax + by + c = 0;
|
||||
up_line = self.line_cross_two_point(quad[0], quad[1])
|
||||
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
||||
|
||||
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]))
|
||||
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]))
|
||||
|
||||
# average angle of left and right line.
|
||||
angle = self.average_angle(quad)
|
||||
|
||||
xy_in_poly = np.argwhere(tcl_mask == 1)
|
||||
for y, x in xy_in_poly:
|
||||
point = (x, y)
|
||||
line = self.theta_line_cross_point(angle, point)
|
||||
cross_point_upper = self.line_cross_point(up_line, line)
|
||||
cross_point_lower = self.line_cross_point(lower_line, line)
|
||||
##FIX, offset reverse
|
||||
upper_offset_x, upper_offset_y = cross_point_upper - point
|
||||
lower_offset_x, lower_offset_y = cross_point_lower - point
|
||||
tbo_map[y, x, 0] = upper_offset_y
|
||||
tbo_map[y, x, 1] = upper_offset_x
|
||||
tbo_map[y, x, 2] = lower_offset_y
|
||||
tbo_map[y, x, 3] = lower_offset_x
|
||||
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
||||
return tbo_map
|
||||
|
||||
def poly2quads(self, poly):
|
||||
"""
|
||||
Split poly into quads.
|
||||
"""
|
||||
quad_list = []
|
||||
point_num = poly.shape[0]
|
||||
|
||||
# point pair
|
||||
point_pair_list = []
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
quad_num = point_num // 2 - 1
|
||||
for idx in range(quad_num):
|
||||
# reshape and adjust to clock-wise
|
||||
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]])
|
||||
|
||||
return np.array(quad_list)
|
||||
|
||||
def extract_polys(self, poly_txt_path):
|
||||
"""
|
||||
Read text_polys, txt_tags, txts from give txt file.
|
||||
"""
|
||||
text_polys, txt_tags, txts = [], [], []
|
||||
|
||||
with open(poly_txt_path) as f:
|
||||
for line in f.readlines():
|
||||
poly_str, txt = line.strip().split('\t')
|
||||
poly = map(float, poly_str.split(','))
|
||||
text_polys.append(np.array(poly, dtype=np.float32).reshape(-1, 2))
|
||||
txts.append(txt)
|
||||
if txt == '###':
|
||||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
|
||||
return np.array(map(np.array, text_polys)), \
|
||||
np.array(txt_tags, dtype=np.bool), txts
|
||||
|
||||
def __call__(self, label_infor):
|
||||
infor = self.convert_label_infor(label_infor)
|
||||
im_path, text_polys, text_tags, text_strs = infor
|
||||
im = cv2.imread(im_path)
|
||||
if im is None:
|
||||
return None
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
h, w, _ = im.shape
|
||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w))
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
#set aspect ratio and keep area fix
|
||||
asp_scales = np.arange(1.0, 1.55, 0.1)
|
||||
asp_scale = np.random.choice(asp_scales)
|
||||
|
||||
if np.random.rand() < 0.5:
|
||||
asp_scale = 1.0 / asp_scale
|
||||
asp_scale = math.sqrt(asp_scale)
|
||||
|
||||
asp_wx = asp_scale
|
||||
asp_hy = 1.0 / asp_scale
|
||||
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
||||
text_polys[:, :, 0] *= asp_wx
|
||||
text_polys[:, :, 1] *= asp_hy
|
||||
|
||||
h, w, _ = im.shape
|
||||
if max(h, w) > 2048:
|
||||
rd_scale = 2048.0 / max(h, w)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
h, w, _ = im.shape
|
||||
if min(h, w) < 16:
|
||||
return None
|
||||
|
||||
#no background
|
||||
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(im, \
|
||||
text_polys, text_tags, hv_tags, text_strs, crop_background=False)
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
#continue for all ignore case
|
||||
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||
return None
|
||||
new_h, new_w, _ = im.shape
|
||||
if (new_h is None) or (new_w is None):
|
||||
return None
|
||||
#resize image
|
||||
std_ratio = float(self.input_size) / max(new_w, new_h)
|
||||
rand_scales = np.array([0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
rz_scale = std_ratio * np.random.choice(rand_scales)
|
||||
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
||||
text_polys[:, :, 0] *= rz_scale
|
||||
text_polys[:, :, 1] *= rz_scale
|
||||
|
||||
#add gaussian blur
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
ks = np.random.permutation(5)[0] + 1
|
||||
ks = int(ks/2)*2 + 1
|
||||
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
||||
#add brighter
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 + np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
#add darker
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 - np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
|
||||
# Padding the im to [input_size, input_size]
|
||||
new_h, new_w, _ = im.shape
|
||||
if min(new_w, new_h) < self.input_size * 0.5:
|
||||
return None
|
||||
|
||||
im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32)
|
||||
im_padded[:, :, 2] = 0.485 * 255
|
||||
im_padded[:, :, 1] = 0.456 * 255
|
||||
im_padded[:, :, 0] = 0.406 * 255
|
||||
|
||||
# Random the start position
|
||||
del_h = self.input_size - new_h
|
||||
del_w = self.input_size - new_w
|
||||
sh, sw = 0, 0
|
||||
if del_h > 1:
|
||||
sh = int(np.random.rand() * del_h)
|
||||
if del_w > 1:
|
||||
sw = int(np.random.rand() * del_w)
|
||||
|
||||
# Padding
|
||||
im_padded[sh: sh + new_h, sw: sw + new_w, :] = im.copy()
|
||||
text_polys[:, :, 0] += sw
|
||||
text_polys[:, :, 1] += sh
|
||||
|
||||
score_map, border_map, training_mask = self.generate_tcl_label((self.input_size, self.input_size),
|
||||
text_polys, text_tags, 0.25)
|
||||
|
||||
# SAST head
|
||||
tvo_map, tco_map = self.generate_tvo_and_tco((self.input_size, self.input_size), text_polys, text_tags, tcl_ratio=0.3, ds_ratio=0.25)
|
||||
# print("test--------tvo_map shape:", tvo_map.shape)
|
||||
|
||||
im_padded[:, :, 2] -= 0.485 * 255
|
||||
im_padded[:, :, 1] -= 0.456 * 255
|
||||
im_padded[:, :, 0] -= 0.406 * 255
|
||||
im_padded[:, :, 2] /= (255.0 * 0.229)
|
||||
im_padded[:, :, 1] /= (255.0 * 0.224)
|
||||
im_padded[:, :, 0] /= (255.0 * 0.225)
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
|
||||
return im_padded[::-1, :, :], score_map[np.newaxis, :, :], border_map.transpose((2, 0, 1)), training_mask[np.newaxis, :, :], tvo_map.transpose((2, 0, 1)), tco_map.transpose((2, 0, 1))
|
||||
|
||||
|
||||
class SASTProcessTest(object):
|
||||
"""
|
||||
SAST process function for test
|
||||
"""
|
||||
def __init__(self, params):
|
||||
super(SASTProcessTest, self).__init__()
|
||||
if 'max_side_len' in params:
|
||||
self.max_side_len = params['max_side_len']
|
||||
else:
|
||||
self.max_side_len = 2400
|
||||
|
||||
def resize_image(self, im):
|
||||
"""
|
||||
resize image to a size multiple of max_stride which is required by the network
|
||||
:param im: the resized image
|
||||
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
:return: the resized image and the resize ratio
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(self.max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
def __call__(self, im):
|
||||
src_h, src_w, _ = im.shape
|
||||
im, (ratio_h, ratio_w) = self.resize_image(im)
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
img_std = [0.229, 0.224, 0.225]
|
||||
im = im[:, :, ::-1].astype(np.float32)
|
||||
im = im / 255
|
||||
im -= img_mean
|
||||
im /= img_std
|
||||
im = im.transpose((2, 0, 1))
|
||||
im = im[np.newaxis, :]
|
||||
return [im, (ratio_h, ratio_w, src_h, src_w)]
|
|
@ -26,7 +26,7 @@ from ppocr.utils.utility import initial_logger
|
|||
from ppocr.utils.utility import get_image_file_list
|
||||
logger = initial_logger()
|
||||
|
||||
from .img_tools import process_image, get_img_data
|
||||
from .img_tools import process_image, process_image_srn, get_img_data
|
||||
|
||||
|
||||
class LMDBReader(object):
|
||||
|
@ -43,6 +43,9 @@ class LMDBReader(object):
|
|||
self.mode = params['mode']
|
||||
self.drop_last = False
|
||||
self.use_tps = False
|
||||
self.num_heads = None
|
||||
if "num_heads" in params:
|
||||
self.num_heads = params['num_heads']
|
||||
if "tps" in params:
|
||||
self.ues_tps = True
|
||||
self.use_distort = False
|
||||
|
@ -119,12 +122,19 @@ class LMDBReader(object):
|
|||
img = cv2.imread(single_img)
|
||||
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
norm_img = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
char_ops=self.char_ops,
|
||||
tps=self.use_tps,
|
||||
infer_mode=True)
|
||||
if self.loss_type == 'srn':
|
||||
norm_img = process_image_srn(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
num_heads=self.num_heads,
|
||||
max_text_length=self.max_text_length)
|
||||
else:
|
||||
norm_img = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
char_ops=self.char_ops,
|
||||
tps=self.use_tps,
|
||||
infer_mode=True)
|
||||
yield norm_img
|
||||
else:
|
||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||
|
@ -144,14 +154,25 @@ class LMDBReader(object):
|
|||
if sample_info is None:
|
||||
continue
|
||||
img, label = sample_info
|
||||
outs = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
label=label,
|
||||
char_ops=self.char_ops,
|
||||
loss_type=self.loss_type,
|
||||
max_text_length=self.max_text_length,
|
||||
distort=self.use_distort)
|
||||
outs = []
|
||||
if self.loss_type == "srn":
|
||||
outs = process_image_srn(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
num_heads=self.num_heads,
|
||||
max_text_length=self.max_text_length,
|
||||
label=label,
|
||||
char_ops=self.char_ops,
|
||||
loss_type=self.loss_type)
|
||||
|
||||
else:
|
||||
outs = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
label=label,
|
||||
char_ops=self.char_ops,
|
||||
loss_type=self.loss_type,
|
||||
max_text_length=self.max_text_length)
|
||||
if outs is None:
|
||||
continue
|
||||
yield outs
|
||||
|
@ -185,6 +206,7 @@ class SimpleReader(object):
|
|||
if params['mode'] != 'test':
|
||||
self.img_set_dir = params['img_set_dir']
|
||||
self.label_file_path = params['label_file_path']
|
||||
self.use_gpu = params['use_gpu']
|
||||
self.char_ops = params['char_ops']
|
||||
self.image_shape = params['image_shape']
|
||||
self.loss_type = params['loss_type']
|
||||
|
@ -213,6 +235,15 @@ class SimpleReader(object):
|
|||
if self.mode != 'train':
|
||||
process_id = 0
|
||||
|
||||
def get_device_num():
|
||||
if self.use_gpu:
|
||||
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", '1')
|
||||
gpu_num = len(gpus.split(','))
|
||||
return gpu_num
|
||||
else:
|
||||
cpu_num = os.environ.get("CPU_NUM", 1)
|
||||
return int(cpu_num)
|
||||
|
||||
def sample_iter_reader():
|
||||
if self.mode != 'train' and self.infer_img is not None:
|
||||
image_file_list = get_image_file_list(self.infer_img)
|
||||
|
@ -237,6 +268,12 @@ class SimpleReader(object):
|
|||
print("multiprocess is not fully compatible with Windows."
|
||||
"num_workers will be 1.")
|
||||
self.num_workers = 1
|
||||
if self.batch_size * get_device_num(
|
||||
) * self.num_workers > img_num:
|
||||
raise Exception(
|
||||
"The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})".
|
||||
format(img_num, self.batch_size * get_device_num() *
|
||||
self.num_workers))
|
||||
for img_id in range(process_id, img_num, self.num_workers):
|
||||
label_infor = label_infor_list[img_id_list[img_id]]
|
||||
substr = label_infor.decode('utf-8').strip("\n").split("\t")
|
||||
|
|
|
@ -360,7 +360,7 @@ def process_image(img,
|
|||
text = char_ops.encode(label)
|
||||
if len(text) == 0 or len(text) > max_text_length:
|
||||
logger.info(
|
||||
"Warning in ppocr/data/rec/img_tools.py:line362: Wrong data type."
|
||||
"Warning in ppocr/data/rec/img_tools.py: Wrong data type."
|
||||
"Excepted string with length between 1 and {}, but "
|
||||
"got '{}'. Label is '{}'".format(max_text_length,
|
||||
len(text), label))
|
||||
|
@ -381,3 +381,84 @@ def process_image(img,
|
|||
assert False, "Unsupport loss_type %s in process_image"\
|
||||
% loss_type
|
||||
return (norm_img)
|
||||
|
||||
def resize_norm_img_srn(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
img_black = np.zeros((imgH, imgW))
|
||||
im_hei = img.shape[0]
|
||||
im_wid = img.shape[1]
|
||||
|
||||
if im_wid <= im_hei * 1:
|
||||
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||
elif im_wid <= im_hei * 2:
|
||||
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||
elif im_wid <= im_hei * 3:
|
||||
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||
else:
|
||||
img_new = cv2.resize(img, (imgW, imgH))
|
||||
|
||||
img_np = np.asarray(img_new)
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||
img_black[:, 0:img_np.shape[1]] = img_np
|
||||
img_black = img_black[:, :, np.newaxis]
|
||||
|
||||
row, col, c = img_black.shape
|
||||
c = 1
|
||||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
def srn_other_inputs(image_shape,
|
||||
num_heads,
|
||||
max_text_length):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
|
||||
|
||||
lbl_weight = np.array([37] * max_text_length).reshape((-1,1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]) * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape([-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]) * [-1e9]
|
||||
|
||||
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||
|
||||
return [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2]
|
||||
|
||||
def process_image_srn(img,
|
||||
image_shape,
|
||||
num_heads,
|
||||
max_text_length,
|
||||
label=None,
|
||||
char_ops=None,
|
||||
loss_type=None):
|
||||
norm_img = resize_norm_img_srn(img, image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
[lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||
srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||
|
||||
if label is not None:
|
||||
char_num = char_ops.get_char_num()
|
||||
text = char_ops.encode(label)
|
||||
if len(text) == 0 or len(text) > max_text_length:
|
||||
return None
|
||||
else:
|
||||
if loss_type == "srn":
|
||||
text_padded = [37] * max_text_length
|
||||
for i in range(len(text)):
|
||||
text_padded[i] = text[i]
|
||||
lbl_weight[i] = [1.0]
|
||||
text_padded = np.array(text_padded)
|
||||
text = text_padded.reshape(-1, 1)
|
||||
return (norm_img, text,encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2,lbl_weight)
|
||||
else:
|
||||
assert False, "Unsupport loss_type %s in process_image"\
|
||||
% loss_type
|
||||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2)
|
||||
|
|
|
@ -97,6 +97,23 @@ class DetModel(object):
|
|||
'shrink_mask':shrink_mask,\
|
||||
'threshold_map':threshold_map,\
|
||||
'threshold_mask':threshold_mask}
|
||||
elif self.algorithm == "SAST":
|
||||
input_score = fluid.layers.data(
|
||||
name='score', shape=[1, 128, 128], dtype='float32')
|
||||
input_border = fluid.layers.data(
|
||||
name='border', shape=[5, 128, 128], dtype='float32')
|
||||
input_mask = fluid.layers.data(
|
||||
name='mask', shape=[1, 128, 128], dtype='float32')
|
||||
input_tvo = fluid.layers.data(
|
||||
name='tvo', shape=[9, 128, 128], dtype='float32')
|
||||
input_tco = fluid.layers.data(
|
||||
name='tco', shape=[3, 128, 128], dtype='float32')
|
||||
feed_list = [image, input_score, input_border, input_mask, input_tvo, input_tco]
|
||||
labels = {'input_score': input_score,\
|
||||
'input_border': input_border,\
|
||||
'input_mask': input_mask,\
|
||||
'input_tvo': input_tvo,\
|
||||
'input_tco': input_tco}
|
||||
loader = fluid.io.DataLoader.from_generator(
|
||||
feed_list=feed_list,
|
||||
capacity=64,
|
||||
|
|
|
@ -58,6 +58,10 @@ class RecModel(object):
|
|||
self.loss_type = global_params['loss_type']
|
||||
self.image_shape = global_params['image_shape']
|
||||
self.max_text_length = global_params['max_text_length']
|
||||
if "num_heads" in global_params:
|
||||
self.num_heads = global_params["num_heads"]
|
||||
else:
|
||||
self.num_heads = None
|
||||
|
||||
def create_feed(self, mode):
|
||||
image_shape = deepcopy(self.image_shape)
|
||||
|
@ -77,6 +81,48 @@ class RecModel(object):
|
|||
lod_level=1)
|
||||
feed_list = [image, label_in, label_out]
|
||||
labels = {'label_in': label_in, 'label_out': label_out}
|
||||
elif self.loss_type == "srn":
|
||||
encoder_word_pos = fluid.data(
|
||||
name="encoder_word_pos",
|
||||
shape=[
|
||||
-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)),
|
||||
1
|
||||
],
|
||||
dtype="int64")
|
||||
gsrm_word_pos = fluid.data(
|
||||
name="gsrm_word_pos",
|
||||
shape=[-1, self.max_text_length, 1],
|
||||
dtype="int64")
|
||||
gsrm_slf_attn_bias1 = fluid.data(
|
||||
name="gsrm_slf_attn_bias1",
|
||||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
],
|
||||
dtype="float32")
|
||||
gsrm_slf_attn_bias2 = fluid.data(
|
||||
name="gsrm_slf_attn_bias2",
|
||||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
],
|
||||
dtype="float32")
|
||||
lbl_weight = fluid.layers.data(
|
||||
name="lbl_weight", shape=[-1, 1], dtype='int64')
|
||||
label = fluid.data(
|
||||
name='label', shape=[-1, 1], dtype='int32', lod_level=1)
|
||||
feed_list = [
|
||||
image, label, encoder_word_pos, gsrm_word_pos,
|
||||
gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight
|
||||
]
|
||||
labels = {
|
||||
'label': label,
|
||||
'encoder_word_pos': encoder_word_pos,
|
||||
'gsrm_word_pos': gsrm_word_pos,
|
||||
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2,
|
||||
'lbl_weight': lbl_weight
|
||||
}
|
||||
else:
|
||||
label = fluid.data(
|
||||
name='label', shape=[None, 1], dtype='int32', lod_level=1)
|
||||
|
@ -88,6 +134,8 @@ class RecModel(object):
|
|||
use_double_buffer=True,
|
||||
iterable=False)
|
||||
else:
|
||||
labels = None
|
||||
loader = None
|
||||
if self.char_type == "ch" and self.infer_img:
|
||||
image_shape[-1] = -1
|
||||
if self.tps != None:
|
||||
|
@ -98,8 +146,42 @@ class RecModel(object):
|
|||
)
|
||||
image_shape = deepcopy(self.image_shape)
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
labels = None
|
||||
loader = None
|
||||
if self.loss_type == "srn":
|
||||
encoder_word_pos = fluid.data(
|
||||
name="encoder_word_pos",
|
||||
shape=[
|
||||
-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)),
|
||||
1
|
||||
],
|
||||
dtype="int64")
|
||||
gsrm_word_pos = fluid.data(
|
||||
name="gsrm_word_pos",
|
||||
shape=[-1, self.max_text_length, 1],
|
||||
dtype="int64")
|
||||
gsrm_slf_attn_bias1 = fluid.data(
|
||||
name="gsrm_slf_attn_bias1",
|
||||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
],
|
||||
dtype="float32")
|
||||
gsrm_slf_attn_bias2 = fluid.data(
|
||||
name="gsrm_slf_attn_bias2",
|
||||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
],
|
||||
dtype="float32")
|
||||
feed_list = [
|
||||
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
labels = {
|
||||
'encoder_word_pos': encoder_word_pos,
|
||||
'gsrm_word_pos': gsrm_word_pos,
|
||||
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
|
||||
}
|
||||
return image, labels, loader
|
||||
|
||||
def __call__(self, mode):
|
||||
|
@ -117,13 +199,27 @@ class RecModel(object):
|
|||
label = labels['label_out']
|
||||
else:
|
||||
label = labels['label']
|
||||
outputs = {'total_loss':loss, 'decoded_out':\
|
||||
decoded_out, 'label':label}
|
||||
if self.loss_type == 'srn':
|
||||
total_loss, img_loss, word_loss = self.loss(predicts, labels)
|
||||
outputs = {
|
||||
'total_loss': total_loss,
|
||||
'img_loss': img_loss,
|
||||
'word_loss': word_loss,
|
||||
'decoded_out': decoded_out,
|
||||
'label': label
|
||||
}
|
||||
else:
|
||||
outputs = {'total_loss':loss, 'decoded_out':\
|
||||
decoded_out, 'label':label}
|
||||
return loader, outputs
|
||||
|
||||
elif mode == "export":
|
||||
predict = predicts['predict']
|
||||
if self.loss_type == "ctc":
|
||||
predict = fluid.layers.softmax(predict)
|
||||
if self.loss_type == "srn":
|
||||
raise Exception(
|
||||
"Warning! SRN does not support export model currently")
|
||||
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
|
||||
else:
|
||||
predict = predicts['predict']
|
||||
|
|
|
@ -0,0 +1,274 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
|
||||
__all__ = ["ResNet"]
|
||||
|
||||
|
||||
class ResNet(object):
|
||||
def __init__(self, params):
|
||||
"""
|
||||
the Resnet backbone network for detection module.
|
||||
Args:
|
||||
params(dict): the super parameters for network build
|
||||
"""
|
||||
self.layers = params['layers']
|
||||
supported_layers = [18, 34, 50, 101, 152]
|
||||
assert self.layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(supported_layers, self.layers)
|
||||
self.is_3x3 = True
|
||||
|
||||
def __call__(self, input):
|
||||
layers = self.layers
|
||||
is_3x3 = self.is_3x3
|
||||
# if layers == 18:
|
||||
# depth = [2, 2, 2, 2]
|
||||
# elif layers == 34 or layers == 50:
|
||||
# depth = [3, 4, 6, 3]
|
||||
# elif layers == 101:
|
||||
# depth = [3, 4, 23, 3]
|
||||
# elif layers == 152:
|
||||
# depth = [3, 8, 36, 3]
|
||||
# elif layers == 200:
|
||||
# depth = [3, 12, 48, 3]
|
||||
# num_filters = [64, 128, 256, 512]
|
||||
# outs = []
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]#, 3, 3]
|
||||
elif layers == 34 or layers == 50:
|
||||
#depth = [3, 4, 6, 3]#, 3, 3]
|
||||
depth = [3, 4, 6, 3, 3]#, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]#, 3, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]#, 3, 3]
|
||||
num_filters = [64, 128, 256, 512, 512]#, 512]
|
||||
blocks = {}
|
||||
|
||||
idx = 'block_0'
|
||||
blocks[idx] = input
|
||||
|
||||
if is_3x3 == False:
|
||||
conv = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=64,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act='relu')
|
||||
else:
|
||||
conv = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=32,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name='conv1_1')
|
||||
conv = self.conv_bn_layer(
|
||||
input=conv,
|
||||
num_filters=32,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name='conv1_2')
|
||||
conv = self.conv_bn_layer(
|
||||
input=conv,
|
||||
num_filters=64,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name='conv1_3')
|
||||
idx = 'block_1'
|
||||
blocks[idx] = conv
|
||||
|
||||
conv = fluid.layers.pool2d(
|
||||
input=conv,
|
||||
pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_padding=1,
|
||||
pool_type='max')
|
||||
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152, 200] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
conv = self.bottleneck_block(
|
||||
input=conv,
|
||||
num_filters=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name)
|
||||
# outs.append(conv)
|
||||
idx = 'block_' + str(block + 2)
|
||||
blocks[idx] = conv
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
conv = self.basic_block(
|
||||
input=conv,
|
||||
num_filters=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name)
|
||||
# outs.append(conv)
|
||||
idx = 'block_' + str(block + 2)
|
||||
blocks[idx] = conv
|
||||
# return outs
|
||||
return blocks
|
||||
|
||||
def conv_bn_layer(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def conv_bn_layer_new(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
pool = fluid.layers.pool2d(
|
||||
input=input,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
pool_padding=0,
|
||||
pool_type='avg',
|
||||
ceil_mode=True)
|
||||
|
||||
conv = fluid.layers.conv2d(
|
||||
input=pool,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=1,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def shortcut(self, input, ch_out, stride, name, if_first=False):
|
||||
ch_in = input.shape[1]
|
||||
if ch_in != ch_out or stride != 1:
|
||||
if if_first:
|
||||
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
|
||||
else:
|
||||
return self.conv_bn_layer_new(
|
||||
input, ch_out, 1, stride, name=name)
|
||||
elif if_first:
|
||||
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
|
||||
else:
|
||||
return input
|
||||
|
||||
def bottleneck_block(self, input, num_filters, stride, name, if_first):
|
||||
conv0 = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(
|
||||
input=conv0,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
conv2 = self.conv_bn_layer(
|
||||
input=conv1,
|
||||
num_filters=num_filters * 4,
|
||||
filter_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
short = self.shortcut(
|
||||
input,
|
||||
num_filters * 4,
|
||||
stride,
|
||||
if_first=if_first,
|
||||
name=name + "_branch1")
|
||||
|
||||
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
|
||||
|
||||
def basic_block(self, input, num_filters, stride, name, if_first):
|
||||
conv0 = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
act='relu',
|
||||
stride=stride,
|
||||
name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(
|
||||
input=conv0,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
short = self.shortcut(
|
||||
input,
|
||||
num_filters,
|
||||
stride,
|
||||
if_first=if_first,
|
||||
name=name + "_branch1")
|
||||
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
|
|
@ -31,16 +31,28 @@ __all__ = [
|
|||
|
||||
class MobileNetV3():
|
||||
def __init__(self, params):
|
||||
self.scale = params['scale']
|
||||
model_name = params['model_name']
|
||||
self.scale = params.get("scale", 0.5)
|
||||
model_name = params.get("model_name", "small")
|
||||
large_stride = params.get("large_stride", [1, 2, 2, 2])
|
||||
small_stride = params.get("small_stride", [2, 2, 2, 2])
|
||||
|
||||
assert isinstance(large_stride, list), "large_stride type must " \
|
||||
"be list but got {}".format(type(large_stride))
|
||||
assert isinstance(small_stride, list), "small_stride type must " \
|
||||
"be list but got {}".format(type(small_stride))
|
||||
assert len(large_stride) == 4, "large_stride length must be " \
|
||||
"4 but got {}".format(len(large_stride))
|
||||
assert len(small_stride) == 4, "small_stride length must be " \
|
||||
"4 but got {}".format(len(small_stride))
|
||||
|
||||
self.inplanes = 16
|
||||
if model_name == "large":
|
||||
self.cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, False, 'relu', 1],
|
||||
[3, 64, 24, False, 'relu', (2, 1)],
|
||||
[3, 16, 16, False, 'relu', large_stride[0]],
|
||||
[3, 64, 24, False, 'relu', (large_stride[1], 1)],
|
||||
[3, 72, 24, False, 'relu', 1],
|
||||
[5, 72, 40, True, 'relu', (2, 1)],
|
||||
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[3, 240, 80, False, 'hard_swish', 1],
|
||||
|
@ -49,7 +61,7 @@ class MobileNetV3():
|
|||
[3, 184, 80, False, 'hard_swish', 1],
|
||||
[3, 480, 112, True, 'hard_swish', 1],
|
||||
[3, 672, 112, True, 'hard_swish', 1],
|
||||
[5, 672, 160, True, 'hard_swish', (2, 1)],
|
||||
[5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
|
||||
[5, 960, 160, True, 'hard_swish', 1],
|
||||
[5, 960, 160, True, 'hard_swish', 1],
|
||||
]
|
||||
|
@ -58,15 +70,15 @@ class MobileNetV3():
|
|||
elif model_name == "small":
|
||||
self.cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, True, 'relu', (2, 1)],
|
||||
[3, 72, 24, False, 'relu', (2, 1)],
|
||||
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
|
||||
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
|
||||
[3, 88, 24, False, 'relu', 1],
|
||||
[5, 96, 40, True, 'hard_swish', (2, 1)],
|
||||
[5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
|
||||
[5, 240, 40, True, 'hard_swish', 1],
|
||||
[5, 240, 40, True, 'hard_swish', 1],
|
||||
[5, 120, 48, True, 'hard_swish', 1],
|
||||
[5, 144, 48, True, 'hard_swish', 1],
|
||||
[5, 288, 96, True, 'hard_swish', (2, 1)],
|
||||
[5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
|
||||
[5, 576, 96, True, 'hard_swish', 1],
|
||||
[5, 576, 96, True, 'hard_swish', 1],
|
||||
]
|
||||
|
|
|
@ -0,0 +1,172 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
|
||||
|
||||
__all__ = ["ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
|
||||
|
||||
Trainable = True
|
||||
w_nolr = fluid.ParamAttr(
|
||||
trainable = Trainable)
|
||||
train_parameters = {
|
||||
"input_size": [3, 224, 224],
|
||||
"input_mean": [0.485, 0.456, 0.406],
|
||||
"input_std": [0.229, 0.224, 0.225],
|
||||
"learning_strategy": {
|
||||
"name": "piecewise_decay",
|
||||
"batch_size": 256,
|
||||
"epochs": [30, 60, 90],
|
||||
"steps": [0.1, 0.01, 0.001, 0.0001]
|
||||
}
|
||||
}
|
||||
|
||||
class ResNet():
|
||||
def __init__(self, params):
|
||||
self.layers = params['layers']
|
||||
self.params = train_parameters
|
||||
|
||||
|
||||
def __call__(self, input):
|
||||
layers = self.layers
|
||||
supported_layers = [18, 34, 50, 101, 152]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
depth = [3, 4, 6, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
stride_list = [(2,2),(2,2),(1,1),(1,1)]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
conv = self.conv_bn_layer(
|
||||
input=input, num_filters=64, filter_size=7, stride=2, act='relu', name="conv1")
|
||||
F = []
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
conv = self.bottleneck_block(
|
||||
input=conv,
|
||||
num_filters=num_filters[block],
|
||||
stride=stride_list[block] if i == 0 else 1, name=conv_name)
|
||||
F.append(conv)
|
||||
|
||||
base = F[-1]
|
||||
for i in [-2, -3]:
|
||||
b, c, w, h = F[i].shape
|
||||
if (w,h) == base.shape[2:]:
|
||||
base = base
|
||||
else:
|
||||
base = fluid.layers.conv2d_transpose( input=base, num_filters=c,filter_size=4, stride=2,
|
||||
padding=1,act=None,
|
||||
param_attr=w_nolr,
|
||||
bias_attr=w_nolr)
|
||||
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.concat([base, F[i]], axis=1)
|
||||
base = fluid.layers.conv2d(base, num_filters=c, filter_size=1, param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.conv2d(base, num_filters=c, filter_size=3,padding = 1, param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr)
|
||||
|
||||
base = fluid.layers.conv2d(base, num_filters=512, filter_size=1,bias_attr=w_nolr,param_attr=w_nolr)
|
||||
|
||||
return base
|
||||
|
||||
def conv_bn_layer(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size= 2 if stride==(1,1) else filter_size,
|
||||
dilation = 2 if stride==(1,1) else 1,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights",trainable = Trainable),
|
||||
bias_attr=False,
|
||||
name=name + '.conv2d.output.1')
|
||||
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
return fluid.layers.batch_norm(input=conv,
|
||||
act=act,
|
||||
name=bn_name + '.output.1',
|
||||
param_attr=ParamAttr(name=bn_name + '_scale',trainable = Trainable),
|
||||
bias_attr=ParamAttr(bn_name + '_offset',trainable = Trainable),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance', )
|
||||
|
||||
def shortcut(self, input, ch_out, stride, is_first, name):
|
||||
ch_in = input.shape[1]
|
||||
if ch_in != ch_out or stride != 1 or is_first == True:
|
||||
if stride == (1,1):
|
||||
return self.conv_bn_layer(input, ch_out, 1, 1, name=name)
|
||||
else: #stride == (2,2)
|
||||
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
|
||||
|
||||
else:
|
||||
return input
|
||||
|
||||
def bottleneck_block(self, input, num_filters, stride, name):
|
||||
conv0 = self.conv_bn_layer(
|
||||
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(
|
||||
input=conv0,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
conv2 = self.conv_bn_layer(
|
||||
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c")
|
||||
|
||||
short = self.shortcut(input, num_filters * 4, stride, is_first=False, name=name + "_branch1")
|
||||
|
||||
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu', name=name + ".add.output.5")
|
||||
|
||||
def basic_block(self, input, num_filters, stride, is_first, name):
|
||||
conv0 = self.conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride,
|
||||
name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(input=conv0, num_filters=num_filters, filter_size=3, act=None,
|
||||
name=name + "_branch2b")
|
||||
short = self.shortcut(input, num_filters, stride, is_first, name=name + "_branch1")
|
||||
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
|
|
@ -0,0 +1,228 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from ..common_functions import conv_bn_layer, deconv_bn_layer
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class SASTHead(object):
|
||||
"""
|
||||
SAST:
|
||||
see arxiv: https://arxiv.org/abs/1908.05498
|
||||
args:
|
||||
params(dict): the super parameters for network build
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
self.model_name = params['model_name']
|
||||
self.with_cab = params['with_cab']
|
||||
|
||||
def FPN_Up_Fusion(self, blocks):
|
||||
"""
|
||||
blocks{}: contain block_2, block_3, block_4, block_5, block_6, block_7 with
|
||||
1/4, 1/8, 1/16, 1/32, 1/64, 1/128 resolution.
|
||||
"""
|
||||
f = [blocks['block_6'], blocks['block_5'], blocks['block_4'], blocks['block_3'], blocks['block_2']]
|
||||
num_outputs = [256, 256, 192, 192, 128]
|
||||
g = [None, None, None, None, None]
|
||||
h = [None, None, None, None, None]
|
||||
for i in range(5):
|
||||
h[i] = conv_bn_layer(input=f[i], num_filters=num_outputs[i],
|
||||
filter_size=1, stride=1, act=None, name='fpn_up_h'+str(i))
|
||||
|
||||
for i in range(4):
|
||||
if i == 0:
|
||||
g[i] = deconv_bn_layer(input=h[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g0')
|
||||
#print("g[{}] shape: {}".format(i, g[i].shape))
|
||||
else:
|
||||
g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
|
||||
g[i] = fluid.layers.relu(g[i])
|
||||
#g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i],
|
||||
# filter_size=1, stride=1, act='relu')
|
||||
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i],
|
||||
filter_size=3, stride=1, act='relu', name='fpn_up_g%d_1'%i)
|
||||
g[i] = deconv_bn_layer(input=g[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g%d_2'%i)
|
||||
#print("g[{}] shape: {}".format(i, g[i].shape))
|
||||
|
||||
g[4] = fluid.layers.elementwise_add(x=g[3], y=h[4])
|
||||
g[4] = fluid.layers.relu(g[4])
|
||||
g[4] = conv_bn_layer(input=g[4], num_filters=num_outputs[4],
|
||||
filter_size=3, stride=1, act='relu', name='fpn_up_fusion_1')
|
||||
g[4] = conv_bn_layer(input=g[4], num_filters=num_outputs[4],
|
||||
filter_size=1, stride=1, act=None, name='fpn_up_fusion_2')
|
||||
|
||||
return g[4]
|
||||
|
||||
def FPN_Down_Fusion(self, blocks):
|
||||
"""
|
||||
blocks{}: contain block_2, block_3, block_4, block_5, block_6, block_7 with
|
||||
1/4, 1/8, 1/16, 1/32, 1/64, 1/128 resolution.
|
||||
"""
|
||||
f = [blocks['block_0'], blocks['block_1'], blocks['block_2']]
|
||||
num_outputs = [32, 64, 128]
|
||||
g = [None, None, None]
|
||||
h = [None, None, None]
|
||||
for i in range(3):
|
||||
h[i] = conv_bn_layer(input=f[i], num_filters=num_outputs[i],
|
||||
filter_size=3, stride=1, act=None, name='fpn_down_h'+str(i))
|
||||
for i in range(2):
|
||||
if i == 0:
|
||||
g[i] = conv_bn_layer(input=h[i], num_filters=num_outputs[i+1], filter_size=3, stride=2, act=None, name='fpn_down_g0')
|
||||
else:
|
||||
g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
|
||||
g[i] = fluid.layers.relu(g[i])
|
||||
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i], filter_size=3, stride=1, act='relu', name='fpn_down_g%d_1'%i)
|
||||
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i+1], filter_size=3, stride=2, act=None, name='fpn_down_g%d_2'%i)
|
||||
# print("g[{}] shape: {}".format(i, g[i].shape))
|
||||
g[2] = fluid.layers.elementwise_add(x=g[1], y=h[2])
|
||||
g[2] = fluid.layers.relu(g[2])
|
||||
g[2] = conv_bn_layer(input=g[2], num_filters=num_outputs[2],
|
||||
filter_size=3, stride=1, act='relu', name='fpn_down_fusion_1')
|
||||
g[2] = conv_bn_layer(input=g[2], num_filters=num_outputs[2],
|
||||
filter_size=1, stride=1, act=None, name='fpn_down_fusion_2')
|
||||
return g[2]
|
||||
|
||||
def SAST_Header1(self, f_common):
|
||||
"""Detector header."""
|
||||
#f_score
|
||||
f_score = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_score1')
|
||||
f_score = conv_bn_layer(input=f_score, num_filters=64, filter_size=3, stride=1, act='relu', name='f_score2')
|
||||
f_score = conv_bn_layer(input=f_score, num_filters=128, filter_size=1, stride=1, act='relu', name='f_score3')
|
||||
f_score = conv_bn_layer(input=f_score, num_filters=1, filter_size=3, stride=1, name='f_score4')
|
||||
f_score = fluid.layers.sigmoid(f_score)
|
||||
# print("f_score shape: {}".format(f_score.shape))
|
||||
|
||||
#f_boder
|
||||
f_border = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_border1')
|
||||
f_border = conv_bn_layer(input=f_border, num_filters=64, filter_size=3, stride=1, act='relu', name='f_border2')
|
||||
f_border = conv_bn_layer(input=f_border, num_filters=128, filter_size=1, stride=1, act='relu', name='f_border3')
|
||||
f_border = conv_bn_layer(input=f_border, num_filters=4, filter_size=3, stride=1, name='f_border4')
|
||||
# print("f_border shape: {}".format(f_border.shape))
|
||||
|
||||
return f_score, f_border
|
||||
|
||||
def SAST_Header2(self, f_common):
|
||||
"""Detector header."""
|
||||
#f_tvo
|
||||
f_tvo = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_tvo1')
|
||||
f_tvo = conv_bn_layer(input=f_tvo, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tvo2')
|
||||
f_tvo = conv_bn_layer(input=f_tvo, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tvo3')
|
||||
f_tvo = conv_bn_layer(input=f_tvo, num_filters=8, filter_size=3, stride=1, name='f_tvo4')
|
||||
# print("f_tvo shape: {}".format(f_tvo.shape))
|
||||
|
||||
#f_tco
|
||||
f_tco = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_tco1')
|
||||
f_tco = conv_bn_layer(input=f_tco, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tco2')
|
||||
f_tco = conv_bn_layer(input=f_tco, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tco3')
|
||||
f_tco = conv_bn_layer(input=f_tco, num_filters=2, filter_size=3, stride=1, name='f_tco4')
|
||||
# print("f_tco shape: {}".format(f_tco.shape))
|
||||
|
||||
return f_tvo, f_tco
|
||||
|
||||
def cross_attention(self, f_common):
|
||||
"""
|
||||
"""
|
||||
f_shape = fluid.layers.shape(f_common)
|
||||
f_theta = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, act='relu', name='f_theta')
|
||||
f_phi = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, act='relu', name='f_phi')
|
||||
f_g = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, act='relu', name='f_g')
|
||||
### horizon
|
||||
fh_theta = f_theta
|
||||
fh_phi = f_phi
|
||||
fh_g = f_g
|
||||
#flatten
|
||||
fh_theta = fluid.layers.transpose(fh_theta, [0, 2, 3, 1])
|
||||
fh_theta = fluid.layers.reshape(fh_theta, [f_shape[0] * f_shape[2], f_shape[3], 128])
|
||||
fh_phi = fluid.layers.transpose(fh_phi, [0, 2, 3, 1])
|
||||
fh_phi = fluid.layers.reshape(fh_phi, [f_shape[0] * f_shape[2], f_shape[3], 128])
|
||||
fh_g = fluid.layers.transpose(fh_g, [0, 2, 3, 1])
|
||||
fh_g = fluid.layers.reshape(fh_g, [f_shape[0] * f_shape[2], f_shape[3], 128])
|
||||
#correlation
|
||||
fh_attn = fluid.layers.matmul(fh_theta, fluid.layers.transpose(fh_phi, [0, 2, 1]))
|
||||
#scale
|
||||
fh_attn = fh_attn / (128 ** 0.5)
|
||||
fh_attn = fluid.layers.softmax(fh_attn)
|
||||
#weighted sum
|
||||
fh_weight = fluid.layers.matmul(fh_attn, fh_g)
|
||||
fh_weight = fluid.layers.reshape(fh_weight, [f_shape[0], f_shape[2], f_shape[3], 128])
|
||||
# print("fh_weight: {}".format(fh_weight.shape))
|
||||
fh_weight = fluid.layers.transpose(fh_weight, [0, 3, 1, 2])
|
||||
fh_weight = conv_bn_layer(input=fh_weight, num_filters=128, filter_size=1, stride=1, name='fh_weight')
|
||||
#short cut
|
||||
fh_sc = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, name='fh_sc')
|
||||
f_h = fluid.layers.relu(fh_weight + fh_sc)
|
||||
######
|
||||
#vertical
|
||||
fv_theta = fluid.layers.transpose(f_theta, [0, 1, 3, 2])
|
||||
fv_phi = fluid.layers.transpose(f_phi, [0, 1, 3, 2])
|
||||
fv_g = fluid.layers.transpose(f_g, [0, 1, 3, 2])
|
||||
#flatten
|
||||
fv_theta = fluid.layers.transpose(fv_theta, [0, 2, 3, 1])
|
||||
fv_theta = fluid.layers.reshape(fv_theta, [f_shape[0] * f_shape[3], f_shape[2], 128])
|
||||
fv_phi = fluid.layers.transpose(fv_phi, [0, 2, 3, 1])
|
||||
fv_phi = fluid.layers.reshape(fv_phi, [f_shape[0] * f_shape[3], f_shape[2], 128])
|
||||
fv_g = fluid.layers.transpose(fv_g, [0, 2, 3, 1])
|
||||
fv_g = fluid.layers.reshape(fv_g, [f_shape[0] * f_shape[3], f_shape[2], 128])
|
||||
#correlation
|
||||
fv_attn = fluid.layers.matmul(fv_theta, fluid.layers.transpose(fv_phi, [0, 2, 1]))
|
||||
#scale
|
||||
fv_attn = fv_attn / (128 ** 0.5)
|
||||
fv_attn = fluid.layers.softmax(fv_attn)
|
||||
#weighted sum
|
||||
fv_weight = fluid.layers.matmul(fv_attn, fv_g)
|
||||
fv_weight = fluid.layers.reshape(fv_weight, [f_shape[0], f_shape[3], f_shape[2], 128])
|
||||
# print("fv_weight: {}".format(fv_weight.shape))
|
||||
fv_weight = fluid.layers.transpose(fv_weight, [0, 3, 2, 1])
|
||||
fv_weight = conv_bn_layer(input=fv_weight, num_filters=128, filter_size=1, stride=1, name='fv_weight')
|
||||
#short cut
|
||||
fv_sc = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, name='fv_sc')
|
||||
f_v = fluid.layers.relu(fv_weight + fv_sc)
|
||||
######
|
||||
f_attn = fluid.layers.concat([f_h, f_v], axis=1)
|
||||
f_attn = conv_bn_layer(input=f_attn, num_filters=128, filter_size=1, stride=1, act='relu', name='f_attn')
|
||||
return f_attn
|
||||
|
||||
def __call__(self, blocks, with_cab=False):
|
||||
# for k, v in blocks.items():
|
||||
# print(k, v.shape)
|
||||
|
||||
#down fpn
|
||||
f_down = self.FPN_Down_Fusion(blocks)
|
||||
# print("f_down shape: {}".format(f_down.shape))
|
||||
#up fpn
|
||||
f_up = self.FPN_Up_Fusion(blocks)
|
||||
# print("f_up shape: {}".format(f_up.shape))
|
||||
#fusion
|
||||
f_common = fluid.layers.elementwise_add(x=f_down, y=f_up)
|
||||
f_common = fluid.layers.relu(f_common)
|
||||
# print("f_common: {}".format(f_common.shape))
|
||||
|
||||
if self.with_cab:
|
||||
# print('enhence f_common with CAB.')
|
||||
f_common = self.cross_attention(f_common)
|
||||
|
||||
f_score, f_border= self.SAST_Header1(f_common)
|
||||
f_tvo, f_tco = self.SAST_Header2(f_common)
|
||||
|
||||
predicts = OrderedDict()
|
||||
predicts['f_score'] = f_score
|
||||
predicts['f_border'] = f_border
|
||||
predicts['f_tvo'] = f_tvo
|
||||
predicts['f_tco'] = f_tco
|
||||
return predicts
|
|
@ -32,6 +32,7 @@ class CTCPredict(object):
|
|||
self.char_num = params['char_num']
|
||||
self.encoder = SequenceEncoder(params)
|
||||
self.encoder_type = params['encoder_type']
|
||||
self.fc_decay = params.get("fc_decay", 0.0004)
|
||||
|
||||
def __call__(self, inputs, labels=None, mode=None):
|
||||
encoder_features = self.encoder(inputs)
|
||||
|
@ -39,7 +40,7 @@ class CTCPredict(object):
|
|||
encoder_features = fluid.layers.concat(encoder_features, axis=1)
|
||||
name = "ctc_fc"
|
||||
para_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=0.0004, k=encoder_features.shape[1], name=name)
|
||||
l2_decay=self.fc_decay, k=encoder_features.shape[1], name=name)
|
||||
predict = fluid.layers.fc(input=encoder_features,
|
||||
size=self.char_num + 1,
|
||||
param_attr=para_attr,
|
||||
|
|
|
@ -0,0 +1,230 @@
|
|||
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
import numpy as np
|
||||
from .self_attention.model import wrap_encoder
|
||||
from .self_attention.model import wrap_encoder_forFeature
|
||||
gradient_clip = 10
|
||||
|
||||
|
||||
class SRNPredict(object):
|
||||
def __init__(self, params):
|
||||
super(SRNPredict, self).__init__()
|
||||
self.char_num = params['char_num']
|
||||
self.max_length = params['max_text_length']
|
||||
|
||||
self.num_heads = params['num_heads']
|
||||
self.num_encoder_TUs = params['num_encoder_TUs']
|
||||
self.num_decoder_TUs = params['num_decoder_TUs']
|
||||
self.hidden_dims = params['hidden_dims']
|
||||
|
||||
def pvam(self, inputs, others):
|
||||
|
||||
b, c, h, w = inputs.shape
|
||||
conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w])
|
||||
conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1])
|
||||
|
||||
#===== Transformer encoder =====
|
||||
b, t, c = conv_features.shape
|
||||
encoder_word_pos = others["encoder_word_pos"]
|
||||
gsrm_word_pos = others["gsrm_word_pos"]
|
||||
|
||||
enc_inputs = [conv_features, encoder_word_pos, None]
|
||||
word_features = wrap_encoder_forFeature(
|
||||
src_vocab_size=-1,
|
||||
max_length=t,
|
||||
n_layer=self.num_encoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
d_key=int(self.hidden_dims / self.num_heads),
|
||||
d_value=int(self.hidden_dims / self.num_heads),
|
||||
d_model=self.hidden_dims,
|
||||
d_inner_hid=self.hidden_dims,
|
||||
prepostprocess_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
relu_dropout=0.1,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True,
|
||||
enc_inputs=enc_inputs, )
|
||||
fluid.clip.set_gradient_clip(
|
||||
fluid.clip.GradientClipByValue(gradient_clip))
|
||||
|
||||
#===== Parallel Visual Attention Module =====
|
||||
b, t, c = word_features.shape
|
||||
|
||||
word_features = fluid.layers.fc(word_features, c, num_flatten_dims=2)
|
||||
word_features_ = fluid.layers.reshape(word_features, [-1, 1, t, c])
|
||||
word_features_ = fluid.layers.expand(word_features_,
|
||||
[1, self.max_length, 1, 1])
|
||||
word_pos_feature = fluid.layers.embedding(gsrm_word_pos,
|
||||
[self.max_length, c])
|
||||
word_pos_ = fluid.layers.reshape(word_pos_feature,
|
||||
[-1, self.max_length, 1, c])
|
||||
word_pos_ = fluid.layers.expand(word_pos_, [1, 1, t, 1])
|
||||
temp = fluid.layers.elementwise_add(
|
||||
word_features_, word_pos_, act='tanh')
|
||||
|
||||
attention_weight = fluid.layers.fc(input=temp,
|
||||
size=1,
|
||||
num_flatten_dims=3,
|
||||
bias_attr=False)
|
||||
attention_weight = fluid.layers.reshape(
|
||||
x=attention_weight, shape=[-1, self.max_length, t])
|
||||
attention_weight = fluid.layers.softmax(input=attention_weight, axis=-1)
|
||||
|
||||
pvam_features = fluid.layers.matmul(attention_weight,
|
||||
word_features) #[b, max_length, c]
|
||||
|
||||
return pvam_features
|
||||
|
||||
def gsrm(self, pvam_features, others):
|
||||
|
||||
#===== GSRM Visual-to-semantic embedding block =====
|
||||
b, t, c = pvam_features.shape
|
||||
word_out = fluid.layers.fc(
|
||||
input=fluid.layers.reshape(pvam_features, [-1, c]),
|
||||
size=self.char_num,
|
||||
act="softmax")
|
||||
#word_out.stop_gradient = True
|
||||
word_ids = fluid.layers.argmax(word_out, axis=1)
|
||||
word_ids.stop_gradient = True
|
||||
word_ids = fluid.layers.reshape(x=word_ids, shape=[-1, t, 1])
|
||||
|
||||
#===== GSRM Semantic reasoning block =====
|
||||
"""
|
||||
This module is achieved through bi-transformers,
|
||||
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
|
||||
"""
|
||||
pad_idx = self.char_num
|
||||
gsrm_word_pos = others["gsrm_word_pos"]
|
||||
gsrm_slf_attn_bias1 = others["gsrm_slf_attn_bias1"]
|
||||
gsrm_slf_attn_bias2 = others["gsrm_slf_attn_bias2"]
|
||||
|
||||
def prepare_bi(word_ids):
|
||||
"""
|
||||
prepare bi for gsrm
|
||||
word1 for forward; word2 for backward
|
||||
"""
|
||||
word1 = fluid.layers.cast(word_ids, "float32")
|
||||
word1 = fluid.layers.pad(word1, [0, 0, 1, 0, 0, 0],
|
||||
pad_value=1.0 * pad_idx)
|
||||
word1 = fluid.layers.cast(word1, "int64")
|
||||
word1 = word1[:, :-1, :]
|
||||
word2 = word_ids
|
||||
return word1, word2
|
||||
|
||||
word1, word2 = prepare_bi(word_ids)
|
||||
word1.stop_gradient = True
|
||||
word2.stop_gradient = True
|
||||
enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
|
||||
enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
|
||||
|
||||
gsrm_feature1 = wrap_encoder(
|
||||
src_vocab_size=self.char_num + 1,
|
||||
max_length=self.max_length,
|
||||
n_layer=self.num_decoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
d_key=int(self.hidden_dims / self.num_heads),
|
||||
d_value=int(self.hidden_dims / self.num_heads),
|
||||
d_model=self.hidden_dims,
|
||||
d_inner_hid=self.hidden_dims,
|
||||
prepostprocess_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
relu_dropout=0.1,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True,
|
||||
enc_inputs=enc_inputs_1, )
|
||||
gsrm_feature2 = wrap_encoder(
|
||||
src_vocab_size=self.char_num + 1,
|
||||
max_length=self.max_length,
|
||||
n_layer=self.num_decoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
d_key=int(self.hidden_dims / self.num_heads),
|
||||
d_value=int(self.hidden_dims / self.num_heads),
|
||||
d_model=self.hidden_dims,
|
||||
d_inner_hid=self.hidden_dims,
|
||||
prepostprocess_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
relu_dropout=0.1,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True,
|
||||
enc_inputs=enc_inputs_2, )
|
||||
gsrm_feature2 = fluid.layers.pad(gsrm_feature2, [0, 0, 0, 1, 0, 0],
|
||||
pad_value=0.)
|
||||
gsrm_feature2 = gsrm_feature2[:, 1:, ]
|
||||
gsrm_features = gsrm_feature1 + gsrm_feature2
|
||||
|
||||
b, t, c = gsrm_features.shape
|
||||
|
||||
gsrm_out = fluid.layers.matmul(
|
||||
x=gsrm_features,
|
||||
y=fluid.default_main_program().global_block().var(
|
||||
"src_word_emb_table"),
|
||||
transpose_y=True)
|
||||
b, t, c = gsrm_out.shape
|
||||
gsrm_out = fluid.layers.softmax(input=fluid.layers.reshape(gsrm_out,
|
||||
[-1, c]))
|
||||
|
||||
return gsrm_features, word_out, gsrm_out
|
||||
|
||||
def vsfd(self, pvam_features, gsrm_features):
|
||||
|
||||
#===== Visual-Semantic Fusion Decoder Module =====
|
||||
b, t, c1 = pvam_features.shape
|
||||
b, t, c2 = gsrm_features.shape
|
||||
combine_features_ = fluid.layers.concat(
|
||||
[pvam_features, gsrm_features], axis=2)
|
||||
img_comb_features_ = fluid.layers.reshape(
|
||||
x=combine_features_, shape=[-1, c1 + c2])
|
||||
img_comb_features_map = fluid.layers.fc(input=img_comb_features_,
|
||||
size=c1,
|
||||
act="sigmoid")
|
||||
img_comb_features_map = fluid.layers.reshape(
|
||||
x=img_comb_features_map, shape=[-1, t, c1])
|
||||
combine_features = img_comb_features_map * pvam_features + (
|
||||
1.0 - img_comb_features_map) * gsrm_features
|
||||
img_comb_features = fluid.layers.reshape(
|
||||
x=combine_features, shape=[-1, c1])
|
||||
|
||||
fc_out = fluid.layers.fc(input=img_comb_features,
|
||||
size=self.char_num,
|
||||
act="softmax")
|
||||
return fc_out
|
||||
|
||||
def __call__(self, inputs, others, mode=None):
|
||||
|
||||
pvam_features = self.pvam(inputs, others)
|
||||
gsrm_features, word_out, gsrm_out = self.gsrm(pvam_features, others)
|
||||
final_out = self.vsfd(pvam_features, gsrm_features)
|
||||
|
||||
_, decoded_out = fluid.layers.topk(input=final_out, k=1)
|
||||
predicts = {
|
||||
'predict': final_out,
|
||||
'decoded_out': decoded_out,
|
||||
'word_out': word_out,
|
||||
'gsrm_out': gsrm_out
|
||||
}
|
||||
|
||||
return predicts
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,115 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class SASTLoss(object):
|
||||
"""
|
||||
SAST Loss function
|
||||
"""
|
||||
|
||||
def __init__(self, params=None):
|
||||
super(SASTLoss, self).__init__()
|
||||
|
||||
def __call__(self, predicts, labels):
|
||||
"""
|
||||
tcl_pos: N x 128 x 3
|
||||
tcl_mask: N x 128 x 1
|
||||
tcl_label: N x X list or LoDTensor
|
||||
"""
|
||||
|
||||
f_score = predicts['f_score']
|
||||
f_border = predicts['f_border']
|
||||
f_tvo = predicts['f_tvo']
|
||||
f_tco = predicts['f_tco']
|
||||
|
||||
l_score = labels['input_score']
|
||||
l_border = labels['input_border']
|
||||
l_mask = labels['input_mask']
|
||||
l_tvo = labels['input_tvo']
|
||||
l_tco = labels['input_tco']
|
||||
|
||||
#score_loss
|
||||
intersection = fluid.layers.reduce_sum(f_score * l_score * l_mask)
|
||||
union = fluid.layers.reduce_sum(f_score * l_mask) + fluid.layers.reduce_sum(l_score * l_mask)
|
||||
score_loss = 1.0 - 2 * intersection / (union + 1e-5)
|
||||
|
||||
#border loss
|
||||
l_border_split, l_border_norm = fluid.layers.split(l_border, num_or_sections=[4, 1], dim=1)
|
||||
f_border_split = f_border
|
||||
l_border_norm_split = fluid.layers.expand(x=l_border_norm, expand_times=[1, 4, 1, 1])
|
||||
l_border_score = fluid.layers.expand(x=l_score, expand_times=[1, 4, 1, 1])
|
||||
l_border_mask = fluid.layers.expand(x=l_mask, expand_times=[1, 4, 1, 1])
|
||||
border_diff = l_border_split - f_border_split
|
||||
abs_border_diff = fluid.layers.abs(border_diff)
|
||||
border_sign = abs_border_diff < 1.0
|
||||
border_sign = fluid.layers.cast(border_sign, dtype='float32')
|
||||
border_sign.stop_gradient = True
|
||||
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
|
||||
(abs_border_diff - 0.5) * (1.0 - border_sign)
|
||||
border_out_loss = l_border_norm_split * border_in_loss
|
||||
border_loss = fluid.layers.reduce_sum(border_out_loss * l_border_score * l_border_mask) / \
|
||||
(fluid.layers.reduce_sum(l_border_score * l_border_mask) + 1e-5)
|
||||
|
||||
#tvo_loss
|
||||
l_tvo_split, l_tvo_norm = fluid.layers.split(l_tvo, num_or_sections=[8, 1], dim=1)
|
||||
f_tvo_split = f_tvo
|
||||
l_tvo_norm_split = fluid.layers.expand(x=l_tvo_norm, expand_times=[1, 8, 1, 1])
|
||||
l_tvo_score = fluid.layers.expand(x=l_score, expand_times=[1, 8, 1, 1])
|
||||
l_tvo_mask = fluid.layers.expand(x=l_mask, expand_times=[1, 8, 1, 1])
|
||||
#
|
||||
tvo_geo_diff = l_tvo_split - f_tvo_split
|
||||
abs_tvo_geo_diff = fluid.layers.abs(tvo_geo_diff)
|
||||
tvo_sign = abs_tvo_geo_diff < 1.0
|
||||
tvo_sign = fluid.layers.cast(tvo_sign, dtype='float32')
|
||||
tvo_sign.stop_gradient = True
|
||||
tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
|
||||
(abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
|
||||
tvo_out_loss = l_tvo_norm_split * tvo_in_loss
|
||||
tvo_loss = fluid.layers.reduce_sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
|
||||
(fluid.layers.reduce_sum(l_tvo_score * l_tvo_mask) + 1e-5)
|
||||
|
||||
#tco_loss
|
||||
l_tco_split, l_tco_norm = fluid.layers.split(l_tco, num_or_sections=[2, 1], dim=1)
|
||||
f_tco_split = f_tco
|
||||
l_tco_norm_split = fluid.layers.expand(x=l_tco_norm, expand_times=[1, 2, 1, 1])
|
||||
l_tco_score = fluid.layers.expand(x=l_score, expand_times=[1, 2, 1, 1])
|
||||
l_tco_mask = fluid.layers.expand(x=l_mask, expand_times=[1, 2, 1, 1])
|
||||
#
|
||||
tco_geo_diff = l_tco_split - f_tco_split
|
||||
abs_tco_geo_diff = fluid.layers.abs(tco_geo_diff)
|
||||
tco_sign = abs_tco_geo_diff < 1.0
|
||||
tco_sign = fluid.layers.cast(tco_sign, dtype='float32')
|
||||
tco_sign.stop_gradient = True
|
||||
tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
|
||||
(abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
|
||||
tco_out_loss = l_tco_norm_split * tco_in_loss
|
||||
tco_loss = fluid.layers.reduce_sum(tco_out_loss * l_tco_score * l_tco_mask) / \
|
||||
(fluid.layers.reduce_sum(l_tco_score * l_tco_mask) + 1e-5)
|
||||
|
||||
|
||||
# total loss
|
||||
tvo_lw, tco_lw = 1.5, 1.5
|
||||
score_lw, border_lw = 1.0, 1.0
|
||||
total_loss = score_loss * score_lw + border_loss * border_lw + \
|
||||
tvo_loss * tvo_lw + tco_loss * tco_lw
|
||||
|
||||
losses = {'total_loss':total_loss, "score_loss":score_loss,\
|
||||
"border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
|
||||
return losses
|
|
@ -0,0 +1,55 @@
|
|||
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class SRNLoss(object):
|
||||
def __init__(self, params):
|
||||
super(SRNLoss, self).__init__()
|
||||
self.char_num = params['char_num']
|
||||
|
||||
def __call__(self, predicts, others):
|
||||
predict = predicts['predict']
|
||||
word_predict = predicts['word_out']
|
||||
gsrm_predict = predicts['gsrm_out']
|
||||
label = others['label']
|
||||
lbl_weight = others['lbl_weight']
|
||||
|
||||
casted_label = fluid.layers.cast(x=label, dtype='int64')
|
||||
cost_word = fluid.layers.cross_entropy(
|
||||
input=word_predict, label=casted_label)
|
||||
cost_gsrm = fluid.layers.cross_entropy(
|
||||
input=gsrm_predict, label=casted_label)
|
||||
cost_vsfd = fluid.layers.cross_entropy(
|
||||
input=predict, label=casted_label)
|
||||
|
||||
cost_word = fluid.layers.reshape(
|
||||
x=fluid.layers.reduce_sum(cost_word), shape=[1])
|
||||
cost_gsrm = fluid.layers.reshape(
|
||||
x=fluid.layers.reduce_sum(cost_gsrm), shape=[1])
|
||||
cost_vsfd = fluid.layers.reshape(
|
||||
x=fluid.layers.reduce_sum(cost_vsfd), shape=[1])
|
||||
|
||||
sum_cost = fluid.layers.sum(
|
||||
[cost_word, cost_vsfd * 2.0, cost_gsrm * 0.15])
|
||||
|
||||
return [sum_cost, cost_vsfd, cost_word]
|
|
@ -14,14 +14,50 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import math
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.regularizer import L2Decay
|
||||
from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
|
||||
import paddle.fluid.layers.ops as ops
|
||||
|
||||
from ppocr.utils.utility import initial_logger
|
||||
|
||||
logger = initial_logger()
|
||||
|
||||
|
||||
def cosine_decay_with_warmup(learning_rate,
|
||||
step_each_epoch,
|
||||
epochs=500,
|
||||
warmup_minibatch=1000):
|
||||
"""Applies cosine decay to the learning rate.
|
||||
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
|
||||
decrease lr for every mini-batch and start with warmup.
|
||||
"""
|
||||
global_step = _decay_step_counter()
|
||||
lr = fluid.layers.tensor.create_global_var(
|
||||
shape=[1],
|
||||
value=0.0,
|
||||
dtype='float32',
|
||||
persistable=True,
|
||||
name="learning_rate")
|
||||
|
||||
warmup_minibatch = fluid.layers.fill_constant(
|
||||
shape=[1],
|
||||
dtype='float32',
|
||||
value=float(warmup_minibatch),
|
||||
force_cpu=True)
|
||||
|
||||
with fluid.layers.control_flow.Switch() as switch:
|
||||
with switch.case(global_step < warmup_minibatch):
|
||||
decayed_lr = learning_rate * (1.0 * global_step / warmup_minibatch)
|
||||
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
|
||||
with switch.default():
|
||||
decayed_lr = learning_rate * \
|
||||
(ops.cos((global_step - warmup_minibatch) * (math.pi / (epochs * step_each_epoch))) + 1)/2
|
||||
fluid.layers.tensor.assign(input=decayed_lr, output=lr)
|
||||
return lr
|
||||
|
||||
|
||||
def AdamDecay(params, parameter_list=None):
|
||||
"""
|
||||
define optimizer function
|
||||
|
@ -36,17 +72,39 @@ def AdamDecay(params, parameter_list=None):
|
|||
l2_decay = params.get("l2_decay", 0.0)
|
||||
|
||||
if 'decay' in params:
|
||||
supported_decay_mode = [
|
||||
"cosine_decay", "cosine_decay_warmup", "piecewise_decay"
|
||||
]
|
||||
params = params['decay']
|
||||
decay_mode = params['function']
|
||||
step_each_epoch = params['step_each_epoch']
|
||||
total_epoch = params['total_epoch']
|
||||
assert decay_mode in supported_decay_mode, "Supported decay mode is {}, but got {}".format(
|
||||
supported_decay_mode, decay_mode)
|
||||
|
||||
if decay_mode == "cosine_decay":
|
||||
step_each_epoch = params['step_each_epoch']
|
||||
total_epoch = params['total_epoch']
|
||||
base_lr = fluid.layers.cosine_decay(
|
||||
learning_rate=base_lr,
|
||||
step_each_epoch=step_each_epoch,
|
||||
epochs=total_epoch)
|
||||
else:
|
||||
logger.info("Only support Cosine decay currently")
|
||||
elif decay_mode == "cosine_decay_warmup":
|
||||
step_each_epoch = params['step_each_epoch']
|
||||
total_epoch = params['total_epoch']
|
||||
warmup_minibatch = params.get("warmup_minibatch", 1000)
|
||||
base_lr = cosine_decay_with_warmup(
|
||||
learning_rate=base_lr,
|
||||
step_each_epoch=step_each_epoch,
|
||||
epochs=total_epoch,
|
||||
warmup_minibatch=warmup_minibatch)
|
||||
elif decay_mode == "piecewise_decay":
|
||||
boundaries = params["boundaries"]
|
||||
decay_rate = params["decay_rate"]
|
||||
values = [
|
||||
base_lr * decay_rate**idx
|
||||
for idx in range(len(boundaries) + 1)
|
||||
]
|
||||
base_lr = fluid.layers.piecewise_decay(boundaries, values)
|
||||
|
||||
optimizer = fluid.optimizer.Adam(
|
||||
learning_rate=base_lr,
|
||||
beta1=beta1,
|
||||
|
@ -54,3 +112,44 @@ def AdamDecay(params, parameter_list=None):
|
|||
regularization=L2Decay(regularization_coeff=l2_decay),
|
||||
parameter_list=parameter_list)
|
||||
return optimizer
|
||||
|
||||
|
||||
def RMSProp(params, parameter_list=None):
|
||||
"""
|
||||
define optimizer function
|
||||
args:
|
||||
params(dict): the super parameters
|
||||
parameter_list (list): list of Variable names to update to minimize loss
|
||||
return:
|
||||
"""
|
||||
base_lr = params.get("base_lr", 0.001)
|
||||
l2_decay = params.get("l2_decay", 0.00005)
|
||||
|
||||
if 'decay' in params:
|
||||
supported_decay_mode = ["cosine_decay", "piecewise_decay"]
|
||||
params = params['decay']
|
||||
decay_mode = params['function']
|
||||
assert decay_mode in supported_decay_mode, "Supported decay mode is {}, but got {}".format(
|
||||
supported_decay_mode, decay_mode)
|
||||
|
||||
if decay_mode == "cosine_decay":
|
||||
step_each_epoch = params['step_each_epoch']
|
||||
total_epoch = params['total_epoch']
|
||||
base_lr = fluid.layers.cosine_decay(
|
||||
learning_rate=base_lr,
|
||||
step_each_epoch=step_each_epoch,
|
||||
epochs=total_epoch)
|
||||
elif decay_mode == "piecewise_decay":
|
||||
boundaries = params["boundaries"]
|
||||
decay_rate = params["decay_rate"]
|
||||
values = [
|
||||
base_lr * decay_rate**idx
|
||||
for idx in range(len(boundaries) + 1)
|
||||
]
|
||||
base_lr = fluid.layers.piecewise_decay(boundaries, values)
|
||||
|
||||
optimizer = fluid.optimizer.RMSProp(
|
||||
learning_rate=base_lr,
|
||||
regularization=fluid.regularizer.L2Decay(regularization_coeff=l2_decay))
|
||||
|
||||
return optimizer
|
||||
|
|
|
@ -22,9 +22,9 @@ import cv2
|
|||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
|
||||
class EASTPostPocess(object):
|
||||
|
|
|
@ -25,7 +25,7 @@ import ycm_core
|
|||
# These are the compilation flags that will be used in case there's no
|
||||
# compilation database set (by default, one is not set).
|
||||
# CHANGE THIS LIST OF FLAGS. YES, THIS IS THE DROID YOU HAVE BEEN LOOKING FOR.
|
||||
sys.path.append(os.path.dirname(__file__))
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
||||
import numpy as np
|
||||
from .locality_aware_nms import nms_locality
|
||||
# import lanms
|
||||
import cv2
|
||||
import time
|
||||
|
||||
|
||||
class SASTPostProcess(object):
|
||||
"""
|
||||
The post process for SAST.
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
self.score_thresh = params.get('score_thresh', 0.5)
|
||||
self.nms_thresh = params.get('nms_thresh', 0.2)
|
||||
self.sample_pts_num = params.get('sample_pts_num', 2)
|
||||
self.shrink_ratio_of_width = params.get('shrink_ratio_of_width', 0.3)
|
||||
self.expand_scale = params.get('expand_scale', 1.0)
|
||||
self.tcl_map_thresh = 0.5
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||
self.is_python35 = True
|
||||
|
||||
def point_pair2poly(self, point_pair_list):
|
||||
"""
|
||||
Transfer vertical point_pairs into poly point in clockwise.
|
||||
"""
|
||||
# constract poly
|
||||
point_num = len(point_pair_list) * 2
|
||||
point_list = [0] * point_num
|
||||
for idx, point_pair in enumerate(point_pair_list):
|
||||
point_list[idx] = point_pair[0]
|
||||
point_list[point_num - 1 - idx] = point_pair[1]
|
||||
return np.array(point_list).reshape(-1, 2)
|
||||
|
||||
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
|
||||
"""
|
||||
expand poly along width.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
|
||||
right_ratio = 1.0 + \
|
||||
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||
poly[0] = left_quad_expand[0]
|
||||
poly[-1] = left_quad_expand[-1]
|
||||
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||
poly[point_num // 2] = right_quad_expand[2]
|
||||
return poly
|
||||
|
||||
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
|
||||
"""Restore quad."""
|
||||
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
|
||||
# Sort the text boxes via the y axis
|
||||
xy_text = xy_text[np.argsort(xy_text[:, 1])]
|
||||
|
||||
scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||
scores = scores[:, np.newaxis]
|
||||
|
||||
# Restore
|
||||
point_num = int(tvo_map.shape[-1] / 2)
|
||||
assert point_num == 4
|
||||
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
|
||||
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
|
||||
quads = xy_text_tile - tvo_map
|
||||
|
||||
return scores, quads, xy_text
|
||||
|
||||
def quad_area(self, quad):
|
||||
"""
|
||||
compute area of a quad.
|
||||
"""
|
||||
edge = [
|
||||
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
|
||||
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
|
||||
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
|
||||
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
|
||||
]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def nms(self, dets):
|
||||
if self.is_python35:
|
||||
import lanms
|
||||
dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
|
||||
else:
|
||||
dets = nms_locality(dets, self.nms_thresh)
|
||||
return dets
|
||||
|
||||
def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
|
||||
"""
|
||||
Cluster pixels in tcl_map based on quads.
|
||||
"""
|
||||
instance_count = quads.shape[0] + 1 # contain background
|
||||
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
|
||||
if instance_count == 1:
|
||||
return instance_count, instance_label_map
|
||||
|
||||
# predict text center
|
||||
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||
n = xy_text.shape[0]
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
|
||||
pred_tc = xy_text - tco
|
||||
|
||||
# get gt text center
|
||||
m = quads.shape[0]
|
||||
gt_tc = np.mean(quads, axis=1) # (m, 2)
|
||||
|
||||
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
|
||||
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
|
||||
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
|
||||
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
|
||||
|
||||
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
|
||||
return instance_count, instance_label_map
|
||||
|
||||
def estimate_sample_pts_num(self, quad, xy_text):
|
||||
"""
|
||||
Estimate sample points number.
|
||||
"""
|
||||
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
|
||||
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
|
||||
dense_sample_pts_num = max(2, int(ew))
|
||||
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
|
||||
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||
|
||||
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
|
||||
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
|
||||
|
||||
sample_pts_num = max(2, int(estimate_arc_len / eh))
|
||||
return sample_pts_num
|
||||
|
||||
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
|
||||
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
|
||||
"""
|
||||
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
|
||||
"""
|
||||
# restore quad
|
||||
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
|
||||
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
|
||||
dets = self.nms(dets)
|
||||
if dets.shape[0] == 0:
|
||||
return []
|
||||
quads = dets[:, :-1].reshape(-1, 4, 2)
|
||||
|
||||
# Compute quad area
|
||||
quad_areas = []
|
||||
for quad in quads:
|
||||
quad_areas.append(-self.quad_area(quad))
|
||||
|
||||
# instance segmentation
|
||||
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
|
||||
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
|
||||
|
||||
# restore single poly with tcl instance.
|
||||
poly_list = []
|
||||
for instance_idx in range(1, instance_count):
|
||||
xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
|
||||
quad = quads[instance_idx - 1]
|
||||
q_area = quad_areas[instance_idx - 1]
|
||||
if q_area < 5:
|
||||
continue
|
||||
|
||||
#
|
||||
len1 = float(np.linalg.norm(quad[0] -quad[1]))
|
||||
len2 = float(np.linalg.norm(quad[1] -quad[2]))
|
||||
min_len = min(len1, len2)
|
||||
if min_len < 3:
|
||||
continue
|
||||
|
||||
# filter small CC
|
||||
if xy_text.shape[0] <= 0:
|
||||
continue
|
||||
|
||||
# filter low confidence instance
|
||||
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
|
||||
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
|
||||
continue
|
||||
|
||||
# sort xy_text
|
||||
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
|
||||
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
|
||||
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
|
||||
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
|
||||
proj_unit_vec = (right_center_pt - left_center_pt) / \
|
||||
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
|
||||
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
|
||||
xy_text = xy_text[np.argsort(proj_value)]
|
||||
|
||||
# Sample pts in tcl map
|
||||
if self.sample_pts_num == 0:
|
||||
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
|
||||
else:
|
||||
sample_pts_num = self.sample_pts_num
|
||||
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
|
||||
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||
|
||||
point_pair_list = []
|
||||
for x, y in xy_center_line:
|
||||
# get corresponding offset
|
||||
offset = tbo_map[y, x, :].reshape(2, 2)
|
||||
if offset_expand != 1.0:
|
||||
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
|
||||
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
|
||||
offset_detal = offset / offset_length * expand_length
|
||||
offset = offset + offset_detal
|
||||
# original point
|
||||
ori_yx = np.array([y, x], dtype=np.float32)
|
||||
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
# ndarry: (x, 2), expand poly along width
|
||||
detected_poly = self.point_pair2poly(point_pair_list)
|
||||
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
|
||||
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
poly_list.append(detected_poly)
|
||||
|
||||
return poly_list
|
||||
|
||||
def __call__(self, outs_dict, ratio_list):
|
||||
score_list = outs_dict['f_score']
|
||||
border_list = outs_dict['f_border']
|
||||
tvo_list = outs_dict['f_tvo']
|
||||
tco_list = outs_dict['f_tco']
|
||||
|
||||
img_num = len(ratio_list)
|
||||
poly_lists = []
|
||||
for ino in range(img_num):
|
||||
p_score = score_list[ino].transpose((1,2,0))
|
||||
p_border = border_list[ino].transpose((1,2,0))
|
||||
p_tvo = tvo_list[ino].transpose((1,2,0))
|
||||
p_tco = tco_list[ino].transpose((1,2,0))
|
||||
# print(p_score.shape, p_border.shape, p_tvo.shape, p_tco.shape)
|
||||
ratio_h, ratio_w, src_h, src_w = ratio_list[ino]
|
||||
|
||||
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
|
||||
shrink_ratio_of_width=self.shrink_ratio_of_width,
|
||||
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
|
||||
|
||||
poly_lists.append(poly_list)
|
||||
|
||||
return poly_lists
|
||||
|
|
@ -25,6 +25,9 @@ class CharacterOps(object):
|
|||
def __init__(self, config):
|
||||
self.character_type = config['character_type']
|
||||
self.loss_type = config['loss_type']
|
||||
self.max_text_len = config['max_text_length']
|
||||
if self.loss_type == "srn" and self.character_type != "en":
|
||||
raise Exception("SRN can only support in character_type == en")
|
||||
if self.character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
|
@ -54,6 +57,8 @@ class CharacterOps(object):
|
|||
self.end_str = "eos"
|
||||
if self.loss_type == "attention":
|
||||
dict_character = [self.beg_str, self.end_str] + dict_character
|
||||
elif self.loss_type == "srn":
|
||||
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
|
@ -147,6 +152,39 @@ def cal_predicts_accuracy(char_ops,
|
|||
return acc, acc_num, img_num
|
||||
|
||||
|
||||
def cal_predicts_accuracy_srn(char_ops,
|
||||
preds,
|
||||
labels,
|
||||
max_text_len,
|
||||
is_debug=False):
|
||||
acc_num = 0
|
||||
img_num = 0
|
||||
|
||||
total_len = preds.shape[0]
|
||||
img_num = int(total_len / max_text_len)
|
||||
for i in range(img_num):
|
||||
cur_label = []
|
||||
cur_pred = []
|
||||
for j in range(max_text_len):
|
||||
if labels[j + i * max_text_len] != 37: #0
|
||||
cur_label.append(labels[j + i * max_text_len][0])
|
||||
else:
|
||||
break
|
||||
|
||||
for j in range(max_text_len + 1):
|
||||
if j < len(cur_label) and preds[j + i * max_text_len][
|
||||
0] != cur_label[j]:
|
||||
break
|
||||
elif j == len(cur_label) and j == max_text_len:
|
||||
acc_num += 1
|
||||
break
|
||||
elif j == len(cur_label) and preds[j + i * max_text_len][0] == 37:
|
||||
acc_num += 1
|
||||
break
|
||||
acc = acc_num * 1.0 / img_num
|
||||
return acc, acc_num, img_num
|
||||
|
||||
|
||||
def convert_rec_attention_infer_res(preds):
|
||||
img_num = preds.shape[0]
|
||||
target_lod = [0]
|
||||
|
|
|
@ -114,15 +114,15 @@ def init_model(config, program, exe):
|
|||
fluid.load(program, path, exe)
|
||||
logger.info("Finish initing model from {}".format(path))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Model checkpoints {} does not exists,"
|
||||
"check if you lost the file prefix.".format(checkpoints + '.pdparams'))
|
||||
|
||||
pretrain_weights = config['Global'].get('pretrain_weights')
|
||||
if pretrain_weights:
|
||||
path = pretrain_weights
|
||||
load_params(exe, program, path)
|
||||
logger.info("Finish initing model from {}".format(path))
|
||||
raise ValueError("Model checkpoints {} does not exists,"
|
||||
"check if you lost the file prefix.".format(
|
||||
checkpoints + '.pdparams'))
|
||||
else:
|
||||
pretrain_weights = config['Global'].get('pretrain_weights')
|
||||
if pretrain_weights:
|
||||
path = pretrain_weights
|
||||
load_params(exe, program, path)
|
||||
logger.info("Finish initing model from {}".format(path))
|
||||
|
||||
|
||||
def save_model(program, model_path):
|
||||
|
|
|
@ -18,9 +18,9 @@ from __future__ import print_function
|
|||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
|
|
|
@ -88,8 +88,8 @@ class DetectionIoUEvaluator(object):
|
|||
points = gt[n]['points']
|
||||
# transcription = gt[n]['text']
|
||||
dontCare = gt[n]['ignore']
|
||||
points = Polygon(points)
|
||||
points = points.buffer(0)
|
||||
# points = Polygon(points)
|
||||
# points = points.buffer(0)
|
||||
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
||||
continue
|
||||
|
||||
|
@ -105,8 +105,8 @@ class DetectionIoUEvaluator(object):
|
|||
|
||||
for n in range(len(pred)):
|
||||
points = pred[n]['points']
|
||||
points = Polygon(points)
|
||||
points = points.buffer(0)
|
||||
# points = Polygon(points)
|
||||
# points = points.buffer(0)
|
||||
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
||||
continue
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
|
|||
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ppocr.utils.character import cal_predicts_accuracy
|
||||
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn
|
||||
from ppocr.utils.character import convert_rec_label_to_lod
|
||||
from ppocr.utils.character import convert_rec_attention_infer_res
|
||||
from ppocr.utils.utility import create_module
|
||||
|
@ -60,22 +60,60 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
|
|||
for ino in range(img_num):
|
||||
img_list.append(data[ino][0])
|
||||
label_list.append(data[ino][1])
|
||||
img_list = np.concatenate(img_list, axis=0)
|
||||
outs = exe.run(eval_info_dict['program'], \
|
||||
|
||||
if config['Global']['loss_type'] != "srn":
|
||||
img_list = np.concatenate(img_list, axis=0)
|
||||
outs = exe.run(eval_info_dict['program'], \
|
||||
feed={'image': img_list}, \
|
||||
fetch_list=eval_info_dict['fetch_varname_list'], \
|
||||
return_numpy=False)
|
||||
preds = np.array(outs[0])
|
||||
if preds.shape[1] != 1:
|
||||
preds, preds_lod = convert_rec_attention_infer_res(preds)
|
||||
preds = np.array(outs[0])
|
||||
|
||||
if config['Global']['loss_type'] == "attention":
|
||||
preds, preds_lod = convert_rec_attention_infer_res(preds)
|
||||
else:
|
||||
preds_lod = outs[0].lod()[0]
|
||||
labels, labels_lod = convert_rec_label_to_lod(label_list)
|
||||
acc, acc_num, sample_num = cal_predicts_accuracy(
|
||||
char_ops, preds, preds_lod, labels, labels_lod,
|
||||
is_remove_duplicate)
|
||||
else:
|
||||
preds_lod = outs[0].lod()[0]
|
||||
labels, labels_lod = convert_rec_label_to_lod(label_list)
|
||||
acc, acc_num, sample_num = cal_predicts_accuracy(
|
||||
char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate)
|
||||
encoder_word_pos_list = []
|
||||
gsrm_word_pos_list = []
|
||||
gsrm_slf_attn_bias1_list = []
|
||||
gsrm_slf_attn_bias2_list = []
|
||||
for ino in range(img_num):
|
||||
encoder_word_pos_list.append(data[ino][2])
|
||||
gsrm_word_pos_list.append(data[ino][3])
|
||||
gsrm_slf_attn_bias1_list.append(data[ino][4])
|
||||
gsrm_slf_attn_bias2_list.append(data[ino][5])
|
||||
|
||||
img_list = np.concatenate(img_list, axis=0)
|
||||
label_list = np.concatenate(label_list, axis=0)
|
||||
encoder_word_pos_list = np.concatenate(
|
||||
encoder_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_word_pos_list = np.concatenate(
|
||||
gsrm_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_slf_attn_bias1_list = np.concatenate(
|
||||
gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
|
||||
gsrm_slf_attn_bias2_list = np.concatenate(
|
||||
gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
|
||||
|
||||
labels = label_list
|
||||
|
||||
outs = exe.run(eval_info_dict['program'], \
|
||||
feed={'image': img_list, 'encoder_word_pos': encoder_word_pos_list,
|
||||
'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \
|
||||
fetch_list=eval_info_dict['fetch_varname_list'], \
|
||||
return_numpy=False)
|
||||
preds = np.array(outs[0])
|
||||
acc, acc_num, sample_num = cal_predicts_accuracy_srn(
|
||||
char_ops, preds, labels, config['Global']['max_text_length'])
|
||||
|
||||
total_acc_num += acc_num
|
||||
total_sample_num += sample_num
|
||||
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
|
||||
#logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
|
||||
total_batch_num += 1
|
||||
avg_acc = total_acc_num * 1.0 / total_sample_num
|
||||
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
|
||||
|
|
|
@ -18,9 +18,9 @@ from __future__ import print_function
|
|||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
|
|
|
@ -13,19 +13,21 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '../..'))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
import cv2
|
||||
from ppocr.data.det.sast_process import SASTProcessTest
|
||||
from ppocr.data.det.east_process import EASTProcessTest
|
||||
from ppocr.data.det.db_process import DBProcessTest
|
||||
from ppocr.postprocess.db_postprocess import DBPostProcess
|
||||
from ppocr.postprocess.east_postprocess import EASTPostPocess
|
||||
from ppocr.postprocess.sast_postprocess import SASTPostProcess
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
|
@ -52,6 +54,20 @@ class TextDetector(object):
|
|||
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
|
||||
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
|
||||
self.postprocess_op = EASTPostPocess(postprocess_params)
|
||||
elif self.det_algorithm == "SAST":
|
||||
self.preprocess_op = SASTProcessTest(preprocess_params)
|
||||
postprocess_params["score_thresh"] = args.det_sast_score_thresh
|
||||
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
|
||||
self.det_sast_polygon = args.det_sast_polygon
|
||||
if self.det_sast_polygon:
|
||||
postprocess_params["sample_pts_num"] = 6
|
||||
postprocess_params["expand_scale"] = 1.2
|
||||
postprocess_params["shrink_ratio_of_width"] = 0.2
|
||||
else:
|
||||
postprocess_params["sample_pts_num"] = 2
|
||||
postprocess_params["expand_scale"] = 1.0
|
||||
postprocess_params["shrink_ratio_of_width"] = 0.3
|
||||
self.postprocess_op = SASTPostProcess(postprocess_params)
|
||||
else:
|
||||
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
|
||||
sys.exit(0)
|
||||
|
@ -84,7 +100,7 @@ class TextDetector(object):
|
|||
return rect
|
||||
|
||||
def clip_det_res(self, points, img_height, img_width):
|
||||
for pno in range(4):
|
||||
for pno in range(points.shape[0]):
|
||||
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
||||
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
||||
return points
|
||||
|
@ -103,6 +119,15 @@ class TextDetector(object):
|
|||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
im, ratio_list = self.preprocess_op(img)
|
||||
|
@ -120,11 +145,20 @@ class TextDetector(object):
|
|||
if self.det_algorithm == "EAST":
|
||||
outs_dict['f_geo'] = outputs[0]
|
||||
outs_dict['f_score'] = outputs[1]
|
||||
elif self.det_algorithm == 'SAST':
|
||||
outs_dict['f_border'] = outputs[0]
|
||||
outs_dict['f_score'] = outputs[1]
|
||||
outs_dict['f_tco'] = outputs[2]
|
||||
outs_dict['f_tvo'] = outputs[3]
|
||||
else:
|
||||
outs_dict['maps'] = outputs[0]
|
||||
|
||||
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
|
||||
dt_boxes = dt_boxes_list[0]
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
if self.det_algorithm == "SAST" and self.det_sast_polygon:
|
||||
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
||||
else:
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
elapse = time.time() - starttime
|
||||
return dt_boxes, elapse
|
||||
|
||||
|
|
|
@ -40,7 +40,8 @@ class TextRecognizer(object):
|
|||
char_ops_params = {
|
||||
"character_type": args.rec_char_type,
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
"use_space_char": args.use_space_char,
|
||||
"max_text_length": args.max_text_length
|
||||
}
|
||||
if self.rec_algorithm != "RARE":
|
||||
char_ops_params['loss_type'] = 'ctc'
|
||||
|
@ -122,9 +123,9 @@ class TextRecognizer(object):
|
|||
ind = np.argmax(probs, axis=1)
|
||||
blank = probs.shape[1]
|
||||
valid_ind = np.where(ind != (blank - 1))[0]
|
||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||
if len(valid_ind) == 0:
|
||||
continue
|
||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||
# rec_res.append([preds_text, score])
|
||||
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
|
||||
else:
|
||||
|
|
|
@ -53,12 +53,18 @@ def parse_args():
|
|||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
||||
|
||||
#SAST parmas
|
||||
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
||||
parser.add_argument("--det_sast_polygon", type=bool, default=False)
|
||||
|
||||
#params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||
parser.add_argument("--rec_model_dir", type=str)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||
parser.add_argument("--rec_char_type", type=str, default='ch')
|
||||
parser.add_argument("--rec_batch_num", type=int, default=30)
|
||||
parser.add_argument("--max_text_length", type=int, default=25)
|
||||
parser.add_argument(
|
||||
"--rec_char_dict_path",
|
||||
type=str,
|
||||
|
@ -95,7 +101,7 @@ def create_predictor(args, mode):
|
|||
config.set_cpu_math_library_num_threads(6)
|
||||
if args.enable_mkldnn:
|
||||
config.enable_mkldnn()
|
||||
|
||||
|
||||
#config.enable_memory_optim()
|
||||
config.disable_glog_info()
|
||||
|
||||
|
@ -169,7 +175,7 @@ def draw_ocr_box_txt(image, boxes, txts):
|
|||
img_right = Image.new('RGB', (w, h), (255, 255, 255))
|
||||
|
||||
import random
|
||||
# 每次使用相同的随机种子 ,可以保证两次颜色一致
|
||||
|
||||
random.seed(0)
|
||||
draw_left = ImageDraw.Draw(img_left)
|
||||
draw_right = ImageDraw.Draw(img_right)
|
||||
|
|
|
@ -22,9 +22,9 @@ import json
|
|||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
|
@ -134,8 +134,10 @@ def main():
|
|||
dic = {'f_score': outs[0], 'f_geo': outs[1]}
|
||||
elif config['Global']['algorithm'] == 'DB':
|
||||
dic = {'maps': outs[0]}
|
||||
elif config['Global']['algorithm'] == 'SAST':
|
||||
dic = {'f_score': outs[0], 'f_border': outs[1], 'f_tvo': outs[2], 'f_tco': outs[3]}
|
||||
else:
|
||||
raise Exception("only support algorithm: ['EAST', 'DB']")
|
||||
raise Exception("only support algorithm: ['EAST', 'DB', 'SAST']")
|
||||
dt_boxes_list = postprocess(dic, ratio_list)
|
||||
for ino in range(img_num):
|
||||
dt_boxes = dt_boxes_list[ino]
|
||||
|
@ -149,7 +151,7 @@ def main():
|
|||
fout.write(otstr.encode())
|
||||
src_img = cv2.imread(img_name)
|
||||
draw_det_res(dt_boxes, config, src_img, img_name)
|
||||
|
||||
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
|
|
|
@ -19,9 +19,9 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
|
@ -64,7 +64,6 @@ def main():
|
|||
exe = fluid.Executor(place)
|
||||
|
||||
rec_model = create_module(config['Architecture']['function'])(params=config)
|
||||
|
||||
startup_prog = fluid.Program()
|
||||
eval_prog = fluid.Program()
|
||||
with fluid.program_guard(eval_prog, startup_prog):
|
||||
|
@ -86,10 +85,36 @@ def main():
|
|||
for i in range(max_img_num):
|
||||
logger.info("infer_img:%s" % infer_list[i])
|
||||
img = next(blobs)
|
||||
predict = exe.run(program=eval_prog,
|
||||
feed={"image": img},
|
||||
fetch_list=fetch_varname_list,
|
||||
return_numpy=False)
|
||||
if loss_type != "srn":
|
||||
predict = exe.run(program=eval_prog,
|
||||
feed={"image": img},
|
||||
fetch_list=fetch_varname_list,
|
||||
return_numpy=False)
|
||||
else:
|
||||
encoder_word_pos_list = []
|
||||
gsrm_word_pos_list = []
|
||||
gsrm_slf_attn_bias1_list = []
|
||||
gsrm_slf_attn_bias2_list = []
|
||||
encoder_word_pos_list.append(img[1])
|
||||
gsrm_word_pos_list.append(img[2])
|
||||
gsrm_slf_attn_bias1_list.append(img[3])
|
||||
gsrm_slf_attn_bias2_list.append(img[4])
|
||||
|
||||
encoder_word_pos_list = np.concatenate(
|
||||
encoder_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_word_pos_list = np.concatenate(
|
||||
gsrm_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_slf_attn_bias1_list = np.concatenate(
|
||||
gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
|
||||
gsrm_slf_attn_bias2_list = np.concatenate(
|
||||
gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
|
||||
|
||||
predict = exe.run(program=eval_prog, \
|
||||
feed={'image': img[0], 'encoder_word_pos': encoder_word_pos_list,
|
||||
'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \
|
||||
fetch_list=fetch_varname_list, \
|
||||
return_numpy=False)
|
||||
if loss_type == "ctc":
|
||||
preds = np.array(predict[0])
|
||||
preds = preds.reshape(-1)
|
||||
|
@ -114,7 +139,18 @@ def main():
|
|||
score = np.mean(probs[0, 1:end_pos[1]])
|
||||
preds = preds.reshape(-1)
|
||||
preds_text = char_ops.decode(preds)
|
||||
|
||||
elif loss_type == "srn":
|
||||
cur_pred = []
|
||||
preds = np.array(predict[0])
|
||||
preds = preds.reshape(-1)
|
||||
probs = np.array(predict[1])
|
||||
ind = np.argmax(probs, axis=1)
|
||||
valid_ind = np.where(preds != 37)[0]
|
||||
if len(valid_ind) == 0:
|
||||
continue
|
||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||
preds = preds[:valid_ind[-1] + 1]
|
||||
preds_text = char_ops.decode(preds)
|
||||
logger.info("\t index: {}".format(preds))
|
||||
logger.info("\t word : {}".format(preds_text))
|
||||
logger.info("\t score: {}".format(score))
|
||||
|
|
|
@ -32,7 +32,8 @@ from eval_utils.eval_det_utils import eval_det_run
|
|||
from eval_utils.eval_rec_utils import eval_rec_run
|
||||
from ppocr.utils.save_load import save_model
|
||||
import numpy as np
|
||||
from ppocr.utils.character import cal_predicts_accuracy, CharacterOps
|
||||
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
||||
|
||||
|
||||
class ArgsParser(ArgumentParser):
|
||||
def __init__(self):
|
||||
|
@ -81,10 +82,8 @@ default_config = {'Global': {'debug': False, }}
|
|||
def load_config(file_path):
|
||||
"""
|
||||
Load config from yml/yaml file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path of the config file to be loaded.
|
||||
|
||||
Returns: global config
|
||||
"""
|
||||
merge_config(default_config)
|
||||
|
@ -103,10 +102,8 @@ def load_config(file_path):
|
|||
def merge_config(config):
|
||||
"""
|
||||
Merge config into global config.
|
||||
|
||||
Args:
|
||||
config (dict): Config to be merged.
|
||||
|
||||
Returns: global config
|
||||
"""
|
||||
for key, value in config.items():
|
||||
|
@ -157,13 +154,11 @@ def build(config, main_prog, startup_prog, mode):
|
|||
3. create a model
|
||||
4. create fetchs
|
||||
5. create an optimizer
|
||||
|
||||
Args:
|
||||
config(dict): config
|
||||
main_prog(): main program
|
||||
startup_prog(): startup program
|
||||
is_train(bool): train or valid
|
||||
|
||||
Returns:
|
||||
dataloader(): a bridge between the model and the data
|
||||
fetchs(dict): dict of model outputs(included loss and measures)
|
||||
|
@ -176,8 +171,16 @@ def build(config, main_prog, startup_prog, mode):
|
|||
fetch_name_list = list(outputs.keys())
|
||||
fetch_varname_list = [outputs[v].name for v in fetch_name_list]
|
||||
opt_loss_name = None
|
||||
model_average = None
|
||||
img_loss_name = None
|
||||
word_loss_name = None
|
||||
if mode == "train":
|
||||
opt_loss = outputs['total_loss']
|
||||
# srn loss
|
||||
#img_loss = outputs['img_loss']
|
||||
#word_loss = outputs['word_loss']
|
||||
#img_loss_name = img_loss.name
|
||||
#word_loss_name = word_loss.name
|
||||
opt_params = config['Optimizer']
|
||||
optimizer = create_module(opt_params['function'])(opt_params)
|
||||
optimizer.minimize(opt_loss)
|
||||
|
@ -185,7 +188,17 @@ def build(config, main_prog, startup_prog, mode):
|
|||
global_lr = optimizer._global_learning_rate()
|
||||
fetch_name_list.insert(0, "lr")
|
||||
fetch_varname_list.insert(0, global_lr.name)
|
||||
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name)
|
||||
if "loss_type" in config["Global"]:
|
||||
if config['Global']["loss_type"] == 'srn':
|
||||
model_average = fluid.optimizer.ModelAverage(
|
||||
config['Global']['average_window'],
|
||||
min_average_window=config['Global'][
|
||||
'min_average_window'],
|
||||
max_average_window=config['Global'][
|
||||
'max_average_window'])
|
||||
|
||||
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,
|
||||
model_average)
|
||||
|
||||
|
||||
def build_export(config, main_prog, startup_prog):
|
||||
|
@ -329,14 +342,20 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
|
||||
preds_idx = fetch_map['decoded_out']
|
||||
preds = np.array(train_outs[preds_idx])
|
||||
preds_lod = train_outs[preds_idx].lod()[0]
|
||||
labels_idx = fetch_map['label']
|
||||
labels = np.array(train_outs[labels_idx])
|
||||
labels_lod = train_outs[labels_idx].lod()[0]
|
||||
|
||||
acc, acc_num, img_num = cal_predicts_accuracy(
|
||||
config['Global']['char_ops'], preds, preds_lod, labels,
|
||||
labels_lod)
|
||||
if config['Global']['loss_type'] != 'srn':
|
||||
preds_lod = train_outs[preds_idx].lod()[0]
|
||||
labels_lod = train_outs[labels_idx].lod()[0]
|
||||
|
||||
acc, acc_num, img_num = cal_predicts_accuracy(
|
||||
config['Global']['char_ops'], preds, preds_lod, labels,
|
||||
labels_lod)
|
||||
else:
|
||||
acc, acc_num, img_num = cal_predicts_accuracy_srn(
|
||||
config['Global']['char_ops'], preds, labels,
|
||||
config['Global']['max_text_length'])
|
||||
t2 = time.time()
|
||||
train_batch_elapse = t2 - t1
|
||||
stats = {'loss': loss, 'acc': acc}
|
||||
|
@ -350,6 +369,9 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
|
||||
if train_batch_id > 0 and\
|
||||
train_batch_id % eval_batch_step == 0:
|
||||
model_average = train_info_dict['model_average']
|
||||
if model_average != None:
|
||||
model_average.apply(exe)
|
||||
metrics = eval_rec_run(exe, config, eval_info_dict, "eval")
|
||||
eval_acc = metrics['avg_acc']
|
||||
eval_sample_num = metrics['total_sample_num']
|
||||
|
@ -375,6 +397,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
save_model(train_info_dict['train_program'], save_path)
|
||||
return
|
||||
|
||||
|
||||
def preprocess():
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
|
@ -386,15 +409,15 @@ def preprocess():
|
|||
check_gpu(use_gpu)
|
||||
|
||||
alg = config['Global']['algorithm']
|
||||
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
|
||||
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||
startup_program = fluid.Program()
|
||||
train_program = fluid.Program()
|
||||
|
||||
if alg in ['EAST', 'DB']:
|
||||
if alg in ['EAST', 'DB', 'SAST']:
|
||||
train_alg_type = 'det'
|
||||
else:
|
||||
train_alg_type = 'rec'
|
||||
|
|
|
@ -18,9 +18,9 @@ from __future__ import print_function
|
|||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
|
@ -52,6 +52,7 @@ def main():
|
|||
train_fetch_name_list = train_build_outputs[1]
|
||||
train_fetch_varname_list = train_build_outputs[2]
|
||||
train_opt_loss_name = train_build_outputs[3]
|
||||
model_average = train_build_outputs[-1]
|
||||
|
||||
eval_program = fluid.Program()
|
||||
eval_build_outputs = program.build(
|
||||
|
@ -85,7 +86,8 @@ def main():
|
|||
'train_program':train_program,\
|
||||
'reader':train_loader,\
|
||||
'fetch_name_list':train_fetch_name_list,\
|
||||
'fetch_varname_list':train_fetch_varname_list}
|
||||
'fetch_varname_list':train_fetch_varname_list,\
|
||||
'model_average': model_average}
|
||||
|
||||
eval_info_dict = {'program':eval_program,\
|
||||
'reader':eval_reader,\
|
||||
|
|
Loading…
Reference in New Issue