add annotation

This commit is contained in:
tlk-dsg 2021-04-14 13:29:01 +08:00
parent b4c55ac8c4
commit fa400b26b5
14 changed files with 332 additions and 72 deletions

View File

@ -4,7 +4,15 @@ from utils import load_pkl
def collate_fn(cfg):
def collate_fn_intra(batch):
"""
Arg :
batch () : 数据集
Returna :
x (dict) : key为词value为长度
y (List) : 关系对应值的集合
"""
batch.sort(key=lambda data: data['seq_len'], reverse=True)
max_len = batch[0]['seq_len']
@ -48,7 +56,9 @@ def collate_fn(cfg):
class CustomDataset(Dataset):
"""默认使用 List 存储数据"""
"""
默认使用 List 存储数据
"""
def __init__(self, fp):
self.file = load_pkl(fp)
@ -68,7 +78,6 @@ if __name__ == '__main__':
vocab = load_pkl(vocab_path)
train_ds = CustomDataset(train_data_path)
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, drop_last=False)
for batch_idx, (x, y) in enumerate(train_dl):
word = x['word']
for idx in word:

View File

@ -44,10 +44,16 @@ class PRMetric():
self.y_pred = np.empty(0)
def reset(self):
"""
重置为0
"""
self.y_true = np.empty(0)
self.y_pred = np.empty(0)
def update(self, y_true: torch.Tensor, y_pred: torch.Tensor):
"""
更新tensor保留值取消原有梯度
"""
y_true = y_true.cpu().detach().numpy()
y_pred = y_pred.cpu().detach().numpy()
y_pred = np.argmax(y_pred, axis=-1)
@ -56,6 +62,9 @@ class PRMetric():
self.y_pred = np.append(self.y_pred, y_pred)
def compute(self):
"""
计算acc,p,r,f1并返回
"""
p, r, f1, _ = precision_recall_fscore_support(self.y_true, self.y_pred, average='macro', warn_for=tuple())
_, _, acc, _ = precision_recall_fscore_support(self.y_true, self.y_pred, average='micro', warn_for=tuple())

View File

@ -16,7 +16,9 @@ class LM(BasicModule):
def forward(self, x):
word, lens = x['word'], x['lens']
mask = seq_len_to_mask(lens, mask_pos_to_true=False)
last_hidden_state, pooler_output = self.bert(word, attention_mask=mask)
a = self.bert(word, attention_mask=mask)
last_hidden_state = a[0]
pooler_output = a[1]
out, out_pool = self.bilstm(last_hidden_state, lens)
out_pool = self.dropout(out_pool)
output = self.fc(out_pool)

View File

@ -13,16 +13,17 @@ class DotAttention(nn.Module):
def forward(self, Q, K, V, mask_out=None, head_mask=None):
"""
一般输入信息 X 假设 K = V = X
一般输入信息 X 假设 K = V = Xs
att_weight = softmax( score_func(q, k) )
att = sum( att_weight * v )
:param Q: [..., L, H]
:param K: [..., S, H]
:param V: [..., S, H]
:param mask_out: [..., 1, S]
:return:
Args:
Q: [..., L, H]
K: [..., S, H]
V: [..., S, H]
mask_out: [..., 1, S]
Return:
attention_out
attention_weight
"""
H = Q.size(-1)
@ -52,9 +53,10 @@ class DotAttention(nn.Module):
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.0, output_attentions=True):
"""
:param embed_dim: 输入的维度必须能被 num_heads 整除
:param num_heads: attention 的个数
:param dropout: float
Args:
embed_dim: 输入的维度必须能被 num_heads 整除
num_heads: attention 的个数
dropout: float
"""
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
@ -72,12 +74,13 @@ class MultiHeadAttention(nn.Module):
def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None):
"""
:param Q: [B, L, Hs]
:param K: [B, S, Hs]
:param V: [B, S, Hs]
:param key_padding_mask: [B, S] 1/True 的地方需要 mask
:param attention_mask: [S] / [L, S] 指定位置 mask 1/True 的地方需要 mask
:param head_mask: [N] 指定 head mask 1/True 的地方需要 mask
Args:
Q: [B, L, Hs]
K: [B, S, Hs]
V: [B, S, Hs]
key_padding_mask: [B, S] 1/True 的地方需要 mask
attention_mask: [S] / [L, S] 指定位置 mask 1/True 的地方需要 mask
head_mask: [N] 指定 head mask 1/True 的地方需要 mask
"""
B, L, Hs = Q.shape
S = V.size(1)

View File

@ -21,12 +21,13 @@ class CNN(nn.Module):
"""
def __init__(self, config):
"""
in_channels : 一般就是 word embedding 的维度或者 hidden size 的维度
out_channels : int
kernel_sizes : list 为了保证输出长度=输入长度必须为奇数: 3, 5, 7...
activation : [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
pooling_strategy : [max, avg, cls]
dropout: : float
Args:
in_channels: 一般就是 word embedding 的维度或者 hidden size 的维度
out_channels: int
kernel_sizes: list 为了保证输出长度=输入长度必须为奇数: 3, 5, 7...
activation: [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
pooling_strategy: [max, avg, cls]
dropout: float
"""
super(CNN, self).__init__()
@ -74,10 +75,13 @@ class CNN(nn.Module):
def forward(self, x, mask=None):
"""
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值
:param mask: [batch_size, max_len], 句长部分为0padding部分为1不影响卷积运算max-pool一定不会pool到pad为0的位置
:return:
"""
Args:
x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值
mask: [batch_size, max_len], 句长部分为0padding部分为1不影响卷积运算max-pool一定不会pool到pad为0的位置
Return:
x: torch.Tensor
xp: torch.Tensor
"""
# [B, L, H] -> [B, H, L] (注释:将 H 维度当作输入 channel 维度)
x = torch.transpose(x, 1, 2)

View File

@ -27,7 +27,8 @@ class Capsule(nn.Module):
def forward(self, x):
"""
x: [B, L, H] # 从 CNN / RNN 得到的结果
Args:
x: [B, L, H] # 从 CNN / RNN 得到的结果
L 作为 input_num_capsules, H 作为 input_dim_capsule
"""
B, I, _ = x.size() # I 是 input_num_capsules

View File

@ -5,9 +5,10 @@ import torch.nn as nn
class Embedding(nn.Module):
def __init__(self, config):
"""
word embedding: 一般 0 padding
pos embedding: 一般 0 padding
dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加
Args:
word embedding: 一般 0 padding
pos embedding: 一般 0 padding
dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加
"""
super(Embedding, self).__init__()

View File

@ -55,6 +55,7 @@ class RNN(nn.Module):
B, L, _ = x.size()
H, N = self.hidden_size, self.num_layers
x_len = x_len.cpu()
x = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=True)
output, hn = self.rnn(x)
output, _ = pad_packed_sequence(output, batch_first=True, total_length=L)

View File

@ -11,6 +11,15 @@ logger = logging.getLogger(__name__)
def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:
"""
处理句子长度设定句长限制
Args :
pos (List[int]) : 句子对应的List
limit (int) : 限制的数
Return :
[p + limit + 1 for p in pos] (List[int]) : 处理后的结果
"""
for i, p in enumerate(pos):
if p > limit:
pos[i] = limit
@ -20,6 +29,12 @@ def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:
def _add_pos_seq(train_data: List[Dict], cfg):
"""
增加位置序列
Args :
train_data (List[Dict]) : 数据集合
cfg : 配置文件
"""
for d in train_data:
entities_idx = [d['head_idx'], d['tail_idx']
] if d['head_idx'] < d['tail_idx'] else [d['tail_idx'], d['head_idx']]
@ -39,6 +54,12 @@ def _add_pos_seq(train_data: List[Dict], cfg):
def _convert_tokens_into_index(data: List[Dict], vocab):
"""
将tokens转换成index值
Args :
data (List[Dict]) : 数据集合
vocab (Class) : 词汇表
"""
unk_str = '[UNK]'
unk_idx = vocab.word2idx[unk_str]
@ -48,6 +69,13 @@ def _convert_tokens_into_index(data: List[Dict], vocab):
def _serialize_sentence(data: List[Dict], serial, cfg):
"""
将句子分词
Args :
data (List[Dict]) : 数据集合
serial (Class): Serializer类
cfg : 配置文件
"""
for d in data:
sent = d['sentence'].strip()
sent = sent.replace(d['head'], ' head ', 1).replace(d['tail'], ' tail ', 1)
@ -68,6 +96,12 @@ def _serialize_sentence(data: List[Dict], serial, cfg):
def _lm_serialize(data: List[Dict], cfg):
"""
lm模型分词
Args :
data (List[Dict]) : 数据集合
cfg : 配置文件
"""
logger.info('use bert tokenizer...')
tokenizer = BertTokenizer.from_pretrained(cfg.lm_file)
for d in data:
@ -79,6 +113,12 @@ def _lm_serialize(data: List[Dict], cfg):
def _add_relation_data(rels: Dict, data: List) -> None:
"""
增加关系数据
Args :
rels (Dict) : 关系字典集合
data (List) : 所需增加的关系数据
"""
for d in data:
d['rel2idx'] = rels[d['relation']]['index']
d['head_type'] = rels[d['relation']]['head_type']
@ -86,6 +126,13 @@ def _add_relation_data(rels: Dict, data: List) -> None:
def _handle_relation_data(relation_data: List[Dict]) -> Dict:
"""
处理关系数据每一个关系有indexhead_type,tail_type三个属性
Arg :
relation_data (List[Dict]) : 所需要处理的关系数据
Return :
rels (Dict) : 处理之后的结果
"""
rels = OrderedDict()
relation_data = sorted(relation_data, key=lambda i: int(i['index']))
for d in relation_data:
@ -99,7 +146,9 @@ def _handle_relation_data(relation_data: List[Dict]) -> Dict:
def preprocess(cfg):
"""
数据预处理阶段
"""
logger.info('===== start preprocess data =====')
train_fp = os.path.join(cfg.cwd, cfg.data_path, 'train.csv')
valid_fp = os.path.join(cfg.cwd, cfg.data_path, 'valid.csv')

View File

@ -15,6 +15,14 @@ class Serializer():
self.do_chinese_split = do_chinese_split
def serialize(self, text, never_split: List = None):
"""
将一段文本按照制定拆分规则拆分成一个词汇List
Args :
text (String) : 所需拆分文本
never_split (List) : 不拆分的词默认为空
Rerurn :
output_tokens (List): 拆分后的结果
"""
never_split = self.never_split + (never_split if never_split is not None else [])
text = self._clean_text(text)
@ -36,7 +44,13 @@ class Serializer():
return output_tokens
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
"""
删除文本中无效字符以及空白字符
Arg :
text (String) : 所需删除的文本
Return :
"".join(output) (String) : 删除后的文本
"""
output = []
for char in text:
cp = ord(char)
@ -49,6 +63,14 @@ class Serializer():
return "".join(output)
def _use_jieba_cut(self, text, never_split):
"""
使用jieba分词
Args :
text (String) : 所需拆分文本
never_split (List) : 不拆分的词
Return :
tokens (List) : 拆分完的结果
"""
for word in never_split:
jieba.suggest_freq(word, True)
tokens = jieba.lcut(text)
@ -61,7 +83,13 @@ class Serializer():
return tokens
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
"""
在CJK字符周围添加空格
Arg :
text (String) : 所需拆分文本
Return :
"".join(output) (String) : 添加完后的文本
"""
output = []
for char in text:
cp = ord(char)
@ -74,7 +102,13 @@ class Serializer():
return "".join(output)
def _orig_tokenize(self, text):
"""Splits text on whitespace and some punctuations like comma or period"""
"""
在空白和一些标点符号如逗号或句点上拆分文本
Arg :
text (String) : 所需拆分文本
Return :
tokens (List) : 分词完的结果
"""
text = text.strip()
if not text:
return []
@ -86,7 +120,13 @@ class Serializer():
return tokens
def _whitespace_tokenize(self, text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
"""
进行基本的空白字符清理和分割
Arg :
text (String) : 所需拆分文本
Return :
tokens (List) : 分词完的结果
"""
text = text.strip()
if not text:
return []
@ -94,7 +134,14 @@ class Serializer():
return tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
"""
从文本中去除重音符号
Arg :
text (String) : 所需拆分文本
Return :
"".join(output) (String) : 去除后的文本
"""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
@ -105,7 +152,15 @@ class Serializer():
return "".join(output)
def _run_split_on_punc(self, text, never_split=None):
"""Splits punctuation on a piece of text."""
"""
通过标点符号拆分文本
Args :
text (String) : 所需拆分文本
never_split (List) : 不拆分的词默认为空
Return :
["".join(x) for x in output] (List) : 拆分完的结果
"""
if never_split is not None and text in never_split:
return [text]
chars = list(text)
@ -128,7 +183,13 @@ class Serializer():
@staticmethod
def is_control(char):
"""Checks whether `chars` is a control character."""
"""
判断字符是否为控制字符
Arg :
char : 字符
Return :
bool : 判断结果
"""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
@ -140,7 +201,13 @@ class Serializer():
@staticmethod
def is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
"""
判断字符是否为空白字符
Arg :
char : 字符
Return :
bool : 判断结果
"""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
@ -152,7 +219,15 @@ class Serializer():
@staticmethod
def is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
"""
判断字符是否为中文字符
Arg :
cp (char): 字符
Return :
bool : 判断结果
"""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
@ -174,7 +249,13 @@ class Serializer():
@staticmethod
def is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
"""
判断字符是否为标点字符
Arg :
char : 字符
Return :
bool : 判断结果
"""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode

View File

@ -7,6 +7,20 @@ logger = logging.getLogger(__name__)
def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
"""
训练模型
Args:
epoch (int): 训练步数
model (class): 训练的模型
dataloader (dict): 数据集
optimizer (Callable): 优化器
criterion (Callable): 损失函数
device (torch.device): 训练的设备
writer (class): 输出
cfg: 配置文件
Return:
losses[-1] : loss值
"""
model.train()
metric = PRMetric()
@ -55,6 +69,20 @@ def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
def validate(epoch, model, dataloader, criterion, device, cfg):
"""
验证模型
Args:
epoch (int): 训练步数
model (class): 训练的模型
dataloader (dict): 数据集
optimizer (Callable): 优化器
criterion (Callable): 损失函数
device (torch.device): 训练的设备
cfg: 配置文件
Return:
f1 : f1值
loss : loss值
"""
model.eval()
metric = PRMetric()

View File

@ -22,6 +22,14 @@ Path = str
def load_pkl(fp: Path, verbose: bool = True) -> Any:
"""
读取文件
Args :
fp (String) : 读取数据地址
verbose (bool) : 是否打印日志
Return :
data (Any) : 读取的数据
"""
if verbose:
logger.info(f'load data from {fp}')
@ -31,6 +39,13 @@ def load_pkl(fp: Path, verbose: bool = True) -> Any:
def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:
"""
保存文件
Args :
data (Any) : 数据
fp (String) :保存的地址
verbose (bool) : 是否打印日志
"""
if verbose:
logger.info(f'save data in {fp}')
@ -39,6 +54,15 @@ def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:
def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:
"""
读取csv格式文件
Args :
fp (String) : 保存地址
is_tsv (bool) : 是否为excel-tab格式
verbose (bool) : 是否打印日志
Return :
list(reader) (List): 读取的List数据
"""
if verbose:
logger.info(f'load csv from {fp}')
@ -49,6 +73,15 @@ def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:
def save_csv(data: List[Dict], fp: Path, save_in_tsv: False, write_head=True, verbose=True) -> None:
"""
保存csv格式文件
Args :
data (List) : 所需保存的List数据
fp (String) : 保存地址
save_in_tsv (bool) : 是否保存为excel-tab格式
write_head (bool) : 是否写表头
verbose (bool) : 是否打印日志
"""
if verbose:
logger.info(f'save csv file in: {fp}')
@ -62,6 +95,15 @@ def save_csv(data: List[Dict], fp: Path, save_in_tsv: False, write_head=True, ve
def load_jsonld(fp: Path, verbose: bool = True) -> List:
"""
读取jsonld文件
Args:
fp (String): jsonld 文件地址
verbose (bool): 是否打印日志
Return:
datas (List) : 读取后的List
"""
if verbose:
logger.info(f'load jsonld from {fp}')
@ -76,16 +118,21 @@ def load_jsonld(fp: Path, verbose: bool = True) -> List:
def save_jsonld(fp):
"""
保存jsonld格式文件
"""
pass
def jsonld2csv(fp: str, verbose: bool = True) -> str:
'''
"""
读入 jsonld 文件存储在同位置同名的 csv 文件
:param fp: jsonld 文件地址
:param verbose: whether print logging
:return: csv 文件地址
'''
Args:
fp (String): jsonld 文件地址
verbose (bool): 是否打印日志
Return:
fp_new (String):文件地址
"""
data = []
root, ext = os.path.splitext(fp)
fp_new = root + '.csv'
@ -108,12 +155,14 @@ def jsonld2csv(fp: str, verbose: bool = True) -> str:
def csv2jsonld(fp: str, verbose: bool = True) -> str:
'''
读入 csv 文件存储为同位置同名的 jsonld 文件
:param fp: csv 文件地址
:param verbose: whether print logging
:return: jsonld 地址
'''
"""
读入 csv 文件存储在同位置同名的 jsonld 文件
Args:
fp (String): csv 文件地址
verbose (bool): 是否打印日志
Return:
fp_new (String):文件地址
"""
data = []
root, ext = os.path.splitext(fp)
fp_new = root + '.jsonld'

View File

@ -14,26 +14,30 @@ __all__ = [
def manual_seed(seed: int = 1) -> None:
"""
设置seed
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if torch.cuda.CUDA_ENABLED and use_deterministic_cudnn:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
#if torch.cuda.CUDA_ENABLED and use_deterministic_cudnn:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):
"""
将一个表示sequence length的一维数组转换为二维的mask默认pad的位置为1
转变 1-d seq_len到2-d mask.
转变 1-d seq_len到2-d mask
:param list, np.ndarray, torch.LongTensor seq_len: shape将是(B,)
:param int max_len: 将长度pad到这个长度默认(None)使用的是seq_len中最长的长度但在nn.DataParallel的场景下可能不同卡的seq_len会有
区别所以需要传入一个max_len使得mask的长度是pad到该长度
:return: np.ndarray, torch.Tensor shape将是(B, max_length) 元素类似为bool或torch.uint8
Args :
seq_len (list, np.ndarray, torch.LongTensor) : shape将是(B,)
max_len (int): 将长度pad到这个长度默认(None)使用的是seq_len中最长的长度但在nn.DataParallel的场景下可能不同卡的seq_len会有区别所以需要传入一个max_len使得mask的长度是pad到该长度
Return:
mask (np.ndarray, torch.Tensor) : shape将是(B, max_length) 元素类似为bool或torch.uint8
"""
if isinstance(seq_len, list):
seq_len = np.array(seq_len)
@ -58,9 +62,11 @@ def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None
def to_one_hot(x: torch.Tensor, length: int) -> torch.Tensor:
"""
:param x: [B] 一般是 target 的值
:param length: L 一般是关系种类树
:return: [B, L] 每一行只有对应位置为1其余为0
Args:
x (torch.Tensor):[B] , 一般是 target 的值
length (int) : L ,一般是关系种类树
Return:
x_one_hot.to(device=x.device) (torch.Tensor) : [B, L] 每一行只有对应位置为1其余为0
"""
B = x.size(0)
x_one_hot = torch.zeros(B, length)

View File

@ -33,6 +33,9 @@ SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES))
class Vocab(object):
"""
构建词汇表,增加词汇删除低频词汇
"""
def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS):
self.name = name
self.init_tokens = init_tokens
@ -44,10 +47,18 @@ class Vocab(object):
self._add_init_tokens()
def _add_init_tokens(self):
"""
添加初始tokens
"""
for token in self.init_tokens.values():
self._add_word(token)
def _add_word(self, word: str):
"""
增加单个词汇
Arg :
word (String) : 增加的词汇
"""
if word not in self.word2idx:
self.word2idx[word] = self.count
self.word2count[word] = 1
@ -57,15 +68,21 @@ class Vocab(object):
self.word2count[word] += 1
def add_words(self, words: Sequence):
"""
通过数组增加词汇
Arg :
words (List) : 增加的词汇组
"""
for word in words:
self._add_word(word)
def trim(self, min_freq=2, verbose: Optional[bool] = True):
'''当 word 词频低于 min_freq 时,从词库中删除
Args:
param min_freq: 最低词频
'''
"""
word 词频低于 min_freq 从词库中删除
Args:
min_freq (int): 最低词频
verbose (bool) : 是否打印日志
"""
assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}'
min_freq = int(min_freq)
if min_freq < 2: