Merge pull request #2357 from caopulan/fix_srn_post
fix srn_postprocess
This commit is contained in:
commit
a09604f897
|
@ -216,6 +216,7 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|||
character_type='en',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
self.max_text_length = kwargs['max_text_length']
|
||||
super(SRNLabelDecode, self).__init__(character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
|
@ -229,9 +230,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|||
preds_idx = np.argmax(pred, axis=1)
|
||||
preds_prob = np.max(pred, axis=1)
|
||||
|
||||
preds_idx = np.reshape(preds_idx, [-1, 25])
|
||||
preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
|
||||
|
||||
preds_prob = np.reshape(preds_prob, [-1, 25])
|
||||
preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
|
||||
|
||||
text = self.decode(preds_idx, preds_prob)
|
||||
|
||||
|
|
Loading…
Reference in New Issue