test
This commit is contained in:
parent
79163abc4f
commit
1f9041d9b3
|
@ -0,0 +1,706 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## relation extraction experiment\n",
|
||||
"> Tutorial author:余海阳(yuhaiyang@zju.edu.cn)\n",
|
||||
"\n",
|
||||
"On this demo,we use `GCN` to extract relations.\n",
|
||||
"We hope this demo can help you understand the process of conctruction knowledge graph and the the principles and common methods of triplet extraction.\n",
|
||||
"\n",
|
||||
"This demo uses `Python3`.\n",
|
||||
"\n",
|
||||
"### Dataset\n",
|
||||
"In this example,we get some Chinese text to extract the triples.\n",
|
||||
"\n",
|
||||
"sentence|relation|head|tail\n",
|
||||
":---:|:---:|:---:|:---:\n",
|
||||
"孔正锡在2005年以一部温馨的爱情电影《长腿叔叔》敲开电影界大门。|导演|长腿叔叔|孔正锡\n",
|
||||
"《伤心的树》是吴宗宪的音乐作品,收录在《你比从前快乐》专辑中。|所属专辑|伤心的树|你比从前快乐\n",
|
||||
"2000年8月,「天坛大佛」荣获「香港十大杰出工程项目」第四名。|所在城市|天坛大佛|香港\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"- train.csv: It contains 6 training triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n",
|
||||
"- valid.csv: It contains 3 validing triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n",
|
||||
"- test.csv: It contains 3 testing triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n",
|
||||
"- relation.csv: It contains 4 relation triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### GCN \n",
|
||||
"\n",
|
||||
"![GCN](img/GCN.png)\n",
|
||||
"\n",
|
||||
"The sentence information mainly includes word embedding and position embedding and the adjacency matrix adj_matrix\n",
|
||||
"The nodes in the adjacency matrix are each word token.\n",
|
||||
"After input to the multi-layer (generally take 2 or 3 layers, and the result will not be significantly improved if it is too multi-layer), the relationship information of the sentence can be obtained through the maximum pool input to the full connection layer.\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Run the neural network with pytorch and confirm whether it is installed before running\n",
|
||||
"!pip install torch\n",
|
||||
"!pip install matplotlib\n",
|
||||
"!pip install transformers"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# import the whole modules\n",
|
||||
"import os\n",
|
||||
"import csv\n",
|
||||
"import math\n",
|
||||
"import pickle\n",
|
||||
"import logging\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from torch import optim\n",
|
||||
"from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
|
||||
"from torch.utils.data import Dataset,DataLoader\n",
|
||||
"from sklearn.metrics import precision_recall_fscore_support\n",
|
||||
"from typing import List, Tuple, Dict, Any, Sequence, Optional, Union\n",
|
||||
"from transformers import BertTokenizer, BertModel\n",
|
||||
"\n",
|
||||
"logger = logging.getLogger(__name__)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Configuration file of model parameters\n",
|
||||
"class Config(object):\n",
|
||||
" model_name = 'gcn' # ['cnn', 'gcn', 'lm']\n",
|
||||
" use_pcnn = True\n",
|
||||
" min_freq = 1\n",
|
||||
" pos_limit = 20\n",
|
||||
" out_path = 'data/out' \n",
|
||||
" batch_size = 2 \n",
|
||||
" word_dim = 10\n",
|
||||
" pos_dim = 5\n",
|
||||
" dim_strategy = 'sum' # ['sum', 'cat']\n",
|
||||
" out_channels = 20\n",
|
||||
" intermediate = 10\n",
|
||||
" kernel_sizes = [3, 5, 7]\n",
|
||||
" activation = 'gelu'\n",
|
||||
" pooling_strategy = 'max'\n",
|
||||
" dropout = 0.3\n",
|
||||
" epoch = 10\n",
|
||||
" num_relations = 4\n",
|
||||
" learning_rate = 3e-4\n",
|
||||
" lr_factor = 0.7 # 学习率的衰减率\n",
|
||||
" lr_patience = 3 # 学习率衰减的等待epoch\n",
|
||||
" weight_decay = 1e-3 # L2正则\n",
|
||||
" early_stopping_patience = 6\n",
|
||||
" train_log = True\n",
|
||||
" log_interval = 1\n",
|
||||
" show_plot = True\n",
|
||||
" only_comparison_plot = False\n",
|
||||
" plot_utils = 'matplot'\n",
|
||||
" lm_file = 'bert-base-chinese'\n",
|
||||
" lm_num_hidden_layers = 2\n",
|
||||
" rnn_layers = 2\n",
|
||||
" \n",
|
||||
"cfg = Config()"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Word token builds a one hot dictionary, and then inputs it to the embedding layer to obtain the corresponding word information matrix\n",
|
||||
"# 0 is pad by default and 1 is unknown\n",
|
||||
"\n",
|
||||
"class Vocab(object):\n",
|
||||
" def __init__(self, name: str = 'basic', init_tokens = [\"[PAD]\", \"[UNK]\"]):\n",
|
||||
" self.name = name\n",
|
||||
" self.init_tokens = init_tokens\n",
|
||||
" self.trimed = False\n",
|
||||
" self.word2idx = {}\n",
|
||||
" self.word2count = {}\n",
|
||||
" self.idx2word = {}\n",
|
||||
" self.count = 0\n",
|
||||
" self._add_init_tokens()\n",
|
||||
"\n",
|
||||
" def _add_init_tokens(self):\n",
|
||||
" for token in self.init_tokens:\n",
|
||||
" self._add_word(token)\n",
|
||||
"\n",
|
||||
" def _add_word(self, word: str):\n",
|
||||
" if word not in self.word2idx:\n",
|
||||
" self.word2idx[word] = self.count\n",
|
||||
" self.word2count[word] = 1\n",
|
||||
" self.idx2word[self.count] = word\n",
|
||||
" self.count += 1\n",
|
||||
" else:\n",
|
||||
" self.word2count[word] += 1\n",
|
||||
"\n",
|
||||
" def add_words(self, words: Sequence):\n",
|
||||
" for word in words:\n",
|
||||
" self._add_word(word)\n",
|
||||
"\n",
|
||||
" def trim(self, min_freq=2, verbose: Optional[bool] = True):\n",
|
||||
" \n",
|
||||
" assert min_freq == int(min_freq), f'min_freq must be integer, can\\'t be {min_freq}'\n",
|
||||
" min_freq = int(min_freq)\n",
|
||||
" if min_freq < 2:\n",
|
||||
" return\n",
|
||||
" if self.trimed:\n",
|
||||
" return\n",
|
||||
" self.trimed = True\n",
|
||||
"\n",
|
||||
" keep_words = []\n",
|
||||
" new_words = []\n",
|
||||
"\n",
|
||||
" for k, v in self.word2count.items():\n",
|
||||
" if v >= min_freq:\n",
|
||||
" keep_words.append(k)\n",
|
||||
" new_words.extend([k] * v)\n",
|
||||
" if verbose:\n",
|
||||
" before_len = len(keep_words)\n",
|
||||
" after_len = len(self.word2idx) - len(self.init_tokens)\n",
|
||||
" logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(before_len, after_len, before_len / after_len * 100))\n",
|
||||
"\n",
|
||||
" # Reinitialize dictionaries\n",
|
||||
" self.word2idx = {}\n",
|
||||
" self.word2count = {}\n",
|
||||
" self.idx2word = {}\n",
|
||||
" self.count = 0\n",
|
||||
" self._add_init_tokens()\n",
|
||||
" self.add_words(new_words)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Functions required for preprocessing\n",
|
||||
"Path = str\n",
|
||||
"\n",
|
||||
"def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'load csv from {fp}')\n",
|
||||
"\n",
|
||||
" dialect = 'excel-tab' if is_tsv else 'excel'\n",
|
||||
" with open(fp, encoding='utf-8') as f:\n",
|
||||
" reader = csv.DictReader(f, dialect=dialect)\n",
|
||||
" return list(reader)\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"def load_pkl(fp: Path, verbose: bool = True) -> Any:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'load data from {fp}')\n",
|
||||
"\n",
|
||||
" with open(fp, 'rb') as f:\n",
|
||||
" data = pickle.load(f)\n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'save data in {fp}')\n",
|
||||
"\n",
|
||||
" with open(fp, 'wb') as f:\n",
|
||||
" pickle.dump(data, f)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"def _handle_relation_data(relation_data: List[Dict]) -> Dict:\n",
|
||||
" rels = dict()\n",
|
||||
" for d in relation_data:\n",
|
||||
" rels[d['relation']] = {\n",
|
||||
" 'index': int(d['index']),\n",
|
||||
" 'head_type': d['head_type'],\n",
|
||||
" 'tail_type': d['tail_type'],\n",
|
||||
" }\n",
|
||||
" return rels\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _add_relation_data(rels: Dict,data: List) -> None:\n",
|
||||
" for d in data:\n",
|
||||
" d['rel2idx'] = rels[d['relation']]['index']\n",
|
||||
" d['head_type'] = rels[d['relation']]['head_type']\n",
|
||||
" d['tail_type'] = rels[d['relation']]['tail_type']\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _convert_tokens_into_index(data: List[Dict], vocab):\n",
|
||||
" unk_str = '[UNK]'\n",
|
||||
" unk_idx = vocab.word2idx[unk_str]\n",
|
||||
"\n",
|
||||
" for d in data:\n",
|
||||
" d['token2idx'] = [vocab.word2idx.get(i, unk_idx) for i in d['tokens']]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _add_pos_seq(train_data: List[Dict], cfg):\n",
|
||||
" for d in train_data:\n",
|
||||
" d['head_offset'], d['tail_offset'], d['lens'] = int(d['head_offset']), int(d['tail_offset']), int(d['lens'])\n",
|
||||
" entities_idx = [d['head_offset'], d['tail_offset']] if d['head_offset'] < d['tail_offset'] else [d['tail_offset'], d['head_offset']]\n",
|
||||
"\n",
|
||||
" d['head_pos'] = list(map(lambda i: i - d['head_offset'], list(range(d['lens']))))\n",
|
||||
" d['head_pos'] = _handle_pos_limit(d['head_pos'], int(cfg.pos_limit))\n",
|
||||
"\n",
|
||||
" d['tail_pos'] = list(map(lambda i: i - d['tail_offset'], list(range(d['lens']))))\n",
|
||||
" d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit))\n",
|
||||
"\n",
|
||||
" if cfg.use_pcnn:\n",
|
||||
" d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\\\n",
|
||||
" [3] * (d['lens'] - entities_idx[1])\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:\n",
|
||||
" for i, p in enumerate(pos):\n",
|
||||
" if p > limit:\n",
|
||||
" pos[i] = limit\n",
|
||||
" if p < -limit:\n",
|
||||
" pos[i] = -limit\n",
|
||||
" return [p + limit + 1 for p in pos]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):\n",
|
||||
" \n",
|
||||
" if isinstance(seq_len, list):\n",
|
||||
" seq_len = np.array(seq_len)\n",
|
||||
"\n",
|
||||
" if isinstance(seq_len, np.ndarray):\n",
|
||||
" seq_len = torch.from_numpy(seq_len)\n",
|
||||
"\n",
|
||||
" if isinstance(seq_len, torch.Tensor):\n",
|
||||
" assert seq_len.dim() == 1, logger.error(f\"seq_len can only have one dimension, got {seq_len.dim()} != 1.\")\n",
|
||||
" batch_size = seq_len.size(0)\n",
|
||||
" max_len = int(max_len) if max_len else seq_len.max().long()\n",
|
||||
" broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len.device)\n",
|
||||
" if mask_pos_to_true:\n",
|
||||
" mask = broad_cast_seq_len.ge(seq_len.unsqueeze(1))\n",
|
||||
" else:\n",
|
||||
" mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))\n",
|
||||
" else:\n",
|
||||
" raise logger.error(\"Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.\")\n",
|
||||
"\n",
|
||||
" return mask\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _lm_serialize(data: List[Dict], cfg):\n",
|
||||
" logger.info('use bert tokenizer...')\n",
|
||||
" tokenizer = BertTokenizer.from_pretrained(cfg.lm_file)\n",
|
||||
" for d in data:\n",
|
||||
" sent = d['sentence'].strip()\n",
|
||||
" sent = sent.replace(d['head'], d['head_type'], 1).replace(d['tail'], d['tail_type'], 1)\n",
|
||||
" sent += '[SEP]' + d['head'] + '[SEP]' + d['tail']\n",
|
||||
" d['token2idx'] = tokenizer.encode(sent, add_special_tokens=True)\n",
|
||||
" d['lens'] = len(d['token2idx'])"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Preprocess\n",
|
||||
"logger.info('load raw files...')\n",
|
||||
"train_fp = os.path.join('data/train.csv')\n",
|
||||
"valid_fp = os.path.join('data/valid.csv')\n",
|
||||
"test_fp = os.path.join('data/test.csv')\n",
|
||||
"relation_fp = os.path.join('data/relation.csv')\n",
|
||||
"\n",
|
||||
"train_data = load_csv(train_fp)\n",
|
||||
"valid_data = load_csv(valid_fp)\n",
|
||||
"test_data = load_csv(test_fp)\n",
|
||||
"relation_data = load_csv(relation_fp)\n",
|
||||
"\n",
|
||||
"for d in train_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
"for d in valid_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
"for d in test_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
" \n",
|
||||
"logger.info('convert relation into index...')\n",
|
||||
"rels = _handle_relation_data(relation_data)\n",
|
||||
"_add_relation_data(rels, train_data)\n",
|
||||
"_add_relation_data(rels, valid_data)\n",
|
||||
"_add_relation_data(rels, test_data)\n",
|
||||
"\n",
|
||||
"logger.info('verify whether use pretrained language models...')\n",
|
||||
"if cfg.model_name == 'lm':\n",
|
||||
" logger.info('use pretrained language models serialize sentence...')\n",
|
||||
" _lm_serialize(train_data, cfg)\n",
|
||||
" _lm_serialize(valid_data, cfg)\n",
|
||||
" _lm_serialize(test_data, cfg)\n",
|
||||
"else:\n",
|
||||
" logger.info('build vocabulary...')\n",
|
||||
" vocab = Vocab('word')\n",
|
||||
" train_tokens = [d['tokens'] for d in train_data]\n",
|
||||
" valid_tokens = [d['tokens'] for d in valid_data]\n",
|
||||
" test_tokens = [d['tokens'] for d in test_data]\n",
|
||||
" sent_tokens = [*train_tokens, *valid_tokens, *test_tokens]\n",
|
||||
" for sent in sent_tokens:\n",
|
||||
" vocab.add_words(sent)\n",
|
||||
" vocab.trim(min_freq=cfg.min_freq)\n",
|
||||
"\n",
|
||||
" logger.info('convert tokens into index...')\n",
|
||||
" _convert_tokens_into_index(train_data, vocab)\n",
|
||||
" _convert_tokens_into_index(valid_data, vocab)\n",
|
||||
" _convert_tokens_into_index(test_data, vocab)\n",
|
||||
"\n",
|
||||
" logger.info('build position sequence...')\n",
|
||||
" _add_pos_seq(train_data, cfg)\n",
|
||||
" _add_pos_seq(valid_data, cfg)\n",
|
||||
" _add_pos_seq(test_data, cfg)\n",
|
||||
"\n",
|
||||
"logger.info('save data for backup...')\n",
|
||||
"os.makedirs(cfg.out_path, exist_ok=True)\n",
|
||||
"train_save_fp = os.path.join(cfg.out_path, 'train.pkl')\n",
|
||||
"valid_save_fp = os.path.join(cfg.out_path, 'valid.pkl')\n",
|
||||
"test_save_fp = os.path.join(cfg.out_path, 'test.pkl')\n",
|
||||
"save_pkl(train_data, train_save_fp)\n",
|
||||
"save_pkl(valid_data, valid_save_fp)\n",
|
||||
"save_pkl(test_data, test_save_fp)\n",
|
||||
"\n",
|
||||
"if cfg.model_name != 'lm':\n",
|
||||
" vocab_save_fp = os.path.join(cfg.out_path, 'vocab.pkl')\n",
|
||||
" vocab_txt = os.path.join(cfg.out_path, 'vocab.txt')\n",
|
||||
" save_pkl(vocab, vocab_save_fp)\n",
|
||||
" logger.info('save vocab in txt file, for watching...')\n",
|
||||
" with open(vocab_txt, 'w', encoding='utf-8') as f:\n",
|
||||
" f.write(os.linesep.join(vocab.word2idx.keys()))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# embedding layer\n",
|
||||
"class Embedding(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
" \"\"\"\n",
|
||||
" word embedding: 一般 0 为 padding\n",
|
||||
" pos embedding: 一般 0 为 padding\n",
|
||||
" dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加\n",
|
||||
" \"\"\"\n",
|
||||
" super(Embedding, self).__init__()\n",
|
||||
"\n",
|
||||
" # self.xxx = config.xxx\n",
|
||||
" self.vocab_size = config.vocab_size\n",
|
||||
" self.word_dim = config.word_dim\n",
|
||||
" self.pos_size = config.pos_limit * 2 + 2\n",
|
||||
" self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim\n",
|
||||
" self.dim_strategy = config.dim_strategy\n",
|
||||
"\n",
|
||||
" self.wordEmbed = nn.Embedding(self.vocab_size,self.word_dim,padding_idx=0)\n",
|
||||
" self.headPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)\n",
|
||||
" self.tailPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def forward(self, *x):\n",
|
||||
" word, head, tail = x\n",
|
||||
" word_embedding = self.wordEmbed(word)\n",
|
||||
" head_embedding = self.headPosEmbed(head)\n",
|
||||
" tail_embedding = self.tailPosEmbed(tail)\n",
|
||||
"\n",
|
||||
" if self.dim_strategy == 'cat':\n",
|
||||
" return torch.cat((word_embedding,head_embedding, tail_embedding), -1)\n",
|
||||
" elif self.dim_strategy == 'sum':\n",
|
||||
" # 此时 pos_dim == word_dim\n",
|
||||
" return word_embedding + head_embedding + tail_embedding\n",
|
||||
" else:\n",
|
||||
" raise Exception('dim_strategy must choose from [sum, cat]')"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# gcn model\n",
|
||||
"class GCN(nn.Module):\n",
|
||||
" def __init__(self,cfg):\n",
|
||||
" super(GCN , self).__init__()\n",
|
||||
"\n",
|
||||
" self.num_layers = cfg.num_layers\n",
|
||||
" self.input_size = cfg.input_size\n",
|
||||
" self.hidden_size = cfg.hidden_size\n",
|
||||
" self.dropout = cfg.dropout\n",
|
||||
"\n",
|
||||
" self.fc1 = nn.Linear(self.input_size , self.hidden_size)\n",
|
||||
" self.fc = nn.Linear(self.hidden_size , self.hidden_size)\n",
|
||||
" self.weight_list = nn.ModuleList()\n",
|
||||
" for i in range(self.num_layers):\n",
|
||||
" self.weight_list.append(nn.Linear(self.hidden_size * (i + 1),self.hidden_size))\n",
|
||||
" self.dropout = nn.Dropout(self.dropout)\n",
|
||||
"\n",
|
||||
" def forward(self , x, adj):\n",
|
||||
" L = adj.sum(2).unsqueeze(2) + 1\n",
|
||||
" outputs = self.fc1(x)\n",
|
||||
" cache_list = [outputs]\n",
|
||||
" output_list = []\n",
|
||||
" for l in range(self.num_layers):\n",
|
||||
" Ax = adj.bmm(outputs)\n",
|
||||
" AxW = self.weight_list[l](Ax)\n",
|
||||
" AxW = AxW + self.weight_list[l](outputs)\n",
|
||||
" AxW = AxW / L\n",
|
||||
" gAxW = F.relu(AxW)\n",
|
||||
" cache_list.append(gAxW)\n",
|
||||
" outputs = torch.cat(cache_list , dim=2)\n",
|
||||
" output_list.append(self.dropout(gAxW))\n",
|
||||
" # gcn_outputs = torch.cat(output_list, dim=2)\n",
|
||||
" gcn_outputs = output_list[self.num_layers - 1]\n",
|
||||
" gcn_outputs = gcn_outputs + self.fc1(x)\n",
|
||||
"\n",
|
||||
" out = self.fc(gcn_outputs)\n",
|
||||
" return out"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# p,r,f1 measurement\n",
|
||||
"class PRMetric():\n",
|
||||
" def __init__(self):\n",
|
||||
" \"\"\"\n",
|
||||
" 暂时调用 sklearn 的方法\n",
|
||||
" \"\"\"\n",
|
||||
" self.y_true = np.empty(0)\n",
|
||||
" self.y_pred = np.empty(0)\n",
|
||||
"\n",
|
||||
" def reset(self):\n",
|
||||
" self.y_true = np.empty(0)\n",
|
||||
" self.y_pred = np.empty(0)\n",
|
||||
"\n",
|
||||
" def update(self, y_true:torch.Tensor, y_pred:torch.Tensor):\n",
|
||||
" y_true = y_true.cpu().detach().numpy()\n",
|
||||
" y_pred = y_pred.cpu().detach().numpy()\n",
|
||||
" y_pred = np.argmax(y_pred,axis=-1)\n",
|
||||
"\n",
|
||||
" self.y_true = np.append(self.y_true, y_true)\n",
|
||||
" self.y_pred = np.append(self.y_pred, y_pred)\n",
|
||||
"\n",
|
||||
" def compute(self):\n",
|
||||
" p, r, f1, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='macro',warn_for=tuple())\n",
|
||||
" _, _, acc, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='micro',warn_for=tuple())\n",
|
||||
"\n",
|
||||
" return acc,p,r,f1"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Iteration in training process\n",
|
||||
"def train(epoch, model, dataloader, optimizer, criterion, cfg):\n",
|
||||
" model.train()\n",
|
||||
"\n",
|
||||
" metric = PRMetric()\n",
|
||||
" losses = []\n",
|
||||
"\n",
|
||||
" for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" y_pred = model(x)\n",
|
||||
" loss = criterion(y_pred, y)\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" metric.update(y_true=y, y_pred=y_pred)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
"\n",
|
||||
" data_total = len(dataloader.dataset)\n",
|
||||
" data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)\n",
|
||||
" if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):\n",
|
||||
" acc,p,r,f1 = metric.compute()\n",
|
||||
" print(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\\t'\n",
|
||||
" f'Loss: {loss.item():.6f}')\n",
|
||||
" print(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\\t'\n",
|
||||
" f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')\n",
|
||||
"\n",
|
||||
" if cfg.show_plot and not cfg.only_comparison_plot:\n",
|
||||
" if cfg.plot_utils == 'matplot':\n",
|
||||
" plt.plot(losses)\n",
|
||||
" plt.title(f'epoch {epoch} train loss')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
" return losses[-1]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Iteration in testing process\n",
|
||||
"def validate(epoch, model, dataloader, criterion,verbose=True):\n",
|
||||
" model.eval()\n",
|
||||
"\n",
|
||||
" metric = PRMetric()\n",
|
||||
" losses = []\n",
|
||||
"\n",
|
||||
" for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
|
||||
" with torch.no_grad():\n",
|
||||
" y_pred = model(x)\n",
|
||||
" loss = criterion(y_pred, y)\n",
|
||||
"\n",
|
||||
" metric.update(y_true=y, y_pred=y_pred)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
"\n",
|
||||
" loss = sum(losses) / len(losses)\n",
|
||||
" acc,p,r,f1 = metric.compute()\n",
|
||||
" data_total = len(dataloader.dataset)\n",
|
||||
" if verbose:\n",
|
||||
" print(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\\t Loss: {loss:.6f}')\n",
|
||||
" print(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]\\n\\n')\n",
|
||||
"\n",
|
||||
" return f1,loss"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Load dataset\n",
|
||||
"train_dataset = CustomDataset(train_save_fp)\n",
|
||||
"valid_dataset = CustomDataset(valid_save_fp)\n",
|
||||
"test_dataset = CustomDataset(test_save_fp)\n",
|
||||
"\n",
|
||||
"train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||||
"valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||||
"test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# After the preprocessed data is loaded, vocab_size is known\n",
|
||||
"vocab = load_pkl(vocab_save_fp)\n",
|
||||
"vocab_size = vocab.count\n",
|
||||
"cfg.vocab_size = vocab_size"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# main entry, define optimization function, loss function and so on\n",
|
||||
"# start epoch\n",
|
||||
"# Use the loss of the valid dataset to make an early stop judgment. When it does not decline, this is the time when the model generalization is the best.\n",
|
||||
"model = GCN(cfg)\n",
|
||||
"print(model)\n",
|
||||
"\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)\n",
|
||||
"scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"\n",
|
||||
"best_f1, best_epoch = -1, 0\n",
|
||||
"es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, = 1000, -1, 0, 0, 0, -1\n",
|
||||
"train_losses, valid_losses = [], []\n",
|
||||
"\n",
|
||||
"logger.info('=' * 10 + ' Start training ' + '=' * 10)\n",
|
||||
"for epoch in range(1, cfg.epoch + 1):\n",
|
||||
" train_loss = train(epoch, model, train_dataloader, optimizer, criterion, cfg)\n",
|
||||
" valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion)\n",
|
||||
" scheduler.step(valid_loss)\n",
|
||||
"\n",
|
||||
" train_losses.append(train_loss)\n",
|
||||
" valid_losses.append(valid_loss)\n",
|
||||
" if best_f1 < valid_f1:\n",
|
||||
" best_f1 = valid_f1\n",
|
||||
" best_epoch = epoch\n",
|
||||
" if es_loss > valid_loss:\n",
|
||||
" es_loss = valid_loss\n",
|
||||
" es_f1 = valid_f1\n",
|
||||
" best_es_f1 = valid_f1\n",
|
||||
" es_epoch = epoch\n",
|
||||
" best_es_epoch = epoch\n",
|
||||
" es_patience = 0\n",
|
||||
" else:\n",
|
||||
" es_patience += 1\n",
|
||||
" if es_patience >= cfg.early_stopping_patience:\n",
|
||||
" best_es_epoch = es_epoch\n",
|
||||
" best_es_f1 = es_f1\n",
|
||||
"\n",
|
||||
"if cfg.show_plot:\n",
|
||||
" if cfg.plot_utils == 'matplot':\n",
|
||||
" plt.plot(train_losses, 'x-')\n",
|
||||
" plt.plot(valid_losses, '+-')\n",
|
||||
" plt.legend(['train', 'valid'])\n",
|
||||
" plt.title('train/valid comparison loss')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '\n",
|
||||
" f'this epoch macro f1: {best_es_f1:0.4f}')\n",
|
||||
"print(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '\n",
|
||||
" f'this epoch macro f1: {best_f1:.4f}')\n",
|
||||
"\n",
|
||||
"test_f1, _ = validate(0, model, test_dataloader, criterion,verbose=False)\n",
|
||||
"print(f'after {cfg.epoch} epochs, final test data macro f1: {test_f1:.4f}')"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"This demo does not include parameter adjustment. Interested students can go to [deepke] by themselves( http://openkg.cn/tool/deepke )Warehouse, download and use more models:)"
|
||||
],
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,579 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## relation extraction experiment\n",
|
||||
"> Tutorial author:余海阳(yuhaiyang@zju.edu.cn)\n",
|
||||
"\n",
|
||||
"On this demo,we use `pretrain_language model` to extract relations.\n",
|
||||
"We hope this demo can help you understand the process of conctruction knowledge graph and the the principles and common methods of triplet extraction.\n",
|
||||
"\n",
|
||||
"This demo uses `Python3`.\n",
|
||||
"\n",
|
||||
"### Dataset\n",
|
||||
"In this example,we get some Chinese text to extract the triples.\n",
|
||||
"\n",
|
||||
"sentence|relation|head|tail\n",
|
||||
":---:|:---:|:---:|:---:\n",
|
||||
"孔正锡在2005年以一部温馨的爱情电影《长腿叔叔》敲开电影界大门。|导演|长腿叔叔|孔正锡\n",
|
||||
"《伤心的树》是吴宗宪的音乐作品,收录在《你比从前快乐》专辑中。|所属专辑|伤心的树|你比从前快乐\n",
|
||||
"2000年8月,「天坛大佛」荣获「香港十大杰出工程项目」第四名。|所在城市|天坛大佛|香港\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"- train.csv: It contains 6 training triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n",
|
||||
"- valid.csv: It contains 3 validing triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n",
|
||||
"- test.csv: It contains 3 testing triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n",
|
||||
"- relation.csv: It contains 4 relation triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### BERT \n",
|
||||
"\n",
|
||||
"![BERT](img/Bert.png)\n",
|
||||
"\n",
|
||||
"After Bert coding, the original sentence can get rich semantic information. The obtained results are input into the bidirectional LSTM, and the output results can obtain the relationship information of the sentence.\n"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### Code experience\n",
|
||||
"\n",
|
||||
"Important tips:\n",
|
||||
"- When we use pretrain language model, we need to load about 500MB of model data,so it is more recommended to download to the local and run.At this time, you only need to add `lm_file` value to the address of the local folder. See the link [transformers](https://huggingface.co/transformers/)"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Run the neural network with pytorch and confirm whether it is installed before running\n",
|
||||
"!pip install torch\n",
|
||||
"!pip install matplotlib\n",
|
||||
"!pip install transformers"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# import the whole modules\n",
|
||||
"import os\n",
|
||||
"import csv\n",
|
||||
"import math\n",
|
||||
"import pickle\n",
|
||||
"import logging\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from torch import optim\n",
|
||||
"from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
|
||||
"from torch.utils.data import Dataset,DataLoader\n",
|
||||
"from sklearn.metrics import precision_recall_fscore_support\n",
|
||||
"from typing import List, Tuple, Dict, Any, Sequence, Optional, Union\n",
|
||||
"from transformers import BertTokenizer, BertModel\n",
|
||||
"\n",
|
||||
"logger = logging.getLogger(__name__)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Configuration file of model parameters\n",
|
||||
"class Config(object):\n",
|
||||
" model_name = 'lm' # ['cnn', 'gcn', 'lm']\n",
|
||||
" use_pcnn = True\n",
|
||||
" min_freq = 1\n",
|
||||
" pos_limit = 20\n",
|
||||
" out_path = 'data/out' \n",
|
||||
" batch_size = 2 \n",
|
||||
" word_dim = 10\n",
|
||||
" pos_dim = 5\n",
|
||||
" dim_strategy = 'sum' # ['sum', 'cat']\n",
|
||||
" out_channels = 20\n",
|
||||
" intermediate = 10\n",
|
||||
" kernel_sizes = [3, 5, 7]\n",
|
||||
" activation = 'gelu'\n",
|
||||
" pooling_strategy = 'max'\n",
|
||||
" dropout = 0.3\n",
|
||||
" epoch = 10\n",
|
||||
" num_relations = 4\n",
|
||||
" learning_rate = 3e-4\n",
|
||||
" lr_factor = 0.7 # 学习率的衰减率\n",
|
||||
" lr_patience = 3 # 学习率衰减的等待epoch\n",
|
||||
" weight_decay = 1e-3 # L2正则\n",
|
||||
" early_stopping_patience = 6\n",
|
||||
" train_log = True\n",
|
||||
" log_interval = 1\n",
|
||||
" show_plot = True\n",
|
||||
" only_comparison_plot = False\n",
|
||||
" plot_utils = 'matplot'\n",
|
||||
" lm_file = 'bert-base-chinese'\n",
|
||||
" # lm_file = '/Users/yuhaiyang/transformers/bert-base-chinese'\n",
|
||||
" lm_num_hidden_layers = 2\n",
|
||||
" rnn_layers = 2\n",
|
||||
" \n",
|
||||
"cfg = Config()"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Functions required for preprocessing\n",
|
||||
"Path = str\n",
|
||||
"\n",
|
||||
"def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'load csv from {fp}')\n",
|
||||
"\n",
|
||||
" dialect = 'excel-tab' if is_tsv else 'excel'\n",
|
||||
" with open(fp, encoding='utf-8') as f:\n",
|
||||
" reader = csv.DictReader(f, dialect=dialect)\n",
|
||||
" return list(reader)\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"def load_pkl(fp: Path, verbose: bool = True) -> Any:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'load data from {fp}')\n",
|
||||
"\n",
|
||||
" with open(fp, 'rb') as f:\n",
|
||||
" data = pickle.load(f)\n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'save data in {fp}')\n",
|
||||
"\n",
|
||||
" with open(fp, 'wb') as f:\n",
|
||||
" pickle.dump(data, f)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"def _handle_relation_data(relation_data: List[Dict]) -> Dict:\n",
|
||||
" rels = dict()\n",
|
||||
" for d in relation_data:\n",
|
||||
" rels[d['relation']] = {\n",
|
||||
" 'index': int(d['index']),\n",
|
||||
" 'head_type': d['head_type'],\n",
|
||||
" 'tail_type': d['tail_type'],\n",
|
||||
" }\n",
|
||||
" return rels\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _add_relation_data(rels: Dict,data: List) -> None:\n",
|
||||
" for d in data:\n",
|
||||
" d['rel2idx'] = rels[d['relation']]['index']\n",
|
||||
" d['head_type'] = rels[d['relation']]['head_type']\n",
|
||||
" d['tail_type'] = rels[d['relation']]['tail_type']\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):\n",
|
||||
" \n",
|
||||
" if isinstance(seq_len, list):\n",
|
||||
" seq_len = np.array(seq_len)\n",
|
||||
"\n",
|
||||
" if isinstance(seq_len, np.ndarray):\n",
|
||||
" seq_len = torch.from_numpy(seq_len)\n",
|
||||
"\n",
|
||||
" if isinstance(seq_len, torch.Tensor):\n",
|
||||
" assert seq_len.dim() == 1, logger.error(f\"seq_len can only have one dimension, got {seq_len.dim()} != 1.\")\n",
|
||||
" batch_size = seq_len.size(0)\n",
|
||||
" max_len = int(max_len) if max_len else seq_len.max().long()\n",
|
||||
" broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len.device)\n",
|
||||
" if mask_pos_to_true:\n",
|
||||
" mask = broad_cast_seq_len.ge(seq_len.unsqueeze(1))\n",
|
||||
" else:\n",
|
||||
" mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))\n",
|
||||
" else:\n",
|
||||
" raise logger.error(\"Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.\")\n",
|
||||
"\n",
|
||||
" return mask\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _lm_serialize(data: List[Dict], cfg):\n",
|
||||
" logger.info('use bert tokenizer...')\n",
|
||||
" tokenizer = BertTokenizer.from_pretrained(cfg.lm_file)\n",
|
||||
" for d in data:\n",
|
||||
" sent = d['sentence'].strip()\n",
|
||||
" sent = sent.replace(d['head'], d['head_type'], 1).replace(d['tail'], d['tail_type'], 1)\n",
|
||||
" sent += '[SEP]' + d['head'] + '[SEP]' + d['tail']\n",
|
||||
" d['token2idx'] = tokenizer.encode(sent, add_special_tokens=True)\n",
|
||||
" d['lens'] = len(d['token2idx'])"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Preprocess\n",
|
||||
"logger.info('load raw files...')\n",
|
||||
"train_fp = os.path.join('data/train.csv')\n",
|
||||
"valid_fp = os.path.join('data/valid.csv')\n",
|
||||
"test_fp = os.path.join('data/test.csv')\n",
|
||||
"relation_fp = os.path.join('data/relation.csv')\n",
|
||||
"\n",
|
||||
"train_data = load_csv(train_fp)\n",
|
||||
"valid_data = load_csv(valid_fp)\n",
|
||||
"test_data = load_csv(test_fp)\n",
|
||||
"relation_data = load_csv(relation_fp)\n",
|
||||
"\n",
|
||||
"for d in train_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
"for d in valid_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
"for d in test_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
" \n",
|
||||
"logger.info('convert relation into index...')\n",
|
||||
"rels = _handle_relation_data(relation_data)\n",
|
||||
"_add_relation_data(rels, train_data)\n",
|
||||
"_add_relation_data(rels, valid_data)\n",
|
||||
"_add_relation_data(rels, test_data)\n",
|
||||
"\n",
|
||||
"logger.info('verify whether use pretrained language models...')\n",
|
||||
"\n",
|
||||
"logger.info('use pretrained language models serialize sentence...')\n",
|
||||
"_lm_serialize(train_data, cfg)\n",
|
||||
"_lm_serialize(valid_data, cfg)\n",
|
||||
"_lm_serialize(test_data, cfg)\n",
|
||||
"\n",
|
||||
"logger.info('save data for backup...')\n",
|
||||
"os.makedirs(cfg.out_path, exist_ok=True)\n",
|
||||
"train_save_fp = os.path.join(cfg.out_path, 'train.pkl')\n",
|
||||
"valid_save_fp = os.path.join(cfg.out_path, 'valid.pkl')\n",
|
||||
"test_save_fp = os.path.join(cfg.out_path, 'test.pkl')\n",
|
||||
"save_pkl(train_data, train_save_fp)\n",
|
||||
"save_pkl(valid_data, valid_save_fp)\n",
|
||||
"save_pkl(test_data, test_save_fp)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# pytorch construct Dataset\n",
|
||||
"def collate_fn(cfg):\n",
|
||||
" def collate_fn_intra(batch):\n",
|
||||
" batch.sort(key=lambda data: int(data['lens']), reverse=True)\n",
|
||||
" max_len = int(batch[0]['lens'])\n",
|
||||
" \n",
|
||||
" def _padding(x, max_len):\n",
|
||||
" return x + [0] * (max_len - len(x))\n",
|
||||
" \n",
|
||||
" def _pad_adj(adj, max_len):\n",
|
||||
" adj = np.array(adj)\n",
|
||||
" pad_len = max_len - adj.shape[0]\n",
|
||||
" for i in range(pad_len):\n",
|
||||
" adj = np.insert(adj, adj.shape[-1], 0, axis=1)\n",
|
||||
" for i in range(pad_len):\n",
|
||||
" adj = np.insert(adj, adj.shape[0], 0, axis=0)\n",
|
||||
" return adj\n",
|
||||
" \n",
|
||||
" x, y = dict(), []\n",
|
||||
" word, word_len = [], []\n",
|
||||
" head_pos, tail_pos = [], []\n",
|
||||
" pcnn_mask = []\n",
|
||||
" adj_matrix = []\n",
|
||||
" for data in batch:\n",
|
||||
" word.append(_padding(data['token2idx'], max_len))\n",
|
||||
" word_len.append(int(data['lens']))\n",
|
||||
" y.append(int(data['rel2idx']))\n",
|
||||
" \n",
|
||||
" if cfg.model_name != 'lm':\n",
|
||||
" head_pos.append(_padding(data['head_pos'], max_len))\n",
|
||||
" tail_pos.append(_padding(data['tail_pos'], max_len))\n",
|
||||
" if cfg.model_name == 'gcn':\n",
|
||||
" head = eval(data['dependency'])\n",
|
||||
" adj = head_to_adj(head, directed=True, self_loop=True)\n",
|
||||
" adj_matrix.append(_pad_adj(adj, max_len))\n",
|
||||
"\n",
|
||||
" if cfg.use_pcnn:\n",
|
||||
" pcnn_mask.append(_padding(data['entities_pos'], max_len))\n",
|
||||
"\n",
|
||||
" x['word'] = torch.tensor(word)\n",
|
||||
" x['lens'] = torch.tensor(word_len)\n",
|
||||
" y = torch.tensor(y)\n",
|
||||
" \n",
|
||||
" if cfg.model_name != 'lm':\n",
|
||||
" x['head_pos'] = torch.tensor(head_pos)\n",
|
||||
" x['tail_pos'] = torch.tensor(tail_pos)\n",
|
||||
" if cfg.model_name == 'gcn':\n",
|
||||
" x['adj'] = torch.tensor(adj_matrix)\n",
|
||||
" if cfg.model_name == 'cnn' and cfg.use_pcnn:\n",
|
||||
" x['pcnn_mask'] = torch.tensor(pcnn_mask)\n",
|
||||
"\n",
|
||||
" return x, y\n",
|
||||
" \n",
|
||||
" return collate_fn_intra\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class CustomDataset(Dataset):\n",
|
||||
" def __init__(self, fp):\n",
|
||||
" self.file = load_pkl(fp)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, item):\n",
|
||||
" sample = self.file[item]\n",
|
||||
" return sample\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.file)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# pretrain language model\n",
|
||||
"class PretrainLM(nn.Module):\n",
|
||||
" def __init__(self, cfg):\n",
|
||||
" super(PretrainLM, self).__init__()\n",
|
||||
" self.num_layers = cfg.rnn_layers\n",
|
||||
" self.lm = BertModel.from_pretrained(cfg.lm_file, num_hidden_layers=cfg.lm_num_hidden_layers)\n",
|
||||
" self.bilstm = nn.LSTM(768,10,batch_first=True,bidirectional=True,num_layers=cfg.rnn_layers,dropout=cfg.dropout)\n",
|
||||
" self.fc = nn.Linear(20, cfg.num_relations)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" N = self.num_layers\n",
|
||||
" word, lens = x['word'], x['lens']\n",
|
||||
" B = word.size(0)\n",
|
||||
" output, pooler_output = self.lm(word)\n",
|
||||
" output = pack_padded_sequence(output, lens, batch_first=True, enforce_sorted=True)\n",
|
||||
" _, (output,_) = self.bilstm(output)\n",
|
||||
" output = output.view(N, 2, B, 10).transpose(1, 2).contiguous().view(N, B, 20).transpose(0, 1)\n",
|
||||
" output = output[:,-1,:]\n",
|
||||
" output = self.fc(output)\n",
|
||||
" \n",
|
||||
" return output"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# p,r,f1 measurement\n",
|
||||
"class PRMetric():\n",
|
||||
" def __init__(self):\n",
|
||||
" \n",
|
||||
" self.y_true = np.empty(0)\n",
|
||||
" self.y_pred = np.empty(0)\n",
|
||||
"\n",
|
||||
" def reset(self):\n",
|
||||
" self.y_true = np.empty(0)\n",
|
||||
" self.y_pred = np.empty(0)\n",
|
||||
"\n",
|
||||
" def update(self, y_true:torch.Tensor, y_pred:torch.Tensor):\n",
|
||||
" y_true = y_true.cpu().detach().numpy()\n",
|
||||
" y_pred = y_pred.cpu().detach().numpy()\n",
|
||||
" y_pred = np.argmax(y_pred,axis=-1)\n",
|
||||
"\n",
|
||||
" self.y_true = np.append(self.y_true, y_true)\n",
|
||||
" self.y_pred = np.append(self.y_pred, y_pred)\n",
|
||||
"\n",
|
||||
" def compute(self):\n",
|
||||
" p, r, f1, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='macro',warn_for=tuple())\n",
|
||||
" _, _, acc, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='micro',warn_for=tuple())\n",
|
||||
"\n",
|
||||
" return acc,p,r,f1"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Iteration in training process\n",
|
||||
"def train(epoch, model, dataloader, optimizer, criterion, cfg):\n",
|
||||
" model.train()\n",
|
||||
"\n",
|
||||
" metric = PRMetric()\n",
|
||||
" losses = []\n",
|
||||
"\n",
|
||||
" for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" y_pred = model(x)\n",
|
||||
" loss = criterion(y_pred, y)\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" metric.update(y_true=y, y_pred=y_pred)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
"\n",
|
||||
" data_total = len(dataloader.dataset)\n",
|
||||
" data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)\n",
|
||||
" if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):\n",
|
||||
" acc,p,r,f1 = metric.compute()\n",
|
||||
" print(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\\t'\n",
|
||||
" f'Loss: {loss.item():.6f}')\n",
|
||||
" print(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\\t'\n",
|
||||
" f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')\n",
|
||||
"\n",
|
||||
" if cfg.show_plot and not cfg.only_comparison_plot:\n",
|
||||
" if cfg.plot_utils == 'matplot':\n",
|
||||
" plt.plot(losses)\n",
|
||||
" plt.title(f'epoch {epoch} train loss')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
" return losses[-1]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Iteration in testing process\n",
|
||||
"def validate(epoch, model, dataloader, criterion,verbose=True):\n",
|
||||
" model.eval()\n",
|
||||
"\n",
|
||||
" metric = PRMetric()\n",
|
||||
" losses = []\n",
|
||||
"\n",
|
||||
" for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
|
||||
" with torch.no_grad():\n",
|
||||
" y_pred = model(x)\n",
|
||||
" loss = criterion(y_pred, y)\n",
|
||||
"\n",
|
||||
" metric.update(y_true=y, y_pred=y_pred)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
"\n",
|
||||
" loss = sum(losses) / len(losses)\n",
|
||||
" acc,p,r,f1 = metric.compute()\n",
|
||||
" data_total = len(dataloader.dataset)\n",
|
||||
" if verbose:\n",
|
||||
" print(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\\t Loss: {loss:.6f}')\n",
|
||||
" print(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]\\n\\n')\n",
|
||||
"\n",
|
||||
" return f1,loss"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Load dataset\n",
|
||||
"train_dataset = CustomDataset(train_save_fp)\n",
|
||||
"valid_dataset = CustomDataset(valid_save_fp)\n",
|
||||
"test_dataset = CustomDataset(test_save_fp)\n",
|
||||
"\n",
|
||||
"train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||||
"valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||||
"test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# main entry, define optimization function, loss function and so on\n",
|
||||
"# start epoch\n",
|
||||
"# Use the loss of the valid dataset to make an early stop judgment. When it does not decline, this is the time when the model generalization is the best.\n",
|
||||
"model = PretrainLM(cfg)\n",
|
||||
"print(model)\n",
|
||||
"\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)\n",
|
||||
"scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"\n",
|
||||
"best_f1, best_epoch = -1, 0\n",
|
||||
"es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, = 1000, -1, 0, 0, 0, -1\n",
|
||||
"train_losses, valid_losses = [], []\n",
|
||||
"\n",
|
||||
"logger.info('=' * 10 + ' Start training ' + '=' * 10)\n",
|
||||
"for epoch in range(1, cfg.epoch + 1):\n",
|
||||
" train_loss = train(epoch, model, train_dataloader, optimizer, criterion, cfg)\n",
|
||||
" valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion)\n",
|
||||
" scheduler.step(valid_loss)\n",
|
||||
"\n",
|
||||
" train_losses.append(train_loss)\n",
|
||||
" valid_losses.append(valid_loss)\n",
|
||||
" if best_f1 < valid_f1:\n",
|
||||
" best_f1 = valid_f1\n",
|
||||
" best_epoch = epoch\n",
|
||||
" if es_loss > valid_loss:\n",
|
||||
" es_loss = valid_loss\n",
|
||||
" es_f1 = valid_f1\n",
|
||||
" best_es_f1 = valid_f1\n",
|
||||
" es_epoch = epoch\n",
|
||||
" best_es_epoch = epoch\n",
|
||||
" es_patience = 0\n",
|
||||
" else:\n",
|
||||
" es_patience += 1\n",
|
||||
" if es_patience >= cfg.early_stopping_patience:\n",
|
||||
" best_es_epoch = es_epoch\n",
|
||||
" best_es_f1 = es_f1\n",
|
||||
"\n",
|
||||
"if cfg.show_plot:\n",
|
||||
" if cfg.plot_utils == 'matplot':\n",
|
||||
" plt.plot(train_losses, 'x-')\n",
|
||||
" plt.plot(valid_losses, '+-')\n",
|
||||
" plt.legend(['train', 'valid'])\n",
|
||||
" plt.title('train/valid comparison loss')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '\n",
|
||||
" f'this epoch macro f1: {best_es_f1:0.4f}')\n",
|
||||
"print(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '\n",
|
||||
" f'this epoch macro f1: {best_f1:.4f}')\n",
|
||||
"\n",
|
||||
"test_f1, _ = validate(0, model, test_dataloader, criterion,verbose=False)\n",
|
||||
"print(f'after {cfg.epoch} epochs, final test data macro f1: {test_f1:.4f}')"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"This demo does not include parameter adjustment. Interested students can go to [deepke] by themselves( http://openkg.cn/tool/deepke )Warehouse, download and use more models:)"
|
||||
],
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,852 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## relation extraction experiment\n",
|
||||
"> Tutorial author:余海阳(yuhaiyang@zju.edu.cn)\n",
|
||||
"\n",
|
||||
"On this demo,we use `pcnn` model to extract relations.\n",
|
||||
"We hope this demo can help you understand the process of conctruction knowledge graph and the the principles and common methods of triplet extraction.\n",
|
||||
"\n",
|
||||
"This demo uses `Python3`.\n",
|
||||
"\n",
|
||||
"### Dataset\n",
|
||||
"In this example,we get some Chinese text to extract the triples.\n",
|
||||
"\n",
|
||||
"sentence|relation|head|tail\n",
|
||||
":---:|:---:|:---:|:---:\n",
|
||||
"孔正锡在2005年以一部温馨的爱情电影《长腿叔叔》敲开电影界大门。|导演|长腿叔叔|孔正锡\n",
|
||||
"《伤心的树》是吴宗宪的音乐作品,收录在《你比从前快乐》专辑中。|所属专辑|伤心的树|你比从前快乐\n",
|
||||
"2000年8月,「天坛大佛」荣获「香港十大杰出工程项目」第四名。|所在城市|天坛大佛|香港\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"- train.csv: It contains 6 training triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n",
|
||||
"- valid.csv: It contains 3 validing triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n",
|
||||
"- test.csv: It contains 3 testing triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`.\n",
|
||||
"- relation.csv: It contains 4 relation triples,each lines represent one triple,sorted by sentence, relationship, head entity and tail entity, and separated by `,`."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### PCNN \n",
|
||||
"\n",
|
||||
"![PCNN](img/PCNN.jpg)\n",
|
||||
"\n",
|
||||
"The sentence information mainly includes word embedding and position embedding.After the convolution layer,according to the position of head tail, it is divided into three sections for maximum pooling,and then through the full connection layer, the relationship information of the sentence can be obtained.\n"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Run the neural network with pytorch and confirm whether it is installed before running\n",
|
||||
"!pip install torch\n",
|
||||
"!pip install matplotlib\n",
|
||||
"!pip install transformers"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# import the whole modules\n",
|
||||
"import os\n",
|
||||
"import csv\n",
|
||||
"import math\n",
|
||||
"import pickle\n",
|
||||
"import logging\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from torch import optim\n",
|
||||
"from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
|
||||
"from torch.utils.data import Dataset,DataLoader\n",
|
||||
"from sklearn.metrics import precision_recall_fscore_support\n",
|
||||
"from typing import List, Tuple, Dict, Any, Sequence, Optional, Union\n",
|
||||
"from transformers import BertTokenizer, BertModel\n",
|
||||
"\n",
|
||||
"logger = logging.getLogger(__name__)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Configuration file of model parameters\n",
|
||||
"# use_pcnn Parameter controls whether there is a piece_ Wise pooling\n",
|
||||
"\n",
|
||||
"class Config(object):\n",
|
||||
" model_name = 'cnn' # ['cnn', 'gcn', 'lm']\n",
|
||||
" use_pcnn = True\n",
|
||||
" min_freq = 1\n",
|
||||
" pos_limit = 20\n",
|
||||
" out_path = 'data/out' \n",
|
||||
" batch_size = 2 \n",
|
||||
" word_dim = 10\n",
|
||||
" pos_dim = 5\n",
|
||||
" dim_strategy = 'sum' # ['sum', 'cat']\n",
|
||||
" out_channels = 20\n",
|
||||
" intermediate = 10\n",
|
||||
" kernel_sizes = [3, 5, 7]\n",
|
||||
" activation = 'gelu'\n",
|
||||
" pooling_strategy = 'max'\n",
|
||||
" dropout = 0.3\n",
|
||||
" epoch = 10\n",
|
||||
" num_relations = 4\n",
|
||||
" learning_rate = 3e-4\n",
|
||||
" lr_factor = 0.7 # 学习率的衰减率\n",
|
||||
" lr_patience = 3 # 学习率衰减的等待epoch\n",
|
||||
" weight_decay = 1e-3 # L2正则\n",
|
||||
" early_stopping_patience = 6\n",
|
||||
" train_log = True\n",
|
||||
" log_interval = 1\n",
|
||||
" show_plot = True\n",
|
||||
" only_comparison_plot = False\n",
|
||||
" plot_utils = 'matplot'\n",
|
||||
" lm_file = 'bert-base-chinese'\n",
|
||||
" lm_num_hidden_layers = 2\n",
|
||||
" rnn_layers = 2\n",
|
||||
" \n",
|
||||
"cfg = Config()"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Word token builds a one hot dictionary, and then inputs it to the embedding layer to obtain the corresponding word information matrix\n",
|
||||
"# 0 is pad by default and 1 is unknown\n",
|
||||
"class Vocab(object):\n",
|
||||
" def __init__(self, name: str = 'basic', init_tokens = [\"[PAD]\", \"[UNK]\"]):\n",
|
||||
" self.name = name\n",
|
||||
" self.init_tokens = init_tokens\n",
|
||||
" self.trimed = False\n",
|
||||
" self.word2idx = {}\n",
|
||||
" self.word2count = {}\n",
|
||||
" self.idx2word = {}\n",
|
||||
" self.count = 0\n",
|
||||
" self._add_init_tokens()\n",
|
||||
"\n",
|
||||
" def _add_init_tokens(self):\n",
|
||||
" for token in self.init_tokens:\n",
|
||||
" self._add_word(token)\n",
|
||||
"\n",
|
||||
" def _add_word(self, word: str):\n",
|
||||
" if word not in self.word2idx:\n",
|
||||
" self.word2idx[word] = self.count\n",
|
||||
" self.word2count[word] = 1\n",
|
||||
" self.idx2word[self.count] = word\n",
|
||||
" self.count += 1\n",
|
||||
" else:\n",
|
||||
" self.word2count[word] += 1\n",
|
||||
"\n",
|
||||
" def add_words(self, words: Sequence):\n",
|
||||
" for word in words:\n",
|
||||
" self._add_word(word)\n",
|
||||
"\n",
|
||||
" def trim(self, min_freq=2, verbose: Optional[bool] = True):\n",
|
||||
" assert min_freq == int(min_freq), f'min_freq must be integer, can\\'t be {min_freq}'\n",
|
||||
" min_freq = int(min_freq)\n",
|
||||
" if min_freq < 2:\n",
|
||||
" return\n",
|
||||
" if self.trimed:\n",
|
||||
" return\n",
|
||||
" self.trimed = True\n",
|
||||
"\n",
|
||||
" keep_words = []\n",
|
||||
" new_words = []\n",
|
||||
"\n",
|
||||
" for k, v in self.word2count.items():\n",
|
||||
" if v >= min_freq:\n",
|
||||
" keep_words.append(k)\n",
|
||||
" new_words.extend([k] * v)\n",
|
||||
" if verbose:\n",
|
||||
" before_len = len(keep_words)\n",
|
||||
" after_len = len(self.word2idx) - len(self.init_tokens)\n",
|
||||
" logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(before_len, after_len, before_len / after_len * 100))\n",
|
||||
"\n",
|
||||
" # Reinitialize dictionaries\n",
|
||||
" self.word2idx = {}\n",
|
||||
" self.word2count = {}\n",
|
||||
" self.idx2word = {}\n",
|
||||
" self.count = 0\n",
|
||||
" self._add_init_tokens()\n",
|
||||
" self.add_words(new_words)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Functions required for preprocessing\n",
|
||||
"Path = str\n",
|
||||
"\n",
|
||||
"def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'load csv from {fp}')\n",
|
||||
"\n",
|
||||
" dialect = 'excel-tab' if is_tsv else 'excel'\n",
|
||||
" with open(fp, encoding='utf-8') as f:\n",
|
||||
" reader = csv.DictReader(f, dialect=dialect)\n",
|
||||
" return list(reader)\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"def load_pkl(fp: Path, verbose: bool = True) -> Any:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'load data from {fp}')\n",
|
||||
"\n",
|
||||
" with open(fp, 'rb') as f:\n",
|
||||
" data = pickle.load(f)\n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'save data in {fp}')\n",
|
||||
"\n",
|
||||
" with open(fp, 'wb') as f:\n",
|
||||
" pickle.dump(data, f)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"def _handle_relation_data(relation_data: List[Dict]) -> Dict:\n",
|
||||
" rels = dict()\n",
|
||||
" for d in relation_data:\n",
|
||||
" rels[d['relation']] = {\n",
|
||||
" 'index': int(d['index']),\n",
|
||||
" 'head_type': d['head_type'],\n",
|
||||
" 'tail_type': d['tail_type'],\n",
|
||||
" }\n",
|
||||
" return rels\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _add_relation_data(rels: Dict,data: List) -> None:\n",
|
||||
" for d in data:\n",
|
||||
" d['rel2idx'] = rels[d['relation']]['index']\n",
|
||||
" d['head_type'] = rels[d['relation']]['head_type']\n",
|
||||
" d['tail_type'] = rels[d['relation']]['tail_type']\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _convert_tokens_into_index(data: List[Dict], vocab):\n",
|
||||
" unk_str = '[UNK]'\n",
|
||||
" unk_idx = vocab.word2idx[unk_str]\n",
|
||||
"\n",
|
||||
" for d in data:\n",
|
||||
" d['token2idx'] = [vocab.word2idx.get(i, unk_idx) for i in d['tokens']]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _add_pos_seq(train_data: List[Dict], cfg):\n",
|
||||
" for d in train_data:\n",
|
||||
" d['head_offset'], d['tail_offset'], d['lens'] = int(d['head_offset']), int(d['tail_offset']), int(d['lens'])\n",
|
||||
" entities_idx = [d['head_offset'], d['tail_offset']] if d['head_offset'] < d['tail_offset'] else [d['tail_offset'], d['head_offset']]\n",
|
||||
"\n",
|
||||
" d['head_pos'] = list(map(lambda i: i - d['head_offset'], list(range(d['lens']))))\n",
|
||||
" d['head_pos'] = _handle_pos_limit(d['head_pos'], int(cfg.pos_limit))\n",
|
||||
"\n",
|
||||
" d['tail_pos'] = list(map(lambda i: i - d['tail_offset'], list(range(d['lens']))))\n",
|
||||
" d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit))\n",
|
||||
"\n",
|
||||
" if cfg.use_pcnn:\n",
|
||||
" d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\\\n",
|
||||
" [3] * (d['lens'] - entities_idx[1])\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:\n",
|
||||
" for i, p in enumerate(pos):\n",
|
||||
" if p > limit:\n",
|
||||
" pos[i] = limit\n",
|
||||
" if p < -limit:\n",
|
||||
" pos[i] = -limit\n",
|
||||
" return [p + limit + 1 for p in pos]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):\n",
|
||||
" if isinstance(seq_len, list):\n",
|
||||
" seq_len = np.array(seq_len)\n",
|
||||
"\n",
|
||||
" if isinstance(seq_len, np.ndarray):\n",
|
||||
" seq_len = torch.from_numpy(seq_len)\n",
|
||||
"\n",
|
||||
" if isinstance(seq_len, torch.Tensor):\n",
|
||||
" assert seq_len.dim() == 1, logger.error(f\"seq_len can only have one dimension, got {seq_len.dim()} != 1.\")\n",
|
||||
" batch_size = seq_len.size(0)\n",
|
||||
" max_len = int(max_len) if max_len else seq_len.max().long()\n",
|
||||
" broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len.device)\n",
|
||||
" if mask_pos_to_true:\n",
|
||||
" mask = broad_cast_seq_len.ge(seq_len.unsqueeze(1))\n",
|
||||
" else:\n",
|
||||
" mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))\n",
|
||||
" else:\n",
|
||||
" raise logger.error(\"Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.\")\n",
|
||||
"\n",
|
||||
" return mask"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Preprocess\n",
|
||||
"logger.info('load raw files...')\n",
|
||||
"train_fp = os.path.join('data/train.csv')\n",
|
||||
"valid_fp = os.path.join('data/valid.csv')\n",
|
||||
"test_fp = os.path.join('data/test.csv')\n",
|
||||
"relation_fp = os.path.join('data/relation.csv')\n",
|
||||
"\n",
|
||||
"train_data = load_csv(train_fp)\n",
|
||||
"valid_data = load_csv(valid_fp)\n",
|
||||
"test_data = load_csv(test_fp)\n",
|
||||
"relation_data = load_csv(relation_fp)\n",
|
||||
"\n",
|
||||
"for d in train_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
"for d in valid_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
"for d in test_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
" \n",
|
||||
"logger.info('convert relation into index...')\n",
|
||||
"rels = _handle_relation_data(relation_data)\n",
|
||||
"_add_relation_data(rels, train_data)\n",
|
||||
"_add_relation_data(rels, valid_data)\n",
|
||||
"_add_relation_data(rels, test_data)\n",
|
||||
"\n",
|
||||
"logger.info('verify whether use pretrained language models...')\n",
|
||||
"logger.info('build vocabulary...')\n",
|
||||
"vocab = Vocab('word')\n",
|
||||
"train_tokens = [d['tokens'] for d in train_data]\n",
|
||||
"valid_tokens = [d['tokens'] for d in valid_data]\n",
|
||||
"test_tokens = [d['tokens'] for d in test_data]\n",
|
||||
"sent_tokens = [*train_tokens, *valid_tokens, *test_tokens]\n",
|
||||
"for sent in sent_tokens:\n",
|
||||
" vocab.add_words(sent)\n",
|
||||
"vocab.trim(min_freq=cfg.min_freq)\n",
|
||||
"\n",
|
||||
"logger.info('convert tokens into index...')\n",
|
||||
"_convert_tokens_into_index(train_data, vocab)\n",
|
||||
"_convert_tokens_into_index(valid_data, vocab)\n",
|
||||
"_convert_tokens_into_index(test_data, vocab)\n",
|
||||
"\n",
|
||||
"logger.info('build position sequence...')\n",
|
||||
"_add_pos_seq(train_data, cfg)\n",
|
||||
"_add_pos_seq(valid_data, cfg)\n",
|
||||
"_add_pos_seq(test_data, cfg)\n",
|
||||
"\n",
|
||||
"logger.info('save data for backup...')\n",
|
||||
"os.makedirs(cfg.out_path, exist_ok=True)\n",
|
||||
"train_save_fp = os.path.join(cfg.out_path, 'train.pkl')\n",
|
||||
"valid_save_fp = os.path.join(cfg.out_path, 'valid.pkl')\n",
|
||||
"test_save_fp = os.path.join(cfg.out_path, 'test.pkl')\n",
|
||||
"save_pkl(train_data, train_save_fp)\n",
|
||||
"save_pkl(valid_data, valid_save_fp)\n",
|
||||
"save_pkl(test_data, test_save_fp)\n",
|
||||
"\n",
|
||||
"vocab_save_fp = os.path.join(cfg.out_path, 'vocab.pkl')\n",
|
||||
"vocab_txt = os.path.join(cfg.out_path, 'vocab.txt')\n",
|
||||
"save_pkl(vocab, vocab_save_fp)\n",
|
||||
"logger.info('save vocab in txt file, for watching...')\n",
|
||||
"with open(vocab_txt, 'w', encoding='utf-8') as f:\n",
|
||||
" f.write(os.linesep.join(vocab.word2idx.keys()))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# pytorch construct Dataset\n",
|
||||
"def collate_fn(cfg):\n",
|
||||
" def collate_fn_intra(batch):\n",
|
||||
" batch.sort(key=lambda data: int(data['lens']), reverse=True)\n",
|
||||
" max_len = int(batch[0]['lens'])\n",
|
||||
" \n",
|
||||
" def _padding(x, max_len):\n",
|
||||
" return x + [0] * (max_len - len(x))\n",
|
||||
" \n",
|
||||
" def _pad_adj(adj, max_len):\n",
|
||||
" adj = np.array(adj)\n",
|
||||
" pad_len = max_len - adj.shape[0]\n",
|
||||
" for i in range(pad_len):\n",
|
||||
" adj = np.insert(adj, adj.shape[-1], 0, axis=1)\n",
|
||||
" for i in range(pad_len):\n",
|
||||
" adj = np.insert(adj, adj.shape[0], 0, axis=0)\n",
|
||||
" return adj\n",
|
||||
" \n",
|
||||
" x, y = dict(), []\n",
|
||||
" word, word_len = [], []\n",
|
||||
" head_pos, tail_pos = [], []\n",
|
||||
" pcnn_mask = []\n",
|
||||
" adj_matrix = []\n",
|
||||
" for data in batch:\n",
|
||||
" word.append(_padding(data['token2idx'], max_len))\n",
|
||||
" word_len.append(int(data['lens']))\n",
|
||||
" y.append(int(data['rel2idx']))\n",
|
||||
" \n",
|
||||
" if cfg.model_name != 'lm':\n",
|
||||
" head_pos.append(_padding(data['head_pos'], max_len))\n",
|
||||
" tail_pos.append(_padding(data['tail_pos'], max_len))\n",
|
||||
" if cfg.model_name == 'gcn':\n",
|
||||
" head = eval(data['dependency'])\n",
|
||||
" adj = head_to_adj(head, directed=True, self_loop=True)\n",
|
||||
" adj_matrix.append(_pad_adj(adj, max_len))\n",
|
||||
"\n",
|
||||
" if cfg.use_pcnn:\n",
|
||||
" pcnn_mask.append(_padding(data['entities_pos'], max_len))\n",
|
||||
"\n",
|
||||
" x['word'] = torch.tensor(word)\n",
|
||||
" x['lens'] = torch.tensor(word_len)\n",
|
||||
" y = torch.tensor(y)\n",
|
||||
" \n",
|
||||
" if cfg.model_name != 'lm':\n",
|
||||
" x['head_pos'] = torch.tensor(head_pos)\n",
|
||||
" x['tail_pos'] = torch.tensor(tail_pos)\n",
|
||||
" if cfg.model_name == 'gcn':\n",
|
||||
" x['adj'] = torch.tensor(adj_matrix)\n",
|
||||
" if cfg.model_name == 'cnn' and cfg.use_pcnn:\n",
|
||||
" x['pcnn_mask'] = torch.tensor(pcnn_mask)\n",
|
||||
"\n",
|
||||
" return x, y\n",
|
||||
" \n",
|
||||
" return collate_fn_intra\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class CustomDataset(Dataset):\n",
|
||||
" def __init__(self, fp):\n",
|
||||
" self.file = load_pkl(fp)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, item):\n",
|
||||
" sample = self.file[item]\n",
|
||||
" return sample\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.file)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# embedding layer\n",
|
||||
"class Embedding(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
" super(Embedding, self).__init__()\n",
|
||||
"\n",
|
||||
" # self.xxx = config.xxx\n",
|
||||
" self.vocab_size = config.vocab_size\n",
|
||||
" self.word_dim = config.word_dim\n",
|
||||
" self.pos_size = config.pos_limit * 2 + 2\n",
|
||||
" self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim\n",
|
||||
" self.dim_strategy = config.dim_strategy\n",
|
||||
"\n",
|
||||
" self.wordEmbed = nn.Embedding(self.vocab_size,self.word_dim,padding_idx=0)\n",
|
||||
" self.headPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)\n",
|
||||
" self.tailPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def forward(self, *x):\n",
|
||||
" word, head, tail = x\n",
|
||||
" word_embedding = self.wordEmbed(word)\n",
|
||||
" head_embedding = self.headPosEmbed(head)\n",
|
||||
" tail_embedding = self.tailPosEmbed(tail)\n",
|
||||
"\n",
|
||||
" if self.dim_strategy == 'cat':\n",
|
||||
" return torch.cat((word_embedding,head_embedding, tail_embedding), -1)\n",
|
||||
" elif self.dim_strategy == 'sum':\n",
|
||||
" # 此时 pos_dim == word_dim\n",
|
||||
" return word_embedding + head_embedding + tail_embedding\n",
|
||||
" else:\n",
|
||||
" raise Exception('dim_strategy must choose from [sum, cat]')"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Gelu activation function, specified by transformer, works better than relu\n",
|
||||
"class GELU(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(GELU, self).__init__()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# cnn model\n",
|
||||
"class CNN(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
"\n",
|
||||
" super(CNN, self).__init__()\n",
|
||||
"\n",
|
||||
" if config.dim_strategy == 'cat':\n",
|
||||
" self.in_channels = config.word_dim + 2 * config.pos_dim\n",
|
||||
" else:\n",
|
||||
" self.in_channels = config.word_dim\n",
|
||||
"\n",
|
||||
" self.out_channels = config.out_channels\n",
|
||||
" self.kernel_sizes = config.kernel_sizes\n",
|
||||
" self.activation = config.activation\n",
|
||||
" self.pooling_strategy = config.pooling_strategy\n",
|
||||
" self.dropout = config.dropout\n",
|
||||
" for kernel_size in self.kernel_sizes:\n",
|
||||
" assert kernel_size % 2 == 1, \"kernel size has to be odd numbers.\"\n",
|
||||
"\n",
|
||||
" self.convs = nn.ModuleList([\n",
|
||||
" nn.Conv1d(in_channels=self.in_channels,\n",
|
||||
" out_channels=self.out_channels,\n",
|
||||
" kernel_size=k,\n",
|
||||
" stride=1,\n",
|
||||
" padding=k // 2,\n",
|
||||
" dilation=1,\n",
|
||||
" groups=1,\n",
|
||||
" bias=False) for k in self.kernel_sizes\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" assert self.activation in ['relu', 'lrelu', 'prelu', 'selu', 'celu', 'gelu', 'sigmoid', 'tanh'], \\\n",
|
||||
" 'activation function must choose from [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]'\n",
|
||||
" self.activations = nn.ModuleDict([\n",
|
||||
" ['relu', nn.ReLU()],\n",
|
||||
" ['lrelu', nn.LeakyReLU()],\n",
|
||||
" ['prelu', nn.PReLU()],\n",
|
||||
" ['selu', nn.SELU()],\n",
|
||||
" ['celu', nn.CELU()],\n",
|
||||
" ['gelu', GELU()],\n",
|
||||
" ['sigmoid', nn.Sigmoid()],\n",
|
||||
" ['tanh', nn.Tanh()],\n",
|
||||
" ])\n",
|
||||
"\n",
|
||||
" # pooling\n",
|
||||
" assert self.pooling_strategy in ['max', 'avg', 'cls'], 'pooling strategy must choose from [max, avg, cls]'\n",
|
||||
"\n",
|
||||
" self.dropout = nn.Dropout(self.dropout)\n",
|
||||
"\n",
|
||||
" def forward(self, x, mask=None):\n",
|
||||
" \n",
|
||||
" x = torch.transpose(x, 1, 2)\n",
|
||||
"\n",
|
||||
" act_fn = self.activations[self.activation]\n",
|
||||
"\n",
|
||||
" x = [act_fn(conv(x)) for conv in self.convs]\n",
|
||||
" x = torch.cat(x, dim=1)\n",
|
||||
"\n",
|
||||
" if mask is not None:\n",
|
||||
" mask = mask.unsqueeze(1)\n",
|
||||
" x = x.masked_fill_(mask, 1e-12)\n",
|
||||
"\n",
|
||||
" if self.pooling_strategy == 'max':\n",
|
||||
" xp = F.max_pool1d(x, kernel_size=x.size(2)).squeeze(2)\n",
|
||||
"\n",
|
||||
" elif self.pooling_strategy == 'avg':\n",
|
||||
" x_len = mask.squeeze().eq(0).sum(-1).unsqueeze(-1).to(torch.float).to(device=mask.device)\n",
|
||||
" xp = torch.sum(x, dim=-1) / x_len\n",
|
||||
"\n",
|
||||
" else:\n",
|
||||
" xp = x[:, :, 0]\n",
|
||||
"\n",
|
||||
" x = x.transpose(1, 2)\n",
|
||||
" x = self.dropout(x)\n",
|
||||
" xp = self.dropout(xp)\n",
|
||||
"\n",
|
||||
" return x, xp \n"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# pcnn model\n",
|
||||
"class PCNN(nn.Module):\n",
|
||||
" def __init__(self, cfg):\n",
|
||||
" super(PCNN, self).__init__()\n",
|
||||
"\n",
|
||||
" self.use_pcnn = cfg.use_pcnn\n",
|
||||
"\n",
|
||||
" self.embedding = Embedding(cfg)\n",
|
||||
" self.cnn = CNN(cfg)\n",
|
||||
" self.fc1 = nn.Linear(len(cfg.kernel_sizes) * cfg.out_channels, cfg.intermediate)\n",
|
||||
" self.fc2 = nn.Linear(cfg.intermediate, cfg.num_relations)\n",
|
||||
" self.dropout = nn.Dropout(cfg.dropout)\n",
|
||||
"\n",
|
||||
" if self.use_pcnn:\n",
|
||||
" self.fc_pcnn = nn.Linear(3 * len(cfg.kernel_sizes) * cfg.out_channels,\n",
|
||||
" len(cfg.kernel_sizes) * cfg.out_channels)\n",
|
||||
" self.pcnn_mask_embedding = nn.Embedding(4, 3)\n",
|
||||
" masks = torch.tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0, 100]])\n",
|
||||
" self.pcnn_mask_embedding.weight.data.copy_(masks)\n",
|
||||
" self.pcnn_mask_embedding.weight.requires_grad = False\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']\n",
|
||||
" mask = seq_len_to_mask(lens)\n",
|
||||
"\n",
|
||||
" inputs = self.embedding(word, head_pos, tail_pos)\n",
|
||||
" out, out_pool = self.cnn(inputs, mask=mask)\n",
|
||||
"\n",
|
||||
" if self.use_pcnn:\n",
|
||||
" out = out.unsqueeze(-1) # [B, L, Hs, 1]\n",
|
||||
" pcnn_mask = x['pcnn_mask']\n",
|
||||
" pcnn_mask = self.pcnn_mask_embedding(pcnn_mask).unsqueeze(-2) # [B, L, 1, 3]\n",
|
||||
" out = out + pcnn_mask # [B, L, Hs, 3]\n",
|
||||
" out = out.max(dim=1)[0] - 100 # [B, Hs, 3]\n",
|
||||
" out_pool = out.view(out.size(0), -1) # [B, 3 * Hs]\n",
|
||||
" out_pool = F.leaky_relu(self.fc_pcnn(out_pool)) # [B, Hs]\n",
|
||||
" out_pool = self.dropout(out_pool)\n",
|
||||
"\n",
|
||||
" output = self.fc1(out_pool)\n",
|
||||
" output = F.leaky_relu(output)\n",
|
||||
" output = self.dropout(output)\n",
|
||||
" output = self.fc2(output)\n",
|
||||
"\n",
|
||||
" return output"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# p,r,f1 measurement\n",
|
||||
"class PRMetric():\n",
|
||||
" def __init__(self):\n",
|
||||
" \n",
|
||||
" self.y_true = np.empty(0)\n",
|
||||
" self.y_pred = np.empty(0)\n",
|
||||
"\n",
|
||||
" def reset(self):\n",
|
||||
" self.y_true = np.empty(0)\n",
|
||||
" self.y_pred = np.empty(0)\n",
|
||||
"\n",
|
||||
" def update(self, y_true:torch.Tensor, y_pred:torch.Tensor):\n",
|
||||
" y_true = y_true.cpu().detach().numpy()\n",
|
||||
" y_pred = y_pred.cpu().detach().numpy()\n",
|
||||
" y_pred = np.argmax(y_pred,axis=-1)\n",
|
||||
"\n",
|
||||
" self.y_true = np.append(self.y_true, y_true)\n",
|
||||
" self.y_pred = np.append(self.y_pred, y_pred)\n",
|
||||
"\n",
|
||||
" def compute(self):\n",
|
||||
" p, r, f1, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='macro',warn_for=tuple())\n",
|
||||
" _, _, acc, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='micro',warn_for=tuple())\n",
|
||||
"\n",
|
||||
" return acc,p,r,f1"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Iteration in training process\n",
|
||||
"def train(epoch, model, dataloader, optimizer, criterion, cfg):\n",
|
||||
" model.train()\n",
|
||||
"\n",
|
||||
" metric = PRMetric()\n",
|
||||
" losses = []\n",
|
||||
"\n",
|
||||
" for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" y_pred = model(x)\n",
|
||||
" loss = criterion(y_pred, y)\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" metric.update(y_true=y, y_pred=y_pred)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
"\n",
|
||||
" data_total = len(dataloader.dataset)\n",
|
||||
" data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)\n",
|
||||
" if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):\n",
|
||||
" acc,p,r,f1 = metric.compute()\n",
|
||||
" print(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\\t'\n",
|
||||
" f'Loss: {loss.item():.6f}')\n",
|
||||
" print(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\\t'\n",
|
||||
" f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')\n",
|
||||
"\n",
|
||||
" if cfg.show_plot and not cfg.only_comparison_plot:\n",
|
||||
" if cfg.plot_utils == 'matplot':\n",
|
||||
" plt.plot(losses)\n",
|
||||
" plt.title(f'epoch {epoch} train loss')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
" return losses[-1]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Iteration in testing process\n",
|
||||
"def validate(epoch, model, dataloader, criterion,verbose=True):\n",
|
||||
" model.eval()\n",
|
||||
"\n",
|
||||
" metric = PRMetric()\n",
|
||||
" losses = []\n",
|
||||
"\n",
|
||||
" for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
|
||||
" with torch.no_grad():\n",
|
||||
" y_pred = model(x)\n",
|
||||
" loss = criterion(y_pred, y)\n",
|
||||
"\n",
|
||||
" metric.update(y_true=y, y_pred=y_pred)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
"\n",
|
||||
" loss = sum(losses) / len(losses)\n",
|
||||
" acc,p,r,f1 = metric.compute()\n",
|
||||
" data_total = len(dataloader.dataset)\n",
|
||||
" if verbose:\n",
|
||||
" print(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\\t Loss: {loss:.6f}')\n",
|
||||
" print(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]\\n\\n')\n",
|
||||
"\n",
|
||||
" return f1,loss"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Load dataset\n",
|
||||
"train_dataset = CustomDataset(train_save_fp)\n",
|
||||
"valid_dataset = CustomDataset(valid_save_fp)\n",
|
||||
"test_dataset = CustomDataset(test_save_fp)\n",
|
||||
"\n",
|
||||
"train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||||
"valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||||
"test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# After the preprocessed data is loaded, vocab_size is known\n",
|
||||
"vocab = load_pkl(vocab_save_fp)\n",
|
||||
"vocab_size = vocab.count\n",
|
||||
"cfg.vocab_size = vocab_size"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# main entry, define optimization function, loss function and so on\n",
|
||||
"# start epoch\n",
|
||||
"# Use the loss of the valid dataset to make an early stop judgment. When it does not decline, this is the time when the model generalization is the best.\n",
|
||||
"model = PCNN(cfg)\n",
|
||||
"print(model)\n",
|
||||
"\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)\n",
|
||||
"scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"\n",
|
||||
"best_f1, best_epoch = -1, 0\n",
|
||||
"es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, = 1000, -1, 0, 0, 0, -1\n",
|
||||
"train_losses, valid_losses = [], []\n",
|
||||
"\n",
|
||||
"logger.info('=' * 10 + ' Start training ' + '=' * 10)\n",
|
||||
"for epoch in range(1, cfg.epoch + 1):\n",
|
||||
" train_loss = train(epoch, model, train_dataloader, optimizer, criterion, cfg)\n",
|
||||
" valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion)\n",
|
||||
" scheduler.step(valid_loss)\n",
|
||||
"\n",
|
||||
" train_losses.append(train_loss)\n",
|
||||
" valid_losses.append(valid_loss)\n",
|
||||
" if best_f1 < valid_f1:\n",
|
||||
" best_f1 = valid_f1\n",
|
||||
" best_epoch = epoch\n",
|
||||
" # 使用 valid loss 做 early stopping 的判断标准\n",
|
||||
" if es_loss > valid_loss:\n",
|
||||
" es_loss = valid_loss\n",
|
||||
" es_f1 = valid_f1\n",
|
||||
" best_es_f1 = valid_f1\n",
|
||||
" es_epoch = epoch\n",
|
||||
" best_es_epoch = epoch\n",
|
||||
" es_patience = 0\n",
|
||||
" else:\n",
|
||||
" es_patience += 1\n",
|
||||
" if es_patience >= cfg.early_stopping_patience:\n",
|
||||
" best_es_epoch = es_epoch\n",
|
||||
" best_es_f1 = es_f1\n",
|
||||
"\n",
|
||||
"if cfg.show_plot:\n",
|
||||
" if cfg.plot_utils == 'matplot':\n",
|
||||
" plt.plot(train_losses, 'x-')\n",
|
||||
" plt.plot(valid_losses, '+-')\n",
|
||||
" plt.legend(['train', 'valid'])\n",
|
||||
" plt.title('train/valid comparison loss')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '\n",
|
||||
" f'this epoch macro f1: {best_es_f1:0.4f}')\n",
|
||||
"print(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '\n",
|
||||
" f'this epoch macro f1: {best_f1:.4f}')\n",
|
||||
"\n",
|
||||
"test_f1, _ = validate(0, model, test_dataloader, criterion,verbose=False)\n",
|
||||
"print(f'after {cfg.epoch} epochs, final test data macro f1: {test_f1:.4f}')"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"This demo does not include parameter adjustment. Interested students can go to [deepke] by themselves( http://openkg.cn/tool/deepke )Warehouse, download and use more models:)"
|
||||
],
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
head_type,tail_type,relation,index
|
||||
None,None,None,0
|
||||
影视作品,人物,导演,1
|
||||
景点,城市,所在城市,2
|
||||
歌曲,音乐专辑,所属专辑,3
|
|
|
@ -0,0 +1,4 @@
|
|||
sentence,relation,head,tail,tokens,lens,head_type,head_offset,tail_type,tail_offset,dependency
|
||||
建国后南京于1951年将莫愁湖列为第一区人民公园。,所在城市,莫愁湖,南京,"['建国', '后', '南京', '于', '1951年', '将', '莫愁湖', '列为', '第一', '区', '人民', '公园', '。']",13,景点,6,城市,2,"[2, 8, 8, 8, 4, 8, 6, 0, 10, 12, 12, 8, 8]"
|
||||
2001年李三光被张黎选中让其饰演《走向共和》中的光绪皇帝而出道。,导演,走向共和,李三光,"['2001年', '李三光', '被', '张黎', '选中', '让', '其', '饰演', '《', '走向', '共和', '》', '中的', '光绪皇帝', '而', '出道', '。']",17,影视作品,9,人物,1,"[2, 5, 5, 5, 0, 5, 6, 6, 10, 13, 10, 10, 8, 13, 16, 13, 5]"
|
||||
《我爱秋莲》是收录于高胜美专辑《雷射》的一首金曲。,所属专辑,我爱秋莲,雷射,"['《', '我', '爱秋莲', '》', '是', '收录于', '高胜', '美专', '辑', '《', '雷射', '》', '的', '一首', '金曲', '。']",16,歌曲,1,音乐专辑,10,"[3, 3, 5, 3, 0, 15, 8, 9, 15, 11, 15, 11, 11, 15, 5, 5]"
|
|
|
@ -0,0 +1,7 @@
|
|||
sentence,relation,head,tail,tokens,lens,head_type,head_offset,tail_type,tail_offset,dependency
|
||||
孔正锡在2005年以一部温馨的爱情电影《长腿叔叔》敲开电影界大门。,导演,长腿叔叔,孔正锡,"['孔正锡', '在', '2005年', '以', '一部', '温馨', '的', '爱情', '电影', '《', '长腿', '叔叔', '》', '敲开', '电影界', '大门', '。']",17,影视作品,10,人物,0,"[14, 14, 2, 14, 9, 9, 6, 9, 12, 12, 12, 4, 12, 0, 16, 14, 14]"
|
||||
2014年8月,韩兆导演的电影《好命先生》正式上映。,导演,好命先生,韩兆,"['2014年8月', ',', '韩兆', '导演', '的', '电影', '《', '好命先生', '》', '正式', '上映', '。']",12,影视作品,7,人物,2,"[11, 1, 4, 6, 4, 8, 8, 11, 8, 11, 0, 11]"
|
||||
2000年8月,「天坛大佛」荣获「香港十大杰出工程项目」第四名。,所在城市,天坛大佛,香港,"['2000年8月', ',', '「', '天坛', '大佛', '」', '荣获', '「', '香港', '十', '大', '杰出', '工程项目', '」', '第四', '名', '。']",17,景点,3,城市,8,"[7, 1, 5, 5, 7, 5, 0, 13, 13, 11, 13, 13, 16, 13, 16, 7, 7]"
|
||||
地安门是北京主皇城四门之一。,所在城市,地安门,北京,"['地安门', '是', '北京', '主', '皇城', '四', '门', '之一', '。']",9,景点,0,城市,2,"[2, 0, 5, 5, 7, 7, 8, 2, 2]"
|
||||
《伤心的树》是吴宗宪的音乐作品,收录在《你比从前快乐》专辑中。,所属专辑,伤心的树,你比从前快乐,"['《', '伤心', '的', '树', '》', '是', '吴宗宪', '的', '音乐作品', ',', '收录', '在', '《', '你', '比', '从前', '快乐', '》', '专辑', '中', '。']",21,歌曲,1,音乐专辑,13,"[4, 4, 2, 6, 4, 0, 9, 7, 6, 6, 6, 11, 17, 17, 17, 15, 19, 17, 20, 12, 6]"
|
||||
请不要认错我是关淑怡专辑《冬恋》里的一首歌曲。,所属专辑,请不要认错我,冬恋,"['请', '不要', '认错', '我', '是', '关淑怡', '专辑', '《', '冬恋', '》', '里', '的', '一', '首', '歌曲', '。']",16,歌曲,0,音乐专辑,8,"[0, 3, 1, 5, 3, 7, 9, 9, 11, 9, 15, 11, 14, 15, 5, 1]"
|
|
|
@ -0,0 +1,4 @@
|
|||
sentence,relation,head,tail,tokens,lens,head_type,head_offset,tail_type,tail_offset,dependency
|
||||
《岳父也是爹》是王军执导的电视剧,马恩然、范明主演。,导演,岳父也是爹,王军,"['《', '岳父', '也是', '爹', '》', '是', '王军', '执导', '的', '电视剧', ',', '马恩然', '、', '范明', '主演', '。']",16,影视作品,1,人物,6,"[4, 4, 4, 6, 4, 0, 8, 10, 8, 6, 6, 15, 14, 12, 6, 6]"
|
||||
渔人码头的落成使得澳门旅游业展现全新的旅游面貌。,所在城市,渔人码头,澳门,"['渔人码头', '的', '落成', '使得', '澳门', '旅游业', '展现', '全新', '的', '旅游', '面貌', '。']",12,景点,0,城市,4,"[3, 1, 4, 0, 6, 4, 4, 11, 8, 11, 7, 4]"
|
||||
《寄梦》是黄露仪的音乐作品,收录在《宁愿相信》专辑中。,所属专辑,寄梦,宁愿相信,"['《', '寄', '梦', '》', '是', '黄露仪', '的', '音乐作品', ',', '收录', '在', '《', '宁愿', '相信', '》', '专辑', '中', '。']",18,歌曲,1,音乐专辑,12,"[2, 5, 2, 2, 0, 8, 6, 5, 5, 5, 10, 14, 14, 17, 14, 17, 11, 5]"
|
|
Binary file not shown.
After Width: | Height: | Size: 68 KiB |
Binary file not shown.
After Width: | Height: | Size: 292 KiB |
Binary file not shown.
After Width: | Height: | Size: 137 KiB |
Loading…
Reference in New Issue