后处理添加类型判断
This commit is contained in:
parent
4402e62959
commit
44840726ff
|
@ -18,6 +18,7 @@ from __future__ import print_function
|
|||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import paddle
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
|
||||
|
@ -130,7 +131,9 @@ class DBPostProcess(object):
|
|||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def __call__(self, pred, shape_list):
|
||||
pred = pred.numpy()[:, 0, :, :]
|
||||
if isinstance(pred, paddle.Tensor):
|
||||
pred = pred.numpy()
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
boxes_batch = []
|
||||
|
@ -140,4 +143,4 @@ class DBPostProcess(object):
|
|||
pred[batch_index], segmentation[batch_index], width, height)
|
||||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
return boxes_batch
|
|
@ -1,4 +1,5 @@
|
|||
import cv2
|
||||
import paddle
|
||||
import numpy as np
|
||||
import pyclipper
|
||||
from shapely.geometry import Polygon
|
||||
|
@ -23,7 +24,9 @@ class DBPostProcess():
|
|||
pred:
|
||||
binary: text region segmentation map, with shape (N, 1,H, W)
|
||||
'''
|
||||
pred = pred.numpy()[:, 0, :, :]
|
||||
if isinstance(pred, paddle.Tensor):
|
||||
pred = pred.numpy()
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = self.binarize(pred)
|
||||
batch_out = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
|
@ -130,4 +133,4 @@ class DBPostProcess():
|
|||
box[:, 0] = box[:, 0] - xmin
|
||||
box[:, 1] = box[:, 1] - ymin
|
||||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
|
@ -100,9 +100,10 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
# out = self.decode_preds(preds)
|
||||
|
||||
preds = F.softmax(preds, axis=2).numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob)
|
||||
|
@ -116,19 +117,18 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|||
return dict_character
|
||||
|
||||
def decode_preds(self, preds):
|
||||
probs = F.softmax(preds, axis=2).numpy()
|
||||
probs_ind = np.argmax(probs, axis=2)
|
||||
probs_ind = np.argmax(preds, axis=2)
|
||||
|
||||
B, N, _ = preds.shape
|
||||
l = np.ones(B).astype(np.int64) * N
|
||||
length = paddle.to_variable(l)
|
||||
length = paddle.to_tensor(l)
|
||||
out = paddle.fluid.layers.ctc_greedy_decoder(preds, 0, length)
|
||||
batch_res = [
|
||||
x[:idx[0]] for x, idx in zip(out[0].numpy(), out[1].numpy())
|
||||
]
|
||||
|
||||
result_list = []
|
||||
for sample_idx, ind, prob in zip(batch_res, probs_ind, probs):
|
||||
for sample_idx, ind, prob in zip(batch_res, probs_ind, preds):
|
||||
char_list = [self.character[idx] for idx in sample_idx]
|
||||
valid_ind = np.where(ind != 0)[0]
|
||||
if len(valid_ind) == 0:
|
||||
|
@ -172,4 +172,4 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
else:
|
||||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
return idx
|
Loading…
Reference in New Issue