add annotation
This commit is contained in:
parent
b4c55ac8c4
commit
fa400b26b5
13
dataset.py
13
dataset.py
|
@ -4,7 +4,15 @@ from utils import load_pkl
|
|||
|
||||
|
||||
def collate_fn(cfg):
|
||||
|
||||
def collate_fn_intra(batch):
|
||||
"""
|
||||
Arg :
|
||||
batch () : 数据集
|
||||
Returna :
|
||||
x (dict) : key为词,value为长度
|
||||
y (List) : 关系对应值的集合
|
||||
"""
|
||||
batch.sort(key=lambda data: data['seq_len'], reverse=True)
|
||||
|
||||
max_len = batch[0]['seq_len']
|
||||
|
@ -48,7 +56,9 @@ def collate_fn(cfg):
|
|||
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
"""默认使用 List 存储数据"""
|
||||
"""
|
||||
默认使用 List 存储数据
|
||||
"""
|
||||
def __init__(self, fp):
|
||||
self.file = load_pkl(fp)
|
||||
|
||||
|
@ -68,7 +78,6 @@ if __name__ == '__main__':
|
|||
vocab = load_pkl(vocab_path)
|
||||
train_ds = CustomDataset(train_data_path)
|
||||
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, drop_last=False)
|
||||
|
||||
for batch_idx, (x, y) in enumerate(train_dl):
|
||||
word = x['word']
|
||||
for idx in word:
|
||||
|
|
|
@ -44,10 +44,16 @@ class PRMetric():
|
|||
self.y_pred = np.empty(0)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
重置为0
|
||||
"""
|
||||
self.y_true = np.empty(0)
|
||||
self.y_pred = np.empty(0)
|
||||
|
||||
def update(self, y_true: torch.Tensor, y_pred: torch.Tensor):
|
||||
"""
|
||||
更新tensor,保留值,取消原有梯度
|
||||
"""
|
||||
y_true = y_true.cpu().detach().numpy()
|
||||
y_pred = y_pred.cpu().detach().numpy()
|
||||
y_pred = np.argmax(y_pred, axis=-1)
|
||||
|
@ -56,6 +62,9 @@ class PRMetric():
|
|||
self.y_pred = np.append(self.y_pred, y_pred)
|
||||
|
||||
def compute(self):
|
||||
"""
|
||||
计算acc,p,r,f1并返回
|
||||
"""
|
||||
p, r, f1, _ = precision_recall_fscore_support(self.y_true, self.y_pred, average='macro', warn_for=tuple())
|
||||
_, _, acc, _ = precision_recall_fscore_support(self.y_true, self.y_pred, average='micro', warn_for=tuple())
|
||||
|
||||
|
|
|
@ -16,7 +16,9 @@ class LM(BasicModule):
|
|||
def forward(self, x):
|
||||
word, lens = x['word'], x['lens']
|
||||
mask = seq_len_to_mask(lens, mask_pos_to_true=False)
|
||||
last_hidden_state, pooler_output = self.bert(word, attention_mask=mask)
|
||||
a = self.bert(word, attention_mask=mask)
|
||||
last_hidden_state = a[0]
|
||||
pooler_output = a[1]
|
||||
out, out_pool = self.bilstm(last_hidden_state, lens)
|
||||
out_pool = self.dropout(out_pool)
|
||||
output = self.fc(out_pool)
|
||||
|
|
|
@ -13,16 +13,17 @@ class DotAttention(nn.Module):
|
|||
|
||||
def forward(self, Q, K, V, mask_out=None, head_mask=None):
|
||||
"""
|
||||
一般输入信息 X 时,假设 K = V = X
|
||||
|
||||
一般输入信息 X 时,假设 K = V = Xs
|
||||
att_weight = softmax( score_func(q, k) )
|
||||
att = sum( att_weight * v )
|
||||
|
||||
:param Q: [..., L, H]
|
||||
:param K: [..., S, H]
|
||||
:param V: [..., S, H]
|
||||
:param mask_out: [..., 1, S]
|
||||
:return:
|
||||
Args:
|
||||
Q: [..., L, H]
|
||||
K: [..., S, H]
|
||||
V: [..., S, H]
|
||||
mask_out: [..., 1, S]
|
||||
Return:
|
||||
attention_out
|
||||
attention_weight
|
||||
"""
|
||||
H = Q.size(-1)
|
||||
|
||||
|
@ -52,9 +53,10 @@ class DotAttention(nn.Module):
|
|||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, embed_dim, num_heads, dropout=0.0, output_attentions=True):
|
||||
"""
|
||||
:param embed_dim: 输入的维度,必须能被 num_heads 整除
|
||||
:param num_heads: attention 的个数
|
||||
:param dropout: float。
|
||||
Args:
|
||||
embed_dim: 输入的维度,必须能被 num_heads 整除
|
||||
num_heads: attention 的个数
|
||||
dropout: float。
|
||||
"""
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.num_heads = num_heads
|
||||
|
@ -72,12 +74,13 @@ class MultiHeadAttention(nn.Module):
|
|||
|
||||
def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||
"""
|
||||
:param Q: [B, L, Hs]
|
||||
:param K: [B, S, Hs]
|
||||
:param V: [B, S, Hs]
|
||||
:param key_padding_mask: [B, S] 为 1/True 的地方需要 mask
|
||||
:param attention_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
|
||||
:param head_mask: [N] 指定 head mask 掉, 为 1/True 的地方需要 mask
|
||||
Args:
|
||||
Q: [B, L, Hs]
|
||||
K: [B, S, Hs]
|
||||
V: [B, S, Hs]
|
||||
key_padding_mask: [B, S] 为 1/True 的地方需要 mask
|
||||
attention_mask: [S] / [L, S] 指定位置 mask 掉, 为 1/True 的地方需要 mask
|
||||
head_mask: [N] 指定 head mask 掉, 为 1/True 的地方需要 mask
|
||||
"""
|
||||
B, L, Hs = Q.shape
|
||||
S = V.size(1)
|
||||
|
|
|
@ -21,12 +21,13 @@ class CNN(nn.Module):
|
|||
"""
|
||||
def __init__(self, config):
|
||||
"""
|
||||
in_channels : 一般就是 word embedding 的维度,或者 hidden size 的维度
|
||||
out_channels : int
|
||||
kernel_sizes : list 为了保证输出长度=输入长度,必须为奇数: 3, 5, 7...
|
||||
activation : [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
pooling_strategy : [max, avg, cls]
|
||||
dropout: : float
|
||||
Args:
|
||||
in_channels: 一般就是 word embedding 的维度,或者 hidden size 的维度
|
||||
out_channels: int
|
||||
kernel_sizes: list 为了保证输出长度=输入长度,必须为奇数: 3, 5, 7...
|
||||
activation: [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]
|
||||
pooling_strategy: [max, avg, cls]
|
||||
dropout: float
|
||||
"""
|
||||
super(CNN, self).__init__()
|
||||
|
||||
|
@ -74,10 +75,13 @@ class CNN(nn.Module):
|
|||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
:param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值
|
||||
:param mask: [batch_size, max_len], 句长部分为0,padding部分为1。不影响卷积运算,max-pool一定不会pool到pad为0的位置
|
||||
:return:
|
||||
"""
|
||||
Args:
|
||||
x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值
|
||||
mask: [batch_size, max_len], 句长部分为0,padding部分为1。不影响卷积运算,max-pool一定不会pool到pad为0的位置
|
||||
Return:
|
||||
x: torch.Tensor
|
||||
xp: torch.Tensor
|
||||
"""
|
||||
# [B, L, H] -> [B, H, L] (注释:将 H 维度当作输入 channel 维度)
|
||||
x = torch.transpose(x, 1, 2)
|
||||
|
||||
|
|
|
@ -27,7 +27,8 @@ class Capsule(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: [B, L, H] # 从 CNN / RNN 得到的结果
|
||||
Args:
|
||||
x: [B, L, H] # 从 CNN / RNN 得到的结果
|
||||
L 作为 input_num_capsules, H 作为 input_dim_capsule
|
||||
"""
|
||||
B, I, _ = x.size() # I 是 input_num_capsules
|
||||
|
|
|
@ -5,9 +5,10 @@ import torch.nn as nn
|
|||
class Embedding(nn.Module):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
word embedding: 一般 0 为 padding
|
||||
pos embedding: 一般 0 为 padding
|
||||
dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加
|
||||
Args:
|
||||
word embedding: 一般 0 为 padding
|
||||
pos embedding: 一般 0 为 padding
|
||||
dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加
|
||||
"""
|
||||
super(Embedding, self).__init__()
|
||||
|
||||
|
|
|
@ -55,6 +55,7 @@ class RNN(nn.Module):
|
|||
B, L, _ = x.size()
|
||||
H, N = self.hidden_size, self.num_layers
|
||||
|
||||
x_len = x_len.cpu()
|
||||
x = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=True)
|
||||
output, hn = self.rnn(x)
|
||||
output, _ = pad_packed_sequence(output, batch_first=True, total_length=L)
|
||||
|
|
|
@ -11,6 +11,15 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:
|
||||
"""
|
||||
处理句子长度,设定句长限制
|
||||
Args :
|
||||
pos (List[int]) : 句子对应的List
|
||||
limit (int) : 限制的数
|
||||
Return :
|
||||
[p + limit + 1 for p in pos] (List[int]) : 处理后的结果
|
||||
|
||||
"""
|
||||
for i, p in enumerate(pos):
|
||||
if p > limit:
|
||||
pos[i] = limit
|
||||
|
@ -20,6 +29,12 @@ def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:
|
|||
|
||||
|
||||
def _add_pos_seq(train_data: List[Dict], cfg):
|
||||
"""
|
||||
增加位置序列
|
||||
Args :
|
||||
train_data (List[Dict]) : 数据集合
|
||||
cfg : 配置文件
|
||||
"""
|
||||
for d in train_data:
|
||||
entities_idx = [d['head_idx'], d['tail_idx']
|
||||
] if d['head_idx'] < d['tail_idx'] else [d['tail_idx'], d['head_idx']]
|
||||
|
@ -39,6 +54,12 @@ def _add_pos_seq(train_data: List[Dict], cfg):
|
|||
|
||||
|
||||
def _convert_tokens_into_index(data: List[Dict], vocab):
|
||||
"""
|
||||
将tokens转换成index值
|
||||
Args :
|
||||
data (List[Dict]) : 数据集合
|
||||
vocab (Class) : 词汇表
|
||||
"""
|
||||
unk_str = '[UNK]'
|
||||
unk_idx = vocab.word2idx[unk_str]
|
||||
|
||||
|
@ -48,6 +69,13 @@ def _convert_tokens_into_index(data: List[Dict], vocab):
|
|||
|
||||
|
||||
def _serialize_sentence(data: List[Dict], serial, cfg):
|
||||
"""
|
||||
将句子分词
|
||||
Args :
|
||||
data (List[Dict]) : 数据集合
|
||||
serial (Class): Serializer类
|
||||
cfg : 配置文件
|
||||
"""
|
||||
for d in data:
|
||||
sent = d['sentence'].strip()
|
||||
sent = sent.replace(d['head'], ' head ', 1).replace(d['tail'], ' tail ', 1)
|
||||
|
@ -68,6 +96,12 @@ def _serialize_sentence(data: List[Dict], serial, cfg):
|
|||
|
||||
|
||||
def _lm_serialize(data: List[Dict], cfg):
|
||||
"""
|
||||
lm模型分词
|
||||
Args :
|
||||
data (List[Dict]) : 数据集合
|
||||
cfg : 配置文件
|
||||
"""
|
||||
logger.info('use bert tokenizer...')
|
||||
tokenizer = BertTokenizer.from_pretrained(cfg.lm_file)
|
||||
for d in data:
|
||||
|
@ -79,6 +113,12 @@ def _lm_serialize(data: List[Dict], cfg):
|
|||
|
||||
|
||||
def _add_relation_data(rels: Dict, data: List) -> None:
|
||||
"""
|
||||
增加关系数据
|
||||
Args :
|
||||
rels (Dict) : 关系字典集合
|
||||
data (List) : 所需增加的关系数据
|
||||
"""
|
||||
for d in data:
|
||||
d['rel2idx'] = rels[d['relation']]['index']
|
||||
d['head_type'] = rels[d['relation']]['head_type']
|
||||
|
@ -86,6 +126,13 @@ def _add_relation_data(rels: Dict, data: List) -> None:
|
|||
|
||||
|
||||
def _handle_relation_data(relation_data: List[Dict]) -> Dict:
|
||||
"""
|
||||
处理关系数据,每一个关系有index,head_type,tail_type三个属性
|
||||
Arg :
|
||||
relation_data (List[Dict]) : 所需要处理的关系数据
|
||||
Return :
|
||||
rels (Dict) : 处理之后的结果
|
||||
"""
|
||||
rels = OrderedDict()
|
||||
relation_data = sorted(relation_data, key=lambda i: int(i['index']))
|
||||
for d in relation_data:
|
||||
|
@ -99,7 +146,9 @@ def _handle_relation_data(relation_data: List[Dict]) -> Dict:
|
|||
|
||||
|
||||
def preprocess(cfg):
|
||||
|
||||
"""
|
||||
数据预处理阶段
|
||||
"""
|
||||
logger.info('===== start preprocess data =====')
|
||||
train_fp = os.path.join(cfg.cwd, cfg.data_path, 'train.csv')
|
||||
valid_fp = os.path.join(cfg.cwd, cfg.data_path, 'valid.csv')
|
||||
|
|
101
serializer.py
101
serializer.py
|
@ -15,6 +15,14 @@ class Serializer():
|
|||
self.do_chinese_split = do_chinese_split
|
||||
|
||||
def serialize(self, text, never_split: List = None):
|
||||
"""
|
||||
将一段文本按照制定拆分规则,拆分成一个词汇List
|
||||
Args :
|
||||
text (String) : 所需拆分文本
|
||||
never_split (List) : 不拆分的词,默认为空
|
||||
Rerurn :
|
||||
output_tokens (List): 拆分后的结果
|
||||
"""
|
||||
never_split = self.never_split + (never_split if never_split is not None else [])
|
||||
text = self._clean_text(text)
|
||||
|
||||
|
@ -36,7 +44,13 @@ class Serializer():
|
|||
return output_tokens
|
||||
|
||||
def _clean_text(self, text):
|
||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
||||
"""
|
||||
删除文本中无效字符以及空白字符
|
||||
Arg :
|
||||
text (String) : 所需删除的文本
|
||||
Return :
|
||||
"".join(output) (String) : 删除后的文本
|
||||
"""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
|
@ -49,6 +63,14 @@ class Serializer():
|
|||
return "".join(output)
|
||||
|
||||
def _use_jieba_cut(self, text, never_split):
|
||||
"""
|
||||
使用jieba分词
|
||||
Args :
|
||||
text (String) : 所需拆分文本
|
||||
never_split (List) : 不拆分的词
|
||||
Return :
|
||||
tokens (List) : 拆分完的结果
|
||||
"""
|
||||
for word in never_split:
|
||||
jieba.suggest_freq(word, True)
|
||||
tokens = jieba.lcut(text)
|
||||
|
@ -61,7 +83,13 @@ class Serializer():
|
|||
return tokens
|
||||
|
||||
def _tokenize_chinese_chars(self, text):
|
||||
"""Adds whitespace around any CJK character."""
|
||||
"""
|
||||
在CJK字符周围添加空格
|
||||
Arg :
|
||||
text (String) : 所需拆分文本
|
||||
Return :
|
||||
"".join(output) (String) : 添加完后的文本
|
||||
"""
|
||||
output = []
|
||||
for char in text:
|
||||
cp = ord(char)
|
||||
|
@ -74,7 +102,13 @@ class Serializer():
|
|||
return "".join(output)
|
||||
|
||||
def _orig_tokenize(self, text):
|
||||
"""Splits text on whitespace and some punctuations like comma or period"""
|
||||
"""
|
||||
在空白和一些标点符号(如逗号或句点)上拆分文本
|
||||
Arg :
|
||||
text (String) : 所需拆分文本
|
||||
Return :
|
||||
tokens (List) : 分词完的结果
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
@ -86,7 +120,13 @@ class Serializer():
|
|||
return tokens
|
||||
|
||||
def _whitespace_tokenize(self, text):
|
||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
||||
"""
|
||||
进行基本的空白字符清理和分割
|
||||
Arg :
|
||||
text (String) : 所需拆分文本
|
||||
Return :
|
||||
tokens (List) : 分词完的结果
|
||||
"""
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return []
|
||||
|
@ -94,7 +134,14 @@ class Serializer():
|
|||
return tokens
|
||||
|
||||
def _run_strip_accents(self, text):
|
||||
"""Strips accents from a piece of text."""
|
||||
"""
|
||||
从文本中去除重音符号
|
||||
Arg :
|
||||
text (String) : 所需拆分文本
|
||||
Return :
|
||||
"".join(output) (String) : 去除后的文本
|
||||
|
||||
"""
|
||||
text = unicodedata.normalize("NFD", text)
|
||||
output = []
|
||||
for char in text:
|
||||
|
@ -105,7 +152,15 @@ class Serializer():
|
|||
return "".join(output)
|
||||
|
||||
def _run_split_on_punc(self, text, never_split=None):
|
||||
"""Splits punctuation on a piece of text."""
|
||||
"""
|
||||
通过标点符号拆分文本
|
||||
Args :
|
||||
text (String) : 所需拆分文本
|
||||
never_split (List) : 不拆分的词,默认为空
|
||||
Return :
|
||||
["".join(x) for x in output] (List) : 拆分完的结果
|
||||
"""
|
||||
|
||||
if never_split is not None and text in never_split:
|
||||
return [text]
|
||||
chars = list(text)
|
||||
|
@ -128,7 +183,13 @@ class Serializer():
|
|||
|
||||
@staticmethod
|
||||
def is_control(char):
|
||||
"""Checks whether `chars` is a control character."""
|
||||
"""
|
||||
判断字符是否为控制字符
|
||||
Arg :
|
||||
char : 字符
|
||||
Return :
|
||||
bool : 判断结果
|
||||
"""
|
||||
# These are technically control characters but we count them as whitespace
|
||||
# characters.
|
||||
if char == "\t" or char == "\n" or char == "\r":
|
||||
|
@ -140,7 +201,13 @@ class Serializer():
|
|||
|
||||
@staticmethod
|
||||
def is_whitespace(char):
|
||||
"""Checks whether `chars` is a whitespace character."""
|
||||
"""
|
||||
判断字符是否为空白字符
|
||||
Arg :
|
||||
char : 字符
|
||||
Return :
|
||||
bool : 判断结果
|
||||
"""
|
||||
# \t, \n, and \r are technically contorl characters but we treat them
|
||||
# as whitespace since they are generally considered as such.
|
||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
||||
|
@ -152,7 +219,15 @@ class Serializer():
|
|||
|
||||
@staticmethod
|
||||
def is_chinese_char(cp):
|
||||
"""Checks whether CP is the codepoint of a CJK character."""
|
||||
"""
|
||||
|
||||
判断字符是否为中文字符
|
||||
Arg :
|
||||
cp (char): 字符
|
||||
Return :
|
||||
bool : 判断结果
|
||||
|
||||
"""
|
||||
# This defines a "chinese character" as anything in the CJK Unicode block:
|
||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
||||
#
|
||||
|
@ -174,7 +249,13 @@ class Serializer():
|
|||
|
||||
@staticmethod
|
||||
def is_punctuation(char):
|
||||
"""Checks whether `chars` is a punctuation character."""
|
||||
"""
|
||||
判断字符是否为标点字符
|
||||
Arg :
|
||||
char : 字符
|
||||
Return :
|
||||
bool : 判断结果
|
||||
"""
|
||||
cp = ord(char)
|
||||
# We treat all non-letter/number ASCII as punctuation.
|
||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
||||
|
|
28
trainer.py
28
trainer.py
|
@ -7,6 +7,20 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
|
||||
"""
|
||||
训练模型
|
||||
Args:
|
||||
epoch (int): 训练步数
|
||||
model (class): 训练的模型
|
||||
dataloader (dict): 数据集
|
||||
optimizer (Callable): 优化器
|
||||
criterion (Callable): 损失函数
|
||||
device (torch.device): 训练的设备
|
||||
writer (class): 输出
|
||||
cfg: 配置文件
|
||||
Return:
|
||||
losses[-1] : loss值
|
||||
"""
|
||||
model.train()
|
||||
|
||||
metric = PRMetric()
|
||||
|
@ -55,6 +69,20 @@ def train(epoch, model, dataloader, optimizer, criterion, device, writer, cfg):
|
|||
|
||||
|
||||
def validate(epoch, model, dataloader, criterion, device, cfg):
|
||||
"""
|
||||
验证模型
|
||||
Args:
|
||||
epoch (int): 训练步数
|
||||
model (class): 训练的模型
|
||||
dataloader (dict): 数据集
|
||||
optimizer (Callable): 优化器
|
||||
criterion (Callable): 损失函数
|
||||
device (torch.device): 训练的设备
|
||||
cfg: 配置文件
|
||||
Return:
|
||||
f1 : f1值
|
||||
loss : loss值
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
metric = PRMetric()
|
||||
|
|
|
@ -22,6 +22,14 @@ Path = str
|
|||
|
||||
|
||||
def load_pkl(fp: Path, verbose: bool = True) -> Any:
|
||||
"""
|
||||
读取文件
|
||||
Args :
|
||||
fp (String) : 读取数据地址
|
||||
verbose (bool) : 是否打印日志
|
||||
Return :
|
||||
data (Any) : 读取的数据
|
||||
"""
|
||||
if verbose:
|
||||
logger.info(f'load data from {fp}')
|
||||
|
||||
|
@ -31,6 +39,13 @@ def load_pkl(fp: Path, verbose: bool = True) -> Any:
|
|||
|
||||
|
||||
def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:
|
||||
"""
|
||||
保存文件
|
||||
Args :
|
||||
data (Any) : 数据
|
||||
fp (String) :保存的地址
|
||||
verbose (bool) : 是否打印日志
|
||||
"""
|
||||
if verbose:
|
||||
logger.info(f'save data in {fp}')
|
||||
|
||||
|
@ -39,6 +54,15 @@ def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:
|
|||
|
||||
|
||||
def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:
|
||||
"""
|
||||
读取csv格式文件
|
||||
Args :
|
||||
fp (String) : 保存地址
|
||||
is_tsv (bool) : 是否为excel-tab格式
|
||||
verbose (bool) : 是否打印日志
|
||||
Return :
|
||||
list(reader) (List): 读取的List数据
|
||||
"""
|
||||
if verbose:
|
||||
logger.info(f'load csv from {fp}')
|
||||
|
||||
|
@ -49,6 +73,15 @@ def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:
|
|||
|
||||
|
||||
def save_csv(data: List[Dict], fp: Path, save_in_tsv: False, write_head=True, verbose=True) -> None:
|
||||
"""
|
||||
保存csv格式文件
|
||||
Args :
|
||||
data (List) : 所需保存的List数据
|
||||
fp (String) : 保存地址
|
||||
save_in_tsv (bool) : 是否保存为excel-tab格式
|
||||
write_head (bool) : 是否写表头
|
||||
verbose (bool) : 是否打印日志
|
||||
"""
|
||||
if verbose:
|
||||
logger.info(f'save csv file in: {fp}')
|
||||
|
||||
|
@ -62,6 +95,15 @@ def save_csv(data: List[Dict], fp: Path, save_in_tsv: False, write_head=True, ve
|
|||
|
||||
|
||||
def load_jsonld(fp: Path, verbose: bool = True) -> List:
|
||||
|
||||
"""
|
||||
读取jsonld文件
|
||||
Args:
|
||||
fp (String): jsonld 文件地址
|
||||
verbose (bool): 是否打印日志
|
||||
Return:
|
||||
datas (List) : 读取后的List
|
||||
"""
|
||||
if verbose:
|
||||
logger.info(f'load jsonld from {fp}')
|
||||
|
||||
|
@ -76,16 +118,21 @@ def load_jsonld(fp: Path, verbose: bool = True) -> List:
|
|||
|
||||
|
||||
def save_jsonld(fp):
|
||||
"""
|
||||
保存jsonld格式文件
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def jsonld2csv(fp: str, verbose: bool = True) -> str:
|
||||
'''
|
||||
"""
|
||||
读入 jsonld 文件,存储在同位置同名的 csv 文件
|
||||
:param fp: jsonld 文件地址
|
||||
:param verbose: whether print logging
|
||||
:return: csv 文件地址
|
||||
'''
|
||||
Args:
|
||||
fp (String): jsonld 文件地址
|
||||
verbose (bool): 是否打印日志
|
||||
Return:
|
||||
fp_new (String):文件地址
|
||||
"""
|
||||
data = []
|
||||
root, ext = os.path.splitext(fp)
|
||||
fp_new = root + '.csv'
|
||||
|
@ -108,12 +155,14 @@ def jsonld2csv(fp: str, verbose: bool = True) -> str:
|
|||
|
||||
|
||||
def csv2jsonld(fp: str, verbose: bool = True) -> str:
|
||||
'''
|
||||
读入 csv 文件,存储为同位置同名的 jsonld 文件
|
||||
:param fp: csv 文件地址
|
||||
:param verbose: whether print logging
|
||||
:return: jsonld 地址
|
||||
'''
|
||||
"""
|
||||
读入 csv 文件,存储在同位置同名的 jsonld 文件
|
||||
Args:
|
||||
fp (String): csv 文件地址
|
||||
verbose (bool): 是否打印日志
|
||||
Return:
|
||||
fp_new (String):文件地址
|
||||
"""
|
||||
data = []
|
||||
root, ext = os.path.splitext(fp)
|
||||
fp_new = root + '.jsonld'
|
||||
|
|
|
@ -14,26 +14,30 @@ __all__ = [
|
|||
|
||||
|
||||
def manual_seed(seed: int = 1) -> None:
|
||||
"""
|
||||
设置seed。
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if torch.cuda.CUDA_ENABLED and use_deterministic_cudnn:
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
#if torch.cuda.CUDA_ENABLED and use_deterministic_cudnn:
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
|
||||
def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):
|
||||
"""
|
||||
将一个表示sequence length的一维数组转换为二维的mask,默认pad的位置为1。
|
||||
转变 1-d seq_len到2-d mask.
|
||||
转变 1-d seq_len到2-d mask。
|
||||
|
||||
:param list, np.ndarray, torch.LongTensor seq_len: shape将是(B,)
|
||||
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有
|
||||
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。
|
||||
:return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8
|
||||
Args :
|
||||
seq_len (list, np.ndarray, torch.LongTensor) : shape将是(B,)
|
||||
max_len (int): 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有区别,所以需要传入一个max_len使得mask的长度是pad到该长度。
|
||||
Return:
|
||||
mask (np.ndarray, torch.Tensor) : shape将是(B, max_length), 元素类似为bool或torch.uint8
|
||||
"""
|
||||
if isinstance(seq_len, list):
|
||||
seq_len = np.array(seq_len)
|
||||
|
@ -58,9 +62,11 @@ def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None
|
|||
|
||||
def to_one_hot(x: torch.Tensor, length: int) -> torch.Tensor:
|
||||
"""
|
||||
:param x: [B] 一般是 target 的值
|
||||
:param length: L 一般是关系种类树
|
||||
:return: [B, L] 每一行,只有对应位置为1,其余为0
|
||||
Args:
|
||||
x (torch.Tensor):[B] , 一般是 target 的值
|
||||
length (int) : L ,一般是关系种类树
|
||||
Return:
|
||||
x_one_hot.to(device=x.device) (torch.Tensor) : [B, L] 每一行,只有对应位置为1,其余为0
|
||||
"""
|
||||
B = x.size(0)
|
||||
x_one_hot = torch.zeros(B, length)
|
||||
|
|
27
vocab.py
27
vocab.py
|
@ -33,6 +33,9 @@ SPECIAL_TOKENS = OrderedDict(zip(SPECIAL_TOKENS_KEYS, SPECIAL_TOKENS_VALUES))
|
|||
|
||||
|
||||
class Vocab(object):
|
||||
"""
|
||||
构建词汇表,增加词汇,删除低频词汇
|
||||
"""
|
||||
def __init__(self, name: str = 'basic', init_tokens: Sequence = SPECIAL_TOKENS):
|
||||
self.name = name
|
||||
self.init_tokens = init_tokens
|
||||
|
@ -44,10 +47,18 @@ class Vocab(object):
|
|||
self._add_init_tokens()
|
||||
|
||||
def _add_init_tokens(self):
|
||||
"""
|
||||
添加初始tokens
|
||||
"""
|
||||
for token in self.init_tokens.values():
|
||||
self._add_word(token)
|
||||
|
||||
def _add_word(self, word: str):
|
||||
"""
|
||||
增加单个词汇
|
||||
Arg :
|
||||
word (String) : 增加的词汇
|
||||
"""
|
||||
if word not in self.word2idx:
|
||||
self.word2idx[word] = self.count
|
||||
self.word2count[word] = 1
|
||||
|
@ -57,15 +68,21 @@ class Vocab(object):
|
|||
self.word2count[word] += 1
|
||||
|
||||
def add_words(self, words: Sequence):
|
||||
"""
|
||||
通过数组增加词汇
|
||||
Arg :
|
||||
words (List) : 增加的词汇组
|
||||
"""
|
||||
for word in words:
|
||||
self._add_word(word)
|
||||
|
||||
def trim(self, min_freq=2, verbose: Optional[bool] = True):
|
||||
'''当 word 词频低于 min_freq 时,从词库中删除
|
||||
|
||||
Args:
|
||||
param min_freq: 最低词频
|
||||
'''
|
||||
"""
|
||||
当 word 词频低于 min_freq 时,从词库中删除
|
||||
Args:
|
||||
min_freq (int): 最低词频
|
||||
verbose (bool) : 是否打印日志
|
||||
"""
|
||||
assert min_freq == int(min_freq), f'min_freq must be integer, can\'t be {min_freq}'
|
||||
min_freq = int(min_freq)
|
||||
if min_freq < 2:
|
||||
|
|
Loading…
Reference in New Issue