deepke/vocab.py

109 lines
2.7 KiB
Python

import logging
from collections import OrderedDict
from typing import Sequence, Optional
logger = logging.getLogger(__name__)
SPECIAL_TOKENS_KEYS = [
"pad_token",
"unk_token",
"mask_token",
"cls_token",
"sep_token",
"bos_token",
"eos_token",
"head_token",
"tail_token",
]
SPECIAL_TOKENS_VALUES = [
"[PAD]",
"[UNK]",
"[MASK]",
"[CLS]",
"[SEP]",
"[BOS]",
"[EOS]",
"HEAD",
"TAIL",
]
SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES))
class Vocab(object):
def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS):
self.name = name
self.init_tokens = init_tokens
self.trimed = False
self.word2idx = {}
self.word2count = {}
self.idx2word = {}
self.count = 0
self._add_init_tokens()
def _add_init_tokens(self):
for token in self.init_tokens.values():
self._add_word(token)
def _add_word(self, word: str):
if word not in self.word2idx:
self.word2idx[word] = self.count
self.word2count[word] = 1
self.idx2word[self.count] = word
self.count += 1
else:
self.word2count[word] += 1
def add_words(self, words: Sequence):
for word in words:
self._add_word(word)
def trim(self, min_freq=2, verbose: Optional[bool] = True):
'''当 word 词频低于 min_freq 时,从词库中删除
Args:
param min_freq: 最低词频
'''
assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}'
min_freq = int(min_freq)
if min_freq < 2:
return
if self.trimed:
return
self.trimed = True
keep_words = []
new_words = []
for k, v in self.word2count.items():
if v >= min_freq:
keep_words.append(k)
new_words.extend([k] * v)
if verbose:
before_len = len(keep_words)
after_len = len(self.word2idx) - len(self.init_tokens)
logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(
before_len, after_len, before_len / after_len * 100))
# Reinitialize dictionaries
self.word2idx = {}
self.word2count = {}
self.idx2word = {}
self.count = 0
self._add_init_tokens()
self.add_words(new_words)
if __name__ == '__main__':
vocab = Vocab('test')
sent = ' 我是中国人,我爱中国。'
sent = list(sent)
print(sent)
vocab.add_words(sent)
print(vocab.word2count)
vocab.trim(2)
print(vocab.word2count)