deepke/tutorial-notebooks/re/few-shot/tutorial.ipynb

1287 lines
51 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 低资源关系抽取实践\n",
"\n",
"> Tutorial作者: 黎洲波zhoubo.li@zju.edu.cn\n",
"\n",
"关系抽取Relation Extraction, RE任务是信息抽取中的关键任务通常而言输入是文本和若干命名实体输出是命名实体对的关系。\n",
"\n",
"DeepKE中的低资源Few-shot关系抽取是基于Pre-train, Prompt, Predict范式在之前的BERT架构中的Attention和Cross-attention部分引入了Prompt参数之后对参数进行Fine-tuning这种方法在低资源场景下表现良好。该方法中的Prompt-tuning原理如下图所示\n",
"\n",
"![关系抽取中的Prompt-tuning](img/img1.png)\n",
"\n",
"通过本次实践展示,我希望读者能够快速了解如何构建低资源关系抽取的模型。\n",
"\n",
"## 数据集\n",
"\n",
"低资源关系抽取常用的数据集有RETACRED, SEMEVAL, TACREV和WIKI80等。本次实践使用的数据集是[SEMEVAL](https://semeval2.fbk.eu/semeval2.php?location=tasks#T11)SEMEVAL数据集来自于2010年的国际语义评测大会中Task 8\"Multi-Way Classification of Semantic Relations Between Pairs of Nominals\",从官网下载的数据文件夹`./data/`的结构如下:\n",
"\n",
"```\n",
".\n",
"├── rel2id.json # 关系标签到ID的映射\n",
"├── temp.txt # 关系标签处理\n",
"├── test.txt # 测试集\n",
"├── train.txt # 训练集\n",
"└── val.txt # 验证集\n",
"```\n",
"\n",
"数据文件的格式在数据集SEMEVAL中的描述如下\n",
"\n",
"```\n",
"Data Format:\n",
"{\n",
" 'token': [tokens in a sentence],\n",
" \"h\": {\n",
" \"name\": mention_name,\n",
" \"pos\" : [postion of mention in a sentence]\n",
" },\n",
" \"t\": {\n",
" \"name\": mention_name,\n",
" \"pos\" : [postion of mention in a sentence]\n",
" },\n",
" \"relation\": relation\n",
"}\n",
"```\n",
"数据集中一共包含9+1种relation各类数据的占比如下图所示\n",
"\n",
"![数据集数据占比](img/img2.png)\n",
"\n",
"## KnowPrompt原理\n",
"此处我们使用了能够对关系标签进行语义解析的Prompt方法所以把这种关系抽取的方法叫Knowledge-aware Prompt-tuning(KnowPrompt)。Fine-tuning图 a、Prompt-tuning图 b和此处的 KnowPrompt图 c方法的模型架构如下图。Prompt中的答案词是指虚拟答案词。\n",
"\n",
"![低资源关系抽取架构图](img/img3.png)\n",
"\n",
"那么接下来我们开始低资源关系抽取实践!"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 运行环境\n",
"\n",
"Python环境使用Python3并且要求以下packages的版本"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"!pip install numpy==1.20.3\n",
"!pip install tokenizers==0.10.3\n",
"!pip install torch==1.8.0\n",
"!pip install regex==2021.4.4\n",
"!pip install transformers==4.7.0\n",
"!pip install tqdm==4.49.0\n",
"!pip install activations==0.1.0\n",
"!pip install dataclasses==0.6\n",
"!pip install file_utils==0.0.1\n",
"!pip install flax==0.3.4\n",
"!pip install utils==1.0.1"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 导入模块"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"import os\n",
"import json\n",
"import csv\n",
"import time\n",
"import pickle\n",
"import logging\n",
"import shutil\n",
"import numpy as np\n",
"from tqdm import tqdm\n",
"from functools import partial\n",
"from collections import Counter\n",
"from collections import OrderedDict\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"from transformers import AutoConfig, AutoModel, AutoTokenizer\n",
"from transformers.modeling_utils import PreTrainedModel\n",
"from transformers.optimization import AdamW, get_linear_schedule_with_warmup\n",
"from allennlp.modules.matrix_attention import DotProductMatrixAttention, CosineMatrixAttention, BilinearMatrixAttention\n",
"from transformers import BertTokenizer, BertForMaskedLM"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 参数配置"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"class Config(object):\n",
" accelerator=None\n",
" accumulate_grad_batches='1'\n",
" amp_backend='native'\n",
" amp_level='O2'\n",
" auto_lr_find=False\n",
" auto_scale_batch_size=False\n",
" auto_select_gpus=False\n",
" batch_size=16\n",
" benchmark=False\n",
" check_val_every_n_epoch='3'\n",
" checkpoint_callback=True\n",
" data_class='REDataset'\n",
" data_dir='data/k-shot/8-1'\n",
" default_root_dir=None\n",
" deterministic=False\n",
" devices=None\n",
" distributed_backend=None\n",
" fast_dev_run=False\n",
" flush_logs_every_n_steps=100\n",
" gpus=None\n",
" gradient_accumulation_steps=1\n",
" gradient_clip_algorithm='norm'\n",
" gradient_clip_val=0.0\n",
" ipus=None\n",
" limit_predict_batches=1.0\n",
" limit_test_batches=1.0\n",
" limit_train_batches=1.0\n",
" limit_val_batches=1.0\n",
" litmodel_class='BertLitModel'\n",
" load_checkpoint=None\n",
" log_dir=''\n",
" log_every_n_steps=50\n",
" log_gpu_memory=None\n",
" logger=True\n",
" lr=3e-05\n",
" lr_2=3e-05\n",
" max_epochs='30'\n",
" max_seq_length=256\n",
" max_steps=None\n",
" max_time=None\n",
" min_epochs=None\n",
" min_steps=None\n",
" model_class='BertForMaskedLM'\n",
" model_name_or_path='bert-large-uncased'\n",
" move_metrics_to_cpu=False\n",
" multiple_trainloader_mode='max_size_cycle'\n",
" num_nodes=1\n",
" num_processes=1\n",
" num_sanity_val_steps=2\n",
" num_train_epochs=30\n",
" num_workers=8\n",
" optimizer='AdamW'\n",
" overfit_batches=0.0\n",
" plugins=None\n",
" precision=32\n",
" prepare_data_per_node=True\n",
" process_position=0\n",
" profiler=None\n",
" progress_bar_refresh_rate=None\n",
" ptune_k=7\n",
" reload_dataloaders_every_epoch=False\n",
" reload_dataloaders_every_n_epochs=0\n",
" replace_sampler_ddp=True\n",
" resume_from_checkpoint=None\n",
" save_path=''\n",
" seed=666\n",
" stochastic_weight_avg=False\n",
" sync_batchnorm=Falset_lambda=0.001\n",
" task_name='wiki80'\n",
" terminate_on_nan=False\n",
" tpu_cores=None\n",
" track_grad_norm=-1\n",
" train_from_saved_model=''\n",
" truncated_bptt_steps=None\n",
" two_steps=False\n",
" use_prompt=True\n",
" val_check_interval=1.0\n",
" wandb=False\n",
" weight_decay=0.01\n",
" weights_save_path=None\n",
" weights_summary='top'\n",
" \n",
"cfg = Config()"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 数据集预处理"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"class InputExampleWiki80(object):\n",
" \"\"\"A single training/test example for span pair classification.\"\"\"\n",
"\n",
" def __init__(self, guid, sentence, span1, span2, ner1, ner2, label):\n",
" self.guid = guid\n",
" self.sentence = sentence\n",
" self.span1 = span1\n",
" self.span2 = span2\n",
" self.ner1 = ner1\n",
" self.ner2 = ner2\n",
" self.label = label\n",
"\n",
"class DataProcessor(object):\n",
" \"\"\"Base class for data converters for sequence classification data sets.\"\"\"\n",
"\n",
" def get_train_examples(self, data_dir):\n",
" \"\"\"Gets a collection of `InputExample`s for the train set.\"\"\"\n",
" raise NotImplementedError()\n",
"\n",
" def get_dev_examples(self, data_dir):\n",
" \"\"\"Gets a collection of `InputExample`s for the dev set.\"\"\"\n",
" raise NotImplementedError()\n",
"\n",
" def get_labels(self):\n",
" \"\"\"Gets the list of labels for this data set.\"\"\"\n",
" raise NotImplementedError()\n",
"\n",
" @classmethod\n",
" def _read_tsv(cls, input_file, quotechar=None):\n",
" \"\"\"Reads a tab separated value file.\"\"\"\n",
" with open(input_file, \"r\") as f:\n",
" reader = csv.reader(f, delimiter=\"\\t\", quotechar=quotechar)\n",
" lines = []\n",
" for line in reader:\n",
" lines.append(line)\n",
" return lines\n",
"\n",
"class wiki80Processor(DataProcessor):\n",
" \"\"\"Processor for the TACRED data set.\"\"\"\n",
" def __init__(self, data_path, use_prompt):\n",
" super().__init__()\n",
" self.data_dir = data_path\n",
"\n",
" @classmethod\n",
" def _read_json(cls, input_file):\n",
" data = []\n",
" with open(input_file, \"r\", encoding='utf-8') as reader:\n",
" all_lines = reader.readlines()\n",
" for line in all_lines:\n",
" ins = eval(line)\n",
" data.append(ins)\n",
" return data\n",
"\n",
" def get_train_examples(self, data_dir):\n",
" \"\"\"See base class.\"\"\"\n",
" return self._create_examples(\n",
" self._read_json(os.path.join(data_dir, \"train.txt\")), \"train\")\n",
"\n",
" def get_dev_examples(self, data_dir):\n",
" \"\"\"See base class.\"\"\"\n",
" return self._create_examples(\n",
" self._read_json(os.path.join(data_dir, \"val.txt\")), \"dev\")\n",
"\n",
" def get_test_examples(self, data_dir):\n",
" \"\"\"See base class.\"\"\"\n",
" return self._create_examples(\n",
" self._read_json(os.path.join(data_dir, \"test.txt\")), \"test\")\n",
"\n",
" def get_labels(self, negative_label=\"no_relation\"):\n",
" data_dir = self.data_dir\n",
" \"\"\"See base class.\"\"\"\n",
" # if 'k-shot' in self.data_dir:\n",
" # data_dir = os.path.abspath(os.path.join(self.data_dir, \"../..\"))\n",
" # else:\n",
" # data_dir = self.data_dir\n",
" with open(os.path.join(data_dir,'rel2id.json'), \"r\", encoding='utf-8') as reader:\n",
" re2id = json.load(reader)\n",
" return re2id\n",
"\n",
"\n",
" def _create_examples(self, dataset, set_type):\n",
" \"\"\"Creates examples for the training and dev sets.\"\"\"\n",
" examples = []\n",
" for example in dataset:\n",
" sentence = example['token']\n",
" examples.append(InputExampleWiki80(guid=None,\n",
" sentence=sentence,\n",
" # maybe some bugs here, I don't -1\n",
" span1=(example['h']['pos'][0], example['h']['pos'][1]),\n",
" span2=(example['t']['pos'][0], example['t']['pos'][1]),\n",
" ner1=None,\n",
" ner2=None,\n",
" label=example['relation']))\n",
" return examples\n",
"\n",
"class Config(dict):\n",
" def __getattr__(self, name):\n",
" return self.get(name)\n",
"\n",
" def __setattr__(self, name, val):\n",
" self[name] = val\n",
"\n",
"\n",
"BATCH_SIZE = 8\n",
"NUM_WORKERS = 8\n",
"\n",
"\n",
"class BaseDataModule(nn.Module):\n",
" \"\"\"\n",
" Base DataModule.\n",
" \"\"\"\n",
"\n",
" def __init__(self, cfg) -> None:\n",
" super().__init__()\n",
" self.cfg = Config(vars(cfg)) if cfg is not None else {}\n",
" self.batch_size = self.cfg.get(\"batch_size\", BATCH_SIZE)\n",
" self.num_workers = self.cfg.get(\"num_workers\", NUM_WORKERS)\n",
"\n",
"\n",
" @staticmethod\n",
" def add_to_argparse(parser):\n",
" parser.add_argument(\n",
" \"--batch_size\", type=int, default=BATCH_SIZE, help=\"Number of examples to operate on per forward step.\"\n",
" )\n",
" parser.add_argument(\n",
" \"--num_workers\", type=int, default=NUM_WORKERS, help=\"Number of additional processes to load data.\"\n",
" )\n",
" parser.add_argument(\n",
" \"--data_dir\", type=str, default=\"./dataset/dialogue\", help=\"Number of additional processes to load data.\"\n",
" )\n",
" return parser\n",
"\n",
" def get_data_config(self):\n",
" \"\"\"Return important settings of the dataset, which will be passed to instantiate models.\"\"\"\n",
" return { \"num_labels\": self.num_labels}\n",
"\n",
" def prepare_data(self):\n",
" \"\"\"\n",
" Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).\n",
" \"\"\"\n",
" pass\n",
"\n",
" def setup(self, stage=None):\n",
" \"\"\"\n",
" Split into train, val, test, and set dims.\n",
" Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.\n",
" \"\"\"\n",
" self.data_train = None\n",
" self.data_val = None\n",
" self.data_test = None\n",
"\n",
" def train_dataloader(self):\n",
" return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)\n",
"\n",
" def val_dataloader(self):\n",
" return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)\n",
"\n",
" def test_dataloader(self):\n",
" return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)\n",
"\n",
"\n",
"class REDataset(BaseDataModule):\n",
" def __init__(self, cfg) -> None:\n",
" super().__init__(cfg)\n",
" self.processor = wiki80Processor(self.cfg.data_dir, self.cfg.use_prompt)\n",
" self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name_or_path)\n",
" \n",
" \n",
" use_gpt = \"gpt\" in cfg.model_name_or_path\n",
"\n",
" rel2id = self.processor.get_labels()\n",
" self.num_labels = len(rel2id)\n",
"\n",
" entity_list = [\"[object_start]\", \"[object_end]\", \"[subject_start]\", \"[subject_end]\"]\n",
" class_list = [f\"[class{i}]\" for i in range(1, self.num_labels+1)]\n",
"\n",
" num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': entity_list})\n",
" num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': class_list})\n",
" if use_gpt:\n",
" self.tokenizer.add_special_tokens({'cls_token': \"[CLS]\"})\n",
" self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})\n",
" so_list = [\"[sub]\", \"[obj]\"]\n",
" num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': so_list})\n",
"\n",
" prompt_tokens = [f\"[T{i}]\" for i in range(1,6)]\n",
" self.tokenizer.add_special_tokens({'additional_special_tokens': prompt_tokens})"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 计算指标函数"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"def dialog_f1_eval(logits, labels):\n",
" def getpred(result, T1=0.5, T2=0.4):\n",
" # 使用阈值得到preds, result = logits\n",
" # T2 表示如果都低于T2 那么就是 no relation, 否则选取一个最大的\n",
" ret = []\n",
" for i in range(len(result)):\n",
" r = []\n",
" maxl, maxj = -1, -1\n",
" for j in range(len(result[i])):\n",
" if result[i][j] > T1:\n",
" r += [j]\n",
" if result[i][j] > maxl:\n",
" maxl = result[i][j]\n",
" maxj = j\n",
" if len(r) == 0:\n",
" if maxl <= T2:\n",
" r = [36]\n",
" else:\n",
" r += [maxj]\n",
" ret.append(r)\n",
" return ret\n",
"\n",
" def geteval(devp, data):\n",
" correct_sys, all_sys = 0, 0\n",
" correct_gt = 0\n",
"\n",
" for i in range(len(data)):\n",
" # 每一个样本 都是[1,4,...,20] 表示有1,4,20 是1 如果没有就是[36]\n",
" for id in data[i]:\n",
" if id != 36:\n",
" # 标签中 1 的个数\n",
" correct_gt += 1\n",
" if id in devp[i]:\n",
" # 预测正确\n",
" correct_sys += 1\n",
"\n",
" for id in devp[i]:\n",
" if id != 36:\n",
" all_sys += 1\n",
"\n",
" precision = 1 if all_sys == 0 else correct_sys / all_sys\n",
" recall = 0 if correct_gt == 0 else correct_sys / correct_gt\n",
" f_1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0\n",
" return f_1\n",
"\n",
" logits = np.asarray(logits)\n",
" logits = list(1 / (1 + np.exp(-logits)))\n",
"\n",
" temp_labels = []\n",
" for l in labels:\n",
" t = []\n",
" for i in range(36):\n",
" if l[i] == 1:\n",
" t += [i]\n",
" if len(t) == 0:\n",
" t = [36]\n",
" temp_labels.append(t)\n",
" assert (len(labels) == len(logits))\n",
" labels = temp_labels\n",
"\n",
" bestT2 = bestf_1 = 0\n",
" for T2 in range(51):\n",
" devp = getpred(logits, T2=T2 / 100.)\n",
" f_1 = geteval(devp, labels)\n",
" if f_1 > bestf_1:\n",
" bestf_1 = f_1\n",
" bestT2 = T2 / 100.\n",
"\n",
" return dict(f1=bestf_1, T2=bestT2)\n",
"\n",
"\n",
"\n",
"def f1_eval(logits, labels):\n",
" def getpred(result, T1 = 0.5, T2 = 0.4) :\n",
" # 使用阈值得到preds, result = logits\n",
" # T2 表示如果都低于T2 那么就是 no relation, 否则选取一个最大的\n",
" ret = []\n",
" for i in range(len(result)):\n",
" r = []\n",
" maxl, maxj = -1, -1\n",
" for j in range(len(result[i])):\n",
" if result[i][j] > T1:\n",
" r += [j]\n",
" if result[i][j] > maxl:\n",
" maxl = result[i][j]\n",
" maxj = j\n",
" if len(r) == 0:\n",
" if maxl <= T2:\n",
" r = [36]\n",
" else:\n",
" r += [maxj]\n",
" ret.append(r)\n",
" return ret\n",
"\n",
" def geteval(devp, data):\n",
" correct_sys, all_sys = 0, 0\n",
" correct_gt = 0\n",
" \n",
" for i in range(len(data)):\n",
" # 每一个样本 都是[1,4,...,20] 表示有1,4,20 是1 如果没有就是[36]\n",
" for id in data[i]:\n",
" if id != 36:\n",
" # 标签中 1 的个数\n",
" correct_gt += 1\n",
" if id in devp[i]:\n",
" # 预测正确\n",
" correct_sys += 1\n",
"\n",
" for id in devp[i]:\n",
" if id != 36:\n",
" all_sys += 1\n",
"\n",
" precision = 1 if all_sys == 0 else correct_sys/all_sys\n",
" recall = 0 if correct_gt == 0 else correct_sys/correct_gt\n",
" f_1 = 2*precision*recall/(precision+recall) if precision+recall != 0 else 0\n",
" return f_1\n",
"\n",
" logits = np.asarray(logits)\n",
" logits = list(1 / (1 + np.exp(-logits)))\n",
"\n",
" temp_labels = []\n",
" for l in labels:\n",
" t = []\n",
" for i in range(36):\n",
" if l[i] == 1:\n",
" t += [i]\n",
" if len(t) == 0:\n",
" t = [36]\n",
" temp_labels.append(t)\n",
" assert(len(labels) == len(logits))\n",
" labels = temp_labels\n",
" \n",
" bestT2 = bestf_1 = 0\n",
" for T2 in range(51):\n",
" devp = getpred(logits, T2=T2/100.)\n",
" f_1 = geteval(devp, labels)\n",
" if f_1 > bestf_1:\n",
" bestf_1 = f_1\n",
" bestT2 = T2/100.\n",
"\n",
" return bestf_1, bestT2\n",
"\n",
"\n",
"def f1_score(output, label, rel_num=42, na_num=13):\n",
" correct_by_relation = Counter()\n",
" guess_by_relation = Counter()\n",
" gold_by_relation = Counter()\n",
" output = np.argmax(output, axis=-1)\n",
"\n",
" for i in range(len(output)):\n",
" guess = output[i]\n",
" gold = label[i]\n",
"\n",
" if guess == na_num:\n",
" guess = 0\n",
" elif guess < na_num:\n",
" guess += 1\n",
"\n",
" if gold == na_num:\n",
" gold = 0\n",
" elif gold < na_num:\n",
" gold += 1\n",
"\n",
" if gold == 0 and guess == 0:\n",
" continue\n",
" if gold == 0 and guess != 0:\n",
" guess_by_relation[guess] += 1\n",
" if gold != 0 and guess == 0:\n",
" gold_by_relation[gold] += 1\n",
" if gold != 0 and guess != 0:\n",
" guess_by_relation[guess] += 1\n",
" gold_by_relation[gold] += 1\n",
" if gold == guess:\n",
" correct_by_relation[gold] += 1\n",
" \n",
" f1_by_relation = Counter()\n",
" recall_by_relation = Counter()\n",
" prec_by_relation = Counter()\n",
" for i in range(1, rel_num):\n",
" recall = 0\n",
" if gold_by_relation[i] > 0:\n",
" recall = correct_by_relation[i] / gold_by_relation[i]\n",
" precision = 0\n",
" if guess_by_relation[i] > 0:\n",
" precision = correct_by_relation[i] / guess_by_relation[i]\n",
" if recall + precision > 0 :\n",
" f1_by_relation[i] = 2 * recall * precision / (recall + precision)\n",
" recall_by_relation[i] = recall\n",
" prec_by_relation[i] = precision\n",
"\n",
" micro_f1 = 0\n",
" if sum(guess_by_relation.values()) != 0 and sum(correct_by_relation.values()) != 0:\n",
" recall = sum(correct_by_relation.values()) / sum(gold_by_relation.values())\n",
" prec = sum(correct_by_relation.values()) / sum(guess_by_relation.values()) \n",
" micro_f1 = 2 * recall * prec / (recall+prec)\n",
"\n",
" return dict(f1=micro_f1)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 模型构建"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### 模型基类"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"OPTIMIZER = \"AdamW\"\n",
"LR = 5e-5\n",
"LOSS = \"cross_entropy\"\n",
"ONE_CYCLE_TOTAL_STEPS = 100\n",
"\n",
"class Config(dict):\n",
" def __getattr__(self, name):\n",
" return self.get(name)\n",
"\n",
" def __setattr__(self, name, val):\n",
" self[name] = val\n",
"\n",
"\n",
"class BaseLitModel(nn.Module):\n",
" \"\"\"\n",
" Generic PyTorch-Lightning class that must be initialized with a PyTorch module.\n",
" \"\"\"\n",
"\n",
" def __init__(self, model, cfg: argparse.Namespace = None, device: torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') = str):\n",
" super().__init__()\n",
" self.model = model\n",
" self.cur_model = model.module if hasattr(model, 'module') else model\n",
" self.device = device\n",
" self.cfg = Config(vars(cfg)) if cfg is not None else {}\n",
"\n",
" optimizer = self.cfg.get(\"optimizer\", OPTIMIZER)\n",
" self.optimizer_class = getattr(torch.optim, optimizer)\n",
" self.lr = self.cfg.get(\"lr\", LR)\n",
"\n",
"\n",
" @staticmethod\n",
" def add_to_argparse(parser):\n",
" parser.add_argument(\"--optimizer\", type=str, default=OPTIMIZER, help=\"optimizer class from torch.optim\")\n",
" parser.add_argument(\"--lr\", type=float, default=LR)\n",
" parser.add_argument(\"--weight_decay\", type=float, default=0.01)\n",
" return parser\n",
"\n",
" def configure_optimizers(self):\n",
" optimizer = self.optimizer_class(self.parameters(), lr=self.lr)\n",
" if self.one_cycle_max_lr is None:\n",
" return optimizer\n",
" scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)\n",
" return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler, \"monitor\": \"val_loss\"}\n",
"\n",
" def forward(self, x):\n",
" return self.model(x)\n",
"\n",
" def training_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
" x, y = batch\n",
" x.to(self.device)\n",
" logits = x\n",
" loss = (logits - y) ** 2\n",
" print(\"train_loss: \", loss)\n",
" #self.train_acc(logits, y)\n",
" #self.log(\"train_acc\", self.train_acc, on_step=False, on_epoch=True)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
" x, y = batch\n",
" x.to(self.device)\n",
" logits = x\n",
" loss = (logits - y) ** 2\n",
" print(\"val_loss: \", loss)\n",
"\n",
" def test_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
" x, y = batch\n",
" x.to(self.device)\n",
" logits = x\n",
" loss = (logits - y) ** 2\n",
" print(\"test_loss: \", loss)\n",
"\n",
" def configure_optimizers(self):\n",
" no_decay_param = [\"bias\", \"LayerNorm.weight\"]\n",
"\n",
" optimizer_group_parameters = [\n",
" {\"params\": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay_param)], \"weight_decay\": self.cfg.weight_decay},\n",
" {\"params\": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay_param)], \"weight_decay\": 0}\n",
" ]\n",
"\n",
" \n",
" optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)\n",
" #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * 0.1, num_training_steps=self.num_training_steps)\n",
" return optimizer"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### 模型子类"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"def multilabel_categorical_crossentropy(y_pred, y_true):\n",
" y_pred = (1 - 2 * y_true) * y_pred\n",
" y_pred_neg = y_pred - y_true * 1e12\n",
" y_pred_pos = y_pred - (1 - y_true) * 1e12\n",
" zeros = torch.zeros_like(y_pred[..., :1])\n",
" y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)\n",
" y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)\n",
" neg_loss = torch.logsumexp(y_pred_neg, dim=-1)\n",
" pos_loss = torch.logsumexp(y_pred_pos, dim=-1)\n",
" return (neg_loss + pos_loss).mean()\n",
"\n",
"class BertLitModel(BaseLitModel):\n",
" \"\"\"\n",
" use AutoModelForMaskedLM, and select the output by another layer in the lit model\n",
" \"\"\"\n",
" def __init__(self, model, cfg, tokenizer):\n",
" super().__init__(model, cfg)\n",
" self.tokenizer = tokenizer\n",
" \n",
" with open(f\"{cfg.data_dir}/rel2id.json\",\"r\") as file:\n",
" rel2id = json.load(file)\n",
" \n",
" Na_num = 0\n",
" for k, v in rel2id.items():\n",
" if k == \"NA\" or k == \"no_relation\" or k == \"Other\":\n",
" Na_num = v\n",
" break\n",
" num_relation = len(rel2id)\n",
" # init loss function\n",
" self.loss_fn = multilabel_categorical_crossentropy if \"dialogue\" in cfg.data_dir else nn.CrossEntropyLoss()\n",
" # ignore the no_relation class to compute the f1 score\n",
" self.eval_fn = f1_eval if \"dialogue\" in cfg.data_dir else partial(f1_score, rel_num=num_relation, na_num=Na_num)\n",
" self.best_f1 = 0\n",
" self.t_lambda = cfg.t_lambda\n",
" \n",
" self.label_st_id = tokenizer(\"[class1]\", add_special_tokens=False)['input_ids'][0]\n",
" \n",
" self._init_label_word()\n",
"\n",
" def _init_label_word(self):\n",
" cfg = self.cfg\n",
" # ./dataset/dataset_name\n",
" dataset_name = cfg.data_dir.split(\"/\")[1]\n",
" model_name_or_path = cfg.model_name_or_path.split(\"/\")[-1]\n",
" label_path = f\"./dataset/{model_name_or_path}_{dataset_name}.pt\"\n",
" # [num_labels, num_tokens], ignore the unanswerable\n",
" if \"dialogue\" in cfg.data_dir:\n",
" label_word_idx = torch.load(label_path)[:-1]\n",
" else:\n",
" label_word_idx = torch.load(label_path)\n",
" \n",
" num_labels = len(label_word_idx)\n",
" \n",
" self.cur_model.resize_token_embeddings(len(self.tokenizer))\n",
" with torch.no_grad():\n",
" word_embeddings = self.cur_model.get_input_embeddings()\n",
" continous_label_word = [a[0] for a in self.tokenizer([f\"[class{i}]\" for i in range(1, num_labels+1)], add_special_tokens=False)['input_ids']]\n",
" for i, idx in enumerate(label_word_idx):\n",
" word_embeddings.weight[continous_label_word[i]] = torch.mean(word_embeddings.weight[idx], dim=0)\n",
" # word_embeddings.weight[continous_label_word[i]] = self.relation_embedding[i]\n",
" so_word = [a[0] for a in self.tokenizer([\"[obj]\",\"[sub]\"], add_special_tokens=False)['input_ids']]\n",
" meaning_word = [a[0] for a in self.tokenizer([\"person\",\"organization\", \"location\", \"date\", \"country\"], add_special_tokens=False)['input_ids']]\n",
" \n",
" for i, idx in enumerate(so_word):\n",
" word_embeddings.weight[so_word[i]] = torch.mean(word_embeddings.weight[meaning_word], dim=0)\n",
" assert torch.equal(self.cur_model.get_input_embeddings().weight, word_embeddings.weight)\n",
" assert torch.equal(self.cur_model.get_input_embeddings().weight, self.cur_model.get_output_embeddings().weight)\n",
" \n",
" self.word2label = continous_label_word # a continous list\n",
" \n",
" \n",
" def forward(self, x):\n",
" return self.model(x)\n",
"\n",
" def training_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
" input_ids, attention_mask, token_type_ids , labels, so = batch\n",
" input_ids = input_ids.to(self.device)\n",
" attention_mask = attention_mask.to(self.device)\n",
" token_type_ids = token_type_ids.to(self.device)\n",
" labels = labels.to(self.device)\n",
" so = so.to(self.device)\n",
" result = self.model(input_ids, attention_mask, token_type_ids, return_dict=True, output_hidden_states=True)\n",
" logits = result.logits\n",
" output_embedding = result.hidden_states[-1]\n",
" logits = self.pvp(logits, input_ids)\n",
" loss = self.loss_fn(logits, labels) + self.t_lambda * self.ke_loss(output_embedding, labels, so)\n",
" #print(\"Train/loss: \", loss)\n",
" return loss\n",
"\n",
" def validation_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
" input_ids, attention_mask, token_type_ids , labels, _ = batch\n",
" input_ids = input_ids.to(self.device)\n",
" attention_mask = attention_mask.to(self.device)\n",
" token_type_ids = token_type_ids.to(self.device)\n",
" labels = labels.to(self.device)\n",
" logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits\n",
" logits = self.pvp(logits, input_ids)\n",
" loss = self.loss_fn(logits, labels)\n",
" #print(\"Eval/loss: \", loss)\n",
" return {\"loss\": loss, \"eval_logits\": logits.detach().cpu().numpy(), \"eval_labels\": labels.detach().cpu().numpy()}\n",
" \n",
" def validation_epoch_end(self, outputs):\n",
" logits = np.concatenate([o[\"eval_logits\"] for o in outputs])\n",
" labels = np.concatenate([o[\"eval_labels\"] for o in outputs])\n",
"\n",
" f1 = self.eval_fn(logits, labels)['f1']\n",
" #print(\"Eval/f1: \", f1)\n",
" best_f1 = -1\n",
" if f1 > self.best_f1:\n",
" self.best_f1 = f1\n",
" best_f1 = self.best_f1\n",
" #print(\"Eval/best_f1: \", self.best_f1)\n",
" return f1, best_f1, self.best_f1\n",
"\n",
" def test_step(self, batch, batch_idx): # pylint: disable=unused-argument\n",
" input_ids, attention_mask, token_type_ids , labels, _ = batch\n",
" input_ids = input_ids.to(self.device)\n",
" attention_mask = attention_mask.to(self.device)\n",
" token_type_ids = token_type_ids.to(self.device)\n",
" labels = labels.to(self.device)\n",
" logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits\n",
" logits = self.pvp(logits, input_ids)\n",
" return {\"test_logits\": logits.detach().cpu().numpy(), \"test_labels\": labels.detach().cpu().numpy()}\n",
"\n",
" def test_epoch_end(self, outputs):\n",
" logits = np.concatenate([o[\"test_logits\"] for o in outputs])\n",
" labels = np.concatenate([o[\"test_labels\"] for o in outputs])\n",
"\n",
" f1 = self.eval_fn(logits, labels)['f1']\n",
" #print(\"Test/f1: \", f1)\n",
" return f1\n",
"\n",
"\n",
" @staticmethod\n",
" def add_to_argparse(parser):\n",
" BaseLitModel.add_to_argparse(parser)\n",
" parser.add_argument(\"--t_lambda\", type=float, default=0.01, help=\"\")\n",
" return parser\n",
" \n",
" def pvp(self, logits, input_ids):\n",
" # convert the [batch_size, seq_len, vocab_size] => [batch_size, num_labels]\n",
" #! hard coded\n",
" _, mask_idx = (input_ids == 103).nonzero(as_tuple=True)\n",
" bs = input_ids.shape[0]\n",
" mask_output = logits[torch.arange(bs), mask_idx]\n",
" assert mask_idx.shape[0] == bs, \"only one mask in sequence!\"\n",
" final_output = mask_output[:,self.word2label]\n",
" \n",
" return final_output\n",
" \n",
" def ke_loss(self, logits, labels, so):\n",
" subject_embedding = []\n",
" object_embedding = []\n",
" bsz = logits.shape[0]\n",
" for i in range(bsz):\n",
" subject_embedding.append(torch.mean(logits[i, so[i][0]:so[i][1]], dim=0))\n",
" object_embedding.append(torch.mean(logits[i, so[i][2]:so[i][3]], dim=0))\n",
" \n",
" subject_embedding = torch.stack(subject_embedding)\n",
" object_embedding = torch.stack(object_embedding)\n",
" # trick , the relation ids is concated, \n",
" relation_embedding = self.cur_model.get_output_embeddings().weight[labels+self.label_st_id]\n",
" \n",
" loss = torch.norm(subject_embedding + relation_embedding - object_embedding, p=2)\n",
" \n",
" return loss\n",
"\n",
" def configure_optimizers(self):\n",
" no_decay_param = [\"bias\", \"LayerNorm.weight\"]\n",
"\n",
" if not self.cfg.two_steps: \n",
" parameters = self.cur_model.named_parameters()\n",
" else:\n",
" # cur_model.bert.embeddings.weight\n",
" parameters = [next(self.cur_model.named_parameters())]\n",
" # only optimize the embedding parameters\n",
" optimizer_group_parameters = [\n",
" {\"params\": [p for n, p in parameters if not any(nd in n for nd in no_decay_param)], \"weight_decay\": self.cfg.weight_decay},\n",
" {\"params\": [p for n, p in parameters if any(nd in n for nd in no_decay_param)], \"weight_decay\": 0}\n",
" ]\n",
"\n",
" \n",
" optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)\n",
" return optimizer"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 输入预处理"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### 少样本采集"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"Seed = [1, 2, 3, 4, 5]\n",
"mode = 'k-shot'\n",
"data_file = 'train.txt'\n",
"\n",
"def get_labels(path, name, negative_label=\"no_relation\"):\n",
" \"\"\"See base class.\"\"\"\n",
"\n",
" count = Counter()\n",
" with open(path + \"/\" + name, \"r\") as f:\n",
" features = []\n",
" for line in f.readlines():\n",
" line = line.rstrip()\n",
" if len(line) > 0:\n",
" # count[line['relation']] += 1\n",
" features.append(eval(line))\n",
"\n",
" # logger.info(\"label distribution as list: %d labels\" % len(count))\n",
" # # Make sure the negative label is alwyas 0\n",
" # labels = []\n",
" # for label, count in count.most_common():\n",
" # logger.info(\"%s: %d 个 %.2f%%\" % (label, count, count * 100.0 / len(dataset)))\n",
" # if label not in labels:\n",
" # labels.append(label)\n",
" return features\n",
"\n",
"path = 'data'\n",
"\n",
"output_dir = os.path.join(path, mode)\n",
"dataset = get_labels(path, data_file)\n",
"\n",
"for seed in Seed:\n",
"\n",
" # Other datasets\n",
" np.random.seed(seed)\n",
" np.random.shuffle(dataset)\n",
"\n",
" # Set up dir\n",
" k = 8\n",
" setting_dir = os.path.join(output_dir, f\"{k}-{seed}\")\n",
" os.makedirs(setting_dir, exist_ok=True)\n",
"\n",
" label_list = {}\n",
" for line in dataset:\n",
" label = line['relation']\n",
" if label not in label_list:\n",
" label_list[label] = [line]\n",
" else:\n",
" label_list[label].append(line)\n",
"\n",
" with open(os.path.join(setting_dir, \"train.txt\"), \"w\") as f:\n",
" file_list = []\n",
" for label in label_list:\n",
" for line in label_list[label][:k]: # train中每一类取前k个数据\n",
" f.writelines(json.dumps(line))\n",
" f.write('\\n')\n",
"\n",
" f.close()\n",
"\n",
"shutil.copyfile('data/rel2id.json','data/k-shot/8-1/rel2id.json')\n",
"shutil.copyfile('data/val.txt','data/k-shot/8-1/val.txt')\n",
"shutil.copyfile('data/test.txt','data/k-shot/8-1/test.txt')"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### 获取标签"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"def split_label_words(tokenizer, label_list):\n",
" label_word_list = []\n",
" for label in label_list:\n",
" if label == 'no_relation':\n",
" label_word_id = tokenizer.encode('None', add_special_tokens=False)\n",
" label_word_list.append(torch.tensor(label_word_id))\n",
" else:\n",
" tmps = label\n",
" label = label.lower()\n",
" label = label.split(\"(\")[0]\n",
" label = label.replace(\":\",\" \").replace(\"_\",\" \").replace(\"per\",\"person\").replace(\"org\",\"organization\")\n",
" label_word_id = tokenizer(label, add_special_tokens=False)['input_ids']\n",
" print(label, label_word_id)\n",
" label_word_list.append(torch.tensor(label_word_id))\n",
" padded_label_word_list = pad_sequence([x for x in label_word_list], batch_first=True, padding_value=0)\n",
" return padded_label_word_list\n",
"\n",
"\n",
"model_name_or_path = cfg.model_name_or_path\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
"with open(\"data/rel2id.json\", \"r\") as file:\n",
" t = json.load(file)\n",
" label_list = list(t)\n",
"\n",
"t = split_label_words(tokenizer, label_list)\n",
"\n",
"with open(f\"data/{model_name_or_path}.pt\", \"wb\") as file:\n",
" torch.save(t, file)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 辅助函数"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"def set_seed(cfg):\n",
" torch.cuda.manual_seed_all(cfg.seed)\n",
" np.random.seed(cfg.seed)\n",
" torch.manual_seed(cfg.seed)\n",
" torch.cuda.manual_seed_all(cfg.seed)\n",
"\n",
"def logging(s, print_=True, log_=True):\n",
" if print_:\n",
" print(s)\n",
" if log_:\n",
" with open(cfg.log_dir, 'a+') as f_log:\n",
" f_log.write(s + '\\n')"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"## 模型训练实践"
],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### 模型训练"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
"\n",
"device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"data = REDataset(cfg)\n",
"data_config = data.get_data_config()\n",
"\n",
"config = AutoConfig.from_pretrained(cfg.model_name_or_path)\n",
"config.num_labels = data_config[\"num_labels\"]\n",
"\n",
"model = BertForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config)\n",
"\n",
"if cfg.train_from_saved_model != '':\n",
" model.load_state_dict(torch.load(cfg.train_from_saved_model)[\"checkpoint\"])\n",
" print(\"load saved model from {}.\".format(cfg.train_from_saved_model))\n",
"\n",
" \n",
"if torch.cuda.device_count() > 1:\n",
" print(\"Let's use\", torch.cuda.device_count(), \"GPUs!\")\n",
" model = torch.nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))\n",
"model.to(device)\n",
"\n",
"cur_model = model.module if hasattr(model, 'module') else model\n",
"\n",
"\n",
"if \"gpt\" in cfg.model_name_or_path or \"roberta\" in cfg.model_name_or_path:\n",
" tokenizer = data.get_tokenizer()\n",
" cur_model.resize_token_embeddings(len(tokenizer))\n",
" cur_model.update_word_idx(len(tokenizer))\n",
" if \"Use\" in cfg.model_class:\n",
" continous_prompt = [a[0] for a in tokenizer([f\"[T{i}]\" for i in range(1,3)], add_special_tokens=False)['input_ids']]\n",
" continous_label_word = [a[0] for a in tokenizer([f\"[class{i}]\" for i in range(1, data.num_labels+1)], add_special_tokens=False)['input_ids']]\n",
" discrete_prompt = [a[0] for a in tokenizer(['It', 'was'], add_special_tokens=False)['input_ids']]\n",
" dataset_name = cfg.data_dir.split(\"/\")[1]\n",
" model.init_unused_weights(continous_prompt, continous_label_word, discrete_prompt, label_path=f\"{cfg.model_name_or_path}_{dataset_name}.pt\")\n",
"# data.setup()\n",
"# relation_embedding = _get_relation_embedding(data)\n",
"lit_model = BertLitModel(cfg=cfg, model=model, tokenizer=data.tokenizer, device=device)\n",
"if cfg.train_from_saved_model != '':\n",
" lit_model.best_f1 = torch.load(cfg.train_from_saved_model)[\"best_f1\"]\n",
"data.tokenizer.save_pretrained('test')\n",
"data.setup()\n",
"\n",
"optimizer = lit_model.configure_optimizers()\n",
"if cfg.train_from_saved_model != '':\n",
" optimizer.load_state_dict(torch.load(cfg.train_from_saved_model)[\"optimizer\"])\n",
" print(\"load saved optimizer from {}.\".format(cfg.train_from_saved_model))\n",
"\n",
"num_training_steps = len(data.train_dataloader()) // cfg.gradient_accumulation_steps * cfg.num_train_epochs\n",
"scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_training_steps * 0.1, num_training_steps=num_training_steps)\n",
"log_step = 100\n",
"\n",
"\n",
"logging(cfg.log_dir,'-' * 89, print_=False)\n",
"logging(cfg.log_dir, time.strftime(\"%Y-%m-%d %H:%M:%S\", time.localtime()) + ' INFO : START TO TRAIN ', print_=False)\n",
"logging(cfg.log_dir,'-' * 89, print_=False)\n",
"\n",
"for epoch in range(cfg.num_train_epochs):\n",
" model.train()\n",
" num_batch = len(data.train_dataloader())\n",
" total_loss = 0\n",
" log_loss = 0\n",
" for index, train_batch in enumerate(tqdm(data.train_dataloader())):\n",
" loss = lit_model.training_step(train_batch, index)\n",
" total_loss += loss.item()\n",
" log_loss += loss.item()\n",
" loss.backward()\n",
"\n",
" optimizer.step()\n",
" scheduler.step()\n",
" optimizer.zero_grad()\n",
"\n",
" if log_step > 0 and (index+1) % log_step == 0:\n",
" cur_loss = log_loss / log_step\n",
" logging(cfg.log_dir, \n",
" '| epoch {:2d} | step {:4d} | lr {} | train loss {:5.3f}'.format(\n",
" epoch, (index+1), scheduler.get_last_lr(), cur_loss * 1000)\n",
" , print_=False)\n",
" log_loss = 0\n",
" avrg_loss = total_loss / num_batch\n",
" logging(cfg.log_dir,\n",
" '| epoch {:2d} | train loss {:5.3f}'.format(\n",
" epoch, avrg_loss * 1000))\n",
" \n",
" model.eval()\n",
" with torch.no_grad():\n",
" val_loss = []\n",
" for val_index, val_batch in enumerate(tqdm(data.val_dataloader())):\n",
" loss = lit_model.validation_step(val_batch, val_index)\n",
" val_loss.append(loss)\n",
" f1, best, best_f1 = lit_model.validation_epoch_end(val_loss)\n",
" logging(cfg.log_dir,'-' * 89)\n",
" logging(cfg.log_dir,\n",
" '| epoch {:2d} | dev_result: {}'.format(epoch, f1))\n",
" logging(cfg.log_dir,'-' * 89)\n",
" logging(cfg.log_dir,\n",
" '| best_f1: {}'.format(best_f1))\n",
" logging(cfg.log_dir,'-' * 89)\n",
" if cfg.save_path != \"\" and best != -1:\n",
" file_name = f\"{epoch}-Eval_f1-{best_f1:.2f}.pt\"\n",
" save_path = cfg.save_path + '/' + file_name\n",
" torch.save({\n",
" 'epoch': epoch,\n",
" 'checkpoint': cur_model.state_dict(),\n",
" 'best_f1': best_f1,\n",
" 'optimizer': optimizer.state_dict()\n",
" }, save_path\n",
" , _use_new_zipfile_serialization=False)\n",
" logging(cfg.log_dir,\n",
" '| successfully save model at: {}'.format(save_path))\n",
" logging(cfg.log_dir,'-' * 89)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"### 模型预测输出"
],
"metadata": {}
},
{
"cell_type": "code",
"execution_count": null,
"source": [
"def test(cfg, model, lit_model, data):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" test_loss = []\n",
" for test_index, test_batch in enumerate(tqdm(data.test_dataloader())):\n",
" loss = lit_model.test_step(test_batch, test_index)\n",
" test_loss.append(loss)\n",
" f1 = lit_model.test_epoch_end(test_loss)\n",
" logging(cfg.log_dir,\n",
" '| test_result: {}'.format(f1))\n",
" logging(cfg.log_dir,'-' * 89)\n",
"\n",
"test(cfg, model, lit_model, data)"
],
"outputs": [],
"metadata": {}
},
{
"cell_type": "markdown",
"source": [
"本demo不包括调参部分有兴趣的同学可以自行前往[DeepKE](https://github.com/zjunlp/DeepKE/tree/master)仓库,下载使用更多的模型 :)"
],
"metadata": {}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}