commit
14fce808ff
|
@ -0,0 +1,135 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 600
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/det_mv3_pse/
|
||||
save_epoch_step: 600
|
||||
# evaluation is run every 63 iterations
|
||||
eval_batch_step: [ 0,63 ]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints: #./output/det_r50_vd_pse_batch8_ColorJitter/best_accuracy
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_pse/predicts_pse.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: PSE
|
||||
Transform: null
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
Neck:
|
||||
name: FPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: PSEHead
|
||||
hidden_dim: 96
|
||||
out_channels: 7
|
||||
|
||||
Loss:
|
||||
name: PSELoss
|
||||
alpha: 0.7
|
||||
ohem_ratio: 3
|
||||
kernel_sample_mask: pred
|
||||
reduction: none
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Step
|
||||
learning_rate: 0.001
|
||||
step_size: 200
|
||||
gamma: 0.1
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.0005
|
||||
|
||||
PostProcess:
|
||||
name: PSEPostProcess
|
||||
thresh: 0
|
||||
box_thresh: 0.85
|
||||
min_area: 16
|
||||
box_type: box # 'box' or 'poly'
|
||||
scale: 1
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [ 1.0 ]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- ColorJitter:
|
||||
brightness: 0.12549019607843137
|
||||
saturation: 0.5
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Resize, 'args': { 'size': [ 0.5, 3 ] } }
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [ -10, 10 ] } }
|
||||
- MakePseGt:
|
||||
kernel_num: 7
|
||||
min_shrink_ratio: 0.4
|
||||
size: 640
|
||||
- RandomCropImgMask:
|
||||
size: [ 640,640 ]
|
||||
main_key: gt_text
|
||||
crop_keys: [ 'image', 'gt_text', 'gt_kernels', 'mask' ]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [ 0.485, 0.456, 0.406 ]
|
||||
std: [ 0.229, 0.224, 0.225 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'gt_text', 'gt_kernels', 'mask' ] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 16
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
ratio_list: [ 1.0 ]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
limit_side_len: 736
|
||||
limit_type: min
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [ 0.485, 0.456, 0.406 ]
|
||||
std: [ 0.229, 0.224, 0.225 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'shape', 'polys', 'ignore_tags' ]
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 8
|
|
@ -0,0 +1,134 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 600
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/det_r50_vd_pse/
|
||||
save_epoch_step: 600
|
||||
# evaluation is run every 125 iterations
|
||||
eval_batch_step: [ 0,125 ]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained
|
||||
checkpoints: #./output/det_r50_vd_pse_batch8_ColorJitter/best_accuracy
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_pse/predicts_pse.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: PSE
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 50
|
||||
Neck:
|
||||
name: FPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: PSEHead
|
||||
hidden_dim: 256
|
||||
out_channels: 7
|
||||
|
||||
Loss:
|
||||
name: PSELoss
|
||||
alpha: 0.7
|
||||
ohem_ratio: 3
|
||||
kernel_sample_mask: pred
|
||||
reduction: none
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Step
|
||||
learning_rate: 0.0001
|
||||
step_size: 200
|
||||
gamma: 0.1
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.0005
|
||||
|
||||
PostProcess:
|
||||
name: PSEPostProcess
|
||||
thresh: 0
|
||||
box_thresh: 0.85
|
||||
min_area: 16
|
||||
box_type: box # 'box' or 'poly'
|
||||
scale: 1
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [ 1.0 ]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- ColorJitter:
|
||||
brightness: 0.12549019607843137
|
||||
saturation: 0.5
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Resize, 'args': { 'size': [ 0.5, 3 ] } }
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [ -10, 10 ] } }
|
||||
- MakePseGt:
|
||||
kernel_num: 7
|
||||
min_shrink_ratio: 0.4
|
||||
size: 640
|
||||
- RandomCropImgMask:
|
||||
size: [ 640,640 ]
|
||||
main_key: gt_text
|
||||
crop_keys: [ 'image', 'gt_text', 'gt_kernels', 'mask' ]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [ 0.485, 0.456, 0.406 ]
|
||||
std: [ 0.229, 0.224, 0.225 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'gt_text', 'gt_kernels', 'mask' ] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
ratio_list: [ 1.0 ]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
limit_side_len: 736
|
||||
limit_type: min
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [ 0.485, 0.456, 0.406 ]
|
||||
std: [ 0.229, 0.224, 0.225 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'shape', 'polys', 'ignore_tags' ]
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 8
|
|
@ -9,11 +9,13 @@
|
|||
### 1.文本检测算法
|
||||
|
||||
PaddleOCR开源的文本检测算法列表:
|
||||
- [x] DB([paper]( https://arxiv.org/abs/1911.08947)) [2](ppocr推荐)
|
||||
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))[1]
|
||||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4]
|
||||
- [x] DB([paper]( https://arxiv.org/abs/1911.08947))(ppocr推荐)
|
||||
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
|
||||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))
|
||||
- [x] PSENet([paper](https://arxiv.org/abs/1903.12473v2))
|
||||
|
||||
在ICDAR2015文本检测公开数据集上,算法效果如下:
|
||||
|
||||
|模型|骨干网络|precision|recall|Hmean|下载链接|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|EAST|ResNet50_vd|85.80%|86.71%|86.25%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|
||||
|
@ -21,6 +23,8 @@ PaddleOCR开源的文本检测算法列表:
|
|||
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|
||||
|DB|MobileNetV3|77.29%|73.08%|75.12%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|
||||
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|
||||
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|
||||
|PSE|MobileNetV3|82.20%|70.48%|75.89%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
|
||||
|
||||
在Total-text文本检测公开数据集上,算法效果如下:
|
||||
|
||||
|
@ -39,15 +43,15 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训
|
|||
### 2.文本识别算法
|
||||
|
||||
PaddleOCR基于动态图开源的文本识别算法列表:
|
||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7](ppocr推荐)
|
||||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
|
||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))(ppocr推荐)
|
||||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
|
||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))
|
||||
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
|
||||
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))
|
||||
|
||||
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
参考[DTRB](https://arxiv.org/abs/1904.01906) 文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|模型|骨干网络|Avg Accuracy|模型存储命名|下载链接|
|
||||
|---|---|---|---|---|
|
||||
|
|
|
@ -11,9 +11,10 @@ This tutorial lists the text detection algorithms and text recognition algorithm
|
|||
### 1. Text Detection Algorithm
|
||||
|
||||
PaddleOCR open source text detection algorithms list:
|
||||
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))[2]
|
||||
- [x] DB([paper](https://arxiv.org/abs/1911.08947))[1]
|
||||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4]
|
||||
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
|
||||
- [x] DB([paper](https://arxiv.org/abs/1911.08947))
|
||||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))
|
||||
- [x] PSE([paper](https://arxiv.org/abs/1903.12473v2))
|
||||
|
||||
On the ICDAR2015 dataset, the text detection result is as follows:
|
||||
|
||||
|
@ -24,6 +25,8 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|
|||
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|
||||
|DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|
||||
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|
||||
|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
|
||||
|PSE|MobileNetV3|82.20%|70.48%|75.89%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
|
||||
|
||||
On Total-Text dataset, the text detection result is as follows:
|
||||
|
||||
|
@ -41,11 +44,11 @@ For the training guide and use of PaddleOCR text detection algorithms, please re
|
|||
### 2. Text Recognition Algorithm
|
||||
|
||||
PaddleOCR open-source text recognition algorithms list:
|
||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7]
|
||||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
|
||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))
|
||||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
|
||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))
|
||||
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
|
||||
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))
|
||||
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from paddle.vision.transforms import ColorJitter as pp_ColorJitter
|
||||
|
||||
__all__ = ['ColorJitter']
|
||||
|
||||
class ColorJitter(object):
|
||||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs):
|
||||
self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
|
||||
|
||||
def __call__(self, data):
|
||||
image = data['image']
|
||||
image = self.aug(image)
|
||||
data['image'] = image
|
||||
return data
|
|
@ -19,11 +19,13 @@ from __future__ import unicode_literals
|
|||
from .iaa_augment import IaaAugment
|
||||
from .make_border_map import MakeBorderMap
|
||||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
||||
from .make_pse_gt import MakePseGt
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .ColorJitter import ColorJitter
|
||||
from .operators import *
|
||||
from .label_ops import *
|
||||
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pyclipper
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
__all__ = ['MakePseGt']
|
||||
|
||||
class MakePseGt(object):
|
||||
r'''
|
||||
Making binary mask from detection data with ICDAR format.
|
||||
Typically following the process of class `MakeICDARData`.
|
||||
'''
|
||||
|
||||
def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
|
||||
self.kernel_num = kernel_num
|
||||
self.min_shrink_ratio = min_shrink_ratio
|
||||
self.size = size
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
image = data['image']
|
||||
text_polys = data['polys']
|
||||
ignore_tags = data['ignore_tags']
|
||||
|
||||
h, w, _ = image.shape
|
||||
short_edge = min(h, w)
|
||||
if short_edge < self.size:
|
||||
# keep short_size >= self.size
|
||||
scale = self.size / short_edge
|
||||
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
||||
text_polys *= scale
|
||||
|
||||
gt_kernels = []
|
||||
for i in range(1,self.kernel_num+1):
|
||||
# s1->sn, from big to small
|
||||
rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i
|
||||
text_kernel, ignore_tags = self.generate_kernel(image.shape[0:2], rate, text_polys, ignore_tags)
|
||||
gt_kernels.append(text_kernel)
|
||||
|
||||
training_mask = np.ones(image.shape[0:2], dtype='uint8')
|
||||
for i in range(text_polys.shape[0]):
|
||||
if ignore_tags[i]:
|
||||
cv2.fillPoly(training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0)
|
||||
|
||||
gt_kernels = np.array(gt_kernels)
|
||||
gt_kernels[gt_kernels > 0] = 1
|
||||
|
||||
data['image'] = image
|
||||
data['polys'] = text_polys
|
||||
data['gt_kernels'] = gt_kernels[0:]
|
||||
data['gt_text'] = gt_kernels[0]
|
||||
data['mask'] = training_mask.astype('float32')
|
||||
return data
|
||||
|
||||
def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None):
|
||||
h, w = img_size
|
||||
text_kernel = np.zeros((h, w), dtype=np.float32)
|
||||
for i, poly in enumerate(text_polys):
|
||||
polygon = Polygon(poly)
|
||||
distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (polygon.length + 1e-6)
|
||||
subject = [tuple(l) for l in poly]
|
||||
pco = pyclipper.PyclipperOffset()
|
||||
pco.AddPath(subject, pyclipper.JT_ROUND,
|
||||
pyclipper.ET_CLOSEDPOLYGON)
|
||||
shrinked = np.array(pco.Execute(-distance))
|
||||
|
||||
if len(shrinked) == 0 or shrinked.size == 0:
|
||||
if ignore_tags is not None:
|
||||
ignore_tags[i] = True
|
||||
continue
|
||||
try:
|
||||
shrinked = np.array(shrinked[0]).reshape(-1, 2)
|
||||
except:
|
||||
if ignore_tags is not None:
|
||||
ignore_tags[i] = True
|
||||
continue
|
||||
cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
|
||||
return text_kernel, ignore_tags
|
|
@ -164,47 +164,55 @@ class EastRandomCropData(object):
|
|||
return data
|
||||
|
||||
|
||||
class PSERandomCrop(object):
|
||||
def __init__(self, size, **kwargs):
|
||||
class RandomCropImgMask(object):
|
||||
def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
|
||||
self.size = size
|
||||
self.main_key = main_key
|
||||
self.crop_keys = crop_keys
|
||||
self.p = p
|
||||
|
||||
def __call__(self, data):
|
||||
imgs = data['imgs']
|
||||
image = data['image']
|
||||
|
||||
h, w = imgs[0].shape[0:2]
|
||||
h, w = image.shape[0:2]
|
||||
th, tw = self.size
|
||||
if w == tw and h == th:
|
||||
return imgs
|
||||
return data
|
||||
|
||||
# label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
|
||||
if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
|
||||
# 文本实例的左上角点
|
||||
tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
|
||||
mask = data[self.main_key]
|
||||
if np.max(mask) > 0 and random.random() > self.p:
|
||||
# make sure to crop the text region
|
||||
tl = np.min(np.where(mask > 0), axis=1) - (th, tw)
|
||||
tl[tl < 0] = 0
|
||||
# 文本实例的右下角点
|
||||
br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
|
||||
br = np.max(np.where(mask > 0), axis=1) - (th, tw)
|
||||
br[br < 0] = 0
|
||||
# 保证选到右下角点时,有足够的距离进行crop
|
||||
|
||||
br[0] = min(br[0], h - th)
|
||||
br[1] = min(br[1], w - tw)
|
||||
|
||||
for _ in range(50000):
|
||||
i = random.randint(tl[0], br[0])
|
||||
j = random.randint(tl[1], br[1])
|
||||
# 保证shrink_label_map有文本
|
||||
if imgs[1][i:i + th, j:j + tw].sum() <= 0:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
|
||||
j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
|
||||
else:
|
||||
i = random.randint(0, h - th)
|
||||
j = random.randint(0, w - tw)
|
||||
i = random.randint(0, h - th) if h - th > 0 else 0
|
||||
j = random.randint(0, w - tw) if w - tw > 0 else 0
|
||||
|
||||
# return i, j, th, tw
|
||||
for idx in range(len(imgs)):
|
||||
if len(imgs[idx].shape) == 3:
|
||||
imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
|
||||
else:
|
||||
imgs[idx] = imgs[idx][i:i + th, j:j + tw]
|
||||
data['imgs'] = imgs
|
||||
for k in data:
|
||||
if k in self.crop_keys:
|
||||
if len(data[k].shape) == 3:
|
||||
if np.argmin(data[k].shape) == 0:
|
||||
img = data[k][:, i:i + th, j:j + tw]
|
||||
if img.shape[1] != img.shape[2]:
|
||||
a = 1
|
||||
elif np.argmin(data[k].shape) == 2:
|
||||
img = data[k][i:i + th, j:j + tw, :]
|
||||
if img.shape[1] != img.shape[0]:
|
||||
a = 1
|
||||
else:
|
||||
img = data[k]
|
||||
else:
|
||||
img = data[k][i:i + th, j:j + tw]
|
||||
if img.shape[0] != img.shape[1]:
|
||||
a = 1
|
||||
data[k] = img
|
||||
return data
|
||||
|
|
|
@ -20,6 +20,7 @@ import paddle.nn as nn
|
|||
from .det_db_loss import DBLoss
|
||||
from .det_east_loss import EASTLoss
|
||||
from .det_sast_loss import SASTLoss
|
||||
from .det_pse_loss import PSELoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
|
@ -42,10 +43,12 @@ from .combined_loss import CombinedLoss
|
|||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss'
|
||||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
|
||||
'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
|
||||
'TableAttentionLoss', 'SARLoss'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -75,12 +75,6 @@ class BalanceLoss(nn.Layer):
|
|||
mask (variable): masked maps.
|
||||
return: (variable) balanced loss
|
||||
"""
|
||||
# if self.main_loss_type in ['DiceLoss']:
|
||||
# # For the loss that returns to scalar value, perform ohem on the mask
|
||||
# mask = ohem_batch(pred, gt, mask, self.negative_ratio)
|
||||
# loss = self.loss(pred, gt, mask)
|
||||
# return loss
|
||||
|
||||
positive = gt * mask
|
||||
negative = (1 - gt) * mask
|
||||
|
||||
|
@ -153,53 +147,4 @@ class BCELoss(nn.Layer):
|
|||
|
||||
def forward(self, input, label, mask=None, weight=None, name=None):
|
||||
loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
|
||||
return loss
|
||||
|
||||
|
||||
def ohem_single(score, gt_text, training_mask, ohem_ratio):
|
||||
pos_num = (int)(np.sum(gt_text > 0.5)) - (
|
||||
int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
|
||||
|
||||
if pos_num == 0:
|
||||
# selected_mask = gt_text.copy() * 0 # may be not good
|
||||
selected_mask = training_mask
|
||||
selected_mask = selected_mask.reshape(
|
||||
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
||||
return selected_mask
|
||||
|
||||
neg_num = (int)(np.sum(gt_text <= 0.5))
|
||||
neg_num = (int)(min(pos_num * ohem_ratio, neg_num))
|
||||
|
||||
if neg_num == 0:
|
||||
selected_mask = training_mask
|
||||
selected_mask = selected_mask.reshape(
|
||||
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
||||
return selected_mask
|
||||
|
||||
neg_score = score[gt_text <= 0.5]
|
||||
# 将负样本得分从高到低排序
|
||||
neg_score_sorted = np.sort(-neg_score)
|
||||
threshold = -neg_score_sorted[neg_num - 1]
|
||||
# 选出 得分高的 负样本 和正样本 的 mask
|
||||
selected_mask = ((score >= threshold) |
|
||||
(gt_text > 0.5)) & (training_mask > 0.5)
|
||||
selected_mask = selected_mask.reshape(
|
||||
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
||||
return selected_mask
|
||||
|
||||
|
||||
def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
|
||||
scores = scores.numpy()
|
||||
gt_texts = gt_texts.numpy()
|
||||
training_masks = training_masks.numpy()
|
||||
|
||||
selected_masks = []
|
||||
for i in range(scores.shape[0]):
|
||||
selected_masks.append(
|
||||
ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
|
||||
i, :, :], ohem_ratio))
|
||||
|
||||
selected_masks = np.concatenate(selected_masks, 0)
|
||||
selected_masks = paddle.to_tensor(selected_masks)
|
||||
|
||||
return selected_masks
|
||||
return loss
|
|
@ -0,0 +1,145 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
from ppocr.utils.iou import iou
|
||||
|
||||
|
||||
class PSELoss(nn.Layer):
|
||||
def __init__(self,
|
||||
alpha,
|
||||
ohem_ratio=3,
|
||||
kernel_sample_mask='pred',
|
||||
reduction='sum',
|
||||
eps=1e-6,
|
||||
**kwargs):
|
||||
"""Implement PSE Loss.
|
||||
"""
|
||||
super(PSELoss, self).__init__()
|
||||
assert reduction in ['sum', 'mean', 'none']
|
||||
self.alpha = alpha
|
||||
self.ohem_ratio = ohem_ratio
|
||||
self.kernel_sample_mask = kernel_sample_mask
|
||||
self.reduction = reduction
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, outputs, labels):
|
||||
predicts = outputs['maps']
|
||||
predicts = F.interpolate(predicts, scale_factor=4)
|
||||
|
||||
texts = predicts[:, 0, :, :]
|
||||
kernels = predicts[:, 1:, :, :]
|
||||
gt_texts, gt_kernels, training_masks = labels[1:]
|
||||
|
||||
# text loss
|
||||
selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
|
||||
|
||||
loss_text = self.dice_loss(texts, gt_texts, selected_masks)
|
||||
iou_text = iou((texts > 0).astype('int64'),
|
||||
gt_texts,
|
||||
training_masks,
|
||||
reduce=False)
|
||||
losses = dict(loss_text=loss_text, iou_text=iou_text)
|
||||
|
||||
# kernel loss
|
||||
loss_kernels = []
|
||||
if self.kernel_sample_mask == 'gt':
|
||||
selected_masks = gt_texts * training_masks
|
||||
elif self.kernel_sample_mask == 'pred':
|
||||
selected_masks = (
|
||||
F.sigmoid(texts) > 0.5).astype('float32') * training_masks
|
||||
|
||||
for i in range(kernels.shape[1]):
|
||||
kernel_i = kernels[:, i, :, :]
|
||||
gt_kernel_i = gt_kernels[:, i, :, :]
|
||||
loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i,
|
||||
selected_masks)
|
||||
loss_kernels.append(loss_kernel_i)
|
||||
loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1)
|
||||
iou_kernel = iou((kernels[:, -1, :, :] > 0).astype('int64'),
|
||||
gt_kernels[:, -1, :, :],
|
||||
training_masks * gt_texts,
|
||||
reduce=False)
|
||||
losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel))
|
||||
loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels
|
||||
losses['loss'] = loss
|
||||
if self.reduction == 'sum':
|
||||
losses = {x: paddle.sum(v) for x, v in losses.items()}
|
||||
elif self.reduction == 'mean':
|
||||
losses = {x: paddle.mean(v) for x, v in losses.items()}
|
||||
return losses
|
||||
|
||||
def dice_loss(self, input, target, mask):
|
||||
input = F.sigmoid(input)
|
||||
|
||||
input = input.reshape([input.shape[0], -1])
|
||||
target = target.reshape([target.shape[0], -1])
|
||||
mask = mask.reshape([mask.shape[0], -1])
|
||||
|
||||
input = input * mask
|
||||
target = target * mask
|
||||
|
||||
a = paddle.sum(input * target, 1)
|
||||
b = paddle.sum(input * input, 1) + self.eps
|
||||
c = paddle.sum(target * target, 1) + self.eps
|
||||
d = (2 * a) / (b + c)
|
||||
return 1 - d
|
||||
|
||||
def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3):
|
||||
pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int(
|
||||
paddle.sum(
|
||||
paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5))
|
||||
.astype('float32')))
|
||||
|
||||
if pos_num == 0:
|
||||
selected_mask = training_mask
|
||||
selected_mask = selected_mask.reshape(
|
||||
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
|
||||
'float32')
|
||||
return selected_mask
|
||||
|
||||
neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32')))
|
||||
neg_num = int(min(pos_num * ohem_ratio, neg_num))
|
||||
|
||||
if neg_num == 0:
|
||||
selected_mask = training_mask
|
||||
selected_mask = selected_mask.view(
|
||||
1, selected_mask.shape[0],
|
||||
selected_mask.shape[1]).astype('float32')
|
||||
return selected_mask
|
||||
|
||||
neg_score = paddle.masked_select(score, gt_text <= 0.5)
|
||||
neg_score_sorted = paddle.sort(-neg_score)
|
||||
threshold = -neg_score_sorted[neg_num - 1]
|
||||
|
||||
selected_mask = paddle.logical_and(
|
||||
paddle.logical_or((score >= threshold), (gt_text > 0.5)),
|
||||
(training_mask > 0.5))
|
||||
selected_mask = selected_mask.reshape(
|
||||
[1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
|
||||
'float32')
|
||||
return selected_mask
|
||||
|
||||
def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3):
|
||||
selected_masks = []
|
||||
for i in range(scores.shape[0]):
|
||||
selected_masks.append(
|
||||
self.ohem_single(scores[i, :, :], gt_texts[i, :, :],
|
||||
training_masks[i, :, :], ohem_ratio))
|
||||
|
||||
selected_masks = paddle.concat(selected_masks, 0).astype('float32')
|
||||
return selected_masks
|
|
@ -169,21 +169,10 @@ class DetectionIoUEvaluator(object):
|
|||
numGlobalCareDet += numDetCare
|
||||
|
||||
perSampleMetrics = {
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'hmean': hmean,
|
||||
'pairs': pairs,
|
||||
'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
|
||||
'gtPolPoints': gtPolPoints,
|
||||
'detPolPoints': detPolPoints,
|
||||
'gtCare': numGtCare,
|
||||
'detCare': numDetCare,
|
||||
'gtDontCare': gtDontCarePolsNum,
|
||||
'detDontCare': detDontCarePolsNum,
|
||||
'detMatched': detMatched,
|
||||
'evaluationLog': evaluationLog
|
||||
}
|
||||
|
||||
return perSampleMetrics
|
||||
|
||||
def combine_results(self, results):
|
||||
|
|
|
@ -20,6 +20,7 @@ def build_head(config):
|
|||
from .det_db_head import DBHead
|
||||
from .det_east_head import EASTHead
|
||||
from .det_sast_head import SASTHead
|
||||
from .det_pse_head import PSEHead
|
||||
from .e2e_pg_head import PGHead
|
||||
|
||||
# rec head
|
||||
|
@ -32,8 +33,9 @@ def build_head(config):
|
|||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead'
|
||||
'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead',
|
||||
'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
|
||||
'TableAttentionHead', 'SARHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class PSEHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_dim=256,
|
||||
out_channels=7,
|
||||
**kwargs):
|
||||
super(PSEHead, self).__init__()
|
||||
self.conv1 = nn.Conv2D(in_channels, hidden_dim, kernel_size=3, stride=1, padding=1)
|
||||
self.bn1 = nn.BatchNorm2D(hidden_dim)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
self.conv2 = nn.Conv2D(hidden_dim, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
out = self.conv1(x)
|
||||
out = self.relu1(self.bn1(out))
|
||||
out = self.conv2(out)
|
||||
return {'maps': out}
|
|
@ -22,7 +22,8 @@ def build_neck(config):
|
|||
from .rnn import SequenceEncoder
|
||||
from .pg_fpn import PGFPN
|
||||
from .table_fpn import TableFPN
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
|
||||
from .fpn import FPN
|
||||
support_dict = ['FPN','DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.nn as nn
|
||||
import paddle
|
||||
import math
|
||||
import paddle.nn.functional as F
|
||||
|
||||
class Conv_BN_ReLU(nn.Layer):
|
||||
def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0):
|
||||
super(Conv_BN_ReLU, self).__init__()
|
||||
self.conv = nn.Conv2D(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
|
||||
bias_attr=False)
|
||||
self.bn = nn.BatchNorm2D(out_planes, momentum=0.1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
for m in self.sublayers():
|
||||
if isinstance(m, nn.Conv2D):
|
||||
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
|
||||
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Normal(0, math.sqrt(2. / n)))
|
||||
elif isinstance(m, nn.BatchNorm2D):
|
||||
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(1.0))
|
||||
m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(0.0))
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.bn(self.conv(x)))
|
||||
|
||||
class FPN(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(FPN, self).__init__()
|
||||
|
||||
# Top layer
|
||||
self.toplayer_ = Conv_BN_ReLU(in_channels[3], out_channels, kernel_size=1, stride=1, padding=0)
|
||||
# Lateral layers
|
||||
self.latlayer1_ = Conv_BN_ReLU(in_channels[2], out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.latlayer2_ = Conv_BN_ReLU(in_channels[1], out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.latlayer3_ = Conv_BN_ReLU(in_channels[0], out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# Smooth layers
|
||||
self.smooth1_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.smooth2_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.smooth3_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
|
||||
self.out_channels = out_channels * 4
|
||||
for m in self.sublayers():
|
||||
if isinstance(m, nn.Conv2D):
|
||||
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
|
||||
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32',
|
||||
default_initializer=paddle.nn.initializer.Normal(0,
|
||||
math.sqrt(2. / n)))
|
||||
elif isinstance(m, nn.BatchNorm2D):
|
||||
m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32',
|
||||
default_initializer=paddle.nn.initializer.Constant(1.0))
|
||||
m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32',
|
||||
default_initializer=paddle.nn.initializer.Constant(0.0))
|
||||
|
||||
def _upsample(self, x, scale=1):
|
||||
return F.upsample(x, scale_factor=scale, mode='bilinear')
|
||||
|
||||
def _upsample_add(self, x, y, scale=1):
|
||||
return F.upsample(x, scale_factor=scale, mode='bilinear') + y
|
||||
|
||||
def forward(self, x):
|
||||
f2, f3, f4, f5 = x
|
||||
p5 = self.toplayer_(f5)
|
||||
|
||||
f4 = self.latlayer1_(f4)
|
||||
p4 = self._upsample_add(p5, f4,2)
|
||||
p4 = self.smooth1_(p4)
|
||||
|
||||
f3 = self.latlayer2_(f3)
|
||||
p3 = self._upsample_add(p4, f3,2)
|
||||
p3 = self.smooth2_(p3)
|
||||
|
||||
f2 = self.latlayer3_(f2)
|
||||
p2 = self._upsample_add(p3, f2,2)
|
||||
p2 = self.smooth3_(p2)
|
||||
|
||||
p3 = self._upsample(p3, 2)
|
||||
p4 = self._upsample(p4, 4)
|
||||
p5 = self._upsample(p5, 8)
|
||||
|
||||
fuse = paddle.concat([p2, p3, p4, p5], axis=1)
|
||||
return fuse
|
|
@ -28,12 +28,14 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
|
|||
TableLabelDecode, SARLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .pse_postprocess import PSEPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DBPostProcess', 'PSEPostProcess', 'EASTPostProcess', 'SASTPostProcess',
|
||||
'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
|
||||
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode'
|
||||
]
|
||||
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .pse_postprocess import PSEPostProcess
|
|
@ -0,0 +1,5 @@
|
|||
## 编译
|
||||
code from https://github.com/whai362/pan_pp.pytorch
|
||||
```python
|
||||
python3 setup.py build_ext --inplace
|
||||
```
|
|
@ -0,0 +1,23 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
python_path = sys.executable
|
||||
|
||||
if subprocess.call('cd ppocr/postprocess/pse_postprocess/pse;{} setup.py build_ext --inplace;cd -'.format(python_path), shell=True) != 0:
|
||||
raise RuntimeError('Cannot compile pse: {}'.format(os.path.dirname(os.path.realpath(__file__))))
|
||||
|
||||
from .pse import pse
|
|
@ -0,0 +1,70 @@
|
|||
|
||||
import numpy as np
|
||||
import cv2
|
||||
cimport numpy as np
|
||||
cimport cython
|
||||
cimport libcpp
|
||||
cimport libcpp.pair
|
||||
cimport libcpp.queue
|
||||
from libcpp.pair cimport *
|
||||
from libcpp.queue cimport *
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
|
||||
np.ndarray[np.int32_t, ndim=2] label,
|
||||
int kernel_num,
|
||||
int label_num,
|
||||
float min_area=0):
|
||||
cdef np.ndarray[np.int32_t, ndim=2] pred
|
||||
pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
|
||||
|
||||
for label_idx in range(1, label_num):
|
||||
if np.sum(label == label_idx) < min_area:
|
||||
label[label == label_idx] = 0
|
||||
|
||||
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
|
||||
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
|
||||
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
|
||||
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
|
||||
cdef np.int16_t* dx = [-1, 1, 0, 0]
|
||||
cdef np.int16_t* dy = [0, 0, -1, 1]
|
||||
cdef np.int16_t tmpx, tmpy
|
||||
|
||||
points = np.array(np.where(label > 0)).transpose((1, 0))
|
||||
for point_idx in range(points.shape[0]):
|
||||
tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
|
||||
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
|
||||
pred[tmpx, tmpy] = label[tmpx, tmpy]
|
||||
|
||||
cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
|
||||
cdef int cur_label
|
||||
for kernel_idx in range(kernel_num - 1, -1, -1):
|
||||
while not que.empty():
|
||||
cur = que.front()
|
||||
que.pop()
|
||||
cur_label = pred[cur.first, cur.second]
|
||||
|
||||
is_edge = True
|
||||
for j in range(4):
|
||||
tmpx = cur.first + dx[j]
|
||||
tmpy = cur.second + dy[j]
|
||||
if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
|
||||
continue
|
||||
if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
|
||||
continue
|
||||
|
||||
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
|
||||
pred[tmpx, tmpy] = cur_label
|
||||
is_edge = False
|
||||
if is_edge:
|
||||
nxt_que.push(cur)
|
||||
|
||||
que, nxt_que = nxt_que, que
|
||||
|
||||
return pred
|
||||
|
||||
def pse(kernels, min_area):
|
||||
kernel_num = kernels.shape[0]
|
||||
label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
|
||||
return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
|
|
@ -0,0 +1,14 @@
|
|||
from distutils.core import setup, Extension
|
||||
from Cython.Build import cythonize
|
||||
import numpy
|
||||
|
||||
setup(ext_modules=cythonize(Extension(
|
||||
'pse',
|
||||
sources=['pse.pyx'],
|
||||
language='c++',
|
||||
include_dirs=[numpy.get_include()],
|
||||
library_dirs=[],
|
||||
libraries=[],
|
||||
extra_compile_args=['-O3'],
|
||||
extra_link_args=[]
|
||||
)))
|
|
@ -0,0 +1,112 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
|
||||
from ppocr.postprocess.pse_postprocess.pse import pse
|
||||
|
||||
|
||||
class PSEPostProcess(object):
|
||||
"""
|
||||
The post process for PSE.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
thresh=0.5,
|
||||
box_thresh=0.85,
|
||||
min_area=16,
|
||||
box_type='box',
|
||||
scale=4,
|
||||
**kwargs):
|
||||
assert box_type in ['box', 'poly'], 'Only box and poly is supported'
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
self.min_area = min_area
|
||||
self.box_type = box_type
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict['maps']
|
||||
if not isinstance(pred, paddle.Tensor):
|
||||
pred = paddle.to_tensor(pred)
|
||||
pred = F.interpolate(pred, scale_factor=4 // self.scale, mode='bilinear')
|
||||
|
||||
score = F.sigmoid(pred[:, 0, :, :])
|
||||
|
||||
kernels = (pred > self.thresh).astype('float32')
|
||||
text_mask = kernels[:, 0, :, :]
|
||||
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
|
||||
|
||||
score = score.numpy()
|
||||
kernels = kernels.numpy().astype(np.uint8)
|
||||
|
||||
boxes_batch = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
boxes, scores = self.boxes_from_bitmap(score[batch_index], kernels[batch_index], shape_list[batch_index])
|
||||
|
||||
boxes_batch.append({'points': boxes, 'scores': scores})
|
||||
return boxes_batch
|
||||
|
||||
def boxes_from_bitmap(self, score, kernels, shape):
|
||||
label = pse(kernels, self.min_area)
|
||||
return self.generate_box(score, label, shape)
|
||||
|
||||
def generate_box(self, score, label, shape):
|
||||
src_h, src_w, ratio_h, ratio_w = shape
|
||||
label_num = np.max(label) + 1
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
for i in range(1, label_num):
|
||||
ind = label == i
|
||||
points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
|
||||
|
||||
if points.shape[0] < self.min_area:
|
||||
label[ind] = 0
|
||||
continue
|
||||
|
||||
score_i = np.mean(score[ind])
|
||||
if score_i < self.box_thresh:
|
||||
label[ind] = 0
|
||||
continue
|
||||
|
||||
if self.box_type == 'box':
|
||||
rect = cv2.minAreaRect(points)
|
||||
bbox = cv2.boxPoints(rect)
|
||||
elif self.box_type == 'poly':
|
||||
box_height = np.max(points[:, 1]) + 10
|
||||
box_width = np.max(points[:, 0]) + 10
|
||||
|
||||
mask = np.zeros((box_height, box_width), np.uint8)
|
||||
mask[points[:, 1], points[:, 0]] = 255
|
||||
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
bbox = np.squeeze(contours[0], 1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
bbox[:, 0] = np.clip(
|
||||
np.round(bbox[:, 0] / ratio_w), 0, src_w)
|
||||
bbox[:, 1] = np.clip(
|
||||
np.round(bbox[:, 1] / ratio_h), 0, src_h)
|
||||
boxes.append(bbox)
|
||||
scores.append(score_i)
|
||||
return boxes, scores
|
|
@ -0,0 +1,48 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
def iou_single(a, b, mask, n_class):
|
||||
valid = mask == 1
|
||||
a = a.masked_select(valid)
|
||||
b = b.masked_select(valid)
|
||||
miou = []
|
||||
for i in range(n_class):
|
||||
if a.shape == [0] and a.shape==b.shape:
|
||||
inter = paddle.to_tensor(0.0)
|
||||
union = paddle.to_tensor(0.0)
|
||||
else:
|
||||
inter = ((a == i).logical_and(b == i)).astype('float32')
|
||||
union = ((a == i).logical_or(b == i)).astype('float32')
|
||||
miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
|
||||
miou = sum(miou) / len(miou)
|
||||
return miou
|
||||
|
||||
def iou(a, b, mask, n_class=2, reduce=True):
|
||||
batch_size = a.shape[0]
|
||||
|
||||
a = a.reshape([batch_size, -1])
|
||||
b = b.reshape([batch_size, -1])
|
||||
mask = mask.reshape([batch_size, -1])
|
||||
|
||||
iou = paddle.zeros((batch_size,), dtype='float32')
|
||||
for i in range(batch_size):
|
||||
iou[i] = iou_single(a[i], b[i], mask[i], n_class)
|
||||
|
||||
if reduce:
|
||||
iou = paddle.mean(iou)
|
||||
return iou
|
|
@ -8,6 +8,7 @@ numpy
|
|||
visualdl
|
||||
python-Levenshtein
|
||||
opencv-contrib-python==4.4.0.46
|
||||
cython
|
||||
lxml
|
||||
premailer
|
||||
openpyxl
|
|
@ -89,6 +89,14 @@ class TextDetector(object):
|
|||
postprocess_params["sample_pts_num"] = 2
|
||||
postprocess_params["expand_scale"] = 1.0
|
||||
postprocess_params["shrink_ratio_of_width"] = 0.3
|
||||
elif self.det_algorithm == "PSE":
|
||||
postprocess_params['name'] = 'PSEPostProcess'
|
||||
postprocess_params["thresh"] = args.det_pse_thresh
|
||||
postprocess_params["box_thresh"] = args.det_pse_box_thresh
|
||||
postprocess_params["min_area"] = args.det_pse_min_area
|
||||
postprocess_params["box_type"] = args.det_pse_box_type
|
||||
postprocess_params["scale"] = args.det_pse_scale
|
||||
self.det_pse_box_type = args.det_pse_box_type
|
||||
else:
|
||||
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
|
||||
sys.exit(0)
|
||||
|
@ -209,7 +217,7 @@ class TextDetector(object):
|
|||
preds['f_score'] = outputs[1]
|
||||
preds['f_tco'] = outputs[2]
|
||||
preds['f_tvo'] = outputs[3]
|
||||
elif self.det_algorithm == 'DB':
|
||||
elif self.det_algorithm in ['DB', 'PSE']:
|
||||
preds['maps'] = outputs[0]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -217,7 +225,9 @@ class TextDetector(object):
|
|||
#self.predictor.try_shrink_memory()
|
||||
post_result = self.postprocess_op(preds, shape_list)
|
||||
dt_boxes = post_result[0]['points']
|
||||
if self.det_algorithm == "SAST" and self.det_sast_polygon:
|
||||
if (self.det_algorithm == "SAST" and
|
||||
self.det_sast_polygon) or (self.det_algorithm == "PSE" and
|
||||
self.det_pse_box_type == 'poly'):
|
||||
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
|
||||
else:
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
|
|
|
@ -63,6 +63,13 @@ def init_args():
|
|||
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
||||
parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
|
||||
|
||||
# PSE parmas
|
||||
parser.add_argument("--det_pse_thresh", type=float, default=0)
|
||||
parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
|
||||
parser.add_argument("--det_pse_min_area", type=float, default=16)
|
||||
parser.add_argument("--det_pse_box_type", type=str, default='box')
|
||||
parser.add_argument("--det_pse_scale", type=int, default=1)
|
||||
|
||||
# params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||
parser.add_argument("--rec_model_dir", type=str)
|
||||
|
|
|
@ -402,7 +402,7 @@ def preprocess(is_train=False):
|
|||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR'
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
Loading…
Reference in New Issue