ADD PGNet_v2

This commit is contained in:
Jethong 2021-03-08 15:11:57 +08:00
parent 1f76f449db
commit bb49e1a53f
10 changed files with 227 additions and 1226 deletions

View File

@ -37,6 +37,7 @@ class ClsLabelEncode(object):
class E2ELabelEncode(object):
def __init__(self, label_list, **kwargs):
self.label_list = label_list
self.max_len = 50
def __call__(self, data):
text_label_index_list, temp_text = [], []
@ -47,7 +48,7 @@ class E2ELabelEncode(object):
for c_ in text:
if c_ in self.label_list:
temp_text.append(self.label_list.index(c_))
temp_text = temp_text + [36] * (50 - len(temp_text))
temp_text = temp_text + [36] * (self.max_len - len(temp_text))
text_label_index_list.append(temp_text)
data['strs'] = np.array(text_label_index_list)
return data

View File

@ -32,16 +32,6 @@ class E2EMetric(object):
self.reset()
def __call__(self, preds, batch, **kwargs):
'''
batch: a list produced by dataloaders.
image: np.ndarray of shape (N, C, H, W).
ratio_list: np.ndarray of shape(N,2)
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
preds: a list of dict produced by post process
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
'''
gt_polyons_batch = batch[2]
temp_gt_strs_batch = batch[3]
ignore_tags_batch = batch[4]
@ -72,13 +62,6 @@ class E2EMetric(object):
self.results.append(result)
def get_metric(self):
"""
return metrics {
'precision': 0,
'recall': 0,
'hmean': 0
}
"""
metircs = combine_results(self.results)
self.reset()
return metircs

View File

@ -106,172 +106,212 @@ class DeConvBNLayer(nn.Layer):
return x
class FPN_Up_Fusion(nn.Layer):
def __init__(self, in_channels):
super(FPN_Up_Fusion, self).__init__()
in_channels = in_channels[::-1]
out_channels = [256, 256, 192, 192, 128]
class PGFPN(nn.Layer):
def __init__(self, in_channels, **kwargs):
super(PGFPN, self).__init__()
num_inputs = [2048, 2048, 1024, 512, 256]
num_outputs = [256, 256, 192, 192, 128]
self.out_channels = 128
# print(in_channels)
self.conv_bn_layer_1 = ConvBNLayer(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=1,
act=None,
name='FPN_d1')
self.conv_bn_layer_2 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
act=None,
name='FPN_d2')
self.conv_bn_layer_3 = ConvBNLayer(
in_channels=256,
out_channels=128,
kernel_size=3,
stride=1,
act=None,
name='FPN_d3')
self.conv_bn_layer_4 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=2,
act=None,
name='FPN_d4')
self.conv_bn_layer_5 = ConvBNLayer(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name='FPN_d5')
self.conv_bn_layer_6 = ConvBNLayer(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=2,
act=None,
name='FPN_d6')
self.conv_bn_layer_7 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
act='relu',
name='FPN_d7')
self.conv_bn_layer_8 = ConvBNLayer(
in_channels=128,
out_channels=128,
kernel_size=1,
stride=1,
act=None,
name='FPN_d8')
self.h0_conv = ConvBNLayer(
in_channels[0], out_channels[0], 1, 1, act=None, name='conv_h0')
self.h1_conv = ConvBNLayer(
in_channels[1], out_channels[1], 1, 1, act=None, name='conv_h1')
self.h2_conv = ConvBNLayer(
in_channels[2], out_channels[2], 1, 1, act=None, name='conv_h2')
self.h3_conv = ConvBNLayer(
in_channels[3], out_channels[3], 1, 1, act=None, name='conv_h3')
self.h4_conv = ConvBNLayer(
in_channels[4], out_channels[4], 1, 1, act=None, name='conv_h4')
self.conv_h0 = ConvBNLayer(
in_channels=num_inputs[0],
out_channels=num_outputs[0],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(0))
self.conv_h1 = ConvBNLayer(
in_channels=num_inputs[1],
out_channels=num_outputs[1],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(1))
self.conv_h2 = ConvBNLayer(
in_channels=num_inputs[2],
out_channels=num_outputs[2],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(2))
self.conv_h3 = ConvBNLayer(
in_channels=num_inputs[3],
out_channels=num_outputs[3],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(3))
self.conv_h4 = ConvBNLayer(
in_channels=num_inputs[4],
out_channels=num_outputs[4],
kernel_size=1,
stride=1,
act=None,
name="conv_h{}".format(4))
self.dconv0 = DeConvBNLayer(
in_channels=out_channels[0],
out_channels=out_channels[1],
in_channels=num_outputs[0],
out_channels=num_outputs[0 + 1],
name="dconv_{}".format(0))
self.dconv1 = DeConvBNLayer(
in_channels=out_channels[1],
out_channels=out_channels[2],
in_channels=num_outputs[1],
out_channels=num_outputs[1 + 1],
act=None,
name="dconv_{}".format(1))
self.dconv2 = DeConvBNLayer(
in_channels=out_channels[2],
out_channels=out_channels[3],
in_channels=num_outputs[2],
out_channels=num_outputs[2 + 1],
act=None,
name="dconv_{}".format(2))
self.dconv3 = DeConvBNLayer(
in_channels=out_channels[3],
out_channels=out_channels[4],
in_channels=num_outputs[3],
out_channels=num_outputs[3 + 1],
act=None,
name="dconv_{}".format(3))
self.conv_g1 = ConvBNLayer(
in_channels=out_channels[1],
out_channels=out_channels[1],
in_channels=num_outputs[1],
out_channels=num_outputs[1],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(1))
self.conv_g2 = ConvBNLayer(
in_channels=out_channels[2],
out_channels=out_channels[2],
in_channels=num_outputs[2],
out_channels=num_outputs[2],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(2))
self.conv_g3 = ConvBNLayer(
in_channels=out_channels[3],
out_channels=out_channels[3],
in_channels=num_outputs[3],
out_channels=num_outputs[3],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(3))
self.conv_g4 = ConvBNLayer(
in_channels=out_channels[4],
out_channels=out_channels[4],
in_channels=num_outputs[4],
out_channels=num_outputs[4],
kernel_size=3,
stride=1,
act='relu',
name="conv_g{}".format(4))
self.convf = ConvBNLayer(
in_channels=out_channels[4],
out_channels=out_channels[4],
in_channels=num_outputs[4],
out_channels=num_outputs[4],
kernel_size=1,
stride=1,
act=None,
name="conv_f{}".format(4))
def _add_relu(self, x1, x2):
x = paddle.add(x=x1, y=x2)
x = F.relu(x)
return x
def forward(self, x):
f = x[2:][::-1]
h0 = self.h0_conv(f[0])
h1 = self.h1_conv(f[1])
h2 = self.h2_conv(f[2])
h3 = self.h3_conv(f[3])
h4 = self.h4_conv(f[4])
c0, c1, c2, c3, c4, c5, c6 = x
# FPN_Down_Fusion
f = [c0, c1, c2]
g = [None, None, None]
h = [None, None, None]
h[0] = self.conv_bn_layer_1(f[0])
h[1] = self.conv_bn_layer_2(f[1])
h[2] = self.conv_bn_layer_3(f[2])
g0 = self.dconv0(h0)
g[0] = self.conv_bn_layer_4(h[0])
g[1] = paddle.add(g[0], h[1])
g[1] = F.relu(g[1])
g[1] = self.conv_bn_layer_5(g[1])
g[1] = self.conv_bn_layer_6(g[1])
g1 = self.dconv2(self.conv_g2(self._add_relu(g0, h1)))
g2 = self.dconv2(self.conv_g2(self._add_relu(g1, h2)))
g3 = self.dconv3(self.conv_g2(self._add_relu(g2, h3)))
g4 = self.dconv4(self.conv_g2(self._add_relu(g3, h4)))
return g4
g[2] = paddle.add(g[1], h[2])
g[2] = F.relu(g[2])
g[2] = self.conv_bn_layer_7(g[2])
f_down = self.conv_bn_layer_8(g[2])
# FPN UP Fusion
f1 = [c6, c5, c4, c3, c2]
g = [None, None, None, None, None]
h = [None, None, None, None, None]
h[0] = self.conv_h0(f1[0])
h[1] = self.conv_h1(f1[1])
h[2] = self.conv_h2(f1[2])
h[3] = self.conv_h3(f1[3])
h[4] = self.conv_h4(f1[4])
class FPN_Down_Fusion(nn.Layer):
def __init__(self, in_channels):
super(FPN_Down_Fusion, self).__init__()
out_channels = [32, 64, 128]
g[0] = self.dconv0(h[0])
g[1] = paddle.add(g[0], h[1])
g[1] = F.relu(g[1])
g[1] = self.conv_g1(g[1])
g[1] = self.dconv1(g[1])
self.h0_conv = ConvBNLayer(
in_channels[0], out_channels[0], 3, 1, act=None, name='FPN_d1')
self.h1_conv = ConvBNLayer(
in_channels[1], out_channels[1], 3, 1, act=None, name='FPN_d2')
self.h2_conv = ConvBNLayer(
in_channels[2], out_channels[2], 3, 1, act=None, name='FPN_d3')
g[2] = paddle.add(g[1], h[2])
g[2] = F.relu(g[2])
g[2] = self.conv_g2(g[2])
g[2] = self.dconv2(g[2])
self.g0_conv = ConvBNLayer(
out_channels[0], out_channels[1], 3, 2, act=None, name='FPN_d4')
g[3] = paddle.add(g[2], h[3])
g[3] = F.relu(g[3])
g[3] = self.conv_g3(g[3])
g[3] = self.dconv3(g[3])
self.g1_conv = nn.Sequential(
ConvBNLayer(
out_channels[1],
out_channels[1],
3,
1,
act='relu',
name='FPN_d5'),
ConvBNLayer(
out_channels[1], out_channels[2], 3, 2, act=None,
name='FPN_d6'))
self.g2_conv = nn.Sequential(
ConvBNLayer(
out_channels[2],
out_channels[2],
3,
1,
act='relu',
name='FPN_d7'),
ConvBNLayer(
out_channels[2], out_channels[2], 1, 1, act=None,
name='FPN_d8'))
def forward(self, x):
f = x[:3]
h0 = self.h0_conv(f[0])
h1 = self.h1_conv(f[1])
h2 = self.h2_conv(f[2])
g0 = self.g0_conv(h0)
g1 = paddle.add(x=g0, y=h1)
g1 = F.relu(g1)
g1 = self.g1_conv(g1)
g2 = paddle.add(x=g1, y=h2)
g2 = F.relu(g2)
g2 = self.g2_conv(g2)
return g2
class PGFPN(nn.Layer):
def __init__(self, in_channels, with_cab=False, **kwargs):
super(PGFPN, self).__init__()
self.in_channels = in_channels
self.with_cab = with_cab
self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels)
self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels)
self.out_channels = 128
def forward(self, x):
# down fpn
f_down = self.FPN_Down_Fusion(x)
# up fpn
f_up = self.FPN_Up_Fusion(x)
# fusion
f_common = paddle.add(x=f_down, y=f_up)
g[4] = paddle.add(x=g[3], y=h[4])
g[4] = F.relu(g[4])
g[4] = self.conv_g4(g[4])
f_up = self.convf(g[4])
f_common = paddle.add(f_down, f_up)
f_common = F.relu(f_common)
return f_common

View File

@ -1,9 +1,18 @@
from os import listdir
import os, sys
from scipy import io
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
from tqdm import tqdm
try: # python2
range = xrange
@ -862,16 +871,3 @@ def combine_results(all_data):
'f_score_e2e': f_score_e2e
}
return final
# a = [1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, 1659, 620, 1654, 681, 1631, 680, 1618,
# 681, 1606, 681, 1594, 681, 1584, 682, 1573, 685, 1542, 694]
# gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
# pred_dict = [{'points': np.array(a), 'text': 'ccc'},
# {'points': np.array(a), 'text': 'ccf'}]
# result = []
# for i in range(2):
# result.append(get_socre(gt_dict, pred_dict))
# print(111)
# a = combine_results(result)
# print(a)

View File

@ -1,6 +1,18 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from shapely.geometry import Polygon
#import Polygon
"""
:param det_x: [1, N] Xs of detection's vertices
:param det_y: [1, N] Ys of detection's vertices

View File

@ -1,881 +0,0 @@
from os import listdir
import os, sys
from scipy import io
import numpy as np
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
from tqdm import tqdm
try: # python2
range = xrange
except Exception:
# python3
range = range
"""
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')'
"""
# if len(sys.argv) != 4:
# print('\n usage: test.py pred_dir gt_dir savefile')
# sys.exit()
global_tp = 0
global_fp = 0
global_fn = 0
tr = 0.7
tp = 0.6
fsc_k = 0.8
k = 2
def get_socre(gt_dict, pred_dict):
# allInputs = listdir(input_dir)
allInputs = 1
global_pred_str = []
global_gt_str = []
global_sigma = []
global_tau = []
def input_reading_mod(pred_dict, input):
"""This helper reads input from txt files"""
det = []
n = len(pred_dict)
for i in range(n):
points = pred_dict[i]['points']
text = pred_dict[i]['text']
# for i in range(len(points)):
point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text])
return det
def gt_reading_mod(gt_dict, gt_id):
"""This helper reads groundtruths from mat files"""
# gt_id = gt_id.split('.')[0]
gt = []
n = len(gt_dict)
for i in range(n):
points = gt_dict[i]['points'].tolist()
h = len(points)
text = gt_dict[i]['text']
xx = [
np.array(
['x:'], dtype='<U2'), 0, np.array(
['y:'], dtype='<U2'), 0, np.array(
['#'], dtype='<U1'), np.array(
['#'], dtype='<U1')
]
t_x, t_y = [], []
for j in range(h):
t_x.append(points[j][0])
t_y.append(points[j][1])
xx[1] = np.array([t_x], dtype='int16')
xx[3] = np.array([t_y], dtype='int16')
if text != "":
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
xx[5] = np.array(['c'], dtype='<U1')
gt.append(xx)
return gt
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt_id, gt in enumerate(groundtruths):
print
"liushanshan gt[1] = {}".format(gt[1])
print
"liushanshan gt[2] = {}".format(gt[2])
print
"liushanshan gt[3] = {}".format(gt[3])
print
"liushanshan gt[4] = {}".format(gt[4])
print
"liushanshan gt[5] = {}".format(gt[5])
if (gt[5] == '#') and (gt[1].shape[1] > 1):
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
# detection = detection.split(',')
detection = list(map(int, detection))
det_x = detection[0::2]
det_y = detection[1::2]
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
if det_gt_iou > threshold:
detections[det_id] = []
detections[:] = [item for item in detections if item != []]
return detections
def sigma_calculation(det_x, det_y, gt_x, gt_y):
"""
sigma = inter_area / gt_area
"""
# print(area_of_intersection(det_x, det_y, gt_x, gt_y))
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(gt_x, gt_y)), 2)
def tau_calculation(det_x, det_y, gt_x, gt_y):
"""
tau = inter_area / det_area
"""
# print "liushanshan det_x {}".format(det_x)
# print "liushanshan det_y {}".format(det_y)
# print "liushanshan area {}".format(area(det_x, det_y))
# print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2))
if area(det_x, det_y) == 0.0:
return 0
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
area(det_x, det_y)), 2)
##############################Initialization###################################
###############################################################################
single_data = {}
for input_id in range(allInputs):
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
and (input_id != 'Deteval_result_non_curved.txt'):
print(input_id)
detections = input_reading_mod(pred_dict, input_id)
# print "liushanshan detections = {}".format(detections)
groundtruths = gt_reading_mod(gt_dict, input_id)
detections = detection_filtering(
detections,
groundtruths) # filters detections overlapping with DC area
dc_id = []
for i in range(len(groundtruths)):
if groundtruths[i][5] == '#':
dc_id.append(i)
cnt = 0
for a in dc_id:
num = a - cnt
del groundtruths[num]
cnt += 1
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
local_tau_table = np.zeros((len(groundtruths), len(detections)))
local_pred_str = {}
local_gt_str = {}
for gt_id, gt in enumerate(groundtruths):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
detection_orig = detection
detection = [float(x) for x in detection[0].split(',')]
detection = list(map(int, detection))
pred_seq_str = detection_orig[1].strip()
det_x = detection[0::2]
det_y = detection[1::2]
gt_x = list(map(int, np.squeeze(gt[1])))
gt_y = list(map(int, np.squeeze(gt[3])))
gt_seq_str = str(gt[4].tolist()[0])
local_sigma_table[gt_id, det_id] = sigma_calculation(
det_x, det_y, gt_x, gt_y)
local_tau_table[gt_id, det_id] = tau_calculation(
det_x, det_y, gt_x, gt_y)
local_pred_str[det_id] = pred_seq_str
local_gt_str[gt_id] = gt_seq_str
global_sigma.append(local_sigma_table)
global_tau.append(local_tau_table)
global_pred_str.append(local_pred_str)
global_gt_str.append(local_gt_str)
print
"liushanshan global_pred_str = {}".format(global_pred_str)
print
"liushanshan global_gt_str = {}".format(global_gt_str)
single_data['sigma'] = global_sigma
single_data['global_tau'] = global_tau
single_data['global_pred_str'] = global_pred_str
single_data['global_gt_str'] = global_gt_str
return single_data
def combine_results(all_data):
global_sigma, global_tau, global_pred_str, global_gt_str = [], [], [], []
for data in all_data:
global_sigma.append(data['sigma'])
global_tau.append(data['global_tau'])
global_pred_str.append(data['global_pred_str'])
global_gt_str.append(data['global_gt_str'])
global_accumulative_recall = 0
global_accumulative_precision = 0
total_num_gt = 0
total_num_det = 0
hit_str_count = 0
hit_count = 0
def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where(
local_sigma_table[gt_id, :] > tr)
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
0].shape[0]
gt_matching_qualified_tau_candidates = np.where(
local_tau_table[gt_id, :] > tp)
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
0].shape[0]
det_matching_qualified_sigma_candidates = np.where(
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
> tr)
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
0].shape[0]
det_matching_qualified_tau_candidates = np.where(
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
tp)
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
0].shape[0]
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
(det_matching_num_qualified_sigma_candidates == 1) and (
det_matching_num_qualified_tau_candidates == 1):
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start
print
"liushanshan one to one det_id = {}".format(matched_det_id)
print
"liushanshan one to one gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
0]]
print
"liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
det_flag[0, matched_det_id] = 1
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for gt_id in range(num_gt):
# skip the following if the groundtruth was matched
if gt_flag[0, gt_id] > 0:
continue
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
if num_non_zero_in_sigma >= k:
####search for all detections that overlaps with this groundtruth
qualified_tau_candidates = np.where((local_tau_table[
gt_id, :] >= tp) & (det_flag[0, :] == 0))
num_qualified_tau_candidates = qualified_tau_candidates[
0].shape[0]
if num_qualified_tau_candidates == 1:
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
and
(local_sigma_table[gt_id, qualified_tau_candidates] >=
tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
>= tr):
gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1
# recg start
print
"liushanshan one to many det_id = {}".format(
qualified_tau_candidates)
print
"liushanshan one to many gt_id = {}".format(gt_id)
gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]]
print
"liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
print
"liushanshan one to many pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
# recg end
global_accumulative_recall = global_accumulative_recall + fsc_k
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
local_accumulative_recall = local_accumulative_recall + fsc_k
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy):
hit_str_num = 0
for det_id in range(num_det):
# skip the following if the detection was matched
if det_flag[0, det_id] > 0:
continue
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
if num_non_zero_in_tau >= k:
####search for all detections that overlaps with this groundtruth
qualified_sigma_candidates = np.where((
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
num_qualified_sigma_candidates = qualified_sigma_candidates[
0].shape[0]
if num_qualified_sigma_candidates == 1:
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
tp) and
(local_sigma_table[qualified_sigma_candidates, det_id]
>= tr)):
# became an one-to-one case
global_accumulative_recall = global_accumulative_recall + 1.0
global_accumulative_precision = global_accumulative_precision + 1.0
local_accumulative_recall = local_accumulative_recall + 1.0
local_accumulative_precision = local_accumulative_precision + 1.0
gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
# recg end
elif (np.sum(local_tau_table[qualified_sigma_candidates,
det_id]) >= tp):
det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1
# recg start
print
"liushanshan many to one det_id = {}".format(det_id)
print
"liushanshan many to one gt_id = {}".format(
qualified_sigma_candidates)
pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
if not global_gt_str[idy].has_key(ele_gt_id):
continue
gt_str_cur = global_gt_str[idy][ele_gt_id]
print
"liushanshan many to one gt_str_cur = {}".format(
gt_str_cur)
print
"liushanshan many to one pred_str_cur = {}".format(
pred_str_cur)
if pred_str_cur == gt_str_cur:
hit_str_num += 1
break
else:
if pred_str_cur.lower() == gt_str_cur.lower():
hit_str_num += 1
break
else:
print
'no match'
# recg end
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
global_accumulative_precision = global_accumulative_precision + fsc_k
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
local_accumulative_precision = local_accumulative_precision + fsc_k
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
for idx in range(len(global_sigma)):
# print(allInputs[idx])
local_sigma_table = np.array(global_sigma[idx])
local_tau_table = global_tau[idx]
num_gt = local_sigma_table.shape[0]
num_det = local_sigma_table.shape[1]
total_num_gt = total_num_gt + num_gt
total_num_det = total_num_det + num_det
local_accumulative_recall = 0
local_accumulative_precision = 0
gt_flag = np.zeros((1, num_gt))
det_flag = np.zeros((1, num_det))
#######first check for one-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for one-to-many case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
hit_str_count += hit_str_num
#######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx)
try:
recall = global_accumulative_recall / total_num_gt
except ZeroDivisionError:
recall = 0
try:
precision = global_accumulative_precision / total_num_det
except ZeroDivisionError:
precision = 0
try:
f_score = 2 * precision * recall / (precision + recall)
except ZeroDivisionError:
f_score = 0
try:
seqerr = 1 - float(hit_str_count) / global_accumulative_recall
except ZeroDivisionError:
seqerr = 1
try:
recall_e2e = float(hit_str_count) / total_num_gt
except ZeroDivisionError:
recall_e2e = 0
try:
precision_e2e = float(hit_str_count) / total_num_det
except ZeroDivisionError:
precision_e2e = 0
try:
f_score_e2e = 2 * precision_e2e * recall_e2e / (
precision_e2e + recall_e2e)
except ZeroDivisionError:
f_score_e2e = 0
final = {
'total_num_gt': total_num_gt,
'total_num_det': total_num_det,
'global_accumulative_recall': global_accumulative_recall,
'hit_str_count': hit_str_count,
'recall': recall,
'precision': precision,
'f_score': f_score,
'seqerr': seqerr,
'recall_e2e': recall_e2e,
'precision_e2e': precision_e2e,
'f_score_e2e': f_score_e2e
}
return final
# def combine_results(all_data):
# tr = 0.7
# tp = 0.6
# fsc_k = 0.8
# k = 2
# global_sigma = []
# global_tau = []
# global_pred_str = []
# global_gt_str = []
# for data in all_data:
# global_sigma.append(data['sigma'])
# global_tau.append(data['global_tau'])
# global_pred_str.append(data['global_pred_str'])
# global_gt_str.append(data['global_gt_str'])
#
# global_accumulative_recall = 0
# global_accumulative_precision = 0
# total_num_gt = 0
# total_num_det = 0
# hit_str_count = 0
# hit_count = 0
#
# def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for gt_id in range(num_gt):
# gt_matching_qualified_sigma_candidates = np.where(local_sigma_table[gt_id, :] > tr)
# gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[0].shape[0]
# gt_matching_qualified_tau_candidates = np.where(local_tau_table[gt_id, :] > tp)
# gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[0].shape[0]
#
# det_matching_qualified_sigma_candidates = np.where(
# local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] > tr)
# det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[0].shape[0]
# det_matching_qualified_tau_candidates = np.where(
# local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > tp)
# det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[0].shape[0]
#
# if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
# (det_matching_num_qualified_sigma_candidates == 1) and (
# det_matching_num_qualified_tau_candidates == 1):
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, gt_id] = 1
# matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# # recg start
# print
# "liushanshan one to one det_id = {}".format(matched_det_id)
# print
# "liushanshan one to one gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[0]]
# print
# "liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
# det_flag[0, matched_det_id] = 1
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for gt_id in range(num_gt):
# # skip the following if the groundtruth was matched
# if gt_flag[0, gt_id] > 0:
# continue
#
# non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
# num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
#
# if num_non_zero_in_sigma >= k:
# ####search for all detections that overlaps with this groundtruth
# qualified_tau_candidates = np.where((local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0))
# num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0]
#
# if num_qualified_tau_candidates == 1:
# if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) and (
# local_sigma_table[gt_id, qualified_tau_candidates] >= tr)):
# # became an one-to-one case
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, gt_id] = 1
# det_flag[0, qualified_tau_candidates] = 1
# # recg start
# print
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
# print
# "liushanshan one to many gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
# print
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
# elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr):
# gt_flag[0, gt_id] = 1
# det_flag[0, qualified_tau_candidates] = 1
# # recg start
# print
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
# print
# "liushanshan one to many gt_id = {}".format(gt_id)
# gt_str_cur = global_gt_str[idy][gt_id]
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
# print
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# # recg end
#
# global_accumulative_recall = global_accumulative_recall + fsc_k
# global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
#
# local_accumulative_recall = local_accumulative_recall + fsc_k
# local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
#
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idy):
# hit_str_num = 0
# for det_id in range(num_det):
# # skip the following if the detection was matched
# if det_flag[0, det_id] > 0:
# continue
#
# non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
# num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
#
# if num_non_zero_in_tau >= k:
# ####search for all detections that overlaps with this groundtruth
# qualified_sigma_candidates = np.where((local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
# num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0]
#
# if num_qualified_sigma_candidates == 1:
# if ((local_tau_table[qualified_sigma_candidates, det_id] >= tp) and (
# local_sigma_table[qualified_sigma_candidates, det_id] >= tr)):
# # became an one-to-one case
# global_accumulative_recall = global_accumulative_recall + 1.0
# global_accumulative_precision = global_accumulative_precision + 1.0
# local_accumulative_recall = local_accumulative_recall + 1.0
# local_accumulative_precision = local_accumulative_precision + 1.0
#
# gt_flag[0, qualified_sigma_candidates] = 1
# det_flag[0, det_id] = 1
# # recg start
# print
# "liushanshan many to one det_id = {}".format(det_id)
# print
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
# pred_str_cur = global_pred_str[idy][det_id]
# gt_len = len(qualified_sigma_candidates[0])
# for idx in range(gt_len):
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
# if ele_gt_id not in global_gt_str[idy]:
# continue
# gt_str_cur = global_gt_str[idy][ele_gt_id]
# print
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# break
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# break
# # recg end
# elif (np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp):
# det_flag[0, det_id] = 1
# gt_flag[0, qualified_sigma_candidates] = 1
# # recg start
# print
# "liushanshan many to one det_id = {}".format(det_id)
# print
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
# pred_str_cur = global_pred_str[idy][det_id]
# gt_len = len(qualified_sigma_candidates[0])
# for idx in range(gt_len):
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
# if not global_gt_str[idy].has_key(ele_gt_id):
# continue
# gt_str_cur = global_gt_str[idy][ele_gt_id]
# print
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
# print
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
# if pred_str_cur == gt_str_cur:
# hit_str_num += 1
# break
# else:
# if pred_str_cur.lower() == gt_str_cur.lower():
# hit_str_num += 1
# break
# else:
# print
# 'no match'
# # recg end
#
# global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
# global_accumulative_precision = global_accumulative_precision + fsc_k
#
# local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
# local_accumulative_precision = local_accumulative_precision + fsc_k
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
#
# for idx in range(len(global_sigma)):
# local_sigma_table = np.array(global_sigma[idx])
# local_tau_table = np.array(global_tau[idx])
#
# num_gt = local_sigma_table.shape[0]
# num_det = local_sigma_table.shape[1]
#
# total_num_gt = total_num_gt + num_gt
# total_num_det = total_num_det + num_det
#
# local_accumulative_recall = 0
# local_accumulative_precision = 0
# gt_flag = np.zeros((1, num_gt))
# det_flag = np.zeros((1, num_det))
#
# #######first check for one-to-one case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
#
# hit_str_count += hit_str_num
# #######then check for one-to-many case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
# hit_str_count += hit_str_num
# #######then check for many-to-one case##########
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
# gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
# local_accumulative_recall, local_accumulative_precision,
# global_accumulative_recall, global_accumulative_precision,
# gt_flag, det_flag, idx)
# try:
# recall = global_accumulative_recall / total_num_gt
# except ZeroDivisionError:
# recall = 0
#
# try:
# precision = global_accumulative_precision / total_num_det
# except ZeroDivisionError:
# precision = 0
#
# try:
# f_score = 2 * precision * recall / (precision + recall)
# except ZeroDivisionError:
# f_score = 0
#
# try:
# seqerr = 1 - float(hit_str_count) / global_accumulative_recall
# except ZeroDivisionError:
# seqerr = 1
#
# try:
# recall_e2e = float(hit_str_count) / total_num_gt
# except ZeroDivisionError:
# recall_e2e = 0
#
# try:
# precision_e2e = float(hit_str_count) / total_num_det
# except ZeroDivisionError:
# precision_e2e = 0
#
# try:
# f_score_e2e = 2 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e)
# except ZeroDivisionError:
# f_score_e2e = 0
#
# final = {
# 'total_num_gt': total_num_gt,
# 'total_num_det': total_num_det,
# 'global_accumulative_recall': global_accumulative_recall,
# 'hit_str_count': hit_str_count,
# 'recall': recall,
# 'precision': precision,
# 'f_score': f_score,
# 'seqerr': seqerr,
# 'recall_e2e': recall_e2e,
# 'precision_e2e': precision_e2e,
# 'f_score_e2e': f_score_e2e
# }
# return final
a = [
1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620,
1659, 620, 1654, 681, 1631, 680, 1618, 681, 1606, 681, 1594, 681, 1584, 682,
1573, 685, 1542, 694
]
gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
pred_dict = [{
'points': np.array(a),
'text': 'ccc'
}, {
'points': np.array(a),
'text': 'ccf'
}]
result = []
result.append(get_socre(gt_dict, gt_dict))
a = combine_results(result)
print(a)

View File

@ -1,3 +1,16 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains various CTC decoders."""
from __future__ import absolute_import
from __future__ import division

View File

@ -1,6 +1,16 @@
"""
Algorithms for computing the skeleton of a binary image
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from scipy import ndimage as ndi

View File

@ -1,147 +1,21 @@
import os
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import cv2
import time
def visualize_e2e_result(im_fn, poly_list, seq_strs, src_im):
"""
"""
result_path = './out'
im_basename = os.path.basename(im_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
vis_det_img = src_im.copy()
valid_set = 'partvgg'
gt_dir = "/Users/hongyongjie/Downloads/part_vgg_synth/train"
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for line in lines:
if valid_set == 'partvgg':
tokens = line.strip().split('\t')[0].split(',')
# tokens = line.strip().split(',')
coords = tokens[:]
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, 4, 2)
elif valid_set == 'totaltext':
tokens = line.strip().split('\t')[0].split(',')
coords = tokens[:]
coords_len = len(coords) / 2
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, coords_len, 2)
cv2.polylines(
vis_det_img,
np.array(gt_poly).astype(np.int32),
isClosed=True,
color=(255, 0, 0),
thickness=2)
for detected_poly, recognized_str in zip(poly_list, seq_strs):
cv2.polylines(
vis_det_img,
np.array(detected_poly[np.newaxis, ...]).astype(np.int32),
isClosed=True,
color=(0, 0, 255),
thickness=2)
cv2.putText(
vis_det_img,
recognized_str,
org=(int(detected_poly[0, 0]), int(detected_poly[0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite("{}/{}_detection.jpg".format(result_path, im_prefix),
vis_det_img)
def visualization_output(src_image,
f_tcl,
f_chars,
output_dir,
image_prefix=None):
"""
"""
# restore BGR image, CHW -> HWC
im_mean = [0.485, 0.456, 0.406]
im_std = [0.229, 0.224, 0.225]
im_mean = np.array(im_mean).reshape((3, 1, 1))
im_std = np.array(im_std).reshape((3, 1, 1))
src_image *= im_std
src_image += im_mean
src_image = src_image.transpose([1, 2, 0])
src_image = src_image[:, :, ::-1] * 255 # BGR -> RGB
H, W, _ = src_image.shape
file_prefix = image_prefix if image_prefix is not None else str(
int(time.time() * 1000))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# visualization f_tcl
tcl_file_name = os.path.join(output_dir, file_prefix + '_0_tcl.jpg')
vis_tcl_img = src_image.copy()
f_tcl_resized = cv2.resize(f_tcl, dsize=(W, H))
vis_tcl_img[:, :, 1] = f_tcl_resized * 255
cv2.imwrite(tcl_file_name, vis_tcl_img)
# visualization char maps
vis_char_img = src_image.copy()
# CHW -> HWC
char_file_name = os.path.join(output_dir, file_prefix + '_1_chars.jpg')
f_chars = np.argmax(f_chars, axis=2)[:, :, np.newaxis].astype('float32')
f_chars[f_chars < 95] = 1.0
f_chars[f_chars == 95] = 0.0
f_chars_resized = cv2.resize(f_chars, dsize=(W, H))
vis_char_img[:, :, 1] = f_chars_resized * 255
cv2.imwrite(char_file_name, vis_char_img)
def visualize_point_result(im_fn, point_list, point_pair_list, src_im, gt_dir,
result_path):
"""
"""
im_basename = os.path.basename(im_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
vis_det_img = src_im.copy()
# draw gt bbox on the image.
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for line in lines:
tokens = line.strip().split('\t')
coords = tokens[0].split(',')
coords_len = len(coords)
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, coords_len / 2, 2)
cv2.polylines(
vis_det_img,
np.array(gt_poly).astype(np.int32),
isClosed=True,
color=(255, 255, 255),
thickness=1)
for point, point_pair in zip(point_list, point_pair_list):
cv2.line(
vis_det_img,
tuple(point_pair[0]),
tuple(point_pair[1]), (0, 255, 255),
thickness=1)
cv2.circle(vis_det_img, tuple(point), 2, (0, 0, 255))
cv2.circle(vis_det_img, tuple(point_pair[0]), 2, (255, 0, 0))
cv2.circle(vis_det_img, tuple(point_pair[1]), 2, (0, 255, 0))
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite("{}/{}_border_points.jpg".format(result_path, im_prefix),
vis_det_img)
def resize_image(im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
@ -295,49 +169,3 @@ def norm2(x, axis=None):
def cos(p1, p2):
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
def generate_direction_info(image_fn,
H,
W,
ratio_h,
ratio_w,
max_length=640,
out_scale=4,
gt_dir=None):
"""
"""
im_basename = os.path.basename(image_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
instance_direction_map = np.zeros(shape=[H // out_scale, W // out_scale, 3])
if gt_dir is None:
gt_dir = '/home/vis/huangzuming/data/SYNTH_DATA/part_vgg_synth_icdar/processed/val/poly'
# get gt label map
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for label_idx, line in enumerate(lines, start=1):
coords, txt = line.strip().split('\t')
if txt == '###':
continue
tokens = coords.strip().split(',')
coords = list(map(float, tokens))
poly = np.array(coords).reshape(4, 2) * np.array(
[ratio_w, ratio_h]).reshape(1, 2) / out_scale
mid_idx = poly.shape[0] // 2
direct_vector = (
(poly[mid_idx] + poly[mid_idx - 1]) - (poly[0] + poly[-1])) / 2.0
direct_vector /= len(txt)
# l2_distance = norm2(direct_vector)
# avg_char_distance = l2_distance / len(txt)
avg_char_distance = 1.0
direct_label = (direct_vector[0], direct_vector[1], avg_char_distance)
cv2.fillPoly(instance_direction_map,
poly.round().astype(np.int32)[np.newaxis, :, :],
direct_label)
instance_direction_map = instance_direction_map.transpose([2, 0, 1])
return instance_direction_map[:2, ...]

View File

@ -44,7 +44,6 @@ class ArgsParser(ArgumentParser):
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
args.config = '/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml'
assert args.config is not None, \
"Please specify --config=configure_file_path."
args.opt = self._parse_opt(args.opt)