910 lines
36 KiB
Plaintext
910 lines
36 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## relation extraction 实践\n",
|
||
"> Tutorial作者:余海阳(yuhaiyang@zju.edu.cn)\n",
|
||
"\n",
|
||
"在这个演示中,我们使用 `pcnn` 模型实现中文关系抽取。\n",
|
||
"希望在这个demo中帮助大家了解知识图谱构建过程中,三元组抽取构建的原理和常用方法。\n",
|
||
"\n",
|
||
"本demo使用 `python3` 运⾏。\n",
|
||
"\n",
|
||
"### 数据集\n",
|
||
"在这个示例中,我们采样了一些中文文本,抽取其中的三元组。\n",
|
||
"\n",
|
||
"sentence|relation|head|tail\n",
|
||
":---:|:---:|:---:|:---:\n",
|
||
"孔正锡在2005年以一部温馨的爱情电影《长腿叔叔》敲开电影界大门。|导演|长腿叔叔|孔正锡\n",
|
||
"《伤心的树》是吴宗宪的音乐作品,收录在《你比从前快乐》专辑中。|所属专辑|伤心的树|你比从前快乐\n",
|
||
"2000年8月,「天坛大佛」荣获「香港十大杰出工程项目」第四名。|所在城市|天坛大佛|香港\n",
|
||
"\n",
|
||
"\n",
|
||
"- train.csv: 包含6个训练三元组,文件的每一⾏表示一个三元组, 按句子、关系、头实体、尾实体排序,并用`,`分隔。\n",
|
||
"- valid.csv: 包含3个验证三元组,文件的每一⾏表示一个三元组, 按句子、关系、头实体、尾实体排序,并用`,`分隔。\n",
|
||
"- test.csv: 包含3个测试三元组,文件的每一⾏表示一个三元组, 按句子、关系、头实体、尾实体排序,并用`,`分隔。\n",
|
||
"- relation.csv: 包含4种关系三元组,文件的每一⾏表示一个三元组种类, 按头实体种类、尾实体种类、关系、序号排序,并用`,`分隔。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### PCNN 原理回顾\n",
|
||
"\n",
|
||
"![PCNN](img/PCNN.jpg)\n",
|
||
"\n",
|
||
"句子信息主要包括word embedding和position embedding,在经过卷积层后,按照head tail的位置,分为三段最大池化,然后经过全连接层,即可得到句子的关系信息。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 使用pytorch运行神经网络,运行前确认是否安装\n",
|
||
"!pip install torch\n",
|
||
"!pip install matplotlib\n",
|
||
"!pip install transformers"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 导入所使用模块\n",
|
||
"import os\n",
|
||
"import csv\n",
|
||
"import math\n",
|
||
"import pickle\n",
|
||
"import logging\n",
|
||
"import torch\n",
|
||
"import torch.nn as nn\n",
|
||
"import torch.nn.functional as F\n",
|
||
"import numpy as np\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"from torch import optim\n",
|
||
"from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
|
||
"from torch.utils.data import Dataset,DataLoader\n",
|
||
"from sklearn.metrics import precision_recall_fscore_support\n",
|
||
"from typing import List, Tuple, Dict, Any, Sequence, Optional, Union\n",
|
||
"from transformers import BertTokenizer, BertModel\n",
|
||
"\n",
|
||
"logger = logging.getLogger(__name__)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 模型调参的配置文件\n",
|
||
"# use_pcnn 参数控制是否是有piece_wise 的方式池化\n",
|
||
"class Config(object):\n",
|
||
" model_name = 'cnn' # ['cnn', 'gcn', 'lm']\n",
|
||
" use_pcnn = True\n",
|
||
" min_freq = 1\n",
|
||
" pos_limit = 20\n",
|
||
" out_path = 'data/out' \n",
|
||
" batch_size = 2 \n",
|
||
" word_dim = 10\n",
|
||
" pos_dim = 5\n",
|
||
" dim_strategy = 'sum' # ['sum', 'cat']\n",
|
||
" out_channels = 20\n",
|
||
" intermediate = 10\n",
|
||
" kernel_sizes = [3, 5, 7]\n",
|
||
" activation = 'gelu'\n",
|
||
" pooling_strategy = 'max'\n",
|
||
" dropout = 0.3\n",
|
||
" epoch = 10\n",
|
||
" num_relations = 4\n",
|
||
" learning_rate = 3e-4\n",
|
||
" lr_factor = 0.7 # 学习率的衰减率\n",
|
||
" lr_patience = 3 # 学习率衰减的等待epoch\n",
|
||
" weight_decay = 1e-3 # L2正则\n",
|
||
" early_stopping_patience = 6\n",
|
||
" train_log = True\n",
|
||
" log_interval = 1\n",
|
||
" show_plot = True\n",
|
||
" only_comparison_plot = False\n",
|
||
" plot_utils = 'matplot'\n",
|
||
" lm_file = 'bert-base-chinese'\n",
|
||
" lm_num_hidden_layers = 2\n",
|
||
" rnn_layers = 2\n",
|
||
" \n",
|
||
"cfg = Config()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# word token 构建 one-hot 词典,后续输入到embedding层得到对应word信息矩阵\n",
|
||
"# 一般默认0为pad,1为unknown\n",
|
||
"class Vocab(object):\n",
|
||
" def __init__(self, name: str = 'basic', init_tokens = [\"[PAD]\", \"[UNK]\"]):\n",
|
||
" self.name = name\n",
|
||
" self.init_tokens = init_tokens\n",
|
||
" self.trimed = False\n",
|
||
" self.word2idx = {}\n",
|
||
" self.word2count = {}\n",
|
||
" self.idx2word = {}\n",
|
||
" self.count = 0\n",
|
||
" self._add_init_tokens()\n",
|
||
"\n",
|
||
" def _add_init_tokens(self):\n",
|
||
" for token in self.init_tokens:\n",
|
||
" self._add_word(token)\n",
|
||
"\n",
|
||
" def _add_word(self, word: str):\n",
|
||
" if word not in self.word2idx:\n",
|
||
" self.word2idx[word] = self.count\n",
|
||
" self.word2count[word] = 1\n",
|
||
" self.idx2word[self.count] = word\n",
|
||
" self.count += 1\n",
|
||
" else:\n",
|
||
" self.word2count[word] += 1\n",
|
||
"\n",
|
||
" def add_words(self, words: Sequence):\n",
|
||
" for word in words:\n",
|
||
" self._add_word(word)\n",
|
||
"\n",
|
||
" def trim(self, min_freq=2, verbose: Optional[bool] = True):\n",
|
||
" '''当 word 词频低于 min_freq 时,从词库中删除\n",
|
||
"\n",
|
||
" Args:\n",
|
||
" param min_freq: 最低词频\n",
|
||
" '''\n",
|
||
" assert min_freq == int(min_freq), f'min_freq must be integer, can\\'t be {min_freq}'\n",
|
||
" min_freq = int(min_freq)\n",
|
||
" if min_freq < 2:\n",
|
||
" return\n",
|
||
" if self.trimed:\n",
|
||
" return\n",
|
||
" self.trimed = True\n",
|
||
"\n",
|
||
" keep_words = []\n",
|
||
" new_words = []\n",
|
||
"\n",
|
||
" for k, v in self.word2count.items():\n",
|
||
" if v >= min_freq:\n",
|
||
" keep_words.append(k)\n",
|
||
" new_words.extend([k] * v)\n",
|
||
" if verbose:\n",
|
||
" before_len = len(keep_words)\n",
|
||
" after_len = len(self.word2idx) - len(self.init_tokens)\n",
|
||
" logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(before_len, after_len, before_len / after_len * 100))\n",
|
||
"\n",
|
||
" # Reinitialize dictionaries\n",
|
||
" self.word2idx = {}\n",
|
||
" self.word2count = {}\n",
|
||
" self.idx2word = {}\n",
|
||
" self.count = 0\n",
|
||
" self._add_init_tokens()\n",
|
||
" self.add_words(new_words)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 预处理过程所需要使用的函数\n",
|
||
"Path = str\n",
|
||
"\n",
|
||
"def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:\n",
|
||
" if verbose:\n",
|
||
" logger.info(f'load csv from {fp}')\n",
|
||
"\n",
|
||
" dialect = 'excel-tab' if is_tsv else 'excel'\n",
|
||
" with open(fp, encoding='utf-8') as f:\n",
|
||
" reader = csv.DictReader(f, dialect=dialect)\n",
|
||
" return list(reader)\n",
|
||
"\n",
|
||
" \n",
|
||
"def load_pkl(fp: Path, verbose: bool = True) -> Any:\n",
|
||
" if verbose:\n",
|
||
" logger.info(f'load data from {fp}')\n",
|
||
"\n",
|
||
" with open(fp, 'rb') as f:\n",
|
||
" data = pickle.load(f)\n",
|
||
" return data\n",
|
||
"\n",
|
||
"\n",
|
||
"def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:\n",
|
||
" if verbose:\n",
|
||
" logger.info(f'save data in {fp}')\n",
|
||
"\n",
|
||
" with open(fp, 'wb') as f:\n",
|
||
" pickle.dump(data, f)\n",
|
||
" \n",
|
||
" \n",
|
||
"def _handle_relation_data(relation_data: List[Dict]) -> Dict:\n",
|
||
" rels = dict()\n",
|
||
" for d in relation_data:\n",
|
||
" rels[d['relation']] = {\n",
|
||
" 'index': int(d['index']),\n",
|
||
" 'head_type': d['head_type'],\n",
|
||
" 'tail_type': d['tail_type'],\n",
|
||
" }\n",
|
||
" return rels\n",
|
||
"\n",
|
||
"\n",
|
||
"def _add_relation_data(rels: Dict,data: List) -> None:\n",
|
||
" for d in data:\n",
|
||
" d['rel2idx'] = rels[d['relation']]['index']\n",
|
||
" d['head_type'] = rels[d['relation']]['head_type']\n",
|
||
" d['tail_type'] = rels[d['relation']]['tail_type']\n",
|
||
"\n",
|
||
"\n",
|
||
"def _convert_tokens_into_index(data: List[Dict], vocab):\n",
|
||
" unk_str = '[UNK]'\n",
|
||
" unk_idx = vocab.word2idx[unk_str]\n",
|
||
"\n",
|
||
" for d in data:\n",
|
||
" d['token2idx'] = [vocab.word2idx.get(i, unk_idx) for i in d['tokens']]\n",
|
||
"\n",
|
||
"\n",
|
||
"def _add_pos_seq(train_data: List[Dict], cfg):\n",
|
||
" for d in train_data:\n",
|
||
" d['head_offset'], d['tail_offset'], d['lens'] = int(d['head_offset']), int(d['tail_offset']), int(d['lens'])\n",
|
||
" entities_idx = [d['head_offset'], d['tail_offset']] if d['head_offset'] < d['tail_offset'] else [d['tail_offset'], d['head_offset']]\n",
|
||
"\n",
|
||
" d['head_pos'] = list(map(lambda i: i - d['head_offset'], list(range(d['lens']))))\n",
|
||
" d['head_pos'] = _handle_pos_limit(d['head_pos'], int(cfg.pos_limit))\n",
|
||
"\n",
|
||
" d['tail_pos'] = list(map(lambda i: i - d['tail_offset'], list(range(d['lens']))))\n",
|
||
" d['tail_pos'] = _handle_pos_limit(d['tail_pos'], int(cfg.pos_limit))\n",
|
||
"\n",
|
||
" if cfg.use_pcnn:\n",
|
||
" d['entities_pos'] = [1] * (entities_idx[0] + 1) + [2] * (entities_idx[1] - entities_idx[0] - 1) +\\\n",
|
||
" [3] * (d['lens'] - entities_idx[1])\n",
|
||
"\n",
|
||
" \n",
|
||
"def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:\n",
|
||
" for i, p in enumerate(pos):\n",
|
||
" if p > limit:\n",
|
||
" pos[i] = limit\n",
|
||
" if p < -limit:\n",
|
||
" pos[i] = -limit\n",
|
||
" return [p + limit + 1 for p in pos]\n",
|
||
"\n",
|
||
"\n",
|
||
"def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):\n",
|
||
" \"\"\"\n",
|
||
" 将一个表示sequence length的一维数组转换为二维的mask,默认pad的位置为1。\n",
|
||
" 转变 1-d seq_len到2-d mask.\n",
|
||
"\n",
|
||
" :param list, np.ndarray, torch.LongTensor seq_len: shape将是(B,)\n",
|
||
" :param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有\n",
|
||
" 区别,所以需要传入一个max_len使得mask的长度是pad到该长度。\n",
|
||
" :return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8\n",
|
||
" \"\"\"\n",
|
||
" if isinstance(seq_len, list):\n",
|
||
" seq_len = np.array(seq_len)\n",
|
||
"\n",
|
||
" if isinstance(seq_len, np.ndarray):\n",
|
||
" seq_len = torch.from_numpy(seq_len)\n",
|
||
"\n",
|
||
" if isinstance(seq_len, torch.Tensor):\n",
|
||
" assert seq_len.dim() == 1, logger.error(f\"seq_len can only have one dimension, got {seq_len.dim()} != 1.\")\n",
|
||
" batch_size = seq_len.size(0)\n",
|
||
" max_len = int(max_len) if max_len else seq_len.max().long()\n",
|
||
" broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len.device)\n",
|
||
" if mask_pos_to_true:\n",
|
||
" mask = broad_cast_seq_len.ge(seq_len.unsqueeze(1))\n",
|
||
" else:\n",
|
||
" mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))\n",
|
||
" else:\n",
|
||
" raise logger.error(\"Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.\")\n",
|
||
"\n",
|
||
" return mask"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 预处理过程\n",
|
||
"logger.info('load raw files...')\n",
|
||
"train_fp = os.path.join('data/train.csv')\n",
|
||
"valid_fp = os.path.join('data/valid.csv')\n",
|
||
"test_fp = os.path.join('data/test.csv')\n",
|
||
"relation_fp = os.path.join('data/relation.csv')\n",
|
||
"\n",
|
||
"train_data = load_csv(train_fp)\n",
|
||
"valid_data = load_csv(valid_fp)\n",
|
||
"test_data = load_csv(test_fp)\n",
|
||
"relation_data = load_csv(relation_fp)\n",
|
||
"\n",
|
||
"for d in train_data:\n",
|
||
" d['tokens'] = eval(d['tokens'])\n",
|
||
"for d in valid_data:\n",
|
||
" d['tokens'] = eval(d['tokens'])\n",
|
||
"for d in test_data:\n",
|
||
" d['tokens'] = eval(d['tokens'])\n",
|
||
" \n",
|
||
"logger.info('convert relation into index...')\n",
|
||
"rels = _handle_relation_data(relation_data)\n",
|
||
"_add_relation_data(rels, train_data)\n",
|
||
"_add_relation_data(rels, valid_data)\n",
|
||
"_add_relation_data(rels, test_data)\n",
|
||
"\n",
|
||
"logger.info('verify whether use pretrained language models...')\n",
|
||
"logger.info('build vocabulary...')\n",
|
||
"vocab = Vocab('word')\n",
|
||
"train_tokens = [d['tokens'] for d in train_data]\n",
|
||
"valid_tokens = [d['tokens'] for d in valid_data]\n",
|
||
"test_tokens = [d['tokens'] for d in test_data]\n",
|
||
"sent_tokens = [*train_tokens, *valid_tokens, *test_tokens]\n",
|
||
"for sent in sent_tokens:\n",
|
||
" vocab.add_words(sent)\n",
|
||
"vocab.trim(min_freq=cfg.min_freq)\n",
|
||
"\n",
|
||
"logger.info('convert tokens into index...')\n",
|
||
"_convert_tokens_into_index(train_data, vocab)\n",
|
||
"_convert_tokens_into_index(valid_data, vocab)\n",
|
||
"_convert_tokens_into_index(test_data, vocab)\n",
|
||
"\n",
|
||
"logger.info('build position sequence...')\n",
|
||
"_add_pos_seq(train_data, cfg)\n",
|
||
"_add_pos_seq(valid_data, cfg)\n",
|
||
"_add_pos_seq(test_data, cfg)\n",
|
||
"\n",
|
||
"logger.info('save data for backup...')\n",
|
||
"os.makedirs(cfg.out_path, exist_ok=True)\n",
|
||
"train_save_fp = os.path.join(cfg.out_path, 'train.pkl')\n",
|
||
"valid_save_fp = os.path.join(cfg.out_path, 'valid.pkl')\n",
|
||
"test_save_fp = os.path.join(cfg.out_path, 'test.pkl')\n",
|
||
"save_pkl(train_data, train_save_fp)\n",
|
||
"save_pkl(valid_data, valid_save_fp)\n",
|
||
"save_pkl(test_data, test_save_fp)\n",
|
||
"\n",
|
||
"vocab_save_fp = os.path.join(cfg.out_path, 'vocab.pkl')\n",
|
||
"vocab_txt = os.path.join(cfg.out_path, 'vocab.txt')\n",
|
||
"save_pkl(vocab, vocab_save_fp)\n",
|
||
"logger.info('save vocab in txt file, for watching...')\n",
|
||
"with open(vocab_txt, 'w', encoding='utf-8') as f:\n",
|
||
" f.write(os.linesep.join(vocab.word2idx.keys()))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# pytorch 构建自定义 Dataset\n",
|
||
"def collate_fn(cfg):\n",
|
||
" def collate_fn_intra(batch):\n",
|
||
" batch.sort(key=lambda data: int(data['lens']), reverse=True)\n",
|
||
" max_len = int(batch[0]['lens'])\n",
|
||
" \n",
|
||
" def _padding(x, max_len):\n",
|
||
" return x + [0] * (max_len - len(x))\n",
|
||
" \n",
|
||
" def _pad_adj(adj, max_len):\n",
|
||
" adj = np.array(adj)\n",
|
||
" pad_len = max_len - adj.shape[0]\n",
|
||
" for i in range(pad_len):\n",
|
||
" adj = np.insert(adj, adj.shape[-1], 0, axis=1)\n",
|
||
" for i in range(pad_len):\n",
|
||
" adj = np.insert(adj, adj.shape[0], 0, axis=0)\n",
|
||
" return adj\n",
|
||
" \n",
|
||
" x, y = dict(), []\n",
|
||
" word, word_len = [], []\n",
|
||
" head_pos, tail_pos = [], []\n",
|
||
" pcnn_mask = []\n",
|
||
" adj_matrix = []\n",
|
||
" for data in batch:\n",
|
||
" word.append(_padding(data['token2idx'], max_len))\n",
|
||
" word_len.append(int(data['lens']))\n",
|
||
" y.append(int(data['rel2idx']))\n",
|
||
" \n",
|
||
" if cfg.model_name != 'lm':\n",
|
||
" head_pos.append(_padding(data['head_pos'], max_len))\n",
|
||
" tail_pos.append(_padding(data['tail_pos'], max_len))\n",
|
||
" if cfg.model_name == 'gcn':\n",
|
||
" head = eval(data['dependency'])\n",
|
||
" adj = head_to_adj(head, directed=True, self_loop=True)\n",
|
||
" adj_matrix.append(_pad_adj(adj, max_len))\n",
|
||
"\n",
|
||
" if cfg.use_pcnn:\n",
|
||
" pcnn_mask.append(_padding(data['entities_pos'], max_len))\n",
|
||
"\n",
|
||
" x['word'] = torch.tensor(word)\n",
|
||
" x['lens'] = torch.tensor(word_len)\n",
|
||
" y = torch.tensor(y)\n",
|
||
" \n",
|
||
" if cfg.model_name != 'lm':\n",
|
||
" x['head_pos'] = torch.tensor(head_pos)\n",
|
||
" x['tail_pos'] = torch.tensor(tail_pos)\n",
|
||
" if cfg.model_name == 'gcn':\n",
|
||
" x['adj'] = torch.tensor(adj_matrix)\n",
|
||
" if cfg.model_name == 'cnn' and cfg.use_pcnn:\n",
|
||
" x['pcnn_mask'] = torch.tensor(pcnn_mask)\n",
|
||
"\n",
|
||
" return x, y\n",
|
||
" \n",
|
||
" return collate_fn_intra\n",
|
||
"\n",
|
||
"\n",
|
||
"class CustomDataset(Dataset):\n",
|
||
" \"\"\"默认使用 List 存储数据\"\"\"\n",
|
||
" def __init__(self, fp):\n",
|
||
" self.file = load_pkl(fp)\n",
|
||
"\n",
|
||
" def __getitem__(self, item):\n",
|
||
" sample = self.file[item]\n",
|
||
" return sample\n",
|
||
"\n",
|
||
" def __len__(self):\n",
|
||
" return len(self.file)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# embedding层\n",
|
||
"class Embedding(nn.Module):\n",
|
||
" def __init__(self, config):\n",
|
||
" \"\"\"\n",
|
||
" word embedding: 一般 0 为 padding\n",
|
||
" pos embedding: 一般 0 为 padding\n",
|
||
" dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加\n",
|
||
" \"\"\"\n",
|
||
" super(Embedding, self).__init__()\n",
|
||
"\n",
|
||
" # self.xxx = config.xxx\n",
|
||
" self.vocab_size = config.vocab_size\n",
|
||
" self.word_dim = config.word_dim\n",
|
||
" self.pos_size = config.pos_limit * 2 + 2\n",
|
||
" self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim\n",
|
||
" self.dim_strategy = config.dim_strategy\n",
|
||
"\n",
|
||
" self.wordEmbed = nn.Embedding(self.vocab_size,self.word_dim,padding_idx=0)\n",
|
||
" self.headPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)\n",
|
||
" self.tailPosEmbed = nn.Embedding(self.pos_size,self.pos_dim,padding_idx=0)\n",
|
||
"\n",
|
||
"\n",
|
||
" def forward(self, *x):\n",
|
||
" word, head, tail = x\n",
|
||
" word_embedding = self.wordEmbed(word)\n",
|
||
" head_embedding = self.headPosEmbed(head)\n",
|
||
" tail_embedding = self.tailPosEmbed(tail)\n",
|
||
"\n",
|
||
" if self.dim_strategy == 'cat':\n",
|
||
" return torch.cat((word_embedding,head_embedding, tail_embedding), -1)\n",
|
||
" elif self.dim_strategy == 'sum':\n",
|
||
" # 此时 pos_dim == word_dim\n",
|
||
" return word_embedding + head_embedding + tail_embedding\n",
|
||
" else:\n",
|
||
" raise Exception('dim_strategy must choose from [sum, cat]')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# gelu激活函数,transformer指定,效果比relu好\n",
|
||
"class GELU(nn.Module):\n",
|
||
" def __init__(self):\n",
|
||
" super(GELU, self).__init__()\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# cnn 模型\n",
|
||
"class CNN(nn.Module):\n",
|
||
" \"\"\"\n",
|
||
" nlp 里为了保证输出的句长 = 输入的句长,一般使用奇数 kernel_size,如 [3, 5, 7, 9]\n",
|
||
" 此时,padding = k // 2\n",
|
||
" stride 一般为 1\n",
|
||
" \"\"\"\n",
|
||
" def __init__(self, config):\n",
|
||
" \"\"\"\n",
|
||
" in_channels : 一般就是 word embedding 的维度,或者 hidden size 的维度\n",
|
||
" out_channels : int\n",
|
||
" kernel_sizes : list 为了保证输出长度=输入长度,必须为奇数: 3, 5, 7...\n",
|
||
" activation : [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]\n",
|
||
" pooling_strategy : [max, avg, cls]\n",
|
||
" dropout: : float\n",
|
||
" \"\"\"\n",
|
||
" super(CNN, self).__init__()\n",
|
||
"\n",
|
||
" # self.xxx = config.xxx\n",
|
||
" # self.in_channels = config.in_channels\n",
|
||
" if config.dim_strategy == 'cat':\n",
|
||
" self.in_channels = config.word_dim + 2 * config.pos_dim\n",
|
||
" else:\n",
|
||
" self.in_channels = config.word_dim\n",
|
||
"\n",
|
||
" self.out_channels = config.out_channels\n",
|
||
" self.kernel_sizes = config.kernel_sizes\n",
|
||
" self.activation = config.activation\n",
|
||
" self.pooling_strategy = config.pooling_strategy\n",
|
||
" self.dropout = config.dropout\n",
|
||
" for kernel_size in self.kernel_sizes:\n",
|
||
" assert kernel_size % 2 == 1, \"kernel size has to be odd numbers.\"\n",
|
||
"\n",
|
||
" # convolution\n",
|
||
" self.convs = nn.ModuleList([\n",
|
||
" nn.Conv1d(in_channels=self.in_channels,\n",
|
||
" out_channels=self.out_channels,\n",
|
||
" kernel_size=k,\n",
|
||
" stride=1,\n",
|
||
" padding=k // 2,\n",
|
||
" dilation=1,\n",
|
||
" groups=1,\n",
|
||
" bias=False) for k in self.kernel_sizes\n",
|
||
" ])\n",
|
||
"\n",
|
||
" # activation function\n",
|
||
" assert self.activation in ['relu', 'lrelu', 'prelu', 'selu', 'celu', 'gelu', 'sigmoid', 'tanh'], \\\n",
|
||
" 'activation function must choose from [relu, lrelu, prelu, selu, celu, gelu, sigmoid, tanh]'\n",
|
||
" self.activations = nn.ModuleDict([\n",
|
||
" ['relu', nn.ReLU()],\n",
|
||
" ['lrelu', nn.LeakyReLU()],\n",
|
||
" ['prelu', nn.PReLU()],\n",
|
||
" ['selu', nn.SELU()],\n",
|
||
" ['celu', nn.CELU()],\n",
|
||
" ['gelu', GELU()],\n",
|
||
" ['sigmoid', nn.Sigmoid()],\n",
|
||
" ['tanh', nn.Tanh()],\n",
|
||
" ])\n",
|
||
"\n",
|
||
" # pooling\n",
|
||
" assert self.pooling_strategy in ['max', 'avg', 'cls'], 'pooling strategy must choose from [max, avg, cls]'\n",
|
||
"\n",
|
||
" self.dropout = nn.Dropout(self.dropout)\n",
|
||
"\n",
|
||
" def forward(self, x, mask=None):\n",
|
||
" \"\"\"\n",
|
||
" :param x: torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H] 一般是经过embedding后的值\n",
|
||
" :param mask: [batch_size, max_len], 句长部分为0,padding部分为1。不影响卷积运算,max-pool一定不会pool到pad为0的位置\n",
|
||
" :return:\n",
|
||
" \"\"\"\n",
|
||
" # [B, L, H] -> [B, H, L] (注释:将 H 维度当作输入 channel 维度)\n",
|
||
" x = torch.transpose(x, 1, 2)\n",
|
||
"\n",
|
||
" # convolution + activation [[B, H, L], ... ]\n",
|
||
" act_fn = self.activations[self.activation]\n",
|
||
"\n",
|
||
" x = [act_fn(conv(x)) for conv in self.convs]\n",
|
||
" x = torch.cat(x, dim=1)\n",
|
||
"\n",
|
||
" # mask\n",
|
||
" if mask is not None:\n",
|
||
" # [B, L] -> [B, 1, L]\n",
|
||
" mask = mask.unsqueeze(1)\n",
|
||
" x = x.masked_fill_(mask, 1e-12)\n",
|
||
"\n",
|
||
" # pooling\n",
|
||
" # [[B, H, L], ... ] -> [[B, H], ... ]\n",
|
||
" if self.pooling_strategy == 'max':\n",
|
||
" xp = F.max_pool1d(x, kernel_size=x.size(2)).squeeze(2)\n",
|
||
" # 等价于 xp = torch.max(x, dim=2)[0]\n",
|
||
"\n",
|
||
" elif self.pooling_strategy == 'avg':\n",
|
||
" x_len = mask.squeeze().eq(0).sum(-1).unsqueeze(-1).to(torch.float).to(device=mask.device)\n",
|
||
" xp = torch.sum(x, dim=-1) / x_len\n",
|
||
"\n",
|
||
" else:\n",
|
||
" # self.pooling_strategy == 'cls'\n",
|
||
" xp = x[:, :, 0]\n",
|
||
"\n",
|
||
" x = x.transpose(1, 2)\n",
|
||
" x = self.dropout(x)\n",
|
||
" xp = self.dropout(xp)\n",
|
||
"\n",
|
||
" return x, xp # [B, L, Hs], [B, Hs]\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# pcnn 模型\n",
|
||
"class PCNN(nn.Module):\n",
|
||
" def __init__(self, cfg):\n",
|
||
" super(PCNN, self).__init__()\n",
|
||
"\n",
|
||
" self.use_pcnn = cfg.use_pcnn\n",
|
||
"\n",
|
||
" self.embedding = Embedding(cfg)\n",
|
||
" self.cnn = CNN(cfg)\n",
|
||
" self.fc1 = nn.Linear(len(cfg.kernel_sizes) * cfg.out_channels, cfg.intermediate)\n",
|
||
" self.fc2 = nn.Linear(cfg.intermediate, cfg.num_relations)\n",
|
||
" self.dropout = nn.Dropout(cfg.dropout)\n",
|
||
"\n",
|
||
" if self.use_pcnn:\n",
|
||
" self.fc_pcnn = nn.Linear(3 * len(cfg.kernel_sizes) * cfg.out_channels,\n",
|
||
" len(cfg.kernel_sizes) * cfg.out_channels)\n",
|
||
" self.pcnn_mask_embedding = nn.Embedding(4, 3)\n",
|
||
" masks = torch.tensor([[0, 0, 0], [100, 0, 0], [0, 100, 0], [0, 0, 100]])\n",
|
||
" self.pcnn_mask_embedding.weight.data.copy_(masks)\n",
|
||
" self.pcnn_mask_embedding.weight.requires_grad = False\n",
|
||
"\n",
|
||
"\n",
|
||
" def forward(self, x):\n",
|
||
" word, lens, head_pos, tail_pos = x['word'], x['lens'], x['head_pos'], x['tail_pos']\n",
|
||
" mask = seq_len_to_mask(lens)\n",
|
||
"\n",
|
||
" inputs = self.embedding(word, head_pos, tail_pos)\n",
|
||
" out, out_pool = self.cnn(inputs, mask=mask)\n",
|
||
"\n",
|
||
" if self.use_pcnn:\n",
|
||
" out = out.unsqueeze(-1) # [B, L, Hs, 1]\n",
|
||
" pcnn_mask = x['pcnn_mask']\n",
|
||
" pcnn_mask = self.pcnn_mask_embedding(pcnn_mask).unsqueeze(-2) # [B, L, 1, 3]\n",
|
||
" out = out + pcnn_mask # [B, L, Hs, 3]\n",
|
||
" out = out.max(dim=1)[0] - 100 # [B, Hs, 3]\n",
|
||
" out_pool = out.view(out.size(0), -1) # [B, 3 * Hs]\n",
|
||
" out_pool = F.leaky_relu(self.fc_pcnn(out_pool)) # [B, Hs]\n",
|
||
" out_pool = self.dropout(out_pool)\n",
|
||
"\n",
|
||
" output = self.fc1(out_pool)\n",
|
||
" output = F.leaky_relu(output)\n",
|
||
" output = self.dropout(output)\n",
|
||
" output = self.fc2(output)\n",
|
||
"\n",
|
||
" return output"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# p,r,f1 指标测量\n",
|
||
"class PRMetric():\n",
|
||
" def __init__(self):\n",
|
||
" \"\"\"\n",
|
||
" 暂时调用 sklearn 的方法\n",
|
||
" \"\"\"\n",
|
||
" self.y_true = np.empty(0)\n",
|
||
" self.y_pred = np.empty(0)\n",
|
||
"\n",
|
||
" def reset(self):\n",
|
||
" self.y_true = np.empty(0)\n",
|
||
" self.y_pred = np.empty(0)\n",
|
||
"\n",
|
||
" def update(self, y_true:torch.Tensor, y_pred:torch.Tensor):\n",
|
||
" y_true = y_true.cpu().detach().numpy()\n",
|
||
" y_pred = y_pred.cpu().detach().numpy()\n",
|
||
" y_pred = np.argmax(y_pred,axis=-1)\n",
|
||
"\n",
|
||
" self.y_true = np.append(self.y_true, y_true)\n",
|
||
" self.y_pred = np.append(self.y_pred, y_pred)\n",
|
||
"\n",
|
||
" def compute(self):\n",
|
||
" p, r, f1, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='macro',warn_for=tuple())\n",
|
||
" _, _, acc, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='micro',warn_for=tuple())\n",
|
||
"\n",
|
||
" return acc,p,r,f1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 训练过程中的迭代\n",
|
||
"def train(epoch, model, dataloader, optimizer, criterion, cfg):\n",
|
||
" model.train()\n",
|
||
"\n",
|
||
" metric = PRMetric()\n",
|
||
" losses = []\n",
|
||
"\n",
|
||
" for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
|
||
" optimizer.zero_grad()\n",
|
||
" y_pred = model(x)\n",
|
||
" loss = criterion(y_pred, y)\n",
|
||
"\n",
|
||
" loss.backward()\n",
|
||
" optimizer.step()\n",
|
||
"\n",
|
||
" metric.update(y_true=y, y_pred=y_pred)\n",
|
||
" losses.append(loss.item())\n",
|
||
"\n",
|
||
" data_total = len(dataloader.dataset)\n",
|
||
" data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)\n",
|
||
" if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):\n",
|
||
" # p r f1 皆为 macro,因为micro时三者相同,定义为acc\n",
|
||
" acc,p,r,f1 = metric.compute()\n",
|
||
" print(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\\t'\n",
|
||
" f'Loss: {loss.item():.6f}')\n",
|
||
" print(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\\t'\n",
|
||
" f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')\n",
|
||
"\n",
|
||
" if cfg.show_plot and not cfg.only_comparison_plot:\n",
|
||
" if cfg.plot_utils == 'matplot':\n",
|
||
" plt.plot(losses)\n",
|
||
" plt.title(f'epoch {epoch} train loss')\n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
" return losses[-1]\n",
|
||
"\n",
|
||
"\n",
|
||
"# 测试过程中的迭代\n",
|
||
"def validate(epoch, model, dataloader, criterion,verbose=True):\n",
|
||
" model.eval()\n",
|
||
"\n",
|
||
" metric = PRMetric()\n",
|
||
" losses = []\n",
|
||
"\n",
|
||
" for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
|
||
" with torch.no_grad():\n",
|
||
" y_pred = model(x)\n",
|
||
" loss = criterion(y_pred, y)\n",
|
||
"\n",
|
||
" metric.update(y_true=y, y_pred=y_pred)\n",
|
||
" losses.append(loss.item())\n",
|
||
"\n",
|
||
" loss = sum(losses) / len(losses)\n",
|
||
" acc,p,r,f1 = metric.compute()\n",
|
||
" data_total = len(dataloader.dataset)\n",
|
||
" if verbose:\n",
|
||
" print(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\\t Loss: {loss:.6f}')\n",
|
||
" print(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]\\n\\n')\n",
|
||
"\n",
|
||
" return f1,loss"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 加载数据集\n",
|
||
"train_dataset = CustomDataset(train_save_fp)\n",
|
||
"valid_dataset = CustomDataset(valid_save_fp)\n",
|
||
"test_dataset = CustomDataset(test_save_fp)\n",
|
||
"\n",
|
||
"train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||
"valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||
"test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 因为加载预处理后的数据,才知道vocab_size\n",
|
||
"vocab = load_pkl(vocab_save_fp)\n",
|
||
"vocab_size = vocab.count\n",
|
||
"cfg.vocab_size = vocab_size"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# main 入口,定义优化函数、loss函数等\n",
|
||
"# 开始epoch迭代\n",
|
||
"# 使用valid 数据集的loss做早停判断,当不再下降时,此时为模型泛化性最好的时刻。\n",
|
||
"model = PCNN(cfg)\n",
|
||
"print(model)\n",
|
||
"\n",
|
||
"optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)\n",
|
||
"scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)\n",
|
||
"criterion = nn.CrossEntropyLoss()\n",
|
||
"\n",
|
||
"best_f1, best_epoch = -1, 0\n",
|
||
"es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, = 1000, -1, 0, 0, 0, -1\n",
|
||
"train_losses, valid_losses = [], []\n",
|
||
"\n",
|
||
"logger.info('=' * 10 + ' Start training ' + '=' * 10)\n",
|
||
"for epoch in range(1, cfg.epoch + 1):\n",
|
||
" train_loss = train(epoch, model, train_dataloader, optimizer, criterion, cfg)\n",
|
||
" valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion)\n",
|
||
" scheduler.step(valid_loss)\n",
|
||
"\n",
|
||
" train_losses.append(train_loss)\n",
|
||
" valid_losses.append(valid_loss)\n",
|
||
" if best_f1 < valid_f1:\n",
|
||
" best_f1 = valid_f1\n",
|
||
" best_epoch = epoch\n",
|
||
" # 使用 valid loss 做 early stopping 的判断标准\n",
|
||
" if es_loss > valid_loss:\n",
|
||
" es_loss = valid_loss\n",
|
||
" es_f1 = valid_f1\n",
|
||
" best_es_f1 = valid_f1\n",
|
||
" es_epoch = epoch\n",
|
||
" best_es_epoch = epoch\n",
|
||
" es_patience = 0\n",
|
||
" else:\n",
|
||
" es_patience += 1\n",
|
||
" if es_patience >= cfg.early_stopping_patience:\n",
|
||
" best_es_epoch = es_epoch\n",
|
||
" best_es_f1 = es_f1\n",
|
||
"\n",
|
||
"if cfg.show_plot:\n",
|
||
" if cfg.plot_utils == 'matplot':\n",
|
||
" plt.plot(train_losses, 'x-')\n",
|
||
" plt.plot(valid_losses, '+-')\n",
|
||
" plt.legend(['train', 'valid'])\n",
|
||
" plt.title('train/valid comparison loss')\n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
"\n",
|
||
"print(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '\n",
|
||
" f'this epoch macro f1: {best_es_f1:0.4f}')\n",
|
||
"print(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '\n",
|
||
" f'this epoch macro f1: {best_f1:.4f}')\n",
|
||
"\n",
|
||
"test_f1, _ = validate(0, model, test_dataloader, criterion,verbose=False)\n",
|
||
"print(f'after {cfg.epoch} epochs, final test data macro f1: {test_f1:.4f}')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"本demo不包括调参部分,有兴趣的同学可以自行前往 [deepke](http://openkg.cn/tool/deepke) 仓库,下载使用更多的模型 :)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.3"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 2
|
||
}
|