fix comment
This commit is contained in:
parent
0f4d92b63f
commit
e7d24ac8b8
|
@ -43,7 +43,6 @@ Architecture:
|
|||
Backbone:
|
||||
name: ResNet
|
||||
layers: 34
|
||||
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
|
@ -53,7 +52,6 @@ Architecture:
|
|||
hidden_size: 256 #
|
||||
l2_decay: 0.00001
|
||||
|
||||
|
||||
Loss:
|
||||
name: AttentionLoss
|
||||
|
||||
|
|
|
@ -192,25 +192,6 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
def encoder(self, labels, labels_length):
|
||||
"""
|
||||
used to encoder labels readed from LMDB dataset, forexample:
|
||||
[35, 25, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]] encode to
|
||||
'you': [0, 35,25,31, 37, 0, ...] 'sos'you'eos'
|
||||
"""
|
||||
if isinstance(labels, paddle.Tensor):
|
||||
labels = labels.numpy()
|
||||
batch_max_length = labels.shape[
|
||||
1] + 2 # add start token 'sos' and end token 'eos'
|
||||
new_labels = np.zeros(
|
||||
[labels.shape[0], batch_max_length]).astype(np.int64)
|
||||
for i in range(labels.shape[0]):
|
||||
new_labels[i, 1:1 + labels_length[i]] = labels[i, :labels_length[
|
||||
i]] # new_labels[i, 0] = 'sos' token
|
||||
new_labels[i, labels_length[i] + 1] = len(
|
||||
self.character) - 1 # add end charactor 'eos' token
|
||||
return new_labels
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
|
|
Loading…
Reference in New Issue