Merge pull request #2661 from WenmuZhou/fix_srn_post_process

add max_text_length to SRNLabelDecode
This commit is contained in:
xiaoting 2021-04-27 13:58:21 +08:00 committed by GitHub
commit 159b3a26b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions

View File

@ -218,6 +218,7 @@ class SRNLabelDecode(BaseRecLabelDecode):
**kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path,
character_type, use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs):
pred = preds['predict']
@ -229,9 +230,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
preds_idx = np.argmax(pred, axis=1)
preds_prob = np.max(pred, axis=1)
preds_idx = np.reshape(preds_idx, [-1, 25])
preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
preds_prob = np.reshape(preds_prob, [-1, 25])
preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
text = self.decode(preds_idx, preds_prob)