39 lines
1.3 KiB
Python
39 lines
1.3 KiB
Python
|
import pytest
|
|||
|
from serializer import Serializer
|
|||
|
from vocab import Vocab
|
|||
|
|
|||
|
|
|||
|
def test_vocab():
|
|||
|
vocab = Vocab('test')
|
|||
|
sent = ' 我是中国人,我爱中国。 I\'m Chinese, I love China'
|
|||
|
|
|||
|
serializer = Serializer(do_lower_case=True)
|
|||
|
tokens = serializer.serialize(sent)
|
|||
|
assert tokens == [
|
|||
|
'我', '是', '中', '国', '人', ',', '我', '爱', '中', '国', '。', 'i', "'", 'm', 'chinese', ',', 'i', 'love', 'china'
|
|||
|
]
|
|||
|
|
|||
|
vocab.add_words(tokens)
|
|||
|
unk_str = '[UNK]'
|
|||
|
unk_idx = vocab.word2idx[unk_str]
|
|||
|
|
|||
|
assert vocab.count == 22
|
|||
|
assert len(vocab.word2idx) == len(vocab.idx2word) == len(vocab.word2idx) == 22
|
|||
|
|
|||
|
vocab.trim(2, verbose=False)
|
|||
|
|
|||
|
assert vocab.count == 11
|
|||
|
assert len(vocab.word2idx) == len(vocab.idx2word) == len(vocab.word2idx) == 11
|
|||
|
|
|||
|
token2idx = [vocab.word2idx.get(i, unk_idx) for i in tokens]
|
|||
|
assert len(tokens) == len(token2idx)
|
|||
|
assert token2idx == [7, 1, 8, 9, 1, 1, 7, 1, 8, 9, 1, 10, 1, 1, 1, 1, 10, 1, 1]
|
|||
|
|
|||
|
idx2tokens = [vocab.idx2word.get(i, unk_str) for i in token2idx]
|
|||
|
assert len(idx2tokens) == len(token2idx)
|
|||
|
assert ' '.join(idx2tokens) == '我 [UNK] 中 国 [UNK] [UNK] 我 [UNK] 中 国 [UNK] i [UNK] [UNK] [UNK] [UNK] i [UNK] [UNK]'
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
pytest.main()
|