commit
9198059544
|
@ -59,8 +59,10 @@ Optimizer:
|
||||||
PostProcess:
|
PostProcess:
|
||||||
name: PGPostProcess
|
name: PGPostProcess
|
||||||
score_thresh: 0.5
|
score_thresh: 0.5
|
||||||
|
mode: fast # fast or slow two ways
|
||||||
Metric:
|
Metric:
|
||||||
name: E2EMetric
|
name: E2EMetric
|
||||||
|
gt_mat_dir: # the dir of gt_mat
|
||||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||||
main_indicator: f_score_e2e
|
main_indicator: f_score_e2e
|
||||||
|
|
||||||
|
@ -106,7 +108,7 @@ Eval:
|
||||||
order: 'hwc'
|
order: 'hwc'
|
||||||
- ToCHWImage:
|
- ToCHWImage:
|
||||||
- KeepKeys:
|
- KeepKeys:
|
||||||
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ]
|
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id']
|
||||||
loader:
|
loader:
|
||||||
shuffle: False
|
shuffle: False
|
||||||
drop_last: False
|
drop_last: False
|
||||||
|
|
|
@ -28,13 +28,10 @@ inference 模型(`paddle.jit.save`保存的模型)
|
||||||
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
||||||
- [5. 多语言模型的推理](#多语言模型的推理)
|
- [5. 多语言模型的推理](#多语言模型的推理)
|
||||||
|
|
||||||
- [四、端到端模型推理](#端到端模型推理)
|
- [四、方向分类模型推理](#方向识别模型推理)
|
||||||
- [1. PGNet端到端模型推理](#PGNet端到端模型推理)
|
|
||||||
|
|
||||||
- [五、方向分类模型推理](#方向识别模型推理)
|
|
||||||
- [1. 方向分类模型推理](#方向分类模型推理)
|
- [1. 方向分类模型推理](#方向分类模型推理)
|
||||||
|
|
||||||
- [六、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
|
- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
|
||||||
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
|
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
|
||||||
- [2. 其他模型推理](#其他模型推理)
|
- [2. 其他模型推理](#其他模型推理)
|
||||||
|
|
||||||
|
@ -362,38 +359,8 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" -
|
||||||
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
|
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
|
||||||
```
|
```
|
||||||
|
|
||||||
<a name="端到端模型推理"></a>
|
|
||||||
## 四、端到端模型推理
|
|
||||||
|
|
||||||
端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
|
|
||||||
<a name="PGNet端到端模型推理"></a>
|
|
||||||
### 1. PGNet端到端模型推理
|
|
||||||
#### (1). 四边形文本检测模型(ICDAR2015)
|
|
||||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)),可以使用如下命令进行转换:
|
|
||||||
```
|
|
||||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
|
||||||
```
|
|
||||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
|
|
||||||
```
|
|
||||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
|
|
||||||
```
|
|
||||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
|
||||||
|
|
||||||
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
|
|
||||||
|
|
||||||
#### (2). 弯曲文本检测模型(Total-Text)
|
|
||||||
和四边形文本检测模型共用一个推理模型
|
|
||||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
|
||||||
```
|
|
||||||
python3.7 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
|
||||||
```
|
|
||||||
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
|
||||||
|
|
||||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
|
||||||
|
|
||||||
|
|
||||||
<a name="方向分类模型推理"></a>
|
<a name="方向分类模型推理"></a>
|
||||||
## 五、方向分类模型推理
|
## 四、方向分类模型推理
|
||||||
|
|
||||||
下面将介绍方向分类模型推理。
|
下面将介绍方向分类模型推理。
|
||||||
|
|
||||||
|
@ -418,7 +385,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
|
||||||
```
|
```
|
||||||
|
|
||||||
<a name="文本检测、方向分类和文字识别串联推理"></a>
|
<a name="文本检测、方向分类和文字识别串联推理"></a>
|
||||||
## 六、文本检测、方向分类和文字识别串联推理
|
## 五、文本检测、方向分类和文字识别串联推理
|
||||||
<a name="超轻量中文OCR模型推理"></a>
|
<a name="超轻量中文OCR模型推理"></a>
|
||||||
### 1. 超轻量中文OCR模型推理
|
### 1. 超轻量中文OCR模型推理
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
- [一、简介](#简介)
|
- [一、简介](#简介)
|
||||||
- [二、环境配置](#环境配置)
|
- [二、环境配置](#环境配置)
|
||||||
- [三、快速使用](#快速使用)
|
- [三、快速使用](#快速使用)
|
||||||
- [四、模型训练、评估、推理](#快速训练)
|
- [四、模型训练、评估、推理](#模型训练、评估、推理)
|
||||||
|
|
||||||
<a name="简介"></a>
|
<a name="简介"></a>
|
||||||
## 一、简介
|
## 一、简介
|
||||||
|
@ -16,11 +16,13 @@ OCR算法可以分为两阶段算法和端对端的算法。二阶段OCR算法
|
||||||
- 提出基于图的修正模块(GRM)来进一步提高模型识别性能
|
- 提出基于图的修正模块(GRM)来进一步提高模型识别性能
|
||||||
- 精度更高,预测速度更快
|
- 精度更高,预测速度更快
|
||||||
|
|
||||||
PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf), 算法原理图如下所示:
|
PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) ,算法原理图如下所示:
|
||||||
![](../pgnet_framework.png)
|
![](../pgnet_framework.png)
|
||||||
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
|
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
|
||||||
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
|
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
|
||||||
|
|
||||||
其检测识别效果图如下:
|
其检测识别效果图如下:
|
||||||
|
|
||||||
![](../imgs_results/e2e_res_img293_pgnet.png)
|
![](../imgs_results/e2e_res_img293_pgnet.png)
|
||||||
![](../imgs_results/e2e_res_img295_pgnet.png)
|
![](../imgs_results/e2e_res_img295_pgnet.png)
|
||||||
|
|
||||||
|
@ -49,24 +51,24 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.
|
||||||
### 单张图像或者图像集合预测
|
### 单张图像或者图像集合预测
|
||||||
```bash
|
```bash
|
||||||
# 预测image_dir指定的单张图像
|
# 预测image_dir指定的单张图像
|
||||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
|
||||||
|
|
||||||
# 预测image_dir指定的图像集合
|
# 预测image_dir指定的图像集合
|
||||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
|
||||||
|
|
||||||
# 如果想使用CPU进行预测,需设置use_gpu参数为False
|
# 如果想使用CPU进行预测,需设置use_gpu参数为False
|
||||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True --use_gpu=False
|
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True --use_gpu=False
|
||||||
```
|
```
|
||||||
### 可视化结果
|
### 可视化结果
|
||||||
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
||||||
|
|
||||||
<a name="快速训练"></a>
|
<a name="模型训练、评估、推理"></a>
|
||||||
## 四、模型训练、评估、推理
|
## 四、模型训练、评估、推理
|
||||||
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
|
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
|
||||||
|
|
||||||
### 准备数据
|
### 准备数据
|
||||||
下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md)数据集到PaddleOCR/train_data/目录,数据集组织结构:
|
下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构:
|
||||||
```
|
```
|
||||||
/PaddleOCR/train_data/total_text/train/
|
/PaddleOCR/train_data/total_text/train/
|
||||||
|- rgb/ # total_text数据集的训练数据
|
|- rgb/ # total_text数据集的训练数据
|
||||||
|
@ -135,20 +137,20 @@ python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{
|
||||||
### 模型预测
|
### 模型预测
|
||||||
测试单张图像的端到端识别效果
|
测试单张图像的端到端识别效果
|
||||||
```shell
|
```shell
|
||||||
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
|
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
|
||||||
```
|
```
|
||||||
|
|
||||||
测试文件夹下所有图像的端到端识别效果
|
测试文件夹下所有图像的端到端识别效果
|
||||||
```shell
|
```shell
|
||||||
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
|
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
|
||||||
```
|
```
|
||||||
|
|
||||||
### 预测推理
|
### 预测推理
|
||||||
#### (1).四边形文本检测模型(ICDAR2015)
|
#### (1). 四边形文本检测模型(ICDAR2015)
|
||||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换:
|
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换:
|
||||||
```
|
```
|
||||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
|
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
|
||||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
||||||
```
|
```
|
||||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
|
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
|
||||||
```
|
```
|
||||||
|
@ -158,7 +160,7 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
|
||||||
|
|
||||||
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
|
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
|
||||||
|
|
||||||
#### (2).弯曲文本检测模型(Total-Text)
|
#### (2). 弯曲文本检测模型(Total-Text)
|
||||||
对于弯曲文本样例
|
对于弯曲文本样例
|
||||||
|
|
||||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
||||||
|
@ -168,3 +170,10 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
|
||||||
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||||
|
|
||||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
||||||
|
|
||||||
|
#### (3). 性能指标
|
||||||
|
| |det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS (size=640)|
|
||||||
|
| --- | --- | --- | --- | --- | --- | --- | --- |
|
||||||
|
|Paper|85.30|86.80|86.1|-|-|61.7|38.20|
|
||||||
|
|Ours|87.03|82.48|84.69|61.71|58.43|60.03|62.61|
|
||||||
|
*note:PaddleOCR里的PGNet实现针对预测速度做了优化,在精度下降可接受范围内,可以显著提升端对端预测速度*
|
||||||
|
|
|
@ -15,7 +15,7 @@ In recent years, the end-to-end OCR algorithm has been well developed, including
|
||||||
- A graph based modification module (GRM) is proposed to further improve the performance of model recognition
|
- A graph based modification module (GRM) is proposed to further improve the performance of model recognition
|
||||||
- Higher accuracy and faster prediction speed
|
- Higher accuracy and faster prediction speed
|
||||||
|
|
||||||
For details of PGNet algorithm, please refer to [paper](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf), The schematic diagram of the algorithm is as follows:
|
For details of PGNet algorithm, please refer to [paper](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) ,The schematic diagram of the algorithm is as follows:
|
||||||
![](../pgnet_framework.png)
|
![](../pgnet_framework.png)
|
||||||
After feature extraction, the input image is sent to four branches: TBO module for text edge offset prediction, TCL module for text centerline prediction, TDO module for text direction offset prediction, and TCC module for text character classification graph prediction.
|
After feature extraction, the input image is sent to four branches: TBO module for text edge offset prediction, TCL module for text centerline prediction, TDO module for text direction offset prediction, and TCC module for text character classification graph prediction.
|
||||||
The output of TBO and TCL can get text detection results after post-processing, and TCL, TDO and TCC are responsible for text recognition.
|
The output of TBO and TCL can get text detection results after post-processing, and TCL, TDO and TCC are responsible for text recognition.
|
||||||
|
@ -49,13 +49,13 @@ After decompression, there should be the following file structure:
|
||||||
### Single image or image set prediction
|
### Single image or image set prediction
|
||||||
```bash
|
```bash
|
||||||
# Prediction single image specified by image_dir
|
# Prediction single image specified by image_dir
|
||||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
|
||||||
|
|
||||||
# Prediction the collection of images specified by image_dir
|
# Prediction the collection of images specified by image_dir
|
||||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True
|
||||||
|
|
||||||
# If you want to use CPU for prediction, you need to set use_gpu parameter is false
|
# If you want to use CPU for prediction, you need to set use_gpu parameter is false
|
||||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True --use_gpu=False
|
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_server_pgnetA_infer/" --e2e_pgnet_polygon=True --use_gpu=False
|
||||||
```
|
```
|
||||||
### Visualization results
|
### Visualization results
|
||||||
The visualized end-to-end results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
|
The visualized end-to-end results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
|
||||||
|
@ -141,12 +141,12 @@ python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{
|
||||||
### Model Test
|
### Model Test
|
||||||
Test the end-to-end result on a single image:
|
Test the end-to-end result on a single image:
|
||||||
```shell
|
```shell
|
||||||
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
|
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
|
||||||
```
|
```
|
||||||
|
|
||||||
Test the end-to-end result on all images in the folder:
|
Test the end-to-end result on all images in the folder:
|
||||||
```shell
|
```shell
|
||||||
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
|
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/e2e_pgnet/best_accuracy" Global.load_static_weights=false
|
||||||
```
|
```
|
||||||
|
|
||||||
### Model inference
|
### Model inference
|
||||||
|
@ -154,7 +154,7 @@ python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img=
|
||||||
First, convert the model saved in the PGNet end-to-end training process into an inference model. In the first stage of training based on composite dataset, the model of English data set training is taken as an example[model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar), you can use the following command to convert:
|
First, convert the model saved in the PGNet end-to-end training process into an inference model. In the first stage of training based on composite dataset, the model of English data set training is taken as an example[model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar), you can use the following command to convert:
|
||||||
```
|
```
|
||||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
|
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
|
||||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
||||||
```
|
```
|
||||||
**For PGNet quadrangle end-to-end model inference, you need to set the parameter `--e2e_algorithm="PGNet"`**, run the following command:
|
**For PGNet quadrangle end-to-end model inference, you need to set the parameter `--e2e_algorithm="PGNet"`**, run the following command:
|
||||||
```
|
```
|
||||||
|
@ -173,3 +173,9 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
|
||||||
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
|
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'e2e_res'. Examples of results are as follows:
|
||||||
|
|
||||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
||||||
|
#### (3). Performance
|
||||||
|
| |det_precision|det_recall|det_f_score|e2e_precision|e2e_recall|e2e_f_score|FPS (size=640)|
|
||||||
|
| --- | --- | --- | --- | --- | --- | --- | --- |
|
||||||
|
|Paper|85.30|86.80|86.1|-|-|61.7|38.20|
|
||||||
|
|Ours|87.03|82.48|84.69|61.71|58.43|60.03|62.61|
|
||||||
|
*note:PGNet in PaddleOCR optimizes the prediction speed, and can significantly improve the end-to-end prediction speed within the acceptable range of accuracy reduction*
|
||||||
|
|
|
@ -200,18 +200,16 @@ class E2ELabelEncode(BaseRecLabelEncode):
|
||||||
self.pad_num = len(self.dict) # the length to pad
|
self.pad_num = len(self.dict) # the length to pad
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
text_label_index_list, temp_text = [], []
|
|
||||||
texts = data['strs']
|
texts = data['strs']
|
||||||
|
temp_texts = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
text = text.lower()
|
text = text.lower()
|
||||||
temp_text = []
|
text = self.encode(text)
|
||||||
for c_ in text:
|
if text is None:
|
||||||
if c_ in self.dict:
|
return None
|
||||||
temp_text.append(self.dict[c_])
|
text = text + [self.pad_num] * (self.max_text_len - len(text))
|
||||||
temp_text = temp_text + [self.pad_num] * (self.max_text_len -
|
temp_texts.append(text)
|
||||||
len(temp_text))
|
data['strs'] = np.array(temp_texts)
|
||||||
text_label_index_list.append(temp_text)
|
|
||||||
data['strs'] = np.array(text_label_index_list)
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -64,9 +64,6 @@ class PGDataSet(Dataset):
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
poly_str, txt = line.strip().split('\t')
|
poly_str, txt = line.strip().split('\t')
|
||||||
poly = list(map(float, poly_str.split(',')))
|
poly = list(map(float, poly_str.split(',')))
|
||||||
if self.mode.lower() == "eval":
|
|
||||||
while len(poly) < 100:
|
|
||||||
poly.append(-1)
|
|
||||||
text_polys.append(
|
text_polys.append(
|
||||||
np.array(
|
np.array(
|
||||||
poly, dtype=np.float32).reshape(-1, 2))
|
poly, dtype=np.float32).reshape(-1, 2))
|
||||||
|
@ -139,10 +136,6 @@ class PGDataSet(Dataset):
|
||||||
try:
|
try:
|
||||||
if self.data_format == 'icdar':
|
if self.data_format == 'icdar':
|
||||||
im_path = os.path.join(data_path, 'rgb', data_line)
|
im_path = os.path.join(data_path, 'rgb', data_line)
|
||||||
if self.mode.lower() == "eval":
|
|
||||||
poly_path = os.path.join(data_path, 'poly_gt',
|
|
||||||
data_line.split('.')[0] + '.txt')
|
|
||||||
else:
|
|
||||||
poly_path = os.path.join(data_path, 'poly',
|
poly_path = os.path.join(data_path, 'poly',
|
||||||
data_line.split('.')[0] + '.txt')
|
data_line.split('.')[0] + '.txt')
|
||||||
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
|
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
|
||||||
|
@ -150,12 +143,14 @@ class PGDataSet(Dataset):
|
||||||
image_dir = os.path.join(os.path.dirname(data_path), 'image')
|
image_dir = os.path.join(os.path.dirname(data_path), 'image')
|
||||||
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
|
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
|
||||||
data_line, image_dir)
|
data_line, image_dir)
|
||||||
|
img_id = int(data_line.split(".")[0][3:])
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'img_path': im_path,
|
'img_path': im_path,
|
||||||
'polys': text_polys,
|
'polys': text_polys,
|
||||||
'tags': text_tags,
|
'tags': text_tags,
|
||||||
'strs': text_strs
|
'strs': text_strs,
|
||||||
|
'img_id': img_id
|
||||||
}
|
}
|
||||||
with open(data['img_path'], 'rb') as f:
|
with open(data['img_path'], 'rb') as f:
|
||||||
img = f.read()
|
img = f.read()
|
||||||
|
|
|
@ -19,57 +19,28 @@ from __future__ import print_function
|
||||||
__all__ = ['E2EMetric']
|
__all__ = ['E2EMetric']
|
||||||
|
|
||||||
from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results
|
from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results
|
||||||
from ppocr.utils.e2e_utils.extract_textpoint import get_dict
|
from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
|
||||||
|
|
||||||
|
|
||||||
class E2EMetric(object):
|
class E2EMetric(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
gt_mat_dir,
|
||||||
character_dict_path,
|
character_dict_path,
|
||||||
main_indicator='f_score_e2e',
|
main_indicator='f_score_e2e',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
self.gt_mat_dir = gt_mat_dir
|
||||||
self.label_list = get_dict(character_dict_path)
|
self.label_list = get_dict(character_dict_path)
|
||||||
self.max_index = len(self.label_list)
|
self.max_index = len(self.label_list)
|
||||||
self.main_indicator = main_indicator
|
self.main_indicator = main_indicator
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __call__(self, preds, batch, **kwargs):
|
def __call__(self, preds, batch, **kwargs):
|
||||||
temp_gt_polyons_batch = batch[2]
|
img_id = batch[5][0]
|
||||||
temp_gt_strs_batch = batch[3]
|
|
||||||
ignore_tags_batch = batch[4]
|
|
||||||
gt_polyons_batch = []
|
|
||||||
gt_strs_batch = []
|
|
||||||
|
|
||||||
temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
|
|
||||||
for temp_list in temp_gt_polyons_batch:
|
|
||||||
t = []
|
|
||||||
for index in temp_list:
|
|
||||||
if index[0] != -1 and index[1] != -1:
|
|
||||||
t.append(index)
|
|
||||||
gt_polyons_batch.append(t)
|
|
||||||
|
|
||||||
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
|
|
||||||
for temp_list in temp_gt_strs_batch:
|
|
||||||
t = ""
|
|
||||||
for index in temp_list:
|
|
||||||
if index < self.max_index:
|
|
||||||
t += self.label_list[index]
|
|
||||||
gt_strs_batch.append(t)
|
|
||||||
|
|
||||||
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
|
||||||
[preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
|
|
||||||
# prepare gt
|
|
||||||
gt_info_list = [{
|
|
||||||
'points': gt_polyon,
|
|
||||||
'text': gt_str,
|
|
||||||
'ignore': ignore_tag
|
|
||||||
} for gt_polyon, gt_str, ignore_tag in
|
|
||||||
zip(gt_polyons, gt_strs, ignore_tags)]
|
|
||||||
# prepare det
|
|
||||||
e2e_info_list = [{
|
e2e_info_list = [{
|
||||||
'points': det_polyon,
|
'points': det_polyon,
|
||||||
'text': pred_str
|
'text': pred_str
|
||||||
} for det_polyon, pred_str in zip(pred['points'], pred['strs'])]
|
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])]
|
||||||
result = get_socre(gt_info_list, e2e_info_list)
|
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
|
||||||
self.results.append(result)
|
self.results.append(result)
|
||||||
|
|
||||||
def get_metric(self):
|
def get_metric(self):
|
||||||
|
|
|
@ -22,10 +22,7 @@ import sys
|
||||||
__dir__ = os.path.dirname(__file__)
|
__dir__ = os.path.dirname(__file__)
|
||||||
sys.path.append(__dir__)
|
sys.path.append(__dir__)
|
||||||
sys.path.append(os.path.join(__dir__, '..'))
|
sys.path.append(os.path.join(__dir__, '..'))
|
||||||
|
from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess
|
||||||
from ppocr.utils.e2e_utils.extract_textpoint import *
|
|
||||||
from ppocr.utils.e2e_utils.visual import *
|
|
||||||
import paddle
|
|
||||||
|
|
||||||
|
|
||||||
class PGPostProcess(object):
|
class PGPostProcess(object):
|
||||||
|
@ -33,10 +30,12 @@ class PGPostProcess(object):
|
||||||
The post process for PGNet.
|
The post process for PGNet.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs):
|
def __init__(self, character_dict_path, valid_set, score_thresh, mode,
|
||||||
self.Lexicon_Table = get_dict(character_dict_path)
|
**kwargs):
|
||||||
|
self.character_dict_path = character_dict_path
|
||||||
self.valid_set = valid_set
|
self.valid_set = valid_set
|
||||||
self.score_thresh = score_thresh
|
self.score_thresh = score_thresh
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
# c++ la-nms is faster, but only support python 3.5
|
# c++ la-nms is faster, but only support python 3.5
|
||||||
self.is_python35 = False
|
self.is_python35 = False
|
||||||
|
@ -44,112 +43,10 @@ class PGPostProcess(object):
|
||||||
self.is_python35 = True
|
self.is_python35 = True
|
||||||
|
|
||||||
def __call__(self, outs_dict, shape_list):
|
def __call__(self, outs_dict, shape_list):
|
||||||
p_score = outs_dict['f_score']
|
post = PGNet_PostProcess(self.character_dict_path, self.valid_set,
|
||||||
p_border = outs_dict['f_border']
|
self.score_thresh, outs_dict, shape_list)
|
||||||
p_char = outs_dict['f_char']
|
if self.mode == 'fast':
|
||||||
p_direction = outs_dict['f_direction']
|
data = post.pg_postprocess_fast()
|
||||||
if isinstance(p_score, paddle.Tensor):
|
|
||||||
p_score = p_score[0].numpy()
|
|
||||||
p_border = p_border[0].numpy()
|
|
||||||
p_direction = p_direction[0].numpy()
|
|
||||||
p_char = p_char[0].numpy()
|
|
||||||
else:
|
else:
|
||||||
p_score = p_score[0]
|
data = post.pg_postprocess_slow()
|
||||||
p_border = p_border[0]
|
|
||||||
p_direction = p_direction[0]
|
|
||||||
p_char = p_char[0]
|
|
||||||
src_h, src_w, ratio_h, ratio_w = shape_list[0]
|
|
||||||
is_curved = self.valid_set == "totaltext"
|
|
||||||
instance_yxs_list = generate_pivot_list(
|
|
||||||
p_score,
|
|
||||||
p_char,
|
|
||||||
p_direction,
|
|
||||||
score_thresh=self.score_thresh,
|
|
||||||
is_backbone=True,
|
|
||||||
is_curved=is_curved)
|
|
||||||
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
|
|
||||||
char_seq_idx_set = []
|
|
||||||
for i in range(len(instance_yxs_list)):
|
|
||||||
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
|
|
||||||
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
|
|
||||||
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
|
|
||||||
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
|
|
||||||
feature_len = [len(feature_seq[0])]
|
|
||||||
featyre_seq = paddle.to_tensor(feature_seq)
|
|
||||||
feature_len = np.array([feature_len]).astype(np.int64)
|
|
||||||
length = paddle.to_tensor(feature_len)
|
|
||||||
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
|
|
||||||
input=featyre_seq, blank=36, input_length=length)
|
|
||||||
seq_pred_str = seq_pred[0].numpy().tolist()[0]
|
|
||||||
seq_len = seq_pred[1].numpy()[0][0]
|
|
||||||
temp_t = []
|
|
||||||
for c in seq_pred_str[:seq_len]:
|
|
||||||
temp_t.append(c)
|
|
||||||
char_seq_idx_set.append(temp_t)
|
|
||||||
seq_strs = []
|
|
||||||
for char_idx_set in char_seq_idx_set:
|
|
||||||
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
|
|
||||||
seq_strs.append(pr_str)
|
|
||||||
poly_list = []
|
|
||||||
keep_str_list = []
|
|
||||||
all_point_list = []
|
|
||||||
all_point_pair_list = []
|
|
||||||
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
|
||||||
if len(yx_center_line) == 1:
|
|
||||||
yx_center_line.append(yx_center_line[-1])
|
|
||||||
|
|
||||||
offset_expand = 1.0
|
|
||||||
if self.valid_set == 'totaltext':
|
|
||||||
offset_expand = 1.2
|
|
||||||
|
|
||||||
point_pair_list = []
|
|
||||||
for batch_id, y, x in yx_center_line:
|
|
||||||
offset = p_border[:, y, x].reshape(2, 2)
|
|
||||||
if offset_expand != 1.0:
|
|
||||||
offset_length = np.linalg.norm(
|
|
||||||
offset, axis=1, keepdims=True)
|
|
||||||
expand_length = np.clip(
|
|
||||||
offset_length * (offset_expand - 1),
|
|
||||||
a_min=0.5,
|
|
||||||
a_max=3.0)
|
|
||||||
offset_detal = offset / offset_length * expand_length
|
|
||||||
offset = offset + offset_detal
|
|
||||||
ori_yx = np.array([y, x], dtype=np.float32)
|
|
||||||
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
|
|
||||||
[ratio_w, ratio_h]).reshape(-1, 2)
|
|
||||||
point_pair_list.append(point_pair)
|
|
||||||
|
|
||||||
all_point_list.append([
|
|
||||||
int(round(x * 4.0 / ratio_w)),
|
|
||||||
int(round(y * 4.0 / ratio_h))
|
|
||||||
])
|
|
||||||
all_point_pair_list.append(point_pair.round().astype(np.int32)
|
|
||||||
.tolist())
|
|
||||||
|
|
||||||
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
|
|
||||||
detected_poly = expand_poly_along_width(
|
|
||||||
detected_poly, shrink_ratio_of_width=0.2)
|
|
||||||
detected_poly[:, 0] = np.clip(
|
|
||||||
detected_poly[:, 0], a_min=0, a_max=src_w)
|
|
||||||
detected_poly[:, 1] = np.clip(
|
|
||||||
detected_poly[:, 1], a_min=0, a_max=src_h)
|
|
||||||
|
|
||||||
if len(keep_str) < 2:
|
|
||||||
continue
|
|
||||||
|
|
||||||
keep_str_list.append(keep_str)
|
|
||||||
if self.valid_set == 'partvgg':
|
|
||||||
middle_point = len(detected_poly) // 2
|
|
||||||
detected_poly = detected_poly[
|
|
||||||
[0, middle_point - 1, middle_point, -1], :]
|
|
||||||
poly_list.append(detected_poly)
|
|
||||||
elif self.valid_set == 'totaltext':
|
|
||||||
poly_list.append(detected_poly)
|
|
||||||
else:
|
|
||||||
print('--> Not supported format.')
|
|
||||||
exit(-1)
|
|
||||||
data = {
|
|
||||||
'points': poly_list,
|
|
||||||
'strs': keep_str_list,
|
|
||||||
}
|
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -13,10 +13,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import scipy.io as io
|
||||||
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
||||||
|
|
||||||
|
|
||||||
def get_socre(gt_dict, pred_dict):
|
def get_socre(gt_dir, img_id, pred_dict):
|
||||||
allInputs = 1
|
allInputs = 1
|
||||||
|
|
||||||
def input_reading_mod(pred_dict):
|
def input_reading_mod(pred_dict):
|
||||||
|
@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict):
|
||||||
det.append([point, text])
|
det.append([point, text])
|
||||||
return det
|
return det
|
||||||
|
|
||||||
def gt_reading_mod(gt_dict):
|
def gt_reading_mod(gt_dir, gt_id):
|
||||||
"""This helper reads groundtruths from mat files"""
|
gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
|
||||||
gt = []
|
gt = gt['polygt']
|
||||||
n = len(gt_dict)
|
|
||||||
for i in range(n):
|
|
||||||
points = gt_dict[i]['points']
|
|
||||||
h = len(points)
|
|
||||||
text = gt_dict[i]['text']
|
|
||||||
xx = [
|
|
||||||
np.array(
|
|
||||||
['x:'], dtype='<U2'), 0, np.array(
|
|
||||||
['y:'], dtype='<U2'), 0, np.array(
|
|
||||||
['#'], dtype='<U1'), np.array(
|
|
||||||
['#'], dtype='<U1')
|
|
||||||
]
|
|
||||||
t_x, t_y = [], []
|
|
||||||
for j in range(h):
|
|
||||||
t_x.append(points[j][0])
|
|
||||||
t_y.append(points[j][1])
|
|
||||||
xx[1] = np.array([t_x], dtype='int16')
|
|
||||||
xx[3] = np.array([t_y], dtype='int16')
|
|
||||||
if text != "" and "#" not in text:
|
|
||||||
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
|
|
||||||
xx[5] = np.array(['c'], dtype='<U1')
|
|
||||||
gt.append(xx)
|
|
||||||
return gt
|
return gt
|
||||||
|
|
||||||
def detection_filtering(detections, groundtruths, threshold=0.5):
|
def detection_filtering(detections, groundtruths, threshold=0.5):
|
||||||
|
@ -101,7 +80,7 @@ def get_socre(gt_dict, pred_dict):
|
||||||
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
|
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
|
||||||
and (input_id != 'Deteval_result_non_curved.txt'):
|
and (input_id != 'Deteval_result_non_curved.txt'):
|
||||||
detections = input_reading_mod(pred_dict)
|
detections = input_reading_mod(pred_dict)
|
||||||
groundtruths = gt_reading_mod(gt_dict)
|
groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
|
||||||
detections = detection_filtering(
|
detections = detection_filtering(
|
||||||
detections,
|
detections,
|
||||||
groundtruths) # filters detections overlapping with DC area
|
groundtruths) # filters detections overlapping with DC area
|
||||||
|
|
|
@ -0,0 +1,457 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Contains various CTC decoders."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from itertools import groupby
|
||||||
|
from skimage.morphology._skeletonize import thin
|
||||||
|
|
||||||
|
|
||||||
|
def get_dict(character_dict_path):
|
||||||
|
character_str = ""
|
||||||
|
with open(character_dict_path, "rb") as fin:
|
||||||
|
lines = fin.readlines()
|
||||||
|
for line in lines:
|
||||||
|
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||||
|
character_str += line
|
||||||
|
dict_character = list(character_str)
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
|
||||||
|
def softmax(logits):
|
||||||
|
"""
|
||||||
|
logits: N x d
|
||||||
|
"""
|
||||||
|
max_value = np.max(logits, axis=1, keepdims=True)
|
||||||
|
exp = np.exp(logits - max_value)
|
||||||
|
exp_sum = np.sum(exp, axis=1, keepdims=True)
|
||||||
|
dist = exp / exp_sum
|
||||||
|
return dist
|
||||||
|
|
||||||
|
|
||||||
|
def get_keep_pos_idxs(labels, remove_blank=None):
|
||||||
|
"""
|
||||||
|
Remove duplicate and get pos idxs of keep items.
|
||||||
|
The value of keep_blank should be [None, 95].
|
||||||
|
"""
|
||||||
|
duplicate_len_list = []
|
||||||
|
keep_pos_idx_list = []
|
||||||
|
keep_char_idx_list = []
|
||||||
|
for k, v_ in groupby(labels):
|
||||||
|
current_len = len(list(v_))
|
||||||
|
if k != remove_blank:
|
||||||
|
current_idx = int(sum(duplicate_len_list) + current_len // 2)
|
||||||
|
keep_pos_idx_list.append(current_idx)
|
||||||
|
keep_char_idx_list.append(k)
|
||||||
|
duplicate_len_list.append(current_len)
|
||||||
|
return keep_char_idx_list, keep_pos_idx_list
|
||||||
|
|
||||||
|
|
||||||
|
def remove_blank(labels, blank=0):
|
||||||
|
new_labels = [x for x in labels if x != blank]
|
||||||
|
return new_labels
|
||||||
|
|
||||||
|
|
||||||
|
def insert_blank(labels, blank=0):
|
||||||
|
new_labels = [blank]
|
||||||
|
for l in labels:
|
||||||
|
new_labels += [l, blank]
|
||||||
|
return new_labels
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
|
||||||
|
"""
|
||||||
|
CTC greedy (best path) decoder.
|
||||||
|
"""
|
||||||
|
raw_str = np.argmax(np.array(probs_seq), axis=1)
|
||||||
|
remove_blank_in_pos = None if keep_blank_in_idxs else blank
|
||||||
|
dedup_str, keep_idx_list = get_keep_pos_idxs(
|
||||||
|
raw_str, remove_blank=remove_blank_in_pos)
|
||||||
|
dst_str = remove_blank(dedup_str, blank=blank)
|
||||||
|
return dst_str, keep_idx_list
|
||||||
|
|
||||||
|
|
||||||
|
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
|
||||||
|
_, _, C = logits_map.shape
|
||||||
|
ys, xs = zip(*gather_info)
|
||||||
|
logits_seq = logits_map[list(ys), list(xs)]
|
||||||
|
probs_seq = logits_seq
|
||||||
|
labels = np.argmax(probs_seq, axis=1)
|
||||||
|
dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
|
||||||
|
detal = len(gather_info) // (pts_num - 1)
|
||||||
|
keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
|
||||||
|
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
|
||||||
|
return dst_str, keep_gather_list
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_decoder_for_image(gather_info_list,
|
||||||
|
logits_map,
|
||||||
|
Lexicon_Table,
|
||||||
|
pts_num=6):
|
||||||
|
"""
|
||||||
|
CTC decoder using multiple processes.
|
||||||
|
"""
|
||||||
|
decoder_str = []
|
||||||
|
decoder_xys = []
|
||||||
|
for gather_info in gather_info_list:
|
||||||
|
if len(gather_info) < pts_num:
|
||||||
|
continue
|
||||||
|
dst_str, xys_list = instance_ctc_greedy_decoder(
|
||||||
|
gather_info, logits_map, pts_num=pts_num)
|
||||||
|
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
|
||||||
|
if len(dst_str_readable) < 2:
|
||||||
|
continue
|
||||||
|
decoder_str.append(dst_str_readable)
|
||||||
|
decoder_xys.append(xys_list)
|
||||||
|
return decoder_str, decoder_xys
|
||||||
|
|
||||||
|
|
||||||
|
def sort_with_direction(pos_list, f_direction):
|
||||||
|
"""
|
||||||
|
f_direction: h x w x 2
|
||||||
|
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sort_part_with_direction(pos_list, point_direction):
|
||||||
|
pos_list = np.array(pos_list).reshape(-1, 2)
|
||||||
|
point_direction = np.array(point_direction).reshape(-1, 2)
|
||||||
|
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||||
|
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||||
|
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
|
||||||
|
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
||||||
|
return sorted_list, sorted_direction
|
||||||
|
|
||||||
|
pos_list = np.array(pos_list).reshape(-1, 2)
|
||||||
|
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
||||||
|
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||||
|
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
|
||||||
|
point_direction)
|
||||||
|
|
||||||
|
point_num = len(sorted_point)
|
||||||
|
if point_num >= 16:
|
||||||
|
middle_num = point_num // 2
|
||||||
|
first_part_point = sorted_point[:middle_num]
|
||||||
|
first_point_direction = sorted_direction[:middle_num]
|
||||||
|
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
||||||
|
first_part_point, first_point_direction)
|
||||||
|
|
||||||
|
last_part_point = sorted_point[middle_num:]
|
||||||
|
last_point_direction = sorted_direction[middle_num:]
|
||||||
|
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
||||||
|
last_part_point, last_point_direction)
|
||||||
|
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
||||||
|
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
||||||
|
|
||||||
|
return sorted_point, np.array(sorted_direction)
|
||||||
|
|
||||||
|
|
||||||
|
def add_id(pos_list, image_id=0):
|
||||||
|
"""
|
||||||
|
Add id for gather feature, for inference.
|
||||||
|
"""
|
||||||
|
new_list = []
|
||||||
|
for item in pos_list:
|
||||||
|
new_list.append((image_id, item[0], item[1]))
|
||||||
|
return new_list
|
||||||
|
|
||||||
|
|
||||||
|
def sort_and_expand_with_direction(pos_list, f_direction):
|
||||||
|
"""
|
||||||
|
f_direction: h x w x 2
|
||||||
|
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||||
|
"""
|
||||||
|
h, w, _ = f_direction.shape
|
||||||
|
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
||||||
|
|
||||||
|
point_num = len(sorted_list)
|
||||||
|
sub_direction_len = max(point_num // 3, 2)
|
||||||
|
left_direction = point_direction[:sub_direction_len, :]
|
||||||
|
right_dirction = point_direction[point_num - sub_direction_len:, :]
|
||||||
|
|
||||||
|
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
||||||
|
left_average_len = np.linalg.norm(left_average_direction)
|
||||||
|
left_start = np.array(sorted_list[0])
|
||||||
|
left_step = left_average_direction / (left_average_len + 1e-6)
|
||||||
|
|
||||||
|
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
||||||
|
right_average_len = np.linalg.norm(right_average_direction)
|
||||||
|
right_step = right_average_direction / (right_average_len + 1e-6)
|
||||||
|
right_start = np.array(sorted_list[-1])
|
||||||
|
|
||||||
|
append_num = max(
|
||||||
|
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
||||||
|
left_list = []
|
||||||
|
right_list = []
|
||||||
|
for i in range(append_num):
|
||||||
|
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
|
||||||
|
'int32').tolist()
|
||||||
|
if ly < h and lx < w and (ly, lx) not in left_list:
|
||||||
|
left_list.append((ly, lx))
|
||||||
|
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
|
||||||
|
'int32').tolist()
|
||||||
|
if ry < h and rx < w and (ry, rx) not in right_list:
|
||||||
|
right_list.append((ry, rx))
|
||||||
|
|
||||||
|
all_list = left_list[::-1] + sorted_list + right_list
|
||||||
|
return all_list
|
||||||
|
|
||||||
|
|
||||||
|
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
|
||||||
|
"""
|
||||||
|
f_direction: h x w x 2
|
||||||
|
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||||
|
binary_tcl_map: h x w
|
||||||
|
"""
|
||||||
|
h, w, _ = f_direction.shape
|
||||||
|
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
||||||
|
|
||||||
|
point_num = len(sorted_list)
|
||||||
|
sub_direction_len = max(point_num // 3, 2)
|
||||||
|
left_direction = point_direction[:sub_direction_len, :]
|
||||||
|
right_dirction = point_direction[point_num - sub_direction_len:, :]
|
||||||
|
|
||||||
|
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
||||||
|
left_average_len = np.linalg.norm(left_average_direction)
|
||||||
|
left_start = np.array(sorted_list[0])
|
||||||
|
left_step = left_average_direction / (left_average_len + 1e-6)
|
||||||
|
|
||||||
|
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
||||||
|
right_average_len = np.linalg.norm(right_average_direction)
|
||||||
|
right_step = right_average_direction / (right_average_len + 1e-6)
|
||||||
|
right_start = np.array(sorted_list[-1])
|
||||||
|
|
||||||
|
append_num = max(
|
||||||
|
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
||||||
|
max_append_num = 2 * append_num
|
||||||
|
|
||||||
|
left_list = []
|
||||||
|
right_list = []
|
||||||
|
for i in range(max_append_num):
|
||||||
|
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
|
||||||
|
'int32').tolist()
|
||||||
|
if ly < h and lx < w and (ly, lx) not in left_list:
|
||||||
|
if binary_tcl_map[ly, lx] > 0.5:
|
||||||
|
left_list.append((ly, lx))
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
for i in range(max_append_num):
|
||||||
|
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
|
||||||
|
'int32').tolist()
|
||||||
|
if ry < h and rx < w and (ry, rx) not in right_list:
|
||||||
|
if binary_tcl_map[ry, rx] > 0.5:
|
||||||
|
right_list.append((ry, rx))
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
all_list = left_list[::-1] + sorted_list + right_list
|
||||||
|
return all_list
|
||||||
|
|
||||||
|
|
||||||
|
def point_pair2poly(point_pair_list):
|
||||||
|
"""
|
||||||
|
Transfer vertical point_pairs into poly point in clockwise.
|
||||||
|
"""
|
||||||
|
point_num = len(point_pair_list) * 2
|
||||||
|
point_list = [0] * point_num
|
||||||
|
for idx, point_pair in enumerate(point_pair_list):
|
||||||
|
point_list[idx] = point_pair[0]
|
||||||
|
point_list[point_num - 1 - idx] = point_pair[1]
|
||||||
|
return np.array(point_list).reshape(-1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||||
|
ratio_pair = np.array(
|
||||||
|
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||||
|
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||||
|
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||||
|
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||||
|
|
||||||
|
|
||||||
|
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
|
||||||
|
"""
|
||||||
|
expand poly along width.
|
||||||
|
"""
|
||||||
|
point_num = poly.shape[0]
|
||||||
|
left_quad = np.array(
|
||||||
|
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||||
|
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||||
|
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||||
|
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||||
|
right_quad = np.array(
|
||||||
|
[
|
||||||
|
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||||
|
poly[point_num // 2], poly[point_num // 2 + 1]
|
||||||
|
],
|
||||||
|
dtype=np.float32)
|
||||||
|
right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||||
|
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||||
|
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||||
|
poly[0] = left_quad_expand[0]
|
||||||
|
poly[-1] = left_quad_expand[-1]
|
||||||
|
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||||
|
poly[point_num // 2] = right_quad_expand[2]
|
||||||
|
return poly
|
||||||
|
|
||||||
|
|
||||||
|
def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
|
||||||
|
src_h, valid_set):
|
||||||
|
poly_list = []
|
||||||
|
keep_str_list = []
|
||||||
|
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
||||||
|
if len(keep_str) < 2:
|
||||||
|
print('--> too short, {}'.format(keep_str))
|
||||||
|
continue
|
||||||
|
|
||||||
|
offset_expand = 1.0
|
||||||
|
if valid_set == 'totaltext':
|
||||||
|
offset_expand = 1.2
|
||||||
|
|
||||||
|
point_pair_list = []
|
||||||
|
for y, x in yx_center_line:
|
||||||
|
offset = p_border[:, y, x].reshape(2, 2) * offset_expand
|
||||||
|
ori_yx = np.array([y, x], dtype=np.float32)
|
||||||
|
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
|
||||||
|
[ratio_w, ratio_h]).reshape(-1, 2)
|
||||||
|
point_pair_list.append(point_pair)
|
||||||
|
|
||||||
|
detected_poly = point_pair2poly(point_pair_list)
|
||||||
|
detected_poly = expand_poly_along_width(
|
||||||
|
detected_poly, shrink_ratio_of_width=0.2)
|
||||||
|
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||||
|
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||||
|
|
||||||
|
keep_str_list.append(keep_str)
|
||||||
|
if valid_set == 'partvgg':
|
||||||
|
middle_point = len(detected_poly) // 2
|
||||||
|
detected_poly = detected_poly[
|
||||||
|
[0, middle_point - 1, middle_point, -1], :]
|
||||||
|
poly_list.append(detected_poly)
|
||||||
|
elif valid_set == 'totaltext':
|
||||||
|
poly_list.append(detected_poly)
|
||||||
|
else:
|
||||||
|
print('--> Not supported format.')
|
||||||
|
exit(-1)
|
||||||
|
return poly_list, keep_str_list
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pivot_list_fast(p_score,
|
||||||
|
p_char_maps,
|
||||||
|
f_direction,
|
||||||
|
Lexicon_Table,
|
||||||
|
score_thresh=0.5):
|
||||||
|
"""
|
||||||
|
return center point and end point of TCL instance; filter with the char maps;
|
||||||
|
"""
|
||||||
|
p_score = p_score[0]
|
||||||
|
f_direction = f_direction.transpose(1, 2, 0)
|
||||||
|
p_tcl_map = (p_score > score_thresh) * 1.0
|
||||||
|
skeleton_map = thin(p_tcl_map.astype(np.uint8))
|
||||||
|
instance_count, instance_label_map = cv2.connectedComponents(
|
||||||
|
skeleton_map.astype(np.uint8), connectivity=8)
|
||||||
|
|
||||||
|
# get TCL Instance
|
||||||
|
all_pos_yxs = []
|
||||||
|
if instance_count > 0:
|
||||||
|
for instance_id in range(1, instance_count):
|
||||||
|
pos_list = []
|
||||||
|
ys, xs = np.where(instance_label_map == instance_id)
|
||||||
|
pos_list = list(zip(ys, xs))
|
||||||
|
|
||||||
|
if len(pos_list) < 3:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pos_list_sorted = sort_and_expand_with_direction_v2(
|
||||||
|
pos_list, f_direction, p_tcl_map)
|
||||||
|
all_pos_yxs.append(pos_list_sorted)
|
||||||
|
|
||||||
|
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||||
|
decoded_str, keep_yxs_list = ctc_decoder_for_image(
|
||||||
|
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
|
||||||
|
return keep_yxs_list, decoded_str
|
||||||
|
|
||||||
|
|
||||||
|
def extract_main_direction(pos_list, f_direction):
|
||||||
|
"""
|
||||||
|
f_direction: h x w x 2
|
||||||
|
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||||
|
"""
|
||||||
|
pos_list = np.array(pos_list)
|
||||||
|
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
|
||||||
|
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||||
|
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||||
|
average_direction = average_direction / (
|
||||||
|
np.linalg.norm(average_direction) + 1e-6)
|
||||||
|
return average_direction
|
||||||
|
|
||||||
|
|
||||||
|
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
|
||||||
|
"""
|
||||||
|
f_direction: h x w x 2
|
||||||
|
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
|
||||||
|
"""
|
||||||
|
pos_list_full = np.array(pos_list).reshape(-1, 3)
|
||||||
|
pos_list = pos_list_full[:, 1:]
|
||||||
|
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
||||||
|
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||||
|
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||||
|
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||||
|
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
||||||
|
return sorted_list
|
||||||
|
|
||||||
|
|
||||||
|
def sort_by_direction_with_image_id(pos_list, f_direction):
|
||||||
|
"""
|
||||||
|
f_direction: h x w x 2
|
||||||
|
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def sort_part_with_direction(pos_list_full, point_direction):
|
||||||
|
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
|
||||||
|
pos_list = pos_list_full[:, 1:]
|
||||||
|
point_direction = np.array(point_direction).reshape(-1, 2)
|
||||||
|
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||||
|
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||||
|
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
||||||
|
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
||||||
|
return sorted_list, sorted_direction
|
||||||
|
|
||||||
|
pos_list = np.array(pos_list).reshape(-1, 3)
|
||||||
|
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
|
||||||
|
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||||
|
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
|
||||||
|
point_direction)
|
||||||
|
|
||||||
|
point_num = len(sorted_point)
|
||||||
|
if point_num >= 16:
|
||||||
|
middle_num = point_num // 2
|
||||||
|
first_part_point = sorted_point[:middle_num]
|
||||||
|
first_point_direction = sorted_direction[:middle_num]
|
||||||
|
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
||||||
|
first_part_point, first_point_direction)
|
||||||
|
|
||||||
|
last_part_point = sorted_point[middle_num:]
|
||||||
|
last_point_direction = sorted_direction[middle_num:]
|
||||||
|
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
||||||
|
last_part_point, last_point_direction)
|
||||||
|
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
||||||
|
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
||||||
|
|
||||||
|
return sorted_point
|
|
@ -35,6 +35,64 @@ def get_dict(character_dict_path):
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
|
|
||||||
|
def point_pair2poly(point_pair_list):
|
||||||
|
"""
|
||||||
|
Transfer vertical point_pairs into poly point in clockwise.
|
||||||
|
"""
|
||||||
|
pair_length_list = []
|
||||||
|
for point_pair in point_pair_list:
|
||||||
|
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
|
||||||
|
pair_length_list.append(pair_length)
|
||||||
|
pair_length_list = np.array(pair_length_list)
|
||||||
|
pair_info = (pair_length_list.max(), pair_length_list.min(),
|
||||||
|
pair_length_list.mean())
|
||||||
|
|
||||||
|
point_num = len(point_pair_list) * 2
|
||||||
|
point_list = [0] * point_num
|
||||||
|
for idx, point_pair in enumerate(point_pair_list):
|
||||||
|
point_list[idx] = point_pair[0]
|
||||||
|
point_list[point_num - 1 - idx] = point_pair[1]
|
||||||
|
return np.array(point_list).reshape(-1, 2), pair_info
|
||||||
|
|
||||||
|
|
||||||
|
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||||
|
"""
|
||||||
|
Generate shrink_quad_along_width.
|
||||||
|
"""
|
||||||
|
ratio_pair = np.array(
|
||||||
|
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||||
|
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||||
|
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||||
|
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||||
|
|
||||||
|
|
||||||
|
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
|
||||||
|
"""
|
||||||
|
expand poly along width.
|
||||||
|
"""
|
||||||
|
point_num = poly.shape[0]
|
||||||
|
left_quad = np.array(
|
||||||
|
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||||
|
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||||
|
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||||
|
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||||
|
right_quad = np.array(
|
||||||
|
[
|
||||||
|
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||||
|
poly[point_num // 2], poly[point_num // 2 + 1]
|
||||||
|
],
|
||||||
|
dtype=np.float32)
|
||||||
|
right_ratio = 1.0 + \
|
||||||
|
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||||
|
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||||
|
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||||
|
poly[0] = left_quad_expand[0]
|
||||||
|
poly[-1] = left_quad_expand[-1]
|
||||||
|
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||||
|
poly[point_num // 2] = right_quad_expand[2]
|
||||||
|
return poly
|
||||||
|
|
||||||
|
|
||||||
def softmax(logits):
|
def softmax(logits):
|
||||||
"""
|
"""
|
||||||
logits: N x d
|
logits: N x d
|
||||||
|
@ -399,7 +457,7 @@ def generate_pivot_list_horizontal(p_score,
|
||||||
return center_pos_yxs, end_points_yxs
|
return center_pos_yxs, end_points_yxs
|
||||||
|
|
||||||
|
|
||||||
def generate_pivot_list(p_score,
|
def generate_pivot_list_slow(p_score,
|
||||||
p_char_maps,
|
p_char_maps,
|
||||||
f_direction,
|
f_direction,
|
||||||
score_thresh=0.5,
|
score_thresh=0.5,
|
|
@ -0,0 +1,181 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
import paddle
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__dir__ = os.path.dirname(__file__)
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.join(__dir__, '..'))
|
||||||
|
from extract_textpoint_slow import *
|
||||||
|
from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
|
||||||
|
|
||||||
|
|
||||||
|
class PGNet_PostProcess(object):
|
||||||
|
# two different post-process
|
||||||
|
def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict,
|
||||||
|
shape_list):
|
||||||
|
self.Lexicon_Table = get_dict(character_dict_path)
|
||||||
|
self.valid_set = valid_set
|
||||||
|
self.score_thresh = score_thresh
|
||||||
|
self.outs_dict = outs_dict
|
||||||
|
self.shape_list = shape_list
|
||||||
|
|
||||||
|
def pg_postprocess_fast(self):
|
||||||
|
p_score = self.outs_dict['f_score']
|
||||||
|
p_border = self.outs_dict['f_border']
|
||||||
|
p_char = self.outs_dict['f_char']
|
||||||
|
p_direction = self.outs_dict['f_direction']
|
||||||
|
if isinstance(p_score, paddle.Tensor):
|
||||||
|
p_score = p_score[0].numpy()
|
||||||
|
p_border = p_border[0].numpy()
|
||||||
|
p_direction = p_direction[0].numpy()
|
||||||
|
p_char = p_char[0].numpy()
|
||||||
|
else:
|
||||||
|
p_score = p_score[0]
|
||||||
|
p_border = p_border[0]
|
||||||
|
p_direction = p_direction[0]
|
||||||
|
p_char = p_char[0]
|
||||||
|
|
||||||
|
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
|
||||||
|
instance_yxs_list, seq_strs = generate_pivot_list_fast(
|
||||||
|
p_score,
|
||||||
|
p_char,
|
||||||
|
p_direction,
|
||||||
|
self.Lexicon_Table,
|
||||||
|
score_thresh=self.score_thresh)
|
||||||
|
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
|
||||||
|
p_border, ratio_w, ratio_h,
|
||||||
|
src_w, src_h, self.valid_set)
|
||||||
|
data = {
|
||||||
|
'points': poly_list,
|
||||||
|
'strs': keep_str_list,
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def pg_postprocess_slow(self):
|
||||||
|
p_score = self.outs_dict['f_score']
|
||||||
|
p_border = self.outs_dict['f_border']
|
||||||
|
p_char = self.outs_dict['f_char']
|
||||||
|
p_direction = self.outs_dict['f_direction']
|
||||||
|
if isinstance(p_score, paddle.Tensor):
|
||||||
|
p_score = p_score[0].numpy()
|
||||||
|
p_border = p_border[0].numpy()
|
||||||
|
p_direction = p_direction[0].numpy()
|
||||||
|
p_char = p_char[0].numpy()
|
||||||
|
else:
|
||||||
|
p_score = p_score[0]
|
||||||
|
p_border = p_border[0]
|
||||||
|
p_direction = p_direction[0]
|
||||||
|
p_char = p_char[0]
|
||||||
|
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
|
||||||
|
is_curved = self.valid_set == "totaltext"
|
||||||
|
instance_yxs_list = generate_pivot_list_slow(
|
||||||
|
p_score,
|
||||||
|
p_char,
|
||||||
|
p_direction,
|
||||||
|
score_thresh=self.score_thresh,
|
||||||
|
is_backbone=True,
|
||||||
|
is_curved=is_curved)
|
||||||
|
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
|
||||||
|
char_seq_idx_set = []
|
||||||
|
for i in range(len(instance_yxs_list)):
|
||||||
|
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
|
||||||
|
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
|
||||||
|
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
|
||||||
|
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
|
||||||
|
feature_len = [len(feature_seq[0])]
|
||||||
|
featyre_seq = paddle.to_tensor(feature_seq)
|
||||||
|
feature_len = np.array([feature_len]).astype(np.int64)
|
||||||
|
length = paddle.to_tensor(feature_len)
|
||||||
|
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
|
||||||
|
input=featyre_seq, blank=36, input_length=length)
|
||||||
|
seq_pred_str = seq_pred[0].numpy().tolist()[0]
|
||||||
|
seq_len = seq_pred[1].numpy()[0][0]
|
||||||
|
temp_t = []
|
||||||
|
for c in seq_pred_str[:seq_len]:
|
||||||
|
temp_t.append(c)
|
||||||
|
char_seq_idx_set.append(temp_t)
|
||||||
|
seq_strs = []
|
||||||
|
for char_idx_set in char_seq_idx_set:
|
||||||
|
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
|
||||||
|
seq_strs.append(pr_str)
|
||||||
|
poly_list = []
|
||||||
|
keep_str_list = []
|
||||||
|
all_point_list = []
|
||||||
|
all_point_pair_list = []
|
||||||
|
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
||||||
|
if len(yx_center_line) == 1:
|
||||||
|
yx_center_line.append(yx_center_line[-1])
|
||||||
|
|
||||||
|
offset_expand = 1.0
|
||||||
|
if self.valid_set == 'totaltext':
|
||||||
|
offset_expand = 1.2
|
||||||
|
|
||||||
|
point_pair_list = []
|
||||||
|
for batch_id, y, x in yx_center_line:
|
||||||
|
offset = p_border[:, y, x].reshape(2, 2)
|
||||||
|
if offset_expand != 1.0:
|
||||||
|
offset_length = np.linalg.norm(
|
||||||
|
offset, axis=1, keepdims=True)
|
||||||
|
expand_length = np.clip(
|
||||||
|
offset_length * (offset_expand - 1),
|
||||||
|
a_min=0.5,
|
||||||
|
a_max=3.0)
|
||||||
|
offset_detal = offset / offset_length * expand_length
|
||||||
|
offset = offset + offset_detal
|
||||||
|
ori_yx = np.array([y, x], dtype=np.float32)
|
||||||
|
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
|
||||||
|
[ratio_w, ratio_h]).reshape(-1, 2)
|
||||||
|
point_pair_list.append(point_pair)
|
||||||
|
|
||||||
|
all_point_list.append([
|
||||||
|
int(round(x * 4.0 / ratio_w)),
|
||||||
|
int(round(y * 4.0 / ratio_h))
|
||||||
|
])
|
||||||
|
all_point_pair_list.append(point_pair.round().astype(np.int32)
|
||||||
|
.tolist())
|
||||||
|
|
||||||
|
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
|
||||||
|
detected_poly = expand_poly_along_width(
|
||||||
|
detected_poly, shrink_ratio_of_width=0.2)
|
||||||
|
detected_poly[:, 0] = np.clip(
|
||||||
|
detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||||
|
detected_poly[:, 1] = np.clip(
|
||||||
|
detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||||
|
|
||||||
|
if len(keep_str) < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
keep_str_list.append(keep_str)
|
||||||
|
detected_poly = np.round(detected_poly).astype('int32')
|
||||||
|
if self.valid_set == 'partvgg':
|
||||||
|
middle_point = len(detected_poly) // 2
|
||||||
|
detected_poly = detected_poly[
|
||||||
|
[0, middle_point - 1, middle_point, -1], :]
|
||||||
|
poly_list.append(detected_poly)
|
||||||
|
elif self.valid_set == 'totaltext':
|
||||||
|
poly_list.append(detected_poly)
|
||||||
|
else:
|
||||||
|
print('--> Not supported format.')
|
||||||
|
exit(-1)
|
||||||
|
data = {
|
||||||
|
'points': poly_list,
|
||||||
|
'strs': keep_str_list,
|
||||||
|
}
|
||||||
|
return data
|
Loading…
Reference in New Issue