parent
acae147b2b
commit
69da55ecfb
|
@ -0,0 +1,92 @@
|
|||
[**中文**](https://github.com/zjunlp/DocRE/blob/master/README_CN.md) | [**English**](https://github.com/zjunlp/DocRE/blob/master/README.md)
|
||||
|
||||
>
|
||||
|
||||
<p align="center">
|
||||
<font size=7><strong>DocNet:Document-level Relation Extraction as Semantic Segmentation</strong></font>
|
||||
</p>
|
||||
|
||||
|
||||
|
||||
This repository is the official implementation of [**DocuNet**](https://github.com/zjunlp/DocRE/), which is model proposed in a paper: **[Document-level Relation Extraction as Semantic Segmentation](https://www.ijcai.org/proceedings/2021/551)**, accepted by **IJCAI2021** main conference.
|
||||
|
||||
|
||||
# Contributor
|
||||
Xiang Chen, Xin Xie, Shuming Deng, Ningyu Zhang, and Huajun Chen.
|
||||
|
||||
|
||||
# Brief Introduction
|
||||
This paper innovatively proposes the DocuNet model, which first regards the document-level relation extraction as the semantic segmentation task in computer vision.
|
||||
|
||||
|
||||
# Requirements
|
||||
|
||||
To install requirements:
|
||||
|
||||
```setup
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
# Training
|
||||
|
||||
To train the DocuNet model in the paper on the dataset [DocRED](https://github.com/thunlp/DocRE), run this command:
|
||||
|
||||
```bash
|
||||
>> bash scripts/run_docred.sh # use BERT/RoBERTa by setting --transformer-type
|
||||
```
|
||||
|
||||
To train the DocuNet model in the paper on the dataset CDR and GDA, run this command:
|
||||
|
||||
```bash
|
||||
>> bash scripts/run_cdr.sh # for CDR
|
||||
>> bash scripts/run_gda.sh # for GDA
|
||||
```
|
||||
|
||||
|
||||
|
||||
# Evaluation
|
||||
|
||||
To evaluate the trained model in the paper, you setting the `--load_path` argument in training scripts. The program will log the result of evaluation automatically. And for DocRED it will generate a test file `result.json` in the official evaluation format. You can compress and submit it to Colab for the official test score.
|
||||
|
||||
|
||||
# Results
|
||||
|
||||
Our model achieves the following performance on :
|
||||
|
||||
### Document-level Relation Extraction on [DocRED](https://github.com/thunlp/DocRED)
|
||||
|
||||
|
||||
| Model | Ign F1 on Dev | F1 on Dev | Ign F1 on Test | F1 on Test |
|
||||
| :----------------: |:--------------: | :------------: | ------------------ | ------------------ |
|
||||
| DocuNet-BERT (base) | 59.86±0.13 | 61.83±0.19 | 59.93 | 61.86 |
|
||||
| DocuNet-RoBERTa (large) | 62.23±0.12 | 64.12±0.14 | 62.39 | 64.55 |
|
||||
|
||||
### Document-level Relation Extraction on [CDR and GDA](https://github.com/fenchri/edge-oriented-graph)
|
||||
|
||||
| Model | CDR | GDA |
|
||||
| :----------------: | :----------------: | :----------------: |
|
||||
| DocuNet-SciBERT (base) | 76.3±0.40 | 85.3±0.50 |
|
||||
|
||||
|
||||
|
||||
|
||||
# Papers for the Project & How to Cite
|
||||
If you use or extend our work, please cite the following paper:
|
||||
|
||||
```
|
||||
@inproceedings{ijcai2021-551,
|
||||
title = {Document-level Relation Extraction as Semantic Segmentation},
|
||||
author = {Zhang, Ningyu and Chen, Xiang and Xie, Xin and Deng, Shumin and Tan, Chuanqi and Chen, Mosha and Huang, Fei and Si, Luo and Chen, Huajun},
|
||||
booktitle = {Proceedings of the Thirtieth International Joint Conference on
|
||||
Artificial Intelligence, {IJCAI-21}},
|
||||
publisher = {International Joint Conferences on Artificial Intelligence Organization},
|
||||
editor = {Zhi-Hua Zhou},
|
||||
pages = {3999--4006},
|
||||
year = {2021},
|
||||
month = {8},
|
||||
note = {Main Track}
|
||||
doi = {10.24963/ijcai.2021/551},
|
||||
url = {https://doi.org/10.24963/ijcai.2021/551},
|
||||
}
|
||||
```
|
|
@ -0,0 +1,90 @@
|
|||
[**中文**](https://github.com/zjunlp/DocRE/blob/master/README_CN.md) | [**English**](https://github.com/zjunlp/DocRE/blob/master/README.md)
|
||||
|
||||
<p align="center">
|
||||
<font size=7><strong>DocuNet:一个基于语义分割方法实现文档级关系抽取的模型</strong></font>
|
||||
</p>
|
||||
|
||||
这是针对我们[**DocuNet**](https://github.com/zjunlp/DocuNet)项目的官方实现代码。这个模型是在**[Document-level Relation Extraction as Semantic Segmentation](https://www.ijcai.org/proceedings/2021/551)**论文中提出来的,该论文已被**IJCAI2021**主会录用。
|
||||
|
||||
|
||||
# 项目成员
|
||||
陈想,谢辛,邓淑敏,张宁豫,陈华钧。
|
||||
|
||||
|
||||
# 项目简介
|
||||
本文创新性地提出DocuNet模型,首次将文档级关系抽取任务类比于计算机视觉中的语义分割任务。
|
||||
|
||||
|
||||
# 环境要求
|
||||
需要按以下命令去配置项目运行环境:
|
||||
|
||||
```运行准备
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
# 模型训练
|
||||
|
||||
## DocRED
|
||||
|
||||
|
||||
请运行以下命令在DocRED中训练DocuNet模型:
|
||||
|
||||
```bash
|
||||
>> bash scripts/run_docred.sh # use BERT/RoBERTa by setting --transformer-type
|
||||
```
|
||||
## CDR和GDA
|
||||
|
||||
请运行以下命令在CDR和GDA中训练DocuNet模型:
|
||||
|
||||
```bash
|
||||
>> bash scripts/run_cdr.sh # for CDR
|
||||
>> bash scripts/run_gda.sh # for GDA
|
||||
```
|
||||
|
||||
数据集GDR和CDA可以根据[edge-oriented graph](https://github.com/fenchri/edge-oriented-graph)指南获取。
|
||||
|
||||
|
||||
# 评估效果
|
||||
>要评估论文中的训练模型,您可以在训练脚本中设置 `--load_path` 参数。程序会自动记录评估结果。对于 DocRED,它将生成一个官方评估格式的测试文件 `result.json`。您可以压缩并提交给 Colab 以获得官方测试分数。
|
||||
|
||||
|
||||
# 结果
|
||||
|
||||
我们的模型达到了以下的性能:
|
||||
|
||||
### 在[DocRED](https://github.com/thunlp/DocRED)上的文档级关系抽取
|
||||
|
||||
| 模型 | Ign F1 on Dev | F1 on Dev | Ign F1 on Test | F1 on Test |
|
||||
| :----------------: |:--------------: | :------------: | ------------------ | ------------------ |
|
||||
| DocuNet-BERT (base) | 59.86±0.13 | 61.83±0.19 | 59.93 | 61.86 |
|
||||
| DocuNet-RoBERTa (large) | 62.23±0.12 | 64.12±0.14 | 62.39 | 64.55 |
|
||||
|
||||
|
||||
### 在[CDR和GDA](https://github.com/fenchri/edge-oriented-graph)上的文档级关系抽取
|
||||
|
||||
| 模型 | CDR | GDA |
|
||||
| :----------------: | :----------------: | :----------------: |
|
||||
| DocuNet-SciBERT (base) | 76.3±0.40 | 85.3±0.50 |
|
||||
|
||||
|
||||
# 有关论文
|
||||
如果您使用或拓展我们的工作,请引用以下论文:
|
||||
|
||||
```
|
||||
@inproceedings{ijcai2021-551,
|
||||
title = {Document-level Relation Extraction as Semantic Segmentation},
|
||||
author = {Zhang, Ningyu and Chen, Xiang and Xie, Xin and Deng, Shumin and Tan, Chuanqi and Chen, Mosha and Huang, Fei and Si, Luo and Chen, Huajun},
|
||||
booktitle = {Proceedings of the Thirtieth International Joint Conference on
|
||||
Artificial Intelligence, {IJCAI-21}},
|
||||
publisher = {International Joint Conferences on Artificial Intelligence Organization},
|
||||
editor = {Zhi-Hua Zhou},
|
||||
pages = {3999--4006},
|
||||
year = {2021},
|
||||
month = {8},
|
||||
note = {Main Track}
|
||||
doi = {10.24963/ijcai.2021/551},
|
||||
url = {https://doi.org/10.24963/ijcai.2021/551},
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
|
||||
class AttentionUNet(torch.nn.Module):
|
||||
"""
|
||||
UNet, down sampling & up sampling for global reasoning
|
||||
"""
|
||||
|
||||
def __init__(self, input_channels, class_number, **kwargs):
|
||||
super(AttentionUNet, self).__init__()
|
||||
|
||||
down_channel = kwargs['down_channel'] # default = 256
|
||||
|
||||
down_channel_2 = down_channel * 2
|
||||
up_channel_1 = down_channel_2 * 2
|
||||
up_channel_2 = down_channel * 2
|
||||
|
||||
self.inc = InConv(input_channels, down_channel)
|
||||
self.down1 = DownLayer(down_channel, down_channel_2)
|
||||
self.down2 = DownLayer(down_channel_2, down_channel_2)
|
||||
|
||||
self.up1 = UpLayer(up_channel_1, up_channel_1 // 4)
|
||||
self.up2 = UpLayer(up_channel_2, up_channel_2 // 4)
|
||||
self.outc = OutConv(up_channel_2 // 4, class_number)
|
||||
|
||||
def forward(self, attention_channels):
|
||||
"""
|
||||
Given multi-channel attention map, return the logits of every one mapping into 3-class
|
||||
:param attention_channels:
|
||||
:return:
|
||||
"""
|
||||
# attention_channels as the shape of: batch_size x channel x width x height
|
||||
x = attention_channels
|
||||
x1 = self.inc(x)
|
||||
x2 = self.down1(x1)
|
||||
x3 = self.down2(x2)
|
||||
x = self.up1(x3, x2)
|
||||
x = self.up2(x, x1)
|
||||
output = self.outc(x)
|
||||
# attn_map as the shape of: batch_size x width x height x class
|
||||
output = output.permute(0, 2, 3, 1).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
class DoubleConv(nn.Module):
|
||||
"""(conv => [BN] => ReLU) * 2"""
|
||||
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(DoubleConv, self).__init__()
|
||||
self.double_conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(inplace=True))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.double_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class InConv(nn.Module):
|
||||
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(InConv, self).__init__()
|
||||
self.conv = DoubleConv(in_ch, out_ch)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class DownLayer(nn.Module):
|
||||
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(DownLayer, self).__init__()
|
||||
self.maxpool_conv = nn.Sequential(
|
||||
nn.MaxPool2d(kernel_size=2),
|
||||
DoubleConv(in_ch, out_ch)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.maxpool_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class UpLayer(nn.Module):
|
||||
|
||||
def __init__(self, in_ch, out_ch, bilinear=True):
|
||||
super(UpLayer, self).__init__()
|
||||
if bilinear:
|
||||
self.up = nn.Upsample(scale_factor=2, mode='bilinear',
|
||||
align_corners=True)
|
||||
else:
|
||||
self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
|
||||
self.conv = DoubleConv(in_ch, out_ch)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x1 = self.up(x1)
|
||||
diffY = x2.size()[2] - x1.size()[2]
|
||||
diffX = x2.size()[3] - x1.size()[3]
|
||||
x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY -
|
||||
diffY // 2))
|
||||
x = torch.cat([x2, x1], dim=1)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class OutConv(nn.Module):
|
||||
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super(OutConv, self).__init__()
|
||||
self.conv = nn.Conv2d(in_ch, out_ch, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
|
@ -0,0 +1 @@
|
|||
# this is an empty file
|
|
@ -0,0 +1 @@
|
|||
# this is an empty file
|
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
from overrides import overrides
|
||||
|
||||
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention
|
||||
|
||||
|
||||
@MatrixAttention.register("ele_multiply")
|
||||
class ElementWiseMatrixAttention(MatrixAttention):
|
||||
"""
|
||||
This similarity function simply computes the dot product between each pair of vectors, with an
|
||||
optional scaling to reduce the variance of the output elements.
|
||||
Parameters
|
||||
----------
|
||||
scale_output : ``bool``, optional
|
||||
If ``True``, we will scale the output by ``math.sqrt(tensor.size(-1))``, to reduce the
|
||||
variance in the result.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
super(ElementWiseMatrixAttention, self).__init__()
|
||||
|
||||
@overrides
|
||||
def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor:
|
||||
result = torch.einsum('iaj,ibj->ijab', [tensor_1, tensor_2])
|
||||
return result
|
|
@ -0,0 +1,176 @@
|
|||
import os
|
||||
import os.path
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
rel2id = json.load(open('../meta/rel2id.json', 'r'))
|
||||
id2rel = {value: key for key, value in rel2id.items()}
|
||||
|
||||
|
||||
def to_official(preds, features):
|
||||
h_idx, t_idx, title = [], [], []
|
||||
|
||||
for f in features:
|
||||
hts = f["hts"]
|
||||
h_idx += [ht[0] for ht in hts]
|
||||
t_idx += [ht[1] for ht in hts]
|
||||
title += [f["title"] for ht in hts]
|
||||
|
||||
res = []
|
||||
print('h_idx, preds', len(h_idx), len(preds))
|
||||
# assert len(h_idx) == len(preds)
|
||||
|
||||
|
||||
for i in range(preds.shape[0]):
|
||||
pred = preds[i]
|
||||
pred = np.nonzero(pred)[0].tolist()
|
||||
for p in pred:
|
||||
if p != 0:
|
||||
res.append(
|
||||
{
|
||||
'title': title[i],
|
||||
'h_idx': h_idx[i],
|
||||
't_idx': t_idx[i],
|
||||
'r': id2rel[p],
|
||||
}
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
def gen_train_facts(data_file_name, truth_dir):
|
||||
fact_file_name = data_file_name[data_file_name.find("train_"):]
|
||||
fact_file_name = os.path.join(truth_dir, fact_file_name.replace(".json", ".fact"))
|
||||
|
||||
if os.path.exists(fact_file_name):
|
||||
fact_in_train = set([])
|
||||
triples = json.load(open(fact_file_name))
|
||||
for x in triples:
|
||||
fact_in_train.add(tuple(x))
|
||||
return fact_in_train
|
||||
|
||||
fact_in_train = set([])
|
||||
ori_data = json.load(open(data_file_name))
|
||||
for data in ori_data:
|
||||
vertexSet = data['vertexSet']
|
||||
for label in data['labels']:
|
||||
rel = label['r']
|
||||
for n1 in vertexSet[label['h']]:
|
||||
for n2 in vertexSet[label['t']]:
|
||||
fact_in_train.add((n1['name'], n2['name'], rel))
|
||||
|
||||
json.dump(list(fact_in_train), open(fact_file_name, "w"))
|
||||
|
||||
return fact_in_train
|
||||
|
||||
|
||||
def official_evaluate(tmp, path):
|
||||
'''
|
||||
Adapted from the official evaluation code
|
||||
'''
|
||||
truth_dir = os.path.join(path, 'ref')
|
||||
|
||||
if not os.path.exists(truth_dir):
|
||||
os.makedirs(truth_dir)
|
||||
|
||||
fact_in_train_annotated = gen_train_facts(os.path.join(path, "train_annotated.json"), truth_dir)
|
||||
fact_in_train_distant = gen_train_facts(os.path.join(path, "train_distant.json"), truth_dir)
|
||||
|
||||
truth = json.load(open(os.path.join(path, "dev.json")))
|
||||
|
||||
std = {}
|
||||
tot_evidences = 0
|
||||
titleset = set([])
|
||||
|
||||
title2vectexSet = {}
|
||||
|
||||
for x in truth:
|
||||
title = x['title']
|
||||
titleset.add(title)
|
||||
|
||||
vertexSet = x['vertexSet']
|
||||
title2vectexSet[title] = vertexSet
|
||||
|
||||
for label in x['labels']:
|
||||
r = label['r']
|
||||
h_idx = label['h']
|
||||
t_idx = label['t']
|
||||
std[(title, r, h_idx, t_idx)] = set(label['evidence'])
|
||||
tot_evidences += len(label['evidence'])
|
||||
|
||||
tot_relations = len(std)
|
||||
tmp.sort(key=lambda x: (x['title'], x['h_idx'], x['t_idx'], x['r']))
|
||||
submission_answer = [tmp[0]]
|
||||
for i in range(1, len(tmp)):
|
||||
x = tmp[i]
|
||||
y = tmp[i - 1]
|
||||
if (x['title'], x['h_idx'], x['t_idx'], x['r']) != (y['title'], y['h_idx'], y['t_idx'], y['r']):
|
||||
submission_answer.append(tmp[i])
|
||||
|
||||
correct_re = 0
|
||||
correct_evidence = 0
|
||||
pred_evi = 0
|
||||
|
||||
correct_in_train_annotated = 0
|
||||
correct_in_train_distant = 0
|
||||
titleset2 = set([])
|
||||
for x in submission_answer:
|
||||
title = x['title']
|
||||
h_idx = x['h_idx']
|
||||
t_idx = x['t_idx']
|
||||
r = x['r']
|
||||
titleset2.add(title)
|
||||
if title not in title2vectexSet:
|
||||
continue
|
||||
vertexSet = title2vectexSet[title]
|
||||
|
||||
if 'evidence' in x:
|
||||
evi = set(x['evidence'])
|
||||
else:
|
||||
evi = set([])
|
||||
pred_evi += len(evi)
|
||||
|
||||
if (title, r, h_idx, t_idx) in std:
|
||||
correct_re += 1
|
||||
stdevi = std[(title, r, h_idx, t_idx)]
|
||||
correct_evidence += len(stdevi & evi)
|
||||
in_train_annotated = in_train_distant = False
|
||||
for n1 in vertexSet[h_idx]:
|
||||
for n2 in vertexSet[t_idx]:
|
||||
if (n1['name'], n2['name'], r) in fact_in_train_annotated:
|
||||
in_train_annotated = True
|
||||
if (n1['name'], n2['name'], r) in fact_in_train_distant:
|
||||
in_train_distant = True
|
||||
|
||||
if in_train_annotated:
|
||||
correct_in_train_annotated += 1
|
||||
if in_train_distant:
|
||||
correct_in_train_distant += 1
|
||||
|
||||
re_p = 1.0 * correct_re / len(submission_answer)
|
||||
re_r = 1.0 * correct_re / tot_relations
|
||||
if re_p + re_r == 0:
|
||||
re_f1 = 0
|
||||
else:
|
||||
re_f1 = 2.0 * re_p * re_r / (re_p + re_r)
|
||||
|
||||
evi_p = 1.0 * correct_evidence / pred_evi if pred_evi > 0 else 0
|
||||
evi_r = 1.0 * correct_evidence / tot_evidences
|
||||
if evi_p + evi_r == 0:
|
||||
evi_f1 = 0
|
||||
else:
|
||||
evi_f1 = 2.0 * evi_p * evi_r / (evi_p + evi_r)
|
||||
|
||||
re_p_ignore_train_annotated = 1.0 * (correct_re - correct_in_train_annotated) / (len(submission_answer) - correct_in_train_annotated + 1e-5)
|
||||
re_p_ignore_train = 1.0 * (correct_re - correct_in_train_distant) / (len(submission_answer) - correct_in_train_distant + 1e-5)
|
||||
|
||||
if re_p_ignore_train_annotated + re_r == 0:
|
||||
re_f1_ignore_train_annotated = 0
|
||||
else:
|
||||
re_f1_ignore_train_annotated = 2.0 * re_p_ignore_train_annotated * re_r / (re_p_ignore_train_annotated + re_r)
|
||||
|
||||
if re_p_ignore_train + re_r == 0:
|
||||
re_f1_ignore_train = 0
|
||||
else:
|
||||
re_f1_ignore_train = 2.0 * re_p_ignore_train * re_r / (re_p_ignore_train + re_r)
|
||||
|
||||
return re_f1, evi_f1, re_f1_ignore_train_annotated, re_f1_ignore_train, re_p, re_r
|
|
@ -0,0 +1 @@
|
|||
# this is an empty file
|
|
@ -0,0 +1,77 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
def process_long_input(model, input_ids, attention_mask, start_tokens, end_tokens):
|
||||
# Split the input to 2 overlapping chunks. Now BERT can encode inputs of which the length are up to 1024.
|
||||
n, c = input_ids.size()
|
||||
start_tokens = torch.tensor(start_tokens).to(input_ids)
|
||||
end_tokens = torch.tensor(end_tokens).to(input_ids)
|
||||
len_start = start_tokens.size(0)
|
||||
len_end = end_tokens.size(0)
|
||||
if c <= 512:
|
||||
output = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
sequence_output = output[0]
|
||||
attention = output[-1][-1]
|
||||
else:
|
||||
new_input_ids, new_attention_mask, num_seg = [], [], []
|
||||
seq_len = attention_mask.sum(1).cpu().numpy().astype(np.int32).tolist()
|
||||
for i, l_i in enumerate(seq_len):
|
||||
if l_i <= 512:
|
||||
new_input_ids.append(input_ids[i, :512])
|
||||
new_attention_mask.append(attention_mask[i, :512])
|
||||
num_seg.append(1)
|
||||
else:
|
||||
input_ids1 = torch.cat([input_ids[i, :512 - len_end], end_tokens], dim=-1)
|
||||
input_ids2 = torch.cat([start_tokens, input_ids[i, (l_i - 512 + len_start): l_i]], dim=-1)
|
||||
attention_mask1 = attention_mask[i, :512]
|
||||
attention_mask2 = attention_mask[i, (l_i - 512): l_i]
|
||||
new_input_ids.extend([input_ids1, input_ids2])
|
||||
new_attention_mask.extend([attention_mask1, attention_mask2])
|
||||
num_seg.append(2)
|
||||
input_ids = torch.stack(new_input_ids, dim=0)
|
||||
attention_mask = torch.stack(new_attention_mask, dim=0)
|
||||
output = model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=True,
|
||||
)
|
||||
sequence_output = output[0]
|
||||
attention = output[-1][-1]
|
||||
i = 0
|
||||
new_output, new_attention = [], []
|
||||
for (n_s, l_i) in zip(num_seg, seq_len):
|
||||
if n_s == 1:
|
||||
output = F.pad(sequence_output[i], (0, 0, 0, c - 512))
|
||||
att = F.pad(attention[i], (0, c - 512, 0, c - 512))
|
||||
new_output.append(output)
|
||||
new_attention.append(att)
|
||||
elif n_s == 2:
|
||||
output1 = sequence_output[i][:512 - len_end]
|
||||
mask1 = attention_mask[i][:512 - len_end]
|
||||
att1 = attention[i][:, :512 - len_end, :512 - len_end]
|
||||
output1 = F.pad(output1, (0, 0, 0, c - 512 + len_end))
|
||||
mask1 = F.pad(mask1, (0, c - 512 + len_end))
|
||||
att1 = F.pad(att1, (0, c - 512 + len_end, 0, c - 512 + len_end))
|
||||
|
||||
output2 = sequence_output[i + 1][len_start:]
|
||||
mask2 = attention_mask[i + 1][len_start:]
|
||||
att2 = attention[i + 1][:, len_start:, len_start:]
|
||||
output2 = F.pad(output2, (0, 0, l_i - 512 + len_start, c - l_i))
|
||||
mask2 = F.pad(mask2, (l_i - 512 + len_start, c - l_i))
|
||||
att2 = F.pad(att2, [l_i - 512 + len_start, c - l_i, l_i - 512 + len_start, c - l_i])
|
||||
mask = mask1 + mask2 + 1e-10
|
||||
output = (output1 + output2) / mask.unsqueeze(-1)
|
||||
att = (att1 + att2)
|
||||
att = att / (att.sum(-1, keepdim=True) + 1e-10)
|
||||
new_output.append(output)
|
||||
new_attention.append(att)
|
||||
i += n_s
|
||||
sequence_output = torch.stack(new_output, dim=0)
|
||||
attention = torch.stack(new_attention, dim=0)
|
||||
return sequence_output, attention
|
|
@ -0,0 +1,87 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def multilabel_categorical_crossentropy(y_true, y_pred):
|
||||
"""多标签分类的交叉熵
|
||||
说明:y_true和y_pred的shape一致,y_true的元素非0即1,
|
||||
1表示对应的类为目标类,0表示对应的类为非目标类。
|
||||
警告:请保证y_pred的值域是全体实数,换言之一般情况下y_pred
|
||||
不用加激活函数,尤其是不能加sigmoid或者softmax!预测
|
||||
阶段则输出y_pred大于0的类。如有疑问,请仔细阅读并理解
|
||||
本文。
|
||||
"""
|
||||
y_pred = (1 - 2 * y_true) * y_pred
|
||||
y_pred_neg = y_pred - y_true * 1e30
|
||||
y_pred_pos = y_pred - (1 - y_true) * 1e30
|
||||
zeros = torch.zeros_like(y_pred[..., :1])
|
||||
y_pred_neg = torch.cat([y_pred_neg, zeros],dim=-1)
|
||||
y_pred_pos = torch.cat((y_pred_pos, zeros),dim=-1)
|
||||
neg_loss = torch.logsumexp(y_pred_neg, axis=-1)
|
||||
pos_loss = torch.logsumexp(y_pred_pos, axis=-1)
|
||||
return neg_loss + pos_loss
|
||||
|
||||
|
||||
class balanced_loss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
|
||||
loss = multilabel_categorical_crossentropy(labels,logits)
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
def get_label(self, logits, num_labels=-1):
|
||||
th_logit = torch.zeros_like(logits[..., :1])
|
||||
output = torch.zeros_like(logits).to(logits)
|
||||
mask = (logits > th_logit)
|
||||
if num_labels > 0:
|
||||
top_v, _ = torch.topk(logits, num_labels, dim=1)
|
||||
top_v = top_v[:, -1]
|
||||
mask = (logits >= top_v.unsqueeze(1)) & mask
|
||||
output[mask] = 1.0
|
||||
output[:, 0] = (output[:,1:].sum(1) == 0.).to(logits)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ATLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
# TH label
|
||||
th_label = torch.zeros_like(labels, dtype=torch.float).to(labels)
|
||||
th_label[:, 0] = 1.0
|
||||
labels[:, 0] = 0.0
|
||||
|
||||
p_mask = labels + th_label
|
||||
n_mask = 1 - labels
|
||||
|
||||
# Rank positive classes to TH
|
||||
logit1 = logits - (1 - p_mask) * 1e30
|
||||
loss1 = -(F.log_softmax(logit1, dim=-1) * labels).sum(1)
|
||||
|
||||
# Rank TH to negative classes
|
||||
logit2 = logits - (1 - n_mask) * 1e30
|
||||
loss2 = -(F.log_softmax(logit2, dim=-1) * th_label).sum(1)
|
||||
|
||||
# Sum two parts
|
||||
loss = loss1 + loss2
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
def get_label(self, logits, num_labels=-1):
|
||||
th_logit = logits[:, 0].unsqueeze(1)
|
||||
output = torch.zeros_like(logits).to(logits)
|
||||
mask = (logits > th_logit)
|
||||
if num_labels > 0:
|
||||
top_v, _ = torch.topk(logits, num_labels, dim=1)
|
||||
top_v = top_v[:, -1]
|
||||
mask = (logits >= top_v.unsqueeze(1)) & mask
|
||||
output[mask] = 1.0
|
||||
output[:, 0] = (output.sum(1) == 0.).to(logits)
|
||||
return output
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
{
|
||||
"P1376": 79,
|
||||
"P607": 27,
|
||||
"P136": 73,
|
||||
"P137": 63,
|
||||
"P131": 2,
|
||||
"P527": 11,
|
||||
"P1412": 38,
|
||||
"P206": 33,
|
||||
"P205": 77,
|
||||
"P449": 52,
|
||||
"P127": 34,
|
||||
"P123": 49,
|
||||
"P86": 66,
|
||||
"P840": 85,
|
||||
"P355": 72,
|
||||
"P737": 93,
|
||||
"P740": 84,
|
||||
"P190": 94,
|
||||
"P576": 71,
|
||||
"P749": 68,
|
||||
"P112": 65,
|
||||
"P118": 40,
|
||||
"P17": 1,
|
||||
"P19": 14,
|
||||
"P3373": 19,
|
||||
"P6": 42,
|
||||
"P276": 44,
|
||||
"P1001": 24,
|
||||
"P580": 62,
|
||||
"P582": 83,
|
||||
"P585": 64,
|
||||
"P463": 18,
|
||||
"P676": 87,
|
||||
"P674": 46,
|
||||
"P264": 10,
|
||||
"P108": 43,
|
||||
"P102": 17,
|
||||
"P25": 81,
|
||||
"P27": 3,
|
||||
"P26": 26,
|
||||
"P20": 37,
|
||||
"P22": 30,
|
||||
"Na": 0,
|
||||
"P807": 95,
|
||||
"P800": 51,
|
||||
"P279": 78,
|
||||
"P1336": 88,
|
||||
"P577": 5,
|
||||
"P570": 8,
|
||||
"P571": 15,
|
||||
"P178": 36,
|
||||
"P179": 55,
|
||||
"P272": 75,
|
||||
"P170": 35,
|
||||
"P171": 80,
|
||||
"P172": 76,
|
||||
"P175": 6,
|
||||
"P176": 67,
|
||||
"P39": 91,
|
||||
"P30": 21,
|
||||
"P31": 60,
|
||||
"P36": 70,
|
||||
"P37": 58,
|
||||
"P35": 54,
|
||||
"P400": 31,
|
||||
"P403": 61,
|
||||
"P361": 12,
|
||||
"P364": 74,
|
||||
"P569": 7,
|
||||
"P710": 41,
|
||||
"P1344": 32,
|
||||
"P488": 82,
|
||||
"P241": 59,
|
||||
"P162": 57,
|
||||
"P161": 9,
|
||||
"P166": 47,
|
||||
"P40": 20,
|
||||
"P1441": 23,
|
||||
"P156": 45,
|
||||
"P155": 39,
|
||||
"P150": 4,
|
||||
"P551": 90,
|
||||
"P706": 56,
|
||||
"P159": 29,
|
||||
"P495": 13,
|
||||
"P58": 53,
|
||||
"P194": 48,
|
||||
"P54": 16,
|
||||
"P57": 28,
|
||||
"P50": 22,
|
||||
"P1366": 86,
|
||||
"P1365": 92,
|
||||
"P937": 69,
|
||||
"P140": 50,
|
||||
"P69": 25,
|
||||
"P1198": 96,
|
||||
"P1056": 89
|
||||
}
|
|
@ -0,0 +1,207 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from opt_einsum import contract
|
||||
from long_seq import process_long_input
|
||||
from losses import balanced_loss as ATLoss
|
||||
import torch.nn.functional as F
|
||||
from allennlp.modules.matrix_attention import DotProductMatrixAttention, CosineMatrixAttention, BilinearMatrixAttention
|
||||
from element_wise import ElementWiseMatrixAttention
|
||||
from attn_unet import AttentionUNet
|
||||
|
||||
class DocREModel(nn.Module):
|
||||
def __init__(self, config, args, model, emb_size=768, block_size=64, num_labels=-1):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.bert_model = model
|
||||
self.hidden_size = config.hidden_size
|
||||
self.loss_fnt = ATLoss()
|
||||
|
||||
self.head_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)
|
||||
self.tail_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)
|
||||
# self.head_extractor = nn.Linear(1 * config.hidden_size , emb_size)
|
||||
# self.tail_extractor = nn.Linear(1 * config.hidden_size , emb_size)
|
||||
self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)
|
||||
|
||||
self.emb_size = emb_size
|
||||
self.block_size = block_size
|
||||
self.num_labels = num_labels
|
||||
|
||||
self.bertdrop = nn.Dropout(0.6)
|
||||
self.unet_in_dim = args.unet_in_dim
|
||||
self.unet_out_dim = args.unet_in_dim
|
||||
self.liner = nn.Linear(config.hidden_size, args.unet_in_dim)
|
||||
self.min_height = args.max_height
|
||||
self.channel_type = args.channel_type
|
||||
self.segmentation_net = AttentionUNet(input_channels=args.unet_in_dim,
|
||||
class_number=args.unet_out_dim,
|
||||
down_channel=args.down_dim)
|
||||
|
||||
|
||||
def encode(self, input_ids, attention_mask,entity_pos):
|
||||
config = self.config
|
||||
if config.transformer_type == "bert":
|
||||
start_tokens = [config.cls_token_id]
|
||||
end_tokens = [config.sep_token_id]
|
||||
elif config.transformer_type == "roberta":
|
||||
start_tokens = [config.cls_token_id]
|
||||
end_tokens = [config.sep_token_id, config.sep_token_id]
|
||||
sequence_output, attention = process_long_input(self.bert_model, input_ids, attention_mask, start_tokens, end_tokens)
|
||||
return sequence_output, attention
|
||||
|
||||
def get_hrt(self, sequence_output, attention, entity_pos, hts):
|
||||
offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
|
||||
bs, h, _, c = attention.size()
|
||||
# ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数
|
||||
|
||||
hss, tss, rss = [], [], []
|
||||
entity_es = []
|
||||
entity_as = []
|
||||
for i in range(len(entity_pos)):
|
||||
entity_embs, entity_atts = [], []
|
||||
for entity_num, e in enumerate(entity_pos[i]):
|
||||
if len(e) > 1:
|
||||
e_emb, e_att = [], []
|
||||
for start, end in e:
|
||||
if start + offset < c:
|
||||
# In case the entity mention is truncated due to limited max seq length.
|
||||
e_emb.append(sequence_output[i, start + offset])
|
||||
e_att.append(attention[i, :, start + offset])
|
||||
if len(e_emb) > 0:
|
||||
e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
|
||||
e_att = torch.stack(e_att, dim=0).mean(0)
|
||||
else:
|
||||
e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
|
||||
e_att = torch.zeros(h, c).to(attention)
|
||||
else:
|
||||
start, end = e[0]
|
||||
if start + offset < c:
|
||||
e_emb = sequence_output[i, start + offset]
|
||||
e_att = attention[i, :, start + offset]
|
||||
else:
|
||||
e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
|
||||
e_att = torch.zeros(h, c).to(attention)
|
||||
entity_embs.append(e_emb)
|
||||
entity_atts.append(e_att)
|
||||
for _ in range(self.min_height-entity_num-1):
|
||||
entity_atts.append(e_att)
|
||||
|
||||
entity_embs = torch.stack(entity_embs, dim=0) # [n_e, d]
|
||||
entity_atts = torch.stack(entity_atts, dim=0) # [n_e, h, seq_len]
|
||||
|
||||
|
||||
entity_es.append(entity_embs)
|
||||
entity_as.append(entity_atts)
|
||||
ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
|
||||
hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
|
||||
ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
|
||||
|
||||
hss.append(hs)
|
||||
tss.append(ts)
|
||||
hss = torch.cat(hss, dim=0)
|
||||
tss = torch.cat(tss, dim=0)
|
||||
return hss, tss, entity_es, entity_as
|
||||
|
||||
def get_mask(self, ents, bs, ne, run_device):
|
||||
ent_mask = torch.zeros(bs, ne, device=run_device)
|
||||
rel_mask = torch.zeros(bs, ne, ne, device=run_device)
|
||||
for _b in range(bs):
|
||||
ent_mask[_b, :len(ents[_b])] = 1
|
||||
rel_mask[_b, :len(ents[_b]), :len(ents[_b])] = 1
|
||||
return ent_mask, rel_mask
|
||||
|
||||
|
||||
def get_ht(self, rel_enco, hts):
|
||||
htss = []
|
||||
for i in range(len(hts)):
|
||||
ht_index = hts[i]
|
||||
for (h_index, t_index) in ht_index:
|
||||
htss.append(rel_enco[i,h_index,t_index])
|
||||
htss = torch.stack(htss,dim=0)
|
||||
return htss
|
||||
|
||||
def get_channel_map(self, sequence_output, entity_as):
|
||||
# sequence_output = sequence_output.to('cpu')
|
||||
# attention = attention.to('cpu')
|
||||
bs,_,d = sequence_output.size()
|
||||
# ne = max([len(x) for x in entity_as]) # 本次bs中的最大实体数
|
||||
ne = self.min_height
|
||||
|
||||
index_pair = []
|
||||
for i in range(ne):
|
||||
tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)
|
||||
index_pair.append(tmp)
|
||||
index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)
|
||||
map_rss = []
|
||||
for b in range(bs):
|
||||
entity_atts = entity_as[b]
|
||||
h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])
|
||||
t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])
|
||||
ht_att = (h_att * t_att).mean(1)
|
||||
ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
|
||||
rs = contract("ld,rl->rd", sequence_output[b], ht_att)
|
||||
map_rss.append(rs)
|
||||
map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)
|
||||
return map_rss
|
||||
|
||||
def forward(self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
labels=None,
|
||||
entity_pos=None,
|
||||
hts=None,
|
||||
instance_mask=None,
|
||||
):
|
||||
|
||||
sequence_output, attention = self.encode(input_ids, attention_mask,entity_pos)
|
||||
|
||||
bs, sequen_len, d = sequence_output.shape
|
||||
run_device = sequence_output.device.index
|
||||
ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数
|
||||
ent_mask, rel_mask = self.get_mask(entity_pos, bs, ne, run_device)
|
||||
|
||||
# get hs, ts and entity_embs >> entity_rs
|
||||
hs, ts, entity_embs, entity_as = self.get_hrt(sequence_output, attention, entity_pos, hts)
|
||||
|
||||
|
||||
# 获得通道map的两种不同方法
|
||||
if self.channel_type == 'context-based':
|
||||
feature_map = self.get_channel_map(sequence_output, entity_as)
|
||||
##print('feature_map:', feature_map.shape)
|
||||
attn_input = self.liner(feature_map).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
elif self.channel_type == 'similarity-based':
|
||||
ent_encode = sequence_output.new_zeros(bs, self.min_height, d)
|
||||
for _b in range(bs):
|
||||
entity_emb = entity_embs[_b]
|
||||
entity_num = entity_emb.size(0)
|
||||
ent_encode[_b, :entity_num, :] = entity_emb
|
||||
# similar0 = ElementWiseMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
|
||||
similar1 = DotProductMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
|
||||
similar2 = CosineMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
|
||||
similar3 = BilinearMatrixAttention(self.emb_size,self.self.emb_size).to(ent_encode.device)(ent_encode, ent_encode).unsqueeze(-1)
|
||||
attn_input = torch.cat([similar1,similar2,similar3],dim=-1).permute(0, 3, 1, 2).contiguous()
|
||||
else:
|
||||
raise Exception("channel_type must be specify correctly")
|
||||
|
||||
|
||||
attn_map = self.segmentation_net(attn_input)
|
||||
h_t = self.get_ht (attn_map, hts)
|
||||
|
||||
hs = torch.tanh(self.head_extractor(torch.cat([hs, h_t], dim=1)))
|
||||
ts = torch.tanh(self.tail_extractor(torch.cat([ts, h_t], dim=1)))
|
||||
|
||||
|
||||
b1 = hs.view(-1, self.emb_size // self.block_size, self.block_size)
|
||||
b2 = ts.view(-1, self.emb_size // self.block_size, self.block_size)
|
||||
bl = (b1.unsqueeze(3) * b2.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
|
||||
logits = self.bilinear(bl)
|
||||
|
||||
|
||||
output = (self.loss_fnt.get_label(logits, num_labels=self.num_labels))
|
||||
if labels is not None:
|
||||
labels = [torch.tensor(label) for label in labels]
|
||||
labels = torch.cat(labels, dim=0).to(logits)
|
||||
loss = self.loss_fnt(logits.float(), labels.float())
|
||||
output = (loss.to(sequence_output), output)
|
||||
return output
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from opt_einsum import contract
|
||||
from long_seq import process_long_input
|
||||
from losses import ATLoss
|
||||
import torch.nn.functional as F
|
||||
from allennlp.modules.matrix_attention import DotProductMatrixAttention, CosineMatrixAttention, BilinearMatrixAttention
|
||||
from element_wise import ElementWiseMatrixAttention
|
||||
from attn_unet import AttentionUNet
|
||||
|
||||
class DocREModel(nn.Module):
|
||||
def __init__(self, config, args, model, emb_size=768, block_size=64, num_labels=-1):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.bert_model = model
|
||||
self.hidden_size = config.hidden_size
|
||||
self.loss_fnt = ATLoss()
|
||||
|
||||
self.head_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)
|
||||
self.tail_extractor = nn.Linear(1 * config.hidden_size + args.unet_out_dim, emb_size)
|
||||
# self.head_extractor = nn.Linear(1 * config.hidden_size , emb_size)
|
||||
# self.tail_extractor = nn.Linear(1 * config.hidden_size , emb_size)
|
||||
self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)
|
||||
|
||||
self.emb_size = emb_size
|
||||
self.block_size = block_size
|
||||
self.num_labels = num_labels
|
||||
|
||||
self.bertdrop = nn.Dropout(0.6)
|
||||
self.unet_in_dim = args.unet_in_dim
|
||||
self.unet_out_dim = args.unet_in_dim
|
||||
self.liner = nn.Linear(config.hidden_size, args.unet_in_dim)
|
||||
self.min_height = args.max_height
|
||||
self.channel_type = args.channel_type
|
||||
self.segmentation_net = AttentionUNet(input_channels=args.unet_in_dim,
|
||||
class_number=args.unet_out_dim,
|
||||
down_channel=args.down_dim)
|
||||
|
||||
|
||||
def encode(self, input_ids, attention_mask,entity_pos):
|
||||
config = self.config
|
||||
if config.transformer_type == "bert":
|
||||
start_tokens = [config.cls_token_id]
|
||||
end_tokens = [config.sep_token_id]
|
||||
elif config.transformer_type == "roberta":
|
||||
start_tokens = [config.cls_token_id]
|
||||
end_tokens = [config.sep_token_id, config.sep_token_id]
|
||||
sequence_output, attention = process_long_input(self.bert_model, input_ids, attention_mask, start_tokens, end_tokens)
|
||||
return sequence_output, attention
|
||||
|
||||
def get_hrt(self, sequence_output, attention, entity_pos, hts):
|
||||
offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
|
||||
bs, h, _, c = attention.size()
|
||||
# ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数
|
||||
|
||||
hss, tss, rss = [], [], []
|
||||
entity_es = []
|
||||
entity_as = []
|
||||
for i in range(len(entity_pos)):
|
||||
entity_embs, entity_atts = [], []
|
||||
for entity_num, e in enumerate(entity_pos[i]):
|
||||
if len(e) > 1:
|
||||
e_emb, e_att = [], []
|
||||
for start, end in e:
|
||||
if start + offset < c:
|
||||
# In case the entity mention is truncated due to limited max seq length.
|
||||
e_emb.append(sequence_output[i, start + offset])
|
||||
e_att.append(attention[i, :, start + offset])
|
||||
if len(e_emb) > 0:
|
||||
e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
|
||||
e_att = torch.stack(e_att, dim=0).mean(0)
|
||||
else:
|
||||
e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
|
||||
e_att = torch.zeros(h, c).to(attention)
|
||||
else:
|
||||
start, end = e[0]
|
||||
if start + offset < c:
|
||||
e_emb = sequence_output[i, start + offset]
|
||||
e_att = attention[i, :, start + offset]
|
||||
else:
|
||||
e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)
|
||||
e_att = torch.zeros(h, c).to(attention)
|
||||
entity_embs.append(e_emb)
|
||||
entity_atts.append(e_att)
|
||||
for _ in range(self.min_height-entity_num-1):
|
||||
entity_atts.append(e_att)
|
||||
|
||||
entity_embs = torch.stack(entity_embs, dim=0) # [n_e, d]
|
||||
entity_atts = torch.stack(entity_atts, dim=0) # [n_e, h, seq_len]
|
||||
|
||||
|
||||
entity_es.append(entity_embs)
|
||||
entity_as.append(entity_atts)
|
||||
ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
|
||||
hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
|
||||
ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
|
||||
|
||||
# h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
|
||||
# t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
|
||||
# ht_att = (h_att * t_att).mean(1)
|
||||
# ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
|
||||
# rs = contract("ld,rl->rd", sequence_output[i], ht_att)
|
||||
hss.append(hs)
|
||||
tss.append(ts)
|
||||
# rss.append(rs)
|
||||
hss = torch.cat(hss, dim=0)
|
||||
tss = torch.cat(tss, dim=0)
|
||||
# rss = torch.cat(rss, dim=0)
|
||||
return hss, tss, entity_es, entity_as
|
||||
|
||||
def get_mask(self, ents, bs, ne, run_device):
|
||||
ent_mask = torch.zeros(bs, ne, device=run_device)
|
||||
rel_mask = torch.zeros(bs, ne, ne, device=run_device)
|
||||
for _b in range(bs):
|
||||
ent_mask[_b, :len(ents[_b])] = 1
|
||||
rel_mask[_b, :len(ents[_b]), :len(ents[_b])] = 1
|
||||
return ent_mask, rel_mask
|
||||
|
||||
|
||||
def get_ht(self, rel_enco, hts):
|
||||
htss = []
|
||||
for i in range(len(hts)):
|
||||
ht_index = hts[i]
|
||||
for (h_index, t_index) in ht_index:
|
||||
htss.append(rel_enco[i,h_index,t_index])
|
||||
htss = torch.stack(htss,dim=0)
|
||||
return htss
|
||||
|
||||
def get_channel_map(self, sequence_output, entity_as):
|
||||
# sequence_output = sequence_output.to('cpu')
|
||||
# attention = attention.to('cpu')
|
||||
bs,_,d = sequence_output.size()
|
||||
# ne = max([len(x) for x in entity_as]) # 本次bs中的最大实体数
|
||||
ne = self.min_height
|
||||
|
||||
index_pair = []
|
||||
for i in range(ne):
|
||||
tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)
|
||||
index_pair.append(tmp)
|
||||
index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)
|
||||
map_rss = []
|
||||
for b in range(bs):
|
||||
entity_atts = entity_as[b]
|
||||
h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])
|
||||
t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])
|
||||
ht_att = (h_att * t_att).mean(1)
|
||||
ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
|
||||
rs = contract("ld,rl->rd", sequence_output[b], ht_att)
|
||||
map_rss.append(rs)
|
||||
map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)
|
||||
return map_rss
|
||||
|
||||
def forward(self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
labels=None,
|
||||
entity_pos=None,
|
||||
hts=None,
|
||||
instance_mask=None,
|
||||
):
|
||||
|
||||
sequence_output, attention = self.encode(input_ids, attention_mask,entity_pos)
|
||||
# sequence_output = self.bertdrop(sequence_output)
|
||||
|
||||
bs, sequen_len, d = sequence_output.shape
|
||||
run_device = sequence_output.device.index
|
||||
ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数
|
||||
ent_mask, rel_mask = self.get_mask(entity_pos, bs, ne, run_device)
|
||||
|
||||
# get hs, ts and entity_embs >> entity_rs
|
||||
hs, ts, entity_embs, entity_as = self.get_hrt(sequence_output, attention, entity_pos, hts)
|
||||
|
||||
|
||||
# 获得通道map的两种不同方法
|
||||
if self.channel_type == 'att_map':
|
||||
feature_map = self.get_channel_map(sequence_output, entity_as)
|
||||
##print('feature_map:', feature_map.shape)
|
||||
attn_input = self.liner(feature_map).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
elif self.channel_type == 'dot_map':
|
||||
ent_encode = sequence_output.new_zeros(bs, self.min_height, d)
|
||||
for _b in range(bs):
|
||||
entity_emb = entity_embs[_b]
|
||||
entity_num = entity_emb.size(0)
|
||||
ent_encode[_b, :entity_num, :] = entity_emb
|
||||
# similar0 = ElementWiseMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
|
||||
similar1 = DotProductMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
|
||||
similar2 = CosineMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)
|
||||
similar3 = BilinearMatrixAttention(self.emb_size,self.self.emb_size).to(ent_encode.device)(ent_encode, ent_encode).unsqueeze(-1)
|
||||
attn_input = torch.cat([similar1,similar2,similar3],dim=-1).permute(0, 3, 1, 2).contiguous()
|
||||
else:
|
||||
# raise Exception("channel_type must be specify correctly")
|
||||
pass
|
||||
|
||||
|
||||
attn_map = self.segmentation_net(attn_input)
|
||||
h_t = self.get_ht (attn_map, hts)
|
||||
|
||||
hs = torch.tanh(self.head_extractor(torch.cat([hs, h_t], dim=1)))
|
||||
ts = torch.tanh(self.tail_extractor(torch.cat([ts, h_t], dim=1)))
|
||||
|
||||
|
||||
b1 = hs.view(-1, self.emb_size // self.block_size, self.block_size)
|
||||
b2 = ts.view(-1, self.emb_size // self.block_size, self.block_size)
|
||||
bl = (b1.unsqueeze(3) * b2.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
|
||||
# bl = torch.cat((bl,ht),dim=-1)
|
||||
logits = self.bilinear(bl)
|
||||
|
||||
output = logits
|
||||
# output = (self.loss_fnt.get_label(logits, num_labels=self.num_labels))
|
||||
if labels is not None:
|
||||
labels = [torch.tensor(label) for label in labels]
|
||||
labels = torch.cat(labels, dim=0).to(logits)
|
||||
loss = self.loss_fnt(logits.float(), labels.float())
|
||||
output = (loss.to(sequence_output), output)
|
||||
return output
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
|
||||
name_list = ['1-5', '6-10', '11-15', '16-20', '21-25', '26-30', '31-35', '35-42']
|
||||
|
||||
data1 = [70.3, 65.0, 64.0, 63.8, 62.7, 59.2, 57.7, 55.3 ]
|
||||
data2 = [70.0, 64.2, 63.1, 62.0, 60.5, 56.1, 54.3, 51.4 ]
|
||||
|
||||
x =list(range(len(data1)))
|
||||
|
||||
|
||||
plt.figure(figsize=(6, 3))
|
||||
plt.plot(x,data1,marker='>',label='w/ U-shaped Network')
|
||||
plt.plot(x,data2,'--',marker='o',label='w/o U-shaped Network')
|
||||
|
||||
|
||||
plt.ylim([50,75])
|
||||
plt.xlabel('# of Entities per Document')
|
||||
plt.ylabel('dev F1 (in %)')
|
||||
plt.xticks(x,name_list)
|
||||
plt.legend(loc='upper right')
|
||||
plt.show()
|
|
@ -0,0 +1,437 @@
|
|||
from fsspec import transaction
|
||||
from torch.utils import data
|
||||
from tqdm import tqdm
|
||||
from transformers.models.auto.configuration_auto import F
|
||||
import ujson as json
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import numpy as np
|
||||
docred_rel2id = json.load(open('../meta/rel2id.json', 'r'))
|
||||
cdr_rel2id = {'1:NR:2': 0, '1:CID:2': 1}
|
||||
gda_rel2id = {'1:NR:2': 0, '1:GDA:2': 1}
|
||||
|
||||
|
||||
def chunks(l, n):
|
||||
res = []
|
||||
for i in range(0, len(l), n):
|
||||
assert len(l[i:i + n]) == n
|
||||
res += [l[i:i + n]]
|
||||
return res
|
||||
|
||||
class ReadDataset:
|
||||
def __init__(self, dataset: str, tokenizer, max_seq_Length: int = 1024,
|
||||
transformers: str = 'bert') -> None:
|
||||
self.transformers = transformers
|
||||
self.dataset = dataset
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_Length = max_seq_Length
|
||||
|
||||
def read(self, file_in: str):
|
||||
save_file = file_in.split('.json')[0] + self.transformers + '_' \
|
||||
+ self.dataset + '.pkl'
|
||||
if self.dataset == 'docred':
|
||||
read_docred(self.transformers, file_in, save_file, self.tokenizer, self.max_seq_Length)
|
||||
elif self.dataset == 'cdr':
|
||||
read_cdr(file_in, save_file, self.tokenizer, self.max_seq_Length)
|
||||
elif self.dataset == 'gda':
|
||||
read_gda(file_in, save_file, self.tokenizer, self.max_seq_Length)
|
||||
else:
|
||||
raise RuntimeError("No read func for this dataset.")
|
||||
|
||||
|
||||
|
||||
def read_docred(transfermers, file_in, save_file, tokenizer, max_seq_length=1024):
|
||||
if os.path.exists(save_file):
|
||||
with open(file=save_file, mode='rb') as fr:
|
||||
features = pickle.load(fr)
|
||||
fr.close()
|
||||
print('load preprocessed data from {}.'.format(save_file))
|
||||
return features
|
||||
else:
|
||||
max_len = 0
|
||||
up512_num = 0
|
||||
i_line = 0
|
||||
pos_samples = 0
|
||||
neg_samples = 0
|
||||
features = []
|
||||
if file_in == "":
|
||||
return None
|
||||
with open(file_in, "r") as fh:
|
||||
data = json.load(fh)
|
||||
if transfermers == 'bert':
|
||||
# entity_type = ["ORG", "-", "LOC", "-", "TIME", "-", "PER", "-", "MISC", "-", "NUM"]
|
||||
entity_type = ["-", "ORG", "-", "LOC", "-", "TIME", "-", "PER", "-", "MISC", "-", "NUM"]
|
||||
|
||||
|
||||
for sample in tqdm(data, desc="Example"):
|
||||
sents = []
|
||||
sent_map = []
|
||||
|
||||
entities = sample['vertexSet']
|
||||
entity_start, entity_end = [], []
|
||||
mention_types = []
|
||||
for entity in entities:
|
||||
for mention in entity:
|
||||
sent_id = mention["sent_id"]
|
||||
pos = mention["pos"]
|
||||
entity_start.append((sent_id, pos[0]))
|
||||
entity_end.append((sent_id, pos[1] - 1))
|
||||
mention_types.append(mention['type'])
|
||||
|
||||
for i_s, sent in enumerate(sample['sents']):
|
||||
new_map = {}
|
||||
for i_t, token in enumerate(sent):
|
||||
tokens_wordpiece = tokenizer.tokenize(token)
|
||||
if (i_s, i_t) in entity_start:
|
||||
t = entity_start.index((i_s, i_t))
|
||||
if transfermers == 'bert':
|
||||
mention_type = mention_types[t]
|
||||
special_token_i = entity_type.index(mention_type)
|
||||
special_token = ['[unused' + str(special_token_i) + ']']
|
||||
else:
|
||||
special_token = ['*']
|
||||
tokens_wordpiece = special_token + tokens_wordpiece
|
||||
# tokens_wordpiece = ["[unused0]"]+ tokens_wordpiece
|
||||
|
||||
if (i_s, i_t) in entity_end:
|
||||
t = entity_end.index((i_s, i_t))
|
||||
if transfermers == 'bert':
|
||||
mention_type = mention_types[t]
|
||||
special_token_i = entity_type.index(mention_type) + 50
|
||||
special_token = ['[unused' + str(special_token_i) + ']']
|
||||
else:
|
||||
special_token = ['*']
|
||||
tokens_wordpiece = tokens_wordpiece + special_token
|
||||
|
||||
# tokens_wordpiece = tokens_wordpiece + ["[unused1]"]
|
||||
# print(tokens_wordpiece,tokenizer.convert_tokens_to_ids(tokens_wordpiece))
|
||||
|
||||
new_map[i_t] = len(sents)
|
||||
sents.extend(tokens_wordpiece)
|
||||
new_map[i_t + 1] = len(sents)
|
||||
sent_map.append(new_map)
|
||||
|
||||
if len(sents)>max_len:
|
||||
max_len=len(sents)
|
||||
if len(sents)>512:
|
||||
up512_num += 1
|
||||
|
||||
train_triple = {}
|
||||
if "labels" in sample:
|
||||
for label in sample['labels']:
|
||||
evidence = label['evidence']
|
||||
r = int(docred_rel2id[label['r']])
|
||||
if (label['h'], label['t']) not in train_triple:
|
||||
train_triple[(label['h'], label['t'])] = [
|
||||
{'relation': r, 'evidence': evidence}]
|
||||
else:
|
||||
train_triple[(label['h'], label['t'])].append(
|
||||
{'relation': r, 'evidence': evidence})
|
||||
|
||||
entity_pos = []
|
||||
for e in entities:
|
||||
entity_pos.append([])
|
||||
mention_num = len(e)
|
||||
for m in e:
|
||||
start = sent_map[m["sent_id"]][m["pos"][0]]
|
||||
end = sent_map[m["sent_id"]][m["pos"][1]]
|
||||
entity_pos[-1].append((start, end,))
|
||||
|
||||
|
||||
relations, hts = [], []
|
||||
# Get positive samples from dataset
|
||||
for h, t in train_triple.keys():
|
||||
relation = [0] * len(docred_rel2id)
|
||||
for mention in train_triple[h, t]:
|
||||
relation[mention["relation"]] = 1
|
||||
evidence = mention["evidence"]
|
||||
relations.append(relation)
|
||||
hts.append([h, t])
|
||||
pos_samples += 1
|
||||
|
||||
# Get negative samples from dataset
|
||||
for h in range(len(entities)):
|
||||
for t in range(len(entities)):
|
||||
if h != t and [h, t] not in hts:
|
||||
relation = [1] + [0] * (len(docred_rel2id) - 1)
|
||||
relations.append(relation)
|
||||
hts.append([h, t])
|
||||
neg_samples += 1
|
||||
|
||||
assert len(relations) == len(entities) * (len(entities) - 1)
|
||||
|
||||
if len(hts)==0:
|
||||
print(len(sent))
|
||||
sents = sents[:max_seq_length - 2]
|
||||
input_ids = tokenizer.convert_tokens_to_ids(sents)
|
||||
input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)
|
||||
|
||||
i_line += 1
|
||||
feature = {'input_ids': input_ids,
|
||||
'entity_pos': entity_pos,
|
||||
'labels': relations,
|
||||
'hts': hts,
|
||||
'title': sample['title'],
|
||||
}
|
||||
features.append(feature)
|
||||
|
||||
|
||||
|
||||
print("# of documents {}.".format(i_line))
|
||||
print("# of positive examples {}.".format(pos_samples))
|
||||
print("# of negative examples {}.".format(neg_samples))
|
||||
print("# {} examples len>512 and max len is {}.".format(up512_num, max_len))
|
||||
|
||||
|
||||
with open(file=save_file, mode='wb') as fw:
|
||||
pickle.dump(features, fw)
|
||||
print('finish reading {} and save preprocessed data to {}.'.format(file_in, save_file))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
|
||||
def read_cdr(file_in, save_file, tokenizer, max_seq_length=1024):
|
||||
if os.path.exists(save_file):
|
||||
with open(file=save_file, mode='rb') as fr:
|
||||
features = pickle.load(fr)
|
||||
fr.close()
|
||||
print('load preprocessed data from {}.'.format(save_file))
|
||||
return features
|
||||
else:
|
||||
pmids = set()
|
||||
features = []
|
||||
maxlen = 0
|
||||
with open(file_in, 'r') as infile:
|
||||
lines = infile.readlines()
|
||||
for i_l, line in enumerate(tqdm(lines)):
|
||||
line = line.rstrip().split('\t')
|
||||
pmid = line[0]
|
||||
|
||||
if pmid not in pmids:
|
||||
pmids.add(pmid)
|
||||
text = line[1]
|
||||
prs = chunks(line[2:], 17)
|
||||
|
||||
ent2idx = {}
|
||||
train_triples = {}
|
||||
|
||||
entity_pos = set()
|
||||
for p in prs:
|
||||
es = list(map(int, p[8].split(':')))
|
||||
ed = list(map(int, p[9].split(':')))
|
||||
tpy = p[7]
|
||||
for start, end in zip(es, ed):
|
||||
entity_pos.add((start, end, tpy))
|
||||
|
||||
es = list(map(int, p[14].split(':')))
|
||||
ed = list(map(int, p[15].split(':')))
|
||||
tpy = p[13]
|
||||
for start, end in zip(es, ed):
|
||||
entity_pos.add((start, end, tpy))
|
||||
|
||||
sents = [t.split(' ') for t in text.split('|')]
|
||||
new_sents = []
|
||||
sent_map = {}
|
||||
i_t = 0
|
||||
for sent in sents:
|
||||
for token in sent:
|
||||
tokens_wordpiece = tokenizer.tokenize(token)
|
||||
for start, end, tpy in list(entity_pos):
|
||||
if i_t == start:
|
||||
tokens_wordpiece = ["*"] + tokens_wordpiece
|
||||
if i_t + 1 == end:
|
||||
tokens_wordpiece = tokens_wordpiece + ["*"]
|
||||
sent_map[i_t] = len(new_sents)
|
||||
new_sents.extend(tokens_wordpiece)
|
||||
i_t += 1
|
||||
sent_map[i_t] = len(new_sents)
|
||||
sents = new_sents
|
||||
|
||||
entity_pos = []
|
||||
|
||||
for p in prs:
|
||||
if p[0] == "not_include":
|
||||
continue
|
||||
if p[1] == "L2R":
|
||||
h_id, t_id = p[5], p[11]
|
||||
h_start, t_start = p[8], p[14]
|
||||
h_end, t_end = p[9], p[15]
|
||||
else:
|
||||
t_id, h_id = p[5], p[11]
|
||||
t_start, h_start = p[8], p[14]
|
||||
t_end, h_end = p[9], p[15]
|
||||
h_start = map(int, h_start.split(':'))
|
||||
h_end = map(int, h_end.split(':'))
|
||||
t_start = map(int, t_start.split(':'))
|
||||
t_end = map(int, t_end.split(':'))
|
||||
h_start = [sent_map[idx] for idx in h_start]
|
||||
h_end = [sent_map[idx] for idx in h_end]
|
||||
t_start = [sent_map[idx] for idx in t_start]
|
||||
t_end = [sent_map[idx] for idx in t_end]
|
||||
if h_id not in ent2idx:
|
||||
ent2idx[h_id] = len(ent2idx)
|
||||
entity_pos.append(list(zip(h_start, h_end)))
|
||||
if t_id not in ent2idx:
|
||||
ent2idx[t_id] = len(ent2idx)
|
||||
entity_pos.append(list(zip(t_start, t_end)))
|
||||
h_id, t_id = ent2idx[h_id], ent2idx[t_id]
|
||||
|
||||
r = cdr_rel2id[p[0]]
|
||||
if (h_id, t_id) not in train_triples:
|
||||
train_triples[(h_id, t_id)] = [{'relation': r}]
|
||||
else:
|
||||
train_triples[(h_id, t_id)].append({'relation': r})
|
||||
|
||||
relations, hts = [], []
|
||||
for h, t in train_triples.keys():
|
||||
relation = [0] * len(cdr_rel2id)
|
||||
for mention in train_triples[h, t]:
|
||||
relation[mention["relation"]] = 1
|
||||
relations.append(relation)
|
||||
hts.append([h, t])
|
||||
|
||||
maxlen = max(maxlen, len(sents))
|
||||
sents = sents[:max_seq_length - 2]
|
||||
input_ids = tokenizer.convert_tokens_to_ids(sents)
|
||||
input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)
|
||||
|
||||
if len(hts) > 0:
|
||||
feature = {'input_ids': input_ids,
|
||||
'entity_pos': entity_pos,
|
||||
'labels': relations,
|
||||
'hts': hts,
|
||||
'title': pmid,
|
||||
}
|
||||
features.append(feature)
|
||||
print("Number of documents: {}.".format(len(features)))
|
||||
print("Max document length: {}.".format(maxlen))
|
||||
|
||||
with open(file=save_file, mode='wb') as fw:
|
||||
pickle.dump(features, fw)
|
||||
print('finish reading {} and save preprocessed data to {}.'.format(file_in, save_file))
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def read_gda(file_in, save_file, tokenizer, max_seq_length=1024):
|
||||
if os.path.exists(save_file):
|
||||
with open(file=save_file, mode='rb') as fr:
|
||||
features = pickle.load(fr)
|
||||
fr.close()
|
||||
print('load preprocessed data from {}.'.format(save_file))
|
||||
return features
|
||||
else:
|
||||
pmids = set()
|
||||
features = []
|
||||
maxlen = 0
|
||||
with open(file_in, 'r') as infile:
|
||||
lines = infile.readlines()
|
||||
for i_l, line in enumerate(tqdm(lines)):
|
||||
line = line.rstrip().split('\t')
|
||||
pmid = line[0]
|
||||
|
||||
if pmid not in pmids:
|
||||
pmids.add(pmid)
|
||||
text = line[1]
|
||||
prs = chunks(line[2:], 17)
|
||||
|
||||
ent2idx = {}
|
||||
train_triples = {}
|
||||
|
||||
entity_pos = set()
|
||||
for p in prs:
|
||||
es = list(map(int, p[8].split(':')))
|
||||
ed = list(map(int, p[9].split(':')))
|
||||
tpy = p[7]
|
||||
for start, end in zip(es, ed):
|
||||
entity_pos.add((start, end, tpy))
|
||||
|
||||
es = list(map(int, p[14].split(':')))
|
||||
ed = list(map(int, p[15].split(':')))
|
||||
tpy = p[13]
|
||||
for start, end in zip(es, ed):
|
||||
entity_pos.add((start, end, tpy))
|
||||
|
||||
sents = [t.split(' ') for t in text.split('|')]
|
||||
new_sents = []
|
||||
sent_map = {}
|
||||
i_t = 0
|
||||
for sent in sents:
|
||||
for token in sent:
|
||||
tokens_wordpiece = tokenizer.tokenize(token)
|
||||
for start, end, tpy in list(entity_pos):
|
||||
if i_t == start:
|
||||
tokens_wordpiece = ["*"] + tokens_wordpiece
|
||||
if i_t + 1 == end:
|
||||
tokens_wordpiece = tokens_wordpiece + ["*"]
|
||||
sent_map[i_t] = len(new_sents)
|
||||
new_sents.extend(tokens_wordpiece)
|
||||
i_t += 1
|
||||
sent_map[i_t] = len(new_sents)
|
||||
sents = new_sents
|
||||
|
||||
entity_pos = []
|
||||
|
||||
for p in prs:
|
||||
if p[0] == "not_include":
|
||||
continue
|
||||
if p[1] == "L2R":
|
||||
h_id, t_id = p[5], p[11]
|
||||
h_start, t_start = p[8], p[14]
|
||||
h_end, t_end = p[9], p[15]
|
||||
else:
|
||||
t_id, h_id = p[5], p[11]
|
||||
t_start, h_start = p[8], p[14]
|
||||
t_end, h_end = p[9], p[15]
|
||||
h_start = map(int, h_start.split(':'))
|
||||
h_end = map(int, h_end.split(':'))
|
||||
t_start = map(int, t_start.split(':'))
|
||||
t_end = map(int, t_end.split(':'))
|
||||
h_start = [sent_map[idx] for idx in h_start]
|
||||
h_end = [sent_map[idx] for idx in h_end]
|
||||
t_start = [sent_map[idx] for idx in t_start]
|
||||
t_end = [sent_map[idx] for idx in t_end]
|
||||
if h_id not in ent2idx:
|
||||
ent2idx[h_id] = len(ent2idx)
|
||||
entity_pos.append(list(zip(h_start, h_end)))
|
||||
if t_id not in ent2idx:
|
||||
ent2idx[t_id] = len(ent2idx)
|
||||
entity_pos.append(list(zip(t_start, t_end)))
|
||||
h_id, t_id = ent2idx[h_id], ent2idx[t_id]
|
||||
|
||||
r = gda_rel2id[p[0]]
|
||||
if (h_id, t_id) not in train_triples:
|
||||
train_triples[(h_id, t_id)] = [{'relation': r}]
|
||||
else:
|
||||
train_triples[(h_id, t_id)].append({'relation': r})
|
||||
|
||||
relations, hts = [], []
|
||||
for h, t in train_triples.keys():
|
||||
relation = [0] * len(gda_rel2id)
|
||||
for mention in train_triples[h, t]:
|
||||
relation[mention["relation"]] = 1
|
||||
relations.append(relation)
|
||||
hts.append([h, t])
|
||||
|
||||
maxlen = max(maxlen, len(sents))
|
||||
sents = sents[:max_seq_length - 2]
|
||||
input_ids = tokenizer.convert_tokens_to_ids(sents)
|
||||
input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)
|
||||
|
||||
if len(hts) > 0:
|
||||
feature = {'input_ids': input_ids,
|
||||
'entity_pos': entity_pos,
|
||||
'labels': relations,
|
||||
'hts': hts,
|
||||
'title': pmid,
|
||||
}
|
||||
features.append(feature)
|
||||
print("Number of documents: {}.".format(len(features)))
|
||||
print("Max document length: {}.".format(maxlen))
|
||||
with open(file=save_file, mode='wb') as fw:
|
||||
pickle.dump(features, fw)
|
||||
print('finish reading {} and save preprocessed data to {}.'.format(file_in, save_file))
|
||||
|
||||
return features
|
|
@ -0,0 +1,8 @@
|
|||
python==3.7
|
||||
cuda==10.2
|
||||
torch==1.5.0
|
||||
transformers==3.0.4
|
||||
opt-einsum==3.3.0
|
||||
ujson
|
||||
tqdm
|
||||
allennlp
|
|
@ -0,0 +1,35 @@
|
|||
#! /bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
if true; then
|
||||
type=context-based
|
||||
bs=4
|
||||
bl=3e-5
|
||||
uls=(4e-4)
|
||||
accum=1
|
||||
for ul in ${uls[@]}
|
||||
do
|
||||
python -u ../train_bio.py --data_dir ../dataset/cdr \
|
||||
--max_height 35 \
|
||||
--channel_type $type \
|
||||
--bert_lr $bl \
|
||||
--transformer_type bert \
|
||||
--model_name_or_path allenai/scibert_scivocab_cased \
|
||||
--train_file train.data \
|
||||
--dev_file dev.data \
|
||||
--test_file test.data \
|
||||
--train_batch_size $bs \
|
||||
--test_batch_size $bs \
|
||||
--gradient_accumulation_steps $accum \
|
||||
--num_labels 1 \
|
||||
--learning_rate $ul \
|
||||
--max_grad_norm 1.0 \
|
||||
--warmup_ratio 0.06 \
|
||||
--num_train_epochs 30.0 \
|
||||
--seed 111 \
|
||||
--num_class 2 \
|
||||
--save_path ../checkpoint/cdr/train_scibert-lr${bl}_accum${accum}_unet-lr${ul}_bs${bs}.pt \
|
||||
--log_dir ../logs/cdr/train_scibert-lr${bl}_accum${accum}_unet-lr${ul}_bs${bs}.log
|
||||
done
|
||||
fi
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
#! /bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
# -------------------Training Shell Script--------------------
|
||||
if true; then
|
||||
transformer_type=bert
|
||||
channel_type=context-based
|
||||
if [[ $transformer_type == bert ]]; then
|
||||
bs=4
|
||||
bl=3e-5
|
||||
ul=(3e-4 4e-4 5e-4)
|
||||
accum=1
|
||||
for ul in ${uls[@]}
|
||||
do
|
||||
python -u ../train_balanceloss.py --data_dir ../dataset/docred \
|
||||
--channel_type $channel_type \
|
||||
--bert_lr $bl \
|
||||
--transformer_type $transformer_type \
|
||||
--model_name_or_path bert-base-cased \
|
||||
--train_file train_annotated.json \
|
||||
--dev_file dev.json \
|
||||
--test_file test.json \
|
||||
--train_batch_size $bs \
|
||||
--test_batch_size $bs \
|
||||
--gradient_accumulation_steps $accum \
|
||||
--num_labels 3 \
|
||||
--learning_rate $ul \
|
||||
--max_grad_norm 1.0 \
|
||||
--warmup_ratio 0.06 \
|
||||
--num_train_epochs 30.0 \
|
||||
--seed 66 \
|
||||
--num_class 97 \
|
||||
--save_path ../checkpoint/docred/train_bert-lr${bl}_accum${accum}_unet-lr${ul}_type_${channel_type}.pt \
|
||||
--log_dir ../logs/docred/train_bert-lr${bl}_accum${accum}_unet-lr${ul}_type_${channel_type}.log
|
||||
done
|
||||
elif [[ $transformer_type == roberta ]]; then
|
||||
type=context-based
|
||||
bs=2
|
||||
bls=(3e-5)
|
||||
ul=4e-4
|
||||
accum=2
|
||||
for ul in ${uls[@]}
|
||||
do
|
||||
python -u ../train_balanceloss.py --data_dir ../dataset/docred \
|
||||
--channel_type $channel_type \
|
||||
--bert_lr $bl \
|
||||
--transformer_type $transformer_type \
|
||||
--model_name_or_path roberta-large \
|
||||
--train_file train_annotated.json \
|
||||
--dev_file dev.json \
|
||||
--test_file test.json \
|
||||
--train_batch_size $bs \
|
||||
--test_batch_size $bs \
|
||||
--gradient_accumulation_steps $accum \
|
||||
--num_labels 4 \
|
||||
--learning_rate $ul \
|
||||
--max_grad_norm 1.0 \
|
||||
--warmup_ratio 0.06 \
|
||||
--num_train_epochs 30.0 \
|
||||
--seed 111 \
|
||||
--num_class 97 \
|
||||
--save_path ../checkpoint/docred/train_bert-lr${bl}_accum${accum}_unet-lr${ul}_type_${channel_type}.pt \
|
||||
--log_dir ../logs/docred/train_bert-lr${bl}_accum${accum}_unet-lr${ul}_type_${channel_type}.log
|
||||
done
|
||||
fi
|
||||
fi
|
|
@ -0,0 +1,35 @@
|
|||
#! /bin/bash
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
if true; then
|
||||
type=context-based
|
||||
bs=4
|
||||
bl=3e-5
|
||||
uls=(4e-4)
|
||||
accum=4
|
||||
for ul in ${uls[@]}
|
||||
do
|
||||
python -u ../train_bio.py --data_dir ../dataset/gda \
|
||||
--max_height 35 \
|
||||
--channel_type $type \
|
||||
--bert_lr $bl \
|
||||
--transformer_type bert \
|
||||
--model_name_or_path allenai/scibert_scivocab_cased \
|
||||
--train_file train.data \
|
||||
--dev_file dev.data \
|
||||
--test_file test.data \
|
||||
--train_batch_size $bs \
|
||||
--test_batch_size $bs \
|
||||
--gradient_accumulation_steps $accum \
|
||||
--num_labels 1 \
|
||||
--learning_rate $ul \
|
||||
--max_grad_norm 1.0 \
|
||||
--warmup_ratio 0.06 \
|
||||
--num_train_epochs 10.0 \
|
||||
--evaluation_steps 400 \
|
||||
--seed 66 \
|
||||
--num_class 2 \
|
||||
--save_path ../checkpoint/gda/train_scibert-lr${bl}_accum${accum}_unet-lr${ul}_bs${bs}.pt \
|
||||
--log_dir ../logs/gda/train_scibert-lr${bl}_accum${accum}_unet-lr${ul}_bs${bs}.log
|
||||
done
|
||||
fi
|
|
@ -0,0 +1 @@
|
|||
# this is an empty file
|
|
@ -0,0 +1,325 @@
|
|||
import argparse
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import ujson as json
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoConfig, AutoModel, AutoTokenizer
|
||||
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
|
||||
from model_balanceloss import DocREModel
|
||||
from utils_sample import set_seed, collate_fn
|
||||
from evaluation import to_official, official_evaluate
|
||||
from prepro import ReadDataset
|
||||
|
||||
|
||||
|
||||
def train(args, model, train_features, dev_features, test_features):
|
||||
def logging(s, print_=True, log_=True):
|
||||
if print_:
|
||||
print(s)
|
||||
if log_:
|
||||
with open(args.log_dir, 'a+') as f_log:
|
||||
f_log.write(s + '\n')
|
||||
def finetune(features, optimizer, num_epoch, num_steps, model):
|
||||
if args.train_from_saved_model != '':
|
||||
best_score = torch.load(args.train_from_saved_model)["best_f1"]
|
||||
epoch_delta = torch.load(args.train_from_saved_model)["epoch"] + 1
|
||||
else:
|
||||
epoch_delta = 0
|
||||
best_score = -1
|
||||
train_dataloader = DataLoader(features, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
|
||||
train_iterator = [epoch + epoch_delta for epoch in range(int(num_epoch))]
|
||||
total_steps = int(len(train_dataloader) * num_epoch // args.gradient_accumulation_steps)
|
||||
warmup_steps = int(total_steps * args.warmup_ratio)
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
|
||||
print("Total steps: {}".format(total_steps))
|
||||
print("Warmup steps: {}".format(warmup_steps))
|
||||
global_step = 0
|
||||
log_step = 100
|
||||
total_loss = 0
|
||||
print('torch.cuda.device_count():',torch.cuda.device_count())
|
||||
if torch.cuda.device_count() > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
|
||||
|
||||
#scaler = GradScaler()
|
||||
for epoch in train_iterator:
|
||||
start_time = time.time()
|
||||
optimizer.zero_grad()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
model.train()
|
||||
|
||||
inputs = {'input_ids': batch[0].to(args.device),
|
||||
'attention_mask': batch[1].to(args.device),
|
||||
'labels': batch[2],
|
||||
'entity_pos': batch[3],
|
||||
'hts': batch[4],
|
||||
}
|
||||
#with autocast():
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] / args.gradient_accumulation_steps
|
||||
total_loss += loss.item()
|
||||
# scaler.scale(loss).backward()
|
||||
|
||||
|
||||
loss.backward()
|
||||
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
#scaler.unscale_(optimizer)
|
||||
if args.max_grad_norm > 0:
|
||||
# torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
#scaler.step(optimizer)
|
||||
#scaler.update()
|
||||
#scheduler.step()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
global_step += 1
|
||||
num_steps += 1
|
||||
if global_step % log_step == 0:
|
||||
cur_loss = total_loss / log_step
|
||||
elapsed = time.time() - start_time
|
||||
logging(
|
||||
'| epoch {:2d} | step {:4d} | min/b {:5.2f} | lr {} | train loss {:5.3f}'.format(
|
||||
epoch, global_step, elapsed / 60, scheduler.get_lr(), cur_loss * 1000))
|
||||
total_loss = 0
|
||||
start_time = time.time()
|
||||
|
||||
if (step + 1) == len(train_dataloader) - 1 or (args.evaluation_steps > 0 and num_steps % args.evaluation_steps == 0 and step % args.gradient_accumulation_steps == 0):
|
||||
# if step ==0:
|
||||
logging('-' * 89)
|
||||
eval_start_time = time.time()
|
||||
dev_score, dev_output = evaluate(args, model, dev_features, tag="dev")
|
||||
|
||||
logging(
|
||||
'| epoch {:3d} | time: {:5.2f}s | dev_result:{}'.format(epoch, time.time() - eval_start_time,
|
||||
dev_output))
|
||||
logging('-' * 89)
|
||||
if dev_score > best_score:
|
||||
best_score = dev_score
|
||||
logging(
|
||||
'| epoch {:3d} | best_f1:{}'.format(epoch, best_score))
|
||||
pred = report(args, model, test_features)
|
||||
with open("result.json", "w") as fh:
|
||||
json.dump(pred, fh)
|
||||
if args.save_path != "":
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'checkpoint': model.state_dict(),
|
||||
'best_f1': best_score,
|
||||
'optimizer': optimizer.state_dict()
|
||||
}, args.save_path
|
||||
, _use_new_zipfile_serialization=False)
|
||||
return num_steps
|
||||
|
||||
|
||||
extract_layer = ["extractor", "bilinear"]
|
||||
bert_layer = ['bert_model']
|
||||
optimizer_grouped_parameters = [
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in bert_layer)], "lr": args.bert_lr},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in extract_layer)], "lr": 1e-4},
|
||||
{"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in extract_layer + bert_layer)]},
|
||||
]
|
||||
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
if args.train_from_saved_model != '':
|
||||
optimizer.load_state_dict(torch.load(args.train_from_saved_model)["optimizer"])
|
||||
print("load saved optimizer from {}.".format(args.train_from_saved_model))
|
||||
|
||||
|
||||
num_steps = 0
|
||||
set_seed(args)
|
||||
model.zero_grad()
|
||||
finetune(train_features, optimizer, args.num_train_epochs, num_steps, model)
|
||||
|
||||
|
||||
def evaluate(args, model, features, tag="dev"):
|
||||
|
||||
dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)
|
||||
preds = []
|
||||
total_loss = 0
|
||||
for i, batch in enumerate(dataloader):
|
||||
model.eval()
|
||||
|
||||
inputs = {'input_ids': batch[0].to(args.device),
|
||||
'attention_mask': batch[1].to(args.device),
|
||||
'labels': batch[2],
|
||||
'entity_pos': batch[3],
|
||||
'hts': batch[4],
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
loss = output[0]
|
||||
pred = output[1].cpu().numpy()
|
||||
pred[np.isnan(pred)] = 0
|
||||
preds.append(pred)
|
||||
total_loss += loss.item()
|
||||
|
||||
average_loss = total_loss / (i + 1)
|
||||
preds = np.concatenate(preds, axis=0).astype(np.float32)
|
||||
ans = to_official(preds, features)
|
||||
if len(ans) > 0:
|
||||
best_f1, _, best_f1_ign, _, re_p, re_r = official_evaluate(ans, args.data_dir)
|
||||
output = {
|
||||
tag + "_F1": best_f1 * 100,
|
||||
tag + "_F1_ign": best_f1_ign * 100,
|
||||
tag + "_re_p": re_p * 100,
|
||||
tag + "_re_r": re_r * 100,
|
||||
tag + "_average_loss": average_loss
|
||||
}
|
||||
return best_f1, output
|
||||
|
||||
|
||||
def report(args, model, features):
|
||||
|
||||
dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)
|
||||
preds = []
|
||||
for batch in dataloader:
|
||||
model.eval()
|
||||
|
||||
inputs = {'input_ids': batch[0].to(args.device),
|
||||
'attention_mask': batch[1].to(args.device),
|
||||
'entity_pos': batch[3],
|
||||
'hts': batch[4],
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
pred, *_ = model(**inputs)
|
||||
pred = pred.cpu().numpy()
|
||||
pred[np.isnan(pred)] = 0
|
||||
preds.append(pred)
|
||||
print(preds)
|
||||
|
||||
preds = np.concatenate(preds, axis=0).astype(np.float32)
|
||||
print(preds)
|
||||
preds = to_official(preds, features)
|
||||
print(preds)
|
||||
return preds
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--data_dir", default="./dataset/docred", type=str)
|
||||
parser.add_argument("--transformer_type", default="bert", type=str)
|
||||
parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str)
|
||||
|
||||
parser.add_argument("--train_file", default="train_annotated.json", type=str)
|
||||
parser.add_argument("--dev_file", default="dev.json", type=str)
|
||||
parser.add_argument("--test_file", default="test.json", type=str)
|
||||
parser.add_argument("--save_path", default="", type=str)
|
||||
parser.add_argument("--load_path", default="", type=str)
|
||||
|
||||
parser.add_argument("--config_name", default="", type=str,
|
||||
help="Pretrained config name or path if not the same as model_name")
|
||||
parser.add_argument("--tokenizer_name", default="", type=str,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name")
|
||||
parser.add_argument("--max_seq_length", default=1024, type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.")
|
||||
|
||||
parser.add_argument("--train_batch_size", default=4, type=int,
|
||||
help="Batch size for training.")
|
||||
parser.add_argument("--test_batch_size", default=8, type=int,
|
||||
help="Batch size for testing.")
|
||||
parser.add_argument("--gradient_accumulation_steps", default=1, type=int,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument("--num_labels", default=4, type=int,
|
||||
help="Max number of labels in prediction.")
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--bert_lr", default=5e-5, type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-6, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--warmup_ratio", default=0.06, type=float,
|
||||
help="Warm up ratio for Adam.")
|
||||
parser.add_argument("--num_train_epochs", default=30.0, type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--evaluation_steps", default=-1, type=int,
|
||||
help="Number of training steps between evaluations.")
|
||||
parser.add_argument("--seed", type=int, default=66,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument("--num_class", type=int, default=97,
|
||||
help="Number of relation types in dataset.")
|
||||
|
||||
parser.add_argument("--unet_in_dim", type=int, default=3,
|
||||
help="unet_in_dim.")
|
||||
parser.add_argument("--unet_out_dim", type=int, default=256,
|
||||
help="unet_out_dim.")
|
||||
parser.add_argument("--down_dim", type=int, default=256,
|
||||
help="down_dim.")
|
||||
parser.add_argument("--channel_type", type=str, default='',
|
||||
help="unet_out_dim.")
|
||||
parser.add_argument("--log_dir", type=str, default='',
|
||||
help="log.")
|
||||
parser.add_argument("--max_height", type=int, default=42,
|
||||
help="log.")
|
||||
parser.add_argument("--train_from_saved_model", type=str, default='',
|
||||
help="train from a saved model.")
|
||||
parser.add_argument("--dataset", type=str, default='docred',
|
||||
help="dataset type")
|
||||
|
||||
args = parser.parse_args()
|
||||
print('args:',args)
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
args.device = device
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=args.num_class,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
)
|
||||
|
||||
Dataset = ReadDataset(args.dataset, tokenizer, args.max_seq_length)
|
||||
|
||||
train_file = os.path.join(args.data_dir, args.train_file)
|
||||
dev_file = os.path.join(args.data_dir, args.dev_file)
|
||||
test_file = os.path.join(args.data_dir, args.test_file)
|
||||
train_features = Dataset.read(train_file)
|
||||
dev_features = Dataset.read(dev_file)
|
||||
test_features = Dataset.read(test_file)
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
)
|
||||
|
||||
config.cls_token_id = tokenizer.cls_token_id
|
||||
config.sep_token_id = tokenizer.sep_token_id
|
||||
config.transformer_type = args.transformer_type
|
||||
|
||||
set_seed(args)
|
||||
model = DocREModel(config, args, model, num_labels=args.num_labels)
|
||||
if args.train_from_saved_model != '':
|
||||
model.load_state_dict(torch.load(args.train_from_saved_model)["checkpoint"])
|
||||
print("load saved model from {}.".format(args.train_from_saved_model))
|
||||
model.to(0)
|
||||
|
||||
|
||||
if args.load_path == "": # Training
|
||||
train(args, model, train_features, dev_features, test_features)
|
||||
else: # Testing
|
||||
model.load_state_dict(torch.load(args.load_path)['checkpoint'])
|
||||
T_features = test_features # Testing on the test set
|
||||
T_score, T_output = evaluate(args, model, T_features, tag="test")
|
||||
print(T_output)
|
||||
pred = report(args, model, T_features)
|
||||
with open("result.json", "w") as fh:
|
||||
json.dump(pred, fh)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,73 @@
|
|||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0 and torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def collate_fn_sample(batch):
|
||||
max_len = max([len(f["input_ids"]) for f in batch])
|
||||
input_ids = [f["input_ids"] + [0] * (max_len - len(f["input_ids"])) for f in batch]
|
||||
input_mask = [[1.0] * len(f["input_ids"]) + [0.0] * (max_len - len(f["input_ids"])) for f in batch]
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||
input_mask = torch.tensor(input_mask, dtype=torch.float)
|
||||
entity_pos = [f["entity_pos"] for f in batch]
|
||||
|
||||
negative_alpha = 8
|
||||
positive_alpha = 1
|
||||
labels, hts = [], []
|
||||
for f in batch:
|
||||
randnum = random.randint(0, 1000000)
|
||||
|
||||
pos_hts = f['pos_hts']
|
||||
pos_labels = f['pos_labels']
|
||||
neg_hts = f['neg_hts']
|
||||
neg_labels = f['neg_labels']
|
||||
if negative_alpha > 0:
|
||||
random.seed(randnum)
|
||||
random.shuffle(neg_hts)
|
||||
random.seed(randnum)
|
||||
random.shuffle(neg_labels)
|
||||
lower_bound = int(max(20, len(pos_hts) * negative_alpha))
|
||||
hts.append( pos_hts * positive_alpha + neg_hts[:lower_bound] )
|
||||
labels.append( pos_labels * positive_alpha + neg_labels[:lower_bound] )
|
||||
|
||||
# labels = [f["labels"] for f in batch]
|
||||
# hts = [f["hts"] for f in batch]
|
||||
|
||||
|
||||
# entity_pos_single = []
|
||||
# # for f in batch:
|
||||
# # entity_pos_item = f["entity_pos"]
|
||||
# # entity_pos2 = []
|
||||
# # for e in entity_pos_item:
|
||||
# # entity_pos2.append([])
|
||||
# # mention_num = len(e)
|
||||
# # bounds = np.random.randint(mention_num, size=3)
|
||||
# # for bound in bounds:
|
||||
# # entity_pos2[-1].append(e[bound])
|
||||
# # entity_pos_single.append( torch.tensor(entity_pos2) )
|
||||
|
||||
output = (input_ids, input_mask, labels, entity_pos, hts, )
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
max_len = max([len(f["input_ids"]) for f in batch])
|
||||
input_ids = [f["input_ids"] + [0] * (max_len - len(f["input_ids"])) for f in batch]
|
||||
input_mask = [[1.0] * len(f["input_ids"]) + [0.0] * (max_len - len(f["input_ids"])) for f in batch]
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||
input_mask = torch.tensor(input_mask, dtype=torch.float)
|
||||
entity_pos = [f["entity_pos"] for f in batch]
|
||||
|
||||
labels = [f["labels"] for f in batch]
|
||||
hts = [f["hts"] for f in batch]
|
||||
output = (input_ids, input_mask, labels, entity_pos, hts )
|
||||
return output
|
Loading…
Reference in New Issue