add psenet
This commit is contained in:
parent
bd92c22b61
commit
fa790288f1
|
@ -0,0 +1,26 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from paddle.vision.transforms import ColorJitter as pp_ColorJitter
|
||||
|
||||
__all__ = ['ColorJitter']
|
||||
|
||||
class ColorJitter(object):
|
||||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs):
|
||||
self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
|
||||
|
||||
def __call__(self, data):
|
||||
image = data['image']
|
||||
image = self.aug(image)
|
||||
data['image'] = image
|
||||
return data
|
|
@ -19,11 +19,13 @@ from __future__ import unicode_literals
|
|||
from .iaa_augment import IaaAugment
|
||||
from .make_border_map import MakeBorderMap
|
||||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
from .random_crop_data import EastRandomCropData, RandomCropImgMask
|
||||
from .make_pse_gt import MakePseGt
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .ColorJitter import ColorJitter
|
||||
from .operators import *
|
||||
from .label_ops import *
|
||||
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pyclipper
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
__all__ = ['MakePseGt']
|
||||
|
||||
class MakePseGt(object):
|
||||
r'''
|
||||
Making binary mask from detection data with ICDAR format.
|
||||
Typically following the process of class `MakeICDARData`.
|
||||
'''
|
||||
|
||||
def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
|
||||
self.kernel_num = kernel_num
|
||||
self.min_shrink_ratio = min_shrink_ratio
|
||||
self.size = size
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
image = data['image']
|
||||
text_polys = data['polys']
|
||||
ignore_tags = data['ignore_tags']
|
||||
|
||||
h, w, _ = image.shape
|
||||
short_edge = min(h, w)
|
||||
if short_edge < self.size:
|
||||
# keep short_size >= self.size
|
||||
scale = self.size / short_edge
|
||||
image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
|
||||
text_polys *= scale
|
||||
|
||||
gt_kernels = []
|
||||
for i in range(1,self.kernel_num+1):
|
||||
# s1->sn, from big to small
|
||||
rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i
|
||||
text_kernel, ignore_tags = self.generate_kernel(image.shape[0:2], rate, text_polys, ignore_tags)
|
||||
gt_kernels.append(text_kernel)
|
||||
|
||||
training_mask = np.ones(image.shape[0:2], dtype='uint8')
|
||||
for i in range(text_polys.shape[0]):
|
||||
if ignore_tags[i]:
|
||||
cv2.fillPoly(training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0)
|
||||
|
||||
gt_kernels = np.array(gt_kernels)
|
||||
gt_kernels[gt_kernels > 0] = 1
|
||||
|
||||
data['image'] = image
|
||||
data['polys'] = text_polys
|
||||
data['gt_kernels'] = gt_kernels[0:]
|
||||
data['gt_text'] = gt_kernels[0]
|
||||
data['mask'] = training_mask.astype('float32')
|
||||
return data
|
||||
|
||||
def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None):
|
||||
h, w = img_size
|
||||
text_kernel = np.zeros((h, w), dtype=np.float32)
|
||||
for i, poly in enumerate(text_polys):
|
||||
polygon = Polygon(poly)
|
||||
distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (polygon.length + 1e-6)
|
||||
subject = [tuple(l) for l in poly]
|
||||
pco = pyclipper.PyclipperOffset()
|
||||
pco.AddPath(subject, pyclipper.JT_ROUND,
|
||||
pyclipper.ET_CLOSEDPOLYGON)
|
||||
shrinked = np.array(pco.Execute(-distance))
|
||||
|
||||
if len(shrinked) == 0 or shrinked.size == 0:
|
||||
if ignore_tags is not None:
|
||||
ignore_tags[i] = True
|
||||
continue
|
||||
try:
|
||||
shrinked = np.array(shrinked[0]).reshape(-1, 2)
|
||||
except:
|
||||
if ignore_tags is not None:
|
||||
ignore_tags[i] = True
|
||||
continue
|
||||
cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
|
||||
return text_kernel, ignore_tags
|
|
@ -164,47 +164,55 @@ class EastRandomCropData(object):
|
|||
return data
|
||||
|
||||
|
||||
class PSERandomCrop(object):
|
||||
def __init__(self, size, **kwargs):
|
||||
class RandomCropImgMask(object):
|
||||
def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
|
||||
self.size = size
|
||||
self.main_key = main_key
|
||||
self.crop_keys = crop_keys
|
||||
self.p = p
|
||||
|
||||
def __call__(self, data):
|
||||
imgs = data['imgs']
|
||||
image = data['image']
|
||||
|
||||
h, w = imgs[0].shape[0:2]
|
||||
h, w = image.shape[0:2]
|
||||
th, tw = self.size
|
||||
if w == tw and h == th:
|
||||
return imgs
|
||||
return data
|
||||
|
||||
# label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
|
||||
if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
|
||||
# 文本实例的左上角点
|
||||
tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
|
||||
mask = data[self.main_key]
|
||||
if np.max(mask) > 0 and random.random() > self.p:
|
||||
# make sure to crop the text region
|
||||
tl = np.min(np.where(mask > 0), axis=1) - (th, tw)
|
||||
tl[tl < 0] = 0
|
||||
# 文本实例的右下角点
|
||||
br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
|
||||
br = np.max(np.where(mask > 0), axis=1) - (th, tw)
|
||||
br[br < 0] = 0
|
||||
# 保证选到右下角点时,有足够的距离进行crop
|
||||
|
||||
br[0] = min(br[0], h - th)
|
||||
br[1] = min(br[1], w - tw)
|
||||
|
||||
for _ in range(50000):
|
||||
i = random.randint(tl[0], br[0])
|
||||
j = random.randint(tl[1], br[1])
|
||||
# 保证shrink_label_map有文本
|
||||
if imgs[1][i:i + th, j:j + tw].sum() <= 0:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
|
||||
j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
|
||||
else:
|
||||
i = random.randint(0, h - th)
|
||||
j = random.randint(0, w - tw)
|
||||
i = random.randint(0, h - th) if h - th > 0 else 0
|
||||
j = random.randint(0, w - tw) if w - tw > 0 else 0
|
||||
|
||||
# return i, j, th, tw
|
||||
for idx in range(len(imgs)):
|
||||
if len(imgs[idx].shape) == 3:
|
||||
imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
|
||||
else:
|
||||
imgs[idx] = imgs[idx][i:i + th, j:j + tw]
|
||||
data['imgs'] = imgs
|
||||
for k in data:
|
||||
if k in self.crop_keys:
|
||||
if len(data[k].shape) == 3:
|
||||
if np.argmin(data[k].shape) == 0:
|
||||
img = data[k][:, i:i + th, j:j + tw]
|
||||
if img.shape[1] != img.shape[2]:
|
||||
a = 1
|
||||
elif np.argmin(data[k].shape) == 2:
|
||||
img = data[k][i:i + th, j:j + tw, :]
|
||||
if img.shape[1] != img.shape[0]:
|
||||
a = 1
|
||||
else:
|
||||
img = data[k]
|
||||
else:
|
||||
img = data[k][i:i + th, j:j + tw]
|
||||
if img.shape[0] != img.shape[1]:
|
||||
a = 1
|
||||
data[k] = img
|
||||
return data
|
||||
|
|
|
@ -13,13 +13,12 @@
|
|||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
# det loss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_east_loss import EASTLoss
|
||||
from .det_sast_loss import SASTLoss
|
||||
from .det_pse_loss import PSELoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
|
@ -41,9 +40,10 @@ from .combined_loss import CombinedLoss
|
|||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -75,12 +75,6 @@ class BalanceLoss(nn.Layer):
|
|||
mask (variable): masked maps.
|
||||
return: (variable) balanced loss
|
||||
"""
|
||||
# if self.main_loss_type in ['DiceLoss']:
|
||||
# # For the loss that returns to scalar value, perform ohem on the mask
|
||||
# mask = ohem_batch(pred, gt, mask, self.negative_ratio)
|
||||
# loss = self.loss(pred, gt, mask)
|
||||
# return loss
|
||||
|
||||
positive = gt * mask
|
||||
negative = (1 - gt) * mask
|
||||
|
||||
|
@ -153,53 +147,4 @@ class BCELoss(nn.Layer):
|
|||
|
||||
def forward(self, input, label, mask=None, weight=None, name=None):
|
||||
loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
|
||||
return loss
|
||||
|
||||
|
||||
def ohem_single(score, gt_text, training_mask, ohem_ratio):
|
||||
pos_num = (int)(np.sum(gt_text > 0.5)) - (
|
||||
int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
|
||||
|
||||
if pos_num == 0:
|
||||
# selected_mask = gt_text.copy() * 0 # may be not good
|
||||
selected_mask = training_mask
|
||||
selected_mask = selected_mask.reshape(
|
||||
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
||||
return selected_mask
|
||||
|
||||
neg_num = (int)(np.sum(gt_text <= 0.5))
|
||||
neg_num = (int)(min(pos_num * ohem_ratio, neg_num))
|
||||
|
||||
if neg_num == 0:
|
||||
selected_mask = training_mask
|
||||
selected_mask = selected_mask.reshape(
|
||||
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
||||
return selected_mask
|
||||
|
||||
neg_score = score[gt_text <= 0.5]
|
||||
# 将负样本得分从高到低排序
|
||||
neg_score_sorted = np.sort(-neg_score)
|
||||
threshold = -neg_score_sorted[neg_num - 1]
|
||||
# 选出 得分高的 负样本 和正样本 的 mask
|
||||
selected_mask = ((score >= threshold) |
|
||||
(gt_text > 0.5)) & (training_mask > 0.5)
|
||||
selected_mask = selected_mask.reshape(
|
||||
1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
||||
return selected_mask
|
||||
|
||||
|
||||
def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
|
||||
scores = scores.numpy()
|
||||
gt_texts = gt_texts.numpy()
|
||||
training_masks = training_masks.numpy()
|
||||
|
||||
selected_masks = []
|
||||
for i in range(scores.shape[0]):
|
||||
selected_masks.append(
|
||||
ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
|
||||
i, :, :], ohem_ratio))
|
||||
|
||||
selected_masks = np.concatenate(selected_masks, 0)
|
||||
selected_masks = paddle.to_tensor(selected_masks)
|
||||
|
||||
return selected_masks
|
||||
return loss
|
|
@ -0,0 +1,119 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Time : 3/29/19 11:03 AM
|
||||
# @Author : zhoujun
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
from ppocr.utils.iou import iou
|
||||
|
||||
|
||||
class PSELoss(nn.Layer):
|
||||
def __init__(self, alpha, ohem_ratio=3, kernel_sample_mask='pred', reduction='sum', **kwargs):
|
||||
"""Implement PSE Loss.
|
||||
"""
|
||||
super(PSELoss, self).__init__()
|
||||
assert reduction in ['sum', 'mean', 'none']
|
||||
self.alpha = alpha
|
||||
self.ohem_ratio = ohem_ratio
|
||||
self.kernel_sample_mask = kernel_sample_mask
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, outputs, labels):
|
||||
predicts = outputs['maps']
|
||||
predicts = F.interpolate(predicts, scale_factor=4)
|
||||
|
||||
texts = predicts[:, 0, :, :]
|
||||
kernels = predicts[:, 1:, :, :]
|
||||
gt_texts, gt_kernels, training_masks = labels[1:]
|
||||
|
||||
# text loss
|
||||
selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
|
||||
|
||||
loss_text = self.dice_loss(texts, gt_texts, selected_masks)
|
||||
iou_text = iou((texts > 0).astype('int64'), gt_texts, training_masks, reduce=False)
|
||||
losses = dict(
|
||||
loss_text=loss_text,
|
||||
iou_text=iou_text
|
||||
)
|
||||
|
||||
# kernel loss
|
||||
loss_kernels = []
|
||||
if self.kernel_sample_mask == 'gt':
|
||||
selected_masks = gt_texts * training_masks
|
||||
elif self.kernel_sample_mask == 'pred':
|
||||
selected_masks = (F.sigmoid(texts) > 0.5).astype('float32') * training_masks
|
||||
|
||||
for i in range(kernels.shape[1]):
|
||||
kernel_i = kernels[:, i, :, :]
|
||||
gt_kernel_i = gt_kernels[:, i, :, :]
|
||||
loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i, selected_masks)
|
||||
loss_kernels.append(loss_kernel_i)
|
||||
loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1)
|
||||
iou_kernel = iou(
|
||||
(kernels[:, -1, :, :] > 0).astype('int64'), gt_kernels[:, -1, :, :], training_masks * gt_texts,
|
||||
reduce=False)
|
||||
losses.update(dict(
|
||||
loss_kernels=loss_kernels,
|
||||
iou_kernel=iou_kernel
|
||||
))
|
||||
loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels
|
||||
losses['loss'] = loss
|
||||
if self.reduction == 'sum':
|
||||
losses = {x: paddle.sum(v) for x, v in losses.items()}
|
||||
elif self.reduction == 'mean':
|
||||
losses = {x: paddle.mean(v) for x, v in losses.items()}
|
||||
return losses
|
||||
|
||||
def dice_loss(self, input, target, mask):
|
||||
input = F.sigmoid(input)
|
||||
|
||||
input = input.reshape([input.shape[0], -1])
|
||||
target = target.reshape([target.shape[0], -1])
|
||||
mask = mask.reshape([mask.shape[0], -1])
|
||||
|
||||
input = input * mask
|
||||
target = target * mask
|
||||
|
||||
a = paddle.sum(input * target, 1)
|
||||
b = paddle.sum(input * input, 1) + 0.001
|
||||
c = paddle.sum(target * target, 1) + 0.001
|
||||
d = (2 * a) / (b + c)
|
||||
return 1 - d
|
||||
|
||||
def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3):
|
||||
pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int(
|
||||
paddle.sum(paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5)).astype('float32')))
|
||||
|
||||
if pos_num == 0:
|
||||
# selected_mask = gt_text.copy() * 0 # may be not good
|
||||
selected_mask = training_mask
|
||||
selected_mask = selected_mask.reshape([1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
|
||||
'float32')
|
||||
return selected_mask
|
||||
|
||||
neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32')))
|
||||
neg_num = int(min(pos_num * ohem_ratio, neg_num))
|
||||
|
||||
if neg_num == 0:
|
||||
selected_mask = training_mask
|
||||
selected_mask = selected_mask.view(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
|
||||
return selected_mask
|
||||
|
||||
neg_score = paddle.masked_select(score, gt_text <= 0.5)
|
||||
neg_score_sorted = paddle.sort(-neg_score)
|
||||
threshold = -neg_score_sorted[neg_num - 1]
|
||||
|
||||
selected_mask = paddle.logical_and(paddle.logical_or((score >= threshold), (gt_text > 0.5)),
|
||||
(training_mask > 0.5))
|
||||
selected_mask = selected_mask.reshape([1, selected_mask.shape[0], selected_mask.shape[1]]).astype('float32')
|
||||
return selected_mask
|
||||
|
||||
def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3):
|
||||
selected_masks = []
|
||||
for i in range(scores.shape[0]):
|
||||
selected_masks.append(
|
||||
self.ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[i, :, :], ohem_ratio))
|
||||
|
||||
selected_masks = paddle.concat(selected_masks, 0).astype('float32')
|
||||
return selected_masks
|
|
@ -20,6 +20,7 @@ def build_head(config):
|
|||
from .det_db_head import DBHead
|
||||
from .det_east_head import EASTHead
|
||||
from .det_sast_head import SASTHead
|
||||
from .det_pse_head import PSEHead
|
||||
from .e2e_pg_head import PGHead
|
||||
|
||||
# rec head
|
||||
|
@ -30,10 +31,10 @@ def build_head(config):
|
|||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead', 'TableAttentionHead']
|
||||
|
||||
#table head
|
||||
# table head
|
||||
from .table_att_head import TableAttentionHead
|
||||
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class PSEHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_dim=256,
|
||||
out_channels=7,
|
||||
**kwargs):
|
||||
super(PSEHead, self).__init__()
|
||||
self.conv1 = nn.Conv2D(in_channels, hidden_dim, kernel_size=3, stride=1, padding=1)
|
||||
self.bn1 = nn.BatchNorm2D(hidden_dim)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
self.conv2 = nn.Conv2D(hidden_dim, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
out = self.conv1(x)
|
||||
out = self.relu1(self.bn1(out))
|
||||
out = self.conv2(out)
|
||||
return {'maps': out}
|
|
@ -28,13 +28,14 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
|
|||
TableLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .pse_postprocess import PSEPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode'
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode', 'PSEPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .pse_postprocess import PSEPostProcess
|
|
@ -0,0 +1,15 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .pse import pse
|
|
@ -0,0 +1,70 @@
|
|||
|
||||
import numpy as np
|
||||
import cv2
|
||||
cimport numpy as np
|
||||
cimport cython
|
||||
cimport libcpp
|
||||
cimport libcpp.pair
|
||||
cimport libcpp.queue
|
||||
from libcpp.pair cimport *
|
||||
from libcpp.queue cimport *
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
|
||||
np.ndarray[np.int32_t, ndim=2] label,
|
||||
int kernel_num,
|
||||
int label_num,
|
||||
float min_area=0):
|
||||
cdef np.ndarray[np.int32_t, ndim=2] pred
|
||||
pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
|
||||
|
||||
for label_idx in range(1, label_num):
|
||||
if np.sum(label == label_idx) < min_area:
|
||||
label[label == label_idx] = 0
|
||||
|
||||
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
|
||||
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
|
||||
cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
|
||||
queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
|
||||
cdef np.int16_t* dx = [-1, 1, 0, 0]
|
||||
cdef np.int16_t* dy = [0, 0, -1, 1]
|
||||
cdef np.int16_t tmpx, tmpy
|
||||
|
||||
points = np.array(np.where(label > 0)).transpose((1, 0))
|
||||
for point_idx in range(points.shape[0]):
|
||||
tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
|
||||
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
|
||||
pred[tmpx, tmpy] = label[tmpx, tmpy]
|
||||
|
||||
cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
|
||||
cdef int cur_label
|
||||
for kernel_idx in range(kernel_num - 1, -1, -1):
|
||||
while not que.empty():
|
||||
cur = que.front()
|
||||
que.pop()
|
||||
cur_label = pred[cur.first, cur.second]
|
||||
|
||||
is_edge = True
|
||||
for j in range(4):
|
||||
tmpx = cur.first + dx[j]
|
||||
tmpy = cur.second + dy[j]
|
||||
if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
|
||||
continue
|
||||
if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
|
||||
continue
|
||||
|
||||
que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
|
||||
pred[tmpx, tmpy] = cur_label
|
||||
is_edge = False
|
||||
if is_edge:
|
||||
nxt_que.push(cur)
|
||||
|
||||
que, nxt_que = nxt_que, que
|
||||
|
||||
return pred
|
||||
|
||||
def pse(kernels, min_area):
|
||||
kernel_num = kernels.shape[0]
|
||||
label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
|
||||
return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
|
||||
from ppocr.postprocess.pse_postprocess.pse import pse
|
||||
|
||||
|
||||
class PSEPostProcess(object):
|
||||
"""
|
||||
The post process for PSE.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
thresh=0.5,
|
||||
box_thresh=0.85,
|
||||
min_area=16,
|
||||
box_type='box',
|
||||
scale=4,
|
||||
**kwargs):
|
||||
assert box_type in ['box', 'poly'], 'Only box and poly is supported'
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
self.min_area = min_area
|
||||
self.box_type = box_type
|
||||
self.scale = scale
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict['maps']
|
||||
if not isinstance(pred, paddle.Tensor):
|
||||
pred = paddle.to_tensor(pred)
|
||||
pred = F.interpolate(pred, scale_factor=4 // self.scale, mode='bilinear')
|
||||
|
||||
score = F.sigmoid(pred[:, 0, :, :])
|
||||
|
||||
kernels = (pred > self.thresh).astype('float32')
|
||||
text_mask = kernels[:, 0, :, :]
|
||||
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
|
||||
|
||||
score = score.numpy()
|
||||
kernels = kernels.numpy().astype(np.uint8)
|
||||
|
||||
boxes_batch = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
||||
boxes, scores = self.boxes_from_bitmap(score[batch_index], kernels[batch_index], src_h, src_w)
|
||||
|
||||
boxes_batch.append({'points': boxes, 'scores': scores})
|
||||
return boxes_batch
|
||||
|
||||
def boxes_from_bitmap(self, score, kernels, src_h, src_w):
|
||||
label = pse(kernels, self.min_area)
|
||||
return self.generate_box(score, label, src_h, src_w)
|
||||
|
||||
def generate_box(self, score, label, src_h, src_w):
|
||||
height, width = label.shape
|
||||
label_num = np.max(label) + 1
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
for i in range(1, label_num):
|
||||
ind = label == i
|
||||
points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
|
||||
|
||||
if points.shape[0] < self.min_area:
|
||||
label[ind] = 0
|
||||
continue
|
||||
|
||||
score_i = np.mean(score[ind])
|
||||
if score_i < self.box_thresh:
|
||||
label[ind] = 0
|
||||
continue
|
||||
|
||||
if self.box_type == 'box':
|
||||
rect = cv2.minAreaRect(points)
|
||||
bbox = cv2.boxPoints(rect)
|
||||
elif self.box_type == 'poly':
|
||||
box_height = np.max(points[:, 1]) + 10
|
||||
box_width = np.max(points[:, 0]) + 10
|
||||
|
||||
mask = np.zeros((box_height, box_width), np.uint8)
|
||||
mask[points[:, 1], points[:, 0]] = 255
|
||||
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
bbox = np.squeeze(contours[0], 1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
bbox[:, 0] = np.clip(
|
||||
np.round(bbox[:, 0] / width * src_w), 0, src_w)
|
||||
bbox[:, 1] = np.clip(
|
||||
np.round(bbox[:, 1] / height * src_h), 0, src_h)
|
||||
|
||||
boxes.append(bbox)
|
||||
scores.append(score_i)
|
||||
return boxes, scores
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
post = PSEPostProcess(thresh=0.5,
|
||||
box_thresh=0.85,
|
||||
min_area=16,
|
||||
box_type='poly',
|
||||
scale=4)
|
||||
out = np.load('/Users/zhoujun20/Desktop/工作相关/OCR/论文复现/pan_pp.pytorch/out.npy')
|
||||
res = np.load('/Users/zhoujun20/Desktop/工作相关/OCR/论文复现/pan_pp.pytorch/det_res.npy', allow_pickle=True).tolist()
|
||||
out = {'maps': paddle.to_tensor(out)}
|
||||
det_res = post(out, shape_list=[[720, 1280, 1, 1]])
|
||||
print(det_res)
|
||||
print(res)
|
|
@ -0,0 +1,48 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
def iou_single(a, b, mask, n_class):
|
||||
valid = mask == 1
|
||||
a = a.masked_select(valid)
|
||||
b = b.masked_select(valid)
|
||||
miou = []
|
||||
for i in range(n_class):
|
||||
if a.shape == [0] and a.shape==b.shape:
|
||||
inter = paddle.to_tensor(0.0)
|
||||
union = paddle.to_tensor(0.0)
|
||||
else:
|
||||
inter = ((a == i).logical_and(b == i)).astype('float32')
|
||||
union = ((a == i).logical_or(b == i)).astype('float32')
|
||||
miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
|
||||
miou = sum(miou) / len(miou)
|
||||
return miou
|
||||
|
||||
def iou(a, b, mask, n_class=2, reduce=True):
|
||||
batch_size = a.shape[0]
|
||||
|
||||
a = a.reshape([batch_size, -1])
|
||||
b = b.reshape([batch_size, -1])
|
||||
mask = mask.reshape([batch_size, -1])
|
||||
|
||||
iou = paddle.zeros((batch_size,), dtype='float32')
|
||||
for i in range(batch_size):
|
||||
iou[i] = iou_single(a[i], b[i], mask[i], n_class)
|
||||
|
||||
if reduce:
|
||||
iou = paddle.mean(iou)
|
||||
return iou
|
|
@ -395,7 +395,7 @@ def preprocess(is_train=False):
|
|||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'TableAttn'
|
||||
'CLS', 'PGNet', 'Distillation', 'TableAttn', 'PSE'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
Loading…
Reference in New Issue