diff --git a/src/deepke/ae/regular/module/Embedding.py b/src/deepke/ae/regular/module/Embedding.py index 83074a3..a55590e 100644 --- a/src/deepke/ae/regular/module/Embedding.py +++ b/src/deepke/ae/regular/module/Embedding.py @@ -19,21 +19,21 @@ class Embedding(nn.Module): self.dim_strategy = config.dim_strategy self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0) - self.headPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0) - self.tailPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0) + self.entityPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0) + self.attribute_keyPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0) self.layer_norm = nn.LayerNorm(self.word_dim) def forward(self, *x): - word, head, tail = x + word, entity, attribute_key = x word_embedding = self.wordEmbed(word) - head_embedding = self.headPosEmbed(head) - tail_embedding = self.tailPosEmbed(tail) + entity_embedding = self.entityPosEmbed(head) + attribute_key_embedding = self.attribute_keyPosEmbed(tail) if self.dim_strategy == 'cat': - return torch.cat((word_embedding, head_embedding, tail_embedding), -1) + return torch.cat((word_embedding, entity_embedding, attribute_key_embedding), -1) elif self.dim_strategy == 'sum': # 此时 pos_dim == word_dim - return self.layer_norm(word_embedding + head_embedding + tail_embedding) + return self.layer_norm(word_embedding + entity_embedding + attribute_key_embedding) else: raise Exception('dim_strategy must choose from [sum, cat]') diff --git a/tutorial-notebooks/ae/regular/RNN.ipynb b/tutorial-notebooks/ae/regular/RNN.ipynb new file mode 100644 index 0000000..4ff8b9f --- /dev/null +++ b/tutorial-notebooks/ae/regular/RNN.ipynb @@ -0,0 +1,761 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## relation extraction experiment\n", + "> Tutorial author: 陶联宽(22051063@zju.edu.cn)\n", + "\n", + "On this demo, we use `lstm` to extract attributions.We hope this demo can help you understand the process of construction knowledge graph and the principles and common methods of triplet extraction.\n", + "\n", + "This demo uses `Python3`.\n", + "\n", + "### Dataset\n", + "In this example,we get some Chinese text to extract the triples\n", + "\n", + "sentence|attribute|entity|entity_offset|attribute_value|attribute_value_offset\n", + ":---:|:---:|:---:|:---:|:---:|:---:\n", + "苏轼(1037~1101年),字子瞻,又字和仲,号“东坡居士”,眉州眉山(即今四川眉州)人,是宋代(北宋)著名的文学家、书画家|字|苏轼|0|和仲|21\n", + "阳成俊,男,汉族,贵州省委党校大学学历|民族|阳成俊|0|汉族|6\n", + "司马懿,字仲达,河南温县人|字|司马懿|0|仲达|6\n", + "\n", + "- train.csv: It contains 6 training triples,each lines represent one triple,sorted by sentence,attribute,entity,entity's offset,attribute value attribute value's offset,and separated by ,.\n", + "- valid.csv: It contains 2 training triples,each lines represent one triple,sorted by sentence,attribute,entity,entity's offset,attribute value attribute value's offset,and separated by ,.\n", + "- test.csv: It contains 2 training triples,each lines represent one triple,sorted by sentence,attribute,entity,entity's offset,attribute value attribute value's offset,and separated by ,.\n", + "- attribute.csv: It contains 3 attribute triples,each lines sorted by attribute,index and separated by ,.\n" + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "### LSTM\n", + "\n", + "![LSTM](img/LSTM.jpg)\n", + "\n", + "The sentence information mainly includes wording embedding.After the rnn layer,according to the position of entity,attribute key,it through the full connection layer, the attribution information of the sentence can be obtained." + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# Run the neural network with pytorch and confirm whether it is installed before running\n", + "!pip install torch\n", + "!pip install matplotlib\n", + "!pip install transformers" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# import the whole modules\n", + "import os\n", + "import csv\n", + "import math\n", + "import pickle\n", + "import logging\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from torch import optim\n", + "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n", + "from torch.utils.data import Dataset,DataLoader\n", + "from sklearn.metrics import precision_recall_fscore_support\n", + "from typing import List, Tuple, Dict, Any, Sequence, Optional, Union\n", + "from transformers import BertTokenizer, BertModel\n", + "\n", + "logger = logging.getLogger(__name__)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# Configuration file of model parameters\n", + "# use_pcnn Parameter controls whether there is a piece_ Wise pooling\n", + "\n", + "class Config(object):\n", + " model_name = 'rnn' # ['cnn', 'gcn', 'lm','rnn']\n", + " use_pcnn = False\n", + " min_freq = 1\n", + " pos_limit = 20\n", + " out_path = 'data/out' \n", + " batch_size = 2 \n", + " word_dim = 10\n", + " pos_dim = 5\n", + " dim_strategy = 'sum' # ['sum', 'cat']\n", + " out_channels = 20\n", + " intermediate = 10\n", + " kernel_sizes = [3, 5, 7]\n", + " activation = 'gelu'\n", + " pooling_strategy = 'max'\n", + " dropout = 0.3\n", + " epoch = 10\n", + " num_relations = 4\n", + " learning_rate = 3e-4\n", + " lr_factor = 0.7 # 学习率的衰减率\n", + " lr_patience = 3 # 学习率衰减的等待epoch\n", + " weight_decay = 1e-3 # L2正则\n", + " early_stopping_patience = 6\n", + " train_log = True\n", + " log_interval = 1\n", + " show_plot = True\n", + " only_comparison_plot = False\n", + " plot_utils = 'matplot'\n", + " lm_file = 'bert-base-chinese'\n", + " lm_num_hidden_layers = 2\n", + " rnn_layers = 2\n", + " \n", + "cfg = Config()" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# Word token builds a one hot dictionary, and then inputs it to the embedding layer to obtain the corresponding word information matrix\n", + "# 0 is pad by default and 1 is unknown\n", + "class Vocab(object):\n", + " def __init__(self, name: str = 'basic', init_tokens = [\"[PAD]\", \"[UNK]\"]):\n", + " self.name = name\n", + " self.init_tokens = init_tokens\n", + " self.trimed = False\n", + " self.word2idx = {}\n", + " self.word2count = {}\n", + " self.idx2word = {}\n", + " self.count = 0\n", + " self._add_init_tokens()\n", + "\n", + " def _add_init_tokens(self):\n", + " for token in self.init_tokens:\n", + " self._add_word(token)\n", + "\n", + " def _add_word(self, word: str):\n", + " if word not in self.word2idx:\n", + " self.word2idx[word] = self.count\n", + " self.word2count[word] = 1\n", + " self.idx2word[self.count] = word\n", + " self.count += 1\n", + " else:\n", + " self.word2count[word] += 1\n", + "\n", + " def add_words(self, words: Sequence):\n", + " for word in words:\n", + " self._add_word(word)\n", + "\n", + " def trim(self, min_freq=2, verbose: Optional[bool] = True):\n", + " assert min_freq == int(min_freq), f'min_freq must be integer, can\\'t be {min_freq}'\n", + " min_freq = int(min_freq)\n", + " if min_freq < 2:\n", + " return\n", + " if self.trimed:\n", + " return\n", + " self.trimed = True\n", + "\n", + " keep_words = []\n", + " new_words = []\n", + "\n", + " for k, v in self.word2count.items():\n", + " if v >= min_freq:\n", + " keep_words.append(k)\n", + " new_words.extend([k] * v)\n", + " if verbose:\n", + " before_len = len(keep_words)\n", + " after_len = len(self.word2idx) - len(self.init_tokens)\n", + " logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(before_len, after_len, before_len / after_len * 100))\n", + "\n", + " # Reinitialize dictionaries\n", + " self.word2idx = {}\n", + " self.word2count = {}\n", + " self.idx2word = {}\n", + " self.count = 0\n", + " self._add_init_tokens()\n", + " self.add_words(new_words)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# Functions required for preprocessing\n", + "Path = str\n", + "\n", + "def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:\n", + " if verbose:\n", + " logger.info(f'load csv from {fp}')\n", + "\n", + " dialect = 'excel-tab' if is_tsv else 'excel'\n", + " with open(fp, encoding='utf-8') as f:\n", + " reader = csv.DictReader(f, dialect=dialect)\n", + " return list(reader)\n", + "\n", + " \n", + "def load_pkl(fp: Path, verbose: bool = True) -> Any:\n", + " if verbose:\n", + " logger.info(f'load data from {fp}')\n", + "\n", + " with open(fp, 'rb') as f:\n", + " data = pickle.load(f)\n", + " return data\n", + "\n", + "\n", + "def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:\n", + " if verbose:\n", + " logger.info(f'save data in {fp}')\n", + "\n", + " with open(fp, 'wb') as f:\n", + " pickle.dump(data, f)\n", + " \n", + "def _handle_attribute_data(attribute_data: List[Dict]) -> Dict:\n", + " atts = OrderedDict()\n", + " attribute_data = sorted(attribute_data, key=lambda i: int(i['index']))\n", + " for d in attribute_data:\n", + " atts[d['attribute']] = {\n", + " 'index': int(d['index'])\n", + " }\n", + " return atts\n", + "\n", + "def _add_attribute_data(atts: Dict, data: List) -> None:\n", + " for d in data:\n", + " d['att2idx'] = atts[d['attribute']]['index']\n", + "\n", + "def _convert_tokens_into_index(data: List[Dict], vocab):\n", + " unk_str = '[UNK]'\n", + " unk_idx = vocab.word2idx[unk_str]\n", + "\n", + " for d in data:\n", + " d['token2idx'] = [vocab.word2idx.get(i, unk_idx) for i in d['tokens']]\n", + "\n", + "def _add_pos_seq(train_data: List[Dict], cfg):\n", + " for d in train_data:\n", + " d['entity_pos'] = list(map(lambda i: i - d['entity_index'], list(range(d['seq_len']))))\n", + " d['entity_pos'] = _handle_pos_limit(d['entity_pos'],int(cfg.pos_limit))\n", + "\n", + " d['attribute_value_pos'] = list(map(lambda i: i - d['attribute_value_index'], list(range(d['seq_len']))))\n", + " d['attribute_value_pos'] = _handle_pos_limit(d['attribute_value_pos'],int(cfg.pos_limit))\n", + " \n", + "def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:\n", + " for i,p in enumerate(pos):\n", + " if p > limit:\n", + " pos[i] = limit\n", + " if p < -limit:\n", + " pos[i] = -limit\n", + " return [p + limit + 1 for p in pos]\n", + "\n", + "def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):\n", + " if isinstance(seq_len, list):\n", + " seq_len = np.array(seq_len)\n", + "\n", + " if isinstance(seq_len, np.ndarray):\n", + " seq_len = torch.from_numpy(seq_len)\n", + "\n", + " if isinstance(seq_len, torch.Tensor):\n", + " assert seq_len.dim() == 1, logger.error(f\"seq_len can only have one dimension, got {seq_len.dim()} != 1.\")\n", + " batch_size = seq_len.size(0)\n", + " max_len = int(max_len) if max_len else seq_len.max().long()\n", + " broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len.device)\n", + " if mask_pos_to_true:\n", + " mask = broad_cast_seq_len.ge(seq_len.unsqueeze(1))\n", + " else:\n", + " mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))\n", + " else:\n", + " raise logger.error(\"Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.\")\n", + "\n", + " return mask\n" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# Preprocess\n", + "logger.info('load raw files...')\n", + "train_fp = os.path.join('data/train.csv')\n", + "valid_fp = os.path.join('data/valid.csv')\n", + "test_fp = os.path.join('data/test.csv')\n", + "attribute_fp = os.path.join('data/attribute.csv')\n", + "\n", + "train_data = load_csv(train_fp)\n", + "valid_data = load_csv(valid_fp)\n", + "test_data = load_csv(test_fp)\n", + "attribute_data = load_csv(attribute_fp)\n", + "\n", + "for d in train_data:\n", + " d['tokens'] = eval(d['tokens'])\n", + "for d in valid_data:\n", + " d['tokens'] = eval(d['tokens'])\n", + "for d in test_data:\n", + " d['tokens'] = eval(d['tokens'])\n", + "\n", + "logger.info('convert relation into index...')\n", + "atts = _handle_attribute_data(attribute_data)\n", + "_add_attribute_data(atts,train_data)\n", + "_add_attribute_data(atts,test_data)\n", + "_add_attribute_data(atts,valid_data)\n", + "\n", + "logger.info('build vocabulary...')\n", + "vocab = Vocab('word')\n", + "train_tokens = [d['tokens'] for d in train_data]\n", + "valid_tokens = [d['tokens'] for d in valid_data]\n", + "test_tokens = [d['tokens'] for d in test_data]\n", + "sent_tokens = [*train_tokens, *valid_tokens, *test_tokens]\n", + "for sent in sent_tokens:\n", + " vocab.add_words(sent)\n", + "vocab.trim(min_freq=cfg.min_freq)\n", + "\n", + "logger.info('convert tokens into index...')\n", + "_convert_tokens_into_index(train_data, vocab)\n", + "_convert_tokens_into_index(valid_data, vocab)\n", + "_convert_tokens_into_index(test_data, vocab)\n", + "\n", + "logger.info('build position sequence...')\n", + "_add_pos_seq(train_data, cfg)\n", + "_add_pos_seq(valid_data, cfg)\n", + "_add_pos_seq(test_data, cfg)\n", + "\n", + "logger.info('save data for backup...')\n", + "os.makedirs(cfg.out_path, exist_ok=True)\n", + "train_save_fp = os.path.join(cfg.out_path, 'train.pkl')\n", + "valid_save_fp = os.path.join(cfg.out_path, 'valid.pkl')\n", + "test_save_fp = os.path.join(cfg.out_path, 'test.pkl')\n", + "save_pkl(train_data, train_save_fp)\n", + "save_pkl(valid_data, valid_save_fp)\n", + "save_pkl(test_data, test_save_fp)\n", + "\n", + "vocab_save_fp = os.path.join(cfg.out_path, 'vocab.pkl')\n", + "vocab_txt = os.path.join(cfg.out_path, 'vocab.txt')\n", + "save_pkl(vocab, vocab_save_fp)\n", + "logger.info('save vocab in txt file, for watching...')\n", + "with open(vocab_txt, 'w', encoding='utf-8') as f:\n", + " f.write(os.linesep.join(vocab.word2idx.keys()))\n", + "\n" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# pytorch construct Dataset\n", + "def collate_fn(cfg):\n", + " def collate_fn_intra(batch):\n", + " batch.sort(key=lambda data: data['seq_len'], reverse=True)\n", + " max_len = batch[0]['seq_len']\n", + "\n", + " def _padding(x, max_len):\n", + " return x + [0] * (max_len - len(x))\n", + "\n", + " x, y = dict(), []\n", + " word, word_len = [], []\n", + " head_pos, tail_pos = [], []\n", + " pcnn_mask = []\n", + " for data in batch:\n", + " word.append(_padding(data['token2idx'], max_len))\n", + " word_len.append(data['seq_len'])\n", + " y.append(int(data['att2idx']))\n", + "\n", + " if cfg.model_name != 'lm':\n", + " head_pos.append(_padding(data['entity_pos'], max_len))\n", + " tail_pos.append(_padding(data['attribute_value_pos'], max_len))\n", + " if cfg.model_name == 'cnn':\n", + " if cfg.use_pcnn:\n", + " pcnn_mask.append(_padding(data['entities_pos'], max_len))\n", + "\n", + " x['word'] = torch.tensor(word)\n", + " x['lens'] = torch.tensor(word_len)\n", + " y = torch.tensor(y)\n", + "\n", + " if cfg.model_name != 'lm':\n", + " x['entity_pos'] = torch.tensor(head_pos)\n", + " x['attribute_value_pos'] = torch.tensor(tail_pos)\n", + " if cfg.model_name == 'cnn' and cfg.use_pcnn:\n", + " x['pcnn_mask'] = torch.tensor(pcnn_mask)\n", + " if cfg.model_name == 'gcn':\n", + " # 没找到合适的做 parsing tree 的工具,暂时随机初始化\n", + " B, L = len(batch), max_len\n", + " adj = torch.empty(B, L, L).random_(2)\n", + " x['adj'] = adj\n", + " return x, y\n", + "\n", + " return collate_fn_intra\n", + "\n", + "\n", + "class CustomDataset(Dataset):\n", + " \"\"\"\n", + " 默认使用 List 存储数据\n", + " \"\"\"\n", + " def __init__(self, fp):\n", + " self.file = load_pkl(fp)\n", + "\n", + " def __getitem__(self, item):\n", + " sample = self.file[item]\n", + " return sample\n", + "\n", + " def __len__(self):\n", + " return len(self.file)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# embedding layer\n", + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "class Embedding(nn.Module):\n", + " def __init__(self, config):\n", + " \"\"\"\n", + " word embedding: 一般 0 为 padding\n", + " pos embedding: 一般 0 为 padding\n", + " dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加\n", + " \"\"\"\n", + " super(Embedding, self).__init__()\n", + "\n", + " # self.xxx = config.xxx\n", + " self.vocab_size = config.vocab_size\n", + " self.word_dim = config.word_dim\n", + " self.pos_size = config.pos_size\n", + " self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim\n", + " self.dim_strategy = config.dim_strategy\n", + "\n", + " self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0)\n", + " self.entityPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)\n", + " self.attribute_keyPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)\n", + " \n", + " self.layer_norm = nn.LayerNorm(self.word_dim)\n", + "\n", + " def forward(self, *x):\n", + " word, entity, attribute_key = x\n", + " word_embedding = self.wordEmbed(word)\n", + " entity_embedding = self.entityPosEmbed(head)\n", + " attribute_key_embedding = self.attribute_keyPosEmbed(tail)\n", + "\n", + " if self.dim_strategy == 'cat':\n", + " return torch.cat((word_embedding, entity_embedding, attribute_key_embedding), -1)\n", + " elif self.dim_strategy == 'sum':\n", + " # 此时 pos_dim == word_dim\n", + " return self.layer_norm(word_embedding + entity_embedding + attribute_key_embedding)\n", + " else:\n", + " raise Exception('dim_strategy must choose from [sum, cat]')\n" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# Gelu activation function, specified by transformer, works better than relu\n", + "class GELU(nn.Module):\n", + " def __init__(self):\n", + " super(GELU, self).__init__()\n", + "\n", + " def forward(self, x):\n", + " return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "#rnn model\n", + "class RNN(nn.Module):\n", + " def __init__(self, config):\n", + " \"\"\"\n", + " type_rnn: RNN, GRU, LSTM 可选\n", + " \"\"\"\n", + " super(RNN, self).__init__()\n", + "\n", + " # self.xxx = config.xxx\n", + " self.input_size = config.input_size\n", + " self.hidden_size = config.hidden_size // 2 if config.bidirectional else config.hidden_size\n", + " self.num_layers = config.num_layers\n", + " self.dropout = config.dropout\n", + " self.bidirectional = config.bidirectional\n", + " self.last_layer_hn = config.last_layer_hn\n", + " self.type_rnn = config.type_rnn\n", + "\n", + " rnn = eval(f'nn.{self.type_rnn}')\n", + " self.rnn = rnn(input_size=self.input_size,\n", + " hidden_size=self.hidden_size,\n", + " num_layers=self.num_layers,\n", + " dropout=self.dropout,\n", + " bidirectional=self.bidirectional,\n", + " bias=True,\n", + " batch_first=True)\n", + "\n", + " # 有bug\n", + " # self._init_weights()\n", + "\n", + " def _init_weights(self):\n", + " \"\"\"orthogonal init yields generally good results than uniform init\"\"\"\n", + " gain = 1 # use default value\n", + " for nth in range(self.num_layers * self.bidirectional):\n", + " # w_ih, (4 * hidden_size x input_size)\n", + " nn.init.orthogonal_(self.rnn.all_weights[nth][0], gain=gain)\n", + " # w_hh, (4 * hidden_size x hidden_size)\n", + " nn.init.orthogonal_(self.rnn.all_weights[nth][1], gain=gain)\n", + " # b_ih, (4 * hidden_size)\n", + " nn.init.zeros_(self.rnn.all_weights[nth][2])\n", + " # b_hh, (4 * hidden_size)\n", + " nn.init.zeros_(self.rnn.all_weights[nth][3])\n", + "\n", + " def forward(self, x, x_len):\n", + " \"\"\"\n", + " Args: \n", + " torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值\n", + " x_len: torch.Tensor [L] 已经排好序的句长值\n", + " Returns:\n", + " output: torch.Tensor [B, L, H_out] 序列标注的使用结果\n", + " hn: torch.Tensor [B, N, H_out] / [B, H_out] 分类的结果,当 last_layer_hn 时只有最后一层结果\n", + " \"\"\"\n", + " B, L, _ = x.size()\n", + " H, N = self.hidden_size, self.num_layers\n", + "\n", + " x_len = x_len.cpu()\n", + " x = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=True)\n", + " output, hn = self.rnn(x)\n", + " output, _ = pad_packed_sequence(output, batch_first=True, total_length=L)\n", + "\n", + " if self.type_rnn == 'LSTM':\n", + " hn = hn[0]\n", + " if self.bidirectional:\n", + " hn = hn.view(N, 2, B, H).transpose(1, 2).contiguous().view(N, B, 2 * H).transpose(0, 1)\n", + " else:\n", + " hn = hn.transpose(0, 1)\n", + " if self.last_layer_hn:\n", + " hn = hn[:, -1, :]\n", + "\n", + " return output, hn" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# p,r,f1 measurement\n", + "class PRMetric():\n", + " def __init__(self):\n", + " \n", + " self.y_true = np.empty(0)\n", + " self.y_pred = np.empty(0)\n", + "\n", + " def reset(self):\n", + " self.y_true = np.empty(0)\n", + " self.y_pred = np.empty(0)\n", + "\n", + " def update(self, y_true:torch.Tensor, y_pred:torch.Tensor):\n", + " y_true = y_true.cpu().detach().numpy()\n", + " y_pred = y_pred.cpu().detach().numpy()\n", + " y_pred = np.argmax(y_pred,axis=-1)\n", + "\n", + " self.y_true = np.append(self.y_true, y_true)\n", + " self.y_pred = np.append(self.y_pred, y_pred)\n", + "\n", + " def compute(self):\n", + " p, r, f1, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='macro',warn_for=tuple())\n", + " _, _, acc, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='micro',warn_for=tuple())\n", + "\n", + " return acc,p,r,f1" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# Iteration in training process\n", + "def train(epoch, model, dataloader, optimizer, criterion, cfg):\n", + " model.train()\n", + "\n", + " metric = PRMetric()\n", + " losses = []\n", + "\n", + " for batch_idx, (x, y) in enumerate(dataloader, 1):\n", + " optimizer.zero_grad()\n", + " y_pred = model(x)\n", + " loss = criterion(y_pred, y)\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " metric.update(y_true=y, y_pred=y_pred)\n", + " losses.append(loss.item())\n", + "\n", + " data_total = len(dataloader.dataset)\n", + " data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)\n", + " if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):\n", + " acc,p,r,f1 = metric.compute()\n", + " print(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\\t'\n", + " f'Loss: {loss.item():.6f}')\n", + " print(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\\t'\n", + " f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')\n", + "\n", + " if cfg.show_plot and not cfg.only_comparison_plot:\n", + " if cfg.plot_utils == 'matplot':\n", + " plt.plot(losses)\n", + " plt.title(f'epoch {epoch} train loss')\n", + " plt.show()\n", + "\n", + " return losses[-1]\n", + "\n", + "# Iteration in testing process\n", + "def validate(epoch, model, dataloader, criterion,verbose=True):\n", + " model.eval()\n", + "\n", + " metric = PRMetric()\n", + " losses = []\n", + "\n", + " for batch_idx, (x, y) in enumerate(dataloader, 1):\n", + " with torch.no_grad():\n", + " y_pred = model(x)\n", + " loss = criterion(y_pred, y)\n", + "\n", + " metric.update(y_true=y, y_pred=y_pred)\n", + " losses.append(loss.item())\n", + "\n", + " loss = sum(losses) / len(losses)\n", + " acc,p,r,f1 = metric.compute()\n", + " data_total = len(dataloader.dataset)\n", + " if verbose:\n", + " print(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\\t Loss: {loss:.6f}')\n", + " print(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]\\n\\n')\n", + "\n", + " return f1,loss" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# Load dataset\n", + "train_dataset = CustomDataset(train_save_fp)\n", + "valid_dataset = CustomDataset(valid_save_fp)\n", + "test_dataset = CustomDataset(test_save_fp)\n", + "\n", + "train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n", + "valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n", + "test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# After the preprocessed data is loaded, vocab_size is known\n", + "vocab = load_pkl(vocab_save_fp)\n", + "vocab_size = vocab.count\n", + "cfg.vocab_size = vocab_size" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# main entry, define optimization function, loss function and so on\n", + "# start epoch\n", + "# Use the loss of the valid dataset to make an early stop judgment. When it does not decline, this is the time when the model generalization is the best.\n", + "model = RNN(cfg)\n", + "print(model)\n", + "\n", + "optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)\n", + "scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)\n", + "criterion = nn.CrossEntropyLoss()\n", + "\n", + "best_f1, best_epoch = -1, 0\n", + "es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, = 1000, -1, 0, 0, 0, -1\n", + "train_losses, valid_losses = [], []\n", + "\n", + "logger.info('=' * 10 + ' Start training ' + '=' * 10)\n", + "for epoch in range(1, cfg.epoch + 1):\n", + " train_loss = train(epoch, model, train_dataloader, optimizer, criterion, cfg)\n", + " valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion)\n", + " scheduler.step(valid_loss)\n", + "\n", + " train_losses.append(train_loss)\n", + " valid_losses.append(valid_loss)\n", + " if best_f1 < valid_f1:\n", + " best_f1 = valid_f1\n", + " best_epoch = epoch\n", + " # 使用 valid loss 做 early stopping 的判断标准\n", + " if es_loss > valid_loss:\n", + " es_loss = valid_loss\n", + " es_f1 = valid_f1\n", + " best_es_f1 = valid_f1\n", + " es_epoch = epoch\n", + " best_es_epoch = epoch\n", + " es_patience = 0\n", + " else:\n", + " es_patience += 1\n", + " if es_patience >= cfg.early_stopping_patience:\n", + " best_es_epoch = es_epoch\n", + " best_es_f1 = es_f1\n", + "\n", + "if cfg.show_plot:\n", + " if cfg.plot_utils == 'matplot':\n", + " plt.plot(train_losses, 'x-')\n", + " plt.plot(valid_losses, '+-')\n", + " plt.legend(['train', 'valid'])\n", + " plt.title('train/valid comparison loss')\n", + " plt.show()\n", + "\n", + "\n", + "print(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '\n", + " f'this epoch macro f1: {best_es_f1:0.4f}')\n", + "print(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '\n", + " f'this epoch macro f1: {best_f1:.4f}')\n", + "\n", + "test_f1, _ = validate(0, model, test_dataloader, criterion,verbose=False)\n", + "print(f'after {cfg.epoch} epochs, final test data macro f1: {test_f1:.4f}')\n" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "This demo does not include parameter adjustment. Interested students can go to [deepke] by themselves( http://openkg.cn/tool/deepke )Warehouse, download and use more models:)" + ], + "metadata": {} + } + ], + "metadata": { + "orig_nbformat": 4, + "language_info": { + "name": "plaintext" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/tutorial-notebooks/ae/regular/data/attribute.csv b/tutorial-notebooks/ae/regular/data/attribute.csv new file mode 100644 index 0000000..b364e74 --- /dev/null +++ b/tutorial-notebooks/ae/regular/data/attribute.csv @@ -0,0 +1,4 @@ +attribute,index +None,0 +民族,1 +字,2 diff --git a/tutorial-notebooks/ae/regular/data/test.csv b/tutorial-notebooks/ae/regular/data/test.csv new file mode 100644 index 0000000..3daad05 --- /dev/null +++ b/tutorial-notebooks/ae/regular/data/test.csv @@ -0,0 +1,3 @@ +sentence,attribute,entity,entity_offset,attribute_value,attribute_value_offset +柳为易,女,1989年5月出生,中共党员 ,汉族,重庆市人,民族,柳为易,0,汉族,22 +庄肇奎 (1728-1798) 榜姓杜,字星堂,号胥园,江苏武进籍,浙江秀水(今嘉兴)人,字,庄肇奎,0,星堂,23 diff --git a/tutorial-notebooks/ae/regular/data/train.csv b/tutorial-notebooks/ae/regular/data/train.csv new file mode 100644 index 0000000..69c5369 --- /dev/null +++ b/tutorial-notebooks/ae/regular/data/train.csv @@ -0,0 +1,7 @@ +sentence,attribute,entity,entity_offset,attribute_value,attribute_value_offset +苏轼(1037~1101年),字子瞻,又字和仲,号“东坡居士”,眉州眉山(即今四川眉州)人,是宋代(北宋)著名的文学家、书画家,字,苏轼,0,和仲,21 +屈中乾,男,汉族,中共党员,特级教师,民族,屈中乾,0,汉族,6 +阳成俊,男,汉族,贵州省委党校大学学历,民族,阳成俊,0,汉族,6 +黄向静,女,汉族,1965年5月生,大学学历,1986年17月参加工作,中共党员,身体健康,民族,黄向静,0,汉族,6 +生平简介陈执中(990-1059),字昭誉,名相陈恕之子,北宋洪州南昌(今属江西)人,字,陈执中,4,昭誉,19 +司马懿,字仲达,河南温县人,字,司马懿,0,仲达,5 diff --git a/tutorial-notebooks/ae/regular/data/valid.csv b/tutorial-notebooks/ae/regular/data/valid.csv new file mode 100644 index 0000000..880d24a --- /dev/null +++ b/tutorial-notebooks/ae/regular/data/valid.csv @@ -0,0 +1,3 @@ +sentence,attribute,entity,entity_offset,attribute_value,attribute_value_offset +田承冉 男,1952年生,汉族,山东桓台人,共党员,民族,田承冉,0,汉族,13 +冷家骥,字展麒,山东招远人,字,冷家骥,0,展麒,5 diff --git a/tutorial-notebooks/ae/regular/img/LSTM.jpg b/tutorial-notebooks/ae/regular/img/LSTM.jpg new file mode 100644 index 0000000..0ba1ad7 Binary files /dev/null and b/tutorial-notebooks/ae/regular/img/LSTM.jpg differ