109 lines
2.7 KiB
Python
109 lines
2.7 KiB
Python
import logging
|
|
from collections import OrderedDict
|
|
from typing import Sequence, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SPECIAL_TOKENS_KEYS = [
|
|
"pad_token",
|
|
"unk_token",
|
|
"mask_token",
|
|
"cls_token",
|
|
"sep_token",
|
|
"bos_token",
|
|
"eos_token",
|
|
"head_token",
|
|
"tail_token",
|
|
|
|
]
|
|
|
|
SPECIAL_TOKENS_VALUES = [
|
|
"[PAD]",
|
|
"[UNK]",
|
|
"[MASK]",
|
|
"[CLS]",
|
|
"[SEP]",
|
|
"[BOS]",
|
|
"[EOS]",
|
|
"HEAD",
|
|
"TAIL",
|
|
]
|
|
|
|
SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES))
|
|
|
|
|
|
class Vocab(object):
|
|
def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS):
|
|
self.name = name
|
|
self.init_tokens = init_tokens
|
|
self.trimed = False
|
|
self.word2idx = {}
|
|
self.word2count = {}
|
|
self.idx2word = {}
|
|
self.count = 0
|
|
self._add_init_tokens()
|
|
|
|
def _add_init_tokens(self):
|
|
for token in self.init_tokens.values():
|
|
self._add_word(token)
|
|
|
|
def _add_word(self, word: str):
|
|
if word not in self.word2idx:
|
|
self.word2idx[word] = self.count
|
|
self.word2count[word] = 1
|
|
self.idx2word[self.count] = word
|
|
self.count += 1
|
|
else:
|
|
self.word2count[word] += 1
|
|
|
|
def add_words(self, words: Sequence):
|
|
for word in words:
|
|
self._add_word(word)
|
|
|
|
def trim(self, min_freq=2, verbose: Optional[bool] = True):
|
|
'''当 word 词频低于 min_freq 时,从词库中删除
|
|
|
|
Args:
|
|
param min_freq: 最低词频
|
|
'''
|
|
assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}'
|
|
min_freq = int(min_freq)
|
|
if min_freq < 2:
|
|
return
|
|
if self.trimed:
|
|
return
|
|
self.trimed = True
|
|
|
|
keep_words = []
|
|
new_words = []
|
|
|
|
for k, v in self.word2count.items():
|
|
if v >= min_freq:
|
|
keep_words.append(k)
|
|
new_words.extend([k] * v)
|
|
if verbose:
|
|
before_len = len(keep_words)
|
|
after_len = len(self.word2idx) - len(self.init_tokens)
|
|
logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(
|
|
before_len, after_len, before_len / after_len * 100))
|
|
|
|
# Reinitialize dictionaries
|
|
self.word2idx = {}
|
|
self.word2count = {}
|
|
self.idx2word = {}
|
|
self.count = 0
|
|
self._add_init_tokens()
|
|
self.add_words(new_words)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
vocab = Vocab('test')
|
|
sent = ' 我是中国人,我爱中国。'
|
|
sent = list(sent)
|
|
print(sent)
|
|
|
|
vocab.add_words(sent)
|
|
print(vocab.word2count)
|
|
vocab.trim(2)
|
|
print(vocab.word2count)
|