Add files via upload

DocuNet
This commit is contained in:
TimelordRi 2021-09-26 10:01:56 +08:00 committed by GitHub
parent acae147b2b
commit 69da55ecfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 2195 additions and 0 deletions

View File

@ -0,0 +1,92 @@
[**中文**](https://github.com/zjunlp/DocRE/blob/master/README_CN.md) | [**English**](https://github.com/zjunlp/DocRE/blob/master/README.md)
>
<p align="center">
<font size=7><strong>DocNet:Document-level Relation Extraction as Semantic Segmentation</strong></font>
</p>
This repository is the official implementation of [**DocuNet**](https://github.com/zjunlp/DocRE/), which is model proposed in a paper: **[Document-level Relation Extraction as Semantic Segmentation](https://www.ijcai.org/proceedings/2021/551)**, accepted by **IJCAI2021** main conference.
# Contributor
Xiang Chen, Xin Xie, Shuming Deng, Ningyu Zhang, and Huajun Chen.
# Brief Introduction
This paper innovatively proposes the DocuNet model, which first regards the document-level relation extraction as the semantic segmentation task in computer vision.
# Requirements
To install requirements:
```setup
pip install -r requirements.txt
```
# Training
To train the DocuNet model in the paper on the dataset [DocRED](https://github.com/thunlp/DocRE), run this command:
```bash
>> bash scripts/run_docred.sh # use BERT/RoBERTa by setting --transformer-type
```
To train the DocuNet model in the paper on the dataset CDR and GDA, run this command:
```bash
>> bash scripts/run_cdr.sh # for CDR
>> bash scripts/run_gda.sh # for GDA
```
# Evaluation
To evaluate the trained model in the paper, you setting the `--load_path` argument in training scripts. The program will log the result of evaluation automatically. And for DocRED it will generate a test file `result.json` in the official evaluation format. You can compress and submit it to Colab for the official test score.
# Results
Our model achieves the following performance on :
### Document-level Relation Extraction on [DocRED](https://github.com/thunlp/DocRED)
| Model | Ign F1 on Dev | F1 on Dev | Ign F1 on Test | F1 on Test |
| :----------------: |:--------------: | :------------: | ------------------ | ------------------ |
| DocuNet-BERT (base) | 59.86±0.13 | 61.83±0.19 | 59.93 | 61.86 |
| DocuNet-RoBERTa (large) | 62.23±0.12 | 64.12±0.14 | 62.39 | 64.55 |
### Document-level Relation Extraction on [CDR and GDA](https://github.com/fenchri/edge-oriented-graph)
| Model | CDR | GDA |
| :----------------: | :----------------: | :----------------: |
| DocuNet-SciBERT (base) | 76.3±0.40 | 85.3±0.50 |
# Papers for the Project & How to Cite
If you use or extend our work, please cite the following paper:
```
@inproceedings{ijcai2021-551,
title = {Document-level Relation Extraction as Semantic Segmentation},
author = {Zhang, Ningyu and Chen, Xiang and Xie, Xin and Deng, Shumin and Tan, Chuanqi and Chen, Mosha and Huang, Fei and Si, Luo and Chen, Huajun},
booktitle = {Proceedings of the Thirtieth International Joint Conference on
Artificial Intelligence, {IJCAI-21}},
publisher = {International Joint Conferences on Artificial Intelligence Organization},
editor = {Zhi-Hua Zhou},
pages = {3999--4006},
year = {2021},
month = {8},
note = {Main Track}
doi = {10.24963/ijcai.2021/551},
url = {https://doi.org/10.24963/ijcai.2021/551},
}
```

View File

@ -0,0 +1,90 @@
[**中文**](https://github.com/zjunlp/DocRE/blob/master/README_CN.md) | [**English**](https://github.com/zjunlp/DocRE/blob/master/README.md)
<p align="center">
<font size=7><strong>DocuNet一个基于语义分割方法实现文档级关系抽取的模型</strong></font>
</p>
这是针对我们[**DocuNet**](https://github.com/zjunlp/DocuNet)项目的官方实现代码。这个模型是在**[Document-level Relation Extraction as Semantic Segmentation](https://www.ijcai.org/proceedings/2021/551)**论文中提出来的,该论文已被**IJCAI2021**主会录用。
# 项目成员
陈想,谢辛,邓淑敏,张宁豫,陈华钧。
# 项目简介
本文创新性地提出DocuNet模型首次将文档级关系抽取任务类比于计算机视觉中的语义分割任务。
# 环境要求
需要按以下命令去配置项目运行环境:
```运行准备
pip install -r requirements.txt
```
# 模型训练
## DocRED
请运行以下命令在DocRED中训练DocuNet模型
```bash
>> bash scripts/run_docred.sh # use BERT/RoBERTa by setting --transformer-type
```
## CDR和GDA
请运行以下命令在CDR和GDA中训练DocuNet模型
```bash
>> bash scripts/run_cdr.sh # for CDR
>> bash scripts/run_gda.sh # for GDA
```
数据集GDR和CDA可以根据[edge-oriented graph](https://github.com/fenchri/edge-oriented-graph)指南获取。
# 评估效果
>要评估论文中的训练模型,您可以在训练脚本中设置 `--load_path` 参数。程序会自动记录评估结果。对于 DocRED它将生成一个官方评估格式的测试文件 `result.json`。您可以压缩并提交给 Colab 以获得官方测试分数。
# 结果
我们的模型达到了以下的性能:
### 在[DocRED](https://github.com/thunlp/DocRED)上的文档级关系抽取
| 模型 | Ign F1 on Dev | F1 on Dev | Ign F1 on Test | F1 on Test |
| :----------------: |:--------------: | :------------: | ------------------ | ------------------ |
| DocuNet-BERT (base) | 59.86±0.13 | 61.83±0.19 | 59.93 | 61.86 |
| DocuNet-RoBERTa (large) | 62.23±0.12 | 64.12±0.14 | 62.39 | 64.55 |
### 在[CDR和GDA](https://github.com/fenchri/edge-oriented-graph)上的文档级关系抽取
| 模型 | CDR | GDA |
| :----------------: | :----------------: | :----------------: |
| DocuNet-SciBERT (base) | 76.3±0.40 | 85.3±0.50 |
# 有关论文
如果您使用或拓展我们的工作,请引用以下论文:
```
@inproceedings{ijcai2021-551,
title = {Document-level Relation Extraction as Semantic Segmentation},
author = {Zhang, Ningyu and Chen, Xiang and Xie, Xin and Deng, Shumin and Tan, Chuanqi and Chen, Mosha and Huang, Fei and Si, Luo and Chen, Huajun},
booktitle = {Proceedings of the Thirtieth International Joint Conference on
Artificial Intelligence, {IJCAI-21}},
publisher = {International Joint Conferences on Artificial Intelligence Organization},
editor = {Zhi-Hua Zhou},
pages = {3999--4006},
year = {2021},
month = {8},
note = {Main Track}
doi = {10.24963/ijcai.2021/551},
url = {https://doi.org/10.24963/ijcai.2021/551},
}
```

View File

@ -0,0 +1,119 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
class AttentionUNet(torch.nn.Module):
"""
UNet, down sampling & up sampling for global reasoning
"""
def __init__(self, input_channels, class_number, **kwargs):
super(AttentionUNet, self).__init__()
down_channel = kwargs['down_channel'] # default = 256
down_channel_2 = down_channel * 2
up_channel_1 = down_channel_2 * 2
up_channel_2 = down_channel * 2
self.inc = InConv(input_channels, down_channel)
self.down1 = DownLayer(down_channel, down_channel_2)
self.down2 = DownLayer(down_channel_2, down_channel_2)
self.up1 = UpLayer(up_channel_1, up_channel_1 // 4)
self.up2 = UpLayer(up_channel_2, up_channel_2 // 4)
self.outc = OutConv(up_channel_2 // 4, class_number)
def forward(self, attention_channels):
"""
Given multi-channel attention map, return the logits of every one mapping into 3-class
:param attention_channels:
:return:
"""
# attention_channels as the shape of: batch_size x channel x width x height
x = attention_channels
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x = self.up1(x3, x2)
x = self.up2(x, x1)
output = self.outc(x)
# attn_map as the shape of: batch_size x width x height x class
output = output.permute(0, 2, 3, 1).contiguous()
return output
class DoubleConv(nn.Module):
"""(conv => [BN] => ReLU) * 2"""
def __init__(self, in_ch, out_ch):
super(DoubleConv, self).__init__()
self.double_conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True))
def forward(self, x):
x = self.double_conv(x)
return x
class InConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(InConv, self).__init__()
self.conv = DoubleConv(in_ch, out_ch)
def forward(self, x):
x = self.conv(x)
return x
class DownLayer(nn.Module):
def __init__(self, in_ch, out_ch):
super(DownLayer, self).__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(kernel_size=2),
DoubleConv(in_ch, out_ch)
)
def forward(self, x):
x = self.maxpool_conv(x)
return x
class UpLayer(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True):
super(UpLayer, self).__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear',
align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
self.conv = DoubleConv(in_ch, out_ch)
def forward(self, x1, x2):
x1 = self.up(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY -
diffY // 2))
x = torch.cat([x2, x1], dim=1)
x = self.conv(x)
return x
class OutConv(nn.Module):
def __init__(self, in_ch, out_ch):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 1)
def forward(self, x):
x = self.conv(x)
return x

View File

@ -0,0 +1 @@
# this is an empty file

View File

@ -0,0 +1 @@
# this is an empty file

View File

@ -0,0 +1,24 @@
import torch
from overrides import overrides
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention
@MatrixAttention.register("ele_multiply")
class ElementWiseMatrixAttention(MatrixAttention):
"""
This similarity function simply computes the dot product between each pair of vectors, with an
optional scaling to reduce the variance of the output elements.
Parameters
----------
scale_output : ``bool``, optional
If ``True``, we will scale the output by ``math.sqrt(tensor.size(-1))``, to reduce the
variance in the result.
"""
def __init__(self) -> None:
super(ElementWiseMatrixAttention, self).__init__()
@overrides
def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
result = torch.einsum('iaj,ibj->ijab', [tensor_1, tensor_2])
return result

View File

@ -0,0 +1,176 @@
import os
import os.path
import json
import numpy as np
rel2id = json.load(open('../meta/rel2id.json', 'r'))
id2rel = {value: key for key, value in rel2id.items()}
def to_official(preds, features):
h_idx, t_idx, title = [], [], []
for f in features:
hts = f["hts"]
h_idx += [ht[0] for ht in hts]
t_idx += [ht[1] for ht in hts]
title += [f["title"] for ht in hts]
res = []
print('h_idx, preds', len(h_idx), len(preds))
# assert len(h_idx) == len(preds)
for i in range(preds.shape[0]):
pred = preds[i]
pred = np.nonzero(pred)[0].tolist()
for p in pred:
if p != 0:
res.append(
{
'title': title[i],
'h_idx': h_idx[i],
't_idx': t_idx[i],
'r': id2rel[p],
}
)
return res
def gen_train_facts(data_file_name, truth_dir):
fact_file_name = data_file_name[data_file_name.find("train_"):]
fact_file_name = os.path.join(truth_dir, fact_file_name.replace(".json", ".fact"))
if os.path.exists(fact_file_name):
fact_in_train = set([])
triples = json.load(open(fact_file_name))
for x in triples:
fact_in_train.add(tuple(x))
return fact_in_train
fact_in_train = set([])
ori_data = json.load(open(data_file_name))
for data in ori_data:
vertexSet = data['vertexSet']
for label in data['labels']:
rel = label['r']
for n1 in vertexSet[label['h']]:
for n2 in vertexSet[label['t']]:
fact_in_train.add((n1['name'], n2['name'], rel))
json.dump(list(fact_in_train), open(fact_file_name, "w"))
return fact_in_train
def official_evaluate(tmp, path):
'''
Adapted from the official evaluation code
'''
truth_dir = os.path.join(path, 'ref')
if not os.path.exists(truth_dir):
os.makedirs(truth_dir)
fact_in_train_annotated = gen_train_facts(os.path.join(path, "train_annotated.json"), truth_dir)
fact_in_train_distant = gen_train_facts(os.path.join(path, "train_distant.json"), truth_dir)
truth = json.load(open(os.path.join(path, "dev.json")))
std = {}
tot_evidences = 0
titleset = set([])
title2vectexSet = {}
for x in truth:
title = x['title']
titleset.add(title)
vertexSet = x['vertexSet']
title2vectexSet[title] = vertexSet
for label in x['labels']:
r = label['r']
h_idx = label['h']
t_idx = label['t']
std[(title, r, h_idx, t_idx)] = set(label['evidence'])
tot_evidences += len(label['evidence'])
tot_relations = len(std)
tmp.sort(key=lambda x: (x['title'], x['h_idx'], x['t_idx'], x['r']))
submission_answer = [tmp[0]]
for i in range(1, len(tmp)):
x = tmp[i]
y = tmp[i - 1]
if (x['title'], x['h_idx'], x['t_idx'], x['r']) != (y['title'], y['h_idx'], y['t_idx'], y['r']):
submission_answer.append(tmp[i])
correct_re = 0
correct_evidence = 0
pred_evi = 0
correct_in_train_annotated = 0
correct_in_train_distant = 0
titleset2 = set([])
for x in submission_answer:
title = x['title']
h_idx = x['h_idx']
t_idx = x['t_idx']
r = x['r']
titleset2.add(title)
if title not in title2vectexSet:
continue
vertexSet = title2vectexSet[title]
if 'evidence' in x:
evi = set(x['evidence'])
else:
evi = set([])
pred_evi += len(evi)
if (title, r, h_idx, t_idx) in std:
correct_re += 1
stdevi = std[(title, r, h_idx, t_idx)]
correct_evidence += len(stdevi & evi)
in_train_annotated = in_train_distant = False
for n1 in vertexSet[h_idx]:
for n2 in vertexSet[t_idx]:
if (n1['name'], n2['name'], r) in fact_in_train_annotated:
in_train_annotated = True
if (n1['name'], n2['name'], r) in fact_in_train_distant:
in_train_distant = True
if in_train_annotated:
correct_in_train_annotated += 1
if in_train_distant:
correct_in_train_distant += 1
re_p = 1.0 * correct_re / len(submission_answer)
re_r = 1.0 * correct_re / tot_relations
if re_p + re_r == 0:
re_f1 = 0
else:
re_f1 = 2.0 * re_p * re_r / (re_p + re_r)
evi_p = 1.0 * correct_evidence / pred_evi if pred_evi > 0 else 0
evi_r = 1.0 * correct_evidence / tot_evidences
if evi_p + evi_r == 0:
evi_f1 = 0
else:
evi_f1 = 2.0 * evi_p * evi_r / (evi_p + evi_r)
re_p_ignore_train_annotated = 1.0 * (correct_re - correct_in_train_annotated) / (len(submission_answer) - correct_in_train_annotated + 1e-5)
re_p_ignore_train = 1.0 * (correct_re - correct_in_train_distant) / (len(submission_answer) - correct_in_train_distant + 1e-5)
if re_p_ignore_train_annotated + re_r == 0:
re_f1_ignore_train_annotated = 0
else:
re_f1_ignore_train_annotated = 2.0 * re_p_ignore_train_annotated * re_r / (re_p_ignore_train_annotated + re_r)
if re_p_ignore_train + re_r == 0:
re_f1_ignore_train = 0
else:
re_f1_ignore_train = 2.0 * re_p_ignore_train * re_r / (re_p_ignore_train + re_r)
return re_f1, evi_f1, re_f1_ignore_train_annotated, re_f1_ignore_train, re_p, re_r

View File

@ -0,0 +1 @@
# this is an empty file

View File

@ -0,0 +1,77 @@
import torch
import torch.nn.functional as F
import numpy as np
def process_long_input(model, input_ids, attention_mask, start_tokens, end_tokens):
# Split the input to 2 overlapping chunks. Now BERT can encode inputs of which the length are up to 1024.
n, c = input_ids.size()
start_tokens = torch.tensor(start_tokens).to(input_ids)
end_tokens = torch.tensor(end_tokens).to(input_ids)
len_start = start_tokens.size(0)
len_end = end_tokens.size(0)
if c <= 512:
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=True,
)
sequence_output = output[0]
attention = output[-1][-1]
else:
new_input_ids, new_attention_mask, num_seg = [], [], []
seq_len = attention_mask.sum(1).cpu().numpy().astype(np.int32).tolist()
for i, l_i in enumerate(seq_len):
if l_i <= 512:
new_input_ids.append(input_ids[i, :512])
new_attention_mask.append(attention_mask[i, :512])
num_seg.append(1)
else:
input_ids1 = torch.cat([input_ids[i, :512 - len_end], end_tokens], dim=-1)
input_ids2 = torch.cat([start_tokens, input_ids[i, (l_i - 512 + len_start): l_i]], dim=-1)
attention_mask1 = attention_mask[i, :512]
attention_mask2 = attention_mask[i, (l_i - 512): l_i]
new_input_ids.extend([input_ids1, input_ids2])
new_attention_mask.extend([attention_mask1, attention_mask2])
num_seg.append(2)
input_ids = torch.stack(new_input_ids, dim=0)
attention_mask = torch.stack(new_attention_mask, dim=0)
output = model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=True,
)
sequence_output = output[0]
attention = output[-1][-1]
i = 0
new_output, new_attention = [], []
for (n_s, l_i) in zip(num_seg, seq_len):
if n_s == 1:
output = F.pad(sequence_output[i], (0, 0, 0, c - 512))
att = F.pad(attention[i], (0, c - 512, 0, c - 512))
new_output.append(output)
new_attention.append(att)
elif n_s == 2:
output1 = sequence_output[i][:512 - len_end]
mask1 = attention_mask[i][:512 - len_end]
att1 = attention[i][:, :512 - len_end, :512 - len_end]
output1 = F.pad(output1, (0, 0, 0, c - 512 + len_end))
mask1 = F.pad(mask1, (0, c - 512 + len_end))
att1 = F.pad(att1, (0, c - 512 + len_end, 0, c - 512 + len_end))
output2 = sequence_output[i + 1][len_start:]
mask2 = attention_mask[i + 1][len_start:]
att2 = attention[i + 1][:, len_start:, len_start:]
output2 = F.pad(output2, (0, 0, l_i - 512 + len_start, c - l_i))
mask2 = F.pad(mask2, (l_i - 512 + len_start, c - l_i))
att2 = F.pad(att2, [l_i - 512 + len_start, c - l_i, l_i - 512 + len_start, c - l_i])
mask = mask1 + mask2 + 1e-10
output = (output1 + output2) / mask.unsqueeze(-1)
att = (att1 + att2)
att = att / (att.sum(-1, keepdim=True) + 1e-10)
new_output.append(output)
new_attention.append(att)
i += n_s
sequence_output = torch.stack(new_output, dim=0)
attention = torch.stack(new_attention, dim=0)
return sequence_output, attention

View File

@ -0,0 +1,87 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def multilabel_categorical_crossentropy(y_true, y_pred):
"""多标签分类的交叉熵
说明y_true和y_pred的shape一致y_true的元素非0即1
1表示对应的类为目标类0表示对应的类为非目标类
警告请保证y_pred的值域是全体实数换言之一般情况下y_pred
不用加激活函数尤其是不能加sigmoid或者softmax预测
阶段则输出y_pred大于0的类如有疑问请仔细阅读并理解
本文
"""
y_pred = (1 - 2 * y_true) * y_pred
y_pred_neg = y_pred - y_true * 1e30
y_pred_pos = y_pred - (1 - y_true) * 1e30
zeros = torch.zeros_like(y_pred[..., :1])
y_pred_neg = torch.cat([y_pred_neg, zeros],dim=-1)
y_pred_pos = torch.cat((y_pred_pos, zeros),dim=-1)
neg_loss = torch.logsumexp(y_pred_neg, axis=-1)
pos_loss = torch.logsumexp(y_pred_pos, axis=-1)
return neg_loss + pos_loss
class balanced_loss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, labels):
loss = multilabel_categorical_crossentropy(labels,logits)
loss = loss.mean()
return loss
def get_label(self, logits, num_labels=-1):
th_logit = torch.zeros_like(logits[..., :1])
output = torch.zeros_like(logits).to(logits)
mask = (logits > th_logit)
if num_labels > 0:
top_v, _ = torch.topk(logits, num_labels, dim=1)
top_v = top_v[:, -1]
mask = (logits >= top_v.unsqueeze(1)) & mask
output[mask] = 1.0
output[:, 0] = (output[:,1:].sum(1) == 0.).to(logits)
return output
class ATLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, logits, labels):
# TH label
th_label = torch.zeros_like(labels, dtype=torch.float).to(labels)
th_label[:, 0] = 1.0
labels[:, 0] = 0.0
p_mask = labels + th_label
n_mask = 1 - labels
# Rank positive classes to TH
logit1 = logits - (1 - p_mask) * 1e30
loss1 = -(F.log_softmax(logit1, dim=-1) * labels).sum(1)
# Rank TH to negative classes
logit2 = logits - (1 - n_mask) * 1e30
loss2 = -(F.log_softmax(logit2, dim=-1) * th_label).sum(1)
# Sum two parts
loss = loss1 + loss2
loss = loss.mean()
return loss
def get_label(self, logits, num_labels=-1):
th_logit = logits[:, 0].unsqueeze(1)
output = torch.zeros_like(logits).to(logits)
mask = (logits > th_logit)
if num_labels > 0:
top_v, _ = torch.topk(logits, num_labels, dim=1)
top_v = top_v[:, -1]
mask = (logits >= top_v.unsqueeze(1)) & mask
output[mask] = 1.0
output[:, 0] = (output.sum(1) == 0.).to(logits)
return output

View File

@ -0,0 +1,99 @@
{
"P1376": 79,
"P607": 27,
"P136": 73,
"P137": 63,
"P131": 2,
"P527": 11,
"P1412": 38,
"P206": 33,
"P205": 77,
"P449": 52,
"P127": 34,
"P123": 49,
"P86": 66,
"P840": 85,
"P355": 72,
"P737": 93,
"P740": 84,
"P190": 94,
"P576": 71,
"P749": 68,
"P112": 65,
"P118": 40,
"P17": 1,
"P19": 14,
"P3373": 19,
"P6": 42,
"P276": 44,
"P1001": 24,
"P580": 62,
"P582": 83,
"P585": 64,
"P463": 18,
"P676": 87,
"P674": 46,
"P264": 10,
"P108": 43,
"P102": 17,
"P25": 81,
"P27": 3,
"P26": 26,
"P20": 37,
"P22": 30,
"Na": 0,
"P807": 95,
"P800": 51,
"P279": 78,
"P1336": 88,
"P577": 5,
"P570": 8,
"P571": 15,
"P178": 36,
"P179": 55,
"P272": 75,
"P170": 35,
"P171": 80,
"P172": 76,
"P175": 6,
"P176": 67,
"P39": 91,
"P30": 21,
"P31": 60,
"P36": 70,
"P37": 58,
"P35": 54,
"P400": 31,
"P403": 61,
"P361": 12,
"P364": 74,
"P569": 7,
"P710": 41,
"P1344": 32,
"P488": 82,
"P241": 59,
"P162": 57,
"P161": 9,
"P166": 47,
"P40": 20,
"P1441": 23,
"P156": 45,
"P155": 39,
"P150": 4,
"P551": 90,
"P706": 56,
"P159": 29,
"P495": 13,
"P58": 53,
"P194": 48,
"P54": 16,
"P57": 28,
"P50": 22,
"P1366": 86,
"P1365": 92,
"P937": 69,
"P140": 50,
"P69": 25,
"P1198": 96,
"P1056": 89
}

View File

@ -0,0 +1,207 @@
import torch
import torch.nn as nn
from opt_einsum import contract
from long_seq import process_long_input
from losses import balanced_loss as ATLoss
import torch.nn.functional as F
from allennlp.modules.matrix_attention import DotProductMatrixAttention, CosineMatrixAttention, BilinearMatrixAttention
from element_wise import ElementWiseMatrixAttention
from attn_unet import AttentionUNet
class DocREModel(nn.Module):
def __init__(self, config, args, model, emb_size=768, block_size=64, num_labels=-1):
super().__init__()
self.config = config
self.bert_model = model
self.hidden_size = config.hidden_size
self.loss_fnt = ATLoss()
self.head_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)
self.tail_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)
# self.head_extractor = nn.Linear(1 * config.hidden_size , emb_size)
# self.tail_extractor = nn.Linear(1 * config.hidden_size , emb_size)
self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)
self.emb_size = emb_size
self.block_size = block_size
self.num_labels = num_labels
self.bertdrop = nn.Dropout(0.6)
self.unet_in_dim = args.unet_in_dim
self.unet_out_dim = args.unet_in_dim
self.liner = nn.Linear(config.hidden_size, args.unet_in_dim)
self.min_height = args.max_height
self.channel_type = args.channel_type
self.segmentation_net = AttentionUNet(input_channels=args.unet_in_dim,
class_number=args.unet_out_dim,
down_channel=args.down_dim)
def encode(self, input_ids, attention_mask,entity_pos):
config = self.config
if config.transformer_type == "bert":
start_tokens = [config.cls_token_id]
end_tokens = [config.sep_token_id]
elif config.transformer_type == "roberta":
start_tokens = [config.cls_token_id]
end_tokens = [config.sep_token_id, config.sep_token_id]
sequence_output, attention = process_long_input(self.bert_model, input_ids, attention_mask, start_tokens, end_tokens)
return sequence_output, attention
def get_hrt(self, sequence_output, attention, entity_pos, hts):
offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
bs, h, _, c = attention.size()
# ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数
hss, tss, rss = [], [], []
entity_es = []
entity_as = []
for i in range(len(entity_pos)):
entity_embs, entity_atts = [], []
for entity_num, e in enumerate(entity_pos[i]):
if len(e) > 1:
e_emb, e_att = [], []
for start, end in e:
if start + offset < c:
# In case the entity mention is truncated due to limited max seq length.
e_emb.append(sequence_output[i, start + offset])
e_att.append(attention[i, :, start + offset])
if len(e_emb) > 0:
e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
e_att = torch.stack(e_att, dim=0).mean(0)
else:
e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
e_att = torch.zeros(h, c).to(attention)
else:
start, end = e[0]
if start + offset < c:
e_emb = sequence_output[i, start + offset]
e_att = attention[i, :, start + offset]
else:
e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
e_att = torch.zeros(h, c).to(attention)
entity_embs.append(e_emb)
entity_atts.append(e_att)
for _ in range(self.min_height-entity_num-1):
entity_atts.append(e_att)
entity_embs = torch.stack(entity_embs, dim=0) # [n_e, d]
entity_atts = torch.stack(entity_atts, dim=0) # [n_e, h, seq_len]
entity_es.append(entity_embs)
entity_as.append(entity_atts)
ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
hss.append(hs)
tss.append(ts)
hss = torch.cat(hss, dim=0)
tss = torch.cat(tss, dim=0)
return hss, tss, entity_es, entity_as
def get_mask(self, ents, bs, ne, run_device):
ent_mask = torch.zeros(bs, ne, device=run_device)
rel_mask = torch.zeros(bs, ne, ne, device=run_device)
for _b in range(bs):
ent_mask[_b, :len(ents[_b])] = 1
rel_mask[_b, :len(ents[_b]), :len(ents[_b])] = 1
return ent_mask, rel_mask
def get_ht(self, rel_enco, hts):
htss = []
for i in range(len(hts)):
ht_index = hts[i]
for (h_index, t_index) in ht_index:
htss.append(rel_enco[i,h_index,t_index])
htss = torch.stack(htss,dim=0)
return htss
def get_channel_map(self, sequence_output, entity_as):
# sequence_output = sequence_output.to('cpu')
# attention = attention.to('cpu')
bs,_,d = sequence_output.size()
# ne = max([len(x) for x in entity_as]) # 本次bs中的最大实体数
ne = self.min_height
index_pair = []
for i in range(ne):
tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)
index_pair.append(tmp)
index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)
map_rss = []
for b in range(bs):
entity_atts = entity_as[b]
h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])
t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])
ht_att = (h_att * t_att).mean(1)
ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
rs = contract("ld,rl->rd", sequence_output[b], ht_att)
map_rss.append(rs)
map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)
return map_rss
def forward(self,
input_ids=None,
attention_mask=None,
labels=None,
entity_pos=None,
hts=None,
instance_mask=None,
):
sequence_output, attention = self.encode(input_ids, attention_mask,entity_pos)
bs, sequen_len, d = sequence_output.shape
run_device = sequence_output.device.index
ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数
ent_mask, rel_mask = self.get_mask(entity_pos, bs, ne, run_device)
# get hs, ts and entity_embs >> entity_rs
hs, ts, entity_embs, entity_as = self.get_hrt(sequence_output, attention, entity_pos, hts)
# 获得通道map的两种不同方法
if self.channel_type == 'context-based':
feature_map = self.get_channel_map(sequence_output, entity_as)
##print('feature_map:', feature_map.shape)
attn_input = self.liner(feature_map).permute(0, 3, 1, 2).contiguous()
elif self.channel_type == 'similarity-based':
ent_encode = sequence_output.new_zeros(bs, self.min_height, d)
for _b in range(bs):
entity_emb = entity_embs[_b]
entity_num = entity_emb.size(0)
ent_encode[_b, :entity_num, :] = entity_emb
# similar0 = ElementWiseMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
similar1 = DotProductMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
similar2 = CosineMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
similar3 = BilinearMatrixAttention(self.emb_size,self.self.emb_size).to(ent_encode.device)(ent_encode, ent_encode).unsqueeze(-1)
attn_input = torch.cat([similar1,similar2,similar3],dim=-1).permute(0, 3, 1, 2).contiguous()
else:
raise Exception("channel_type must be specify correctly")
attn_map = self.segmentation_net(attn_input)
h_t = self.get_ht (attn_map, hts)
hs = torch.tanh(self.head_extractor(torch.cat([hs, h_t], dim=1)))
ts = torch.tanh(self.tail_extractor(torch.cat([ts, h_t], dim=1)))
b1 = hs.view(-1, self.emb_size // self.block_size, self.block_size)
b2 = ts.view(-1, self.emb_size // self.block_size, self.block_size)
bl = (b1.unsqueeze(3) * b2.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
logits = self.bilinear(bl)
output = (self.loss_fnt.get_label(logits, num_labels=self.num_labels))
if labels is not None:
labels = [torch.tensor(label) for label in labels]
labels = torch.cat(labels, dim=0).to(logits)
loss = self.loss_fnt(logits.float(), labels.float())
output = (loss.to(sequence_output), output)
return output

View File

@ -0,0 +1,217 @@
import torch
import torch.nn as nn
from opt_einsum import contract
from long_seq import process_long_input
from losses import ATLoss
import torch.nn.functional as F
from allennlp.modules.matrix_attention import DotProductMatrixAttention, CosineMatrixAttention, BilinearMatrixAttention
from element_wise import ElementWiseMatrixAttention
from attn_unet import AttentionUNet
class DocREModel(nn.Module):
def __init__(self, config, args, model, emb_size=768, block_size=64, num_labels=-1):
super().__init__()
self.config = config
self.bert_model = model
self.hidden_size = config.hidden_size
self.loss_fnt = ATLoss()
self.head_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)
self.tail_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)
# self.head_extractor = nn.Linear(1 * config.hidden_size , emb_size)
# self.tail_extractor = nn.Linear(1 * config.hidden_size , emb_size)
self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)
self.emb_size = emb_size
self.block_size = block_size
self.num_labels = num_labels
self.bertdrop = nn.Dropout(0.6)
self.unet_in_dim = args.unet_in_dim
self.unet_out_dim = args.unet_in_dim
self.liner = nn.Linear(config.hidden_size, args.unet_in_dim)
self.min_height = args.max_height
self.channel_type = args.channel_type
self.segmentation_net = AttentionUNet(input_channels=args.unet_in_dim,
class_number=args.unet_out_dim,
down_channel=args.down_dim)
def encode(self, input_ids, attention_mask,entity_pos):
config = self.config
if config.transformer_type == "bert":
start_tokens = [config.cls_token_id]
end_tokens = [config.sep_token_id]
elif config.transformer_type == "roberta":
start_tokens = [config.cls_token_id]
end_tokens = [config.sep_token_id, config.sep_token_id]
sequence_output, attention = process_long_input(self.bert_model, input_ids, attention_mask, start_tokens, end_tokens)
return sequence_output, attention
def get_hrt(self, sequence_output, attention, entity_pos, hts):
offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
bs, h, _, c = attention.size()
# ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数
hss, tss, rss = [], [], []
entity_es = []
entity_as = []
for i in range(len(entity_pos)):
entity_embs, entity_atts = [], []
for entity_num, e in enumerate(entity_pos[i]):
if len(e) > 1:
e_emb, e_att = [], []
for start, end in e:
if start + offset < c:
# In case the entity mention is truncated due to limited max seq length.
e_emb.append(sequence_output[i, start + offset])
e_att.append(attention[i, :, start + offset])
if len(e_emb) > 0:
e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
e_att = torch.stack(e_att, dim=0).mean(0)
else:
e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
e_att = torch.zeros(h, c).to(attention)
else:
start, end = e[0]
if start + offset < c:
e_emb = sequence_output[i, start + offset]
e_att = attention[i, :, start + offset]
else:
e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
e_att = torch.zeros(h, c).to(attention)
entity_embs.append(e_emb)
entity_atts.append(e_att)
for _ in range(self.min_height-entity_num-1):
entity_atts.append(e_att)
entity_embs = torch.stack(entity_embs, dim=0) # [n_e, d]
entity_atts = torch.stack(entity_atts, dim=0) # [n_e, h, seq_len]
entity_es.append(entity_embs)
entity_as.append(entity_atts)
ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
# h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
# t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
# ht_att = (h_att * t_att).mean(1)
# ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
# rs = contract("ld,rl->rd", sequence_output[i], ht_att)
hss.append(hs)
tss.append(ts)
# rss.append(rs)
hss = torch.cat(hss, dim=0)
tss = torch.cat(tss, dim=0)
# rss = torch.cat(rss, dim=0)
return hss, tss, entity_es, entity_as
def get_mask(self, ents, bs, ne, run_device):
ent_mask = torch.zeros(bs, ne, device=run_device)
rel_mask = torch.zeros(bs, ne, ne, device=run_device)
for _b in range(bs):
ent_mask[_b, :len(ents[_b])] = 1
rel_mask[_b, :len(ents[_b]), :len(ents[_b])] = 1
return ent_mask, rel_mask
def get_ht(self, rel_enco, hts):
htss = []
for i in range(len(hts)):
ht_index = hts[i]
for (h_index, t_index) in ht_index:
htss.append(rel_enco[i,h_index,t_index])
htss = torch.stack(htss,dim=0)
return htss
def get_channel_map(self, sequence_output, entity_as):
# sequence_output = sequence_output.to('cpu')
# attention = attention.to('cpu')
bs,_,d = sequence_output.size()
# ne = max([len(x) for x in entity_as]) # 本次bs中的最大实体数
ne = self.min_height
index_pair = []
for i in range(ne):
tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)
index_pair.append(tmp)
index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)
map_rss = []
for b in range(bs):
entity_atts = entity_as[b]
h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])
t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])
ht_att = (h_att * t_att).mean(1)
ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
rs = contract("ld,rl->rd", sequence_output[b], ht_att)
map_rss.append(rs)
map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)
return map_rss
def forward(self,
input_ids=None,
attention_mask=None,
labels=None,
entity_pos=None,
hts=None,
instance_mask=None,
):
sequence_output, attention = self.encode(input_ids, attention_mask,entity_pos)
# sequence_output = self.bertdrop(sequence_output)
bs, sequen_len, d = sequence_output.shape
run_device = sequence_output.device.index
ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数
ent_mask, rel_mask = self.get_mask(entity_pos, bs, ne, run_device)
# get hs, ts and entity_embs >> entity_rs
hs, ts, entity_embs, entity_as = self.get_hrt(sequence_output, attention, entity_pos, hts)
# 获得通道map的两种不同方法
if self.channel_type == 'att_map':
feature_map = self.get_channel_map(sequence_output, entity_as)
##print('feature_map:', feature_map.shape)
attn_input = self.liner(feature_map).permute(0, 3, 1, 2).contiguous()
elif self.channel_type == 'dot_map':
ent_encode = sequence_output.new_zeros(bs, self.min_height, d)
for _b in range(bs):
entity_emb = entity_embs[_b]
entity_num = entity_emb.size(0)
ent_encode[_b, :entity_num, :] = entity_emb
# similar0 = ElementWiseMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
similar1 = DotProductMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
similar2 = CosineMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
similar3 = BilinearMatrixAttention(self.emb_size,self.self.emb_size).to(ent_encode.device)(ent_encode, ent_encode).unsqueeze(-1)
attn_input = torch.cat([similar1,similar2,similar3],dim=-1).permute(0, 3, 1, 2).contiguous()
else:
# raise Exception("channel_type must be specify correctly")
pass
attn_map = self.segmentation_net(attn_input)
h_t = self.get_ht (attn_map, hts)
hs = torch.tanh(self.head_extractor(torch.cat([hs, h_t], dim=1)))
ts = torch.tanh(self.tail_extractor(torch.cat([ts, h_t], dim=1)))
b1 = hs.view(-1, self.emb_size // self.block_size, self.block_size)
b2 = ts.view(-1, self.emb_size // self.block_size, self.block_size)
bl = (b1.unsqueeze(3) * b2.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
# bl = torch.cat((bl,ht),dim=-1)
logits = self.bilinear(bl)
output = logits
# output = (self.loss_fnt.get_label(logits, num_labels=self.num_labels))
if labels is not None:
labels = [torch.tensor(label) for label in labels]
labels = torch.cat(labels, dim=0).to(logits)
loss = self.loss_fnt(logits.float(), labels.float())
output = (loss.to(sequence_output), output)
return output

View File

@ -0,0 +1,24 @@
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
name_list = ['1-5', '6-10', '11-15', '16-20', '21-25', '26-30', '31-35', '35-42']
data1 = [70.3, 65.0, 64.0, 63.8, 62.7, 59.2, 57.7, 55.3 ]
data2 = [70.0, 64.2, 63.1, 62.0, 60.5, 56.1, 54.3, 51.4 ]
x =list(range(len(data1)))
plt.figure(figsize=(6, 3))
plt.plot(x,data1,marker='>',label='w/ U-shaped Network')
plt.plot(x,data2,'--',marker='o',label='w/o U-shaped Network')
plt.ylim([50,75])
plt.xlabel('# of Entities per Document')
plt.ylabel('dev F1 (in %)')
plt.xticks(x,name_list)
plt.legend(loc='upper right')
plt.show()

View File

@ -0,0 +1,437 @@
from fsspec import transaction
from torch.utils import data
from tqdm import tqdm
from transformers.models.auto.configuration_auto import F
import ujson as json
import os
import pickle
import random
import numpy as np
docred_rel2id = json.load(open('../meta/rel2id.json', 'r'))
cdr_rel2id = {'1:NR:2': 0, '1:CID:2': 1}
gda_rel2id = {'1:NR:2': 0, '1:GDA:2': 1}
def chunks(l, n):
res = []
for i in range(0, len(l), n):
assert len(l[i:i + n]) == n
res += [l[i:i + n]]
return res
class ReadDataset:
def __init__(self, dataset: str, tokenizer, max_seq_Length: int = 1024,
transformers: str = 'bert') -> None:
self.transformers = transformers
self.dataset = dataset
self.tokenizer = tokenizer
self.max_seq_Length = max_seq_Length
def read(self, file_in: str):
save_file = file_in.split('.json')[0] + self.transformers + '_' \
+ self.dataset + '.pkl'
if self.dataset == 'docred':
read_docred(self.transformers, file_in, save_file, self.tokenizer, self.max_seq_Length)
elif self.dataset == 'cdr':
read_cdr(file_in, save_file, self.tokenizer, self.max_seq_Length)
elif self.dataset == 'gda':
read_gda(file_in, save_file, self.tokenizer, self.max_seq_Length)
else:
raise RuntimeError("No read func for this dataset.")
def read_docred(transfermers, file_in, save_file, tokenizer, max_seq_length=1024):
if os.path.exists(save_file):
with open(file=save_file, mode='rb') as fr:
features = pickle.load(fr)
fr.close()
print('load preprocessed data from {}.'.format(save_file))
return features
else:
max_len = 0
up512_num = 0
i_line = 0
pos_samples = 0
neg_samples = 0
features = []
if file_in == "":
return None
with open(file_in, "r") as fh:
data = json.load(fh)
if transfermers == 'bert':
# entity_type = ["ORG", "-", "LOC", "-", "TIME", "-", "PER", "-", "MISC", "-", "NUM"]
entity_type = ["-", "ORG", "-", "LOC", "-", "TIME", "-", "PER", "-", "MISC", "-", "NUM"]
for sample in tqdm(data, desc="Example"):
sents = []
sent_map = []
entities = sample['vertexSet']
entity_start, entity_end = [], []
mention_types = []
for entity in entities:
for mention in entity:
sent_id = mention["sent_id"]
pos = mention["pos"]
entity_start.append((sent_id, pos[0]))
entity_end.append((sent_id, pos[1] - 1))
mention_types.append(mention['type'])
for i_s, sent in enumerate(sample['sents']):
new_map = {}
for i_t, token in enumerate(sent):
tokens_wordpiece = tokenizer.tokenize(token)
if (i_s, i_t) in entity_start:
t = entity_start.index((i_s, i_t))
if transfermers == 'bert':
mention_type = mention_types[t]
special_token_i = entity_type.index(mention_type)
special_token = ['[unused' + str(special_token_i) + ']']
else:
special_token = ['*']
tokens_wordpiece = special_token + tokens_wordpiece
# tokens_wordpiece = ["[unused0]"]+ tokens_wordpiece
if (i_s, i_t) in entity_end:
t = entity_end.index((i_s, i_t))
if transfermers == 'bert':
mention_type = mention_types[t]
special_token_i = entity_type.index(mention_type) + 50
special_token = ['[unused' + str(special_token_i) + ']']
else:
special_token = ['*']
tokens_wordpiece = tokens_wordpiece + special_token
# tokens_wordpiece = tokens_wordpiece + ["[unused1]"]
# print(tokens_wordpiece,tokenizer.convert_tokens_to_ids(tokens_wordpiece))
new_map[i_t] = len(sents)
sents.extend(tokens_wordpiece)
new_map[i_t + 1] = len(sents)
sent_map.append(new_map)
if len(sents)>max_len:
max_len=len(sents)
if len(sents)>512:
up512_num += 1
train_triple = {}
if "labels" in sample:
for label in sample['labels']:
evidence = label['evidence']
r = int(docred_rel2id[label['r']])
if (label['h'], label['t']) not in train_triple:
train_triple[(label['h'], label['t'])] = [
{'relation': r, 'evidence': evidence}]
else:
train_triple[(label['h'], label['t'])].append(
{'relation': r, 'evidence': evidence})
entity_pos = []
for e in entities:
entity_pos.append([])
mention_num = len(e)
for m in e:
start = sent_map[m["sent_id"]][m["pos"][0]]
end = sent_map[m["sent_id"]][m["pos"][1]]
entity_pos[-1].append((start, end,))
relations, hts = [], []
# Get positive samples from dataset
for h, t in train_triple.keys():
relation = [0] * len(docred_rel2id)
for mention in train_triple[h, t]:
relation[mention["relation"]] = 1
evidence = mention["evidence"]
relations.append(relation)
hts.append([h, t])
pos_samples += 1
# Get negative samples from dataset
for h in range(len(entities)):
for t in range(len(entities)):
if h != t and [h, t] not in hts:
relation = [1] + [0] * (len(docred_rel2id) - 1)
relations.append(relation)
hts.append([h, t])
neg_samples += 1
assert len(relations) == len(entities) * (len(entities) - 1)
if len(hts)==0:
print(len(sent))
sents = sents[:max_seq_length - 2]
input_ids = tokenizer.convert_tokens_to_ids(sents)
input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)
i_line += 1
feature = {'input_ids': input_ids,
'entity_pos': entity_pos,
'labels': relations,
'hts': hts,
'title': sample['title'],
}
features.append(feature)
print("# of documents {}.".format(i_line))
print("# of positive examples {}.".format(pos_samples))
print("# of negative examples {}.".format(neg_samples))
print("# {} examples len>512 and max len is {}.".format(up512_num, max_len))
with open(file=save_file, mode='wb') as fw:
pickle.dump(features, fw)
print('finish reading {} and save preprocessed data to {}.'.format(file_in, save_file))
return features
def read_cdr(file_in, save_file, tokenizer, max_seq_length=1024):
if os.path.exists(save_file):
with open(file=save_file, mode='rb') as fr:
features = pickle.load(fr)
fr.close()
print('load preprocessed data from {}.'.format(save_file))
return features
else:
pmids = set()
features = []
maxlen = 0
with open(file_in, 'r') as infile:
lines = infile.readlines()
for i_l, line in enumerate(tqdm(lines)):
line = line.rstrip().split('\t')
pmid = line[0]
if pmid not in pmids:
pmids.add(pmid)
text = line[1]
prs = chunks(line[2:], 17)
ent2idx = {}
train_triples = {}
entity_pos = set()
for p in prs:
es = list(map(int, p[8].split(':')))
ed = list(map(int, p[9].split(':')))
tpy = p[7]
for start, end in zip(es, ed):
entity_pos.add((start, end, tpy))
es = list(map(int, p[14].split(':')))
ed = list(map(int, p[15].split(':')))
tpy = p[13]
for start, end in zip(es, ed):
entity_pos.add((start, end, tpy))
sents = [t.split(' ') for t in text.split('|')]
new_sents = []
sent_map = {}
i_t = 0
for sent in sents:
for token in sent:
tokens_wordpiece = tokenizer.tokenize(token)
for start, end, tpy in list(entity_pos):
if i_t == start:
tokens_wordpiece = ["*"] + tokens_wordpiece
if i_t + 1 == end:
tokens_wordpiece = tokens_wordpiece + ["*"]
sent_map[i_t] = len(new_sents)
new_sents.extend(tokens_wordpiece)
i_t += 1
sent_map[i_t] = len(new_sents)
sents = new_sents
entity_pos = []
for p in prs:
if p[0] == "not_include":
continue
if p[1] == "L2R":
h_id, t_id = p[5], p[11]
h_start, t_start = p[8], p[14]
h_end, t_end = p[9], p[15]
else:
t_id, h_id = p[5], p[11]
t_start, h_start = p[8], p[14]
t_end, h_end = p[9], p[15]
h_start = map(int, h_start.split(':'))
h_end = map(int, h_end.split(':'))
t_start = map(int, t_start.split(':'))
t_end = map(int, t_end.split(':'))
h_start = [sent_map[idx] for idx in h_start]
h_end = [sent_map[idx] for idx in h_end]
t_start = [sent_map[idx] for idx in t_start]
t_end = [sent_map[idx] for idx in t_end]
if h_id not in ent2idx:
ent2idx[h_id] = len(ent2idx)
entity_pos.append(list(zip(h_start, h_end)))
if t_id not in ent2idx:
ent2idx[t_id] = len(ent2idx)
entity_pos.append(list(zip(t_start, t_end)))
h_id, t_id = ent2idx[h_id], ent2idx[t_id]
r = cdr_rel2id[p[0]]
if (h_id, t_id) not in train_triples:
train_triples[(h_id, t_id)] = [{'relation': r}]
else:
train_triples[(h_id, t_id)].append({'relation': r})
relations, hts = [], []
for h, t in train_triples.keys():
relation = [0] * len(cdr_rel2id)
for mention in train_triples[h, t]:
relation[mention["relation"]] = 1
relations.append(relation)
hts.append([h, t])
maxlen = max(maxlen, len(sents))
sents = sents[:max_seq_length - 2]
input_ids = tokenizer.convert_tokens_to_ids(sents)
input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)
if len(hts) > 0:
feature = {'input_ids': input_ids,
'entity_pos': entity_pos,
'labels': relations,
'hts': hts,
'title': pmid,
}
features.append(feature)
print("Number of documents: {}.".format(len(features)))
print("Max document length: {}.".format(maxlen))
with open(file=save_file, mode='wb') as fw:
pickle.dump(features, fw)
print('finish reading {} and save preprocessed data to {}.'.format(file_in, save_file))
return features
def read_gda(file_in, save_file, tokenizer, max_seq_length=1024):
if os.path.exists(save_file):
with open(file=save_file, mode='rb') as fr:
features = pickle.load(fr)
fr.close()
print('load preprocessed data from {}.'.format(save_file))
return features
else:
pmids = set()
features = []
maxlen = 0
with open(file_in, 'r') as infile:
lines = infile.readlines()
for i_l, line in enumerate(tqdm(lines)):
line = line.rstrip().split('\t')
pmid = line[0]
if pmid not in pmids:
pmids.add(pmid)
text = line[1]
prs = chunks(line[2:], 17)
ent2idx = {}
train_triples = {}
entity_pos = set()
for p in prs:
es = list(map(int, p[8].split(':')))
ed = list(map(int, p[9].split(':')))
tpy = p[7]
for start, end in zip(es, ed):
entity_pos.add((start, end, tpy))
es = list(map(int, p[14].split(':')))
ed = list(map(int, p[15].split(':')))
tpy = p[13]
for start, end in zip(es, ed):
entity_pos.add((start, end, tpy))
sents = [t.split(' ') for t in text.split('|')]
new_sents = []
sent_map = {}
i_t = 0
for sent in sents:
for token in sent:
tokens_wordpiece = tokenizer.tokenize(token)
for start, end, tpy in list(entity_pos):
if i_t == start:
tokens_wordpiece = ["*"] + tokens_wordpiece
if i_t + 1 == end:
tokens_wordpiece = tokens_wordpiece + ["*"]
sent_map[i_t] = len(new_sents)
new_sents.extend(tokens_wordpiece)
i_t += 1
sent_map[i_t] = len(new_sents)
sents = new_sents
entity_pos = []
for p in prs:
if p[0] == "not_include":
continue
if p[1] == "L2R":
h_id, t_id = p[5], p[11]
h_start, t_start = p[8], p[14]
h_end, t_end = p[9], p[15]
else:
t_id, h_id = p[5], p[11]
t_start, h_start = p[8], p[14]
t_end, h_end = p[9], p[15]
h_start = map(int, h_start.split(':'))
h_end = map(int, h_end.split(':'))
t_start = map(int, t_start.split(':'))
t_end = map(int, t_end.split(':'))
h_start = [sent_map[idx] for idx in h_start]
h_end = [sent_map[idx] for idx in h_end]
t_start = [sent_map[idx] for idx in t_start]
t_end = [sent_map[idx] for idx in t_end]
if h_id not in ent2idx:
ent2idx[h_id] = len(ent2idx)
entity_pos.append(list(zip(h_start, h_end)))
if t_id not in ent2idx:
ent2idx[t_id] = len(ent2idx)
entity_pos.append(list(zip(t_start, t_end)))
h_id, t_id = ent2idx[h_id], ent2idx[t_id]
r = gda_rel2id[p[0]]
if (h_id, t_id) not in train_triples:
train_triples[(h_id, t_id)] = [{'relation': r}]
else:
train_triples[(h_id, t_id)].append({'relation': r})
relations, hts = [], []
for h, t in train_triples.keys():
relation = [0] * len(gda_rel2id)
for mention in train_triples[h, t]:
relation[mention["relation"]] = 1
relations.append(relation)
hts.append([h, t])
maxlen = max(maxlen, len(sents))
sents = sents[:max_seq_length - 2]
input_ids = tokenizer.convert_tokens_to_ids(sents)
input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)
if len(hts) > 0:
feature = {'input_ids': input_ids,
'entity_pos': entity_pos,
'labels': relations,
'hts': hts,
'title': pmid,
}
features.append(feature)
print("Number of documents: {}.".format(len(features)))
print("Max document length: {}.".format(maxlen))
with open(file=save_file, mode='wb') as fw:
pickle.dump(features, fw)
print('finish reading {} and save preprocessed data to {}.'.format(file_in, save_file))
return features

View File

@ -0,0 +1,8 @@
python==3.7
cuda==10.2
torch==1.5.0
transformers==3.0.4
opt-einsum==3.3.0
ujson
tqdm
allennlp

View File

@ -0,0 +1,35 @@
#! /bin/bash
export CUDA_VISIBLE_DEVICES=0
if true; then
type=context-based
bs=4
bl=3e-5
uls=(4e-4)
accum=1
for ul in ${uls[@]}
do
python -u ../train_bio.py --data_dir ../dataset/cdr \
--max_height 35 \
--channel_type $type \
--bert_lr $bl \
--transformer_type bert \
--model_name_or_path allenai/scibert_scivocab_cased \
--train_file train.data \
--dev_file dev.data \
--test_file test.data \
--train_batch_size $bs \
--test_batch_size $bs \
--gradient_accumulation_steps $accum \
--num_labels 1 \
--learning_rate $ul \
--max_grad_norm 1.0 \
--warmup_ratio 0.06 \
--num_train_epochs 30.0 \
--seed 111 \
--num_class 2 \
--save_path ../checkpoint/cdr/train_scibert-lr${bl}_accum${accum}_unet-lr${ul}_bs${bs}.pt \
--log_dir ../logs/cdr/train_scibert-lr${bl}_accum${accum}_unet-lr${ul}_bs${bs}.log
done
fi

View File

@ -0,0 +1,66 @@
#! /bin/bash
export CUDA_VISIBLE_DEVICES=0
# -------------------Training Shell Script--------------------
if true; then
transformer_type=bert
channel_type=context-based
if [[ $transformer_type == bert ]]; then
bs=4
bl=3e-5
ul=(3e-4 4e-4 5e-4)
accum=1
for ul in ${uls[@]}
do
python -u ../train_balanceloss.py --data_dir ../dataset/docred \
--channel_type $channel_type \
--bert_lr $bl \
--transformer_type $transformer_type \
--model_name_or_path bert-base-cased \
--train_file train_annotated.json \
--dev_file dev.json \
--test_file test.json \
--train_batch_size $bs \
--test_batch_size $bs \
--gradient_accumulation_steps $accum \
--num_labels 3 \
--learning_rate $ul \
--max_grad_norm 1.0 \
--warmup_ratio 0.06 \
--num_train_epochs 30.0 \
--seed 66 \
--num_class 97 \
--save_path ../checkpoint/docred/train_bert-lr${bl}_accum${accum}_unet-lr${ul}_type_${channel_type}.pt \
--log_dir ../logs/docred/train_bert-lr${bl}_accum${accum}_unet-lr${ul}_type_${channel_type}.log
done
elif [[ $transformer_type == roberta ]]; then
type=context-based
bs=2
bls=(3e-5)
ul=4e-4
accum=2
for ul in ${uls[@]}
do
python -u ../train_balanceloss.py --data_dir ../dataset/docred \
--channel_type $channel_type \
--bert_lr $bl \
--transformer_type $transformer_type \
--model_name_or_path roberta-large \
--train_file train_annotated.json \
--dev_file dev.json \
--test_file test.json \
--train_batch_size $bs \
--test_batch_size $bs \
--gradient_accumulation_steps $accum \
--num_labels 4 \
--learning_rate $ul \
--max_grad_norm 1.0 \
--warmup_ratio 0.06 \
--num_train_epochs 30.0 \
--seed 111 \
--num_class 97 \
--save_path ../checkpoint/docred/train_bert-lr${bl}_accum${accum}_unet-lr${ul}_type_${channel_type}.pt \
--log_dir ../logs/docred/train_bert-lr${bl}_accum${accum}_unet-lr${ul}_type_${channel_type}.log
done
fi
fi

View File

@ -0,0 +1,35 @@
#! /bin/bash
export CUDA_VISIBLE_DEVICES=0
if true; then
type=context-based
bs=4
bl=3e-5
uls=(4e-4)
accum=4
for ul in ${uls[@]}
do
python -u ../train_bio.py --data_dir ../dataset/gda \
--max_height 35 \
--channel_type $type \
--bert_lr $bl \
--transformer_type bert \
--model_name_or_path allenai/scibert_scivocab_cased \
--train_file train.data \
--dev_file dev.data \
--test_file test.data \
--train_batch_size $bs \
--test_batch_size $bs \
--gradient_accumulation_steps $accum \
--num_labels 1 \
--learning_rate $ul \
--max_grad_norm 1.0 \
--warmup_ratio 0.06 \
--num_train_epochs 10.0 \
--evaluation_steps 400 \
--seed 66 \
--num_class 2 \
--save_path ../checkpoint/gda/train_scibert-lr${bl}_accum${accum}_unet-lr${ul}_bs${bs}.pt \
--log_dir ../logs/gda/train_scibert-lr${bl}_accum${accum}_unet-lr${ul}_bs${bs}.log
done
fi

View File

@ -0,0 +1 @@
# this is an empty file

View File

@ -0,0 +1,325 @@
import argparse
import os
import time
from datetime import datetime
import numpy as np
import torch
import ujson as json
from torch.utils.data import DataLoader
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
from model_balanceloss import DocREModel
from utils_sample import set_seed, collate_fn
from evaluation import to_official, official_evaluate
from prepro import ReadDataset
def train(args, model, train_features, dev_features, test_features):
def logging(s, print_=True, log_=True):
if print_:
print(s)
if log_:
with open(args.log_dir, 'a+') as f_log:
f_log.write(s + '\n')
def finetune(features, optimizer, num_epoch, num_steps, model):
if args.train_from_saved_model != '':
best_score = torch.load(args.train_from_saved_model)["best_f1"]
epoch_delta = torch.load(args.train_from_saved_model)["epoch"] + 1
else:
epoch_delta = 0
best_score = -1
train_dataloader = DataLoader(features, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
train_iterator = [epoch + epoch_delta for epoch in range(int(num_epoch))]
total_steps = int(len(train_dataloader) * num_epoch // args.gradient_accumulation_steps)
warmup_steps = int(total_steps * args.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
print("Total steps: {}".format(total_steps))
print("Warmup steps: {}".format(warmup_steps))
global_step = 0
log_step = 100
total_loss = 0
print('torch.cuda.device_count():',torch.cuda.device_count())
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
#scaler = GradScaler()
for epoch in train_iterator:
start_time = time.time()
optimizer.zero_grad()
for step, batch in enumerate(train_dataloader):
model.train()
inputs = {'input_ids': batch[0].to(args.device),
'attention_mask': batch[1].to(args.device),
'labels': batch[2],
'entity_pos': batch[3],
'hts': batch[4],
}
#with autocast():
outputs = model(**inputs)
loss = outputs[0] / args.gradient_accumulation_steps
total_loss += loss.item()
# scaler.scale(loss).backward()
loss.backward()
if step % args.gradient_accumulation_steps == 0:
#scaler.unscale_(optimizer)
if args.max_grad_norm > 0:
# torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
#scaler.step(optimizer)
#scaler.update()
#scheduler.step()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
num_steps += 1
if global_step % log_step == 0:
cur_loss = total_loss / log_step
elapsed = time.time() - start_time
logging(
'| epoch {:2d} | step {:4d} | min/b {:5.2f} | lr {} | train loss {:5.3f}'.format(
epoch, global_step, elapsed / 60, scheduler.get_lr(), cur_loss * 1000))
total_loss = 0
start_time = time.time()
if (step + 1) == len(train_dataloader) - 1 or (args.evaluation_steps > 0 and num_steps % args.evaluation_steps == 0 and step % args.gradient_accumulation_steps == 0):
# if step ==0:
logging('-' * 89)
eval_start_time = time.time()
dev_score, dev_output = evaluate(args, model, dev_features, tag="dev")
logging(
'| epoch {:3d} | time: {:5.2f}s | dev_result:{}'.format(epoch, time.time() - eval_start_time,
dev_output))
logging('-' * 89)
if dev_score > best_score:
best_score = dev_score
logging(
'| epoch {:3d} | best_f1:{}'.format(epoch, best_score))
pred = report(args, model, test_features)
with open("result.json", "w") as fh:
json.dump(pred, fh)
if args.save_path != "":
torch.save({
'epoch': epoch,
'checkpoint': model.state_dict(),
'best_f1': best_score,
'optimizer': optimizer.state_dict()
}, args.save_path
, _use_new_zipfile_serialization=False)
return num_steps
extract_layer = ["extractor", "bilinear"]
bert_layer = ['bert_model']
optimizer_grouped_parameters = [
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in bert_layer)], "lr": args.bert_lr},
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in extract_layer)], "lr": 1e-4},
{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in extract_layer + bert_layer)]},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
if args.train_from_saved_model != '':
optimizer.load_state_dict(torch.load(args.train_from_saved_model)["optimizer"])
print("load saved optimizer from {}.".format(args.train_from_saved_model))
num_steps = 0
set_seed(args)
model.zero_grad()
finetune(train_features, optimizer, args.num_train_epochs, num_steps, model)
def evaluate(args, model, features, tag="dev"):
dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)
preds = []
total_loss = 0
for i, batch in enumerate(dataloader):
model.eval()
inputs = {'input_ids': batch[0].to(args.device),
'attention_mask': batch[1].to(args.device),
'labels': batch[2],
'entity_pos': batch[3],
'hts': batch[4],
}
with torch.no_grad():
output = model(**inputs)
loss = output[0]
pred = output[1].cpu().numpy()
pred[np.isnan(pred)] = 0
preds.append(pred)
total_loss += loss.item()
average_loss = total_loss / (i + 1)
preds = np.concatenate(preds, axis=0).astype(np.float32)
ans = to_official(preds, features)
if len(ans) > 0:
best_f1, _, best_f1_ign, _, re_p, re_r = official_evaluate(ans, args.data_dir)
output = {
tag + "_F1": best_f1 * 100,
tag + "_F1_ign": best_f1_ign * 100,
tag + "_re_p": re_p * 100,
tag + "_re_r": re_r * 100,
tag + "_average_loss": average_loss
}
return best_f1, output
def report(args, model, features):
dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)
preds = []
for batch in dataloader:
model.eval()
inputs = {'input_ids': batch[0].to(args.device),
'attention_mask': batch[1].to(args.device),
'entity_pos': batch[3],
'hts': batch[4],
}
with torch.no_grad():
pred, *_ = model(**inputs)
pred = pred.cpu().numpy()
pred[np.isnan(pred)] = 0
preds.append(pred)
print(preds)
preds = np.concatenate(preds, axis=0).astype(np.float32)
print(preds)
preds = to_official(preds, features)
print(preds)
return preds
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", default="./dataset/docred", type=str)
parser.add_argument("--transformer_type", default="bert", type=str)
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str)
parser.add_argument("--train_file", default="train_annotated.json", type=str)
parser.add_argument("--dev_file", default="dev.json", type=str)
parser.add_argument("--test_file", default="test.json", type=str)
parser.add_argument("--save_path", default="", type=str)
parser.add_argument("--load_path", default="", type=str)
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", default="", type=str,
help="Pretrained tokenizer name or path if not the same as model_name")
parser.add_argument("--max_seq_length", default=1024, type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--train_batch_size", default=4, type=int,
help="Batch size for training.")
parser.add_argument("--test_batch_size", default=8, type=int,
help="Batch size for testing.")
parser.add_argument("--gradient_accumulation_steps", default=1, type=int,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--num_labels", default=4, type=int,
help="Max number of labels in prediction.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--bert_lr", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--adam_epsilon", default=1e-6, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.")
parser.add_argument("--warmup_ratio", default=0.06, type=float,
help="Warm up ratio for Adam.")
parser.add_argument("--num_train_epochs", default=30.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--evaluation_steps", default=-1, type=int,
help="Number of training steps between evaluations.")
parser.add_argument("--seed", type=int, default=66,
help="random seed for initialization")
parser.add_argument("--num_class", type=int, default=97,
help="Number of relation types in dataset.")
parser.add_argument("--unet_in_dim", type=int, default=3,
help="unet_in_dim.")
parser.add_argument("--unet_out_dim", type=int, default=256,
help="unet_out_dim.")
parser.add_argument("--down_dim", type=int, default=256,
help="down_dim.")
parser.add_argument("--channel_type", type=str, default='',
help="unet_out_dim.")
parser.add_argument("--log_dir", type=str, default='',
help="log.")
parser.add_argument("--max_height", type=int, default=42,
help="log.")
parser.add_argument("--train_from_saved_model", type=str, default='',
help="train from a saved model.")
parser.add_argument("--dataset", type=str, default='docred',
help="dataset type")
args = parser.parse_args()
print('args:',args)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
args.n_gpu = torch.cuda.device_count()
args.device = device
config = AutoConfig.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path,
num_labels=args.num_class,
)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
)
Dataset = ReadDataset(args.dataset, tokenizer, args.max_seq_length)
train_file = os.path.join(args.data_dir, args.train_file)
dev_file = os.path.join(args.data_dir, args.dev_file)
test_file = os.path.join(args.data_dir, args.test_file)
train_features = Dataset.read(train_file)
dev_features = Dataset.read(dev_file)
test_features = Dataset.read(test_file)
model = AutoModel.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
)
config.cls_token_id = tokenizer.cls_token_id
config.sep_token_id = tokenizer.sep_token_id
config.transformer_type = args.transformer_type
set_seed(args)
model = DocREModel(config, args, model, num_labels=args.num_labels)
if args.train_from_saved_model != '':
model.load_state_dict(torch.load(args.train_from_saved_model)["checkpoint"])
print("load saved model from {}.".format(args.train_from_saved_model))
model.to(0)
if args.load_path == "": # Training
train(args, model, train_features, dev_features, test_features)
else: # Testing
model.load_state_dict(torch.load(args.load_path)['checkpoint'])
T_features = test_features # Testing on the test set
T_score, T_output = evaluate(args, model, T_features, tag="test")
print(T_output)
pred = report(args, model, T_features)
with open("result.json", "w") as fh:
json.dump(pred, fh)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,73 @@
import torch
import random
import numpy as np
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0 and torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
def collate_fn_sample(batch):
max_len = max([len(f["input_ids"]) for f in batch])
input_ids = [f["input_ids"] + [0] * (max_len - len(f["input_ids"])) for f in batch]
input_mask = [[1.0] * len(f["input_ids"]) + [0.0] * (max_len - len(f["input_ids"])) for f in batch]
input_ids = torch.tensor(input_ids, dtype=torch.long)
input_mask = torch.tensor(input_mask, dtype=torch.float)
entity_pos = [f["entity_pos"] for f in batch]
negative_alpha = 8
positive_alpha = 1
labels, hts = [], []
for f in batch:
randnum = random.randint(0, 1000000)
pos_hts = f['pos_hts']
pos_labels = f['pos_labels']
neg_hts = f['neg_hts']
neg_labels = f['neg_labels']
if negative_alpha > 0:
random.seed(randnum)
random.shuffle(neg_hts)
random.seed(randnum)
random.shuffle(neg_labels)
lower_bound = int(max(20, len(pos_hts) * negative_alpha))
hts.append( pos_hts * positive_alpha + neg_hts[:lower_bound] )
labels.append( pos_labels * positive_alpha + neg_labels[:lower_bound] )
# labels = [f["labels"] for f in batch]
# hts = [f["hts"] for f in batch]
# entity_pos_single = []
# # for f in batch:
# # entity_pos_item = f["entity_pos"]
# # entity_pos2 = []
# # for e in entity_pos_item:
# # entity_pos2.append([])
# # mention_num = len(e)
# # bounds = np.random.randint(mention_num, size=3)
# # for bound in bounds:
# # entity_pos2[-1].append(e[bound])
# # entity_pos_single.append( torch.tensor(entity_pos2) )
output = (input_ids, input_mask, labels, entity_pos, hts, )
return output
def collate_fn(batch):
max_len = max([len(f["input_ids"]) for f in batch])
input_ids = [f["input_ids"] + [0] * (max_len - len(f["input_ids"])) for f in batch]
input_mask = [[1.0] * len(f["input_ids"]) + [0.0] * (max_len - len(f["input_ids"])) for f in batch]
input_ids = torch.tensor(input_ids, dtype=torch.long)
input_mask = torch.tensor(input_mask, dtype=torch.float)
entity_pos = [f["entity_pos"] for f in batch]
labels = [f["labels"] for f in batch]
hts = [f["hts"] for f in batch]
output = (input_ids, input_mask, labels, entity_pos, hts )
return output