From 7d2cb7f4cf1cc923797170dff730b87d80790b49 Mon Sep 17 00:00:00 2001 From: tlk-dsg <467460833@qq.com> Date: Wed, 27 Oct 2021 17:34:55 +0800 Subject: [PATCH] fix bug --- tutorial-notebooks/re/standard/GCN.ipynb | 706 ----------------------- 1 file changed, 706 deletions(-) delete mode 100644 tutorial-notebooks/re/standard/GCN.ipynb diff --git a/tutorial-notebooks/re/standard/GCN.ipynb b/tutorial-notebooks/re/standard/GCN.ipynb deleted file mode 100644 index 265c766..0000000 --- a/tutorial-notebooks/re/standard/GCN.ipynb +++ /dev/null @@ -1,706 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "source": [ - "## relation extraction experiment\n", - "> Tutorial author:余海阳(yuhaiyang@zju.edu.cn)\n", - "\n", - "On this demo,we use `GCN` to extract relations.\n", - "We hope this demo can help you understand the process of conctruction knowledge graph and the principles and common methods of triplet extraction.\n", - "\n", - "This demo uses `Python3`.\n", - "\n", - "### Dataset\n", - "In this example,we get some Chinese text to extract the triples.\n", - "\n", - "sentence|relation|head|tail\n", - ":---:|:---:|:---:|:---:\n", - "孔正锡在2005年以一部温馨的爱情电影《长腿叔叔》敲开电影界大门。|导演|长腿叔叔|孔正锡\n", - "《伤心的树》是吴宗宪的音乐作品,收录在《你比从前快乐》专辑中。|所属专辑|伤心的树|你比从前快乐\n", - "2000年8月,「天坛大佛」荣获「香港十大杰出工程项目」第四名。|所在城市|天坛大佛|香港\n", - "\n", - "\n", - "- train.csv: It contains 6 training triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n", - "- valid.csv: It contains 3 validing triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n", - "- test.csv: It contains 3 testing triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n", - "- relation.csv: It contains 4 relation triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`." - ], - "metadata": {} - }, - { - "cell_type": "markdown", - "source": [ - "### GCN \n", - "\n", - "![GCN](img/GCN.png)\n", - "\n", - "The sentence information mainly includes word embedding and position embedding and the adjacency matrix adj_matrix\n", - "The nodes in the adjacency matrix are each word token.\n", - "After input to the multi-layer (generally take 2 or 3 layers, and the result will not be significantly improved if it is too multi-layer), the relationship information of the sentence can be obtained through the maximum pool input to the full connection layer.\n", - "\n" - ], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# Run the neural network with pytorch and confirm whether it is installed before running\n", - "!pip install torch\n", - "!pip install matplotlib\n", - "!pip install transformers" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# import the whole modules\n", - "import os\n", - "import csv\n", - "import math\n", - "import pickle\n", - "import logging\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from torch import optim\n", - "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n", - "from torch.utils.data import Dataset,DataLoader\n", - "from sklearn.metrics import precision_recall_fscore_support\n", - "from typing import List, Tuple, Dict, Any, Sequence, Optional, Union\n", - "from transformers import BertTokenizer, BertModel\n", - "\n", - "logger = logging.getLogger(__name__)" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# Configuration file of model parameters\n", - "class Config(object):\n", - " model_name = 'gcn' # ['cnn', 'gcn', 'lm']\n", - " use_pcnn = True\n", - " min_freq = 1\n", - " pos_limit = 20\n", - " out_path = 'data/out' \n", - " batch_size = 2 \n", - " word_dim = 10\n", - " pos_dim = 5\n", - " dim_strategy = 'sum' # ['sum', 'cat']\n", - " out_channels = 20\n", - " intermediate = 10\n", - " kernel_sizes = [3, 5, 7]\n", - " activation = 'gelu'\n", - " pooling_strategy = 'max'\n", - " dropout = 0.3\n", - " epoch = 10\n", - " num_relations = 4\n", - " learning_rate = 3e-4\n", - " lr_factor = 0.7 # 学习率的衰减率\n", - " lr_patience = 3 # 学习率衰减的等待epoch\n", - " weight_decay = 1e-3 # L2正则\n", - " early_stopping_patience = 6\n", - " train_log = True\n", - " log_interval = 1\n", - " show_plot = True\n", - " only_comparison_plot = False\n", - " plot_utils = 'matplot'\n", - " lm_file = 'bert-base-chinese'\n", - " lm_num_hidden_layers = 2\n", - " rnn_layers = 2\n", - " \n", - "cfg = Config()" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# Word token builds a one hot dictionary, and then inputs it to the embedding layer to obtain the corresponding word information matrix\n", - "# 0 is pad by default and 1 is unknown\n", - "\n", - "class Vocab(object):\n", - " def __init__(self, name: str = 'basic', init_tokens = [\"[PAD]\", \"[UNK]\"]):\n", - " self.name = name\n", - " self.init_tokens = init_tokens\n", - " self.trimed = False\n", - " self.word2idx = {}\n", - " self.word2count = {}\n", - " self.idx2word = {}\n", - " self.count = 0\n", - " self._add_init_tokens()\n", - "\n", - " def _add_init_tokens(self):\n", - " for token in self.init_tokens:\n", - " self._add_word(token)\n", - "\n", - " def _add_word(self, word: str):\n", - " if word not in self.word2idx:\n", - " self.word2idx[word] = self.count\n", - " self.word2count[word] = 1\n", - " self.idx2word[self.count] = word\n", - " self.count += 1\n", - " else:\n", - " self.word2count[word] += 1\n", - "\n", - " def add_words(self, words: Sequence):\n", - " for word in words:\n", - " self._add_word(word)\n", - "\n", - " def trim(self, min_freq=2, verbose: Optional[bool] = True):\n", - " \n", - " assert min_freq == int(min_freq), f'min_freq must be integer, can\\'t be {min_freq}'\n", - " min_freq = int(min_freq)\n", - " if min_freq < 2:\n", - " return\n", - " if self.trimed:\n", - " return\n", - " self.trimed = True\n", - "\n", - " keep_words = []\n", - " new_words = []\n", - "\n", - " for k, v in self.word2count.items():\n", - " if v >= min_freq:\n", - " keep_words.append(k)\n", - " new_words.extend([k] * v)\n", - " if verbose:\n", - " before_len = len(keep_words)\n", - " after_len = len(self.word2idx) - len(self.init_tokens)\n", - " logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(before_len, after_len, before_len / after_len * 100))\n", - "\n", - " # Reinitialize dictionaries\n", - " self.word2idx = {}\n", - " self.word2count = {}\n", - " self.idx2word = {}\n", - " self.count = 0\n", - " self._add_init_tokens()\n", - " self.add_words(new_words)" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# Functions required for preprocessing\n", - "Path = str\n", - "\n", - "def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:\n", - " if verbose:\n", - " logger.info(f'load csv from {fp}')\n", - "\n", - " dialect = 'excel-tab' if is_tsv else 'excel'\n", - " with open(fp, encoding='utf-8') as f:\n", - " reader = csv.DictReader(f, dialect=dialect)\n", - " return list(reader)\n", - "\n", - " \n", - "def load_pkl(fp: Path, verbose: bool = True) -> Any:\n", - " if verbose:\n", - " logger.info(f'load data from {fp}')\n", - "\n", - " with open(fp, 'rb') as f:\n", - " data = pickle.load(f)\n", - " return data\n", - "\n", - "\n", - "def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:\n", - " if verbose:\n", - " logger.info(f'save data in {fp}')\n", - "\n", - " with open(fp, 'wb') as f:\n", - " pickle.dump(data, f)\n", - " \n", - " \n", - "def _handle_relation_data(relation_data: List[Dict]) -> Dict:\n", - " rels = dict()\n", - " for d in relation_data:\n", - " rels[d['relation']] = {\n", - " 'index': int(d['index']),\n", - " 'head_type': d['head_type'],\n", - " 'tail_type': d['tail_type'],\n", - " }\n", - " return rels\n", - "\n", - "\n", - "def _add_relation_data(rels: Dict,data: List) -> None:\n", - " for d in data:\n", - " d['rel2idx'] = rels[d['relation']]['index']\n", - " d['head_type'] = rels[d['relation']]['head_type']\n", - " d['tail_type'] = rels[d['relation']]['tail_type']\n", - "\n", - "\n", - "def _convert_tokens_into_index(data: List[Dict], vocab):\n", - " unk_str = '[UNK]'\n", - " unk_idx = vocab.word2idx[unk_str]\n", - "\n", - " for d in data:\n", - " d['token2idx'] = [vocab.word2idx.get(i, unk_idx) for i in d['tokens']]\n", - "\n", - "\n", - "def _add_pos_seq(train_data: List[Dict], cfg):\n", - " for d in train_data:\n", - " d['head_offset'], d['tail_offset'], d['lens'] = int(d['head_offset']), int(d['tail_offset']), int(d['lens'])\n", - " entities_idx = [d['head_offset'], d['tail_offset']] if d['head_offset'] < d['tail_offset'] else [d['tail_offset'], d['head_offset']]\n", - "\n", - " d['head_pos'] = list(map(lambda i: i - d['head_offset'], list(range(d['lens']))))\n", - " d['head_pos'] = _handle_pos_limit(d['head_pos'], int(cfg.pos_limit))\n", - "\n", - " d['tail_pos'] = list(map(lambda i: i - d['tail_offset'], list(range(d['lens']))))\n", - " d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit))\n", - "\n", - " if cfg.use_pcnn:\n", - " d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\\\n", - " [3] * (d['lens'] - entities_idx[1])\n", - "\n", - " \n", - "def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:\n", - " for i, p in enumerate(pos):\n", - " if p > limit:\n", - " pos[i] = limit\n", - " if p < -limit:\n", - " pos[i] = -limit\n", - " return [p + limit + 1 for p in pos]\n", - "\n", - "\n", - "def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):\n", - " \n", - " if isinstance(seq_len, list):\n", - " seq_len = np.array(seq_len)\n", - "\n", - " if isinstance(seq_len, np.ndarray):\n", - " seq_len = torch.from_numpy(seq_len)\n", - "\n", - " if isinstance(seq_len, torch.Tensor):\n", - " assert seq_len.dim() == 1, logger.error(f\"seq_len can only have one dimension, got {seq_len.dim()} != 1.\")\n", - " batch_size = seq_len.size(0)\n", - " max_len = int(max_len) if max_len else seq_len.max().long()\n", - " broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len.device)\n", - " if mask_pos_to_true:\n", - " mask = broad_cast_seq_len.ge(seq_len.unsqueeze(1))\n", - " else:\n", - " mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))\n", - " else:\n", - " raise logger.error(\"Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.\")\n", - "\n", - " return mask\n", - "\n", - "\n", - "def _lm_serialize(data: List[Dict], cfg):\n", - " logger.info('use bert tokenizer...')\n", - " tokenizer = BertTokenizer.from_pretrained(cfg.lm_file)\n", - " for d in data:\n", - " sent = d['sentence'].strip()\n", - " sent = sent.replace(d['head'], d['head_type'], 1).replace(d['tail'], d['tail_type'], 1)\n", - " sent += '[SEP]' + d['head'] + '[SEP]' + d['tail']\n", - " d['token2idx'] = tokenizer.encode(sent, add_special_tokens=True)\n", - " d['lens'] = len(d['token2idx'])" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# Preprocess\n", - "logger.info('load raw files...')\n", - "train_fp = os.path.join('data/train.csv')\n", - "valid_fp = os.path.join('data/valid.csv')\n", - "test_fp = os.path.join('data/test.csv')\n", - "relation_fp = os.path.join('data/relation.csv')\n", - "\n", - "train_data = load_csv(train_fp)\n", - "valid_data = load_csv(valid_fp)\n", - "test_data = load_csv(test_fp)\n", - "relation_data = load_csv(relation_fp)\n", - "\n", - "for d in train_data:\n", - " d['tokens'] = eval(d['tokens'])\n", - "for d in valid_data:\n", - " d['tokens'] = eval(d['tokens'])\n", - "for d in test_data:\n", - " d['tokens'] = eval(d['tokens'])\n", - " \n", - "logger.info('convert relation into index...')\n", - "rels = _handle_relation_data(relation_data)\n", - "_add_relation_data(rels, train_data)\n", - "_add_relation_data(rels, valid_data)\n", - "_add_relation_data(rels, test_data)\n", - "\n", - "logger.info('verify whether use pretrained language models...')\n", - "if cfg.model_name == 'lm':\n", - " logger.info('use pretrained language models serialize sentence...')\n", - " _lm_serialize(train_data, cfg)\n", - " _lm_serialize(valid_data, cfg)\n", - " _lm_serialize(test_data, cfg)\n", - "else:\n", - " logger.info('build vocabulary...')\n", - " vocab = Vocab('word')\n", - " train_tokens = [d['tokens'] for d in train_data]\n", - " valid_tokens = [d['tokens'] for d in valid_data]\n", - " test_tokens = [d['tokens'] for d in test_data]\n", - " sent_tokens = [*train_tokens, *valid_tokens, *test_tokens]\n", - " for sent in sent_tokens:\n", - " vocab.add_words(sent)\n", - " vocab.trim(min_freq=cfg.min_freq)\n", - "\n", - " logger.info('convert tokens into index...')\n", - " _convert_tokens_into_index(train_data, vocab)\n", - " _convert_tokens_into_index(valid_data, vocab)\n", - " _convert_tokens_into_index(test_data, vocab)\n", - "\n", - " logger.info('build position sequence...')\n", - " _add_pos_seq(train_data, cfg)\n", - " _add_pos_seq(valid_data, cfg)\n", - " _add_pos_seq(test_data, cfg)\n", - "\n", - "logger.info('save data for backup...')\n", - "os.makedirs(cfg.out_path, exist_ok=True)\n", - "train_save_fp = os.path.join(cfg.out_path, 'train.pkl')\n", - "valid_save_fp = os.path.join(cfg.out_path, 'valid.pkl')\n", - "test_save_fp = os.path.join(cfg.out_path, 'test.pkl')\n", - "save_pkl(train_data, train_save_fp)\n", - "save_pkl(valid_data, valid_save_fp)\n", - "save_pkl(test_data, test_save_fp)\n", - "\n", - "if cfg.model_name != 'lm':\n", - " vocab_save_fp = os.path.join(cfg.out_path, 'vocab.pkl')\n", - " vocab_txt = os.path.join(cfg.out_path, 'vocab.txt')\n", - " save_pkl(vocab, vocab_save_fp)\n", - " logger.info('save vocab in txt file, for watching...')\n", - " with open(vocab_txt, 'w', encoding='utf-8') as f:\n", - " f.write(os.linesep.join(vocab.word2idx.keys()))" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# embedding layer\n", - "class Embedding(nn.Module):\n", - " def __init__(self, config):\n", - " \"\"\"\n", - " word embedding: 一般 0 为 padding\n", - " pos embedding: 一般 0 为 padding\n", - " dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加\n", - " \"\"\"\n", - " super(Embedding, self).__init__()\n", - "\n", - " # self.xxx = config.xxx\n", - " self.vocab_size = config.vocab_size\n", - " self.word_dim = config.word_dim\n", - " self.pos_size = config.pos_limit * 2 + 2\n", - " self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim\n", - " self.dim_strategy = config.dim_strategy\n", - "\n", - " self.wordEmbed = nn.Embedding(self.vocab_size,self.word_dim,padding_idx=0)\n", - " self.headPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)\n", - " self.tailPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)\n", - "\n", - "\n", - " def forward(self, *x):\n", - " word, head, tail = x\n", - " word_embedding = self.wordEmbed(word)\n", - " head_embedding = self.headPosEmbed(head)\n", - " tail_embedding = self.tailPosEmbed(tail)\n", - "\n", - " if self.dim_strategy == 'cat':\n", - " return torch.cat((word_embedding,head_embedding, tail_embedding), -1)\n", - " elif self.dim_strategy == 'sum':\n", - " # 此时 pos_dim == word_dim\n", - " return word_embedding + head_embedding + tail_embedding\n", - " else:\n", - " raise Exception('dim_strategy must choose from [sum, cat]')" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# gcn model\n", - "class GCN(nn.Module):\n", - " def __init__(self,cfg):\n", - " super(GCN , self).__init__()\n", - "\n", - " self.num_layers = cfg.num_layers\n", - " self.input_size = cfg.input_size\n", - " self.hidden_size = cfg.hidden_size\n", - " self.dropout = cfg.dropout\n", - "\n", - " self.fc1 = nn.Linear(self.input_size , self.hidden_size)\n", - " self.fc = nn.Linear(self.hidden_size , self.hidden_size)\n", - " self.weight_list = nn.ModuleList()\n", - " for i in range(self.num_layers):\n", - " self.weight_list.append(nn.Linear(self.hidden_size * (i + 1),self.hidden_size))\n", - " self.dropout = nn.Dropout(self.dropout)\n", - "\n", - " def forward(self , x, adj):\n", - " L = adj.sum(2).unsqueeze(2) + 1\n", - " outputs = self.fc1(x)\n", - " cache_list = [outputs]\n", - " output_list = []\n", - " for l in range(self.num_layers):\n", - " Ax = adj.bmm(outputs)\n", - " AxW = self.weight_list[l](Ax)\n", - " AxW = AxW + self.weight_list[l](outputs)\n", - " AxW = AxW / L\n", - " gAxW = F.relu(AxW)\n", - " cache_list.append(gAxW)\n", - " outputs = torch.cat(cache_list , dim=2)\n", - " output_list.append(self.dropout(gAxW))\n", - " # gcn_outputs = torch.cat(output_list, dim=2)\n", - " gcn_outputs = output_list[self.num_layers - 1]\n", - " gcn_outputs = gcn_outputs + self.fc1(x)\n", - "\n", - " out = self.fc(gcn_outputs)\n", - " return out" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# p,r,f1 measurement\n", - "class PRMetric():\n", - " def __init__(self):\n", - " \"\"\"\n", - " 暂时调用 sklearn 的方法\n", - " \"\"\"\n", - " self.y_true = np.empty(0)\n", - " self.y_pred = np.empty(0)\n", - "\n", - " def reset(self):\n", - " self.y_true = np.empty(0)\n", - " self.y_pred = np.empty(0)\n", - "\n", - " def update(self, y_true:torch.Tensor, y_pred:torch.Tensor):\n", - " y_true = y_true.cpu().detach().numpy()\n", - " y_pred = y_pred.cpu().detach().numpy()\n", - " y_pred = np.argmax(y_pred,axis=-1)\n", - "\n", - " self.y_true = np.append(self.y_true, y_true)\n", - " self.y_pred = np.append(self.y_pred, y_pred)\n", - "\n", - " def compute(self):\n", - " p, r, f1, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='macro',warn_for=tuple())\n", - " _, _, acc, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='micro',warn_for=tuple())\n", - "\n", - " return acc,p,r,f1" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# Iteration in training process\n", - "def train(epoch, model, dataloader, optimizer, criterion, cfg):\n", - " model.train()\n", - "\n", - " metric = PRMetric()\n", - " losses = []\n", - "\n", - " for batch_idx, (x, y) in enumerate(dataloader, 1):\n", - " optimizer.zero_grad()\n", - " y_pred = model(x)\n", - " loss = criterion(y_pred, y)\n", - "\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " metric.update(y_true=y, y_pred=y_pred)\n", - " losses.append(loss.item())\n", - "\n", - " data_total = len(dataloader.dataset)\n", - " data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)\n", - " if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):\n", - " acc,p,r,f1 = metric.compute()\n", - " print(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\\t'\n", - " f'Loss: {loss.item():.6f}')\n", - " print(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\\t'\n", - " f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')\n", - "\n", - " if cfg.show_plot and not cfg.only_comparison_plot:\n", - " if cfg.plot_utils == 'matplot':\n", - " plt.plot(losses)\n", - " plt.title(f'epoch {epoch} train loss')\n", - " plt.show()\n", - "\n", - " return losses[-1]\n", - "\n", - "\n", - "# Iteration in testing process\n", - "def validate(epoch, model, dataloader, criterion,verbose=True):\n", - " model.eval()\n", - "\n", - " metric = PRMetric()\n", - " losses = []\n", - "\n", - " for batch_idx, (x, y) in enumerate(dataloader, 1):\n", - " with torch.no_grad():\n", - " y_pred = model(x)\n", - " loss = criterion(y_pred, y)\n", - "\n", - " metric.update(y_true=y, y_pred=y_pred)\n", - " losses.append(loss.item())\n", - "\n", - " loss = sum(losses) / len(losses)\n", - " acc,p,r,f1 = metric.compute()\n", - " data_total = len(dataloader.dataset)\n", - " if verbose:\n", - " print(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\\t Loss: {loss:.6f}')\n", - " print(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]\\n\\n')\n", - "\n", - " return f1,loss" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# Load dataset\n", - "train_dataset = CustomDataset(train_save_fp)\n", - "valid_dataset = CustomDataset(valid_save_fp)\n", - "test_dataset = CustomDataset(test_save_fp)\n", - "\n", - "train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n", - "valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n", - "test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# After the preprocessed data is loaded, vocab_size is known\n", - "vocab = load_pkl(vocab_save_fp)\n", - "vocab_size = vocab.count\n", - "cfg.vocab_size = vocab_size" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [ - "# main entry, define optimization function, loss function and so on\n", - "# start epoch\n", - "# Use the loss of the valid dataset to make an early stop judgment. When it does not decline, this is the time when the model generalization is the best.\n", - "model = GCN(cfg)\n", - "print(model)\n", - "\n", - "optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)\n", - "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)\n", - "criterion = nn.CrossEntropyLoss()\n", - "\n", - "best_f1, best_epoch = -1, 0\n", - "es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, = 1000, -1, 0, 0, 0, -1\n", - "train_losses, valid_losses = [], []\n", - "\n", - "logger.info('=' * 10 + ' Start training ' + '=' * 10)\n", - "for epoch in range(1, cfg.epoch + 1):\n", - " train_loss = train(epoch, model, train_dataloader, optimizer, criterion, cfg)\n", - " valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion)\n", - " scheduler.step(valid_loss)\n", - "\n", - " train_losses.append(train_loss)\n", - " valid_losses.append(valid_loss)\n", - " if best_f1 < valid_f1:\n", - " best_f1 = valid_f1\n", - " best_epoch = epoch\n", - " if es_loss > valid_loss:\n", - " es_loss = valid_loss\n", - " es_f1 = valid_f1\n", - " best_es_f1 = valid_f1\n", - " es_epoch = epoch\n", - " best_es_epoch = epoch\n", - " es_patience = 0\n", - " else:\n", - " es_patience += 1\n", - " if es_patience >= cfg.early_stopping_patience:\n", - " best_es_epoch = es_epoch\n", - " best_es_f1 = es_f1\n", - "\n", - "if cfg.show_plot:\n", - " if cfg.plot_utils == 'matplot':\n", - " plt.plot(train_losses, 'x-')\n", - " plt.plot(valid_losses, '+-')\n", - " plt.legend(['train', 'valid'])\n", - " plt.title('train/valid comparison loss')\n", - " plt.show()\n", - "\n", - "\n", - "print(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '\n", - " f'this epoch macro f1: {best_es_f1:0.4f}')\n", - "print(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '\n", - " f'this epoch macro f1: {best_f1:.4f}')\n", - "\n", - "test_f1, _ = validate(0, model, test_dataloader, criterion,verbose=False)\n", - "print(f'after {cfg.epoch} epochs, final test data macro f1: {test_f1:.4f}')" - ], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "code", - "execution_count": null, - "source": [], - "outputs": [], - "metadata": {} - }, - { - "cell_type": "markdown", - "source": [ - "This demo does not include parameter adjustment. Interested students can go to [deepke] by themselves( http://openkg.cn/tool/deepke )Warehouse, download and use more models:)" - ], - "metadata": {} - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}