fix comment
This commit is contained in:
parent
e7d24ac8b8
commit
550022ea66
|
@ -211,7 +211,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) > self.max_text_len:
|
||||
if len(text) >= self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
||||
|
|
|
@ -194,18 +194,3 @@ class AttentionLSTMCell(nn.Layer):
|
|||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
|
||||
return cur_hidden, alpha
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
paddle.disable_static()
|
||||
|
||||
model = Attention(100, 200, 10)
|
||||
|
||||
x = np.random.uniform(-1, 1, [2, 10, 100]).astype(np.float32)
|
||||
y = np.random.randint(0, 10, [2, 21]).astype(np.int32)
|
||||
|
||||
xp = paddle.to_tensor(x)
|
||||
yp = paddle.to_tensor(y)
|
||||
|
||||
res = model(inputs=xp, targets=yp, is_train=True, batch_max_length=20)
|
||||
print("res: ", res.shape)
|
||||
|
|
Loading…
Reference in New Issue