diff --git a/dataset.py b/dataset.py index 145926f..6ecfebf 100644 --- a/dataset.py +++ b/dataset.py @@ -4,7 +4,15 @@ from utils import load_pkl def collate_fn(cfg): + def collate_fn_intra(batch): + """ + Arg : + batch () : 数据集 + Returna : + x (dict) : key为词,value为长度 + y (List) : 关系对应值的集合 + """ batch.sort(key=lambda data: data['seq_len'], reverse=True) max_len = batch[0]['seq_len'] @@ -48,7 +56,9 @@ def collate_fn(cfg): class CustomDataset(Dataset): - """默认使用 List 存储数据""" + """ + 默认使用 List 存储数据 + """ def __init__(self, fp): self.file = load_pkl(fp) @@ -68,7 +78,6 @@ if __name__ == '__main__': vocab = load_pkl(vocab_path) train_ds = CustomDataset(train_data_path) train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, drop_last=False) - for batch_idx, (x, y) in enumerate(train_dl): word = x['word'] for idx in word: diff --git a/metrics.py b/metrics.py index 1e4b420..6ca22d6 100644 --- a/metrics.py +++ b/metrics.py @@ -44,10 +44,16 @@ class PRMetric(): self.y_pred = np.empty(0) def reset(self): + """ + 重置为0 + """ self.y_true = np.empty(0) self.y_pred = np.empty(0) def update(self, y_true: torch.Tensor, y_pred: torch.Tensor): + """ + 更新tensor,保留值,取消原有梯度 + """ y_true = y_true.cpu().detach().numpy() y_pred = y_pred.cpu().detach().numpy() y_pred = np.argmax(y_pred, axis=-1) @@ -56,6 +62,9 @@ class PRMetric(): self.y_pred = np.append(self.y_pred, y_pred) def compute(self): + """ + 计算acc,p,r,f1并返回 + """ p, r, f1, _ = precision_recall_fscore_support(self.y_true, self.y_pred, average='macro', warn_for=tuple()) _, _, acc, _ = precision_recall_fscore_support(self.y_true, self.y_pred, average='micro', warn_for=tuple()) diff --git a/models/LM.py b/models/LM.py index 44eb60b..9a7dab6 100644 --- a/models/LM.py +++ b/models/LM.py @@ -16,7 +16,9 @@ class LM(BasicModule): def forward(self, x): word, lens = x['word'], x['lens'] mask = seq_len_to_mask(lens, mask_pos_to_true=False) - last_hidden_state, pooler_output = self.bert(word, attention_mask=mask) + a = self.bert(word, attention_mask=mask) + last_hidden_state = a[0] + pooler_output = a[1] out, out_pool = self.bilstm(last_hidden_state, lens) out_pool = self.dropout(out_pool) output = self.fc(out_pool) diff --git a/module/Attention.py b/module/Attention.py index 851be48..91ef887 100644 --- a/module/Attention.py +++ b/module/Attention.py @@ -13,16 +13,17 @@ class DotAttention(nn.Module): def forward(self, Q, K, V, mask_out=None, head_mask=None): """ - 一般输入信息 X 时,假设 K = V = X - + 一般输入信息 X 时,假设 K = V = Xs att_weight = softmax( score_func(q, k) ) att = sum( att_weight * v ) - - :param Q: [..., L, H] - :param K: [..., S, H] - :param V: [..., S, H] - :param mask_out: [..., 1, S] - :return: + Args: + Q: [..., L, H] + K: [..., S, H] + V: [..., S, H] + mask_out: [..., 1, S] + Return: + attention_out + attention_weight """ H = Q.size(-1) @@ -52,9 +53,10 @@ class DotAttention(nn.Module): class MultiHeadAttention(nn.Module): def __init__(self, embed_dim, num_heads, dropout=0.0, output_attentions=True): """ - :param embed_dim: 输入的维度,必须能被 num_heads 整除 - :param num_heads: attention 的个数 - :param dropout: float。 + Args: + embed_dim: 输入的维度,必须能被 num_heads 整除 + num_heads: attention 的个数 + dropout: float。 """ super(MultiHeadAttention, self).__init__() self.num_heads = num_heads @@ -72,12 +74,13 @@ class MultiHeadAttention(nn.Module): def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None): """ - :param Q: [B, L, Hs] - :param K: [B, S, Hs] - :param V: [B, S, Hs] - :param key_padding_mask: [B, S] 为 1/True 的地方需要 mask - :param attention_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask - :param head_mask: [N] 指定 head mask 掉, 为 1/True 的地方需要 mask + Args: + Q: [B, L, Hs] + K: [B, S, Hs] + V: [B, S, Hs] + key_padding_mask: [B, S] 为 1/True 的地方需要 mask + attention_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask + head_mask: [N] 指定 head mask 掉, 为 1/True 的地方需要 mask """ B, L, Hs = Q.shape S = V.size(1) diff --git a/module/CNN.py b/module/CNN.py index 8d43366..4495200 100644 --- a/module/CNN.py +++ b/module/CNN.py @@ -21,12 +21,13 @@ class CNN(nn.Module): """ def __init__(self, config): """ - in_channels : 一般就是 word embedding 的维度,或者 hidden size 的维度 - out_channels : int - kernel_sizes : list 为了保证输出长度=输入长度,必须为奇数: 3, 5, 7... - activation : [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh] - pooling_strategy : [max, avg, cls] - dropout: : float + Args: + in_channels: 一般就是 word embedding 的维度,或者 hidden size 的维度 + out_channels: int + kernel_sizes: list 为了保证输出长度=输入长度,必须为奇数: 3, 5, 7... + activation: [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh] + pooling_strategy: [max, avg, cls] + dropout: float """ super(CNN, self).__init__() @@ -74,10 +75,13 @@ class CNN(nn.Module): def forward(self, x, mask=None): """ - :param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值 - :param mask: [batch_size, max_len], 句长部分为0,padding部分为1。不影响卷积运算,max-pool一定不会pool到pad为0的位置 - :return: - """ + Args: + x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值 + mask: [batch_size, max_len], 句长部分为0,padding部分为1。不影响卷积运算,max-pool一定不会pool到pad为0的位置 + Return: + x: torch.Tensor + xp: torch.Tensor + """ # [B, L, H] -> [B, H, L] (注释:将 H 维度当作输入 channel 维度) x = torch.transpose(x, 1, 2) diff --git a/module/Capsule.py b/module/Capsule.py index 6c69668..1f4a3b1 100644 --- a/module/Capsule.py +++ b/module/Capsule.py @@ -27,7 +27,8 @@ class Capsule(nn.Module): def forward(self, x): """ - x: [B, L, H] # 从 CNN / RNN 得到的结果 + Args: + x: [B, L, H] # 从 CNN / RNN 得到的结果 L 作为 input_num_capsules, H 作为 input_dim_capsule """ B, I, _ = x.size() # I 是 input_num_capsules diff --git a/module/Embedding.py b/module/Embedding.py index 83074a3..2b435d6 100644 --- a/module/Embedding.py +++ b/module/Embedding.py @@ -5,9 +5,10 @@ import torch.nn as nn class Embedding(nn.Module): def __init__(self, config): """ - word embedding: 一般 0 为 padding - pos embedding: 一般 0 为 padding - dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加 + Args: + word embedding: 一般 0 为 padding + pos embedding: 一般 0 为 padding + dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加 """ super(Embedding, self).__init__() diff --git a/module/RNN.py b/module/RNN.py index 052b3f9..f177b5e 100644 --- a/module/RNN.py +++ b/module/RNN.py @@ -55,6 +55,7 @@ class RNN(nn.Module): B, L, _ = x.size() H, N = self.hidden_size, self.num_layers + x_len = x_len.cpu() x = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=True) output, hn = self.rnn(x) output, _ = pad_packed_sequence(output, batch_first=True, total_length=L) diff --git a/preprocess.py b/preprocess.py index 866039c..67c4efd 100644 --- a/preprocess.py +++ b/preprocess.py @@ -11,6 +11,15 @@ logger = logging.getLogger(__name__) def _handle_pos_limit(pos: List[int], limit: int) -> List[int]: + """ + 处理句子长度,设定句长限制 + Args : + pos (List[int]) : 句子对应的List + limit (int) : 限制的数 + Return : + [p + limit + 1 for p in pos] (List[int]) : 处理后的结果 + + """ for i, p in enumerate(pos): if p > limit: pos[i] = limit @@ -20,6 +29,12 @@ def _handle_pos_limit(pos: List[int], limit: int) -> List[int]: def _add_pos_seq(train_data: List[Dict], cfg): + """ + 增加位置序列 + Args : + train_data (List[Dict]) : 数据集合 + cfg : 配置文件 + """ for d in train_data: entities_idx = [d['head_idx'], d['tail_idx'] ] if d['head_idx'] < d['tail_idx'] else [d['tail_idx'], d['head_idx']] @@ -39,6 +54,12 @@ def _add_pos_seq(train_data: List[Dict], cfg): def _convert_tokens_into_index(data: List[Dict], vocab): + """ + 将tokens转换成index值 + Args : + data (List[Dict]) : 数据集合 + vocab (Class) : 词汇表 + """ unk_str = '[UNK]' unk_idx = vocab.word2idx[unk_str] @@ -48,6 +69,13 @@ def _convert_tokens_into_index(data: List[Dict], vocab): def _serialize_sentence(data: List[Dict], serial, cfg): + """ + 将句子分词 + Args : + data (List[Dict]) : 数据集合 + serial (Class): Serializer类 + cfg : 配置文件 + """ for d in data: sent = d['sentence'].strip() sent = sent.replace(d['head'], ' head ', 1).replace(d['tail'], ' tail ', 1) @@ -68,6 +96,12 @@ def _serialize_sentence(data: List[Dict], serial, cfg): def _lm_serialize(data: List[Dict], cfg): + """ + lm模型分词 + Args : + data (List[Dict]) : 数据集合 + cfg : 配置文件 + """ logger.info('use bert tokenizer...') tokenizer = BertTokenizer.from_pretrained(cfg.lm_file) for d in data: @@ -79,6 +113,12 @@ def _lm_serialize(data: List[Dict], cfg): def _add_relation_data(rels: Dict, data: List) -> None: + """ + 增加关系数据 + Args : + rels (Dict) : 关系字典集合 + data (List) : 所需增加的关系数据 + """ for d in data: d['rel2idx'] = rels[d['relation']]['index'] d['head_type'] = rels[d['relation']]['head_type'] @@ -86,6 +126,13 @@ def _add_relation_data(rels: Dict, data: List) -> None: def _handle_relation_data(relation_data: List[Dict]) -> Dict: + """ + 处理关系数据,每一个关系有index,head_type,tail_type三个属性 + Arg : + relation_data (List[Dict]) : 所需要处理的关系数据 + Return : + rels (Dict) : 处理之后的结果 + """ rels = OrderedDict() relation_data = sorted(relation_data, key=lambda i: int(i['index'])) for d in relation_data: @@ -99,7 +146,9 @@ def _handle_relation_data(relation_data: List[Dict]) -> Dict: def preprocess(cfg): - + """ + 数据预处理阶段 + """ logger.info('===== start preprocess data =====') train_fp = os.path.join(cfg.cwd, cfg.data_path, 'train.csv') valid_fp = os.path.join(cfg.cwd, cfg.data_path, 'valid.csv') diff --git a/serializer.py b/serializer.py index c896e07..b061afe 100644 --- a/serializer.py +++ b/serializer.py @@ -15,6 +15,14 @@ class Serializer(): self.do_chinese_split = do_chinese_split def serialize(self, text, never_split: List = None): + """ + 将一段文本按照制定拆分规则,拆分成一个词汇List + Args : + text (String) : 所需拆分文本 + never_split (List) : 不拆分的词,默认为空 + Rerurn : + output_tokens (List): 拆分后的结果 + """ never_split = self.never_split + (never_split if never_split is not None else []) text = self._clean_text(text) @@ -36,7 +44,13 @@ class Serializer(): return output_tokens def _clean_text(self, text): - """Performs invalid character removal and whitespace cleanup on text.""" + """ + 删除文本中无效字符以及空白字符 + Arg : + text (String) : 所需删除的文本 + Return : + "".join(output) (String) : 删除后的文本 + """ output = [] for char in text: cp = ord(char) @@ -49,6 +63,14 @@ class Serializer(): return "".join(output) def _use_jieba_cut(self, text, never_split): + """ + 使用jieba分词 + Args : + text (String) : 所需拆分文本 + never_split (List) : 不拆分的词 + Return : + tokens (List) : 拆分完的结果 + """ for word in never_split: jieba.suggest_freq(word, True) tokens = jieba.lcut(text) @@ -61,7 +83,13 @@ class Serializer(): return tokens def _tokenize_chinese_chars(self, text): - """Adds whitespace around any CJK character.""" + """ + 在CJK字符周围添加空格 + Arg : + text (String) : 所需拆分文本 + Return : + "".join(output) (String) : 添加完后的文本 + """ output = [] for char in text: cp = ord(char) @@ -74,7 +102,13 @@ class Serializer(): return "".join(output) def _orig_tokenize(self, text): - """Splits text on whitespace and some punctuations like comma or period""" + """ + 在空白和一些标点符号(如逗号或句点)上拆分文本 + Arg : + text (String) : 所需拆分文本 + Return : + tokens (List) : 分词完的结果 + """ text = text.strip() if not text: return [] @@ -86,7 +120,13 @@ class Serializer(): return tokens def _whitespace_tokenize(self, text): - """Runs basic whitespace cleaning and splitting on a piece of text.""" + """ + 进行基本的空白字符清理和分割 + Arg : + text (String) : 所需拆分文本 + Return : + tokens (List) : 分词完的结果 + """ text = text.strip() if not text: return [] @@ -94,7 +134,14 @@ class Serializer(): return tokens def _run_strip_accents(self, text): - """Strips accents from a piece of text.""" + """ + 从文本中去除重音符号 + Arg : + text (String) : 所需拆分文本 + Return : + "".join(output) (String) : 去除后的文本 + + """ text = unicodedata.normalize("NFD", text) output = [] for char in text: @@ -105,7 +152,15 @@ class Serializer(): return "".join(output) def _run_split_on_punc(self, text, never_split=None): - """Splits punctuation on a piece of text.""" + """ + 通过标点符号拆分文本 + Args : + text (String) : 所需拆分文本 + never_split (List) : 不拆分的词,默认为空 + Return : + ["".join(x) for x in output] (List) : 拆分完的结果 + """ + if never_split is not None and text in never_split: return [text] chars = list(text) @@ -128,7 +183,13 @@ class Serializer(): @staticmethod def is_control(char): - """Checks whether `chars` is a control character.""" + """ + 判断字符是否为控制字符 + Arg : + char : 字符 + Return : + bool : 判断结果 + """ # These are technically control characters but we count them as whitespace # characters. if char == "\t" or char == "\n" or char == "\r": @@ -140,7 +201,13 @@ class Serializer(): @staticmethod def is_whitespace(char): - """Checks whether `chars` is a whitespace character.""" + """ + 判断字符是否为空白字符 + Arg : + char : 字符 + Return : + bool : 判断结果 + """ # \t, \n, and \r are technically contorl characters but we treat them # as whitespace since they are generally considered as such. if char == " " or char == "\t" or char == "\n" or char == "\r": @@ -152,7 +219,15 @@ class Serializer(): @staticmethod def is_chinese_char(cp): - """Checks whether CP is the codepoint of a CJK character.""" + """ + + 判断字符是否为中文字符 + Arg : + cp (char): 字符 + Return : + bool : 判断结果 + + """ # This defines a "chinese character" as anything in the CJK Unicode block: # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) # @@ -174,7 +249,13 @@ class Serializer(): @staticmethod def is_punctuation(char): - """Checks whether `chars` is a punctuation character.""" + """ + 判断字符是否为标点字符 + Arg : + char : 字符 + Return : + bool : 判断结果 + """ cp = ord(char) # We treat all non-letter/number ASCII as punctuation. # Characters such as "^", "$", and "`" are not in the Unicode diff --git a/trainer.py b/trainer.py index 478bc92..c2a889e 100644 --- a/trainer.py +++ b/trainer.py @@ -7,6 +7,20 @@ logger = logging.getLogger(__name__) def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg): + """ + 训练模型 + Args: + epoch (int): 训练步数 + model (class): 训练的模型 + dataloader (dict): 数据集 + optimizer (Callable): 优化器 + criterion (Callable): 损失函数 + device (torch.device): 训练的设备 + writer (class): 输出 + cfg: 配置文件 + Return: + losses[-1] : loss值 + """ model.train() metric = PRMetric() @@ -55,6 +69,20 @@ def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg): def validate(epoch, model, dataloader, criterion, device, cfg): + """ + 验证模型 + Args: + epoch (int): 训练步数 + model (class): 训练的模型 + dataloader (dict): 数据集 + optimizer (Callable): 优化器 + criterion (Callable): 损失函数 + device (torch.device): 训练的设备 + cfg: 配置文件 + Return: + f1 : f1值 + loss : loss值 + """ model.eval() metric = PRMetric() diff --git a/utils/ioUtils.py b/utils/ioUtils.py index 42b5a91..f7e8626 100644 --- a/utils/ioUtils.py +++ b/utils/ioUtils.py @@ -22,6 +22,14 @@ Path = str def load_pkl(fp: Path, verbose: bool = True) -> Any: + """ + 读取文件 + Args : + fp (String) : 读取数据地址 + verbose (bool) : 是否打印日志 + Return : + data (Any) : 读取的数据 + """ if verbose: logger.info(f'load data from {fp}') @@ -31,6 +39,13 @@ def load_pkl(fp: Path, verbose: bool = True) -> Any: def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None: + """ + 保存文件 + Args : + data (Any) : 数据 + fp (String) :保存的地址 + verbose (bool) : 是否打印日志 + """ if verbose: logger.info(f'save data in {fp}') @@ -39,6 +54,15 @@ def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None: def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List: + """ + 读取csv格式文件 + Args : + fp (String) : 保存地址 + is_tsv (bool) : 是否为excel-tab格式 + verbose (bool) : 是否打印日志 + Return : + list(reader) (List): 读取的List数据 + """ if verbose: logger.info(f'load csv from {fp}') @@ -49,6 +73,15 @@ def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List: def save_csv(data: List[Dict], fp: Path, save_in_tsv: False, write_head=True, verbose=True) -> None: + """ + 保存csv格式文件 + Args : + data (List) : 所需保存的List数据 + fp (String) : 保存地址 + save_in_tsv (bool) : 是否保存为excel-tab格式 + write_head (bool) : 是否写表头 + verbose (bool) : 是否打印日志 + """ if verbose: logger.info(f'save csv file in: {fp}') @@ -62,6 +95,15 @@ def save_csv(data: List[Dict], fp: Path, save_in_tsv: False, write_head=True, ve def load_jsonld(fp: Path, verbose: bool = True) -> List: + + """ + 读取jsonld文件 + Args: + fp (String): jsonld 文件地址 + verbose (bool): 是否打印日志 + Return: + datas (List) : 读取后的List + """ if verbose: logger.info(f'load jsonld from {fp}') @@ -76,16 +118,21 @@ def load_jsonld(fp: Path, verbose: bool = True) -> List: def save_jsonld(fp): + """ + 保存jsonld格式文件 + """ pass def jsonld2csv(fp: str, verbose: bool = True) -> str: - ''' + """ 读入 jsonld 文件,存储在同位置同名的 csv 文件 - :param fp: jsonld 文件地址 - :param verbose: whether print logging - :return: csv 文件地址 - ''' + Args: + fp (String): jsonld 文件地址 + verbose (bool): 是否打印日志 + Return: + fp_new (String):文件地址 + """ data = [] root, ext = os.path.splitext(fp) fp_new = root + '.csv' @@ -108,12 +155,14 @@ def jsonld2csv(fp: str, verbose: bool = True) -> str: def csv2jsonld(fp: str, verbose: bool = True) -> str: - ''' - 读入 csv 文件,存储为同位置同名的 jsonld 文件 - :param fp: csv 文件地址 - :param verbose: whether print logging - :return: jsonld 地址 - ''' + """ + 读入 csv 文件,存储在同位置同名的 jsonld 文件 + Args: + fp (String): csv 文件地址 + verbose (bool): 是否打印日志 + Return: + fp_new (String):文件地址 + """ data = [] root, ext = os.path.splitext(fp) fp_new = root + '.jsonld' diff --git a/utils/nnUtils.py b/utils/nnUtils.py index 9d14e57..ef2a63a 100644 --- a/utils/nnUtils.py +++ b/utils/nnUtils.py @@ -14,26 +14,30 @@ __all__ = [ def manual_seed(seed: int = 1) -> None: + """ + 设置seed。 + """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - if torch.cuda.CUDA_ENABLED and use_deterministic_cudnn: - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False + #if torch.cuda.CUDA_ENABLED and use_deterministic_cudnn: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True): """ 将一个表示sequence length的一维数组转换为二维的mask,默认pad的位置为1。 - 转变 1-d seq_len到2-d mask. + 转变 1-d seq_len到2-d mask。 - :param list, np.ndarray, torch.LongTensor seq_len: shape将是(B,) - :param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 - 区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 - :return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8 + Args : + seq_len (list, np.ndarray, torch.LongTensor) : shape将是(B,) + max_len (int): 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 + Return: + mask (np.ndarray, torch.Tensor) : shape将是(B, max_length), 元素类似为bool或torch.uint8 """ if isinstance(seq_len, list): seq_len = np.array(seq_len) @@ -58,9 +62,11 @@ def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None def to_one_hot(x: torch.Tensor, length: int) -> torch.Tensor: """ - :param x: [B] 一般是 target 的值 - :param length: L 一般是关系种类树 - :return: [B, L] 每一行,只有对应位置为1,其余为0 + Args: + x (torch.Tensor):[B] , 一般是 target 的值 + length (int) : L ,一般是关系种类树 + Return: + x_one_hot.to(device=x.device) (torch.Tensor) : [B, L] 每一行,只有对应位置为1,其余为0 """ B = x.size(0) x_one_hot = torch.zeros(B, length) diff --git a/vocab.py b/vocab.py index a4728cd..2b877ac 100644 --- a/vocab.py +++ b/vocab.py @@ -33,6 +33,9 @@ SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES)) class Vocab(object): + """ + 构建词汇表,增加词汇,删除低频词汇 + """ def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS): self.name = name self.init_tokens = init_tokens @@ -44,10 +47,18 @@ class Vocab(object): self._add_init_tokens() def _add_init_tokens(self): + """ + 添加初始tokens + """ for token in self.init_tokens.values(): self._add_word(token) def _add_word(self, word: str): + """ + 增加单个词汇 + Arg : + word (String) : 增加的词汇 + """ if word not in self.word2idx: self.word2idx[word] = self.count self.word2count[word] = 1 @@ -57,15 +68,21 @@ class Vocab(object): self.word2count[word] += 1 def add_words(self, words: Sequence): + """ + 通过数组增加词汇 + Arg : + words (List) : 增加的词汇组 + """ for word in words: self._add_word(word) def trim(self, min_freq=2, verbose: Optional[bool] = True): - '''当 word 词频低于 min_freq 时,从词库中删除 - - Args: - param min_freq: 最低词频 - ''' + """ + 当 word 词频低于 min_freq 时,从词库中删除 + Args: + min_freq (int): 最低词频 + verbose (bool) : 是否打印日志 + """ assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}' min_freq = int(min_freq) if min_freq < 2: