commit
10b54d6696
|
@ -24,7 +24,7 @@ from paddle.nn import functional as F
|
|||
|
||||
|
||||
def get_para_bias_attr(l2_decay, k, name):
|
||||
regularizer = paddle.fluid.regularizer.L2Decay(l2_decay)
|
||||
regularizer = paddle.regularizer.L2Decay(l2_decay)
|
||||
stdv = 1.0 / math.sqrt(k * 1.0)
|
||||
initializer = nn.initializer.Uniform(-stdv, stdv)
|
||||
weight_attr = ParamAttr(
|
||||
|
@ -33,6 +33,7 @@ def get_para_bias_attr(l2_decay, k, name):
|
|||
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
|
||||
return [weight_attr, bias_attr]
|
||||
|
||||
|
||||
class CTCHead(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
|
||||
super(CTCHead, self).__init__()
|
||||
|
|
|
@ -17,7 +17,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from paddle import fluid
|
||||
import paddle
|
||||
|
||||
|
||||
class L1Decay(object):
|
||||
|
@ -32,8 +32,7 @@ class L1Decay(object):
|
|||
self.regularization_coeff = factor
|
||||
|
||||
def __call__(self):
|
||||
reg = fluid.regularizer.L1Decay(
|
||||
regularization_coeff=self.regularization_coeff)
|
||||
reg = paddle.regularizer.L1Decay(self.regularization_coeff)
|
||||
return reg
|
||||
|
||||
|
||||
|
@ -49,6 +48,5 @@ class L2Decay(object):
|
|||
self.regularization_coeff = factor
|
||||
|
||||
def __call__(self):
|
||||
reg = fluid.regularizer.L2Decay(
|
||||
regularization_coeff=self.regularization_coeff)
|
||||
reg = paddle.regularizer.L2Decay(self.regularization_coeff)
|
||||
return reg
|
||||
|
|
|
@ -102,7 +102,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
# out = self.decode_preds(preds)
|
||||
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
|
@ -116,27 +115,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|||
dict_character = ['blank'] + dict_character
|
||||
return dict_character
|
||||
|
||||
def decode_preds(self, preds):
|
||||
probs_ind = np.argmax(preds, axis=2)
|
||||
|
||||
B, N, _ = preds.shape
|
||||
l = np.ones(B).astype(np.int64) * N
|
||||
length = paddle.to_tensor(l)
|
||||
out = paddle.fluid.layers.ctc_greedy_decoder(preds, 0, length)
|
||||
batch_res = [
|
||||
x[:idx[0]] for x, idx in zip(out[0].numpy(), out[1].numpy())
|
||||
]
|
||||
|
||||
result_list = []
|
||||
for sample_idx, ind, prob in zip(batch_res, probs_ind, preds):
|
||||
char_list = [self.character[idx] for idx in sample_idx]
|
||||
valid_ind = np.where(ind != 0)[0]
|
||||
if len(valid_ind) == 0:
|
||||
continue
|
||||
conf_list = prob[valid_ind, ind[valid_ind]]
|
||||
result_list.append((''.join(char_list), conf_list))
|
||||
return result_list
|
||||
|
||||
|
||||
class AttnLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
|
Loading…
Reference in New Issue