1287 lines
51 KiB
Plaintext
1287 lines
51 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# 低资源关系抽取实践\n",
|
||
"\n",
|
||
"> Tutorial作者: 黎洲波(zhoubo.li@zju.edu.cn)\n",
|
||
"\n",
|
||
"关系抽取(Relation Extraction, RE)任务是信息抽取中的关键任务,通常而言,输入是文本和若干命名实体,输出是命名实体对的关系。\n",
|
||
"\n",
|
||
"DeepKE中的低资源(Few-shot)关系抽取是基于Pre-train, Prompt, Predict范式,在之前的BERT架构中的Attention和Cross-attention部分引入了Prompt参数,之后对参数进行Fine-tuning,这种方法在低资源场景下表现良好。该方法中的Prompt-tuning原理如下图所示:\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"通过本次实践展示,我希望读者能够快速了解如何构建低资源关系抽取的模型。\n",
|
||
"\n",
|
||
"## 数据集\n",
|
||
"\n",
|
||
"低资源关系抽取常用的数据集有:RETACRED, SEMEVAL, TACREV和WIKI80等。本次实践使用的数据集是[SEMEVAL](https://semeval2.fbk.eu/semeval2.php?location=tasks#T11),SEMEVAL数据集来自于2010年的国际语义评测大会中Task 8:\"Multi-Way Classification of Semantic Relations Between Pairs of Nominals\",从官网下载的数据文件夹`./data/`的结构如下:\n",
|
||
"\n",
|
||
"```\n",
|
||
".\n",
|
||
"├── rel2id.json # 关系标签到ID的映射\n",
|
||
"├── temp.txt # 关系标签处理\n",
|
||
"├── test.txt # 测试集\n",
|
||
"├── train.txt # 训练集\n",
|
||
"└── val.txt # 验证集\n",
|
||
"```\n",
|
||
"\n",
|
||
"数据文件的格式在数据集SEMEVAL中的描述如下:\n",
|
||
"\n",
|
||
"```\n",
|
||
"Data Format:\n",
|
||
"{\n",
|
||
" 'token': [tokens in a sentence],\n",
|
||
" \"h\": {\n",
|
||
" \"name\": mention_name,\n",
|
||
" \"pos\" : [postion of mention in a sentence]\n",
|
||
" },\n",
|
||
" \"t\": {\n",
|
||
" \"name\": mention_name,\n",
|
||
" \"pos\" : [postion of mention in a sentence]\n",
|
||
" },\n",
|
||
" \"relation\": relation\n",
|
||
"}\n",
|
||
"```\n",
|
||
"数据集中一共包含9+1种relation,各类数据的占比如下图所示:\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"## KnowPrompt原理\n",
|
||
"此处,我们使用了能够对关系标签进行语义解析的Prompt方法,所以把这种关系抽取的方法叫Knowledge-aware Prompt-tuning(KnowPrompt)。Fine-tuning(图 a)、Prompt-tuning(图 b)和此处的 KnowPrompt(图 c)方法的模型架构如下图。Prompt中的答案词是指虚拟答案词。\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"那么接下来我们开始低资源关系抽取实践!"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 运行环境\n",
|
||
"\n",
|
||
"Python环境使用Python3,并且要求以下packages的版本:"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"!pip install numpy==1.20.3\n",
|
||
"!pip install tokenizers==0.10.3\n",
|
||
"!pip install torch==1.8.0\n",
|
||
"!pip install regex==2021.4.4\n",
|
||
"!pip install transformers==4.7.0\n",
|
||
"!pip install tqdm==4.49.0\n",
|
||
"!pip install activations==0.1.0\n",
|
||
"!pip install dataclasses==0.6\n",
|
||
"!pip install file_utils==0.0.1\n",
|
||
"!pip install flax==0.3.4\n",
|
||
"!pip install utils==1.0.1"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 导入模块"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"import os\n",
|
||
"import json\n",
|
||
"import csv\n",
|
||
"import time\n",
|
||
"import pickle\n",
|
||
"import logging\n",
|
||
"import shutil\n",
|
||
"import numpy as np\n",
|
||
"from tqdm import tqdm\n",
|
||
"from functools import partial\n",
|
||
"from collections import Counter\n",
|
||
"from collections import OrderedDict\n",
|
||
"\n",
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"from torch.nn.utils.rnn import pad_sequence\n",
|
||
"from torch.utils.data import DataLoader, TensorDataset\n",
|
||
"from transformers import AutoConfig, AutoModel, AutoTokenizer\n",
|
||
"from transformers.modeling_utils import PreTrainedModel\n",
|
||
"from transformers.optimization import AdamW, get_linear_schedule_with_warmup\n",
|
||
"from allennlp.modules.matrix_attention import DotProductMatrixAttention, CosineMatrixAttention, BilinearMatrixAttention\n",
|
||
"from transformers import BertTokenizer, BertForMaskedLM"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 参数配置"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"class Config(object):\n",
|
||
" accelerator=None\n",
|
||
" accumulate_grad_batches='1'\n",
|
||
" amp_backend='native'\n",
|
||
" amp_level='O2'\n",
|
||
" auto_lr_find=False\n",
|
||
" auto_scale_batch_size=False\n",
|
||
" auto_select_gpus=False\n",
|
||
" batch_size=16\n",
|
||
" benchmark=False\n",
|
||
" check_val_every_n_epoch='3'\n",
|
||
" checkpoint_callback=True\n",
|
||
" data_class='REDataset'\n",
|
||
" data_dir='data/k-shot/8-1'\n",
|
||
" default_root_dir=None\n",
|
||
" deterministic=False\n",
|
||
" devices=None\n",
|
||
" distributed_backend=None\n",
|
||
" fast_dev_run=False\n",
|
||
" flush_logs_every_n_steps=100\n",
|
||
" gpus=None\n",
|
||
" gradient_accumulation_steps=1\n",
|
||
" gradient_clip_algorithm='norm'\n",
|
||
" gradient_clip_val=0.0\n",
|
||
" ipus=None\n",
|
||
" limit_predict_batches=1.0\n",
|
||
" limit_test_batches=1.0\n",
|
||
" limit_train_batches=1.0\n",
|
||
" limit_val_batches=1.0\n",
|
||
" litmodel_class='BertLitModel'\n",
|
||
" load_checkpoint=None\n",
|
||
" log_dir=''\n",
|
||
" log_every_n_steps=50\n",
|
||
" log_gpu_memory=None\n",
|
||
" logger=True\n",
|
||
" lr=3e-05\n",
|
||
" lr_2=3e-05\n",
|
||
" max_epochs='30'\n",
|
||
" max_seq_length=256\n",
|
||
" max_steps=None\n",
|
||
" max_time=None\n",
|
||
" min_epochs=None\n",
|
||
" min_steps=None\n",
|
||
" model_class='BertForMaskedLM'\n",
|
||
" model_name_or_path='bert-large-uncased'\n",
|
||
" move_metrics_to_cpu=False\n",
|
||
" multiple_trainloader_mode='max_size_cycle'\n",
|
||
" num_nodes=1\n",
|
||
" num_processes=1\n",
|
||
" num_sanity_val_steps=2\n",
|
||
" num_train_epochs=30\n",
|
||
" num_workers=8\n",
|
||
" optimizer='AdamW'\n",
|
||
" overfit_batches=0.0\n",
|
||
" plugins=None\n",
|
||
" precision=32\n",
|
||
" prepare_data_per_node=True\n",
|
||
" process_position=0\n",
|
||
" profiler=None\n",
|
||
" progress_bar_refresh_rate=None\n",
|
||
" ptune_k=7\n",
|
||
" reload_dataloaders_every_epoch=False\n",
|
||
" reload_dataloaders_every_n_epochs=0\n",
|
||
" replace_sampler_ddp=True\n",
|
||
" resume_from_checkpoint=None\n",
|
||
" save_path=''\n",
|
||
" seed=666\n",
|
||
" stochastic_weight_avg=False\n",
|
||
" sync_batchnorm=Falset_lambda=0.001\n",
|
||
" task_name='wiki80'\n",
|
||
" terminate_on_nan=False\n",
|
||
" tpu_cores=None\n",
|
||
" track_grad_norm=-1\n",
|
||
" train_from_saved_model=''\n",
|
||
" truncated_bptt_steps=None\n",
|
||
" two_steps=False\n",
|
||
" use_prompt=True\n",
|
||
" val_check_interval=1.0\n",
|
||
" wandb=False\n",
|
||
" weight_decay=0.01\n",
|
||
" weights_save_path=None\n",
|
||
" weights_summary='top'\n",
|
||
" \n",
|
||
"cfg = Config()"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 数据集预处理"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"class InputExampleWiki80(object):\n",
|
||
" \"\"\"A single training/test example for span pair classification.\"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, guid, sentence, span1, span2, ner1, ner2, label):\n",
|
||
" self.guid = guid\n",
|
||
" self.sentence = sentence\n",
|
||
" self.span1 = span1\n",
|
||
" self.span2 = span2\n",
|
||
" self.ner1 = ner1\n",
|
||
" self.ner2 = ner2\n",
|
||
" self.label = label\n",
|
||
"\n",
|
||
"class DataProcessor(object):\n",
|
||
" \"\"\"Base class for data converters for sequence classification data sets.\"\"\"\n",
|
||
"\n",
|
||
" def get_train_examples(self, data_dir):\n",
|
||
" \"\"\"Gets a collection of `InputExample`s for the train set.\"\"\"\n",
|
||
" raise NotImplementedError()\n",
|
||
"\n",
|
||
" def get_dev_examples(self, data_dir):\n",
|
||
" \"\"\"Gets a collection of `InputExample`s for the dev set.\"\"\"\n",
|
||
" raise NotImplementedError()\n",
|
||
"\n",
|
||
" def get_labels(self):\n",
|
||
" \"\"\"Gets the list of labels for this data set.\"\"\"\n",
|
||
" raise NotImplementedError()\n",
|
||
"\n",
|
||
" @classmethod\n",
|
||
" def _read_tsv(cls, input_file, quotechar=None):\n",
|
||
" \"\"\"Reads a tab separated value file.\"\"\"\n",
|
||
" with open(input_file, \"r\") as f:\n",
|
||
" reader = csv.reader(f, delimiter=\"\\t\", quotechar=quotechar)\n",
|
||
" lines = []\n",
|
||
" for line in reader:\n",
|
||
" lines.append(line)\n",
|
||
" return lines\n",
|
||
"\n",
|
||
"class wiki80Processor(DataProcessor):\n",
|
||
" \"\"\"Processor for the TACRED data set.\"\"\"\n",
|
||
" def __init__(self, data_path, use_prompt):\n",
|
||
" super().__init__()\n",
|
||
" self.data_dir = data_path\n",
|
||
"\n",
|
||
" @classmethod\n",
|
||
" def _read_json(cls, input_file):\n",
|
||
" data = []\n",
|
||
" with open(input_file, \"r\", encoding='utf-8') as reader:\n",
|
||
" all_lines = reader.readlines()\n",
|
||
" for line in all_lines:\n",
|
||
" ins = eval(line)\n",
|
||
" data.append(ins)\n",
|
||
" return data\n",
|
||
"\n",
|
||
" def get_train_examples(self, data_dir):\n",
|
||
" \"\"\"See base class.\"\"\"\n",
|
||
" return self._create_examples(\n",
|
||
" self._read_json(os.path.join(data_dir, \"train.txt\")), \"train\")\n",
|
||
"\n",
|
||
" def get_dev_examples(self, data_dir):\n",
|
||
" \"\"\"See base class.\"\"\"\n",
|
||
" return self._create_examples(\n",
|
||
" self._read_json(os.path.join(data_dir, \"val.txt\")), \"dev\")\n",
|
||
"\n",
|
||
" def get_test_examples(self, data_dir):\n",
|
||
" \"\"\"See base class.\"\"\"\n",
|
||
" return self._create_examples(\n",
|
||
" self._read_json(os.path.join(data_dir, \"test.txt\")), \"test\")\n",
|
||
"\n",
|
||
" def get_labels(self, negative_label=\"no_relation\"):\n",
|
||
" data_dir = self.data_dir\n",
|
||
" \"\"\"See base class.\"\"\"\n",
|
||
" # if 'k-shot' in self.data_dir:\n",
|
||
" # data_dir = os.path.abspath(os.path.join(self.data_dir, \"../..\"))\n",
|
||
" # else:\n",
|
||
" # data_dir = self.data_dir\n",
|
||
" with open(os.path.join(data_dir,'rel2id.json'), \"r\", encoding='utf-8') as reader:\n",
|
||
" re2id = json.load(reader)\n",
|
||
" return re2id\n",
|
||
"\n",
|
||
"\n",
|
||
" def _create_examples(self, dataset, set_type):\n",
|
||
" \"\"\"Creates examples for the training and dev sets.\"\"\"\n",
|
||
" examples = []\n",
|
||
" for example in dataset:\n",
|
||
" sentence = example['token']\n",
|
||
" examples.append(InputExampleWiki80(guid=None,\n",
|
||
" sentence=sentence,\n",
|
||
" # maybe some bugs here, I don't -1\n",
|
||
" span1=(example['h']['pos'][0], example['h']['pos'][1]),\n",
|
||
" span2=(example['t']['pos'][0], example['t']['pos'][1]),\n",
|
||
" ner1=None,\n",
|
||
" ner2=None,\n",
|
||
" label=example['relation']))\n",
|
||
" return examples\n",
|
||
"\n",
|
||
"class Config(dict):\n",
|
||
" def __getattr__(self, name):\n",
|
||
" return self.get(name)\n",
|
||
"\n",
|
||
" def __setattr__(self, name, val):\n",
|
||
" self[name] = val\n",
|
||
"\n",
|
||
"\n",
|
||
"BATCH_SIZE = 8\n",
|
||
"NUM_WORKERS = 8\n",
|
||
"\n",
|
||
"\n",
|
||
"class BaseDataModule(nn.Module):\n",
|
||
" \"\"\"\n",
|
||
" Base DataModule.\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, cfg) -> None:\n",
|
||
" super().__init__()\n",
|
||
" self.cfg = Config(vars(cfg)) if cfg is not None else {}\n",
|
||
" self.batch_size = self.cfg.get(\"batch_size\", BATCH_SIZE)\n",
|
||
" self.num_workers = self.cfg.get(\"num_workers\", NUM_WORKERS)\n",
|
||
"\n",
|
||
"\n",
|
||
" @staticmethod\n",
|
||
" def add_to_argparse(parser):\n",
|
||
" parser.add_argument(\n",
|
||
" \"--batch_size\", type=int, default=BATCH_SIZE, help=\"Number of examples to operate on per forward step.\"\n",
|
||
" )\n",
|
||
" parser.add_argument(\n",
|
||
" \"--num_workers\", type=int, default=NUM_WORKERS, help=\"Number of additional processes to load data.\"\n",
|
||
" )\n",
|
||
" parser.add_argument(\n",
|
||
" \"--data_dir\", type=str, default=\"./dataset/dialogue\", help=\"Number of additional processes to load data.\"\n",
|
||
" )\n",
|
||
" return parser\n",
|
||
"\n",
|
||
" def get_data_config(self):\n",
|
||
" \"\"\"Return important settings of the dataset, which will be passed to instantiate models.\"\"\"\n",
|
||
" return { \"num_labels\": self.num_labels}\n",
|
||
"\n",
|
||
" def prepare_data(self):\n",
|
||
" \"\"\"\n",
|
||
" Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).\n",
|
||
" \"\"\"\n",
|
||
" pass\n",
|
||
"\n",
|
||
" def setup(self, stage=None):\n",
|
||
" \"\"\"\n",
|
||
" Split into train, val, test, and set dims.\n",
|
||
" Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.\n",
|
||
" \"\"\"\n",
|
||
" self.data_train = None\n",
|
||
" self.data_val = None\n",
|
||
" self.data_test = None\n",
|
||
"\n",
|
||
" def train_dataloader(self):\n",
|
||
" return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)\n",
|
||
"\n",
|
||
" def val_dataloader(self):\n",
|
||
" return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)\n",
|
||
"\n",
|
||
" def test_dataloader(self):\n",
|
||
" return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)\n",
|
||
"\n",
|
||
"\n",
|
||
"class REDataset(BaseDataModule):\n",
|
||
" def __init__(self, cfg) -> None:\n",
|
||
" super().__init__(cfg)\n",
|
||
" self.processor = wiki80Processor(self.cfg.data_dir, self.cfg.use_prompt)\n",
|
||
" self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name_or_path)\n",
|
||
" \n",
|
||
" \n",
|
||
" use_gpt = \"gpt\" in cfg.model_name_or_path\n",
|
||
"\n",
|
||
" rel2id = self.processor.get_labels()\n",
|
||
" self.num_labels = len(rel2id)\n",
|
||
"\n",
|
||
" entity_list = [\"[object_start]\", \"[object_end]\", \"[subject_start]\", \"[subject_end]\"]\n",
|
||
" class_list = [f\"[class{i}]\" for i in range(1, self.num_labels+1)]\n",
|
||
"\n",
|
||
" num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': entity_list})\n",
|
||
" num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': class_list})\n",
|
||
" if use_gpt:\n",
|
||
" self.tokenizer.add_special_tokens({'cls_token': \"[CLS]\"})\n",
|
||
" self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
|
||
" so_list = [\"[sub]\", \"[obj]\"]\n",
|
||
" num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': so_list})\n",
|
||
"\n",
|
||
" prompt_tokens = [f\"[T{i}]\" for i in range(1,6)]\n",
|
||
" self.tokenizer.add_special_tokens({'additional_special_tokens': prompt_tokens})"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 计算指标函数"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"def dialog_f1_eval(logits, labels):\n",
|
||
" def getpred(result, T1=0.5, T2=0.4):\n",
|
||
" # 使用阈值得到preds, result = logits\n",
|
||
" # T2 表示如果都低于T2 那么就是 no relation, 否则选取一个最大的\n",
|
||
" ret = []\n",
|
||
" for i in range(len(result)):\n",
|
||
" r = []\n",
|
||
" maxl, maxj = -1, -1\n",
|
||
" for j in range(len(result[i])):\n",
|
||
" if result[i][j] > T1:\n",
|
||
" r += [j]\n",
|
||
" if result[i][j] > maxl:\n",
|
||
" maxl = result[i][j]\n",
|
||
" maxj = j\n",
|
||
" if len(r) == 0:\n",
|
||
" if maxl <= T2:\n",
|
||
" r = [36]\n",
|
||
" else:\n",
|
||
" r += [maxj]\n",
|
||
" ret.append(r)\n",
|
||
" return ret\n",
|
||
"\n",
|
||
" def geteval(devp, data):\n",
|
||
" correct_sys, all_sys = 0, 0\n",
|
||
" correct_gt = 0\n",
|
||
"\n",
|
||
" for i in range(len(data)):\n",
|
||
" # 每一个样本 都是[1,4,...,20] 表示有1,4,20 是1, 如果没有就是[36]\n",
|
||
" for id in data[i]:\n",
|
||
" if id != 36:\n",
|
||
" # 标签中 1 的个数\n",
|
||
" correct_gt += 1\n",
|
||
" if id in devp[i]:\n",
|
||
" # 预测正确\n",
|
||
" correct_sys += 1\n",
|
||
"\n",
|
||
" for id in devp[i]:\n",
|
||
" if id != 36:\n",
|
||
" all_sys += 1\n",
|
||
"\n",
|
||
" precision = 1 if all_sys == 0 else correct_sys / all_sys\n",
|
||
" recall = 0 if correct_gt == 0 else correct_sys / correct_gt\n",
|
||
" f_1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0\n",
|
||
" return f_1\n",
|
||
"\n",
|
||
" logits = np.asarray(logits)\n",
|
||
" logits = list(1 / (1 + np.exp(-logits)))\n",
|
||
"\n",
|
||
" temp_labels = []\n",
|
||
" for l in labels:\n",
|
||
" t = []\n",
|
||
" for i in range(36):\n",
|
||
" if l[i] == 1:\n",
|
||
" t += [i]\n",
|
||
" if len(t) == 0:\n",
|
||
" t = [36]\n",
|
||
" temp_labels.append(t)\n",
|
||
" assert (len(labels) == len(logits))\n",
|
||
" labels = temp_labels\n",
|
||
"\n",
|
||
" bestT2 = bestf_1 = 0\n",
|
||
" for T2 in range(51):\n",
|
||
" devp = getpred(logits, T2=T2 / 100.)\n",
|
||
" f_1 = geteval(devp, labels)\n",
|
||
" if f_1 > bestf_1:\n",
|
||
" bestf_1 = f_1\n",
|
||
" bestT2 = T2 / 100.\n",
|
||
"\n",
|
||
" return dict(f1=bestf_1, T2=bestT2)\n",
|
||
"\n",
|
||
"\n",
|
||
"\n",
|
||
"def f1_eval(logits, labels):\n",
|
||
" def getpred(result, T1 = 0.5, T2 = 0.4) :\n",
|
||
" # 使用阈值得到preds, result = logits\n",
|
||
" # T2 表示如果都低于T2 那么就是 no relation, 否则选取一个最大的\n",
|
||
" ret = []\n",
|
||
" for i in range(len(result)):\n",
|
||
" r = []\n",
|
||
" maxl, maxj = -1, -1\n",
|
||
" for j in range(len(result[i])):\n",
|
||
" if result[i][j] > T1:\n",
|
||
" r += [j]\n",
|
||
" if result[i][j] > maxl:\n",
|
||
" maxl = result[i][j]\n",
|
||
" maxj = j\n",
|
||
" if len(r) == 0:\n",
|
||
" if maxl <= T2:\n",
|
||
" r = [36]\n",
|
||
" else:\n",
|
||
" r += [maxj]\n",
|
||
" ret.append(r)\n",
|
||
" return ret\n",
|
||
"\n",
|
||
" def geteval(devp, data):\n",
|
||
" correct_sys, all_sys = 0, 0\n",
|
||
" correct_gt = 0\n",
|
||
" \n",
|
||
" for i in range(len(data)):\n",
|
||
" # 每一个样本 都是[1,4,...,20] 表示有1,4,20 是1, 如果没有就是[36]\n",
|
||
" for id in data[i]:\n",
|
||
" if id != 36:\n",
|
||
" # 标签中 1 的个数\n",
|
||
" correct_gt += 1\n",
|
||
" if id in devp[i]:\n",
|
||
" # 预测正确\n",
|
||
" correct_sys += 1\n",
|
||
"\n",
|
||
" for id in devp[i]:\n",
|
||
" if id != 36:\n",
|
||
" all_sys += 1\n",
|
||
"\n",
|
||
" precision = 1 if all_sys == 0 else correct_sys/all_sys\n",
|
||
" recall = 0 if correct_gt == 0 else correct_sys/correct_gt\n",
|
||
" f_1 = 2*precision*recall/(precision+recall) if precision+recall != 0 else 0\n",
|
||
" return f_1\n",
|
||
"\n",
|
||
" logits = np.asarray(logits)\n",
|
||
" logits = list(1 / (1 + np.exp(-logits)))\n",
|
||
"\n",
|
||
" temp_labels = []\n",
|
||
" for l in labels:\n",
|
||
" t = []\n",
|
||
" for i in range(36):\n",
|
||
" if l[i] == 1:\n",
|
||
" t += [i]\n",
|
||
" if len(t) == 0:\n",
|
||
" t = [36]\n",
|
||
" temp_labels.append(t)\n",
|
||
" assert(len(labels) == len(logits))\n",
|
||
" labels = temp_labels\n",
|
||
" \n",
|
||
" bestT2 = bestf_1 = 0\n",
|
||
" for T2 in range(51):\n",
|
||
" devp = getpred(logits, T2=T2/100.)\n",
|
||
" f_1 = geteval(devp, labels)\n",
|
||
" if f_1 > bestf_1:\n",
|
||
" bestf_1 = f_1\n",
|
||
" bestT2 = T2/100.\n",
|
||
"\n",
|
||
" return bestf_1, bestT2\n",
|
||
"\n",
|
||
"\n",
|
||
"def f1_score(output, label, rel_num=42, na_num=13):\n",
|
||
" correct_by_relation = Counter()\n",
|
||
" guess_by_relation = Counter()\n",
|
||
" gold_by_relation = Counter()\n",
|
||
" output = np.argmax(output, axis=-1)\n",
|
||
"\n",
|
||
" for i in range(len(output)):\n",
|
||
" guess = output[i]\n",
|
||
" gold = label[i]\n",
|
||
"\n",
|
||
" if guess == na_num:\n",
|
||
" guess = 0\n",
|
||
" elif guess < na_num:\n",
|
||
" guess += 1\n",
|
||
"\n",
|
||
" if gold == na_num:\n",
|
||
" gold = 0\n",
|
||
" elif gold < na_num:\n",
|
||
" gold += 1\n",
|
||
"\n",
|
||
" if gold == 0 and guess == 0:\n",
|
||
" continue\n",
|
||
" if gold == 0 and guess != 0:\n",
|
||
" guess_by_relation[guess] += 1\n",
|
||
" if gold != 0 and guess == 0:\n",
|
||
" gold_by_relation[gold] += 1\n",
|
||
" if gold != 0 and guess != 0:\n",
|
||
" guess_by_relation[guess] += 1\n",
|
||
" gold_by_relation[gold] += 1\n",
|
||
" if gold == guess:\n",
|
||
" correct_by_relation[gold] += 1\n",
|
||
" \n",
|
||
" f1_by_relation = Counter()\n",
|
||
" recall_by_relation = Counter()\n",
|
||
" prec_by_relation = Counter()\n",
|
||
" for i in range(1, rel_num):\n",
|
||
" recall = 0\n",
|
||
" if gold_by_relation[i] > 0:\n",
|
||
" recall = correct_by_relation[i] / gold_by_relation[i]\n",
|
||
" precision = 0\n",
|
||
" if guess_by_relation[i] > 0:\n",
|
||
" precision = correct_by_relation[i] / guess_by_relation[i]\n",
|
||
" if recall + precision > 0 :\n",
|
||
" f1_by_relation[i] = 2 * recall * precision / (recall + precision)\n",
|
||
" recall_by_relation[i] = recall\n",
|
||
" prec_by_relation[i] = precision\n",
|
||
"\n",
|
||
" micro_f1 = 0\n",
|
||
" if sum(guess_by_relation.values()) != 0 and sum(correct_by_relation.values()) != 0:\n",
|
||
" recall = sum(correct_by_relation.values()) / sum(gold_by_relation.values())\n",
|
||
" prec = sum(correct_by_relation.values()) / sum(guess_by_relation.values()) \n",
|
||
" micro_f1 = 2 * recall * prec / (recall+prec)\n",
|
||
"\n",
|
||
" return dict(f1=micro_f1)"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 模型构建"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"### 模型基类"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"OPTIMIZER = \"AdamW\"\n",
|
||
"LR = 5e-5\n",
|
||
"LOSS = \"cross_entropy\"\n",
|
||
"ONE_CYCLE_TOTAL_STEPS = 100\n",
|
||
"\n",
|
||
"class Config(dict):\n",
|
||
" def __getattr__(self, name):\n",
|
||
" return self.get(name)\n",
|
||
"\n",
|
||
" def __setattr__(self, name, val):\n",
|
||
" self[name] = val\n",
|
||
"\n",
|
||
"\n",
|
||
"class BaseLitModel(nn.Module):\n",
|
||
" \"\"\"\n",
|
||
" Generic PyTorch-Lightning class that must be initialized with a PyTorch module.\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, model, cfg: argparse.Namespace = None, device: torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') = str):\n",
|
||
" super().__init__()\n",
|
||
" self.model = model\n",
|
||
" self.cur_model = model.module if hasattr(model, 'module') else model\n",
|
||
" self.device = device\n",
|
||
" self.cfg = Config(vars(cfg)) if cfg is not None else {}\n",
|
||
"\n",
|
||
" optimizer = self.cfg.get(\"optimizer\", OPTIMIZER)\n",
|
||
" self.optimizer_class = getattr(torch.optim, optimizer)\n",
|
||
" self.lr = self.cfg.get(\"lr\", LR)\n",
|
||
"\n",
|
||
"\n",
|
||
" @staticmethod\n",
|
||
" def add_to_argparse(parser):\n",
|
||
" parser.add_argument(\"--optimizer\", type=str, default=OPTIMIZER, help=\"optimizer class from torch.optim\")\n",
|
||
" parser.add_argument(\"--lr\", type=float, default=LR)\n",
|
||
" parser.add_argument(\"--weight_decay\", type=float, default=0.01)\n",
|
||
" return parser\n",
|
||
"\n",
|
||
" def configure_optimizers(self):\n",
|
||
" optimizer = self.optimizer_class(self.parameters(), lr=self.lr)\n",
|
||
" if self.one_cycle_max_lr is None:\n",
|
||
" return optimizer\n",
|
||
" scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)\n",
|
||
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler, \"monitor\": \"val_loss\"}\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" return self.model(x)\n",
|
||
"\n",
|
||
" def training_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
|
||
" x, y = batch\n",
|
||
" x.to(self.device)\n",
|
||
" logits = x\n",
|
||
" loss = (logits - y) ** 2\n",
|
||
" print(\"train_loss: \", loss)\n",
|
||
" #self.train_acc(logits, y)\n",
|
||
" #self.log(\"train_acc\", self.train_acc, on_step=False, on_epoch=True)\n",
|
||
" return loss\n",
|
||
"\n",
|
||
" def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
|
||
" x, y = batch\n",
|
||
" x.to(self.device)\n",
|
||
" logits = x\n",
|
||
" loss = (logits - y) ** 2\n",
|
||
" print(\"val_loss: \", loss)\n",
|
||
"\n",
|
||
" def test_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
|
||
" x, y = batch\n",
|
||
" x.to(self.device)\n",
|
||
" logits = x\n",
|
||
" loss = (logits - y) ** 2\n",
|
||
" print(\"test_loss: \", loss)\n",
|
||
"\n",
|
||
" def configure_optimizers(self):\n",
|
||
" no_decay_param = [\"bias\", \"LayerNorm.weight\"]\n",
|
||
"\n",
|
||
" optimizer_group_parameters = [\n",
|
||
" {\"params\": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay_param)], \"weight_decay\": self.cfg.weight_decay},\n",
|
||
" {\"params\": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay_param)], \"weight_decay\": 0}\n",
|
||
" ]\n",
|
||
"\n",
|
||
" \n",
|
||
" optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)\n",
|
||
" #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * 0.1, num_training_steps=self.num_training_steps)\n",
|
||
" return optimizer"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"### 模型子类"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"def multilabel_categorical_crossentropy(y_pred, y_true):\n",
|
||
" y_pred = (1 - 2 * y_true) * y_pred\n",
|
||
" y_pred_neg = y_pred - y_true * 1e12\n",
|
||
" y_pred_pos = y_pred - (1 - y_true) * 1e12\n",
|
||
" zeros = torch.zeros_like(y_pred[..., :1])\n",
|
||
" y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)\n",
|
||
" y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)\n",
|
||
" neg_loss = torch.logsumexp(y_pred_neg, dim=-1)\n",
|
||
" pos_loss = torch.logsumexp(y_pred_pos, dim=-1)\n",
|
||
" return (neg_loss + pos_loss).mean()\n",
|
||
"\n",
|
||
"class BertLitModel(BaseLitModel):\n",
|
||
" \"\"\"\n",
|
||
" use AutoModelForMaskedLM, and select the output by another layer in the lit model\n",
|
||
" \"\"\"\n",
|
||
" def __init__(self, model, cfg, tokenizer):\n",
|
||
" super().__init__(model, cfg)\n",
|
||
" self.tokenizer = tokenizer\n",
|
||
" \n",
|
||
" with open(f\"{cfg.data_dir}/rel2id.json\",\"r\") as file:\n",
|
||
" rel2id = json.load(file)\n",
|
||
" \n",
|
||
" Na_num = 0\n",
|
||
" for k, v in rel2id.items():\n",
|
||
" if k == \"NA\" or k == \"no_relation\" or k == \"Other\":\n",
|
||
" Na_num = v\n",
|
||
" break\n",
|
||
" num_relation = len(rel2id)\n",
|
||
" # init loss function\n",
|
||
" self.loss_fn = multilabel_categorical_crossentropy if \"dialogue\" in cfg.data_dir else nn.CrossEntropyLoss()\n",
|
||
" # ignore the no_relation class to compute the f1 score\n",
|
||
" self.eval_fn = f1_eval if \"dialogue\" in cfg.data_dir else partial(f1_score, rel_num=num_relation, na_num=Na_num)\n",
|
||
" self.best_f1 = 0\n",
|
||
" self.t_lambda = cfg.t_lambda\n",
|
||
" \n",
|
||
" self.label_st_id = tokenizer(\"[class1]\", add_special_tokens=False)['input_ids'][0]\n",
|
||
" \n",
|
||
" self._init_label_word()\n",
|
||
"\n",
|
||
" def _init_label_word(self):\n",
|
||
" cfg = self.cfg\n",
|
||
" # ./dataset/dataset_name\n",
|
||
" dataset_name = cfg.data_dir.split(\"/\")[1]\n",
|
||
" model_name_or_path = cfg.model_name_or_path.split(\"/\")[-1]\n",
|
||
" label_path = f\"./dataset/{model_name_or_path}_{dataset_name}.pt\"\n",
|
||
" # [num_labels, num_tokens], ignore the unanswerable\n",
|
||
" if \"dialogue\" in cfg.data_dir:\n",
|
||
" label_word_idx = torch.load(label_path)[:-1]\n",
|
||
" else:\n",
|
||
" label_word_idx = torch.load(label_path)\n",
|
||
" \n",
|
||
" num_labels = len(label_word_idx)\n",
|
||
" \n",
|
||
" self.cur_model.resize_token_embeddings(len(self.tokenizer))\n",
|
||
" with torch.no_grad():\n",
|
||
" word_embeddings = self.cur_model.get_input_embeddings()\n",
|
||
" continous_label_word = [a[0] for a in self.tokenizer([f\"[class{i}]\" for i in range(1, num_labels+1)], add_special_tokens=False)['input_ids']]\n",
|
||
" for i, idx in enumerate(label_word_idx):\n",
|
||
" word_embeddings.weight[continous_label_word[i]] = torch.mean(word_embeddings.weight[idx], dim=0)\n",
|
||
" # word_embeddings.weight[continous_label_word[i]] = self.relation_embedding[i]\n",
|
||
" so_word = [a[0] for a in self.tokenizer([\"[obj]\",\"[sub]\"], add_special_tokens=False)['input_ids']]\n",
|
||
" meaning_word = [a[0] for a in self.tokenizer([\"person\",\"organization\", \"location\", \"date\", \"country\"], add_special_tokens=False)['input_ids']]\n",
|
||
" \n",
|
||
" for i, idx in enumerate(so_word):\n",
|
||
" word_embeddings.weight[so_word[i]] = torch.mean(word_embeddings.weight[meaning_word], dim=0)\n",
|
||
" assert torch.equal(self.cur_model.get_input_embeddings().weight, word_embeddings.weight)\n",
|
||
" assert torch.equal(self.cur_model.get_input_embeddings().weight, self.cur_model.get_output_embeddings().weight)\n",
|
||
" \n",
|
||
" self.word2label = continous_label_word # a continous list\n",
|
||
" \n",
|
||
" \n",
|
||
" def forward(self, x):\n",
|
||
" return self.model(x)\n",
|
||
"\n",
|
||
" def training_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
|
||
" input_ids, attention_mask, token_type_ids , labels, so = batch\n",
|
||
" input_ids = input_ids.to(self.device)\n",
|
||
" attention_mask = attention_mask.to(self.device)\n",
|
||
" token_type_ids = token_type_ids.to(self.device)\n",
|
||
" labels = labels.to(self.device)\n",
|
||
" so = so.to(self.device)\n",
|
||
" result = self.model(input_ids, attention_mask, token_type_ids, return_dict=True, output_hidden_states=True)\n",
|
||
" logits = result.logits\n",
|
||
" output_embedding = result.hidden_states[-1]\n",
|
||
" logits = self.pvp(logits, input_ids)\n",
|
||
" loss = self.loss_fn(logits, labels) + self.t_lambda * self.ke_loss(output_embedding, labels, so)\n",
|
||
" #print(\"Train/loss: \", loss)\n",
|
||
" return loss\n",
|
||
"\n",
|
||
" def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
|
||
" input_ids, attention_mask, token_type_ids , labels, _ = batch\n",
|
||
" input_ids = input_ids.to(self.device)\n",
|
||
" attention_mask = attention_mask.to(self.device)\n",
|
||
" token_type_ids = token_type_ids.to(self.device)\n",
|
||
" labels = labels.to(self.device)\n",
|
||
" logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits\n",
|
||
" logits = self.pvp(logits, input_ids)\n",
|
||
" loss = self.loss_fn(logits, labels)\n",
|
||
" #print(\"Eval/loss: \", loss)\n",
|
||
" return {\"loss\": loss, \"eval_logits\": logits.detach().cpu().numpy(), \"eval_labels\": labels.detach().cpu().numpy()}\n",
|
||
" \n",
|
||
" def validation_epoch_end(self, outputs):\n",
|
||
" logits = np.concatenate([o[\"eval_logits\"] for o in outputs])\n",
|
||
" labels = np.concatenate([o[\"eval_labels\"] for o in outputs])\n",
|
||
"\n",
|
||
" f1 = self.eval_fn(logits, labels)['f1']\n",
|
||
" #print(\"Eval/f1: \", f1)\n",
|
||
" best_f1 = -1\n",
|
||
" if f1 > self.best_f1:\n",
|
||
" self.best_f1 = f1\n",
|
||
" best_f1 = self.best_f1\n",
|
||
" #print(\"Eval/best_f1: \", self.best_f1)\n",
|
||
" return f1, best_f1, self.best_f1\n",
|
||
"\n",
|
||
" def test_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
|
||
" input_ids, attention_mask, token_type_ids , labels, _ = batch\n",
|
||
" input_ids = input_ids.to(self.device)\n",
|
||
" attention_mask = attention_mask.to(self.device)\n",
|
||
" token_type_ids = token_type_ids.to(self.device)\n",
|
||
" labels = labels.to(self.device)\n",
|
||
" logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits\n",
|
||
" logits = self.pvp(logits, input_ids)\n",
|
||
" return {\"test_logits\": logits.detach().cpu().numpy(), \"test_labels\": labels.detach().cpu().numpy()}\n",
|
||
"\n",
|
||
" def test_epoch_end(self, outputs):\n",
|
||
" logits = np.concatenate([o[\"test_logits\"] for o in outputs])\n",
|
||
" labels = np.concatenate([o[\"test_labels\"] for o in outputs])\n",
|
||
"\n",
|
||
" f1 = self.eval_fn(logits, labels)['f1']\n",
|
||
" #print(\"Test/f1: \", f1)\n",
|
||
" return f1\n",
|
||
"\n",
|
||
"\n",
|
||
" @staticmethod\n",
|
||
" def add_to_argparse(parser):\n",
|
||
" BaseLitModel.add_to_argparse(parser)\n",
|
||
" parser.add_argument(\"--t_lambda\", type=float, default=0.01, help=\"\")\n",
|
||
" return parser\n",
|
||
" \n",
|
||
" def pvp(self, logits, input_ids):\n",
|
||
" # convert the [batch_size, seq_len, vocab_size] => [batch_size, num_labels]\n",
|
||
" #! hard coded\n",
|
||
" _, mask_idx = (input_ids == 103).nonzero(as_tuple=True)\n",
|
||
" bs = input_ids.shape[0]\n",
|
||
" mask_output = logits[torch.arange(bs), mask_idx]\n",
|
||
" assert mask_idx.shape[0] == bs, \"only one mask in sequence!\"\n",
|
||
" final_output = mask_output[:,self.word2label]\n",
|
||
" \n",
|
||
" return final_output\n",
|
||
" \n",
|
||
" def ke_loss(self, logits, labels, so):\n",
|
||
" subject_embedding = []\n",
|
||
" object_embedding = []\n",
|
||
" bsz = logits.shape[0]\n",
|
||
" for i in range(bsz):\n",
|
||
" subject_embedding.append(torch.mean(logits[i, so[i][0]:so[i][1]], dim=0))\n",
|
||
" object_embedding.append(torch.mean(logits[i, so[i][2]:so[i][3]], dim=0))\n",
|
||
" \n",
|
||
" subject_embedding = torch.stack(subject_embedding)\n",
|
||
" object_embedding = torch.stack(object_embedding)\n",
|
||
" # trick , the relation ids is concated, \n",
|
||
" relation_embedding = self.cur_model.get_output_embeddings().weight[labels+self.label_st_id]\n",
|
||
" \n",
|
||
" loss = torch.norm(subject_embedding + relation_embedding - object_embedding, p=2)\n",
|
||
" \n",
|
||
" return loss\n",
|
||
"\n",
|
||
" def configure_optimizers(self):\n",
|
||
" no_decay_param = [\"bias\", \"LayerNorm.weight\"]\n",
|
||
"\n",
|
||
" if not self.cfg.two_steps: \n",
|
||
" parameters = self.cur_model.named_parameters()\n",
|
||
" else:\n",
|
||
" # cur_model.bert.embeddings.weight\n",
|
||
" parameters = [next(self.cur_model.named_parameters())]\n",
|
||
" # only optimize the embedding parameters\n",
|
||
" optimizer_group_parameters = [\n",
|
||
" {\"params\": [p for n, p in parameters if not any(nd in n for nd in no_decay_param)], \"weight_decay\": self.cfg.weight_decay},\n",
|
||
" {\"params\": [p for n, p in parameters if any(nd in n for nd in no_decay_param)], \"weight_decay\": 0}\n",
|
||
" ]\n",
|
||
"\n",
|
||
" \n",
|
||
" optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)\n",
|
||
" return optimizer"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 输入预处理"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"### 少样本采集"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"Seed = [1, 2, 3, 4, 5]\n",
|
||
"mode = 'k-shot'\n",
|
||
"data_file = 'train.txt'\n",
|
||
"\n",
|
||
"def get_labels(path, name, negative_label=\"no_relation\"):\n",
|
||
" \"\"\"See base class.\"\"\"\n",
|
||
"\n",
|
||
" count = Counter()\n",
|
||
" with open(path + \"/\" + name, \"r\") as f:\n",
|
||
" features = []\n",
|
||
" for line in f.readlines():\n",
|
||
" line = line.rstrip()\n",
|
||
" if len(line) > 0:\n",
|
||
" # count[line['relation']] += 1\n",
|
||
" features.append(eval(line))\n",
|
||
"\n",
|
||
" # logger.info(\"label distribution as list: %d labels\" % len(count))\n",
|
||
" # # Make sure the negative label is alwyas 0\n",
|
||
" # labels = []\n",
|
||
" # for label, count in count.most_common():\n",
|
||
" # logger.info(\"%s: %d 个 %.2f%%\" % (label, count, count * 100.0 / len(dataset)))\n",
|
||
" # if label not in labels:\n",
|
||
" # labels.append(label)\n",
|
||
" return features\n",
|
||
"\n",
|
||
"path = 'data'\n",
|
||
"\n",
|
||
"output_dir = os.path.join(path, mode)\n",
|
||
"dataset = get_labels(path, data_file)\n",
|
||
"\n",
|
||
"for seed in Seed:\n",
|
||
"\n",
|
||
" # Other datasets\n",
|
||
" np.random.seed(seed)\n",
|
||
" np.random.shuffle(dataset)\n",
|
||
"\n",
|
||
" # Set up dir\n",
|
||
" k = 8\n",
|
||
" setting_dir = os.path.join(output_dir, f\"{k}-{seed}\")\n",
|
||
" os.makedirs(setting_dir, exist_ok=True)\n",
|
||
"\n",
|
||
" label_list = {}\n",
|
||
" for line in dataset:\n",
|
||
" label = line['relation']\n",
|
||
" if label not in label_list:\n",
|
||
" label_list[label] = [line]\n",
|
||
" else:\n",
|
||
" label_list[label].append(line)\n",
|
||
"\n",
|
||
" with open(os.path.join(setting_dir, \"train.txt\"), \"w\") as f:\n",
|
||
" file_list = []\n",
|
||
" for label in label_list:\n",
|
||
" for line in label_list[label][:k]: # train中每一类取前k个数据\n",
|
||
" f.writelines(json.dumps(line))\n",
|
||
" f.write('\\n')\n",
|
||
"\n",
|
||
" f.close()\n",
|
||
"\n",
|
||
"shutil.copyfile('data/rel2id.json','data/k-shot/8-1/rel2id.json')\n",
|
||
"shutil.copyfile('data/val.txt','data/k-shot/8-1/val.txt')\n",
|
||
"shutil.copyfile('data/test.txt','data/k-shot/8-1/test.txt')"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"### 获取标签"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"def split_label_words(tokenizer, label_list):\n",
|
||
" label_word_list = []\n",
|
||
" for label in label_list:\n",
|
||
" if label == 'no_relation':\n",
|
||
" label_word_id = tokenizer.encode('None', add_special_tokens=False)\n",
|
||
" label_word_list.append(torch.tensor(label_word_id))\n",
|
||
" else:\n",
|
||
" tmps = label\n",
|
||
" label = label.lower()\n",
|
||
" label = label.split(\"(\")[0]\n",
|
||
" label = label.replace(\":\",\" \").replace(\"_\",\" \").replace(\"per\",\"person\").replace(\"org\",\"organization\")\n",
|
||
" label_word_id = tokenizer(label, add_special_tokens=False)['input_ids']\n",
|
||
" print(label, label_word_id)\n",
|
||
" label_word_list.append(torch.tensor(label_word_id))\n",
|
||
" padded_label_word_list = pad_sequence([x for x in label_word_list], batch_first=True, padding_value=0)\n",
|
||
" return padded_label_word_list\n",
|
||
"\n",
|
||
"\n",
|
||
"model_name_or_path = cfg.model_name_or_path\n",
|
||
"\n",
|
||
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
|
||
"with open(\"data/rel2id.json\", \"r\") as file:\n",
|
||
" t = json.load(file)\n",
|
||
" label_list = list(t)\n",
|
||
"\n",
|
||
"t = split_label_words(tokenizer, label_list)\n",
|
||
"\n",
|
||
"with open(f\"data/{model_name_or_path}.pt\", \"wb\") as file:\n",
|
||
" torch.save(t, file)"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 辅助函数"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"def set_seed(cfg):\n",
|
||
" torch.cuda.manual_seed_all(cfg.seed)\n",
|
||
" np.random.seed(cfg.seed)\n",
|
||
" torch.manual_seed(cfg.seed)\n",
|
||
" torch.cuda.manual_seed_all(cfg.seed)\n",
|
||
"\n",
|
||
"def logging(s, print_=True, log_=True):\n",
|
||
" if print_:\n",
|
||
" print(s)\n",
|
||
" if log_:\n",
|
||
" with open(cfg.log_dir, 'a+') as f_log:\n",
|
||
" f_log.write(s + '\\n')"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 模型训练实践"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"### 模型训练"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
||
"\n",
|
||
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||
"\n",
|
||
"data = REDataset(cfg)\n",
|
||
"data_config = data.get_data_config()\n",
|
||
"\n",
|
||
"config = AutoConfig.from_pretrained(cfg.model_name_or_path)\n",
|
||
"config.num_labels = data_config[\"num_labels\"]\n",
|
||
"\n",
|
||
"model = BertForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config)\n",
|
||
"\n",
|
||
"if cfg.train_from_saved_model != '':\n",
|
||
" model.load_state_dict(torch.load(cfg.train_from_saved_model)[\"checkpoint\"])\n",
|
||
" print(\"load saved model from {}.\".format(cfg.train_from_saved_model))\n",
|
||
"\n",
|
||
" \n",
|
||
"if torch.cuda.device_count() > 1:\n",
|
||
" print(\"Let's use\", torch.cuda.device_count(), \"GPUs!\")\n",
|
||
" model = torch.nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))\n",
|
||
"model.to(device)\n",
|
||
"\n",
|
||
"cur_model = model.module if hasattr(model, 'module') else model\n",
|
||
"\n",
|
||
"\n",
|
||
"if \"gpt\" in cfg.model_name_or_path or \"roberta\" in cfg.model_name_or_path:\n",
|
||
" tokenizer = data.get_tokenizer()\n",
|
||
" cur_model.resize_token_embeddings(len(tokenizer))\n",
|
||
" cur_model.update_word_idx(len(tokenizer))\n",
|
||
" if \"Use\" in cfg.model_class:\n",
|
||
" continous_prompt = [a[0] for a in tokenizer([f\"[T{i}]\" for i in range(1,3)], add_special_tokens=False)['input_ids']]\n",
|
||
" continous_label_word = [a[0] for a in tokenizer([f\"[class{i}]\" for i in range(1, data.num_labels+1)], add_special_tokens=False)['input_ids']]\n",
|
||
" discrete_prompt = [a[0] for a in tokenizer(['It', 'was'], add_special_tokens=False)['input_ids']]\n",
|
||
" dataset_name = cfg.data_dir.split(\"/\")[1]\n",
|
||
" model.init_unused_weights(continous_prompt, continous_label_word, discrete_prompt, label_path=f\"{cfg.model_name_or_path}_{dataset_name}.pt\")\n",
|
||
"# data.setup()\n",
|
||
"# relation_embedding = _get_relation_embedding(data)\n",
|
||
"lit_model = BertLitModel(cfg=cfg, model=model, tokenizer=data.tokenizer, device=device)\n",
|
||
"if cfg.train_from_saved_model != '':\n",
|
||
" lit_model.best_f1 = torch.load(cfg.train_from_saved_model)[\"best_f1\"]\n",
|
||
"data.tokenizer.save_pretrained('test')\n",
|
||
"data.setup()\n",
|
||
"\n",
|
||
"optimizer = lit_model.configure_optimizers()\n",
|
||
"if cfg.train_from_saved_model != '':\n",
|
||
" optimizer.load_state_dict(torch.load(cfg.train_from_saved_model)[\"optimizer\"])\n",
|
||
" print(\"load saved optimizer from {}.\".format(cfg.train_from_saved_model))\n",
|
||
"\n",
|
||
"num_training_steps = len(data.train_dataloader()) // cfg.gradient_accumulation_steps * cfg.num_train_epochs\n",
|
||
"scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_training_steps * 0.1, num_training_steps=num_training_steps)\n",
|
||
"log_step = 100\n",
|
||
"\n",
|
||
"\n",
|
||
"logging(cfg.log_dir,'-' * 89, print_=False)\n",
|
||
"logging(cfg.log_dir, time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()) + ' INFO : START TO TRAIN ', print_=False)\n",
|
||
"logging(cfg.log_dir,'-' * 89, print_=False)\n",
|
||
"\n",
|
||
"for epoch in range(cfg.num_train_epochs):\n",
|
||
" model.train()\n",
|
||
" num_batch = len(data.train_dataloader())\n",
|
||
" total_loss = 0\n",
|
||
" log_loss = 0\n",
|
||
" for index, train_batch in enumerate(tqdm(data.train_dataloader())):\n",
|
||
" loss = lit_model.training_step(train_batch, index)\n",
|
||
" total_loss += loss.item()\n",
|
||
" log_loss += loss.item()\n",
|
||
" loss.backward()\n",
|
||
"\n",
|
||
" optimizer.step()\n",
|
||
" scheduler.step()\n",
|
||
" optimizer.zero_grad()\n",
|
||
"\n",
|
||
" if log_step > 0 and (index+1) % log_step == 0:\n",
|
||
" cur_loss = log_loss / log_step\n",
|
||
" logging(cfg.log_dir, \n",
|
||
" '| epoch {:2d} | step {:4d} | lr {} | train loss {:5.3f}'.format(\n",
|
||
" epoch, (index+1), scheduler.get_last_lr(), cur_loss * 1000)\n",
|
||
" , print_=False)\n",
|
||
" log_loss = 0\n",
|
||
" avrg_loss = total_loss / num_batch\n",
|
||
" logging(cfg.log_dir,\n",
|
||
" '| epoch {:2d} | train loss {:5.3f}'.format(\n",
|
||
" epoch, avrg_loss * 1000))\n",
|
||
" \n",
|
||
" model.eval()\n",
|
||
" with torch.no_grad():\n",
|
||
" val_loss = []\n",
|
||
" for val_index, val_batch in enumerate(tqdm(data.val_dataloader())):\n",
|
||
" loss = lit_model.validation_step(val_batch, val_index)\n",
|
||
" val_loss.append(loss)\n",
|
||
" f1, best, best_f1 = lit_model.validation_epoch_end(val_loss)\n",
|
||
" logging(cfg.log_dir,'-' * 89)\n",
|
||
" logging(cfg.log_dir,\n",
|
||
" '| epoch {:2d} | dev_result: {}'.format(epoch, f1))\n",
|
||
" logging(cfg.log_dir,'-' * 89)\n",
|
||
" logging(cfg.log_dir,\n",
|
||
" '| best_f1: {}'.format(best_f1))\n",
|
||
" logging(cfg.log_dir,'-' * 89)\n",
|
||
" if cfg.save_path != \"\" and best != -1:\n",
|
||
" file_name = f\"{epoch}-Eval_f1-{best_f1:.2f}.pt\"\n",
|
||
" save_path = cfg.save_path + '/' + file_name\n",
|
||
" torch.save({\n",
|
||
" 'epoch': epoch,\n",
|
||
" 'checkpoint': cur_model.state_dict(),\n",
|
||
" 'best_f1': best_f1,\n",
|
||
" 'optimizer': optimizer.state_dict()\n",
|
||
" }, save_path\n",
|
||
" , _use_new_zipfile_serialization=False)\n",
|
||
" logging(cfg.log_dir,\n",
|
||
" '| successfully save model at: {}'.format(save_path))\n",
|
||
" logging(cfg.log_dir,'-' * 89)"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"### 模型预测输出"
|
||
],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"source": [
|
||
"def test(cfg, model, lit_model, data):\n",
|
||
" model.eval()\n",
|
||
" with torch.no_grad():\n",
|
||
" test_loss = []\n",
|
||
" for test_index, test_batch in enumerate(tqdm(data.test_dataloader())):\n",
|
||
" loss = lit_model.test_step(test_batch, test_index)\n",
|
||
" test_loss.append(loss)\n",
|
||
" f1 = lit_model.test_epoch_end(test_loss)\n",
|
||
" logging(cfg.log_dir,\n",
|
||
" '| test_result: {}'.format(f1))\n",
|
||
" logging(cfg.log_dir,'-' * 89)\n",
|
||
"\n",
|
||
"test(cfg, model, lit_model, data)"
|
||
],
|
||
"outputs": [],
|
||
"metadata": {}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"本demo不包括调参部分,有兴趣的同学可以自行前往[DeepKE](https://github.com/zjunlp/DeepKE/tree/master)仓库,下载使用更多的模型 :)"
|
||
],
|
||
"metadata": {}
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.8"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
} |