Merge branch 'dygraph' into fix_doc
This commit is contained in:
commit
3e289f6acf
|
@ -14,12 +14,13 @@ Global:
|
|||
load_static_weights: True
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
||||
checkpoints:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
infer_img:
|
||||
save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: SAST
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 600
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/pgnet_r50_vd_totaltext/
|
||||
save_epoch_step: 10
|
||||
# evaluation is run every 0 iterationss after the 1000th iteration
|
||||
eval_batch_step: [ 0, 1000 ]
|
||||
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
|
||||
# from static branch, load_static_weights must be set as True.
|
||||
# 2. If you want to finetune the pretrained models we provide in the docs,
|
||||
# you should set load_static_weights as False.
|
||||
load_static_weights: False
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words
|
||||
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
|
||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||
character_type: EN
|
||||
max_text_length: 50 # the max length in seq
|
||||
max_text_nums: 30 # the max seq nums in a pic
|
||||
tcl_len: 64
|
||||
|
||||
Architecture:
|
||||
model_type: e2e
|
||||
algorithm: PGNet
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 50
|
||||
Neck:
|
||||
name: PGFPN
|
||||
Head:
|
||||
name: PGHead
|
||||
|
||||
Loss:
|
||||
name: PGLoss
|
||||
tcl_bs: 64
|
||||
max_text_length: 50 # the same as Global: max_text_length
|
||||
max_text_nums: 30 # the same as Global:max_text_nums
|
||||
pad_num: 36 # the length of dict for pad
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
learning_rate: 0.001
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
|
||||
PostProcess:
|
||||
name: PGPostProcess
|
||||
score_thresh: 0.5
|
||||
Metric:
|
||||
name: E2EMetric
|
||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||
main_indicator: f_score_e2e
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: PGDataSet
|
||||
label_file_list: [.././train_data/total_text/train/]
|
||||
ratio_list: [1.0]
|
||||
data_format: icdar #two data format: icdar/textnet
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- PGProcessTrain:
|
||||
batch_size: 14 # same as loader: batch_size_per_card
|
||||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: True
|
||||
batch_size_per_card: 14
|
||||
num_workers: 16
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: PGDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_list: [./train_data/total_text/test/]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- E2ELabelEncode:
|
||||
- E2EResizeForTest:
|
||||
max_side_len: 768
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [ 0.485, 0.456, 0.406 ]
|
||||
std: [ 0.229, 0.224, 0.225 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ]
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -12,7 +12,8 @@ inference 模型(`paddle.jit.save`保存的模型)
|
|||
- [一、训练模型转inference模型](#训练模型转inference模型)
|
||||
- [检测模型转inference模型](#检测模型转inference模型)
|
||||
- [识别模型转inference模型](#识别模型转inference模型)
|
||||
- [方向分类模型转inference模型](#方向分类模型转inference模型)
|
||||
- [方向分类模型转inference模型](#方向分类模型转inference模型)
|
||||
- [端到端模型转inference模型](#端到端模型转inference模型)
|
||||
|
||||
- [二、文本检测模型推理](#文本检测模型推理)
|
||||
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
|
||||
|
@ -27,10 +28,13 @@ inference 模型(`paddle.jit.save`保存的模型)
|
|||
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
||||
- [5. 多语言模型的推理](#多语言模型的推理)
|
||||
|
||||
- [四、方向分类模型推理](#方向识别模型推理)
|
||||
- [四、端到端模型推理](#端到端模型推理)
|
||||
- [1. PGNet端到端模型推理](#PGNet端到端模型推理)
|
||||
|
||||
- [五、方向分类模型推理](#方向识别模型推理)
|
||||
- [1. 方向分类模型推理](#方向分类模型推理)
|
||||
|
||||
- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
|
||||
- [六、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
|
||||
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
|
||||
- [2. 其他模型推理](#其他模型推理)
|
||||
|
||||
|
@ -118,6 +122,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo
|
|||
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 分类inference模型的program文件
|
||||
```
|
||||
<a name="端到端模型转inference模型"></a>
|
||||
### 端到端模型转inference模型
|
||||
|
||||
下载端到端模型:
|
||||
```
|
||||
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar && tar xf ./ch_lite/ch_ppocr_mobile_v2.0_cls_train.tar -C ./ch_lite/
|
||||
```
|
||||
|
||||
端到端模型转inference模型与检测的方式相同,如下:
|
||||
```
|
||||
# -c 后面设置训练算法的yml配置文件
|
||||
# -o 配置可选参数
|
||||
# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
|
||||
# Global.load_static_weights 参数需要设置为 False。
|
||||
# Global.save_inference_dir参数设置转换的模型将保存的地址。
|
||||
|
||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e/
|
||||
```
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```
|
||||
/inference/e2e/
|
||||
├── inference.pdiparams # 分类inference模型的参数文件
|
||||
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 分类inference模型的program文件
|
||||
```
|
||||
|
||||
<a name="文本检测模型推理"></a>
|
||||
## 二、文本检测模型推理
|
||||
|
@ -332,8 +362,38 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" -
|
|||
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
|
||||
```
|
||||
|
||||
<a name="端到端模型推理"></a>
|
||||
## 四、端到端模型推理
|
||||
|
||||
端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
|
||||
<a name="PGNet端到端模型推理"></a>
|
||||
### 1. PGNet端到端模型推理
|
||||
#### (1). 四边形文本检测模型(ICDAR2015)
|
||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)),可以使用如下命令进行转换:
|
||||
```
|
||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
||||
```
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
|
||||
```
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
|
||||
|
||||
#### (2). 弯曲文本检测模型(Total-Text)
|
||||
和四边形文本检测模型共用一个推理模型
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
||||
```
|
||||
python3.7 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
||||
```
|
||||
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
||||
|
||||
|
||||
<a name="方向分类模型推理"></a>
|
||||
## 四、方向分类模型推理
|
||||
## 五、方向分类模型推理
|
||||
|
||||
下面将介绍方向分类模型推理。
|
||||
|
||||
|
@ -358,7 +418,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
|
|||
```
|
||||
|
||||
<a name="文本检测、方向分类和文字识别串联推理"></a>
|
||||
## 五、文本检测、方向分类和文字识别串联推理
|
||||
## 六、文本检测、方向分类和文字识别串联推理
|
||||
<a name="超轻量中文OCR模型推理"></a>
|
||||
### 1. 超轻量中文OCR模型推理
|
||||
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
# 端对端OCR算法-PGNet
|
||||
- [一、简介](#简介)
|
||||
- [二、环境配置](#环境配置)
|
||||
- [三、快速使用](#快速使用)
|
||||
- [四、快速训练](#开始训练)
|
||||
- [五、预测推理](#预测推理)
|
||||
|
||||
|
||||
<a name="简介"></a>
|
||||
##简介
|
||||
OCR算法可以分为两阶段算法和端对端的算法。二阶段OCR算法一般分为两个部分,文本检测和文本识别算法,文件检测算法从图像中得到文本行的检测框,然后识别算法去识别文本框中的内容。而端对端OCR算法可以在一个算法中完成文字检测和文字识别,其基本思想是设计一个同时具有检测单元和识别模块的模型,共享其中两者的CNN特征,并联合训练。由于一个算法即可完成文字识别,端对端模型更小,速度更快。
|
||||
|
||||
### PGNet算法介绍
|
||||
近些年来,端对端OCR算法得到了良好的发展,包括MaskTextSpotter系列、TextSnake、TextDragon、PGNet系列等算法。在这些算法中,PGNet算法具备其他算法不具备的优势,包括:
|
||||
- 设计PGNet loss指导训练,不需要字符级别的标注
|
||||
- 不需要NMS和ROI相关操作,加速预测
|
||||
- 提出预测文本行内的阅读顺序模块;
|
||||
- 提出基于图的修正模块(GRM)来进一步提高模型识别性能
|
||||
- 精度更高,预测速度更快
|
||||
|
||||
PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf), 算法原理图如下所示:
|
||||
![](../pgnet_framework.png)
|
||||
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
|
||||
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
|
||||
其检测识别效果图如下:
|
||||
![](../imgs_results/e2e_res_img293_pgnet.png)
|
||||
![](../imgs_results/e2e_res_img295_pgnet.png)
|
||||
|
||||
<a name="环境配置"></a>
|
||||
##环境配置
|
||||
请先参考[快速安装](./installation.md)配置PaddleOCR运行环境。
|
||||
|
||||
*注意:也可以通过 whl 包安装使用PaddleOCR,具体参考[Paddleocr Package使用说明](./whl.md)。*
|
||||
|
||||
<a name="快速使用"></a>
|
||||
##快速使用
|
||||
### inference模型下载
|
||||
本节以训练好的端到端模型为例,快速使用模型预测,首先下载训练好的端到端inference模型[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar)
|
||||
```
|
||||
mkdir inference && cd inference
|
||||
# 下载英文端到端模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar && tar xf e2e_server_pgnetA_infer.tar
|
||||
```
|
||||
* windows 环境下如果没有安装wget,下载模型时可将链接复制到浏览器中下载,并解压放置在相应目录下
|
||||
|
||||
解压完毕后应有如下文件结构:
|
||||
```
|
||||
├── e2e_server_pgnetA_infer
|
||||
│ ├── inference.pdiparams
|
||||
│ ├── inference.pdiparams.info
|
||||
│ └── inference.pdmodel
|
||||
```
|
||||
### 单张图像或者图像集合预测
|
||||
```bash
|
||||
# 预测image_dir指定的单张图像
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
||||
|
||||
# 预测image_dir指定的图像集合
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
||||
|
||||
# 如果想使用CPU进行预测,需设置use_gpu参数为False
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True --use_gpu=False
|
||||
```
|
||||
<a name="开始训练"></a>
|
||||
##开始训练
|
||||
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
|
||||
###数据形式为icdar, 十六点标注数据
|
||||
解压数据集和下载标注文件后,PaddleOCR/train_data/total_text/train/ 有两个文件夹,分别是:
|
||||
```
|
||||
/PaddleOCR/train_data/total_text/train/
|
||||
|- rgb/ total_text数据集的训练数据
|
||||
|- gt_0.png
|
||||
| ...
|
||||
|- total_text.txt total_text数据集的训练标注
|
||||
```
|
||||
|
||||
提供的标注文件格式如下,中间用"\t"分隔:
|
||||
```
|
||||
" 图像文件名 json.dumps编码的图像标注信息"
|
||||
rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}]
|
||||
```
|
||||
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。
|
||||
`transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
|
||||
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。
|
||||
|
||||
### 快速启动训练
|
||||
|
||||
模型训练一般分两步骤进行,第一步可以选择用合成数据训练,第二步加载第一步训练好的模型训练,这边我们提供了第一步训练好的模型,可以直接加载,从第二步开始训练
|
||||
[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar)
|
||||
```shell
|
||||
cd PaddleOCR/
|
||||
下载ResNet50_vd的动态图预训练模型
|
||||
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar
|
||||
可以得到以下的文件格式
|
||||
./pretrain_models/train_step1/
|
||||
└─ best_accuracy.pdopt
|
||||
└─ best_accuracy.states
|
||||
└─ best_accuracy.pdparams
|
||||
|
||||
```
|
||||
|
||||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||
|
||||
```shell
|
||||
# 单机单卡训练 e2e 模型
|
||||
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False
|
||||
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False
|
||||
```
|
||||
|
||||
上述指令中,通过-c 选择训练使用configs/e2e/e2e_r50_vd_pg.yml配置文件。
|
||||
有关配置文件的详细解释,请参考[链接](./config.md)。
|
||||
|
||||
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
|
||||
```shell
|
||||
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001
|
||||
```
|
||||
|
||||
#### 断点训练
|
||||
|
||||
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
|
||||
```shell
|
||||
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model
|
||||
```
|
||||
|
||||
**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
|
||||
|
||||
<a name="预测推理"></a>
|
||||
## 预测推理
|
||||
|
||||
PaddleOCR计算三个OCR端到端相关的指标,分别是:Precision、Recall、Hmean。
|
||||
|
||||
运行如下代码,根据配置文件`e2e_r50_vd_pg.yml`中`save_res_path`指定的测试集检测结果文件,计算评估指标。
|
||||
|
||||
评估时设置后处理参数`max_side_len=768`,使用不同数据集、不同模型训练,可调整参数进行优化
|
||||
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
|
||||
```shell
|
||||
python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{path/to/weights}/best_accuracy"
|
||||
```
|
||||
|
||||
### 测试端到端效果
|
||||
测试单张图像的端到端识别效果
|
||||
```shell
|
||||
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
|
||||
```
|
||||
|
||||
测试文件夹下所有图像的端到端识别效果
|
||||
```shell
|
||||
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
|
||||
```
|
||||
|
||||
###转为推理模型
|
||||
### (1). 四边形文本检测模型(ICDAR2015)
|
||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换:
|
||||
```
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
|
||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
||||
```
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
|
||||
```
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
|
||||
|
||||
### (2). 弯曲文本检测模型(Total-Text)
|
||||
对于弯曲文本样例
|
||||
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
||||
```
|
||||
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
Binary file not shown.
After Width: | Height: | Size: 663 KiB |
Binary file not shown.
After Width: | Height: | Size: 467 KiB |
Binary file not shown.
After Width: | Height: | Size: 134 KiB |
Binary file not shown.
After Width: | Height: | Size: 337 KiB |
Binary file not shown.
After Width: | Height: | Size: 242 KiB |
|
@ -34,6 +34,7 @@ import paddle.distributed as dist
|
|||
from ppocr.data.imaug import transform, create_operators
|
||||
from ppocr.data.simple_dataset import SimpleDataSet
|
||||
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||
from ppocr.data.pgnet_dataset import PGDataSet
|
||||
|
||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||
|
||||
|
@ -54,7 +55,7 @@ signal.signal(signal.SIGTERM, term_mp)
|
|||
def build_dataloader(config, mode, device, logger, seed=None):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet']
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
|
||||
module_name = config[mode]['dataset']['name']
|
||||
assert module_name in support_dict, Exception(
|
||||
'DataSet only support {}'.format(support_dict))
|
||||
|
@ -72,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None):
|
|||
else:
|
||||
use_shared_memory = True
|
||||
if mode == "Train":
|
||||
#Distribute data to multiple cards
|
||||
# Distribute data to multiple cards
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last)
|
||||
else:
|
||||
#Distribute data to single card
|
||||
# Distribute data to single card
|
||||
batch_sampler = BatchSampler(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
|
|
|
@ -28,6 +28,7 @@ from .label_ops import *
|
|||
|
||||
from .east_process import *
|
||||
from .sast_process import *
|
||||
from .pg_process import *
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
|
|
|
@ -187,6 +187,34 @@ class CTCLabelEncode(BaseRecLabelEncode):
|
|||
return dict_character
|
||||
|
||||
|
||||
class E2ELabelEncode(BaseRecLabelEncode):
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='EN',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(E2ELabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
self.pad_num = len(self.dict) # the length to pad
|
||||
|
||||
def __call__(self, data):
|
||||
text_label_index_list, temp_text = [], []
|
||||
texts = data['strs']
|
||||
for text in texts:
|
||||
text = text.lower()
|
||||
temp_text = []
|
||||
for c_ in text:
|
||||
if c_ in self.dict:
|
||||
temp_text.append(self.dict[c_])
|
||||
temp_text = temp_text + [self.pad_num] * (self.max_text_len -
|
||||
len(temp_text))
|
||||
text_label_index_list.append(temp_text)
|
||||
data['strs'] = np.array(text_label_index_list)
|
||||
return data
|
||||
|
||||
|
||||
class AttnLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -197,7 +197,6 @@ class DetResizeForTest(object):
|
|||
sys.exit(0)
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
# return img, np.array([h, w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type2(self, img):
|
||||
|
@ -206,7 +205,6 @@ class DetResizeForTest(object):
|
|||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.resize_long) / resize_h
|
||||
else:
|
||||
|
@ -223,3 +221,72 @@ class DetResizeForTest(object):
|
|||
ratio_w = resize_w / float(w)
|
||||
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
|
||||
class E2EResizeForTest(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(E2EResizeForTest, self).__init__()
|
||||
self.max_side_len = kwargs['max_side_len']
|
||||
self.valid_set = kwargs['valid_set']
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
if self.valid_set == 'totaltext':
|
||||
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
|
||||
img, max_side_len=self.max_side_len)
|
||||
else:
|
||||
im_resized, (ratio_h, ratio_w) = self.resize_image(
|
||||
img, max_side_len=self.max_side_len)
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def resize_image_for_totaltext(self, im, max_side_len=512):
|
||||
|
||||
h, w, _ = im.shape
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
ratio = 1.25
|
||||
if h * ratio > max_side_len:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
def resize_image(self, im, max_side_len=512):
|
||||
"""
|
||||
resize image to a size multiple of max_stride which is required by the network
|
||||
:param im: the resized image
|
||||
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
:return: the resized image and the resize ratio
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
|
|
@ -0,0 +1,906 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['PGProcessTrain']
|
||||
|
||||
|
||||
class PGProcessTrain(object):
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
max_text_length,
|
||||
max_text_nums,
|
||||
tcl_len,
|
||||
batch_size=14,
|
||||
min_crop_size=24,
|
||||
min_text_size=4,
|
||||
max_text_size=512,
|
||||
**kwargs):
|
||||
self.tcl_len = tcl_len
|
||||
self.max_text_length = max_text_length
|
||||
self.max_text_nums = max_text_nums
|
||||
self.batch_size = batch_size
|
||||
self.min_crop_size = min_crop_size
|
||||
self.min_text_size = min_text_size
|
||||
self.max_text_size = max_text_size
|
||||
self.Lexicon_Table = self.get_dict(character_dict_path)
|
||||
self.pad_num = len(self.Lexicon_Table)
|
||||
self.img_id = 0
|
||||
|
||||
def get_dict(self, character_dict_path):
|
||||
character_str = ""
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
character_str += line
|
||||
dict_character = list(character_str)
|
||||
return dict_character
|
||||
|
||||
def quad_area(self, poly):
|
||||
"""
|
||||
compute area of a polygon
|
||||
:param poly:
|
||||
:return:
|
||||
"""
|
||||
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def gen_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
rect = cv2.minAreaRect(poly.astype(
|
||||
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad
|
||||
|
||||
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
|
||||
"""
|
||||
check so that the text poly is in the same direction,
|
||||
and also filter some invalid polygons
|
||||
:param polys:
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
(h, w) = xxx_todo_changeme
|
||||
if polys.shape[0] == 0:
|
||||
return polys, np.array([]), np.array([])
|
||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||
|
||||
validated_polys = []
|
||||
validated_tags = []
|
||||
hv_tags = []
|
||||
for poly, tag in zip(polys, tags):
|
||||
quad = self.gen_quad_from_poly(poly)
|
||||
p_area = self.quad_area(quad)
|
||||
if abs(p_area) < 1:
|
||||
print('invalid poly')
|
||||
continue
|
||||
if p_area > 0:
|
||||
if tag == False:
|
||||
print('poly in wrong direction')
|
||||
tag = True # reversed cases should be ignore
|
||||
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
|
||||
1), :]
|
||||
quad = quad[(0, 3, 2, 1), :]
|
||||
|
||||
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
|
||||
quad[2])
|
||||
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
|
||||
quad[2])
|
||||
hv_tag = 1
|
||||
|
||||
if len_w * 2.0 < len_h:
|
||||
hv_tag = 0
|
||||
|
||||
validated_polys.append(poly)
|
||||
validated_tags.append(tag)
|
||||
hv_tags.append(hv_tag)
|
||||
return np.array(validated_polys), np.array(validated_tags), np.array(
|
||||
hv_tags)
|
||||
|
||||
def crop_area(self,
|
||||
im,
|
||||
polys,
|
||||
tags,
|
||||
hv_tags,
|
||||
txts,
|
||||
crop_background=False,
|
||||
max_tries=25):
|
||||
"""
|
||||
make random crop from the input image
|
||||
:param im:
|
||||
:param polys: [b,4,2]
|
||||
:param tags:
|
||||
:param crop_background:
|
||||
:param max_tries: 50 -> 25
|
||||
:return:
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
pad_h = h // 10
|
||||
pad_w = w // 10
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return im, polys, tags, hv_tags, txts
|
||||
for i in range(max_tries):
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
if xmax - xmin < self.min_crop_size or \
|
||||
ymax - ymin < self.min_crop_size:
|
||||
continue
|
||||
if polys.shape[0] != 0:
|
||||
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
||||
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
||||
selected_polys = np.where(
|
||||
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
else:
|
||||
selected_polys = []
|
||||
if len(selected_polys) == 0:
|
||||
# no text in this area
|
||||
if crop_background:
|
||||
txts_tmp = []
|
||||
for selected_poly in selected_polys:
|
||||
txts_tmp.append(txts[selected_poly])
|
||||
txts = txts_tmp
|
||||
return im[ymin: ymax + 1, xmin: xmax + 1, :], \
|
||||
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
||||
else:
|
||||
continue
|
||||
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||
polys = polys[selected_polys]
|
||||
tags = tags[selected_polys]
|
||||
hv_tags = hv_tags[selected_polys]
|
||||
txts_tmp = []
|
||||
for selected_poly in selected_polys:
|
||||
txts_tmp.append(txts[selected_poly])
|
||||
txts = txts_tmp
|
||||
polys[:, :, 0] -= xmin
|
||||
polys[:, :, 1] -= ymin
|
||||
return im, polys, tags, hv_tags, txts
|
||||
|
||||
return im, polys, tags, hv_tags, txts
|
||||
|
||||
def fit_and_gather_tcl_points_v2(self,
|
||||
min_area_quad,
|
||||
poly,
|
||||
max_h,
|
||||
max_w,
|
||||
fixed_point_num=64,
|
||||
img_id=0,
|
||||
reference_height=3):
|
||||
"""
|
||||
Find the center point of poly as key_points, then fit and gather.
|
||||
"""
|
||||
key_point_xys = []
|
||||
point_num = poly.shape[0]
|
||||
for idx in range(point_num // 2):
|
||||
center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
|
||||
key_point_xys.append(center_point)
|
||||
|
||||
tmp_image = np.zeros(
|
||||
shape=(
|
||||
max_h,
|
||||
max_w, ), dtype='float32')
|
||||
cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
|
||||
False, 1.0)
|
||||
ys, xs = np.where(tmp_image > 0)
|
||||
xy_text = np.array(list(zip(xs, ys)), dtype='float32')
|
||||
|
||||
left_center_pt = (
|
||||
(min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
|
||||
right_center_pt = (
|
||||
(min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
|
||||
proj_unit_vec = (right_center_pt - left_center_pt) / (
|
||||
np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
|
||||
proj_unit_vec_tile = np.tile(proj_unit_vec,
|
||||
(xy_text.shape[0], 1)) # (n, 2)
|
||||
left_center_pt_tile = np.tile(left_center_pt,
|
||||
(xy_text.shape[0], 1)) # (n, 2)
|
||||
xy_text_to_left_center = xy_text - left_center_pt_tile
|
||||
proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
|
||||
xy_text = xy_text[np.argsort(proj_value)]
|
||||
|
||||
# convert to np and keep the num of point not greater then fixed_point_num
|
||||
pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
|
||||
point_num = len(pos_info)
|
||||
if point_num > fixed_point_num:
|
||||
keep_ids = [
|
||||
int((point_num * 1.0 / fixed_point_num) * x)
|
||||
for x in range(fixed_point_num)
|
||||
]
|
||||
pos_info = pos_info[keep_ids, :]
|
||||
|
||||
keep = int(min(len(pos_info), fixed_point_num))
|
||||
if np.random.rand() < 0.2 and reference_height >= 3:
|
||||
dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
|
||||
random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
|
||||
[keep, 1])
|
||||
pos_info += random_float
|
||||
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
|
||||
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
|
||||
|
||||
# padding to fixed length
|
||||
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
|
||||
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
|
||||
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
|
||||
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
|
||||
pos_m[:keep] = 1.0
|
||||
return pos_l, pos_m
|
||||
|
||||
def generate_direction_map(self, poly_quads, n_char, direction_map):
|
||||
"""
|
||||
"""
|
||||
width_list = []
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
width_list.append(quad_w)
|
||||
height_list.append(quad_h)
|
||||
norm_width = max(sum(width_list) / n_char, 1.0)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
k = 1
|
||||
for quad in poly_quads:
|
||||
direct_vector_full = (
|
||||
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
||||
direct_vector = direct_vector_full / (
|
||||
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
||||
direction_label = tuple(
|
||||
map(float,
|
||||
[direct_vector[0], direct_vector[1], 1.0 / average_height]))
|
||||
cv2.fillPoly(direction_map,
|
||||
quad.round().astype(np.int32)[np.newaxis, :, :],
|
||||
direction_label)
|
||||
k += 1
|
||||
return direction_map
|
||||
|
||||
def calculate_average_height(self, poly_quads):
|
||||
"""
|
||||
"""
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
height_list.append(quad_h)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
return average_height
|
||||
|
||||
def generate_tcl_ctc_label(self,
|
||||
h,
|
||||
w,
|
||||
polys,
|
||||
tags,
|
||||
text_strs,
|
||||
ds_ratio,
|
||||
tcl_ratio=0.3,
|
||||
shrink_ratio_of_width=0.15):
|
||||
"""
|
||||
Generate polygon.
|
||||
"""
|
||||
score_map_big = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
|
||||
score_map = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
score_label_map = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
||||
training_mask = np.ones(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
|
||||
[1, 1, 3]).astype(np.float32)
|
||||
|
||||
label_idx = 0
|
||||
score_label_map_text_label_list = []
|
||||
pos_list, pos_mask, label_list = [], [], []
|
||||
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||
poly = poly_tag[0]
|
||||
tag = poly_tag[1]
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
||||
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
||||
continue
|
||||
|
||||
if tag:
|
||||
cv2.fillPoly(training_mask,
|
||||
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
||||
else:
|
||||
text_label = text_strs[poly_idx]
|
||||
text_label = self.prepare_text_label(text_label,
|
||||
self.Lexicon_Table)
|
||||
|
||||
text_label_index_list = [[self.Lexicon_Table.index(c_)]
|
||||
for c_ in text_label
|
||||
if c_ in self.Lexicon_Table]
|
||||
if len(text_label_index_list) < 1:
|
||||
continue
|
||||
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
tcl_quads = self.poly2quads(tcl_poly)
|
||||
poly_quads = self.poly2quads(poly)
|
||||
|
||||
stcl_quads, quad_index = self.shrink_poly_along_width(
|
||||
tcl_quads,
|
||||
shrink_ratio_of_width=shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0 / tcl_ratio)
|
||||
|
||||
cv2.fillPoly(score_map,
|
||||
np.round(stcl_quads).astype(np.int32), 1.0)
|
||||
cv2.fillPoly(score_map_big,
|
||||
np.round(stcl_quads / ds_ratio).astype(np.int32),
|
||||
1.0)
|
||||
|
||||
for idx, quad in enumerate(stcl_quads):
|
||||
quad_mask = np.zeros((h, w), dtype=np.float32)
|
||||
quad_mask = cv2.fillPoly(
|
||||
quad_mask,
|
||||
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
||||
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
|
||||
quad_mask, tbo_map)
|
||||
|
||||
# score label map and score_label_map_text_label_list for refine
|
||||
if label_idx == 0:
|
||||
text_pos_list_ = [[len(self.Lexicon_Table)], ]
|
||||
score_label_map_text_label_list.append(text_pos_list_)
|
||||
|
||||
label_idx += 1
|
||||
cv2.fillPoly(score_label_map,
|
||||
np.round(poly_quads).astype(np.int32), label_idx)
|
||||
score_label_map_text_label_list.append(text_label_index_list)
|
||||
|
||||
# direction info, fix-me
|
||||
n_char = len(text_label_index_list)
|
||||
direction_map = self.generate_direction_map(poly_quads, n_char,
|
||||
direction_map)
|
||||
|
||||
# pos info
|
||||
average_shrink_height = self.calculate_average_height(
|
||||
stcl_quads)
|
||||
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
|
||||
min_area_quad,
|
||||
poly,
|
||||
max_h=h,
|
||||
max_w=w,
|
||||
fixed_point_num=64,
|
||||
img_id=self.img_id,
|
||||
reference_height=average_shrink_height)
|
||||
|
||||
label_l = text_label_index_list
|
||||
if len(text_label_index_list) < 2:
|
||||
continue
|
||||
|
||||
pos_list.append(pos_l)
|
||||
pos_mask.append(pos_m)
|
||||
label_list.append(label_l)
|
||||
|
||||
# use big score_map for smooth tcl lines
|
||||
score_map_big_resized = cv2.resize(
|
||||
score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
|
||||
score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
|
||||
|
||||
return score_map, score_label_map, tbo_map, direction_map, training_mask, \
|
||||
pos_list, pos_mask, label_list, score_label_map_text_label_list
|
||||
|
||||
def adjust_point(self, poly):
|
||||
"""
|
||||
adjust point order.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
if point_num == 4:
|
||||
len_1 = np.linalg.norm(poly[0] - poly[1])
|
||||
len_2 = np.linalg.norm(poly[1] - poly[2])
|
||||
len_3 = np.linalg.norm(poly[2] - poly[3])
|
||||
len_4 = np.linalg.norm(poly[3] - poly[0])
|
||||
|
||||
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
||||
poly = poly[[1, 2, 3, 0], :]
|
||||
|
||||
elif point_num > 4:
|
||||
vector_1 = poly[0] - poly[1]
|
||||
vector_2 = poly[1] - poly[2]
|
||||
cos_theta = np.dot(vector_1, vector_2) / (
|
||||
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
||||
theta = np.arccos(np.round(cos_theta, decimals=4))
|
||||
|
||||
if abs(theta) > (70 / 180 * math.pi):
|
||||
index = list(range(1, point_num)) + [0]
|
||||
poly = poly[np.array(index), :]
|
||||
return poly
|
||||
|
||||
def gen_min_area_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if point_num == 4:
|
||||
min_area_quad = poly
|
||||
center_point = np.sum(poly, axis=0) / 4
|
||||
else:
|
||||
rect = cv2.minAreaRect(poly.astype(
|
||||
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad, center_point
|
||||
|
||||
def shrink_quad_along_width(self,
|
||||
quad,
|
||||
begin_width_ratio=0.,
|
||||
end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def shrink_poly_along_width(self,
|
||||
quads,
|
||||
shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0):
|
||||
"""
|
||||
shrink poly with given length.
|
||||
"""
|
||||
upper_edge_list = []
|
||||
|
||||
def get_cut_info(edge_len_list, cut_len):
|
||||
for idx, edge_len in enumerate(edge_len_list):
|
||||
cut_len -= edge_len
|
||||
if cut_len <= 0.000001:
|
||||
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
||||
return idx, ratio
|
||||
|
||||
for quad in quads:
|
||||
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
||||
upper_edge_list.append(upper_edge_len)
|
||||
|
||||
# length of left edge and right edge.
|
||||
left_length = np.linalg.norm(quads[0][0] - quads[0][
|
||||
3]) * expand_height_ratio
|
||||
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
|
||||
2]) * expand_height_ratio
|
||||
|
||||
shrink_length = min(left_length, right_length,
|
||||
sum(upper_edge_list)) * shrink_ratio_of_width
|
||||
# shrinking length
|
||||
upper_len_left = shrink_length
|
||||
upper_len_right = sum(upper_edge_list) - shrink_length
|
||||
|
||||
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
||||
left_quad = self.shrink_quad_along_width(
|
||||
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
||||
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
||||
right_quad = self.shrink_quad_along_width(
|
||||
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
||||
|
||||
out_quad_list = []
|
||||
if left_idx == right_idx:
|
||||
out_quad_list.append(
|
||||
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
||||
else:
|
||||
out_quad_list.append(left_quad)
|
||||
for idx in range(left_idx + 1, right_idx):
|
||||
out_quad_list.append(quads[idx])
|
||||
out_quad_list.append(right_quad)
|
||||
|
||||
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
||||
|
||||
def prepare_text_label(self, label_str, Lexicon_Table):
|
||||
"""
|
||||
Prepare text lablel by given Lexicon_Table.
|
||||
"""
|
||||
if len(Lexicon_Table) == 36:
|
||||
return label_str.lower()
|
||||
else:
|
||||
return label_str
|
||||
|
||||
def vector_angle(self, A, B):
|
||||
"""
|
||||
Calculate the angle between vector AB and x-axis positive direction.
|
||||
"""
|
||||
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
||||
return np.arctan2(*AB)
|
||||
|
||||
def theta_line_cross_point(self, theta, point):
|
||||
"""
|
||||
Calculate the line through given point and angle in ax + by + c =0 form.
|
||||
"""
|
||||
x, y = point
|
||||
cos = np.cos(theta)
|
||||
sin = np.sin(theta)
|
||||
return [sin, -cos, cos * y - sin * x]
|
||||
|
||||
def line_cross_two_point(self, A, B):
|
||||
"""
|
||||
Calculate the line through given point A and B in ax + by + c =0 form.
|
||||
"""
|
||||
angle = self.vector_angle(A, B)
|
||||
return self.theta_line_cross_point(angle, A)
|
||||
|
||||
def average_angle(self, poly):
|
||||
"""
|
||||
Calculate the average angle between left and right edge in given poly.
|
||||
"""
|
||||
p0, p1, p2, p3 = poly
|
||||
angle30 = self.vector_angle(p3, p0)
|
||||
angle21 = self.vector_angle(p2, p1)
|
||||
return (angle30 + angle21) / 2
|
||||
|
||||
def line_cross_point(self, line1, line2):
|
||||
"""
|
||||
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
||||
"""
|
||||
a1, b1, c1 = line1
|
||||
a2, b2, c2 = line2
|
||||
d = a1 * b2 - a2 * b1
|
||||
|
||||
if d == 0:
|
||||
print('Cross point does not exist')
|
||||
return np.array([0, 0], dtype=np.float32)
|
||||
else:
|
||||
x = (b1 * c2 - b2 * c1) / d
|
||||
y = (a2 * c1 - a1 * c2) / d
|
||||
|
||||
return np.array([x, y], dtype=np.float32)
|
||||
|
||||
def quad2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point. (4, 2)
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
||||
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
||||
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
||||
|
||||
def poly2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
tcl_poly = np.zeros_like(poly)
|
||||
point_num = poly.shape[0]
|
||||
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
|
||||
) * ratio_pair
|
||||
tcl_poly[idx] = point_pair[0]
|
||||
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
||||
return tcl_poly
|
||||
|
||||
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
||||
"""
|
||||
Generate tbo_map for give quad.
|
||||
"""
|
||||
# upper and lower line function: ax + by + c = 0;
|
||||
up_line = self.line_cross_two_point(quad[0], quad[1])
|
||||
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
||||
|
||||
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[1] - quad[2]))
|
||||
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3]))
|
||||
|
||||
# average angle of left and right line.
|
||||
angle = self.average_angle(quad)
|
||||
|
||||
xy_in_poly = np.argwhere(tcl_mask == 1)
|
||||
for y, x in xy_in_poly:
|
||||
point = (x, y)
|
||||
line = self.theta_line_cross_point(angle, point)
|
||||
cross_point_upper = self.line_cross_point(up_line, line)
|
||||
cross_point_lower = self.line_cross_point(lower_line, line)
|
||||
##FIX, offset reverse
|
||||
upper_offset_x, upper_offset_y = cross_point_upper - point
|
||||
lower_offset_x, lower_offset_y = cross_point_lower - point
|
||||
tbo_map[y, x, 0] = upper_offset_y
|
||||
tbo_map[y, x, 1] = upper_offset_x
|
||||
tbo_map[y, x, 2] = lower_offset_y
|
||||
tbo_map[y, x, 3] = lower_offset_x
|
||||
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
||||
return tbo_map
|
||||
|
||||
def poly2quads(self, poly):
|
||||
"""
|
||||
Split poly into quads.
|
||||
"""
|
||||
quad_list = []
|
||||
point_num = poly.shape[0]
|
||||
|
||||
# point pair
|
||||
point_pair_list = []
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
quad_num = point_num // 2 - 1
|
||||
for idx in range(quad_num):
|
||||
# reshape and adjust to clock-wise
|
||||
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
|
||||
).reshape(4, 2)[[0, 2, 3, 1]])
|
||||
|
||||
return np.array(quad_list)
|
||||
|
||||
def rotate_im_poly(self, im, text_polys):
|
||||
"""
|
||||
rotate image with 90 / 180 / 270 degre
|
||||
"""
|
||||
im_w, im_h = im.shape[1], im.shape[0]
|
||||
dst_im = im.copy()
|
||||
dst_polys = []
|
||||
rand_degree_ratio = np.random.rand()
|
||||
rand_degree_cnt = 1
|
||||
if rand_degree_ratio > 0.5:
|
||||
rand_degree_cnt = 3
|
||||
for i in range(rand_degree_cnt):
|
||||
dst_im = np.rot90(dst_im)
|
||||
rot_degree = -90 * rand_degree_cnt
|
||||
rot_angle = rot_degree * math.pi / 180.0
|
||||
n_poly = text_polys.shape[0]
|
||||
cx, cy = 0.5 * im_w, 0.5 * im_h
|
||||
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
||||
for i in range(n_poly):
|
||||
wordBB = text_polys[i]
|
||||
poly = []
|
||||
for j in range(4): # 16->4
|
||||
sx, sy = wordBB[j][0], wordBB[j][1]
|
||||
dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
|
||||
sy - cy) + ncx
|
||||
dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
|
||||
sy - cy) + ncy
|
||||
poly.append([dx, dy])
|
||||
dst_polys.append(poly)
|
||||
return dst_im, np.array(dst_polys, dtype=np.float32)
|
||||
|
||||
def __call__(self, data):
|
||||
input_size = 512
|
||||
im = data['image']
|
||||
text_polys = data['polys']
|
||||
text_tags = data['tags']
|
||||
text_strs = data['strs']
|
||||
h, w, _ = im.shape
|
||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
||||
text_polys, text_tags, (h, w))
|
||||
if text_polys.shape[0] <= 0:
|
||||
return None
|
||||
# set aspect ratio and keep area fix
|
||||
asp_scales = np.arange(1.0, 1.55, 0.1)
|
||||
asp_scale = np.random.choice(asp_scales)
|
||||
if np.random.rand() < 0.5:
|
||||
asp_scale = 1.0 / asp_scale
|
||||
asp_scale = math.sqrt(asp_scale)
|
||||
|
||||
asp_wx = asp_scale
|
||||
asp_hy = 1.0 / asp_scale
|
||||
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
||||
text_polys[:, :, 0] *= asp_wx
|
||||
text_polys[:, :, 1] *= asp_hy
|
||||
|
||||
h, w, _ = im.shape
|
||||
if max(h, w) > 2048:
|
||||
rd_scale = 2048.0 / max(h, w)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
h, w, _ = im.shape
|
||||
if min(h, w) < 16:
|
||||
return None
|
||||
|
||||
# no background
|
||||
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
|
||||
im,
|
||||
text_polys,
|
||||
text_tags,
|
||||
hv_tags,
|
||||
text_strs,
|
||||
crop_background=False)
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
# # continue for all ignore case
|
||||
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||
return None
|
||||
new_h, new_w, _ = im.shape
|
||||
if (new_h is None) or (new_w is None):
|
||||
return None
|
||||
# resize image
|
||||
std_ratio = float(input_size) / max(new_w, new_h)
|
||||
rand_scales = np.array(
|
||||
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
rz_scale = std_ratio * np.random.choice(rand_scales)
|
||||
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
||||
text_polys[:, :, 0] *= rz_scale
|
||||
text_polys[:, :, 1] *= rz_scale
|
||||
|
||||
# add gaussian blur
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
ks = np.random.permutation(5)[0] + 1
|
||||
ks = int(ks / 2) * 2 + 1
|
||||
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
||||
# add brighter
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 + np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
# add darker
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 - np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
|
||||
# Padding the im to [input_size, input_size]
|
||||
new_h, new_w, _ = im.shape
|
||||
if min(new_w, new_h) < input_size * 0.5:
|
||||
return None
|
||||
im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
|
||||
im_padded[:, :, 2] = 0.485 * 255
|
||||
im_padded[:, :, 1] = 0.456 * 255
|
||||
im_padded[:, :, 0] = 0.406 * 255
|
||||
|
||||
# Random the start position
|
||||
del_h = input_size - new_h
|
||||
del_w = input_size - new_w
|
||||
sh, sw = 0, 0
|
||||
if del_h > 1:
|
||||
sh = int(np.random.rand() * del_h)
|
||||
if del_w > 1:
|
||||
sw = int(np.random.rand() * del_w)
|
||||
|
||||
# Padding
|
||||
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
|
||||
text_polys[:, :, 0] += sw
|
||||
text_polys[:, :, 1] += sh
|
||||
|
||||
score_map, score_label_map, border_map, direction_map, training_mask, \
|
||||
pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
|
||||
input_size,
|
||||
text_polys,
|
||||
text_tags,
|
||||
text_strs, 0.25)
|
||||
if len(label_list) <= 0: # eliminate negative samples
|
||||
return None
|
||||
pos_list_temp = np.zeros([64, 3])
|
||||
pos_mask_temp = np.zeros([64, 1])
|
||||
label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
|
||||
|
||||
for i, label in enumerate(label_list):
|
||||
n = len(label)
|
||||
if n > self.max_text_length:
|
||||
label_list[i] = label[:self.max_text_length]
|
||||
continue
|
||||
while n < self.max_text_length:
|
||||
label.append([self.pad_num])
|
||||
n += 1
|
||||
|
||||
for i in range(len(label_list)):
|
||||
label_list[i] = np.array(label_list[i])
|
||||
|
||||
if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
|
||||
return None
|
||||
for __ in range(self.max_text_nums - len(pos_list), 0, -1):
|
||||
pos_list.append(pos_list_temp)
|
||||
pos_mask.append(pos_mask_temp)
|
||||
label_list.append(label_list_temp)
|
||||
|
||||
if self.img_id == self.batch_size - 1:
|
||||
self.img_id = 0
|
||||
else:
|
||||
self.img_id += 1
|
||||
|
||||
im_padded[:, :, 2] -= 0.485 * 255
|
||||
im_padded[:, :, 1] -= 0.456 * 255
|
||||
im_padded[:, :, 0] -= 0.406 * 255
|
||||
im_padded[:, :, 2] /= (255.0 * 0.229)
|
||||
im_padded[:, :, 1] /= (255.0 * 0.224)
|
||||
im_padded[:, :, 0] /= (255.0 * 0.225)
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
images = im_padded[::-1, :, :]
|
||||
tcl_maps = score_map[np.newaxis, :, :]
|
||||
tcl_label_maps = score_label_map[np.newaxis, :, :]
|
||||
border_maps = border_map.transpose((2, 0, 1))
|
||||
direction_maps = direction_map.transpose((2, 0, 1))
|
||||
training_masks = training_mask[np.newaxis, :, :]
|
||||
pos_list = np.array(pos_list)
|
||||
pos_mask = np.array(pos_mask)
|
||||
label_list = np.array(label_list)
|
||||
data['images'] = images
|
||||
data['tcl_maps'] = tcl_maps
|
||||
data['tcl_label_maps'] = tcl_label_maps
|
||||
data['border_maps'] = border_maps
|
||||
data['direction_maps'] = direction_maps
|
||||
data['training_masks'] = training_masks
|
||||
data['label_list'] = label_list
|
||||
data['pos_list'] = pos_list
|
||||
data['pos_mask'] = pos_mask
|
||||
return data
|
|
@ -0,0 +1,175 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import os
|
||||
from paddle.io import Dataset
|
||||
from .imaug import transform, create_operators
|
||||
import random
|
||||
|
||||
|
||||
class PGDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(PGDataSet, self).__init__()
|
||||
|
||||
self.logger = logger
|
||||
self.seed = seed
|
||||
self.mode = mode
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
|
||||
label_file_list = dataset_config.pop('label_file_list')
|
||||
data_source_num = len(label_file_list)
|
||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||
if isinstance(ratio_list, (float, int)):
|
||||
ratio_list = [float(ratio_list)] * int(data_source_num)
|
||||
self.data_format = dataset_config.get('data_format', 'icdar')
|
||||
assert len(
|
||||
ratio_list
|
||||
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
|
||||
self.data_format)
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if mode.lower() == "train":
|
||||
self.shuffle_data_random()
|
||||
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
def shuffle_data_random(self):
|
||||
if self.do_shuffle:
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def extract_polys(self, poly_txt_path):
|
||||
"""
|
||||
Read text_polys, txt_tags, txts from give txt file.
|
||||
"""
|
||||
text_polys, txt_tags, txts = [], [], []
|
||||
with open(poly_txt_path) as f:
|
||||
for line in f.readlines():
|
||||
poly_str, txt = line.strip().split('\t')
|
||||
poly = list(map(float, poly_str.split(',')))
|
||||
if self.mode.lower() == "eval":
|
||||
while len(poly) < 100:
|
||||
poly.append(-1)
|
||||
text_polys.append(
|
||||
np.array(
|
||||
poly, dtype=np.float32).reshape(-1, 2))
|
||||
txts.append(txt)
|
||||
txt_tags.append(txt == '###')
|
||||
|
||||
return np.array(list(map(np.array, text_polys))), \
|
||||
np.array(txt_tags, dtype=np.bool), txts
|
||||
|
||||
def extract_info_textnet(self, im_fn, img_dir=''):
|
||||
"""
|
||||
Extract information from line in textnet format.
|
||||
"""
|
||||
info_list = im_fn.split('\t')
|
||||
img_path = ''
|
||||
for ext in [
|
||||
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
|
||||
]:
|
||||
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
|
||||
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
|
||||
break
|
||||
|
||||
if img_path == '':
|
||||
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
|
||||
info_list[0], img_dir))
|
||||
|
||||
nBox = (len(info_list) - 1) // 9
|
||||
wordBBs, txts, txt_tags = [], [], []
|
||||
for n in range(0, nBox):
|
||||
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
|
||||
txt = info_list[(n + 1) * 9]
|
||||
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
|
||||
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
|
||||
txts.append(txt)
|
||||
if txt == '###':
|
||||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
|
||||
|
||||
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
|
||||
if isinstance(file_list, str):
|
||||
file_list = [file_list]
|
||||
data_lines = []
|
||||
for idx, data_source in enumerate(file_list):
|
||||
image_files = []
|
||||
if data_format == 'icdar':
|
||||
image_files = [(data_source, x) for x in
|
||||
os.listdir(os.path.join(data_source, 'rgb'))
|
||||
if x.split('.')[-1] in [
|
||||
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
|
||||
'tiff', 'gif', 'JPG'
|
||||
]]
|
||||
elif data_format == 'textnet':
|
||||
with open(data_source) as f:
|
||||
image_files = [(data_source, x.strip())
|
||||
for x in f.readlines()]
|
||||
else:
|
||||
print("Unrecognized data format...")
|
||||
exit(-1)
|
||||
random.seed(self.seed)
|
||||
image_files = random.sample(
|
||||
image_files, round(len(image_files) * ratio_list[idx]))
|
||||
data_lines.extend(image_files)
|
||||
return data_lines
|
||||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data_path, data_line = self.data_lines[file_idx]
|
||||
try:
|
||||
if self.data_format == 'icdar':
|
||||
im_path = os.path.join(data_path, 'rgb', data_line)
|
||||
if self.mode.lower() == "eval":
|
||||
poly_path = os.path.join(data_path, 'poly_gt',
|
||||
data_line.split('.')[0] + '.txt')
|
||||
else:
|
||||
poly_path = os.path.join(data_path, 'poly',
|
||||
data_line.split('.')[0] + '.txt')
|
||||
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
|
||||
else:
|
||||
image_dir = os.path.join(os.path.dirname(data_path), 'image')
|
||||
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
|
||||
data_line, image_dir)
|
||||
|
||||
data = {
|
||||
'img_path': im_path,
|
||||
'polys': text_polys,
|
||||
'tags': text_tags,
|
||||
'strs': text_strs
|
||||
}
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
outs = transform(data, self.ops)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
self.data_idx_order_list[idx], e))
|
||||
outs = None
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
|
@ -29,10 +29,11 @@ def build_loss(config):
|
|||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
# e2e loss
|
||||
from .e2e_pg_loss import PGLoss
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss'
|
||||
]
|
||||
'SRNLoss', 'PGLoss']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
import paddle
|
||||
|
||||
from .det_basic_loss import DiceLoss
|
||||
from ppocr.utils.e2e_utils.extract_batchsize import pre_process
|
||||
|
||||
|
||||
class PGLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
tcl_bs,
|
||||
max_text_length,
|
||||
max_text_nums,
|
||||
pad_num,
|
||||
eps=1e-6,
|
||||
**kwargs):
|
||||
super(PGLoss, self).__init__()
|
||||
self.tcl_bs = tcl_bs
|
||||
self.max_text_nums = max_text_nums
|
||||
self.max_text_length = max_text_length
|
||||
self.pad_num = pad_num
|
||||
self.dice_loss = DiceLoss(eps=eps)
|
||||
|
||||
def border_loss(self, f_border, l_border, l_score, l_mask):
|
||||
l_border_split, l_border_norm = paddle.tensor.split(
|
||||
l_border, num_or_sections=[4, 1], axis=1)
|
||||
f_border_split = f_border
|
||||
b, c, h, w = l_border_norm.shape
|
||||
l_border_norm_split = paddle.expand(
|
||||
x=l_border_norm, shape=[b, 4 * c, h, w])
|
||||
b, c, h, w = l_score.shape
|
||||
l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
|
||||
b, c, h, w = l_mask.shape
|
||||
l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
|
||||
border_diff = l_border_split - f_border_split
|
||||
abs_border_diff = paddle.abs(border_diff)
|
||||
border_sign = abs_border_diff < 1.0
|
||||
border_sign = paddle.cast(border_sign, dtype='float32')
|
||||
border_sign.stop_gradient = True
|
||||
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
|
||||
(abs_border_diff - 0.5) * (1.0 - border_sign)
|
||||
border_out_loss = l_border_norm_split * border_in_loss
|
||||
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
|
||||
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
|
||||
return border_loss
|
||||
|
||||
def direction_loss(self, f_direction, l_direction, l_score, l_mask):
|
||||
l_direction_split, l_direction_norm = paddle.tensor.split(
|
||||
l_direction, num_or_sections=[2, 1], axis=1)
|
||||
f_direction_split = f_direction
|
||||
b, c, h, w = l_direction_norm.shape
|
||||
l_direction_norm_split = paddle.expand(
|
||||
x=l_direction_norm, shape=[b, 2 * c, h, w])
|
||||
b, c, h, w = l_score.shape
|
||||
l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
|
||||
b, c, h, w = l_mask.shape
|
||||
l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
|
||||
direction_diff = l_direction_split - f_direction_split
|
||||
abs_direction_diff = paddle.abs(direction_diff)
|
||||
direction_sign = abs_direction_diff < 1.0
|
||||
direction_sign = paddle.cast(direction_sign, dtype='float32')
|
||||
direction_sign.stop_gradient = True
|
||||
direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
|
||||
(abs_direction_diff - 0.5) * (1.0 - direction_sign)
|
||||
direction_out_loss = l_direction_norm_split * direction_in_loss
|
||||
direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
|
||||
(paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
|
||||
return direction_loss
|
||||
|
||||
def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
|
||||
f_char = paddle.transpose(f_char, [0, 2, 3, 1])
|
||||
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
|
||||
tcl_pos = paddle.cast(tcl_pos, dtype=int)
|
||||
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
|
||||
f_tcl_char = paddle.reshape(f_tcl_char,
|
||||
[-1, 64, 37]) # len(Lexicon_Table)+1
|
||||
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
|
||||
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
|
||||
b, c, l = tcl_mask.shape
|
||||
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
|
||||
tcl_mask_fg.stop_gradient = True
|
||||
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
|
||||
-20.0)
|
||||
f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
|
||||
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
|
||||
N, B, _ = f_tcl_char_ld.shape
|
||||
input_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
||||
cost = paddle.nn.functional.ctc_loss(
|
||||
log_probs=f_tcl_char_ld,
|
||||
labels=tcl_label,
|
||||
input_lengths=input_lengths,
|
||||
label_lengths=label_t,
|
||||
blank=self.pad_num,
|
||||
reduction='none')
|
||||
cost = cost.mean()
|
||||
return cost
|
||||
|
||||
def forward(self, predicts, labels):
|
||||
images, tcl_maps, tcl_label_maps, border_maps \
|
||||
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
|
||||
# for all the batch_size
|
||||
pos_list, pos_mask, label_list, label_t = pre_process(
|
||||
label_list, pos_list, pos_mask, self.max_text_length,
|
||||
self.max_text_nums, self.pad_num, self.tcl_bs)
|
||||
|
||||
f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
|
||||
predicts['f_char']
|
||||
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
|
||||
border_loss = self.border_loss(f_border, border_maps, tcl_maps,
|
||||
training_masks)
|
||||
direction_loss = self.direction_loss(f_direction, direction_maps,
|
||||
tcl_maps, training_masks)
|
||||
ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
|
||||
loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
|
||||
|
||||
losses = {
|
||||
'loss': loss_all,
|
||||
"score_loss": score_loss,
|
||||
"border_loss": border_loss,
|
||||
"direction_loss": direction_loss,
|
||||
"ctc_loss": ctc_loss
|
||||
}
|
||||
return losses
|
|
@ -26,8 +26,9 @@ def build_metric(config):
|
|||
from .det_metric import DetMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
|
||||
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric']
|
||||
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__all__ = ['E2EMetric']
|
||||
|
||||
from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results
|
||||
from ppocr.utils.e2e_utils.extract_textpoint import get_dict
|
||||
|
||||
|
||||
class E2EMetric(object):
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
main_indicator='f_score_e2e',
|
||||
**kwargs):
|
||||
self.label_list = get_dict(character_dict_path)
|
||||
self.max_index = len(self.label_list)
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
temp_gt_polyons_batch = batch[2]
|
||||
temp_gt_strs_batch = batch[3]
|
||||
ignore_tags_batch = batch[4]
|
||||
gt_polyons_batch = []
|
||||
gt_strs_batch = []
|
||||
|
||||
temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
|
||||
for temp_list in temp_gt_polyons_batch:
|
||||
t = []
|
||||
for index in temp_list:
|
||||
if index[0] != -1 and index[1] != -1:
|
||||
t.append(index)
|
||||
gt_polyons_batch.append(t)
|
||||
|
||||
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
|
||||
for temp_list in temp_gt_strs_batch:
|
||||
t = ""
|
||||
for index in temp_list:
|
||||
if index < self.max_index:
|
||||
t += self.label_list[index]
|
||||
gt_strs_batch.append(t)
|
||||
|
||||
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
||||
[preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
|
||||
# prepare gt
|
||||
gt_info_list = [{
|
||||
'points': gt_polyon,
|
||||
'text': gt_str,
|
||||
'ignore': ignore_tag
|
||||
} for gt_polyon, gt_str, ignore_tag in
|
||||
zip(gt_polyons, gt_strs, ignore_tags)]
|
||||
# prepare det
|
||||
e2e_info_list = [{
|
||||
'points': det_polyon,
|
||||
'text': pred_str
|
||||
} for det_polyon, pred_str in zip(pred['points'], pred['strs'])]
|
||||
result = get_socre(gt_info_list, e2e_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
def get_metric(self):
|
||||
metircs = combine_results(self.results)
|
||||
self.reset()
|
||||
return metircs
|
||||
|
||||
def reset(self):
|
||||
self.results = [] # clear results
|
|
@ -150,7 +150,7 @@ class DetectionIoUEvaluator(object):
|
|||
pairs.append({'gt': gtNum, 'det': detNum})
|
||||
detMatchedNums.append(detNum)
|
||||
evaluationLog += "Match GT #" + \
|
||||
str(gtNum) + " with Det #" + str(detNum) + "\n"
|
||||
str(gtNum) + " with Det #" + str(detNum) + "\n"
|
||||
|
||||
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
||||
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
||||
|
@ -162,7 +162,7 @@ class DetectionIoUEvaluator(object):
|
|||
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
|
||||
|
||||
hmean = 0 if (precision + recall) == 0 else 2.0 * \
|
||||
precision * recall / (precision + recall)
|
||||
precision * recall / (precision + recall)
|
||||
|
||||
matchedSum += detMatched
|
||||
numGlobalCareGt += numGtCare
|
||||
|
@ -200,7 +200,8 @@ class DetectionIoUEvaluator(object):
|
|||
methodPrecision = 0 if numGlobalCareDet == 0 else float(
|
||||
matchedSum) / numGlobalCareDet
|
||||
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
|
||||
methodRecall * methodPrecision / (methodRecall + methodPrecision)
|
||||
methodRecall * methodPrecision / (
|
||||
methodRecall + methodPrecision)
|
||||
# print(methodRecall, methodPrecision, methodHmean)
|
||||
# sys.exit(-1)
|
||||
methodMetrics = {
|
||||
|
|
|
@ -26,6 +26,9 @@ def build_backbone(config, model_type):
|
|||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
|
||||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
support_dict = ['ResNet']
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -0,0 +1,265 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ["ResNet"]
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None, ):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = nn.BatchNorm(
|
||||
out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv2)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv1)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
# depth = [3, 4, 6, 3]
|
||||
depth = [3, 4, 6, 3, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
elif layers == 200:
|
||||
depth = [3, 12, 48, 3]
|
||||
num_channels = [64, 256, 512, 1024,
|
||||
2048] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512, 512]
|
||||
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1_1")
|
||||
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = [3, 64]
|
||||
# num_filters = [64, 128, 256, 512, 512]
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BottleneckBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(bottleneck_block)
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
basic_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BasicBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block],
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(basic_block)
|
||||
self.out_channels.append(num_filters[block])
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
||||
def forward(self, inputs):
|
||||
out = [inputs]
|
||||
y = self.conv1_1(inputs)
|
||||
out.append(y)
|
||||
y = self.pool2d_max(y)
|
||||
for block in self.stages:
|
||||
y = block(y)
|
||||
out.append(y)
|
||||
return out
|
|
@ -20,6 +20,7 @@ def build_head(config):
|
|||
from .det_db_head import DBHead
|
||||
from .det_east_head import EASTHead
|
||||
from .det_sast_head import SASTHead
|
||||
from .e2e_pg_head import PGHead
|
||||
|
||||
# rec head
|
||||
from .rec_ctc_head import CTCHead
|
||||
|
@ -30,8 +31,8 @@ def build_head(config):
|
|||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead'
|
||||
]
|
||||
'SRNHead', 'PGHead']
|
||||
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,253 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||
moving_mean_name="bn_" + name + "_mean",
|
||||
moving_variance_name="bn_" + name + "_variance",
|
||||
use_global_stats=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class PGHead(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(PGHead, self).__init__()
|
||||
self.conv_f_score1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_score{}".format(1))
|
||||
self.conv_f_score2 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
name="conv_f_score{}".format(2))
|
||||
self.conv_f_score3 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_score{}".format(3))
|
||||
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels=128,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=ParamAttr(name="conv_f_score{}".format(4)),
|
||||
bias_attr=False)
|
||||
|
||||
self.conv_f_boder1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_boder{}".format(1))
|
||||
self.conv_f_boder2 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
name="conv_f_boder{}".format(2))
|
||||
self.conv_f_boder3 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_boder{}".format(3))
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=128,
|
||||
out_channels=4,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=ParamAttr(name="conv_f_boder{}".format(4)),
|
||||
bias_attr=False)
|
||||
self.conv_f_char1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=128,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_char{}".format(1))
|
||||
self.conv_f_char2 = ConvBNLayer(
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
name="conv_f_char{}".format(2))
|
||||
self.conv_f_char3 = ConvBNLayer(
|
||||
in_channels=128,
|
||||
out_channels=256,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_char{}".format(3))
|
||||
self.conv_f_char4 = ConvBNLayer(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
name="conv_f_char{}".format(4))
|
||||
self.conv_f_char5 = ConvBNLayer(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_char{}".format(5))
|
||||
self.conv3 = nn.Conv2D(
|
||||
in_channels=256,
|
||||
out_channels=37,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=ParamAttr(name="conv_f_char{}".format(6)),
|
||||
bias_attr=False)
|
||||
|
||||
self.conv_f_direc1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_direc{}".format(1))
|
||||
self.conv_f_direc2 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
name="conv_f_direc{}".format(2))
|
||||
self.conv_f_direc3 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_direc{}".format(3))
|
||||
self.conv4 = nn.Conv2D(
|
||||
in_channels=128,
|
||||
out_channels=2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
f_score = self.conv_f_score1(x)
|
||||
f_score = self.conv_f_score2(f_score)
|
||||
f_score = self.conv_f_score3(f_score)
|
||||
f_score = self.conv1(f_score)
|
||||
f_score = F.sigmoid(f_score)
|
||||
|
||||
# f_border
|
||||
f_border = self.conv_f_boder1(x)
|
||||
f_border = self.conv_f_boder2(f_border)
|
||||
f_border = self.conv_f_boder3(f_border)
|
||||
f_border = self.conv2(f_border)
|
||||
|
||||
f_char = self.conv_f_char1(x)
|
||||
f_char = self.conv_f_char2(f_char)
|
||||
f_char = self.conv_f_char3(f_char)
|
||||
f_char = self.conv_f_char4(f_char)
|
||||
f_char = self.conv_f_char5(f_char)
|
||||
f_char = self.conv3(f_char)
|
||||
|
||||
f_direction = self.conv_f_direc1(x)
|
||||
f_direction = self.conv_f_direc2(f_direction)
|
||||
f_direction = self.conv_f_direc3(f_direction)
|
||||
f_direction = self.conv4(f_direction)
|
||||
|
||||
predicts = {}
|
||||
predicts['f_score'] = f_score
|
||||
predicts['f_border'] = f_border
|
||||
predicts['f_char'] = f_char
|
||||
predicts['f_direction'] = f_direction
|
||||
return predicts
|
|
@ -14,12 +14,14 @@
|
|||
|
||||
__all__ = ['build_neck']
|
||||
|
||||
|
||||
def build_neck(config):
|
||||
from .db_fpn import DBFPN
|
||||
from .east_fpn import EASTFPN
|
||||
from .sast_fpn import SASTFPN
|
||||
from .rnn import SequenceEncoder
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder']
|
||||
from .pg_fpn import PGFPN
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,314 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = nn.BatchNorm(
|
||||
out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance',
|
||||
use_global_stats=False)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class DeConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(DeConvBNLayer, self).__init__()
|
||||
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.deconv = nn.Conv2DTranspose(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||
moving_mean_name="bn_" + name + "_mean",
|
||||
moving_variance_name="bn_" + name + "_variance",
|
||||
use_global_stats=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.deconv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class PGFPN(nn.Layer):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(PGFPN, self).__init__()
|
||||
num_inputs = [2048, 2048, 1024, 512, 256]
|
||||
num_outputs = [256, 256, 192, 192, 128]
|
||||
self.out_channels = 128
|
||||
self.conv_bn_layer_1 = ConvBNLayer(
|
||||
in_channels=3,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act=None,
|
||||
name='FPN_d1')
|
||||
self.conv_bn_layer_2 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act=None,
|
||||
name='FPN_d2')
|
||||
self.conv_bn_layer_3 = ConvBNLayer(
|
||||
in_channels=256,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act=None,
|
||||
name='FPN_d3')
|
||||
self.conv_bn_layer_4 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
act=None,
|
||||
name='FPN_d4')
|
||||
self.conv_bn_layer_5 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name='FPN_d5')
|
||||
self.conv_bn_layer_6 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
act=None,
|
||||
name='FPN_d6')
|
||||
self.conv_bn_layer_7 = ConvBNLayer(
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name='FPN_d7')
|
||||
self.conv_bn_layer_8 = ConvBNLayer(
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name='FPN_d8')
|
||||
|
||||
self.conv_h0 = ConvBNLayer(
|
||||
in_channels=num_inputs[0],
|
||||
out_channels=num_outputs[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_h{}".format(0))
|
||||
self.conv_h1 = ConvBNLayer(
|
||||
in_channels=num_inputs[1],
|
||||
out_channels=num_outputs[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_h{}".format(1))
|
||||
self.conv_h2 = ConvBNLayer(
|
||||
in_channels=num_inputs[2],
|
||||
out_channels=num_outputs[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_h{}".format(2))
|
||||
self.conv_h3 = ConvBNLayer(
|
||||
in_channels=num_inputs[3],
|
||||
out_channels=num_outputs[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_h{}".format(3))
|
||||
self.conv_h4 = ConvBNLayer(
|
||||
in_channels=num_inputs[4],
|
||||
out_channels=num_outputs[4],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_h{}".format(4))
|
||||
|
||||
self.dconv0 = DeConvBNLayer(
|
||||
in_channels=num_outputs[0],
|
||||
out_channels=num_outputs[0 + 1],
|
||||
name="dconv_{}".format(0))
|
||||
self.dconv1 = DeConvBNLayer(
|
||||
in_channels=num_outputs[1],
|
||||
out_channels=num_outputs[1 + 1],
|
||||
act=None,
|
||||
name="dconv_{}".format(1))
|
||||
self.dconv2 = DeConvBNLayer(
|
||||
in_channels=num_outputs[2],
|
||||
out_channels=num_outputs[2 + 1],
|
||||
act=None,
|
||||
name="dconv_{}".format(2))
|
||||
self.dconv3 = DeConvBNLayer(
|
||||
in_channels=num_outputs[3],
|
||||
out_channels=num_outputs[3 + 1],
|
||||
act=None,
|
||||
name="dconv_{}".format(3))
|
||||
self.conv_g1 = ConvBNLayer(
|
||||
in_channels=num_outputs[1],
|
||||
out_channels=num_outputs[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv_g{}".format(1))
|
||||
self.conv_g2 = ConvBNLayer(
|
||||
in_channels=num_outputs[2],
|
||||
out_channels=num_outputs[2],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv_g{}".format(2))
|
||||
self.conv_g3 = ConvBNLayer(
|
||||
in_channels=num_outputs[3],
|
||||
out_channels=num_outputs[3],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv_g{}".format(3))
|
||||
self.conv_g4 = ConvBNLayer(
|
||||
in_channels=num_outputs[4],
|
||||
out_channels=num_outputs[4],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv_g{}".format(4))
|
||||
self.convf = ConvBNLayer(
|
||||
in_channels=num_outputs[4],
|
||||
out_channels=num_outputs[4],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_f{}".format(4))
|
||||
|
||||
def forward(self, x):
|
||||
c0, c1, c2, c3, c4, c5, c6 = x
|
||||
# FPN_Down_Fusion
|
||||
f = [c0, c1, c2]
|
||||
g = [None, None, None]
|
||||
h = [None, None, None]
|
||||
h[0] = self.conv_bn_layer_1(f[0])
|
||||
h[1] = self.conv_bn_layer_2(f[1])
|
||||
h[2] = self.conv_bn_layer_3(f[2])
|
||||
|
||||
g[0] = self.conv_bn_layer_4(h[0])
|
||||
g[1] = paddle.add(g[0], h[1])
|
||||
g[1] = F.relu(g[1])
|
||||
g[1] = self.conv_bn_layer_5(g[1])
|
||||
g[1] = self.conv_bn_layer_6(g[1])
|
||||
|
||||
g[2] = paddle.add(g[1], h[2])
|
||||
g[2] = F.relu(g[2])
|
||||
g[2] = self.conv_bn_layer_7(g[2])
|
||||
f_down = self.conv_bn_layer_8(g[2])
|
||||
|
||||
# FPN UP Fusion
|
||||
f1 = [c6, c5, c4, c3, c2]
|
||||
g = [None, None, None, None, None]
|
||||
h = [None, None, None, None, None]
|
||||
h[0] = self.conv_h0(f1[0])
|
||||
h[1] = self.conv_h1(f1[1])
|
||||
h[2] = self.conv_h2(f1[2])
|
||||
h[3] = self.conv_h3(f1[3])
|
||||
h[4] = self.conv_h4(f1[4])
|
||||
|
||||
g[0] = self.dconv0(h[0])
|
||||
g[1] = paddle.add(g[0], h[1])
|
||||
g[1] = F.relu(g[1])
|
||||
g[1] = self.conv_g1(g[1])
|
||||
g[1] = self.dconv1(g[1])
|
||||
|
||||
g[2] = paddle.add(g[1], h[2])
|
||||
g[2] = F.relu(g[2])
|
||||
g[2] = self.conv_g2(g[2])
|
||||
g[2] = self.dconv2(g[2])
|
||||
|
||||
g[3] = paddle.add(g[2], h[3])
|
||||
g[3] = F.relu(g[3])
|
||||
g[3] = self.conv_g3(g[3])
|
||||
g[3] = self.dconv3(g[3])
|
||||
|
||||
g[4] = paddle.add(x=g[3], y=h[4])
|
||||
g[4] = F.relu(g[4])
|
||||
g[4] = self.conv_g4(g[4])
|
||||
f_up = self.convf(g[4])
|
||||
f_common = paddle.add(f_down, f_up)
|
||||
f_common = F.relu(f_common)
|
||||
return f_common
|
|
@ -28,10 +28,11 @@ def build_post_process(config, global_config=None):
|
|||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
||||
from ppocr.utils.e2e_utils.extract_textpoint import *
|
||||
from ppocr.utils.e2e_utils.visual import *
|
||||
import paddle
|
||||
|
||||
|
||||
class PGPostProcess(object):
|
||||
"""
|
||||
The post process for PGNet.
|
||||
"""
|
||||
|
||||
def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs):
|
||||
self.Lexicon_Table = get_dict(character_dict_path)
|
||||
self.valid_set = valid_set
|
||||
self.score_thresh = score_thresh
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||
self.is_python35 = True
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
p_score = outs_dict['f_score']
|
||||
p_border = outs_dict['f_border']
|
||||
p_char = outs_dict['f_char']
|
||||
p_direction = outs_dict['f_direction']
|
||||
if isinstance(p_score, paddle.Tensor):
|
||||
p_score = p_score[0].numpy()
|
||||
p_border = p_border[0].numpy()
|
||||
p_direction = p_direction[0].numpy()
|
||||
p_char = p_char[0].numpy()
|
||||
else:
|
||||
p_score = p_score[0]
|
||||
p_border = p_border[0]
|
||||
p_direction = p_direction[0]
|
||||
p_char = p_char[0]
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[0]
|
||||
is_curved = self.valid_set == "totaltext"
|
||||
instance_yxs_list = generate_pivot_list(
|
||||
p_score,
|
||||
p_char,
|
||||
p_direction,
|
||||
score_thresh=self.score_thresh,
|
||||
is_backbone=True,
|
||||
is_curved=is_curved)
|
||||
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
|
||||
char_seq_idx_set = []
|
||||
for i in range(len(instance_yxs_list)):
|
||||
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
|
||||
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
|
||||
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
|
||||
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
|
||||
feature_len = [len(feature_seq[0])]
|
||||
featyre_seq = paddle.to_tensor(feature_seq)
|
||||
feature_len = np.array([feature_len]).astype(np.int64)
|
||||
length = paddle.to_tensor(feature_len)
|
||||
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
|
||||
input=featyre_seq, blank=36, input_length=length)
|
||||
seq_pred_str = seq_pred[0].numpy().tolist()[0]
|
||||
seq_len = seq_pred[1].numpy()[0][0]
|
||||
temp_t = []
|
||||
for c in seq_pred_str[:seq_len]:
|
||||
temp_t.append(c)
|
||||
char_seq_idx_set.append(temp_t)
|
||||
seq_strs = []
|
||||
for char_idx_set in char_seq_idx_set:
|
||||
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
|
||||
seq_strs.append(pr_str)
|
||||
poly_list = []
|
||||
keep_str_list = []
|
||||
all_point_list = []
|
||||
all_point_pair_list = []
|
||||
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
||||
if len(yx_center_line) == 1:
|
||||
yx_center_line.append(yx_center_line[-1])
|
||||
|
||||
offset_expand = 1.0
|
||||
if self.valid_set == 'totaltext':
|
||||
offset_expand = 1.2
|
||||
|
||||
point_pair_list = []
|
||||
for batch_id, y, x in yx_center_line:
|
||||
offset = p_border[:, y, x].reshape(2, 2)
|
||||
if offset_expand != 1.0:
|
||||
offset_length = np.linalg.norm(
|
||||
offset, axis=1, keepdims=True)
|
||||
expand_length = np.clip(
|
||||
offset_length * (offset_expand - 1),
|
||||
a_min=0.5,
|
||||
a_max=3.0)
|
||||
offset_detal = offset / offset_length * expand_length
|
||||
offset = offset + offset_detal
|
||||
ori_yx = np.array([y, x], dtype=np.float32)
|
||||
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
|
||||
[ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
all_point_list.append([
|
||||
int(round(x * 4.0 / ratio_w)),
|
||||
int(round(y * 4.0 / ratio_h))
|
||||
])
|
||||
all_point_pair_list.append(point_pair.round().astype(np.int32)
|
||||
.tolist())
|
||||
|
||||
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
|
||||
detected_poly = expand_poly_along_width(
|
||||
detected_poly, shrink_ratio_of_width=0.2)
|
||||
detected_poly[:, 0] = np.clip(
|
||||
detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(
|
||||
detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
|
||||
if len(keep_str) < 2:
|
||||
continue
|
||||
|
||||
keep_str_list.append(keep_str)
|
||||
if self.valid_set == 'partvgg':
|
||||
middle_point = len(detected_poly) // 2
|
||||
detected_poly = detected_poly[
|
||||
[0, middle_point - 1, middle_point, -1], :]
|
||||
poly_list.append(detected_poly)
|
||||
elif self.valid_set == 'totaltext':
|
||||
poly_list.append(detected_poly)
|
||||
else:
|
||||
print('--> Not supported format.')
|
||||
exit(-1)
|
||||
data = {
|
||||
'points': poly_list,
|
||||
'strs': keep_str_list,
|
||||
}
|
||||
return data
|
|
@ -18,6 +18,7 @@ from __future__ import print_function
|
|||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
@ -49,12 +50,12 @@ class SASTPostProcess(object):
|
|||
self.shrink_ratio_of_width = shrink_ratio_of_width
|
||||
self.expand_scale = expand_scale
|
||||
self.tcl_map_thresh = tcl_map_thresh
|
||||
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||
self.is_python35 = True
|
||||
|
||||
|
||||
def point_pair2poly(self, point_pair_list):
|
||||
"""
|
||||
Transfer vertical point_pairs into poly point in clockwise.
|
||||
|
@ -66,31 +67,42 @@ class SASTPostProcess(object):
|
|||
point_list[idx] = point_pair[0]
|
||||
point_list[point_num - 1 - idx] = point_pair[1]
|
||||
return np.array(point_list).reshape(-1, 2)
|
||||
|
||||
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
|
||||
def shrink_quad_along_width(self,
|
||||
quad,
|
||||
begin_width_ratio=0.,
|
||||
end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
|
||||
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
|
||||
"""
|
||||
expand poly along width.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_quad = np.array(
|
||||
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
|
||||
1.0)
|
||||
right_quad = np.array(
|
||||
[
|
||||
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]
|
||||
],
|
||||
dtype=np.float32)
|
||||
right_ratio = 1.0 + \
|
||||
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
|
||||
right_ratio)
|
||||
poly[0] = left_quad_expand[0]
|
||||
poly[-1] = left_quad_expand[-1]
|
||||
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||
|
@ -100,7 +112,7 @@ class SASTPostProcess(object):
|
|||
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
|
||||
"""Restore quad."""
|
||||
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
|
||||
# Sort the text boxes via the y axis
|
||||
xy_text = xy_text[np.argsort(xy_text[:, 1])]
|
||||
|
@ -112,7 +124,7 @@ class SASTPostProcess(object):
|
|||
point_num = int(tvo_map.shape[-1] / 2)
|
||||
assert point_num == 4
|
||||
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
|
||||
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
|
||||
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
|
||||
quads = xy_text_tile - tvo_map
|
||||
|
||||
return scores, quads, xy_text
|
||||
|
@ -121,14 +133,12 @@ class SASTPostProcess(object):
|
|||
"""
|
||||
compute area of a quad.
|
||||
"""
|
||||
edge = [
|
||||
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
|
||||
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
|
||||
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
|
||||
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
|
||||
]
|
||||
edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
|
||||
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
|
||||
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
|
||||
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
|
||||
def nms(self, dets):
|
||||
if self.is_python35:
|
||||
import lanms
|
||||
|
@ -141,7 +151,7 @@ class SASTPostProcess(object):
|
|||
"""
|
||||
Cluster pixels in tcl_map based on quads.
|
||||
"""
|
||||
instance_count = quads.shape[0] + 1 # contain background
|
||||
instance_count = quads.shape[0] + 1 # contain background
|
||||
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
|
||||
if instance_count == 1:
|
||||
return instance_count, instance_label_map
|
||||
|
@ -149,18 +159,19 @@ class SASTPostProcess(object):
|
|||
# predict text center
|
||||
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||
n = xy_text.shape[0]
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
|
||||
pred_tc = xy_text - tco
|
||||
|
||||
|
||||
# get gt text center
|
||||
m = quads.shape[0]
|
||||
gt_tc = np.mean(quads, axis=1) # (m, 2)
|
||||
gt_tc = np.mean(quads, axis=1) # (m, 2)
|
||||
|
||||
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
|
||||
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
|
||||
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
|
||||
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
|
||||
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
|
||||
(1, m, 1)) # (n, m, 2)
|
||||
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
|
||||
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
|
||||
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
|
||||
|
||||
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
|
||||
return instance_count, instance_label_map
|
||||
|
@ -169,26 +180,47 @@ class SASTPostProcess(object):
|
|||
"""
|
||||
Estimate sample points number.
|
||||
"""
|
||||
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
|
||||
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
eh = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[1] - quad[2])) / 2.0
|
||||
ew = (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
|
||||
dense_sample_pts_num = max(2, int(ew))
|
||||
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
|
||||
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||
dense_xy_center_line = xy_text[np.linspace(
|
||||
0,
|
||||
xy_text.shape[0] - 1,
|
||||
dense_sample_pts_num,
|
||||
endpoint=True,
|
||||
dtype=np.float32).astype(np.int32)]
|
||||
|
||||
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
|
||||
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
|
||||
dense_xy_center_line_diff = dense_xy_center_line[
|
||||
1:] - dense_xy_center_line[:-1]
|
||||
estimate_arc_len = np.sum(
|
||||
np.linalg.norm(
|
||||
dense_xy_center_line_diff, axis=1))
|
||||
|
||||
sample_pts_num = max(2, int(estimate_arc_len / eh))
|
||||
return sample_pts_num
|
||||
|
||||
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
|
||||
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
|
||||
def detect_sast(self,
|
||||
tcl_map,
|
||||
tvo_map,
|
||||
tbo_map,
|
||||
tco_map,
|
||||
ratio_w,
|
||||
ratio_h,
|
||||
src_w,
|
||||
src_h,
|
||||
shrink_ratio_of_width=0.3,
|
||||
tcl_map_thresh=0.5,
|
||||
offset_expand=1.0,
|
||||
out_strid=4.0):
|
||||
"""
|
||||
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
|
||||
"""
|
||||
# restore quad
|
||||
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
|
||||
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
|
||||
tvo_map)
|
||||
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
|
||||
dets = self.nms(dets)
|
||||
if dets.shape[0] == 0:
|
||||
|
@ -202,7 +234,8 @@ class SASTPostProcess(object):
|
|||
|
||||
# instance segmentation
|
||||
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
|
||||
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
|
||||
instance_count, instance_label_map = self.cluster_by_quads_tco(
|
||||
tcl_map, tcl_map_thresh, quads, tco_map)
|
||||
|
||||
# restore single poly with tcl instance.
|
||||
poly_list = []
|
||||
|
@ -212,10 +245,10 @@ class SASTPostProcess(object):
|
|||
q_area = quad_areas[instance_idx - 1]
|
||||
if q_area < 5:
|
||||
continue
|
||||
|
||||
|
||||
#
|
||||
len1 = float(np.linalg.norm(quad[0] -quad[1]))
|
||||
len2 = float(np.linalg.norm(quad[1] -quad[2]))
|
||||
len1 = float(np.linalg.norm(quad[0] - quad[1]))
|
||||
len2 = float(np.linalg.norm(quad[1] - quad[2]))
|
||||
min_len = min(len1, len2)
|
||||
if min_len < 3:
|
||||
continue
|
||||
|
@ -225,16 +258,18 @@ class SASTPostProcess(object):
|
|||
continue
|
||||
|
||||
# filter low confidence instance
|
||||
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
|
||||
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
|
||||
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
|
||||
continue
|
||||
|
||||
# sort xy_text
|
||||
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
|
||||
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
|
||||
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
|
||||
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
|
||||
left_center_pt = np.array(
|
||||
[[(quad[0, 0] + quad[-1, 0]) / 2.0,
|
||||
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
|
||||
right_center_pt = np.array(
|
||||
[[(quad[1, 0] + quad[2, 0]) / 2.0,
|
||||
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
|
||||
proj_unit_vec = (right_center_pt - left_center_pt) / \
|
||||
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
|
||||
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
|
||||
|
@ -245,33 +280,45 @@ class SASTPostProcess(object):
|
|||
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
|
||||
else:
|
||||
sample_pts_num = self.sample_pts_num
|
||||
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
|
||||
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||
xy_center_line = xy_text[np.linspace(
|
||||
0,
|
||||
xy_text.shape[0] - 1,
|
||||
sample_pts_num,
|
||||
endpoint=True,
|
||||
dtype=np.float32).astype(np.int32)]
|
||||
|
||||
point_pair_list = []
|
||||
for x, y in xy_center_line:
|
||||
# get corresponding offset
|
||||
offset = tbo_map[y, x, :].reshape(2, 2)
|
||||
if offset_expand != 1.0:
|
||||
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
|
||||
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
|
||||
offset_length = np.linalg.norm(
|
||||
offset, axis=1, keepdims=True)
|
||||
expand_length = np.clip(
|
||||
offset_length * (offset_expand - 1),
|
||||
a_min=0.5,
|
||||
a_max=3.0)
|
||||
offset_detal = offset / offset_length * expand_length
|
||||
offset = offset + offset_detal
|
||||
# original point
|
||||
offset = offset + offset_detal
|
||||
# original point
|
||||
ori_yx = np.array([y, x], dtype=np.float32)
|
||||
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
|
||||
[ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
# ndarry: (x, 2), expand poly along width
|
||||
detected_poly = self.point_pair2poly(point_pair_list)
|
||||
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
|
||||
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
detected_poly = self.expand_poly_along_width(detected_poly,
|
||||
shrink_ratio_of_width)
|
||||
detected_poly[:, 0] = np.clip(
|
||||
detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(
|
||||
detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
poly_list.append(detected_poly)
|
||||
|
||||
return poly_list
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
score_list = outs_dict['f_score']
|
||||
border_list = outs_dict['f_border']
|
||||
tvo_list = outs_dict['f_tvo']
|
||||
|
@ -281,20 +328,28 @@ class SASTPostProcess(object):
|
|||
border_list = border_list.numpy()
|
||||
tvo_list = tvo_list.numpy()
|
||||
tco_list = tco_list.numpy()
|
||||
|
||||
|
||||
img_num = len(shape_list)
|
||||
poly_lists = []
|
||||
for ino in range(img_num):
|
||||
p_score = score_list[ino].transpose((1,2,0))
|
||||
p_border = border_list[ino].transpose((1,2,0))
|
||||
p_tvo = tvo_list[ino].transpose((1,2,0))
|
||||
p_tco = tco_list[ino].transpose((1,2,0))
|
||||
p_score = score_list[ino].transpose((1, 2, 0))
|
||||
p_border = border_list[ino].transpose((1, 2, 0))
|
||||
p_tvo = tvo_list[ino].transpose((1, 2, 0))
|
||||
p_tco = tco_list[ino].transpose((1, 2, 0))
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
|
||||
|
||||
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
|
||||
shrink_ratio_of_width=self.shrink_ratio_of_width,
|
||||
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
|
||||
poly_list = self.detect_sast(
|
||||
p_score,
|
||||
p_tvo,
|
||||
p_border,
|
||||
p_tco,
|
||||
ratio_w,
|
||||
ratio_h,
|
||||
src_w,
|
||||
src_h,
|
||||
shrink_ratio_of_width=self.shrink_ratio_of_width,
|
||||
tcl_map_thresh=self.tcl_map_thresh,
|
||||
offset_expand=self.expand_scale)
|
||||
poly_lists.append({'points': np.array(poly_list)})
|
||||
|
||||
return poly_lists
|
||||
|
||||
|
|
|
@ -0,0 +1,458 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
||||
|
||||
|
||||
def get_socre(gt_dict, pred_dict):
|
||||
allInputs = 1
|
||||
|
||||
def input_reading_mod(pred_dict):
|
||||
"""This helper reads input from txt files"""
|
||||
det = []
|
||||
n = len(pred_dict)
|
||||
for i in range(n):
|
||||
points = pred_dict[i]['points']
|
||||
text = pred_dict[i]['text']
|
||||
point = ",".join(map(str, points.reshape(-1, )))
|
||||
det.append([point, text])
|
||||
return det
|
||||
|
||||
def gt_reading_mod(gt_dict):
|
||||
"""This helper reads groundtruths from mat files"""
|
||||
gt = []
|
||||
n = len(gt_dict)
|
||||
for i in range(n):
|
||||
points = gt_dict[i]['points']
|
||||
h = len(points)
|
||||
text = gt_dict[i]['text']
|
||||
xx = [
|
||||
np.array(
|
||||
['x:'], dtype='<U2'), 0, np.array(
|
||||
['y:'], dtype='<U2'), 0, np.array(
|
||||
['#'], dtype='<U1'), np.array(
|
||||
['#'], dtype='<U1')
|
||||
]
|
||||
t_x, t_y = [], []
|
||||
for j in range(h):
|
||||
t_x.append(points[j][0])
|
||||
t_y.append(points[j][1])
|
||||
xx[1] = np.array([t_x], dtype='int16')
|
||||
xx[3] = np.array([t_y], dtype='int16')
|
||||
if text != "" and "#" not in text:
|
||||
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
|
||||
xx[5] = np.array(['c'], dtype='<U1')
|
||||
gt.append(xx)
|
||||
return gt
|
||||
|
||||
def detection_filtering(detections, groundtruths, threshold=0.5):
|
||||
for gt_id, gt in enumerate(groundtruths):
|
||||
if (gt[5] == '#') and (gt[1].shape[1] > 1):
|
||||
gt_x = list(map(int, np.squeeze(gt[1])))
|
||||
gt_y = list(map(int, np.squeeze(gt[3])))
|
||||
for det_id, detection in enumerate(detections):
|
||||
detection_orig = detection
|
||||
detection = [float(x) for x in detection[0].split(',')]
|
||||
detection = list(map(int, detection))
|
||||
det_x = detection[0::2]
|
||||
det_y = detection[1::2]
|
||||
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
|
||||
if det_gt_iou > threshold:
|
||||
detections[det_id] = []
|
||||
|
||||
detections[:] = [item for item in detections if item != []]
|
||||
return detections
|
||||
|
||||
def sigma_calculation(det_x, det_y, gt_x, gt_y):
|
||||
"""
|
||||
sigma = inter_area / gt_area
|
||||
"""
|
||||
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
|
||||
area(gt_x, gt_y)), 2)
|
||||
|
||||
def tau_calculation(det_x, det_y, gt_x, gt_y):
|
||||
if area(det_x, det_y) == 0.0:
|
||||
return 0
|
||||
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
|
||||
area(det_x, det_y)), 2)
|
||||
|
||||
##############################Initialization###################################
|
||||
# global_sigma = []
|
||||
# global_tau = []
|
||||
# global_pred_str = []
|
||||
# global_gt_str = []
|
||||
###############################################################################
|
||||
|
||||
for input_id in range(allInputs):
|
||||
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
|
||||
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
|
||||
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
|
||||
and (input_id != 'Deteval_result_non_curved.txt'):
|
||||
detections = input_reading_mod(pred_dict)
|
||||
groundtruths = gt_reading_mod(gt_dict)
|
||||
detections = detection_filtering(
|
||||
detections,
|
||||
groundtruths) # filters detections overlapping with DC area
|
||||
dc_id = []
|
||||
for i in range(len(groundtruths)):
|
||||
if groundtruths[i][5] == '#':
|
||||
dc_id.append(i)
|
||||
cnt = 0
|
||||
for a in dc_id:
|
||||
num = a - cnt
|
||||
del groundtruths[num]
|
||||
cnt += 1
|
||||
|
||||
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
|
||||
local_tau_table = np.zeros((len(groundtruths), len(detections)))
|
||||
local_pred_str = {}
|
||||
local_gt_str = {}
|
||||
|
||||
for gt_id, gt in enumerate(groundtruths):
|
||||
if len(detections) > 0:
|
||||
for det_id, detection in enumerate(detections):
|
||||
detection_orig = detection
|
||||
detection = [float(x) for x in detection[0].split(',')]
|
||||
detection = list(map(int, detection))
|
||||
pred_seq_str = detection_orig[1].strip()
|
||||
det_x = detection[0::2]
|
||||
det_y = detection[1::2]
|
||||
gt_x = list(map(int, np.squeeze(gt[1])))
|
||||
gt_y = list(map(int, np.squeeze(gt[3])))
|
||||
gt_seq_str = str(gt[4].tolist()[0])
|
||||
|
||||
local_sigma_table[gt_id, det_id] = sigma_calculation(
|
||||
det_x, det_y, gt_x, gt_y)
|
||||
local_tau_table[gt_id, det_id] = tau_calculation(
|
||||
det_x, det_y, gt_x, gt_y)
|
||||
local_pred_str[det_id] = pred_seq_str
|
||||
local_gt_str[gt_id] = gt_seq_str
|
||||
|
||||
global_sigma = local_sigma_table
|
||||
global_tau = local_tau_table
|
||||
global_pred_str = local_pred_str
|
||||
global_gt_str = local_gt_str
|
||||
|
||||
single_data = {}
|
||||
single_data['sigma'] = global_sigma
|
||||
single_data['global_tau'] = global_tau
|
||||
single_data['global_pred_str'] = global_pred_str
|
||||
single_data['global_gt_str'] = global_gt_str
|
||||
return single_data
|
||||
|
||||
|
||||
def combine_results(all_data):
|
||||
tr = 0.7
|
||||
tp = 0.6
|
||||
fsc_k = 0.8
|
||||
k = 2
|
||||
global_sigma = []
|
||||
global_tau = []
|
||||
global_pred_str = []
|
||||
global_gt_str = []
|
||||
for data in all_data:
|
||||
global_sigma.append(data['sigma'])
|
||||
global_tau.append(data['global_tau'])
|
||||
global_pred_str.append(data['global_pred_str'])
|
||||
global_gt_str.append(data['global_gt_str'])
|
||||
|
||||
global_accumulative_recall = 0
|
||||
global_accumulative_precision = 0
|
||||
total_num_gt = 0
|
||||
total_num_det = 0
|
||||
hit_str_count = 0
|
||||
hit_count = 0
|
||||
|
||||
def one_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idy):
|
||||
hit_str_num = 0
|
||||
for gt_id in range(num_gt):
|
||||
gt_matching_qualified_sigma_candidates = np.where(
|
||||
local_sigma_table[gt_id, :] > tr)
|
||||
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
|
||||
0].shape[0]
|
||||
gt_matching_qualified_tau_candidates = np.where(
|
||||
local_tau_table[gt_id, :] > tp)
|
||||
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
|
||||
0].shape[0]
|
||||
|
||||
det_matching_qualified_sigma_candidates = np.where(
|
||||
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
|
||||
> tr)
|
||||
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
|
||||
0].shape[0]
|
||||
det_matching_qualified_tau_candidates = np.where(
|
||||
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
|
||||
tp)
|
||||
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
|
||||
0].shape[0]
|
||||
|
||||
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
|
||||
(det_matching_num_qualified_sigma_candidates == 1) and (
|
||||
det_matching_num_qualified_tau_candidates == 1):
|
||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
||||
|
||||
gt_flag[0, gt_id] = 1
|
||||
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
|
||||
# recg start
|
||||
gt_str_cur = global_gt_str[idy][gt_id]
|
||||
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
|
||||
0]]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
# recg end
|
||||
det_flag[0, matched_det_id] = 1
|
||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
||||
|
||||
def one_to_many(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idy):
|
||||
hit_str_num = 0
|
||||
for gt_id in range(num_gt):
|
||||
# skip the following if the groundtruth was matched
|
||||
if gt_flag[0, gt_id] > 0:
|
||||
continue
|
||||
|
||||
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
|
||||
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
|
||||
|
||||
if num_non_zero_in_sigma >= k:
|
||||
####search for all detections that overlaps with this groundtruth
|
||||
qualified_tau_candidates = np.where((local_tau_table[
|
||||
gt_id, :] >= tp) & (det_flag[0, :] == 0))
|
||||
num_qualified_tau_candidates = qualified_tau_candidates[
|
||||
0].shape[0]
|
||||
|
||||
if num_qualified_tau_candidates == 1:
|
||||
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
|
||||
and
|
||||
(local_sigma_table[gt_id, qualified_tau_candidates] >=
|
||||
tr)):
|
||||
# became an one-to-one case
|
||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
||||
|
||||
gt_flag[0, gt_id] = 1
|
||||
det_flag[0, qualified_tau_candidates] = 1
|
||||
# recg start
|
||||
gt_str_cur = global_gt_str[idy][gt_id]
|
||||
pred_str_cur = global_pred_str[idy][
|
||||
qualified_tau_candidates[0].tolist()[0]]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
# recg end
|
||||
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
|
||||
>= tr):
|
||||
gt_flag[0, gt_id] = 1
|
||||
det_flag[0, qualified_tau_candidates] = 1
|
||||
# recg start
|
||||
gt_str_cur = global_gt_str[idy][gt_id]
|
||||
pred_str_cur = global_pred_str[idy][
|
||||
qualified_tau_candidates[0].tolist()[0]]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
# recg end
|
||||
|
||||
global_accumulative_recall = global_accumulative_recall + fsc_k
|
||||
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
|
||||
|
||||
local_accumulative_recall = local_accumulative_recall + fsc_k
|
||||
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
|
||||
|
||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
||||
|
||||
def many_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idy):
|
||||
hit_str_num = 0
|
||||
for det_id in range(num_det):
|
||||
# skip the following if the detection was matched
|
||||
if det_flag[0, det_id] > 0:
|
||||
continue
|
||||
|
||||
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
|
||||
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
|
||||
|
||||
if num_non_zero_in_tau >= k:
|
||||
####search for all detections that overlaps with this groundtruth
|
||||
qualified_sigma_candidates = np.where((
|
||||
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
|
||||
num_qualified_sigma_candidates = qualified_sigma_candidates[
|
||||
0].shape[0]
|
||||
|
||||
if num_qualified_sigma_candidates == 1:
|
||||
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
|
||||
tp) and
|
||||
(local_sigma_table[qualified_sigma_candidates, det_id]
|
||||
>= tr)):
|
||||
# became an one-to-one case
|
||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
||||
|
||||
gt_flag[0, qualified_sigma_candidates] = 1
|
||||
det_flag[0, det_id] = 1
|
||||
# recg start
|
||||
pred_str_cur = global_pred_str[idy][det_id]
|
||||
gt_len = len(qualified_sigma_candidates[0])
|
||||
for idx in range(gt_len):
|
||||
ele_gt_id = qualified_sigma_candidates[0].tolist()[
|
||||
idx]
|
||||
if ele_gt_id not in global_gt_str[idy]:
|
||||
continue
|
||||
gt_str_cur = global_gt_str[idy][ele_gt_id]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
break
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
break
|
||||
# recg end
|
||||
elif (np.sum(local_tau_table[qualified_sigma_candidates,
|
||||
det_id]) >= tp):
|
||||
det_flag[0, det_id] = 1
|
||||
gt_flag[0, qualified_sigma_candidates] = 1
|
||||
# recg start
|
||||
pred_str_cur = global_pred_str[idy][det_id]
|
||||
gt_len = len(qualified_sigma_candidates[0])
|
||||
for idx in range(gt_len):
|
||||
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
|
||||
if ele_gt_id not in global_gt_str[idy]:
|
||||
continue
|
||||
gt_str_cur = global_gt_str[idy][ele_gt_id]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
break
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
break
|
||||
# recg end
|
||||
|
||||
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
|
||||
global_accumulative_precision = global_accumulative_precision + fsc_k
|
||||
|
||||
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
|
||||
local_accumulative_precision = local_accumulative_precision + fsc_k
|
||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
||||
|
||||
for idx in range(len(global_sigma)):
|
||||
local_sigma_table = np.array(global_sigma[idx])
|
||||
local_tau_table = global_tau[idx]
|
||||
|
||||
num_gt = local_sigma_table.shape[0]
|
||||
num_det = local_sigma_table.shape[1]
|
||||
|
||||
total_num_gt = total_num_gt + num_gt
|
||||
total_num_det = total_num_det + num_det
|
||||
|
||||
local_accumulative_recall = 0
|
||||
local_accumulative_precision = 0
|
||||
gt_flag = np.zeros((1, num_gt))
|
||||
det_flag = np.zeros((1, num_det))
|
||||
|
||||
#######first check for one-to-one case##########
|
||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
||||
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idx)
|
||||
|
||||
hit_str_count += hit_str_num
|
||||
#######then check for one-to-many case##########
|
||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
||||
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idx)
|
||||
hit_str_count += hit_str_num
|
||||
#######then check for many-to-one case##########
|
||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
||||
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idx)
|
||||
hit_str_count += hit_str_num
|
||||
|
||||
try:
|
||||
recall = global_accumulative_recall / total_num_gt
|
||||
except ZeroDivisionError:
|
||||
recall = 0
|
||||
|
||||
try:
|
||||
precision = global_accumulative_precision / total_num_det
|
||||
except ZeroDivisionError:
|
||||
precision = 0
|
||||
|
||||
try:
|
||||
f_score = 2 * precision * recall / (precision + recall)
|
||||
except ZeroDivisionError:
|
||||
f_score = 0
|
||||
|
||||
try:
|
||||
seqerr = 1 - float(hit_str_count) / global_accumulative_recall
|
||||
except ZeroDivisionError:
|
||||
seqerr = 1
|
||||
|
||||
try:
|
||||
recall_e2e = float(hit_str_count) / total_num_gt
|
||||
except ZeroDivisionError:
|
||||
recall_e2e = 0
|
||||
|
||||
try:
|
||||
precision_e2e = float(hit_str_count) / total_num_det
|
||||
except ZeroDivisionError:
|
||||
precision_e2e = 0
|
||||
|
||||
try:
|
||||
f_score_e2e = 2 * precision_e2e * recall_e2e / (
|
||||
precision_e2e + recall_e2e)
|
||||
except ZeroDivisionError:
|
||||
f_score_e2e = 0
|
||||
|
||||
final = {
|
||||
'total_num_gt': total_num_gt,
|
||||
'total_num_det': total_num_det,
|
||||
'global_accumulative_recall': global_accumulative_recall,
|
||||
'hit_str_count': hit_str_count,
|
||||
'recall': recall,
|
||||
'precision': precision,
|
||||
'f_score': f_score,
|
||||
'seqerr': seqerr,
|
||||
'recall_e2e': recall_e2e,
|
||||
'precision_e2e': precision_e2e,
|
||||
'f_score_e2e': f_score_e2e
|
||||
}
|
||||
return final
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
"""
|
||||
:param det_x: [1, N] Xs of detection's vertices
|
||||
:param det_y: [1, N] Ys of detection's vertices
|
||||
:param gt_x: [1, N] Xs of groundtruth's vertices
|
||||
:param gt_y: [1, N] Ys of groundtruth's vertices
|
||||
|
||||
##############
|
||||
All the calculation of 'AREA' in this script is handled by:
|
||||
1) First generating a binary mask with the polygon area filled up with 1's
|
||||
2) Summing up all the 1's
|
||||
"""
|
||||
|
||||
|
||||
def area(x, y):
|
||||
polygon = Polygon(np.stack([x, y], axis=1))
|
||||
return float(polygon.area)
|
||||
|
||||
|
||||
def approx_area_of_intersection(det_x, det_y, gt_x, gt_y):
|
||||
"""
|
||||
This helper determine if both polygons are intersecting with each others with an approximation method.
|
||||
Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
|
||||
"""
|
||||
det_ymax = np.max(det_y)
|
||||
det_xmax = np.max(det_x)
|
||||
det_ymin = np.min(det_y)
|
||||
det_xmin = np.min(det_x)
|
||||
|
||||
gt_ymax = np.max(gt_y)
|
||||
gt_xmax = np.max(gt_x)
|
||||
gt_ymin = np.min(gt_y)
|
||||
gt_xmin = np.min(gt_x)
|
||||
|
||||
all_min_ymax = np.minimum(det_ymax, gt_ymax)
|
||||
all_max_ymin = np.maximum(det_ymin, gt_ymin)
|
||||
|
||||
intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin))
|
||||
|
||||
all_min_xmax = np.minimum(det_xmax, gt_xmax)
|
||||
all_max_xmin = np.maximum(det_xmin, gt_xmin)
|
||||
intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin))
|
||||
|
||||
return intersect_heights * intersect_widths
|
||||
|
||||
|
||||
def area_of_intersection(det_x, det_y, gt_x, gt_y):
|
||||
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
|
||||
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
|
||||
return float(p1.intersection(p2).area)
|
||||
|
||||
|
||||
def area_of_union(det_x, det_y, gt_x, gt_y):
|
||||
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
|
||||
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
|
||||
return float(p1.union(p2).area)
|
||||
|
||||
|
||||
def iou(det_x, det_y, gt_x, gt_y):
|
||||
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
|
||||
area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
|
||||
|
||||
|
||||
def iod(det_x, det_y, gt_x, gt_y):
|
||||
"""
|
||||
This helper determine the fraction of intersection area over detection area
|
||||
"""
|
||||
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
|
||||
area(det_x, det_y) + 1.0)
|
|
@ -0,0 +1,87 @@
|
|||
import paddle
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
|
||||
def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
|
||||
"""
|
||||
"""
|
||||
pos_lists_, pos_masks_, label_lists_ = [], [], []
|
||||
img_bs = batch_size
|
||||
ngpu = int(batch_size / img_bs)
|
||||
img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
|
||||
pos_lists_split, pos_masks_split, label_lists_split = [], [], []
|
||||
for i in range(ngpu):
|
||||
pos_lists_split.append([])
|
||||
pos_masks_split.append([])
|
||||
label_lists_split.append([])
|
||||
|
||||
for i in range(img_ids.shape[0]):
|
||||
img_id = img_ids[i]
|
||||
gpu_id = int(img_id / img_bs)
|
||||
img_id = img_id % img_bs
|
||||
pos_list = pos_lists[i].copy()
|
||||
pos_list[:, 0] = img_id
|
||||
pos_lists_split[gpu_id].append(pos_list)
|
||||
pos_masks_split[gpu_id].append(pos_masks[i].copy())
|
||||
label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
|
||||
# repeat or delete
|
||||
for i in range(ngpu):
|
||||
vp_len = len(pos_lists_split[i])
|
||||
if vp_len <= tcl_bs:
|
||||
for j in range(0, tcl_bs - vp_len):
|
||||
pos_list = pos_lists_split[i][j].copy()
|
||||
pos_lists_split[i].append(pos_list)
|
||||
pos_mask = pos_masks_split[i][j].copy()
|
||||
pos_masks_split[i].append(pos_mask)
|
||||
label_list = copy.deepcopy(label_lists_split[i][j])
|
||||
label_lists_split[i].append(label_list)
|
||||
else:
|
||||
for j in range(0, vp_len - tcl_bs):
|
||||
c_len = len(pos_lists_split[i])
|
||||
pop_id = np.random.permutation(c_len)[0]
|
||||
pos_lists_split[i].pop(pop_id)
|
||||
pos_masks_split[i].pop(pop_id)
|
||||
label_lists_split[i].pop(pop_id)
|
||||
# merge
|
||||
for i in range(ngpu):
|
||||
pos_lists_.extend(pos_lists_split[i])
|
||||
pos_masks_.extend(pos_masks_split[i])
|
||||
label_lists_.extend(label_lists_split[i])
|
||||
return pos_lists_, pos_masks_, label_lists_
|
||||
|
||||
|
||||
def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
|
||||
pad_num, tcl_bs):
|
||||
label_list = label_list.numpy()
|
||||
batch, _, _, _ = label_list.shape
|
||||
pos_list = pos_list.numpy()
|
||||
pos_mask = pos_mask.numpy()
|
||||
pos_list_t = []
|
||||
pos_mask_t = []
|
||||
label_list_t = []
|
||||
for i in range(batch):
|
||||
for j in range(max_text_nums):
|
||||
if pos_mask[i, j].any():
|
||||
pos_list_t.append(pos_list[i][j])
|
||||
pos_mask_t.append(pos_mask[i][j])
|
||||
label_list_t.append(label_list[i][j])
|
||||
pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
|
||||
label_list_t, tcl_bs)
|
||||
label = []
|
||||
tt = [l.tolist() for l in label_list]
|
||||
for i in range(tcl_bs):
|
||||
k = 0
|
||||
for j in range(max_text_length):
|
||||
if tt[i][j][0] != pad_num:
|
||||
k += 1
|
||||
else:
|
||||
break
|
||||
label.append(k)
|
||||
label = paddle.to_tensor(label)
|
||||
label = paddle.cast(label, dtype='int64')
|
||||
pos_list = paddle.to_tensor(pos_list)
|
||||
pos_mask = paddle.to_tensor(pos_mask)
|
||||
label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
|
||||
label_list = paddle.cast(label_list, dtype='int32')
|
||||
return pos_list, pos_mask, label_list, label
|
|
@ -0,0 +1,532 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains various CTC decoders."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import cv2
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from itertools import groupby
|
||||
from skimage.morphology._skeletonize import thin
|
||||
|
||||
|
||||
def get_dict(character_dict_path):
|
||||
character_str = ""
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
character_str += line
|
||||
dict_character = list(character_str)
|
||||
return dict_character
|
||||
|
||||
|
||||
def softmax(logits):
|
||||
"""
|
||||
logits: N x d
|
||||
"""
|
||||
max_value = np.max(logits, axis=1, keepdims=True)
|
||||
exp = np.exp(logits - max_value)
|
||||
exp_sum = np.sum(exp, axis=1, keepdims=True)
|
||||
dist = exp / exp_sum
|
||||
return dist
|
||||
|
||||
|
||||
def get_keep_pos_idxs(labels, remove_blank=None):
|
||||
"""
|
||||
Remove duplicate and get pos idxs of keep items.
|
||||
The value of keep_blank should be [None, 95].
|
||||
"""
|
||||
duplicate_len_list = []
|
||||
keep_pos_idx_list = []
|
||||
keep_char_idx_list = []
|
||||
for k, v_ in groupby(labels):
|
||||
current_len = len(list(v_))
|
||||
if k != remove_blank:
|
||||
current_idx = int(sum(duplicate_len_list) + current_len // 2)
|
||||
keep_pos_idx_list.append(current_idx)
|
||||
keep_char_idx_list.append(k)
|
||||
duplicate_len_list.append(current_len)
|
||||
return keep_char_idx_list, keep_pos_idx_list
|
||||
|
||||
|
||||
def remove_blank(labels, blank=0):
|
||||
new_labels = [x for x in labels if x != blank]
|
||||
return new_labels
|
||||
|
||||
|
||||
def insert_blank(labels, blank=0):
|
||||
new_labels = [blank]
|
||||
for l in labels:
|
||||
new_labels += [l, blank]
|
||||
return new_labels
|
||||
|
||||
|
||||
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
|
||||
"""
|
||||
CTC greedy (best path) decoder.
|
||||
"""
|
||||
raw_str = np.argmax(np.array(probs_seq), axis=1)
|
||||
remove_blank_in_pos = None if keep_blank_in_idxs else blank
|
||||
dedup_str, keep_idx_list = get_keep_pos_idxs(
|
||||
raw_str, remove_blank=remove_blank_in_pos)
|
||||
dst_str = remove_blank(dedup_str, blank=blank)
|
||||
return dst_str, keep_idx_list
|
||||
|
||||
|
||||
def instance_ctc_greedy_decoder(gather_info,
|
||||
logits_map,
|
||||
keep_blank_in_idxs=True):
|
||||
"""
|
||||
gather_info: [[x, y], [x, y] ...]
|
||||
logits_map: H x W X (n_chars + 1)
|
||||
"""
|
||||
_, _, C = logits_map.shape
|
||||
ys, xs = zip(*gather_info)
|
||||
logits_seq = logits_map[list(ys), list(xs)] # n x 96
|
||||
probs_seq = softmax(logits_seq)
|
||||
dst_str, keep_idx_list = ctc_greedy_decoder(
|
||||
probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
|
||||
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
|
||||
return dst_str, keep_gather_list
|
||||
|
||||
|
||||
def ctc_decoder_for_image(gather_info_list, logits_map,
|
||||
keep_blank_in_idxs=True):
|
||||
"""
|
||||
CTC decoder using multiple processes.
|
||||
"""
|
||||
decoder_results = []
|
||||
for gather_info in gather_info_list:
|
||||
res = instance_ctc_greedy_decoder(
|
||||
gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
|
||||
decoder_results.append(res)
|
||||
return decoder_results
|
||||
|
||||
|
||||
def sort_with_direction(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
|
||||
def sort_part_with_direction(pos_list, point_direction):
|
||||
pos_list = np.array(pos_list).reshape(-1, 2)
|
||||
point_direction = np.array(point_direction).reshape(-1, 2)
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
|
||||
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
||||
return sorted_list, sorted_direction
|
||||
|
||||
pos_list = np.array(pos_list).reshape(-1, 2)
|
||||
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
|
||||
point_direction)
|
||||
|
||||
point_num = len(sorted_point)
|
||||
if point_num >= 16:
|
||||
middle_num = point_num // 2
|
||||
first_part_point = sorted_point[:middle_num]
|
||||
first_point_direction = sorted_direction[:middle_num]
|
||||
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
||||
first_part_point, first_point_direction)
|
||||
|
||||
last_part_point = sorted_point[middle_num:]
|
||||
last_point_direction = sorted_direction[middle_num:]
|
||||
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
||||
last_part_point, last_point_direction)
|
||||
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
||||
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
||||
|
||||
return sorted_point, np.array(sorted_direction)
|
||||
|
||||
|
||||
def add_id(pos_list, image_id=0):
|
||||
"""
|
||||
Add id for gather feature, for inference.
|
||||
"""
|
||||
new_list = []
|
||||
for item in pos_list:
|
||||
new_list.append((image_id, item[0], item[1]))
|
||||
return new_list
|
||||
|
||||
|
||||
def sort_and_expand_with_direction(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
h, w, _ = f_direction.shape
|
||||
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
||||
|
||||
# expand along
|
||||
point_num = len(sorted_list)
|
||||
sub_direction_len = max(point_num // 3, 2)
|
||||
left_direction = point_direction[:sub_direction_len, :]
|
||||
right_dirction = point_direction[point_num - sub_direction_len:, :]
|
||||
|
||||
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
||||
left_average_len = np.linalg.norm(left_average_direction)
|
||||
left_start = np.array(sorted_list[0])
|
||||
left_step = left_average_direction / (left_average_len + 1e-6)
|
||||
|
||||
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
||||
right_average_len = np.linalg.norm(right_average_direction)
|
||||
right_step = right_average_direction / (right_average_len + 1e-6)
|
||||
right_start = np.array(sorted_list[-1])
|
||||
|
||||
append_num = max(
|
||||
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
||||
left_list = []
|
||||
right_list = []
|
||||
for i in range(append_num):
|
||||
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
|
||||
'int32').tolist()
|
||||
if ly < h and lx < w and (ly, lx) not in left_list:
|
||||
left_list.append((ly, lx))
|
||||
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
|
||||
'int32').tolist()
|
||||
if ry < h and rx < w and (ry, rx) not in right_list:
|
||||
right_list.append((ry, rx))
|
||||
|
||||
all_list = left_list[::-1] + sorted_list + right_list
|
||||
return all_list
|
||||
|
||||
|
||||
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
binary_tcl_map: h x w
|
||||
"""
|
||||
h, w, _ = f_direction.shape
|
||||
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
||||
|
||||
# expand along
|
||||
point_num = len(sorted_list)
|
||||
sub_direction_len = max(point_num // 3, 2)
|
||||
left_direction = point_direction[:sub_direction_len, :]
|
||||
right_dirction = point_direction[point_num - sub_direction_len:, :]
|
||||
|
||||
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
||||
left_average_len = np.linalg.norm(left_average_direction)
|
||||
left_start = np.array(sorted_list[0])
|
||||
left_step = left_average_direction / (left_average_len + 1e-6)
|
||||
|
||||
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
||||
right_average_len = np.linalg.norm(right_average_direction)
|
||||
right_step = right_average_direction / (right_average_len + 1e-6)
|
||||
right_start = np.array(sorted_list[-1])
|
||||
|
||||
append_num = max(
|
||||
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
||||
max_append_num = 2 * append_num
|
||||
|
||||
left_list = []
|
||||
right_list = []
|
||||
for i in range(max_append_num):
|
||||
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
|
||||
'int32').tolist()
|
||||
if ly < h and lx < w and (ly, lx) not in left_list:
|
||||
if binary_tcl_map[ly, lx] > 0.5:
|
||||
left_list.append((ly, lx))
|
||||
else:
|
||||
break
|
||||
|
||||
for i in range(max_append_num):
|
||||
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
|
||||
'int32').tolist()
|
||||
if ry < h and rx < w and (ry, rx) not in right_list:
|
||||
if binary_tcl_map[ry, rx] > 0.5:
|
||||
right_list.append((ry, rx))
|
||||
else:
|
||||
break
|
||||
|
||||
all_list = left_list[::-1] + sorted_list + right_list
|
||||
return all_list
|
||||
|
||||
|
||||
def generate_pivot_list_curved(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_expand=True,
|
||||
is_backbone=False,
|
||||
image_id=0):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
p_score = p_score[0]
|
||||
f_direction = f_direction.transpose(1, 2, 0)
|
||||
p_tcl_map = (p_score > score_thresh) * 1.0
|
||||
skeleton_map = thin(p_tcl_map)
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
skeleton_map.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
center_pos_yxs = []
|
||||
end_points_yxs = []
|
||||
instance_center_pos_yxs = []
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 3:
|
||||
continue
|
||||
|
||||
if is_expand:
|
||||
pos_list_sorted = sort_and_expand_with_direction_v2(
|
||||
pos_list, f_direction, p_tcl_map)
|
||||
else:
|
||||
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
|
||||
all_pos_yxs.append(pos_list_sorted)
|
||||
|
||||
# use decoder to filter backgroud points.
|
||||
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||
decode_res = ctc_decoder_for_image(
|
||||
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
|
||||
for decoded_str, keep_yxs_list in decode_res:
|
||||
if is_backbone:
|
||||
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
|
||||
instance_center_pos_yxs.append(keep_yxs_list_with_id)
|
||||
else:
|
||||
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
|
||||
center_pos_yxs.extend(keep_yxs_list)
|
||||
|
||||
if is_backbone:
|
||||
return instance_center_pos_yxs
|
||||
else:
|
||||
return center_pos_yxs, end_points_yxs
|
||||
|
||||
|
||||
def generate_pivot_list_horizontal(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_backbone=False,
|
||||
image_id=0):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
p_score = p_score[0]
|
||||
f_direction = f_direction.transpose(1, 2, 0)
|
||||
p_tcl_map_bi = (p_score > score_thresh) * 1.0
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
p_tcl_map_bi.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
center_pos_yxs = []
|
||||
end_points_yxs = []
|
||||
instance_center_pos_yxs = []
|
||||
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 5:
|
||||
continue
|
||||
|
||||
# add rule here
|
||||
main_direction = extract_main_direction(pos_list,
|
||||
f_direction) # y x
|
||||
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
|
||||
is_h_angle = abs(np.sum(
|
||||
main_direction * reference_directin)) < math.cos(math.pi / 180 *
|
||||
70)
|
||||
|
||||
point_yxs = np.array(pos_list)
|
||||
max_y, max_x = np.max(point_yxs, axis=0)
|
||||
min_y, min_x = np.min(point_yxs, axis=0)
|
||||
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
|
||||
|
||||
pos_list_final = []
|
||||
if is_h_len:
|
||||
xs = np.unique(xs)
|
||||
for x in xs:
|
||||
ys = instance_label_map[:, x].copy().reshape((-1, ))
|
||||
y = int(np.where(ys == instance_id)[0].mean())
|
||||
pos_list_final.append((y, x))
|
||||
else:
|
||||
ys = np.unique(ys)
|
||||
for y in ys:
|
||||
xs = instance_label_map[y, :].copy().reshape((-1, ))
|
||||
x = int(np.where(xs == instance_id)[0].mean())
|
||||
pos_list_final.append((y, x))
|
||||
|
||||
pos_list_sorted, _ = sort_with_direction(pos_list_final,
|
||||
f_direction)
|
||||
all_pos_yxs.append(pos_list_sorted)
|
||||
|
||||
# use decoder to filter backgroud points.
|
||||
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||
decode_res = ctc_decoder_for_image(
|
||||
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
|
||||
for decoded_str, keep_yxs_list in decode_res:
|
||||
if is_backbone:
|
||||
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
|
||||
instance_center_pos_yxs.append(keep_yxs_list_with_id)
|
||||
else:
|
||||
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
|
||||
center_pos_yxs.extend(keep_yxs_list)
|
||||
|
||||
if is_backbone:
|
||||
return instance_center_pos_yxs
|
||||
else:
|
||||
return center_pos_yxs, end_points_yxs
|
||||
|
||||
|
||||
def generate_pivot_list(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_backbone=False,
|
||||
is_curved=True,
|
||||
image_id=0):
|
||||
"""
|
||||
Warp all the function together.
|
||||
"""
|
||||
if is_curved:
|
||||
return generate_pivot_list_curved(
|
||||
p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=score_thresh,
|
||||
is_expand=True,
|
||||
is_backbone=is_backbone,
|
||||
image_id=image_id)
|
||||
else:
|
||||
return generate_pivot_list_horizontal(
|
||||
p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=score_thresh,
|
||||
is_backbone=is_backbone,
|
||||
image_id=image_id)
|
||||
|
||||
|
||||
# for refine module
|
||||
def extract_main_direction(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
pos_list = np.array(pos_list)
|
||||
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
average_direction = average_direction / (
|
||||
np.linalg.norm(average_direction) + 1e-6)
|
||||
return average_direction
|
||||
|
||||
|
||||
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
|
||||
"""
|
||||
pos_list_full = np.array(pos_list).reshape(-1, 3)
|
||||
pos_list = pos_list_full[:, 1:]
|
||||
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
||||
return sorted_list
|
||||
|
||||
|
||||
def sort_by_direction_with_image_id(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
|
||||
def sort_part_with_direction(pos_list_full, point_direction):
|
||||
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
|
||||
pos_list = pos_list_full[:, 1:]
|
||||
point_direction = np.array(point_direction).reshape(-1, 2)
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
||||
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
||||
return sorted_list, sorted_direction
|
||||
|
||||
pos_list = np.array(pos_list).reshape(-1, 3)
|
||||
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
|
||||
point_direction)
|
||||
|
||||
point_num = len(sorted_point)
|
||||
if point_num >= 16:
|
||||
middle_num = point_num // 2
|
||||
first_part_point = sorted_point[:middle_num]
|
||||
first_point_direction = sorted_direction[:middle_num]
|
||||
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
||||
first_part_point, first_point_direction)
|
||||
|
||||
last_part_point = sorted_point[middle_num:]
|
||||
last_point_direction = sorted_direction[middle_num:]
|
||||
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
||||
last_part_point, last_point_direction)
|
||||
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
||||
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
||||
|
||||
return sorted_point
|
||||
|
||||
|
||||
def generate_pivot_list_tt_inference(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_backbone=False,
|
||||
is_curved=True,
|
||||
image_id=0):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
p_score = p_score[0]
|
||||
f_direction = f_direction.transpose(1, 2, 0)
|
||||
p_tcl_map = (p_score > score_thresh) * 1.0
|
||||
skeleton_map = thin(p_tcl_map)
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
skeleton_map.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 3:
|
||||
continue
|
||||
pos_list_sorted = sort_and_expand_with_direction_v2(
|
||||
pos_list, f_direction, p_tcl_map)
|
||||
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
|
||||
all_pos_yxs.append(pos_list_sorted_with_id)
|
||||
return all_pos_yxs
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import cv2
|
||||
import time
|
||||
|
||||
|
||||
def resize_image(im, max_side_len=512):
|
||||
"""
|
||||
resize image to a size multiple of max_stride which is required by the network
|
||||
:param im: the resized image
|
||||
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
:return: the resized image and the resize ratio
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
if resize_h > resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
def resize_image_min(im, max_side_len=512):
|
||||
"""
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
if resize_h < resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
def resize_image_for_totaltext(im, max_side_len=512):
|
||||
"""
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
ratio = 1.25
|
||||
if h * ratio > max_side_len:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
def point_pair2poly(point_pair_list):
|
||||
"""
|
||||
Transfer vertical point_pairs into poly point in clockwise.
|
||||
"""
|
||||
pair_length_list = []
|
||||
for point_pair in point_pair_list:
|
||||
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
|
||||
pair_length_list.append(pair_length)
|
||||
pair_length_list = np.array(pair_length_list)
|
||||
pair_info = (pair_length_list.max(), pair_length_list.min(),
|
||||
pair_length_list.mean())
|
||||
|
||||
point_num = len(point_pair_list) * 2
|
||||
point_list = [0] * point_num
|
||||
for idx, point_pair in enumerate(point_pair_list):
|
||||
point_list[idx] = point_pair[0]
|
||||
point_list[point_num - 1 - idx] = point_pair[1]
|
||||
return np.array(point_list).reshape(-1, 2), pair_info
|
||||
|
||||
|
||||
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
|
||||
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
|
||||
"""
|
||||
expand poly along width.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
left_quad = np.array(
|
||||
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||
right_quad = np.array(
|
||||
[
|
||||
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]
|
||||
],
|
||||
dtype=np.float32)
|
||||
right_ratio = 1.0 + \
|
||||
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||
poly[0] = left_quad_expand[0]
|
||||
poly[-1] = left_quad_expand[-1]
|
||||
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||
poly[point_num // 2] = right_quad_expand[2]
|
||||
return poly
|
||||
|
||||
|
||||
def norm2(x, axis=None):
|
||||
if axis:
|
||||
return np.sqrt(np.sum(x**2, axis=axis))
|
||||
return np.sqrt(np.sum(x**2))
|
||||
|
||||
|
||||
def cos(p1, p2):
|
||||
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
|
|
@ -7,4 +7,5 @@ opencv-python==4.2.0.32
|
|||
tqdm
|
||||
numpy
|
||||
visualdl
|
||||
python-Levenshtein
|
||||
python-Levenshtein
|
||||
opencv-contrib-python
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time
|
||||
import sys
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.postprocess import build_post_process
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextE2E(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.e2e_algorithm = args.e2e_algorithm
|
||||
pre_process_list = [{
|
||||
'E2EResizeForTest': {}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': ['image', 'shape']
|
||||
}
|
||||
}]
|
||||
postprocess_params = {}
|
||||
if self.e2e_algorithm == "PGNet":
|
||||
pre_process_list[0] = {
|
||||
'E2EResizeForTest': {
|
||||
'max_side_len': args.e2e_limit_side_len,
|
||||
'valid_set': 'totaltext'
|
||||
}
|
||||
}
|
||||
postprocess_params['name'] = 'PGPostProcess'
|
||||
postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
|
||||
postprocess_params["character_dict_path"] = args.e2e_char_dict_path
|
||||
postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
|
||||
self.e2e_pgnet_polygon = args.e2e_pgnet_polygon
|
||||
else:
|
||||
logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
|
||||
sys.exit(0)
|
||||
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
|
||||
args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
|
||||
# self.predictor.eval()
|
||||
|
||||
def clip_det_res(self, points, img_height, img_width):
|
||||
for pno in range(points.shape[0]):
|
||||
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
||||
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
||||
return points
|
||||
|
||||
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def __call__(self, img):
|
||||
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
data = transform(data, self.preprocess_op)
|
||||
img, shape_list = data
|
||||
if img is None:
|
||||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
shape_list = np.expand_dims(shape_list, axis=0)
|
||||
img = img.copy()
|
||||
starttime = time.time()
|
||||
|
||||
self.input_tensor.copy_from_cpu(img)
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
|
||||
preds = {}
|
||||
if self.e2e_algorithm == 'PGNet':
|
||||
preds['f_border'] = outputs[0]
|
||||
preds['f_char'] = outputs[1]
|
||||
preds['f_direction'] = outputs[2]
|
||||
preds['f_score'] = outputs[3]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
post_result = self.postprocess_op(preds, shape_list)
|
||||
points, strs = post_result['points'], post_result['strs']
|
||||
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
|
||||
elapse = time.time() - starttime
|
||||
return dt_boxes, strs, elapse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = utility.parse_args()
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
text_detector = TextE2E(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
draw_img_save = "./inference_results"
|
||||
if not os.path.exists(draw_img_save):
|
||||
os.makedirs(draw_img_save)
|
||||
for image_file in image_file_list:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
points, strs, elapse = text_detector(img)
|
||||
if count > 0:
|
||||
total_time += elapse
|
||||
count += 1
|
||||
logger.info("Predict time of {}: {}".format(image_file, elapse))
|
||||
src_im = utility.draw_e2e_res(points, strs, image_file)
|
||||
img_name_pure = os.path.split(image_file)[-1]
|
||||
img_path = os.path.join(draw_img_save,
|
||||
"e2e_res_{}".format(img_name_pure))
|
||||
cv2.imwrite(img_path, src_im)
|
||||
logger.info("The visualized image saved in {}".format(img_path))
|
||||
if count > 1:
|
||||
logger.info("Avg Time: {}".format(total_time / (count - 1)))
|
|
@ -74,6 +74,19 @@ def parse_args():
|
|||
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
|
||||
parser.add_argument("--drop_score", type=float, default=0.5)
|
||||
|
||||
# params for e2e
|
||||
parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
|
||||
parser.add_argument("--e2e_model_dir", type=str)
|
||||
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
|
||||
parser.add_argument("--e2e_limit_type", type=str, default='max')
|
||||
|
||||
# PGNet parmas
|
||||
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
|
||||
parser.add_argument(
|
||||
"--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
|
||||
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
|
||||
parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
|
||||
|
||||
# params for text classifier
|
||||
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||
parser.add_argument("--cls_model_dir", type=str)
|
||||
|
@ -93,8 +106,10 @@ def create_predictor(args, mode, logger):
|
|||
model_dir = args.det_model_dir
|
||||
elif mode == 'cls':
|
||||
model_dir = args.cls_model_dir
|
||||
else:
|
||||
elif mode == 'rec':
|
||||
model_dir = args.rec_model_dir
|
||||
else:
|
||||
model_dir = args.e2e_model_dir
|
||||
|
||||
if model_dir is None:
|
||||
logger.info("not find {} model file path {}".format(mode, model_dir))
|
||||
|
@ -148,6 +163,22 @@ def create_predictor(args, mode, logger):
|
|||
return predictor, input_tensor, output_tensors
|
||||
|
||||
|
||||
def draw_e2e_res(dt_boxes, strs, img_path):
|
||||
src_im = cv2.imread(img_path)
|
||||
for box, str in zip(dt_boxes, strs):
|
||||
box = box.astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
||||
cv2.putText(
|
||||
src_im,
|
||||
str,
|
||||
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
|
||||
fontFace=cv2.FONT_HERSHEY_COMPLEX,
|
||||
fontScale=0.7,
|
||||
color=(0, 255, 0),
|
||||
thickness=1)
|
||||
return src_im
|
||||
|
||||
|
||||
def draw_text_det_res(dt_boxes, img_path):
|
||||
src_im = cv2.imread(img_path)
|
||||
for box in dt_boxes:
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import paddle
|
||||
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
|
||||
|
||||
def draw_e2e_res(dt_boxes, strs, config, img, img_name):
|
||||
if len(dt_boxes) > 0:
|
||||
src_im = img
|
||||
for box, str in zip(dt_boxes, strs):
|
||||
box = box.astype(np.int32).reshape((-1, 1, 2))
|
||||
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
||||
cv2.putText(
|
||||
src_im,
|
||||
str,
|
||||
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
|
||||
fontFace=cv2.FONT_HERSHEY_COMPLEX,
|
||||
fontScale=0.7,
|
||||
color=(0, 255, 0),
|
||||
thickness=1)
|
||||
save_det_path = os.path.dirname(config['Global'][
|
||||
'save_res_path']) + "/e2e_results/"
|
||||
if not os.path.exists(save_det_path):
|
||||
os.makedirs(save_det_path)
|
||||
save_path = os.path.join(save_det_path, os.path.basename(img_name))
|
||||
cv2.imwrite(save_path, src_im)
|
||||
logger.info("The e2e Image saved in {}".format(save_path))
|
||||
|
||||
|
||||
def main():
|
||||
global_config = config['Global']
|
||||
|
||||
# build model
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model, logger)
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
for op in config['Eval']['dataset']['transforms']:
|
||||
op_name = list(op)[0]
|
||||
if 'Label' in op_name:
|
||||
continue
|
||||
elif op_name == 'KeepKeys':
|
||||
op[op_name]['keep_keys'] = ['image', 'shape']
|
||||
transforms.append(op)
|
||||
|
||||
ops = create_operators(transforms, global_config)
|
||||
|
||||
save_res_path = config['Global']['save_res_path']
|
||||
if not os.path.exists(os.path.dirname(save_res_path)):
|
||||
os.makedirs(os.path.dirname(save_res_path))
|
||||
|
||||
model.eval()
|
||||
with open(save_res_path, "wb") as fout:
|
||||
for file in get_image_file_list(config['Global']['infer_img']):
|
||||
logger.info("infer_img: {}".format(file))
|
||||
with open(file, 'rb') as f:
|
||||
img = f.read()
|
||||
data = {'image': img}
|
||||
batch = transform(data, ops)
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
shape_list = np.expand_dims(batch[1], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds, shape_list)
|
||||
points, strs = post_result['points'], post_result['strs']
|
||||
# write resule
|
||||
dt_boxes_json = []
|
||||
for poly, str in zip(points, strs):
|
||||
tmp_json = {"transcription": str}
|
||||
tmp_json['points'] = poly.tolist()
|
||||
dt_boxes_json.append(tmp_json)
|
||||
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
|
||||
fout.write(otstr.encode())
|
||||
src_img = cv2.imread(file)
|
||||
draw_e2e_res(points, strs, config, src_img, file)
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
main()
|
|
@ -375,7 +375,8 @@ def preprocess(is_train=False):
|
|||
|
||||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
Loading…
Reference in New Issue