Merge pull request #1438 from MissPenguin/dygraph
update inference for east & sast
This commit is contained in:
commit
80fc4f59dc
|
@ -17,17 +17,17 @@ PaddleOCR开源的文本检测算法列表:
|
|||
|
||||
|模型|骨干网络|precision|recall|Hmean|下载链接|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|EAST|ResNet50_vd|88.18%|85.51%|86.82%|[下载链接 (coming soon)](link)|
|
||||
|EAST|MobileNetV3|81.67%|79.83%|80.74%|[下载链接 (coming soon)](coming soon)|
|
||||
|DB|ResNet50_vd|83.79%|80.65%|82.19%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|
||||
|DB|MobileNetV3|75.92%|73.18%|74.53%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|
||||
|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[下载链接 (coming soon)](link)|
|
||||
|EAST|ResNet50_vd|88.76%|81.36%|84.90%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|
||||
|EAST|MobileNetV3|78.24%|79.15%|78.69%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|
||||
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|
||||
|DB|MobileNetV3|77.29%|73.08%|75.12%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|
||||
|SAST|ResNet50_vd|91.83%|81.80%|86.52%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar))|
|
||||
|
||||
在Total-text文本检测公开数据集上,算法效果如下:
|
||||
|
||||
|模型|骨干网络|precision|recall|Hmean|下载链接|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|SAST|ResNet50_vd|88.74%|79.80%|84.03%|[下载链接 (coming soon)](link)|
|
||||
|SAST|ResNet50_vd|89.05%|76.80%|82.47%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)|
|
||||
|
||||
**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:[百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi)
|
||||
|
||||
|
@ -47,13 +47,10 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
参考[DTRB](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|模型|骨干网络|Avg Accuracy|模型存储命名|下载链接|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)|
|
||||
|-|-|-|-|-|
|
||||
|Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)|
|
||||
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
||||
|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)|
|
||||
|CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)|
|
||||
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[下载链接 (coming soon )]()|
|
||||
|STAR-Net|Resnet34_vd|83.93%|rec_r34_vd_tps_bilstm_ctc|[下载链接 (coming soon )]()|
|
||||
|
||||
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
||||
|
||||
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
||||
|
|
|
@ -180,7 +180,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_
|
|||
<a name="EAST文本检测模型推理"></a>
|
||||
### 3. EAST文本检测模型推理
|
||||
|
||||
首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址 (coming soon)](link) ),可以使用如下命令进行转换:
|
||||
首先将EAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar) ),可以使用如下命令进行转换:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east
|
||||
|
@ -193,7 +193,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img
|
|||
```
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
|
||||
(coming soon)
|
||||
![](../imgs_results/det_res_img_10_east.jpg)
|
||||
|
||||
**注意**:本代码库中,EAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
|
||||
|
||||
|
@ -201,7 +201,7 @@ python3 tools/infer/predict_det.py --det_algorithm="EAST" --image_dir="./doc/img
|
|||
<a name="SAST文本检测模型推理"></a>
|
||||
### 4. SAST文本检测模型推理
|
||||
#### (1). 四边形文本检测模型(ICDAR2015)
|
||||
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换:
|
||||
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
|
||||
```
|
||||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15
|
||||
|
||||
|
@ -212,10 +212,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
|
|||
```
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
|
||||
(coming soon)
|
||||
![](../imgs_results/det_res_img_10_sast.jpg)
|
||||
|
||||
#### (2). 弯曲文本检测模型(Total-Text)
|
||||
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址(coming soon)](link)),可以使用如下命令进行转换:
|
||||
首先将SAST文本检测训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt
|
||||
|
@ -228,7 +228,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
|
|||
```
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
|
||||
(coming soon)
|
||||
![](../imgs_results/det_res_img623_sast.jpg)
|
||||
|
||||
**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
|
||||
|
||||
|
|
|
@ -19,17 +19,17 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|
|||
|
||||
|Model|Backbone|precision|recall|Hmean|Download link|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|EAST|ResNet50_vd|88.18%|85.51%|86.82%|[download link (coming soon)](link)|
|
||||
|EAST|MobileNetV3|81.67%|79.83%|80.74%|[download link (coming soon)](coming soon)|
|
||||
|DB|ResNet50_vd|83.79%|80.65%|82.19%|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|
||||
|DB|MobileNetV3|75.92%|73.18%|74.53%|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|
||||
|SAST|ResNet50_vd|92.18%|82.96%|87.33%|[download link (coming soon)](link)|
|
||||
|EAST|ResNet50_vd|88.76%|81.36%|84.90%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|
||||
|EAST|MobileNetV3|78.24%|79.15%|78.69%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|
||||
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|
||||
|DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|
||||
|SAST|ResNet50_vd|91.83%|81.80%|86.52%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar))|
|
||||
|
||||
On Total-Text dataset, the text detection result is as follows:
|
||||
|
||||
|Model|Backbone|precision|recall|Hmean|Download link|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|SAST|ResNet50_vd|88.74%|79.80%|84.03%|[download link (coming soon)](link)|
|
||||
|SAST|ResNet50_vd|89.05%|76.80%|82.47%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)|
|
||||
|
||||
**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi).
|
||||
|
||||
|
@ -48,14 +48,10 @@ PaddleOCR open-source text recognition algorithms list:
|
|||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
||||
|Model|Backbone|Avg Accuracy|Module combination|Download link|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)|
|
||||
|Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)|
|
||||
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
||||
|CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)|
|
||||
|STAR-Net|MobileNetV3|81.56%|rec_mv3_tps_bilstm_ctc|[download link (coming soon )]()|
|
||||
|STAR-Net|Resnet34_vd|83.93%|rec_r34_vd_tps_bilstm_ctc|[download link (coming soon )]()|
|
||||
|
||||
|
||||
|-|-|-|-|-|
|
||||
|Rosetta|Resnet34_vd|80.9%|rec_r34_vd_none_none_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar)|
|
||||
|Rosetta|MobileNetV3|78.05%|rec_mv3_none_none_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar)|
|
||||
|CRNN|Resnet34_vd|82.76%|rec_r34_vd_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar)|
|
||||
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
||||
|
||||
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./doc/doc_en/recognition_en.md)
|
||||
|
|
|
@ -187,7 +187,7 @@ The visualized text detection results are saved to the `./inference_results` fol
|
|||
<a name="EAST_DETECTION"></a>
|
||||
### 3. EAST TEXT DETECTION MODEL INFERENCE
|
||||
|
||||
First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert:
|
||||
First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/det/det_r50_vd_east.yml -o Global.pretrained_model=./det_r50_vd_east_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_east
|
||||
|
@ -200,7 +200,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img_10.jpg" --det_
|
|||
|
||||
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
|
||||
|
||||
(coming soon)
|
||||
![](../imgs_results/det_res_img_10_east.jpg)
|
||||
|
||||
**Note**: EAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
|
||||
|
||||
|
@ -208,7 +208,7 @@ The visualized text detection results are saved to the `./inference_results` fol
|
|||
<a name="SAST_DETECTION"></a>
|
||||
### 4. SAST TEXT DETECTION MODEL INFERENCE
|
||||
#### (1). Quadrangle text detection model (ICDAR2015)
|
||||
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link (coming soon)](link)), you can use the following command to convert:
|
||||
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_icdar15.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_ic15
|
||||
|
@ -222,10 +222,10 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
|
|||
|
||||
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
|
||||
|
||||
(coming soon)
|
||||
![](../imgs_results/det_res_img_10_sast.jpg)
|
||||
|
||||
#### (2). Curved text detection model (Total-Text)
|
||||
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link (coming soon)](https://paddleocr.bj.bcebos.com/SAST/sast_r50_vd_total_text.tar)), you can use the following command to convert:
|
||||
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the Total-Text English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)), you can use the following command to convert:
|
||||
|
||||
```
|
||||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/det_sast_tt
|
||||
|
@ -239,7 +239,7 @@ python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/img
|
|||
|
||||
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
|
||||
|
||||
(coming soon)
|
||||
![](../imgs_results/det_res_img623_sast.jpg)
|
||||
|
||||
**Note**: SAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
|
||||
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 126 KiB After Width: | Height: | Size: 125 KiB |
Binary file not shown.
Before Width: | Height: | Size: 331 KiB After Width: | Height: | Size: 332 KiB |
Binary file not shown.
Before Width: | Height: | Size: 333 KiB After Width: | Height: | Size: 332 KiB |
|
@ -19,12 +19,10 @@ from __future__ import print_function
|
|||
import numpy as np
|
||||
from .locality_aware_nms import nms_locality
|
||||
import cv2
|
||||
import paddle
|
||||
|
||||
import os
|
||||
import sys
|
||||
# __dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
# sys.path.append(__dir__)
|
||||
# sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
|
||||
class EASTPostProcess(object):
|
||||
|
@ -113,11 +111,14 @@ class EASTPostProcess(object):
|
|||
def __call__(self, outs_dict, shape_list):
|
||||
score_list = outs_dict['f_score']
|
||||
geo_list = outs_dict['f_geo']
|
||||
if isinstance(score_list, paddle.Tensor):
|
||||
score_list = score_list.numpy()
|
||||
geo_list = geo_list.numpy()
|
||||
img_num = len(shape_list)
|
||||
dt_boxes_list = []
|
||||
for ino in range(img_num):
|
||||
score = score_list[ino].numpy()
|
||||
geo = geo_list[ino].numpy()
|
||||
score = score_list[ino]
|
||||
geo = geo_list[ino]
|
||||
boxes = self.detect(
|
||||
score_map=score,
|
||||
geo_map=geo,
|
||||
|
|
|
@ -24,7 +24,7 @@ sys.path.append(os.path.join(__dir__, '..'))
|
|||
|
||||
import numpy as np
|
||||
from .locality_aware_nms import nms_locality
|
||||
# import lanms
|
||||
import paddle
|
||||
import cv2
|
||||
import time
|
||||
|
||||
|
@ -276,14 +276,19 @@ class SASTPostProcess(object):
|
|||
border_list = outs_dict['f_border']
|
||||
tvo_list = outs_dict['f_tvo']
|
||||
tco_list = outs_dict['f_tco']
|
||||
if isinstance(score_list, paddle.Tensor):
|
||||
score_list = score_list.numpy()
|
||||
border_list = border_list.numpy()
|
||||
tvo_list = tvo_list.numpy()
|
||||
tco_list = tco_list.numpy()
|
||||
|
||||
img_num = len(shape_list)
|
||||
poly_lists = []
|
||||
for ino in range(img_num):
|
||||
p_score = score_list[ino].transpose((1,2,0)).numpy()
|
||||
p_border = border_list[ino].transpose((1,2,0)).numpy()
|
||||
p_tvo = tvo_list[ino].transpose((1,2,0)).numpy()
|
||||
p_tco = tco_list[ino].transpose((1,2,0)).numpy()
|
||||
p_score = score_list[ino].transpose((1,2,0))
|
||||
p_border = border_list[ino].transpose((1,2,0))
|
||||
p_tvo = tvo_list[ino].transpose((1,2,0))
|
||||
p_tco = tco_list[ino].transpose((1,2,0))
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
|
||||
|
||||
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
|
||||
|
|
|
@ -37,33 +37,51 @@ class TextDetector(object):
|
|||
def __init__(self, args):
|
||||
self.det_algorithm = args.det_algorithm
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
pre_process_list = [{
|
||||
'DetResizeForTest': {
|
||||
'limit_side_len': args.det_limit_side_len,
|
||||
'limit_type': args.det_limit_type
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': ['image', 'shape']
|
||||
}
|
||||
}]
|
||||
postprocess_params = {}
|
||||
if self.det_algorithm == "DB":
|
||||
pre_process_list = [{
|
||||
'DetResizeForTest': {
|
||||
'limit_side_len': args.det_limit_side_len,
|
||||
'limit_type': args.det_limit_type
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': ['image', 'shape']
|
||||
}
|
||||
}]
|
||||
postprocess_params['name'] = 'DBPostProcess'
|
||||
postprocess_params["thresh"] = args.det_db_thresh
|
||||
postprocess_params["box_thresh"] = args.det_db_box_thresh
|
||||
postprocess_params["max_candidates"] = 1000
|
||||
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||
postprocess_params["use_dilation"] = True
|
||||
elif self.det_algorithm == "EAST":
|
||||
postprocess_params['name'] = 'EASTPostProcess'
|
||||
postprocess_params["score_thresh"] = args.det_east_score_thresh
|
||||
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
|
||||
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
|
||||
elif self.det_algorithm == "SAST":
|
||||
postprocess_params['name'] = 'SASTPostProcess'
|
||||
postprocess_params["score_thresh"] = args.det_sast_score_thresh
|
||||
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
|
||||
self.det_sast_polygon = args.det_sast_polygon
|
||||
if self.det_sast_polygon:
|
||||
postprocess_params["sample_pts_num"] = 6
|
||||
postprocess_params["expand_scale"] = 1.2
|
||||
postprocess_params["shrink_ratio_of_width"] = 0.2
|
||||
else:
|
||||
postprocess_params["sample_pts_num"] = 2
|
||||
postprocess_params["expand_scale"] = 1.0
|
||||
postprocess_params["shrink_ratio_of_width"] = 0.3
|
||||
else:
|
||||
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
|
||||
sys.exit(0)
|
||||
|
@ -149,12 +167,25 @@ class TextDetector(object):
|
|||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
preds = outputs[0]
|
||||
|
||||
# preds = self.predictor(img)
|
||||
preds = {}
|
||||
if self.det_algorithm == "EAST":
|
||||
preds['f_geo'] = outputs[0]
|
||||
preds['f_score'] = outputs[1]
|
||||
elif self.det_algorithm == 'SAST':
|
||||
preds['f_border'] = outputs[0]
|
||||
preds['f_score'] = outputs[1]
|
||||
preds['f_tco'] = outputs[2]
|
||||
preds['f_tvo'] = outputs[3]
|
||||
else:
|
||||
preds = outputs[0]
|
||||
|
||||
post_result = self.postprocess_op(preds, shape_list)
|
||||
dt_boxes = post_result[0]['points']
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
if self.det_algorithm == "SAST" and self.det_sast_polygon:
|
||||
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
||||
else:
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
elapse = time.time() - starttime
|
||||
return dt_boxes, elapse
|
||||
|
||||
|
|
Loading…
Reference in New Issue