85 lines
2.6 KiB
Python
85 lines
2.6 KiB
Python
import torch
|
|
from torch.utils.data import Dataset
|
|
from utils import load_pkl
|
|
|
|
|
|
def collate_fn(cfg):
|
|
def collate_fn_intra(batch):
|
|
batch.sort(key=lambda data: data['seq_len'], reverse=True)
|
|
|
|
max_len = batch[0]['seq_len']
|
|
|
|
def _padding(x, max_len):
|
|
return x + [0] * (max_len - len(x))
|
|
|
|
x, y = dict(), []
|
|
word, word_len = [], []
|
|
head_pos, tail_pos = [], []
|
|
pcnn_mask = []
|
|
for data in batch:
|
|
word.append(_padding(data['token2idx'], max_len))
|
|
word_len.append(data['seq_len'])
|
|
y.append(int(data['rel2idx']))
|
|
|
|
if cfg.model_name != 'lm':
|
|
head_pos.append(_padding(data['head_pos'], max_len))
|
|
tail_pos.append(_padding(data['tail_pos'], max_len))
|
|
if cfg.model_name == 'cnn':
|
|
if cfg.use_pcnn:
|
|
pcnn_mask.append(_padding(data['entities_pos'], max_len))
|
|
|
|
x['word'] = torch.tensor(word)
|
|
x['lens'] = torch.tensor(word_len)
|
|
y = torch.tensor(y)
|
|
|
|
if cfg.model_name != 'lm':
|
|
x['head_pos'] = torch.tensor(head_pos)
|
|
x['tail_pos'] = torch.tensor(tail_pos)
|
|
if cfg.model_name == 'cnn' and cfg.use_pcnn:
|
|
x['pcnn_mask'] = torch.tensor(pcnn_mask)
|
|
if cfg.model_name == 'gcn':
|
|
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
|
|
B, L = len(batch), max_len
|
|
adj = torch.empty(B, L, L).random_(2)
|
|
x['adj'] = adj
|
|
return x, y
|
|
|
|
return collate_fn_intra
|
|
|
|
|
|
class CustomDataset(Dataset):
|
|
"""默认使用 List 存储数据"""
|
|
def __init__(self, fp):
|
|
self.file = load_pkl(fp)
|
|
|
|
def __getitem__(self, item):
|
|
sample = self.file[item]
|
|
return sample
|
|
|
|
def __len__(self):
|
|
return len(self.file)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
from torch.utils.data import DataLoader
|
|
train_data_path = 'data/out/train.pkl'
|
|
vocab_path = 'data/out/vocab.pkl'
|
|
unk_str = 'UNK'
|
|
vocab = load_pkl(vocab_path)
|
|
train_ds = CustomDataset(train_data_path)
|
|
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn, drop_last=False)
|
|
|
|
for batch_idx, (x, y) in enumerate(train_dl):
|
|
word = x['word']
|
|
for idx in word:
|
|
idx2token = ''.join([vocab.idx2word.get(i, unk_str) for i in idx.numpy()])
|
|
print(idx2token)
|
|
print(y)
|
|
break
|
|
# x, y = x.to(device), y.to(device)
|
|
# optimizer.zero_grad()
|
|
# y_pred = models(y)
|
|
# loss = criterion(y_pred, y)
|
|
# loss.backward()
|
|
# optimizer.step()
|