fix comment

This commit is contained in:
LDOUBLEV 2021-02-01 06:44:04 +00:00
parent e7d24ac8b8
commit 550022ea66
2 changed files with 1 additions and 16 deletions

View File

@ -211,7 +211,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
text = self.encode(text)
if text is None:
return None
if len(text) > self.max_text_len:
if len(text) >= self.max_text_len:
return None
data['length'] = np.array(len(text))
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len

View File

@ -194,18 +194,3 @@ class AttentionLSTMCell(nn.Layer):
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
if __name__ == '__main__':
paddle.disable_static()
model = Attention(100, 200, 10)
x = np.random.uniform(-1, 1, [2, 10, 100]).astype(np.float32)
y = np.random.randint(0, 10, [2, 21]).astype(np.int32)
xp = paddle.to_tensor(x)
yp = paddle.to_tensor(y)
res = model(inputs=xp, targets=yp, is_train=True, batch_max_length=20)
print("res: ", res.shape)