ADD PGNet_v2
This commit is contained in:
parent
1f76f449db
commit
bb49e1a53f
|
@ -37,6 +37,7 @@ class ClsLabelEncode(object):
|
||||||
class E2ELabelEncode(object):
|
class E2ELabelEncode(object):
|
||||||
def __init__(self, label_list, **kwargs):
|
def __init__(self, label_list, **kwargs):
|
||||||
self.label_list = label_list
|
self.label_list = label_list
|
||||||
|
self.max_len = 50
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
text_label_index_list, temp_text = [], []
|
text_label_index_list, temp_text = [], []
|
||||||
|
@ -47,7 +48,7 @@ class E2ELabelEncode(object):
|
||||||
for c_ in text:
|
for c_ in text:
|
||||||
if c_ in self.label_list:
|
if c_ in self.label_list:
|
||||||
temp_text.append(self.label_list.index(c_))
|
temp_text.append(self.label_list.index(c_))
|
||||||
temp_text = temp_text + [36] * (50 - len(temp_text))
|
temp_text = temp_text + [36] * (self.max_len - len(temp_text))
|
||||||
text_label_index_list.append(temp_text)
|
text_label_index_list.append(temp_text)
|
||||||
data['strs'] = np.array(text_label_index_list)
|
data['strs'] = np.array(text_label_index_list)
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -32,16 +32,6 @@ class E2EMetric(object):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __call__(self, preds, batch, **kwargs):
|
def __call__(self, preds, batch, **kwargs):
|
||||||
'''
|
|
||||||
batch: a list produced by dataloaders.
|
|
||||||
image: np.ndarray of shape (N, C, H, W).
|
|
||||||
ratio_list: np.ndarray of shape(N,2)
|
|
||||||
polygons: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
|
|
||||||
ignore_tags: np.ndarray of shape (N, K), indicates whether a region is ignorable or not.
|
|
||||||
preds: a list of dict produced by post process
|
|
||||||
points: np.ndarray of shape (N, K, 4, 2), the polygons of objective regions.
|
|
||||||
'''
|
|
||||||
|
|
||||||
gt_polyons_batch = batch[2]
|
gt_polyons_batch = batch[2]
|
||||||
temp_gt_strs_batch = batch[3]
|
temp_gt_strs_batch = batch[3]
|
||||||
ignore_tags_batch = batch[4]
|
ignore_tags_batch = batch[4]
|
||||||
|
@ -72,13 +62,6 @@ class E2EMetric(object):
|
||||||
self.results.append(result)
|
self.results.append(result)
|
||||||
|
|
||||||
def get_metric(self):
|
def get_metric(self):
|
||||||
"""
|
|
||||||
return metrics {
|
|
||||||
'precision': 0,
|
|
||||||
'recall': 0,
|
|
||||||
'hmean': 0
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
metircs = combine_results(self.results)
|
metircs = combine_results(self.results)
|
||||||
self.reset()
|
self.reset()
|
||||||
return metircs
|
return metircs
|
||||||
|
|
|
@ -106,172 +106,212 @@ class DeConvBNLayer(nn.Layer):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class FPN_Up_Fusion(nn.Layer):
|
class PGFPN(nn.Layer):
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels, **kwargs):
|
||||||
super(FPN_Up_Fusion, self).__init__()
|
super(PGFPN, self).__init__()
|
||||||
in_channels = in_channels[::-1]
|
num_inputs = [2048, 2048, 1024, 512, 256]
|
||||||
out_channels = [256, 256, 192, 192, 128]
|
num_outputs = [256, 256, 192, 192, 128]
|
||||||
|
self.out_channels = 128
|
||||||
|
# print(in_channels)
|
||||||
|
self.conv_bn_layer_1 = ConvBNLayer(
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=32,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
act=None,
|
||||||
|
name='FPN_d1')
|
||||||
|
self.conv_bn_layer_2 = ConvBNLayer(
|
||||||
|
in_channels=64,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
act=None,
|
||||||
|
name='FPN_d2')
|
||||||
|
self.conv_bn_layer_3 = ConvBNLayer(
|
||||||
|
in_channels=256,
|
||||||
|
out_channels=128,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
act=None,
|
||||||
|
name='FPN_d3')
|
||||||
|
self.conv_bn_layer_4 = ConvBNLayer(
|
||||||
|
in_channels=32,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
act=None,
|
||||||
|
name='FPN_d4')
|
||||||
|
self.conv_bn_layer_5 = ConvBNLayer(
|
||||||
|
in_channels=64,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
act='relu',
|
||||||
|
name='FPN_d5')
|
||||||
|
self.conv_bn_layer_6 = ConvBNLayer(
|
||||||
|
in_channels=64,
|
||||||
|
out_channels=128,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
act=None,
|
||||||
|
name='FPN_d6')
|
||||||
|
self.conv_bn_layer_7 = ConvBNLayer(
|
||||||
|
in_channels=128,
|
||||||
|
out_channels=128,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
act='relu',
|
||||||
|
name='FPN_d7')
|
||||||
|
self.conv_bn_layer_8 = ConvBNLayer(
|
||||||
|
in_channels=128,
|
||||||
|
out_channels=128,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
act=None,
|
||||||
|
name='FPN_d8')
|
||||||
|
|
||||||
self.h0_conv = ConvBNLayer(
|
self.conv_h0 = ConvBNLayer(
|
||||||
in_channels[0], out_channels[0], 1, 1, act=None, name='conv_h0')
|
in_channels=num_inputs[0],
|
||||||
self.h1_conv = ConvBNLayer(
|
out_channels=num_outputs[0],
|
||||||
in_channels[1], out_channels[1], 1, 1, act=None, name='conv_h1')
|
kernel_size=1,
|
||||||
self.h2_conv = ConvBNLayer(
|
stride=1,
|
||||||
in_channels[2], out_channels[2], 1, 1, act=None, name='conv_h2')
|
act=None,
|
||||||
self.h3_conv = ConvBNLayer(
|
name="conv_h{}".format(0))
|
||||||
in_channels[3], out_channels[3], 1, 1, act=None, name='conv_h3')
|
self.conv_h1 = ConvBNLayer(
|
||||||
self.h4_conv = ConvBNLayer(
|
in_channels=num_inputs[1],
|
||||||
in_channels[4], out_channels[4], 1, 1, act=None, name='conv_h4')
|
out_channels=num_outputs[1],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
act=None,
|
||||||
|
name="conv_h{}".format(1))
|
||||||
|
self.conv_h2 = ConvBNLayer(
|
||||||
|
in_channels=num_inputs[2],
|
||||||
|
out_channels=num_outputs[2],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
act=None,
|
||||||
|
name="conv_h{}".format(2))
|
||||||
|
self.conv_h3 = ConvBNLayer(
|
||||||
|
in_channels=num_inputs[3],
|
||||||
|
out_channels=num_outputs[3],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
act=None,
|
||||||
|
name="conv_h{}".format(3))
|
||||||
|
self.conv_h4 = ConvBNLayer(
|
||||||
|
in_channels=num_inputs[4],
|
||||||
|
out_channels=num_outputs[4],
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
act=None,
|
||||||
|
name="conv_h{}".format(4))
|
||||||
|
|
||||||
self.dconv0 = DeConvBNLayer(
|
self.dconv0 = DeConvBNLayer(
|
||||||
in_channels=out_channels[0],
|
in_channels=num_outputs[0],
|
||||||
out_channels=out_channels[1],
|
out_channels=num_outputs[0 + 1],
|
||||||
name="dconv_{}".format(0))
|
name="dconv_{}".format(0))
|
||||||
self.dconv1 = DeConvBNLayer(
|
self.dconv1 = DeConvBNLayer(
|
||||||
in_channels=out_channels[1],
|
in_channels=num_outputs[1],
|
||||||
out_channels=out_channels[2],
|
out_channels=num_outputs[1 + 1],
|
||||||
act=None,
|
act=None,
|
||||||
name="dconv_{}".format(1))
|
name="dconv_{}".format(1))
|
||||||
self.dconv2 = DeConvBNLayer(
|
self.dconv2 = DeConvBNLayer(
|
||||||
in_channels=out_channels[2],
|
in_channels=num_outputs[2],
|
||||||
out_channels=out_channels[3],
|
out_channels=num_outputs[2 + 1],
|
||||||
act=None,
|
act=None,
|
||||||
name="dconv_{}".format(2))
|
name="dconv_{}".format(2))
|
||||||
self.dconv3 = DeConvBNLayer(
|
self.dconv3 = DeConvBNLayer(
|
||||||
in_channels=out_channels[3],
|
in_channels=num_outputs[3],
|
||||||
out_channels=out_channels[4],
|
out_channels=num_outputs[3 + 1],
|
||||||
act=None,
|
act=None,
|
||||||
name="dconv_{}".format(3))
|
name="dconv_{}".format(3))
|
||||||
self.conv_g1 = ConvBNLayer(
|
self.conv_g1 = ConvBNLayer(
|
||||||
in_channels=out_channels[1],
|
in_channels=num_outputs[1],
|
||||||
out_channels=out_channels[1],
|
out_channels=num_outputs[1],
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
act='relu',
|
act='relu',
|
||||||
name="conv_g{}".format(1))
|
name="conv_g{}".format(1))
|
||||||
self.conv_g2 = ConvBNLayer(
|
self.conv_g2 = ConvBNLayer(
|
||||||
in_channels=out_channels[2],
|
in_channels=num_outputs[2],
|
||||||
out_channels=out_channels[2],
|
out_channels=num_outputs[2],
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
act='relu',
|
act='relu',
|
||||||
name="conv_g{}".format(2))
|
name="conv_g{}".format(2))
|
||||||
self.conv_g3 = ConvBNLayer(
|
self.conv_g3 = ConvBNLayer(
|
||||||
in_channels=out_channels[3],
|
in_channels=num_outputs[3],
|
||||||
out_channels=out_channels[3],
|
out_channels=num_outputs[3],
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
act='relu',
|
act='relu',
|
||||||
name="conv_g{}".format(3))
|
name="conv_g{}".format(3))
|
||||||
self.conv_g4 = ConvBNLayer(
|
self.conv_g4 = ConvBNLayer(
|
||||||
in_channels=out_channels[4],
|
in_channels=num_outputs[4],
|
||||||
out_channels=out_channels[4],
|
out_channels=num_outputs[4],
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
stride=1,
|
stride=1,
|
||||||
act='relu',
|
act='relu',
|
||||||
name="conv_g{}".format(4))
|
name="conv_g{}".format(4))
|
||||||
self.convf = ConvBNLayer(
|
self.convf = ConvBNLayer(
|
||||||
in_channels=out_channels[4],
|
in_channels=num_outputs[4],
|
||||||
out_channels=out_channels[4],
|
out_channels=num_outputs[4],
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
act=None,
|
act=None,
|
||||||
name="conv_f{}".format(4))
|
name="conv_f{}".format(4))
|
||||||
|
|
||||||
def _add_relu(self, x1, x2):
|
|
||||||
x = paddle.add(x=x1, y=x2)
|
|
||||||
x = F.relu(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
f = x[2:][::-1]
|
c0, c1, c2, c3, c4, c5, c6 = x
|
||||||
h0 = self.h0_conv(f[0])
|
# FPN_Down_Fusion
|
||||||
h1 = self.h1_conv(f[1])
|
f = [c0, c1, c2]
|
||||||
h2 = self.h2_conv(f[2])
|
g = [None, None, None]
|
||||||
h3 = self.h3_conv(f[3])
|
h = [None, None, None]
|
||||||
h4 = self.h4_conv(f[4])
|
h[0] = self.conv_bn_layer_1(f[0])
|
||||||
|
h[1] = self.conv_bn_layer_2(f[1])
|
||||||
|
h[2] = self.conv_bn_layer_3(f[2])
|
||||||
|
|
||||||
g0 = self.dconv0(h0)
|
g[0] = self.conv_bn_layer_4(h[0])
|
||||||
|
g[1] = paddle.add(g[0], h[1])
|
||||||
|
g[1] = F.relu(g[1])
|
||||||
|
g[1] = self.conv_bn_layer_5(g[1])
|
||||||
|
g[1] = self.conv_bn_layer_6(g[1])
|
||||||
|
|
||||||
g1 = self.dconv2(self.conv_g2(self._add_relu(g0, h1)))
|
g[2] = paddle.add(g[1], h[2])
|
||||||
g2 = self.dconv2(self.conv_g2(self._add_relu(g1, h2)))
|
g[2] = F.relu(g[2])
|
||||||
g3 = self.dconv3(self.conv_g2(self._add_relu(g2, h3)))
|
g[2] = self.conv_bn_layer_7(g[2])
|
||||||
g4 = self.dconv4(self.conv_g2(self._add_relu(g3, h4)))
|
f_down = self.conv_bn_layer_8(g[2])
|
||||||
return g4
|
|
||||||
|
|
||||||
|
# FPN UP Fusion
|
||||||
|
f1 = [c6, c5, c4, c3, c2]
|
||||||
|
g = [None, None, None, None, None]
|
||||||
|
h = [None, None, None, None, None]
|
||||||
|
h[0] = self.conv_h0(f1[0])
|
||||||
|
h[1] = self.conv_h1(f1[1])
|
||||||
|
h[2] = self.conv_h2(f1[2])
|
||||||
|
h[3] = self.conv_h3(f1[3])
|
||||||
|
h[4] = self.conv_h4(f1[4])
|
||||||
|
|
||||||
class FPN_Down_Fusion(nn.Layer):
|
g[0] = self.dconv0(h[0])
|
||||||
def __init__(self, in_channels):
|
g[1] = paddle.add(g[0], h[1])
|
||||||
super(FPN_Down_Fusion, self).__init__()
|
g[1] = F.relu(g[1])
|
||||||
out_channels = [32, 64, 128]
|
g[1] = self.conv_g1(g[1])
|
||||||
|
g[1] = self.dconv1(g[1])
|
||||||
|
|
||||||
self.h0_conv = ConvBNLayer(
|
g[2] = paddle.add(g[1], h[2])
|
||||||
in_channels[0], out_channels[0], 3, 1, act=None, name='FPN_d1')
|
g[2] = F.relu(g[2])
|
||||||
self.h1_conv = ConvBNLayer(
|
g[2] = self.conv_g2(g[2])
|
||||||
in_channels[1], out_channels[1], 3, 1, act=None, name='FPN_d2')
|
g[2] = self.dconv2(g[2])
|
||||||
self.h2_conv = ConvBNLayer(
|
|
||||||
in_channels[2], out_channels[2], 3, 1, act=None, name='FPN_d3')
|
|
||||||
|
|
||||||
self.g0_conv = ConvBNLayer(
|
g[3] = paddle.add(g[2], h[3])
|
||||||
out_channels[0], out_channels[1], 3, 2, act=None, name='FPN_d4')
|
g[3] = F.relu(g[3])
|
||||||
|
g[3] = self.conv_g3(g[3])
|
||||||
|
g[3] = self.dconv3(g[3])
|
||||||
|
|
||||||
self.g1_conv = nn.Sequential(
|
g[4] = paddle.add(x=g[3], y=h[4])
|
||||||
ConvBNLayer(
|
g[4] = F.relu(g[4])
|
||||||
out_channels[1],
|
g[4] = self.conv_g4(g[4])
|
||||||
out_channels[1],
|
f_up = self.convf(g[4])
|
||||||
3,
|
f_common = paddle.add(f_down, f_up)
|
||||||
1,
|
|
||||||
act='relu',
|
|
||||||
name='FPN_d5'),
|
|
||||||
ConvBNLayer(
|
|
||||||
out_channels[1], out_channels[2], 3, 2, act=None,
|
|
||||||
name='FPN_d6'))
|
|
||||||
|
|
||||||
self.g2_conv = nn.Sequential(
|
|
||||||
ConvBNLayer(
|
|
||||||
out_channels[2],
|
|
||||||
out_channels[2],
|
|
||||||
3,
|
|
||||||
1,
|
|
||||||
act='relu',
|
|
||||||
name='FPN_d7'),
|
|
||||||
ConvBNLayer(
|
|
||||||
out_channels[2], out_channels[2], 1, 1, act=None,
|
|
||||||
name='FPN_d8'))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
f = x[:3]
|
|
||||||
h0 = self.h0_conv(f[0])
|
|
||||||
h1 = self.h1_conv(f[1])
|
|
||||||
h2 = self.h2_conv(f[2])
|
|
||||||
g0 = self.g0_conv(h0)
|
|
||||||
g1 = paddle.add(x=g0, y=h1)
|
|
||||||
g1 = F.relu(g1)
|
|
||||||
g1 = self.g1_conv(g1)
|
|
||||||
g2 = paddle.add(x=g1, y=h2)
|
|
||||||
g2 = F.relu(g2)
|
|
||||||
g2 = self.g2_conv(g2)
|
|
||||||
return g2
|
|
||||||
|
|
||||||
|
|
||||||
class PGFPN(nn.Layer):
|
|
||||||
def __init__(self, in_channels, with_cab=False, **kwargs):
|
|
||||||
super(PGFPN, self).__init__()
|
|
||||||
self.in_channels = in_channels
|
|
||||||
self.with_cab = with_cab
|
|
||||||
self.FPN_Down_Fusion = FPN_Down_Fusion(self.in_channels)
|
|
||||||
self.FPN_Up_Fusion = FPN_Up_Fusion(self.in_channels)
|
|
||||||
self.out_channels = 128
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# down fpn
|
|
||||||
f_down = self.FPN_Down_Fusion(x)
|
|
||||||
|
|
||||||
# up fpn
|
|
||||||
f_up = self.FPN_Up_Fusion(x)
|
|
||||||
|
|
||||||
# fusion
|
|
||||||
f_common = paddle.add(x=f_down, y=f_up)
|
|
||||||
f_common = F.relu(f_common)
|
f_common = F.relu(f_common)
|
||||||
|
|
||||||
return f_common
|
return f_common
|
||||||
|
|
|
@ -1,9 +1,18 @@
|
||||||
from os import listdir
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
import os, sys
|
#
|
||||||
from scipy import io
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
try: # python2
|
try: # python2
|
||||||
range = xrange
|
range = xrange
|
||||||
|
@ -862,16 +871,3 @@ def combine_results(all_data):
|
||||||
'f_score_e2e': f_score_e2e
|
'f_score_e2e': f_score_e2e
|
||||||
}
|
}
|
||||||
return final
|
return final
|
||||||
|
|
||||||
|
|
||||||
# a = [1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620, 1659, 620, 1654, 681, 1631, 680, 1618,
|
|
||||||
# 681, 1606, 681, 1594, 681, 1584, 682, 1573, 685, 1542, 694]
|
|
||||||
# gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
|
|
||||||
# pred_dict = [{'points': np.array(a), 'text': 'ccc'},
|
|
||||||
# {'points': np.array(a), 'text': 'ccf'}]
|
|
||||||
# result = []
|
|
||||||
# for i in range(2):
|
|
||||||
# result.append(get_socre(gt_dict, pred_dict))
|
|
||||||
# print(111)
|
|
||||||
# a = combine_results(result)
|
|
||||||
# print(a)
|
|
||||||
|
|
|
@ -1,6 +1,18 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from shapely.geometry import Polygon
|
from shapely.geometry import Polygon
|
||||||
#import Polygon
|
|
||||||
"""
|
"""
|
||||||
:param det_x: [1, N] Xs of detection's vertices
|
:param det_x: [1, N] Xs of detection's vertices
|
||||||
:param det_y: [1, N] Ys of detection's vertices
|
:param det_y: [1, N] Ys of detection's vertices
|
||||||
|
|
|
@ -1,881 +0,0 @@
|
||||||
from os import listdir
|
|
||||||
import os, sys
|
|
||||||
from scipy import io
|
|
||||||
import numpy as np
|
|
||||||
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
try: # python2
|
|
||||||
range = xrange
|
|
||||||
except Exception:
|
|
||||||
# python3
|
|
||||||
range = range
|
|
||||||
"""
|
|
||||||
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')'
|
|
||||||
"""
|
|
||||||
|
|
||||||
# if len(sys.argv) != 4:
|
|
||||||
# print('\n usage: test.py pred_dir gt_dir savefile')
|
|
||||||
# sys.exit()
|
|
||||||
global_tp = 0
|
|
||||||
global_fp = 0
|
|
||||||
global_fn = 0
|
|
||||||
|
|
||||||
tr = 0.7
|
|
||||||
tp = 0.6
|
|
||||||
fsc_k = 0.8
|
|
||||||
k = 2
|
|
||||||
|
|
||||||
|
|
||||||
def get_socre(gt_dict, pred_dict):
|
|
||||||
# allInputs = listdir(input_dir)
|
|
||||||
allInputs = 1
|
|
||||||
global_pred_str = []
|
|
||||||
global_gt_str = []
|
|
||||||
global_sigma = []
|
|
||||||
global_tau = []
|
|
||||||
|
|
||||||
def input_reading_mod(pred_dict, input):
|
|
||||||
"""This helper reads input from txt files"""
|
|
||||||
det = []
|
|
||||||
n = len(pred_dict)
|
|
||||||
for i in range(n):
|
|
||||||
points = pred_dict[i]['points']
|
|
||||||
text = pred_dict[i]['text']
|
|
||||||
# for i in range(len(points)):
|
|
||||||
point = ",".join(map(str, points.reshape(-1, )))
|
|
||||||
det.append([point, text])
|
|
||||||
return det
|
|
||||||
|
|
||||||
def gt_reading_mod(gt_dict, gt_id):
|
|
||||||
"""This helper reads groundtruths from mat files"""
|
|
||||||
# gt_id = gt_id.split('.')[0]
|
|
||||||
gt = []
|
|
||||||
n = len(gt_dict)
|
|
||||||
for i in range(n):
|
|
||||||
points = gt_dict[i]['points'].tolist()
|
|
||||||
h = len(points)
|
|
||||||
text = gt_dict[i]['text']
|
|
||||||
xx = [
|
|
||||||
np.array(
|
|
||||||
['x:'], dtype='<U2'), 0, np.array(
|
|
||||||
['y:'], dtype='<U2'), 0, np.array(
|
|
||||||
['#'], dtype='<U1'), np.array(
|
|
||||||
['#'], dtype='<U1')
|
|
||||||
]
|
|
||||||
t_x, t_y = [], []
|
|
||||||
for j in range(h):
|
|
||||||
t_x.append(points[j][0])
|
|
||||||
t_y.append(points[j][1])
|
|
||||||
xx[1] = np.array([t_x], dtype='int16')
|
|
||||||
xx[3] = np.array([t_y], dtype='int16')
|
|
||||||
if text != "":
|
|
||||||
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
|
|
||||||
xx[5] = np.array(['c'], dtype='<U1')
|
|
||||||
gt.append(xx)
|
|
||||||
return gt
|
|
||||||
|
|
||||||
def detection_filtering(detections, groundtruths, threshold=0.5):
|
|
||||||
for gt_id, gt in enumerate(groundtruths):
|
|
||||||
print
|
|
||||||
"liushanshan gt[1] = {}".format(gt[1])
|
|
||||||
print
|
|
||||||
"liushanshan gt[2] = {}".format(gt[2])
|
|
||||||
print
|
|
||||||
"liushanshan gt[3] = {}".format(gt[3])
|
|
||||||
print
|
|
||||||
"liushanshan gt[4] = {}".format(gt[4])
|
|
||||||
print
|
|
||||||
"liushanshan gt[5] = {}".format(gt[5])
|
|
||||||
if (gt[5] == '#') and (gt[1].shape[1] > 1):
|
|
||||||
gt_x = list(map(int, np.squeeze(gt[1])))
|
|
||||||
gt_y = list(map(int, np.squeeze(gt[3])))
|
|
||||||
for det_id, detection in enumerate(detections):
|
|
||||||
detection_orig = detection
|
|
||||||
detection = [float(x) for x in detection[0].split(',')]
|
|
||||||
# detection = detection.split(',')
|
|
||||||
detection = list(map(int, detection))
|
|
||||||
det_x = detection[0::2]
|
|
||||||
det_y = detection[1::2]
|
|
||||||
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
|
|
||||||
if det_gt_iou > threshold:
|
|
||||||
detections[det_id] = []
|
|
||||||
|
|
||||||
detections[:] = [item for item in detections if item != []]
|
|
||||||
return detections
|
|
||||||
|
|
||||||
def sigma_calculation(det_x, det_y, gt_x, gt_y):
|
|
||||||
"""
|
|
||||||
sigma = inter_area / gt_area
|
|
||||||
"""
|
|
||||||
# print(area_of_intersection(det_x, det_y, gt_x, gt_y))
|
|
||||||
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
|
|
||||||
area(gt_x, gt_y)), 2)
|
|
||||||
|
|
||||||
def tau_calculation(det_x, det_y, gt_x, gt_y):
|
|
||||||
"""
|
|
||||||
tau = inter_area / det_area
|
|
||||||
"""
|
|
||||||
# print "liushanshan det_x {}".format(det_x)
|
|
||||||
# print "liushanshan det_y {}".format(det_y)
|
|
||||||
# print "liushanshan area {}".format(area(det_x, det_y))
|
|
||||||
# print "liushanshan tau = {}".format(np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / area(det_x, det_y)), 2))
|
|
||||||
if area(det_x, det_y) == 0.0:
|
|
||||||
return 0
|
|
||||||
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
|
|
||||||
area(det_x, det_y)), 2)
|
|
||||||
|
|
||||||
##############################Initialization###################################
|
|
||||||
|
|
||||||
###############################################################################
|
|
||||||
single_data = {}
|
|
||||||
for input_id in range(allInputs):
|
|
||||||
|
|
||||||
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
|
|
||||||
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
|
|
||||||
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
|
|
||||||
and (input_id != 'Deteval_result_non_curved.txt'):
|
|
||||||
print(input_id)
|
|
||||||
detections = input_reading_mod(pred_dict, input_id)
|
|
||||||
# print "liushanshan detections = {}".format(detections)
|
|
||||||
groundtruths = gt_reading_mod(gt_dict, input_id)
|
|
||||||
detections = detection_filtering(
|
|
||||||
detections,
|
|
||||||
groundtruths) # filters detections overlapping with DC area
|
|
||||||
dc_id = []
|
|
||||||
for i in range(len(groundtruths)):
|
|
||||||
if groundtruths[i][5] == '#':
|
|
||||||
dc_id.append(i)
|
|
||||||
cnt = 0
|
|
||||||
for a in dc_id:
|
|
||||||
num = a - cnt
|
|
||||||
del groundtruths[num]
|
|
||||||
cnt += 1
|
|
||||||
|
|
||||||
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
|
|
||||||
local_tau_table = np.zeros((len(groundtruths), len(detections)))
|
|
||||||
local_pred_str = {}
|
|
||||||
local_gt_str = {}
|
|
||||||
|
|
||||||
for gt_id, gt in enumerate(groundtruths):
|
|
||||||
if len(detections) > 0:
|
|
||||||
for det_id, detection in enumerate(detections):
|
|
||||||
detection_orig = detection
|
|
||||||
detection = [float(x) for x in detection[0].split(',')]
|
|
||||||
detection = list(map(int, detection))
|
|
||||||
pred_seq_str = detection_orig[1].strip()
|
|
||||||
det_x = detection[0::2]
|
|
||||||
det_y = detection[1::2]
|
|
||||||
gt_x = list(map(int, np.squeeze(gt[1])))
|
|
||||||
gt_y = list(map(int, np.squeeze(gt[3])))
|
|
||||||
gt_seq_str = str(gt[4].tolist()[0])
|
|
||||||
|
|
||||||
local_sigma_table[gt_id, det_id] = sigma_calculation(
|
|
||||||
det_x, det_y, gt_x, gt_y)
|
|
||||||
local_tau_table[gt_id, det_id] = tau_calculation(
|
|
||||||
det_x, det_y, gt_x, gt_y)
|
|
||||||
local_pred_str[det_id] = pred_seq_str
|
|
||||||
local_gt_str[gt_id] = gt_seq_str
|
|
||||||
|
|
||||||
global_sigma.append(local_sigma_table)
|
|
||||||
global_tau.append(local_tau_table)
|
|
||||||
global_pred_str.append(local_pred_str)
|
|
||||||
global_gt_str.append(local_gt_str)
|
|
||||||
print
|
|
||||||
"liushanshan global_pred_str = {}".format(global_pred_str)
|
|
||||||
print
|
|
||||||
"liushanshan global_gt_str = {}".format(global_gt_str)
|
|
||||||
single_data['sigma'] = global_sigma
|
|
||||||
single_data['global_tau'] = global_tau
|
|
||||||
single_data['global_pred_str'] = global_pred_str
|
|
||||||
single_data['global_gt_str'] = global_gt_str
|
|
||||||
return single_data
|
|
||||||
|
|
||||||
|
|
||||||
def combine_results(all_data):
|
|
||||||
global_sigma, global_tau, global_pred_str, global_gt_str = [], [], [], []
|
|
||||||
for data in all_data:
|
|
||||||
global_sigma.append(data['sigma'])
|
|
||||||
global_tau.append(data['global_tau'])
|
|
||||||
global_pred_str.append(data['global_pred_str'])
|
|
||||||
global_gt_str.append(data['global_gt_str'])
|
|
||||||
global_accumulative_recall = 0
|
|
||||||
global_accumulative_precision = 0
|
|
||||||
total_num_gt = 0
|
|
||||||
total_num_det = 0
|
|
||||||
hit_str_count = 0
|
|
||||||
hit_count = 0
|
|
||||||
|
|
||||||
def one_to_one(local_sigma_table, local_tau_table,
|
|
||||||
local_accumulative_recall, local_accumulative_precision,
|
|
||||||
global_accumulative_recall, global_accumulative_precision,
|
|
||||||
gt_flag, det_flag, idy):
|
|
||||||
hit_str_num = 0
|
|
||||||
for gt_id in range(num_gt):
|
|
||||||
gt_matching_qualified_sigma_candidates = np.where(
|
|
||||||
local_sigma_table[gt_id, :] > tr)
|
|
||||||
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
|
|
||||||
0].shape[0]
|
|
||||||
gt_matching_qualified_tau_candidates = np.where(
|
|
||||||
local_tau_table[gt_id, :] > tp)
|
|
||||||
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
|
|
||||||
0].shape[0]
|
|
||||||
|
|
||||||
det_matching_qualified_sigma_candidates = np.where(
|
|
||||||
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
|
|
||||||
> tr)
|
|
||||||
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
|
|
||||||
0].shape[0]
|
|
||||||
det_matching_qualified_tau_candidates = np.where(
|
|
||||||
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
|
|
||||||
tp)
|
|
||||||
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
|
|
||||||
0].shape[0]
|
|
||||||
|
|
||||||
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
|
|
||||||
(det_matching_num_qualified_sigma_candidates == 1) and (
|
|
||||||
det_matching_num_qualified_tau_candidates == 1):
|
|
||||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
|
||||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
|
||||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
|
||||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
|
||||||
|
|
||||||
gt_flag[0, gt_id] = 1
|
|
||||||
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
|
|
||||||
# recg start
|
|
||||||
print
|
|
||||||
"liushanshan one to one det_id = {}".format(matched_det_id)
|
|
||||||
print
|
|
||||||
"liushanshan one to one gt_id = {}".format(gt_id)
|
|
||||||
gt_str_cur = global_gt_str[idy][gt_id]
|
|
||||||
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
|
|
||||||
0]]
|
|
||||||
print
|
|
||||||
"liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
|
|
||||||
print
|
|
||||||
"liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
|
|
||||||
if pred_str_cur == gt_str_cur:
|
|
||||||
hit_str_num += 1
|
|
||||||
else:
|
|
||||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
|
||||||
hit_str_num += 1
|
|
||||||
# recg end
|
|
||||||
det_flag[0, matched_det_id] = 1
|
|
||||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
|
||||||
|
|
||||||
def one_to_many(local_sigma_table, local_tau_table,
|
|
||||||
local_accumulative_recall, local_accumulative_precision,
|
|
||||||
global_accumulative_recall, global_accumulative_precision,
|
|
||||||
gt_flag, det_flag, idy):
|
|
||||||
hit_str_num = 0
|
|
||||||
for gt_id in range(num_gt):
|
|
||||||
# skip the following if the groundtruth was matched
|
|
||||||
if gt_flag[0, gt_id] > 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
|
|
||||||
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
|
|
||||||
|
|
||||||
if num_non_zero_in_sigma >= k:
|
|
||||||
####search for all detections that overlaps with this groundtruth
|
|
||||||
qualified_tau_candidates = np.where((local_tau_table[
|
|
||||||
gt_id, :] >= tp) & (det_flag[0, :] == 0))
|
|
||||||
num_qualified_tau_candidates = qualified_tau_candidates[
|
|
||||||
0].shape[0]
|
|
||||||
|
|
||||||
if num_qualified_tau_candidates == 1:
|
|
||||||
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
|
|
||||||
and
|
|
||||||
(local_sigma_table[gt_id, qualified_tau_candidates] >=
|
|
||||||
tr)):
|
|
||||||
# became an one-to-one case
|
|
||||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
|
||||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
|
||||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
|
||||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
|
||||||
|
|
||||||
gt_flag[0, gt_id] = 1
|
|
||||||
det_flag[0, qualified_tau_candidates] = 1
|
|
||||||
# recg start
|
|
||||||
print
|
|
||||||
"liushanshan one to many det_id = {}".format(
|
|
||||||
qualified_tau_candidates)
|
|
||||||
print
|
|
||||||
"liushanshan one to many gt_id = {}".format(gt_id)
|
|
||||||
gt_str_cur = global_gt_str[idy][gt_id]
|
|
||||||
pred_str_cur = global_pred_str[idy][
|
|
||||||
qualified_tau_candidates[0].tolist()[0]]
|
|
||||||
print
|
|
||||||
"liushanshan one to many gt_str_cur = {}".format(
|
|
||||||
gt_str_cur)
|
|
||||||
print
|
|
||||||
"liushanshan one to many pred_str_cur = {}".format(
|
|
||||||
pred_str_cur)
|
|
||||||
if pred_str_cur == gt_str_cur:
|
|
||||||
hit_str_num += 1
|
|
||||||
else:
|
|
||||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
|
||||||
hit_str_num += 1
|
|
||||||
# recg end
|
|
||||||
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
|
|
||||||
>= tr):
|
|
||||||
gt_flag[0, gt_id] = 1
|
|
||||||
det_flag[0, qualified_tau_candidates] = 1
|
|
||||||
# recg start
|
|
||||||
print
|
|
||||||
"liushanshan one to many det_id = {}".format(
|
|
||||||
qualified_tau_candidates)
|
|
||||||
print
|
|
||||||
"liushanshan one to many gt_id = {}".format(gt_id)
|
|
||||||
gt_str_cur = global_gt_str[idy][gt_id]
|
|
||||||
pred_str_cur = global_pred_str[idy][
|
|
||||||
qualified_tau_candidates[0].tolist()[0]]
|
|
||||||
print
|
|
||||||
"liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
|
|
||||||
print
|
|
||||||
"liushanshan one to many pred_str_cur = {}".format(
|
|
||||||
pred_str_cur)
|
|
||||||
if pred_str_cur == gt_str_cur:
|
|
||||||
hit_str_num += 1
|
|
||||||
else:
|
|
||||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
|
||||||
hit_str_num += 1
|
|
||||||
# recg end
|
|
||||||
|
|
||||||
global_accumulative_recall = global_accumulative_recall + fsc_k
|
|
||||||
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
|
|
||||||
|
|
||||||
local_accumulative_recall = local_accumulative_recall + fsc_k
|
|
||||||
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
|
|
||||||
|
|
||||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
|
||||||
|
|
||||||
def many_to_one(local_sigma_table, local_tau_table,
|
|
||||||
local_accumulative_recall, local_accumulative_precision,
|
|
||||||
global_accumulative_recall, global_accumulative_precision,
|
|
||||||
gt_flag, det_flag, idy):
|
|
||||||
hit_str_num = 0
|
|
||||||
for det_id in range(num_det):
|
|
||||||
# skip the following if the detection was matched
|
|
||||||
if det_flag[0, det_id] > 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
|
|
||||||
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
|
|
||||||
|
|
||||||
if num_non_zero_in_tau >= k:
|
|
||||||
####search for all detections that overlaps with this groundtruth
|
|
||||||
qualified_sigma_candidates = np.where((
|
|
||||||
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
|
|
||||||
num_qualified_sigma_candidates = qualified_sigma_candidates[
|
|
||||||
0].shape[0]
|
|
||||||
|
|
||||||
if num_qualified_sigma_candidates == 1:
|
|
||||||
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
|
|
||||||
tp) and
|
|
||||||
(local_sigma_table[qualified_sigma_candidates, det_id]
|
|
||||||
>= tr)):
|
|
||||||
# became an one-to-one case
|
|
||||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
|
||||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
|
||||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
|
||||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
|
||||||
|
|
||||||
gt_flag[0, qualified_sigma_candidates] = 1
|
|
||||||
det_flag[0, det_id] = 1
|
|
||||||
# recg start
|
|
||||||
print
|
|
||||||
"liushanshan many to one det_id = {}".format(det_id)
|
|
||||||
print
|
|
||||||
"liushanshan many to one gt_id = {}".format(
|
|
||||||
qualified_sigma_candidates)
|
|
||||||
pred_str_cur = global_pred_str[idy][det_id]
|
|
||||||
gt_len = len(qualified_sigma_candidates[0])
|
|
||||||
for idx in range(gt_len):
|
|
||||||
ele_gt_id = qualified_sigma_candidates[0].tolist()[
|
|
||||||
idx]
|
|
||||||
if not global_gt_str[idy].has_key(ele_gt_id):
|
|
||||||
continue
|
|
||||||
gt_str_cur = global_gt_str[idy][ele_gt_id]
|
|
||||||
print
|
|
||||||
"liushanshan many to one gt_str_cur = {}".format(
|
|
||||||
gt_str_cur)
|
|
||||||
print
|
|
||||||
"liushanshan many to one pred_str_cur = {}".format(
|
|
||||||
pred_str_cur)
|
|
||||||
if pred_str_cur == gt_str_cur:
|
|
||||||
hit_str_num += 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
|
||||||
hit_str_num += 1
|
|
||||||
break
|
|
||||||
# recg end
|
|
||||||
elif (np.sum(local_tau_table[qualified_sigma_candidates,
|
|
||||||
det_id]) >= tp):
|
|
||||||
det_flag[0, det_id] = 1
|
|
||||||
gt_flag[0, qualified_sigma_candidates] = 1
|
|
||||||
# recg start
|
|
||||||
print
|
|
||||||
"liushanshan many to one det_id = {}".format(det_id)
|
|
||||||
print
|
|
||||||
"liushanshan many to one gt_id = {}".format(
|
|
||||||
qualified_sigma_candidates)
|
|
||||||
pred_str_cur = global_pred_str[idy][det_id]
|
|
||||||
gt_len = len(qualified_sigma_candidates[0])
|
|
||||||
for idx in range(gt_len):
|
|
||||||
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
|
|
||||||
if not global_gt_str[idy].has_key(ele_gt_id):
|
|
||||||
continue
|
|
||||||
gt_str_cur = global_gt_str[idy][ele_gt_id]
|
|
||||||
print
|
|
||||||
"liushanshan many to one gt_str_cur = {}".format(
|
|
||||||
gt_str_cur)
|
|
||||||
print
|
|
||||||
"liushanshan many to one pred_str_cur = {}".format(
|
|
||||||
pred_str_cur)
|
|
||||||
if pred_str_cur == gt_str_cur:
|
|
||||||
hit_str_num += 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
|
||||||
hit_str_num += 1
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
print
|
|
||||||
'no match'
|
|
||||||
# recg end
|
|
||||||
|
|
||||||
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
|
|
||||||
global_accumulative_precision = global_accumulative_precision + fsc_k
|
|
||||||
|
|
||||||
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
|
|
||||||
local_accumulative_precision = local_accumulative_precision + fsc_k
|
|
||||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
|
||||||
|
|
||||||
for idx in range(len(global_sigma)):
|
|
||||||
# print(allInputs[idx])
|
|
||||||
local_sigma_table = np.array(global_sigma[idx])
|
|
||||||
local_tau_table = global_tau[idx]
|
|
||||||
|
|
||||||
num_gt = local_sigma_table.shape[0]
|
|
||||||
num_det = local_sigma_table.shape[1]
|
|
||||||
|
|
||||||
total_num_gt = total_num_gt + num_gt
|
|
||||||
total_num_det = total_num_det + num_det
|
|
||||||
|
|
||||||
local_accumulative_recall = 0
|
|
||||||
local_accumulative_precision = 0
|
|
||||||
gt_flag = np.zeros((1, num_gt))
|
|
||||||
det_flag = np.zeros((1, num_det))
|
|
||||||
|
|
||||||
#######first check for one-to-one case##########
|
|
||||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
|
||||||
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
|
|
||||||
local_accumulative_recall, local_accumulative_precision,
|
|
||||||
global_accumulative_recall, global_accumulative_precision,
|
|
||||||
gt_flag, det_flag, idx)
|
|
||||||
|
|
||||||
hit_str_count += hit_str_num
|
|
||||||
#######then check for one-to-many case##########
|
|
||||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
|
||||||
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
|
|
||||||
local_accumulative_recall, local_accumulative_precision,
|
|
||||||
global_accumulative_recall, global_accumulative_precision,
|
|
||||||
gt_flag, det_flag, idx)
|
|
||||||
hit_str_count += hit_str_num
|
|
||||||
#######then check for many-to-one case##########
|
|
||||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
|
||||||
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
|
|
||||||
local_accumulative_recall, local_accumulative_precision,
|
|
||||||
global_accumulative_recall, global_accumulative_precision,
|
|
||||||
gt_flag, det_flag, idx)
|
|
||||||
|
|
||||||
try:
|
|
||||||
recall = global_accumulative_recall / total_num_gt
|
|
||||||
except ZeroDivisionError:
|
|
||||||
recall = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
precision = global_accumulative_precision / total_num_det
|
|
||||||
except ZeroDivisionError:
|
|
||||||
precision = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
f_score = 2 * precision * recall / (precision + recall)
|
|
||||||
except ZeroDivisionError:
|
|
||||||
f_score = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
seqerr = 1 - float(hit_str_count) / global_accumulative_recall
|
|
||||||
except ZeroDivisionError:
|
|
||||||
seqerr = 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
recall_e2e = float(hit_str_count) / total_num_gt
|
|
||||||
except ZeroDivisionError:
|
|
||||||
recall_e2e = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
precision_e2e = float(hit_str_count) / total_num_det
|
|
||||||
except ZeroDivisionError:
|
|
||||||
precision_e2e = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
f_score_e2e = 2 * precision_e2e * recall_e2e / (
|
|
||||||
precision_e2e + recall_e2e)
|
|
||||||
except ZeroDivisionError:
|
|
||||||
f_score_e2e = 0
|
|
||||||
|
|
||||||
final = {
|
|
||||||
'total_num_gt': total_num_gt,
|
|
||||||
'total_num_det': total_num_det,
|
|
||||||
'global_accumulative_recall': global_accumulative_recall,
|
|
||||||
'hit_str_count': hit_str_count,
|
|
||||||
'recall': recall,
|
|
||||||
'precision': precision,
|
|
||||||
'f_score': f_score,
|
|
||||||
'seqerr': seqerr,
|
|
||||||
'recall_e2e': recall_e2e,
|
|
||||||
'precision_e2e': precision_e2e,
|
|
||||||
'f_score_e2e': f_score_e2e
|
|
||||||
}
|
|
||||||
return final
|
|
||||||
|
|
||||||
|
|
||||||
# def combine_results(all_data):
|
|
||||||
# tr = 0.7
|
|
||||||
# tp = 0.6
|
|
||||||
# fsc_k = 0.8
|
|
||||||
# k = 2
|
|
||||||
# global_sigma = []
|
|
||||||
# global_tau = []
|
|
||||||
# global_pred_str = []
|
|
||||||
# global_gt_str = []
|
|
||||||
# for data in all_data:
|
|
||||||
# global_sigma.append(data['sigma'])
|
|
||||||
# global_tau.append(data['global_tau'])
|
|
||||||
# global_pred_str.append(data['global_pred_str'])
|
|
||||||
# global_gt_str.append(data['global_gt_str'])
|
|
||||||
#
|
|
||||||
# global_accumulative_recall = 0
|
|
||||||
# global_accumulative_precision = 0
|
|
||||||
# total_num_gt = 0
|
|
||||||
# total_num_det = 0
|
|
||||||
# hit_str_count = 0
|
|
||||||
# hit_count = 0
|
|
||||||
#
|
|
||||||
# def one_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
|
|
||||||
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
|
|
||||||
# gt_flag, det_flag, idy):
|
|
||||||
# hit_str_num = 0
|
|
||||||
# for gt_id in range(num_gt):
|
|
||||||
# gt_matching_qualified_sigma_candidates = np.where(local_sigma_table[gt_id, :] > tr)
|
|
||||||
# gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[0].shape[0]
|
|
||||||
# gt_matching_qualified_tau_candidates = np.where(local_tau_table[gt_id, :] > tp)
|
|
||||||
# gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[0].shape[0]
|
|
||||||
#
|
|
||||||
# det_matching_qualified_sigma_candidates = np.where(
|
|
||||||
# local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]] > tr)
|
|
||||||
# det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[0].shape[0]
|
|
||||||
# det_matching_qualified_tau_candidates = np.where(
|
|
||||||
# local_tau_table[:, gt_matching_qualified_tau_candidates[0]] > tp)
|
|
||||||
# det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[0].shape[0]
|
|
||||||
#
|
|
||||||
# if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
|
|
||||||
# (det_matching_num_qualified_sigma_candidates == 1) and (
|
|
||||||
# det_matching_num_qualified_tau_candidates == 1):
|
|
||||||
# global_accumulative_recall = global_accumulative_recall + 1.0
|
|
||||||
# global_accumulative_precision = global_accumulative_precision + 1.0
|
|
||||||
# local_accumulative_recall = local_accumulative_recall + 1.0
|
|
||||||
# local_accumulative_precision = local_accumulative_precision + 1.0
|
|
||||||
#
|
|
||||||
# gt_flag[0, gt_id] = 1
|
|
||||||
# matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
|
|
||||||
# # recg start
|
|
||||||
# print
|
|
||||||
# "liushanshan one to one det_id = {}".format(matched_det_id)
|
|
||||||
# print
|
|
||||||
# "liushanshan one to one gt_id = {}".format(gt_id)
|
|
||||||
# gt_str_cur = global_gt_str[idy][gt_id]
|
|
||||||
# pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[0]]
|
|
||||||
# print
|
|
||||||
# "liushanshan one to one gt_str_cur = {}".format(gt_str_cur)
|
|
||||||
# print
|
|
||||||
# "liushanshan one to one pred_str_cur = {}".format(pred_str_cur)
|
|
||||||
# if pred_str_cur == gt_str_cur:
|
|
||||||
# hit_str_num += 1
|
|
||||||
# else:
|
|
||||||
# if pred_str_cur.lower() == gt_str_cur.lower():
|
|
||||||
# hit_str_num += 1
|
|
||||||
# # recg end
|
|
||||||
# det_flag[0, matched_det_id] = 1
|
|
||||||
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
|
||||||
#
|
|
||||||
# def one_to_many(local_sigma_table, local_tau_table, local_accumulative_recall,
|
|
||||||
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
|
|
||||||
# gt_flag, det_flag, idy):
|
|
||||||
# hit_str_num = 0
|
|
||||||
# for gt_id in range(num_gt):
|
|
||||||
# # skip the following if the groundtruth was matched
|
|
||||||
# if gt_flag[0, gt_id] > 0:
|
|
||||||
# continue
|
|
||||||
#
|
|
||||||
# non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
|
|
||||||
# num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
|
|
||||||
#
|
|
||||||
# if num_non_zero_in_sigma >= k:
|
|
||||||
# ####search for all detections that overlaps with this groundtruth
|
|
||||||
# qualified_tau_candidates = np.where((local_tau_table[gt_id, :] >= tp) & (det_flag[0, :] == 0))
|
|
||||||
# num_qualified_tau_candidates = qualified_tau_candidates[0].shape[0]
|
|
||||||
#
|
|
||||||
# if num_qualified_tau_candidates == 1:
|
|
||||||
# if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp) and (
|
|
||||||
# local_sigma_table[gt_id, qualified_tau_candidates] >= tr)):
|
|
||||||
# # became an one-to-one case
|
|
||||||
# global_accumulative_recall = global_accumulative_recall + 1.0
|
|
||||||
# global_accumulative_precision = global_accumulative_precision + 1.0
|
|
||||||
# local_accumulative_recall = local_accumulative_recall + 1.0
|
|
||||||
# local_accumulative_precision = local_accumulative_precision + 1.0
|
|
||||||
#
|
|
||||||
# gt_flag[0, gt_id] = 1
|
|
||||||
# det_flag[0, qualified_tau_candidates] = 1
|
|
||||||
# # recg start
|
|
||||||
# print
|
|
||||||
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
|
|
||||||
# print
|
|
||||||
# "liushanshan one to many gt_id = {}".format(gt_id)
|
|
||||||
# gt_str_cur = global_gt_str[idy][gt_id]
|
|
||||||
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
|
|
||||||
# print
|
|
||||||
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
|
|
||||||
# print
|
|
||||||
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
|
|
||||||
# if pred_str_cur == gt_str_cur:
|
|
||||||
# hit_str_num += 1
|
|
||||||
# else:
|
|
||||||
# if pred_str_cur.lower() == gt_str_cur.lower():
|
|
||||||
# hit_str_num += 1
|
|
||||||
# # recg end
|
|
||||||
# elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates]) >= tr):
|
|
||||||
# gt_flag[0, gt_id] = 1
|
|
||||||
# det_flag[0, qualified_tau_candidates] = 1
|
|
||||||
# # recg start
|
|
||||||
# print
|
|
||||||
# "liushanshan one to many det_id = {}".format(qualified_tau_candidates)
|
|
||||||
# print
|
|
||||||
# "liushanshan one to many gt_id = {}".format(gt_id)
|
|
||||||
# gt_str_cur = global_gt_str[idy][gt_id]
|
|
||||||
# pred_str_cur = global_pred_str[idy][qualified_tau_candidates[0].tolist()[0]]
|
|
||||||
# print
|
|
||||||
# "liushanshan one to many gt_str_cur = {}".format(gt_str_cur)
|
|
||||||
# print
|
|
||||||
# "liushanshan one to many pred_str_cur = {}".format(pred_str_cur)
|
|
||||||
# if pred_str_cur == gt_str_cur:
|
|
||||||
# hit_str_num += 1
|
|
||||||
# else:
|
|
||||||
# if pred_str_cur.lower() == gt_str_cur.lower():
|
|
||||||
# hit_str_num += 1
|
|
||||||
# # recg end
|
|
||||||
#
|
|
||||||
# global_accumulative_recall = global_accumulative_recall + fsc_k
|
|
||||||
# global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
|
|
||||||
#
|
|
||||||
# local_accumulative_recall = local_accumulative_recall + fsc_k
|
|
||||||
# local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
|
|
||||||
#
|
|
||||||
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
|
||||||
#
|
|
||||||
# def many_to_one(local_sigma_table, local_tau_table, local_accumulative_recall,
|
|
||||||
# local_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
|
|
||||||
# gt_flag, det_flag, idy):
|
|
||||||
# hit_str_num = 0
|
|
||||||
# for det_id in range(num_det):
|
|
||||||
# # skip the following if the detection was matched
|
|
||||||
# if det_flag[0, det_id] > 0:
|
|
||||||
# continue
|
|
||||||
#
|
|
||||||
# non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
|
|
||||||
# num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
|
|
||||||
#
|
|
||||||
# if num_non_zero_in_tau >= k:
|
|
||||||
# ####search for all detections that overlaps with this groundtruth
|
|
||||||
# qualified_sigma_candidates = np.where((local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
|
|
||||||
# num_qualified_sigma_candidates = qualified_sigma_candidates[0].shape[0]
|
|
||||||
#
|
|
||||||
# if num_qualified_sigma_candidates == 1:
|
|
||||||
# if ((local_tau_table[qualified_sigma_candidates, det_id] >= tp) and (
|
|
||||||
# local_sigma_table[qualified_sigma_candidates, det_id] >= tr)):
|
|
||||||
# # became an one-to-one case
|
|
||||||
# global_accumulative_recall = global_accumulative_recall + 1.0
|
|
||||||
# global_accumulative_precision = global_accumulative_precision + 1.0
|
|
||||||
# local_accumulative_recall = local_accumulative_recall + 1.0
|
|
||||||
# local_accumulative_precision = local_accumulative_precision + 1.0
|
|
||||||
#
|
|
||||||
# gt_flag[0, qualified_sigma_candidates] = 1
|
|
||||||
# det_flag[0, det_id] = 1
|
|
||||||
# # recg start
|
|
||||||
# print
|
|
||||||
# "liushanshan many to one det_id = {}".format(det_id)
|
|
||||||
# print
|
|
||||||
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
|
|
||||||
# pred_str_cur = global_pred_str[idy][det_id]
|
|
||||||
# gt_len = len(qualified_sigma_candidates[0])
|
|
||||||
# for idx in range(gt_len):
|
|
||||||
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
|
|
||||||
# if ele_gt_id not in global_gt_str[idy]:
|
|
||||||
# continue
|
|
||||||
# gt_str_cur = global_gt_str[idy][ele_gt_id]
|
|
||||||
# print
|
|
||||||
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
|
|
||||||
# print
|
|
||||||
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
|
|
||||||
# if pred_str_cur == gt_str_cur:
|
|
||||||
# hit_str_num += 1
|
|
||||||
# break
|
|
||||||
# else:
|
|
||||||
# if pred_str_cur.lower() == gt_str_cur.lower():
|
|
||||||
# hit_str_num += 1
|
|
||||||
# break
|
|
||||||
# # recg end
|
|
||||||
# elif (np.sum(local_tau_table[qualified_sigma_candidates, det_id]) >= tp):
|
|
||||||
# det_flag[0, det_id] = 1
|
|
||||||
# gt_flag[0, qualified_sigma_candidates] = 1
|
|
||||||
# # recg start
|
|
||||||
# print
|
|
||||||
# "liushanshan many to one det_id = {}".format(det_id)
|
|
||||||
# print
|
|
||||||
# "liushanshan many to one gt_id = {}".format(qualified_sigma_candidates)
|
|
||||||
# pred_str_cur = global_pred_str[idy][det_id]
|
|
||||||
# gt_len = len(qualified_sigma_candidates[0])
|
|
||||||
# for idx in range(gt_len):
|
|
||||||
# ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
|
|
||||||
# if not global_gt_str[idy].has_key(ele_gt_id):
|
|
||||||
# continue
|
|
||||||
# gt_str_cur = global_gt_str[idy][ele_gt_id]
|
|
||||||
# print
|
|
||||||
# "liushanshan many to one gt_str_cur = {}".format(gt_str_cur)
|
|
||||||
# print
|
|
||||||
# "liushanshan many to one pred_str_cur = {}".format(pred_str_cur)
|
|
||||||
# if pred_str_cur == gt_str_cur:
|
|
||||||
# hit_str_num += 1
|
|
||||||
# break
|
|
||||||
# else:
|
|
||||||
# if pred_str_cur.lower() == gt_str_cur.lower():
|
|
||||||
# hit_str_num += 1
|
|
||||||
# break
|
|
||||||
# else:
|
|
||||||
# print
|
|
||||||
# 'no match'
|
|
||||||
# # recg end
|
|
||||||
#
|
|
||||||
# global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
|
|
||||||
# global_accumulative_precision = global_accumulative_precision + fsc_k
|
|
||||||
#
|
|
||||||
# local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
|
|
||||||
# local_accumulative_precision = local_accumulative_precision + fsc_k
|
|
||||||
# return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
|
||||||
#
|
|
||||||
# for idx in range(len(global_sigma)):
|
|
||||||
# local_sigma_table = np.array(global_sigma[idx])
|
|
||||||
# local_tau_table = np.array(global_tau[idx])
|
|
||||||
#
|
|
||||||
# num_gt = local_sigma_table.shape[0]
|
|
||||||
# num_det = local_sigma_table.shape[1]
|
|
||||||
#
|
|
||||||
# total_num_gt = total_num_gt + num_gt
|
|
||||||
# total_num_det = total_num_det + num_det
|
|
||||||
#
|
|
||||||
# local_accumulative_recall = 0
|
|
||||||
# local_accumulative_precision = 0
|
|
||||||
# gt_flag = np.zeros((1, num_gt))
|
|
||||||
# det_flag = np.zeros((1, num_det))
|
|
||||||
#
|
|
||||||
# #######first check for one-to-one case##########
|
|
||||||
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
|
||||||
# gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
|
|
||||||
# local_accumulative_recall, local_accumulative_precision,
|
|
||||||
# global_accumulative_recall, global_accumulative_precision,
|
|
||||||
# gt_flag, det_flag, idx)
|
|
||||||
#
|
|
||||||
# hit_str_count += hit_str_num
|
|
||||||
# #######then check for one-to-many case##########
|
|
||||||
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
|
||||||
# gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
|
|
||||||
# local_accumulative_recall, local_accumulative_precision,
|
|
||||||
# global_accumulative_recall, global_accumulative_precision,
|
|
||||||
# gt_flag, det_flag, idx)
|
|
||||||
# hit_str_count += hit_str_num
|
|
||||||
# #######then check for many-to-one case##########
|
|
||||||
# local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
|
||||||
# gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
|
|
||||||
# local_accumulative_recall, local_accumulative_precision,
|
|
||||||
# global_accumulative_recall, global_accumulative_precision,
|
|
||||||
# gt_flag, det_flag, idx)
|
|
||||||
# try:
|
|
||||||
# recall = global_accumulative_recall / total_num_gt
|
|
||||||
# except ZeroDivisionError:
|
|
||||||
# recall = 0
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# precision = global_accumulative_precision / total_num_det
|
|
||||||
# except ZeroDivisionError:
|
|
||||||
# precision = 0
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# f_score = 2 * precision * recall / (precision + recall)
|
|
||||||
# except ZeroDivisionError:
|
|
||||||
# f_score = 0
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# seqerr = 1 - float(hit_str_count) / global_accumulative_recall
|
|
||||||
# except ZeroDivisionError:
|
|
||||||
# seqerr = 1
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# recall_e2e = float(hit_str_count) / total_num_gt
|
|
||||||
# except ZeroDivisionError:
|
|
||||||
# recall_e2e = 0
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# precision_e2e = float(hit_str_count) / total_num_det
|
|
||||||
# except ZeroDivisionError:
|
|
||||||
# precision_e2e = 0
|
|
||||||
#
|
|
||||||
# try:
|
|
||||||
# f_score_e2e = 2 * precision_e2e * recall_e2e / (precision_e2e + recall_e2e)
|
|
||||||
# except ZeroDivisionError:
|
|
||||||
# f_score_e2e = 0
|
|
||||||
#
|
|
||||||
# final = {
|
|
||||||
# 'total_num_gt': total_num_gt,
|
|
||||||
# 'total_num_det': total_num_det,
|
|
||||||
# 'global_accumulative_recall': global_accumulative_recall,
|
|
||||||
# 'hit_str_count': hit_str_count,
|
|
||||||
# 'recall': recall,
|
|
||||||
# 'precision': precision,
|
|
||||||
# 'f_score': f_score,
|
|
||||||
# 'seqerr': seqerr,
|
|
||||||
# 'recall_e2e': recall_e2e,
|
|
||||||
# 'precision_e2e': precision_e2e,
|
|
||||||
# 'f_score_e2e': f_score_e2e
|
|
||||||
# }
|
|
||||||
# return final
|
|
||||||
|
|
||||||
a = [
|
|
||||||
1526, 642, 1565, 629, 1579, 627, 1593, 625, 1607, 623, 1620, 622, 1634, 620,
|
|
||||||
1659, 620, 1654, 681, 1631, 680, 1618, 681, 1606, 681, 1594, 681, 1584, 682,
|
|
||||||
1573, 685, 1542, 694
|
|
||||||
]
|
|
||||||
gt_dict = [{'points': np.array(a).reshape(-1, 2), 'text': 'MILK'}]
|
|
||||||
pred_dict = [{
|
|
||||||
'points': np.array(a),
|
|
||||||
'text': 'ccc'
|
|
||||||
}, {
|
|
||||||
'points': np.array(a),
|
|
||||||
'text': 'ccf'
|
|
||||||
}]
|
|
||||||
result = []
|
|
||||||
result.append(get_socre(gt_dict, gt_dict))
|
|
||||||
a = combine_results(result)
|
|
||||||
print(a)
|
|
|
@ -1,3 +1,16 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
"""Contains various CTC decoders."""
|
"""Contains various CTC decoders."""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
|
@ -1,6 +1,16 @@
|
||||||
"""
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
Algorithms for computing the skeleton of a binary image
|
#
|
||||||
"""
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import ndimage as ndi
|
from scipy import ndimage as ndi
|
||||||
|
|
|
@ -1,147 +1,21 @@
|
||||||
import os
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
def visualize_e2e_result(im_fn, poly_list, seq_strs, src_im):
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
result_path = './out'
|
|
||||||
im_basename = os.path.basename(im_fn)
|
|
||||||
im_prefix = im_basename[:im_basename.rfind('.')]
|
|
||||||
vis_det_img = src_im.copy()
|
|
||||||
valid_set = 'partvgg'
|
|
||||||
gt_dir = "/Users/hongyongjie/Downloads/part_vgg_synth/train"
|
|
||||||
text_path = os.path.join(gt_dir, im_prefix + '.txt')
|
|
||||||
fid = open(text_path, 'r')
|
|
||||||
lines = [line.strip() for line in fid.readlines()]
|
|
||||||
for line in lines:
|
|
||||||
if valid_set == 'partvgg':
|
|
||||||
tokens = line.strip().split('\t')[0].split(',')
|
|
||||||
# tokens = line.strip().split(',')
|
|
||||||
coords = tokens[:]
|
|
||||||
coords = list(map(float, coords))
|
|
||||||
gt_poly = np.array(coords).reshape(1, 4, 2)
|
|
||||||
elif valid_set == 'totaltext':
|
|
||||||
tokens = line.strip().split('\t')[0].split(',')
|
|
||||||
coords = tokens[:]
|
|
||||||
coords_len = len(coords) / 2
|
|
||||||
coords = list(map(float, coords))
|
|
||||||
gt_poly = np.array(coords).reshape(1, coords_len, 2)
|
|
||||||
cv2.polylines(
|
|
||||||
vis_det_img,
|
|
||||||
np.array(gt_poly).astype(np.int32),
|
|
||||||
isClosed=True,
|
|
||||||
color=(255, 0, 0),
|
|
||||||
thickness=2)
|
|
||||||
|
|
||||||
for detected_poly, recognized_str in zip(poly_list, seq_strs):
|
|
||||||
cv2.polylines(
|
|
||||||
vis_det_img,
|
|
||||||
np.array(detected_poly[np.newaxis, ...]).astype(np.int32),
|
|
||||||
isClosed=True,
|
|
||||||
color=(0, 0, 255),
|
|
||||||
thickness=2)
|
|
||||||
cv2.putText(
|
|
||||||
vis_det_img,
|
|
||||||
recognized_str,
|
|
||||||
org=(int(detected_poly[0, 0]), int(detected_poly[0, 1])),
|
|
||||||
fontFace=cv2.FONT_HERSHEY_COMPLEX,
|
|
||||||
fontScale=0.7,
|
|
||||||
color=(0, 255, 0),
|
|
||||||
thickness=1)
|
|
||||||
|
|
||||||
if not os.path.exists(result_path):
|
|
||||||
os.makedirs(result_path)
|
|
||||||
cv2.imwrite("{}/{}_detection.jpg".format(result_path, im_prefix),
|
|
||||||
vis_det_img)
|
|
||||||
|
|
||||||
|
|
||||||
def visualization_output(src_image,
|
|
||||||
f_tcl,
|
|
||||||
f_chars,
|
|
||||||
output_dir,
|
|
||||||
image_prefix=None):
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
# restore BGR image, CHW -> HWC
|
|
||||||
im_mean = [0.485, 0.456, 0.406]
|
|
||||||
im_std = [0.229, 0.224, 0.225]
|
|
||||||
|
|
||||||
im_mean = np.array(im_mean).reshape((3, 1, 1))
|
|
||||||
im_std = np.array(im_std).reshape((3, 1, 1))
|
|
||||||
src_image *= im_std
|
|
||||||
src_image += im_mean
|
|
||||||
src_image = src_image.transpose([1, 2, 0])
|
|
||||||
src_image = src_image[:, :, ::-1] * 255 # BGR -> RGB
|
|
||||||
H, W, _ = src_image.shape
|
|
||||||
|
|
||||||
file_prefix = image_prefix if image_prefix is not None else str(
|
|
||||||
int(time.time() * 1000))
|
|
||||||
if not os.path.exists(output_dir):
|
|
||||||
os.makedirs(output_dir)
|
|
||||||
|
|
||||||
# visualization f_tcl
|
|
||||||
tcl_file_name = os.path.join(output_dir, file_prefix + '_0_tcl.jpg')
|
|
||||||
vis_tcl_img = src_image.copy()
|
|
||||||
f_tcl_resized = cv2.resize(f_tcl, dsize=(W, H))
|
|
||||||
vis_tcl_img[:, :, 1] = f_tcl_resized * 255
|
|
||||||
cv2.imwrite(tcl_file_name, vis_tcl_img)
|
|
||||||
|
|
||||||
# visualization char maps
|
|
||||||
vis_char_img = src_image.copy()
|
|
||||||
# CHW -> HWC
|
|
||||||
char_file_name = os.path.join(output_dir, file_prefix + '_1_chars.jpg')
|
|
||||||
f_chars = np.argmax(f_chars, axis=2)[:, :, np.newaxis].astype('float32')
|
|
||||||
f_chars[f_chars < 95] = 1.0
|
|
||||||
f_chars[f_chars == 95] = 0.0
|
|
||||||
f_chars_resized = cv2.resize(f_chars, dsize=(W, H))
|
|
||||||
vis_char_img[:, :, 1] = f_chars_resized * 255
|
|
||||||
cv2.imwrite(char_file_name, vis_char_img)
|
|
||||||
|
|
||||||
|
|
||||||
def visualize_point_result(im_fn, point_list, point_pair_list, src_im, gt_dir,
|
|
||||||
result_path):
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
im_basename = os.path.basename(im_fn)
|
|
||||||
im_prefix = im_basename[:im_basename.rfind('.')]
|
|
||||||
vis_det_img = src_im.copy()
|
|
||||||
|
|
||||||
# draw gt bbox on the image.
|
|
||||||
text_path = os.path.join(gt_dir, im_prefix + '.txt')
|
|
||||||
fid = open(text_path, 'r')
|
|
||||||
lines = [line.strip() for line in fid.readlines()]
|
|
||||||
for line in lines:
|
|
||||||
tokens = line.strip().split('\t')
|
|
||||||
coords = tokens[0].split(',')
|
|
||||||
coords_len = len(coords)
|
|
||||||
coords = list(map(float, coords))
|
|
||||||
gt_poly = np.array(coords).reshape(1, coords_len / 2, 2)
|
|
||||||
cv2.polylines(
|
|
||||||
vis_det_img,
|
|
||||||
np.array(gt_poly).astype(np.int32),
|
|
||||||
isClosed=True,
|
|
||||||
color=(255, 255, 255),
|
|
||||||
thickness=1)
|
|
||||||
|
|
||||||
for point, point_pair in zip(point_list, point_pair_list):
|
|
||||||
cv2.line(
|
|
||||||
vis_det_img,
|
|
||||||
tuple(point_pair[0]),
|
|
||||||
tuple(point_pair[1]), (0, 255, 255),
|
|
||||||
thickness=1)
|
|
||||||
cv2.circle(vis_det_img, tuple(point), 2, (0, 0, 255))
|
|
||||||
cv2.circle(vis_det_img, tuple(point_pair[0]), 2, (255, 0, 0))
|
|
||||||
cv2.circle(vis_det_img, tuple(point_pair[1]), 2, (0, 255, 0))
|
|
||||||
|
|
||||||
if not os.path.exists(result_path):
|
|
||||||
os.makedirs(result_path)
|
|
||||||
cv2.imwrite("{}/{}_border_points.jpg".format(result_path, im_prefix),
|
|
||||||
vis_det_img)
|
|
||||||
|
|
||||||
|
|
||||||
def resize_image(im, max_side_len=512):
|
def resize_image(im, max_side_len=512):
|
||||||
"""
|
"""
|
||||||
resize image to a size multiple of max_stride which is required by the network
|
resize image to a size multiple of max_stride which is required by the network
|
||||||
|
@ -295,49 +169,3 @@ def norm2(x, axis=None):
|
||||||
|
|
||||||
def cos(p1, p2):
|
def cos(p1, p2):
|
||||||
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
|
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
|
||||||
|
|
||||||
|
|
||||||
def generate_direction_info(image_fn,
|
|
||||||
H,
|
|
||||||
W,
|
|
||||||
ratio_h,
|
|
||||||
ratio_w,
|
|
||||||
max_length=640,
|
|
||||||
out_scale=4,
|
|
||||||
gt_dir=None):
|
|
||||||
"""
|
|
||||||
"""
|
|
||||||
im_basename = os.path.basename(image_fn)
|
|
||||||
im_prefix = im_basename[:im_basename.rfind('.')]
|
|
||||||
instance_direction_map = np.zeros(shape=[H // out_scale, W // out_scale, 3])
|
|
||||||
|
|
||||||
if gt_dir is None:
|
|
||||||
gt_dir = '/home/vis/huangzuming/data/SYNTH_DATA/part_vgg_synth_icdar/processed/val/poly'
|
|
||||||
|
|
||||||
# get gt label map
|
|
||||||
text_path = os.path.join(gt_dir, im_prefix + '.txt')
|
|
||||||
fid = open(text_path, 'r')
|
|
||||||
lines = [line.strip() for line in fid.readlines()]
|
|
||||||
for label_idx, line in enumerate(lines, start=1):
|
|
||||||
coords, txt = line.strip().split('\t')
|
|
||||||
if txt == '###':
|
|
||||||
continue
|
|
||||||
tokens = coords.strip().split(',')
|
|
||||||
coords = list(map(float, tokens))
|
|
||||||
poly = np.array(coords).reshape(4, 2) * np.array(
|
|
||||||
[ratio_w, ratio_h]).reshape(1, 2) / out_scale
|
|
||||||
mid_idx = poly.shape[0] // 2
|
|
||||||
direct_vector = (
|
|
||||||
(poly[mid_idx] + poly[mid_idx - 1]) - (poly[0] + poly[-1])) / 2.0
|
|
||||||
|
|
||||||
direct_vector /= len(txt)
|
|
||||||
# l2_distance = norm2(direct_vector)
|
|
||||||
# avg_char_distance = l2_distance / len(txt)
|
|
||||||
avg_char_distance = 1.0
|
|
||||||
|
|
||||||
direct_label = (direct_vector[0], direct_vector[1], avg_char_distance)
|
|
||||||
cv2.fillPoly(instance_direction_map,
|
|
||||||
poly.round().astype(np.int32)[np.newaxis, :, :],
|
|
||||||
direct_label)
|
|
||||||
instance_direction_map = instance_direction_map.transpose([2, 0, 1])
|
|
||||||
return instance_direction_map[:2, ...]
|
|
||||||
|
|
|
@ -44,7 +44,6 @@ class ArgsParser(ArgumentParser):
|
||||||
|
|
||||||
def parse_args(self, argv=None):
|
def parse_args(self, argv=None):
|
||||||
args = super(ArgsParser, self).parse_args(argv)
|
args = super(ArgsParser, self).parse_args(argv)
|
||||||
args.config = '/Users/hongyongjie/project/PaddleOCR/configs/e2e/e2e_r50_vd_pg.yml'
|
|
||||||
assert args.config is not None, \
|
assert args.config is not None, \
|
||||||
"Please specify --config=configure_file_path."
|
"Please specify --config=configure_file_path."
|
||||||
args.opt = self._parse_opt(args.opt)
|
args.opt = self._parse_opt(args.opt)
|
||||||
|
|
Loading…
Reference in New Issue