add max_text_length to SRNLabelDecode
This commit is contained in:
parent
88c6ad8a31
commit
38fc1fae63
|
@ -218,6 +218,7 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(SRNLabelDecode, self).__init__(character_dict_path,
|
super(SRNLabelDecode, self).__init__(character_dict_path,
|
||||||
character_type, use_space_char)
|
character_type, use_space_char)
|
||||||
|
self.max_text_length = kwargs.get('max_text_length', 25)
|
||||||
|
|
||||||
def __call__(self, preds, label=None, *args, **kwargs):
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
pred = preds['predict']
|
pred = preds['predict']
|
||||||
|
@ -229,9 +230,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
||||||
preds_idx = np.argmax(pred, axis=1)
|
preds_idx = np.argmax(pred, axis=1)
|
||||||
preds_prob = np.max(pred, axis=1)
|
preds_prob = np.max(pred, axis=1)
|
||||||
|
|
||||||
preds_idx = np.reshape(preds_idx, [-1, 25])
|
preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
|
||||||
|
|
||||||
preds_prob = np.reshape(preds_prob, [-1, 25])
|
preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
|
||||||
|
|
||||||
text = self.decode(preds_idx, preds_prob)
|
text = self.decode(preds_idx, preds_prob)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue