From d1ffa5f748e7de3a91d74c9d2b464a7907a5e473 Mon Sep 17 00:00:00 2001 From: TimelordRi <35120358+TimelordRi@users.noreply.github.com> Date: Tue, 28 Sep 2021 12:46:39 +0800 Subject: [PATCH] Add files via upload --- tutorial-notebooks/re/document/tutorial.ipynb | 1167 +++++++++++++++++ 1 file changed, 1167 insertions(+) create mode 100644 tutorial-notebooks/re/document/tutorial.ipynb diff --git a/tutorial-notebooks/re/document/tutorial.ipynb b/tutorial-notebooks/re/document/tutorial.ipynb new file mode 100644 index 0000000..326b92f --- /dev/null +++ b/tutorial-notebooks/re/document/tutorial.ipynb @@ -0,0 +1,1167 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# 文档级关系抽取实践\n", + "\n", + "> Tutorial作者: 黎洲波(zhoubo.li@zju.edu.cn)\n", + "\n", + "关系抽取(Relation Extraction, RE)任务是信息抽取中的关键任务,通常而言,输入是文本和若干命名实体,输出是命名实体对的关系。\n", + "\n", + "DeepKE中的文档级(Document-level)关系抽取是基于计算机视觉(Computer Vision, CV)的语义分割方法,使用了U型网络(U-net)结构的神经网络模型,在[DocRED](https://github.com/thunlp/DocRED/tree/master/)数据集上取得了良好的效果。文档级关系抽取任务的输入输出如下图所示,图中彩色标注的是命名实体,不同于句子级(Sentence-level)关系抽取的是,文档级关系抽取不仅能在句子内部抽取实体之间的关系,还可以在多个句子之间抽取实体之间的关系。\n", + "\n", + "![文档级关系抽取](img/img1.png)\n", + "\n", + "通过本次实践展示,我希望读者能够快速了解如何构建文档级关系抽取的模型。\n", + "\n", + "## 数据集\n", + "\n", + "文档级关系抽取常用的数据集有:DocRED,CDR和GDA。本次实践使用的数据集是[DocRED](https://github.com/thunlp/DocRED/tree/master/),数据文件夹`./data/`的结构如下:\n", + "\n", + "```\n", + ".\n", + "├── dev.json # 验证集\n", + "├── rel_info.json # 关系集\n", + "├── rel2id.json # 关系标签到ID的映射\n", + "├── test.json # 测试集\n", + "└── train_annotated.json # 训练集\n", + "```\n", + "\n", + "数据文件的格式在数据集DocRED中的描述如下:\n", + "\n", + "```json\n", + "Data Format:\n", + "{\n", + " 'title',\n", + " 'sents': [\n", + " [word in sent 0],\n", + " [word in sent 1]\n", + " ]\n", + " 'vertexSet': [\n", + " [\n", + " { 'name': mention_name, \n", + " 'sent_id': mention in which sentence, \n", + " 'pos': postion of mention in a sentence, \n", + " 'type': NER_type}\n", + " {anthor mention}\n", + " ], \n", + " [anthoer entity]\n", + " ]\n", + " 'labels': [\n", + " {\n", + " 'h': idx of head entity in vertexSet,\n", + " 't': idx of tail entity in vertexSet,\n", + " 'r': relation,\n", + " 'evidence': evidence sentences' id\n", + " }\n", + " ]\n", + "}\n", + "```\n", + "\n", + "## DocuNet原理\n", + "此处,我们使用视觉图像处理中的语义分割中的U-net结构神经网络来完成关系抽取,所以把该方法称为Document U-shaped Network(DocuNet),DocuNet的原理结构图如下:\n", + "\n", + "![文档级关系抽取架构图](img/img2.png)\n", + "\n", + "那么接下来我们开始文档级关系抽取实践!" + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## 运行环境\n", + "\n", + "Python环境使用Python3,并且要求以下packages的版本:" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "!pip install python==3.7\n", + "!pip install cuda==10.2\n", + "!pip install torch==1.5.0\n", + "!pip install transformers==3.0.4\n", + "!pip install opt-einsum==3.3.0\n", + "!pip install ujson\n", + "!pip install tqdm\n", + "!pip install allennlp\n", + "!pip install matplotlib" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## 导入模块" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "import os\n", + "import time\n", + "import random\n", + "import numpy as np\n", + "import torch\n", + "import ujson as json\n", + "import os\n", + "import pickle\n", + "from tqdm import tqdm\n", + "from opt_einsum import contract\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader\n", + "from transformers import AutoConfig, AutoModel, AutoTokenizer\n", + "from transformers.optimization import AdamW, get_linear_schedule_with_warmup\n", + "from allennlp.modules.matrix_attention import DotProductMatrixAttention, CosineMatrixAttention, BilinearMatrixAttention" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## 数据集预处理" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "docred_rel2id = json.load(open('./data/rel2id.json', 'r'))\n", + "\n", + "def chunks(l, n):\n", + " res = []\n", + " for i in range(0, len(l), n):\n", + " assert len(l[i:i + n]) == n\n", + " res += [l[i:i + n]]\n", + " return res\n", + "\n", + "class ReadDataset:\n", + " def __init__(self, dataset: str, tokenizer, max_seq_Length: int = 1024,\n", + " transformers: str = 'bert') -> None:\n", + " self.transformers = transformers\n", + " self.dataset = dataset\n", + " self.tokenizer = tokenizer\n", + " self.max_seq_Length = max_seq_Length\n", + "\n", + " def read(self, file_in: str):\n", + " save_file = file_in.split('.json')[0] + '_' + self.transformers + '_' \\\n", + " + self.dataset + '.pkl'\n", + " if self.dataset == 'docred':\n", + " return read_docred(self.transformers, file_in, save_file, self.tokenizer, self.max_seq_Length)\n", + " else:\n", + " raise RuntimeError(\"No read func for this dataset.\")\n", + "\n", + "def read_docred(transfermers, file_in, save_file, tokenizer, max_seq_length=1024):\n", + " if os.path.exists(save_file):\n", + " with open(file=save_file, mode='rb') as fr:\n", + " features = pickle.load(fr)\n", + " fr.close()\n", + " print('load preprocessed data from {}.'.format(save_file))\n", + " return features\n", + " else:\n", + " max_len = 0\n", + " up512_num = 0\n", + " i_line = 0\n", + " pos_samples = 0\n", + " neg_samples = 0\n", + " features = []\n", + " if file_in == \"\":\n", + " return None\n", + " with open(file_in, \"r\") as fh:\n", + " data = json.load(fh)\n", + " if transfermers == 'bert':\n", + " # entity_type = [\"ORG\", \"-\", \"LOC\", \"-\", \"TIME\", \"-\", \"PER\", \"-\", \"MISC\", \"-\", \"NUM\"]\n", + " entity_type = [\"-\", \"ORG\", \"-\", \"LOC\", \"-\", \"TIME\", \"-\", \"PER\", \"-\", \"MISC\", \"-\", \"NUM\"]\n", + "\n", + "\n", + " for sample in tqdm(data, desc=\"Example\"):\n", + " sents = []\n", + " sent_map = []\n", + "\n", + " entities = sample['vertexSet']\n", + " entity_start, entity_end = [], []\n", + " mention_types = []\n", + " for entity in entities:\n", + " for mention in entity:\n", + " sent_id = mention[\"sent_id\"]\n", + " pos = mention[\"pos\"]\n", + " entity_start.append((sent_id, pos[0]))\n", + " entity_end.append((sent_id, pos[1] - 1))\n", + " mention_types.append(mention['type'])\n", + "\n", + " for i_s, sent in enumerate(sample['sents']):\n", + " new_map = {}\n", + " for i_t, token in enumerate(sent):\n", + " tokens_wordpiece = tokenizer.tokenize(token)\n", + " if (i_s, i_t) in entity_start:\n", + " t = entity_start.index((i_s, i_t))\n", + " if transfermers == 'bert':\n", + " mention_type = mention_types[t]\n", + " special_token_i = entity_type.index(mention_type)\n", + " special_token = ['[unused' + str(special_token_i) + ']']\n", + " else:\n", + " special_token = ['*']\n", + " tokens_wordpiece = special_token + tokens_wordpiece\n", + " # tokens_wordpiece = [\"[unused0]\"]+ tokens_wordpiece\n", + "\n", + " if (i_s, i_t) in entity_end:\n", + " t = entity_end.index((i_s, i_t))\n", + " if transfermers == 'bert':\n", + " mention_type = mention_types[t]\n", + " special_token_i = entity_type.index(mention_type) + 50\n", + " special_token = ['[unused' + str(special_token_i) + ']']\n", + " else:\n", + " special_token = ['*']\n", + " tokens_wordpiece = tokens_wordpiece + special_token\n", + "\n", + " # tokens_wordpiece = tokens_wordpiece + [\"[unused1]\"]\n", + " # print(tokens_wordpiece,tokenizer.convert_tokens_to_ids(tokens_wordpiece))\n", + "\n", + " new_map[i_t] = len(sents)\n", + " sents.extend(tokens_wordpiece)\n", + " new_map[i_t + 1] = len(sents)\n", + " sent_map.append(new_map)\n", + "\n", + " if len(sents)>max_len:\n", + " max_len=len(sents)\n", + " if len(sents)>512:\n", + " up512_num += 1\n", + "\n", + " train_triple = {}\n", + " if \"labels\" in sample:\n", + " for label in sample['labels']:\n", + " evidence = label['evidence']\n", + " r = int(docred_rel2id[label['r']])\n", + " if (label['h'], label['t']) not in train_triple:\n", + " train_triple[(label['h'], label['t'])] = [\n", + " {'relation': r, 'evidence': evidence}]\n", + " else:\n", + " train_triple[(label['h'], label['t'])].append(\n", + " {'relation': r, 'evidence': evidence})\n", + "\n", + " entity_pos = []\n", + " for e in entities:\n", + " entity_pos.append([])\n", + " mention_num = len(e)\n", + " for m in e:\n", + " start = sent_map[m[\"sent_id\"]][m[\"pos\"][0]]\n", + " end = sent_map[m[\"sent_id\"]][m[\"pos\"][1]]\n", + " entity_pos[-1].append((start, end,))\n", + "\n", + "\n", + " relations, hts = [], []\n", + " # Get positive samples from dataset\n", + " for h, t in train_triple.keys():\n", + " relation = [0] * len(docred_rel2id)\n", + " for mention in train_triple[h, t]:\n", + " relation[mention[\"relation\"]] = 1\n", + " evidence = mention[\"evidence\"]\n", + " relations.append(relation)\n", + " hts.append([h, t])\n", + " pos_samples += 1\n", + "\n", + " # Get negative samples from dataset\n", + " for h in range(len(entities)):\n", + " for t in range(len(entities)):\n", + " if h != t and [h, t] not in hts:\n", + " relation = [1] + [0] * (len(docred_rel2id) - 1)\n", + " relations.append(relation)\n", + " hts.append([h, t])\n", + " neg_samples += 1\n", + "\n", + " assert len(relations) == len(entities) * (len(entities) - 1)\n", + "\n", + " if len(hts)==0:\n", + " print(len(sent))\n", + " sents = sents[:max_seq_length - 2]\n", + " input_ids = tokenizer.convert_tokens_to_ids(sents)\n", + " input_ids = tokenizer.build_inputs_with_special_tokens(input_ids)\n", + "\n", + " i_line += 1\n", + " feature = {'input_ids': input_ids,\n", + " 'entity_pos': entity_pos,\n", + " 'labels': relations,\n", + " 'hts': hts,\n", + " 'title': sample['title'],\n", + " }\n", + " features.append(feature)\n", + "\n", + "\n", + "\n", + " print(\"# of documents {}.\".format(i_line))\n", + " print(\"# of positive examples {}.\".format(pos_samples))\n", + " print(\"# of negative examples {}.\".format(neg_samples))\n", + " print(\"# {} examples len>512 and max len is {}.\".format(up512_num, max_len))\n", + "\n", + "\n", + " with open(file=save_file, mode='wb') as fw:\n", + " pickle.dump(features, fw)\n", + " print('finish reading {} and save preprocessed data to {}.'.format(file_in, save_file))\n", + "\n", + " return features" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## 模型构建" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "class AttentionUNet(torch.nn.Module):\n", + " \"\"\"\n", + " UNet, down sampling & up sampling for global reasoning\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_channels, class_number, **kwcfg):\n", + " super(AttentionUNet, self).__init__()\n", + "\n", + " down_channel = kwcfg['down_channel'] # default = 256\n", + "\n", + " down_channel_2 = down_channel * 2\n", + " up_channel_1 = down_channel_2 * 2\n", + " up_channel_2 = down_channel * 2\n", + "\n", + " self.inc = InConv(input_channels, down_channel)\n", + " self.down1 = DownLayer(down_channel, down_channel_2)\n", + " self.down2 = DownLayer(down_channel_2, down_channel_2)\n", + "\n", + " self.up1 = UpLayer(up_channel_1, up_channel_1 // 4)\n", + " self.up2 = UpLayer(up_channel_2, up_channel_2 // 4)\n", + " self.outc = OutConv(up_channel_2 // 4, class_number)\n", + "\n", + " def forward(self, attention_channels):\n", + " \"\"\"\n", + " Given multi-channel attention map, return the logits of every one mapping into 3-class\n", + " :param attention_channels:\n", + " :return:\n", + " \"\"\"\n", + " # attention_channels as the shape of: batch_size x channel x width x height\n", + " x = attention_channels\n", + " x1 = self.inc(x)\n", + " x2 = self.down1(x1)\n", + " x3 = self.down2(x2)\n", + " x = self.up1(x3, x2)\n", + " x = self.up2(x, x1)\n", + " output = self.outc(x)\n", + " # attn_map as the shape of: batch_size x width x height x class\n", + " output = output.permute(0, 2, 3, 1).contiguous()\n", + " return output\n", + "\n", + "\n", + "class DoubleConv(nn.Module):\n", + " \"\"\"(conv => [BN] => ReLU) * 2\"\"\"\n", + "\n", + " def __init__(self, in_ch, out_ch):\n", + " super(DoubleConv, self).__init__()\n", + " self.double_conv = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),\n", + " nn.BatchNorm2d(out_ch),\n", + " nn.ReLU(inplace=True),\n", + " nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),\n", + " nn.BatchNorm2d(out_ch),\n", + " nn.ReLU(inplace=True))\n", + "\n", + " def forward(self, x):\n", + " x = self.double_conv(x)\n", + " return x\n", + "\n", + "\n", + "class InConv(nn.Module):\n", + "\n", + " def __init__(self, in_ch, out_ch):\n", + " super(InConv, self).__init__()\n", + " self.conv = DoubleConv(in_ch, out_ch)\n", + "\n", + " def forward(self, x):\n", + " x = self.conv(x)\n", + " return x\n", + "\n", + "\n", + "class DownLayer(nn.Module):\n", + "\n", + " def __init__(self, in_ch, out_ch):\n", + " super(DownLayer, self).__init__()\n", + " self.maxpool_conv = nn.Sequential(\n", + " nn.MaxPool2d(kernel_size=2),\n", + " DoubleConv(in_ch, out_ch)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.maxpool_conv(x)\n", + " return x\n", + "\n", + "\n", + "class UpLayer(nn.Module):\n", + "\n", + " def __init__(self, in_ch, out_ch, bilinear=True):\n", + " super(UpLayer, self).__init__()\n", + " if bilinear:\n", + " self.up = nn.Upsample(scale_factor=2, mode='bilinear',\n", + " align_corners=True)\n", + " else:\n", + " self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)\n", + " self.conv = DoubleConv(in_ch, out_ch)\n", + "\n", + " def forward(self, x1, x2):\n", + " x1 = self.up(x1)\n", + " diffY = x2.size()[2] - x1.size()[2]\n", + " diffX = x2.size()[3] - x1.size()[3]\n", + " x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY -\n", + " diffY // 2))\n", + " x = torch.cat([x2, x1], dim=1)\n", + " x = self.conv(x)\n", + " return x\n", + "\n", + "\n", + "class OutConv(nn.Module):\n", + "\n", + " def __init__(self, in_ch, out_ch):\n", + " super(OutConv, self).__init__()\n", + " self.conv = nn.Conv2d(in_ch, out_ch, 1)\n", + "\n", + " def forward(self, x):\n", + " x = self.conv(x)\n", + " return x\n", + "\n", + "class DocREModel(nn.Module):\n", + " def __init__(self, config, cfg, model, emb_size=768, block_size=64, num_labels=-1):\n", + " super().__init__()\n", + " self.config = config\n", + " self.bert_model = model\n", + " self.hidden_size = config.hidden_size\n", + " self.loss_fnt = ATLoss()\n", + "\n", + " self.head_extractor = nn.Linear(1 * config.hidden_size + cfg.unet_out_dim, emb_size)\n", + " self.tail_extractor = nn.Linear(1 * config.hidden_size + cfg.unet_out_dim, emb_size)\n", + " # self.head_extractor = nn.Linear(1 * config.hidden_size , emb_size)\n", + " # self.tail_extractor = nn.Linear(1 * config.hidden_size , emb_size)\n", + " self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)\n", + "\n", + " self.emb_size = emb_size\n", + " self.block_size = block_size\n", + " self.num_labels = num_labels\n", + "\n", + " self.bertdrop = nn.Dropout(0.6)\n", + " self.unet_in_dim = cfg.unet_in_dim\n", + " self.unet_out_dim = cfg.unet_in_dim\n", + " self.liner = nn.Linear(config.hidden_size, cfg.unet_in_dim)\n", + " self.min_height = cfg.max_height\n", + " self.channel_type = cfg.channel_type\n", + " self.segmentation_net = AttentionUNet(input_channels=cfg.unet_in_dim,\n", + " class_number=cfg.unet_out_dim,\n", + " down_channel=cfg.down_dim)\n", + "\n", + "\n", + " def encode(self, input_ids, attention_mask,entity_pos):\n", + " config = self.config\n", + " if config.transformer_type == \"bert\":\n", + " start_tokens = [config.cls_token_id]\n", + " end_tokens = [config.sep_token_id]\n", + " elif config.transformer_type == \"roberta\":\n", + " start_tokens = [config.cls_token_id]\n", + " end_tokens = [config.sep_token_id, config.sep_token_id]\n", + " sequence_output, attention = process_long_input(self.bert_model, input_ids, attention_mask, start_tokens, end_tokens)\n", + " return sequence_output, attention\n", + "\n", + " def get_hrt(self, sequence_output, attention, entity_pos, hts):\n", + " offset = 1 if self.config.transformer_type in [\"bert\", \"roberta\"] else 0\n", + " bs, h, _, c = attention.size()\n", + " # ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数\n", + "\n", + " hss, tss, rss = [], [], []\n", + " entity_es = []\n", + " entity_as = []\n", + " for i in range(len(entity_pos)):\n", + " entity_embs, entity_atts = [], []\n", + " for entity_num, e in enumerate(entity_pos[i]):\n", + " if len(e) > 1:\n", + " e_emb, e_att = [], []\n", + " for start, end in e:\n", + " if start + offset < c:\n", + " # In case the entity mention is truncated due to limited max seq length.\n", + " e_emb.append(sequence_output[i, start + offset])\n", + " e_att.append(attention[i, :, start + offset])\n", + " if len(e_emb) > 0:\n", + " e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)\n", + " e_att = torch.stack(e_att, dim=0).mean(0)\n", + " else:\n", + " e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)\n", + " e_att = torch.zeros(h, c).to(attention)\n", + " else:\n", + " start, end = e[0]\n", + " if start + offset < c:\n", + " e_emb = sequence_output[i, start + offset]\n", + " e_att = attention[i, :, start + offset]\n", + " else:\n", + " e_emb = torch.zeros(self.config.hidden_size).to(sequence_output)\n", + " e_att = torch.zeros(h, c).to(attention)\n", + " entity_embs.append(e_emb)\n", + " entity_atts.append(e_att)\n", + " for _ in range(self.min_height-entity_num-1):\n", + " entity_atts.append(e_att)\n", + "\n", + " entity_embs = torch.stack(entity_embs, dim=0) # [n_e, d]\n", + " entity_atts = torch.stack(entity_atts, dim=0) # [n_e, h, seq_len]\n", + "\n", + "\n", + " entity_es.append(entity_embs)\n", + " entity_as.append(entity_atts)\n", + " ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)\n", + " hs = torch.index_select(entity_embs, 0, ht_i[:, 0])\n", + " ts = torch.index_select(entity_embs, 0, ht_i[:, 1])\n", + "\n", + " hss.append(hs)\n", + " tss.append(ts)\n", + " hss = torch.cat(hss, dim=0)\n", + " tss = torch.cat(tss, dim=0)\n", + " return hss, tss, entity_es, entity_as\n", + "\n", + " def get_mask(self, ents, bs, ne, run_device):\n", + " ent_mask = torch.zeros(bs, ne, device=run_device)\n", + " rel_mask = torch.zeros(bs, ne, ne, device=run_device)\n", + " for _b in range(bs):\n", + " ent_mask[_b, :len(ents[_b])] = 1\n", + " rel_mask[_b, :len(ents[_b]), :len(ents[_b])] = 1\n", + " return ent_mask, rel_mask\n", + "\n", + "\n", + " def get_ht(self, rel_enco, hts):\n", + " htss = []\n", + " for i in range(len(hts)):\n", + " ht_index = hts[i]\n", + " for (h_index, t_index) in ht_index:\n", + " htss.append(rel_enco[i,h_index,t_index])\n", + " htss = torch.stack(htss,dim=0)\n", + " return htss\n", + "\n", + " def get_channel_map(self, sequence_output, entity_as):\n", + " # sequence_output = sequence_output.to('cpu')\n", + " # attention = attention.to('cpu')\n", + " bs,_,d = sequence_output.size()\n", + " # ne = max([len(x) for x in entity_as]) # 本次bs中的最大实体数\n", + " ne = self.min_height\n", + "\n", + " index_pair = []\n", + " for i in range(ne):\n", + " tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)\n", + " index_pair.append(tmp)\n", + " index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)\n", + " map_rss = []\n", + " for b in range(bs):\n", + " entity_atts = entity_as[b]\n", + " h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])\n", + " t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])\n", + " ht_att = (h_att * t_att).mean(1)\n", + " ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)\n", + " rs = contract(\"ld,rl->rd\", sequence_output[b], ht_att)\n", + " map_rss.append(rs)\n", + " map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)\n", + " return map_rss\n", + "\n", + " def forward(self,\n", + " input_ids=None,\n", + " attention_mask=None,\n", + " labels=None,\n", + " entity_pos=None,\n", + " hts=None,\n", + " instance_mask=None,\n", + " ):\n", + "\n", + " sequence_output, attention = self.encode(input_ids, attention_mask,entity_pos)\n", + "\n", + " bs, sequen_len, d = sequence_output.shape\n", + " run_device = sequence_output.device.index\n", + " ne = max([len(x) for x in entity_pos]) # 本次bs中的最大实体数\n", + " ent_mask, rel_mask = self.get_mask(entity_pos, bs, ne, run_device)\n", + "\n", + " # get hs, ts and entity_embs >> entity_rs\n", + " hs, ts, entity_embs, entity_as = self.get_hrt(sequence_output, attention, entity_pos, hts)\n", + "\n", + "\n", + " # 获得通道map的两种不同方法\n", + " if self.channel_type == 'context-based':\n", + " feature_map = self.get_channel_map(sequence_output, entity_as)\n", + " ##print('feature_map:', feature_map.shape)\n", + " attn_input = self.liner(feature_map).permute(0, 3, 1, 2).contiguous()\n", + "\n", + " elif self.channel_type == 'similarity-based':\n", + " ent_encode = sequence_output.new_zeros(bs, self.min_height, d)\n", + " for _b in range(bs):\n", + " entity_emb = entity_embs[_b]\n", + " entity_num = entity_emb.size(0)\n", + " ent_encode[_b, :entity_num, :] = entity_emb\n", + " # similar0 = ElementWiseMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)\n", + " similar1 = DotProductMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)\n", + " similar2 = CosineMatrixAttention()(ent_encode, ent_encode).unsqueeze(-1)\n", + " similar3 = BilinearMatrixAttention(self.emb_size,self.self.emb_size).to(ent_encode.device)(ent_encode, ent_encode).unsqueeze(-1)\n", + " attn_input = torch.cat([similar1,similar2,similar3],dim=-1).permute(0, 3, 1, 2).contiguous()\n", + " else:\n", + " raise Exception(\"channel_type must be specify correctly\")\n", + "\n", + "\n", + " attn_map = self.segmentation_net(attn_input)\n", + " h_t = self.get_ht (attn_map, hts)\n", + "\n", + " hs = torch.tanh(self.head_extractor(torch.cat([hs, h_t], dim=1)))\n", + " ts = torch.tanh(self.tail_extractor(torch.cat([ts, h_t], dim=1)))\n", + "\n", + "\n", + " b1 = hs.view(-1, self.emb_size // self.block_size, self.block_size)\n", + " b2 = ts.view(-1, self.emb_size // self.block_size, self.block_size)\n", + " bl = (b1.unsqueeze(3) * b2.unsqueeze(2)).view(-1, self.emb_size * self.block_size)\n", + " logits = self.bilinear(bl)\n", + "\n", + "\n", + " output = (self.loss_fnt.get_label(logits, num_labels=self.num_labels))\n", + " if labels is not None:\n", + " labels = [torch.tensor(label) for label in labels]\n", + " labels = torch.cat(labels, dim=0).to(logits)\n", + " loss = self.loss_fnt(logits.float(), labels.float())\n", + " output = (loss.to(sequence_output), output)\n", + " return output" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## 损失函数" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "def multilabel_categorical_crossentropy(y_true, y_pred):\n", + " \"\"\"多标签分类的交叉熵\n", + " 说明:y_true和y_pred的shape一致,y_true的元素非0即1,\n", + " 1表示对应的类为目标类,0表示对应的类为非目标类。\n", + " 警告:请保证y_pred的值域是全体实数,换言之一般情况下y_pred\n", + " 不用加激活函数,尤其是不能加sigmoid或者softmax!预测\n", + " 阶段则输出y_pred大于0的类。如有疑问,请仔细阅读并理解\n", + " 本文。\n", + " \"\"\"\n", + " y_pred = (1 - 2 * y_true) * y_pred\n", + " y_pred_neg = y_pred - y_true * 1e30\n", + " y_pred_pos = y_pred - (1 - y_true) * 1e30\n", + " zeros = torch.zeros_like(y_pred[..., :1])\n", + " y_pred_neg = torch.cat([y_pred_neg, zeros],dim=-1)\n", + " y_pred_pos = torch.cat((y_pred_pos, zeros),dim=-1)\n", + " neg_loss = torch.logsumexp(y_pred_neg, axis=-1)\n", + " pos_loss = torch.logsumexp(y_pred_pos, axis=-1)\n", + " return neg_loss + pos_loss\n", + "\n", + "\n", + "class balanced_loss(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " def forward(self, logits, labels):\n", + "\n", + " loss = multilabel_categorical_crossentropy(labels,logits)\n", + " loss = loss.mean()\n", + " return loss\n", + "\n", + " def get_label(self, logits, num_labels=-1):\n", + " th_logit = torch.zeros_like(logits[..., :1])\n", + " output = torch.zeros_like(logits).to(logits)\n", + " mask = (logits > th_logit)\n", + " if num_labels > 0:\n", + " top_v, _ = torch.topk(logits, num_labels, dim=1)\n", + " top_v = top_v[:, -1]\n", + " mask = (logits >= top_v.unsqueeze(1)) & mask\n", + " output[mask] = 1.0\n", + " output[:, 0] = (output[:,1:].sum(1) == 0.).to(logits)\n", + "\n", + " return output" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## 输入预处理" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "def process_long_input(model, input_ids, attention_mask, start_tokens, end_tokens):\n", + " # Split the input to 2 overlapping chunks. Now BERT can encode inputs of which the length are up to 1024.\n", + " n, c = input_ids.size()\n", + " start_tokens = torch.tensor(start_tokens).to(input_ids)\n", + " end_tokens = torch.tensor(end_tokens).to(input_ids)\n", + " len_start = start_tokens.size(0)\n", + " len_end = end_tokens.size(0)\n", + " if c <= 512:\n", + " output = model(\n", + " input_ids=input_ids,\n", + " attention_mask=attention_mask,\n", + " output_attentions=True,\n", + " )\n", + " sequence_output = output[0]\n", + " attention = output[-1][-1]\n", + " else:\n", + " new_input_ids, new_attention_mask, num_seg = [], [], []\n", + " seq_len = attention_mask.sum(1).cpu().numpy().astype(np.int32).tolist()\n", + " for i, l_i in enumerate(seq_len):\n", + " if l_i <= 512:\n", + " new_input_ids.append(input_ids[i, :512])\n", + " new_attention_mask.append(attention_mask[i, :512])\n", + " num_seg.append(1)\n", + " else:\n", + " input_ids1 = torch.cat([input_ids[i, :512 - len_end], end_tokens], dim=-1)\n", + " input_ids2 = torch.cat([start_tokens, input_ids[i, (l_i - 512 + len_start): l_i]], dim=-1)\n", + " attention_mask1 = attention_mask[i, :512]\n", + " attention_mask2 = attention_mask[i, (l_i - 512): l_i]\n", + " new_input_ids.extend([input_ids1, input_ids2])\n", + " new_attention_mask.extend([attention_mask1, attention_mask2])\n", + " num_seg.append(2)\n", + " input_ids = torch.stack(new_input_ids, dim=0)\n", + " attention_mask = torch.stack(new_attention_mask, dim=0)\n", + " output = model(\n", + " input_ids=input_ids,\n", + " attention_mask=attention_mask,\n", + " output_attentions=True,\n", + " )\n", + " sequence_output = output[0]\n", + " attention = output[-1][-1]\n", + " i = 0\n", + " new_output, new_attention = [], []\n", + " for (n_s, l_i) in zip(num_seg, seq_len):\n", + " if n_s == 1:\n", + " output = F.pad(sequence_output[i], (0, 0, 0, c - 512))\n", + " att = F.pad(attention[i], (0, c - 512, 0, c - 512))\n", + " new_output.append(output)\n", + " new_attention.append(att)\n", + " elif n_s == 2:\n", + " output1 = sequence_output[i][:512 - len_end]\n", + " mask1 = attention_mask[i][:512 - len_end]\n", + " att1 = attention[i][:, :512 - len_end, :512 - len_end]\n", + " output1 = F.pad(output1, (0, 0, 0, c - 512 + len_end))\n", + " mask1 = F.pad(mask1, (0, c - 512 + len_end))\n", + " att1 = F.pad(att1, (0, c - 512 + len_end, 0, c - 512 + len_end))\n", + "\n", + " output2 = sequence_output[i + 1][len_start:]\n", + " mask2 = attention_mask[i + 1][len_start:]\n", + " att2 = attention[i + 1][:, len_start:, len_start:]\n", + " output2 = F.pad(output2, (0, 0, l_i - 512 + len_start, c - l_i))\n", + " mask2 = F.pad(mask2, (l_i - 512 + len_start, c - l_i))\n", + " att2 = F.pad(att2, [l_i - 512 + len_start, c - l_i, l_i - 512 + len_start, c - l_i])\n", + " mask = mask1 + mask2 + 1e-10\n", + " output = (output1 + output2) / mask.unsqueeze(-1)\n", + " att = (att1 + att2)\n", + " att = att / (att.sum(-1, keepdim=True) + 1e-10)\n", + " new_output.append(output)\n", + " new_attention.append(att)\n", + " i += n_s\n", + " sequence_output = torch.stack(new_output, dim=0)\n", + " attention = torch.stack(new_attention, dim=0)\n", + " return sequence_output, attention" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## 辅助函数" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "def set_seed(cfg):\n", + " random.seed(cfg.seed)\n", + " np.random.seed(cfg.seed)\n", + " torch.manual_seed(cfg.seed)\n", + " if cfg.n_gpu > 0 and torch.cuda.is_available():\n", + " torch.cuda.manual_seed_all(cfg.seed)\n", + "\n", + "def collate_fn(batch):\n", + " max_len = max([len(f[\"input_ids\"]) for f in batch])\n", + " input_ids = [f[\"input_ids\"] + [0] * (max_len - len(f[\"input_ids\"])) for f in batch]\n", + " input_mask = [[1.0] * len(f[\"input_ids\"]) + [0.0] * (max_len - len(f[\"input_ids\"])) for f in batch]\n", + " input_ids = torch.tensor(input_ids, dtype=torch.long)\n", + " input_mask = torch.tensor(input_mask, dtype=torch.float)\n", + " entity_pos = [f[\"entity_pos\"] for f in batch]\n", + "\n", + " labels = [f[\"labels\"] for f in batch]\n", + " hts = [f[\"hts\"] for f in batch]\n", + " output = (input_ids, input_mask, labels, entity_pos, hts )\n", + " return output\n", + "\n", + "def to_official(preds, features):\n", + " h_idx, t_idx, title = [], [], []\n", + "\n", + " for f in features:\n", + " hts = f[\"hts\"]\n", + " h_idx += [ht[0] for ht in hts]\n", + " t_idx += [ht[1] for ht in hts]\n", + " title += [f[\"title\"] for ht in hts]\n", + "\n", + " res = []\n", + "\n", + "\n", + " for i in range(preds.shape[0]):\n", + " pred = preds[i]\n", + " pred = np.nonzero(pred)[0].tolist()\n", + " for p in pred:\n", + " if p != 0:\n", + " res.append(\n", + " {\n", + " 'title': title[i],\n", + " 'h_idx': h_idx[i],\n", + " 't_idx': t_idx[i],\n", + " 'r': id2rel[p],\n", + " }\n", + " )\n", + " return res\n", + "\n", + "def report(cfg, model, features):\n", + "\n", + " dataloader = DataLoader(features, batch_size=cfg.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False)\n", + " preds = []\n", + " for batch in dataloader:\n", + " model.eval()\n", + " inputs = {'input_ids': batch[0].to(cfg.device),\n", + " 'attention_mask': batch[1].to(cfg.device),\n", + " 'entity_pos': batch[3],\n", + " 'hts': batch[4],\n", + " }\n", + "\n", + " with torch.no_grad():\n", + " pred, *_ = model(**inputs)\n", + " pred = pred.cpu().numpy()\n", + " pred[np.isnan(pred)] = 0\n", + " preds.append(pred)\n", + " preds = np.array(preds).astype(np.float32)\n", + " preds = to_official(preds, features)\n", + " return preds\n", + "\n", + "def logging(s, print_=True, log_=True):\n", + " if print_:\n", + " print(s)\n", + " if log_:\n", + " with open(cfg.log_dir, 'a+') as f_log:\n", + " f_log.write(s + '\\n')" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "## 模型训练实践" + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "### 模型参数配置" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "class Config(object):\n", + " adam_epsilon=1e-06\n", + " bert_lr=3e-05\n", + " channel_type='context-based'\n", + " config_name=''\n", + " data_dir='./data'\n", + " dataset='docred'\n", + " dev_file='dev.json'\n", + " down_dim=256\n", + " evaluation_steps=-1\n", + " gradient_accumulation_steps=2\n", + " learning_rate=0.0004\n", + " log_dir='./train_roberta.log'\n", + " max_grad_norm=1.0\n", + " max_height=42\n", + " max_seq_length=1024\n", + " model_name_or_path='roberta-base'\n", + " num_class=97\n", + " num_labels=4\n", + " num_train_epochs=30\n", + " save_path='./train_roberta.pt'\n", + " seed=111\n", + " test_batch_size=2\n", + " test_file='test.json'\n", + " tokenizer_name=''\n", + " train_batch_size=2\n", + " train_file='train_annotated.json'\n", + " train_from_saved_model=''\n", + " transformer_type='roberta'\n", + " unet_in_dim=3\n", + " unet_out_dim=256\n", + " warmup_ratio=0.06\n", + " load_path='./train_roberta.pt'\n", + " \n", + "cfg = Config()" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "### 模型训练" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "def finetune(features, optimizer, num_epoch, num_steps, model):\n", + " cur_model = model.module if hasattr(model, 'module') else model\n", + " if cfg.train_from_saved_model != '':\n", + " best_score = torch.load(cfg.train_from_saved_model)[\"best_f1\"]\n", + " epoch_delta = torch.load(cfg.train_from_saved_model)[\"epoch\"] + 1\n", + " else:\n", + " epoch_delta = 0\n", + " best_score = -1\n", + " train_dataloader = DataLoader(features, batch_size=cfg.train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)\n", + " train_iterator = [epoch + epoch_delta for epoch in range(num_epoch)]\n", + " total_steps = int(len(train_dataloader) * num_epoch // cfg.gradient_accumulation_steps)\n", + " warmup_steps = int(total_steps * cfg.warmup_ratio)\n", + " scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)\n", + " print(\"Total steps: {}\".format(total_steps))\n", + " print(\"Warmup steps: {}\".format(warmup_steps))\n", + " global_step = 0\n", + " log_step = 100\n", + " total_loss = 0\n", + " \n", + "\n", + "\n", + " #scaler = GradScaler()\n", + " for epoch in train_iterator:\n", + " start_time = time.time()\n", + " optimizer.zero_grad()\n", + "\n", + " for step, batch in enumerate(train_dataloader):\n", + " model.train()\n", + "\n", + " inputs = {'input_ids': batch[0].to(cfg.device),\n", + " 'attention_mask': batch[1].to(cfg.device),\n", + " 'labels': batch[2],\n", + " 'entity_pos': batch[3],\n", + " 'hts': batch[4],\n", + " }\n", + " #with autocast():\n", + " outputs = model(**inputs)\n", + " loss = outputs[0] / cfg.gradient_accumulation_steps\n", + " total_loss += loss.item()\n", + " # scaler.scale(loss).backward()\n", + " \n", + "\n", + " loss.backward()\n", + "\n", + " if step % cfg.gradient_accumulation_steps == 0:\n", + " #scaler.unscale_(optimizer)\n", + " if cfg.max_grad_norm > 0:\n", + " # torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), cfg.max_grad_norm)\n", + " torch.nn.utils.clip_grad_norm_(cur_model.parameters(), cfg.max_grad_norm)\n", + " #scaler.step(optimizer)\n", + " #scaler.update()\n", + " #scheduler.step()\n", + " optimizer.step()\n", + " scheduler.step()\n", + " optimizer.zero_grad()\n", + " global_step += 1\n", + " num_steps += 1\n", + " if global_step % log_step == 0:\n", + " cur_loss = total_loss / log_step\n", + " elapsed = time.time() - start_time\n", + " logging(\n", + " '| epoch {:2d} | step {:4d} | min/b {:5.2f} | lr {} | train loss {:5.3f}'.format(\n", + " epoch, global_step, elapsed / 60, scheduler.get_last_lr(), cur_loss * 1000))\n", + " total_loss = 0\n", + " start_time = time.time()\n", + "\n", + " if (step + 1) == len(train_dataloader) - 1 or (cfg.evaluation_steps > 0 and num_steps % cfg.evaluation_steps == 0 and step % cfg.gradient_accumulation_steps == 0):\n", + " # if step ==0:\n", + " logging('-' * 89)\n", + " eval_start_time = time.time()\n", + " dev_score, dev_output = evaluate(cfg, model, dev_features, tag=\"dev\")\n", + "\n", + " logging(\n", + " '| epoch {:3d} | time: {:5.2f}s | dev_result:{}'.format(epoch, time.time() - eval_start_time,\n", + " dev_output))\n", + " logging('-' * 89)\n", + " if dev_score > best_score:\n", + " best_score = dev_score\n", + " logging(\n", + " '| epoch {:3d} | best_f1:{}'.format(epoch, best_score))\n", + " pred = report(cfg, model, test_features)\n", + " with open(\"./submit_result/best_result.json\", \"w\") as fh:\n", + " json.dump(pred, fh)\n", + " if cfg.save_path != \"\":\n", + " torch.save({\n", + " 'epoch': epoch,\n", + " 'checkpoint': cur_model.state_dict(),\n", + " 'best_f1': best_score,\n", + " 'optimizer': optimizer.state_dict()\n", + " }, cfg.save_path\n", + " , _use_new_zipfile_serialization=False)\n", + " return num_steps\n", + "\n", + "def train(cfg, model, train_features, dev_features, test_features):\n", + " cur_model = model.module if hasattr(model, 'module') else model\n", + " extract_layer = [\"extractor\", \"bilinear\"]\n", + " bert_layer = ['bert_model']\n", + " optimizer_grouped_parameters = [\n", + " {\"params\": [p for n, p in cur_model.named_parameters() if any(nd in n for nd in bert_layer)], \"lr\": cfg.bert_lr},\n", + " {\"params\": [p for n, p in cur_model.named_parameters() if any(nd in n for nd in extract_layer)], \"lr\": 1e-4},\n", + " {\"params\": [p for n, p in cur_model.named_parameters() if not any(nd in n for nd in extract_layer + bert_layer)]},\n", + " ]\n", + "\n", + " optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate, eps=cfg.adam_epsilon)\n", + " if cfg.train_from_saved_model != '':\n", + " optimizer.load_state_dict(torch.load(cfg.train_from_saved_model)[\"optimizer\"])\n", + " print(\"load saved optimizer from {}.\".format(cfg.train_from_saved_model))\n", + " \n", + "\n", + " num_steps = 0\n", + " set_seed(cfg)\n", + " model.zero_grad()\n", + " finetune(train_features, optimizer, cfg.num_train_epochs, num_steps, model)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", + "cfg.n_gpu = torch.cuda.device_count()\n", + "cfg.device = device\n", + "\n", + "config = AutoConfig.from_pretrained(\n", + " cfg.config_name if cfg.config_name else cfg.model_name_or_path,\n", + " num_labels=cfg.num_class,\n", + ")\n", + "tokenizer = AutoTokenizer.from_pretrained(\n", + " cfg.tokenizer_name if cfg.tokenizer_name else cfg.model_name_or_path,\n", + ")\n", + "\n", + "Dataset = ReadDataset(cfg.dataset, tokenizer, cfg.max_seq_length)\n", + "\n", + "train_file = os.path.join(cfg.data_dir, cfg.train_file)\n", + "dev_file = os.path.join(cfg.data_dir, cfg.dev_file)\n", + "test_file = os.path.join(cfg.data_dir, cfg.test_file)\n", + "train_features = Dataset.read(train_file)\n", + "dev_features = Dataset.read(dev_file)\n", + "test_features = Dataset.read(test_file)\n", + "\n", + "model = AutoModel.from_pretrained(\n", + " cfg.model_name_or_path,\n", + " from_tf=bool(\".ckpt\" in cfg.model_name_or_path),\n", + " config=config,\n", + ")\n", + "\n", + "config.cls_token_id = tokenizer.cls_token_id\n", + "config.sep_token_id = tokenizer.sep_token_id\n", + "config.transformer_type = cfg.transformer_type\n", + "\n", + "set_seed(cfg)\n", + "model = DocREModel(config, cfg, model, num_labels=cfg.num_labels)\n", + "if cfg.train_from_saved_model != '':\n", + " model.load_state_dict(torch.load(cfg.train_from_saved_model)[\"checkpoint\"])\n", + " print(\"load saved model from {}.\".format(cfg.train_from_saved_model))\n", + "\n", + "if torch.cuda.device_count() > 1:\n", + " print(\"Let's use\", torch.cuda.device_count(), \"GPUs!\")\n", + " model = torch.nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))\n", + "model.to(device)\n", + "\n", + "\n", + "# Training\n", + "train(cfg, model, train_features, dev_features, test_features)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "### 模型预测输出" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "\n", + "model.load_state_dict(torch.load(cfg.load_path)['checkpoint'])\n", + "T_features = test_features # Testing on the test set\n", + "#T_score, T_output = evaluate(args, model, T_features, tag=\"test\")\n", + "pred = report(cfg, model, T_features)\n", + "print(pred)\n", + "with open(\"./result.json\", \"w\") as fh:\n", + " json.dump(pred, fh)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "本demo不包括调参部分,有兴趣的同学可以自行前往[DeepKE](https://github.com/zjunlp/DeepKE/tree/master)仓库,下载使用更多的模型 :)" + ], + "metadata": {} + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file