fix bug
This commit is contained in:
parent
7d2cb7f4cf
commit
6015c7d48a
|
@ -2,6 +2,7 @@
|
|||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## attribution extraction experiment\n",
|
||||
"> Tutorial author: 陶联宽(22051063@zju.edu.cn)\n",
|
||||
|
@ -24,45 +25,46 @@
|
|||
"- valid.csv: It contains 2 training triples,each lines represent one triple,sorted by sentence,attribute,entity,entity's offset,attribute value attribute value's offset,and separated by ,.\n",
|
||||
"- test.csv: It contains 2 training triples,each lines represent one triple,sorted by sentence,attribute,entity,entity's offset,attribute value attribute value's offset,and separated by ,.\n",
|
||||
"- attribute.csv: It contains 3 attribute triples,each lines sorted by attribute,index and separated by ,."
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### BERT \n",
|
||||
"\n",
|
||||
"![BERT](img/Bert.png)\n",
|
||||
"\n",
|
||||
"After Bert coding, the original sentence can get rich semantic information. The obtained results are input into the bidirectional LSTM, and the output results can obtain the relationship information of the sentence.\n"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Code experience\n",
|
||||
"\n",
|
||||
"Important tips:\n",
|
||||
"- When we use pretrain language model, we need to load about 500MB of model data,so it is more recommended to download to the local and run.At this time, you only need to add `lm_file` value to the address of the local folder. See the link [transformers](https://huggingface.co/transformers/)"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Run the neural network with pytorch and confirm whether it is installed before running\n",
|
||||
"!pip install torch\n",
|
||||
"!pip install matplotlib\n",
|
||||
"!pip install transformers"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# import the whole modules\n",
|
||||
"import os\n",
|
||||
|
@ -79,17 +81,17 @@
|
|||
"from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
|
||||
"from torch.utils.data import Dataset,DataLoader\n",
|
||||
"from sklearn.metrics import precision_recall_fscore_support\n",
|
||||
"from typing import List, Tuple, Dict, Any, Sequence, Optional, Union\n",
|
||||
"from typing import List, Tuple, Dict, Any, Sequence, Optional, Union, OrderedDict\n",
|
||||
"from transformers import BertTokenizer, BertModel\n",
|
||||
"\n",
|
||||
"logger = logging.getLogger(__name__)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Configuration file of model parameters\n",
|
||||
"class Config(object):\n",
|
||||
|
@ -126,13 +128,13 @@
|
|||
" rnn_layers = 2\n",
|
||||
" \n",
|
||||
"cfg = Config()"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Functions required for preprocessing\n",
|
||||
"Path = str\n",
|
||||
|
@ -210,13 +212,13 @@
|
|||
" sent += '[SEP]' + d['entity'] + '[SEP]' + d['attribute_value']\n",
|
||||
" d['token2idx'] = tokenizer.encode(sent, add_special_tokens=True)\n",
|
||||
" d['seq_len'] = len(d['token2idx'])"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Preprocess\n",
|
||||
"logger.info('load raw files...')\n",
|
||||
|
@ -237,7 +239,7 @@
|
|||
"for d in test_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
" \n",
|
||||
"llogger.info('convert attribution into index...')\n",
|
||||
"logger.info('convert attribution into index...')\n",
|
||||
"atts = _handle_attribute_data(attribute_data)\n",
|
||||
"_add_attribute_data(atts,train_data)\n",
|
||||
"_add_attribute_data(atts,test_data)\n",
|
||||
|
@ -258,13 +260,13 @@
|
|||
"save_pkl(train_data, train_save_fp)\n",
|
||||
"save_pkl(valid_data, valid_save_fp)\n",
|
||||
"save_pkl(test_data, test_save_fp)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# pytorch construct Dataset\n",
|
||||
"def collate_fn(cfg):\n",
|
||||
|
@ -323,13 +325,13 @@
|
|||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.file)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# pretrain language model\n",
|
||||
"class PretrainLM(nn.Module):\n",
|
||||
|
@ -352,13 +354,13 @@
|
|||
" output = self.fc(output)\n",
|
||||
" \n",
|
||||
" return output"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# p,r,f1 measurement\n",
|
||||
"class PRMetric():\n",
|
||||
|
@ -384,13 +386,13 @@
|
|||
" _, _, acc, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='micro',warn_for=tuple())\n",
|
||||
"\n",
|
||||
" return acc,p,r,f1"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Iteration in training process\n",
|
||||
"def train(epoch, model, dataloader, optimizer, criterion, cfg):\n",
|
||||
|
@ -451,13 +453,13 @@
|
|||
" print(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]\\n\\n')\n",
|
||||
"\n",
|
||||
" return f1,loss"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load dataset\n",
|
||||
"train_dataset = CustomDataset(train_save_fp)\n",
|
||||
|
@ -467,13 +469,13 @@
|
|||
"train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||||
"valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||||
"test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# main entry, define optimization function, loss function and so on\n",
|
||||
"# start epoch\n",
|
||||
|
@ -529,22 +531,22 @@
|
|||
"\n",
|
||||
"test_f1, _ = validate(0, model, test_dataloader, criterion,verbose=False)\n",
|
||||
"print(f'after {cfg.epoch} epochs, final test data macro f1: {test_f1:.4f}')"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This demo does not include parameter adjustment. Interested students can go to [deepke] by themselves( http://openkg.cn/tool/deepke )Warehouse, download and use more models:)"
|
||||
],
|
||||
"metadata": {}
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "07ee17aed077b353900b50ce6f0ef17f1492499c86f09df07de696a5c0b76ad4"
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"display_name": "Python 3.8.11 64-bit ('deepke': conda)",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
|
@ -557,9 +559,9 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
"version": "3.8.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,771 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## attribution extraction experiment\n",
|
||||
"> Tutorial author: 陶联宽(22051063@zju.edu.cn)\n",
|
||||
"\n",
|
||||
"On this demo, we use `lstm` to extract attributions.\n",
|
||||
"We hope this demo can help you understand the process of construction knowledge graph and the principles and common methods of triplet extraction.\n",
|
||||
"\n",
|
||||
"This demo uses `Python3`.\n",
|
||||
"\n",
|
||||
"### Dataset\n",
|
||||
"In this example,we get some Chinese text to extract the triples\n",
|
||||
"\n",
|
||||
"sentence|attribute|entity|entity_offset|attribute_value|attribute_value_offset\n",
|
||||
":---:|:---:|:---:|:---:|:---:|:---:\n",
|
||||
"苏轼(1037~1101年),字子瞻,又字和仲,号“东坡居士”,眉州眉山(即今四川眉州)人,是宋代(北宋)著名的文学家、书画家|字|苏轼|0|和仲|21\n",
|
||||
"阳成俊,男,汉族,贵州省委党校大学学历|民族|阳成俊|0|汉族|6\n",
|
||||
"司马懿,字仲达,河南温县人|字|司马懿|0|仲达|6\n",
|
||||
"\n",
|
||||
"- train.csv: It contains 6 training triples,each lines represent one triple,sorted by sentence,attribute,entity,entity's offset,attribute value attribute value's offset,and separated by ,.\n",
|
||||
"- valid.csv: It contains 2 training triples,each lines represent one triple,sorted by sentence,attribute,entity,entity's offset,attribute value attribute value's offset,and separated by ,.\n",
|
||||
"- test.csv: It contains 2 training triples,each lines represent one triple,sorted by sentence,attribute,entity,entity's offset,attribute value attribute value's offset,and separated by ,.\n",
|
||||
"- attribute.csv: It contains 3 attribute triples,each lines sorted by attribute,index and separated by ,."
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"### LSTM\n",
|
||||
"\n",
|
||||
"![LSTM](img/LSTM.jpg)\n",
|
||||
"\n",
|
||||
"The sentence information mainly includes wording embedding.After the rnn layer,according to the position of entity,attribute key,it through the full connection layer, the attribution information of the sentence can be obtained.\n"
|
||||
],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Run the neural network with pytorch and confirm whether it is installed before running\n",
|
||||
"!pip install torch\n",
|
||||
"!pip install matplotlib\n",
|
||||
"!pip install transformers"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# import the whole modules\n",
|
||||
"import os\n",
|
||||
"import csv\n",
|
||||
"import math\n",
|
||||
"import pickle\n",
|
||||
"import logging\n",
|
||||
"import torch\n",
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from torch import optim\n",
|
||||
"from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
|
||||
"from torch.utils.data import Dataset,DataLoader\n",
|
||||
"from sklearn.metrics import precision_recall_fscore_support\n",
|
||||
"from typing import List, Tuple, Dict, Any, Sequence, Optional, Union\n",
|
||||
"from transformers import BertTokenizer, BertModel\n",
|
||||
"\n",
|
||||
"logger = logging.getLogger(__name__)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Configuration file of model parameters\n",
|
||||
"# use_pcnn Parameter controls whether there is a piece_ Wise pooling\n",
|
||||
"\n",
|
||||
"class Config(object):\n",
|
||||
" model_name = 'rnn' # ['cnn', 'gcn', 'lm','rnn']\n",
|
||||
" use_pcnn = False\n",
|
||||
" min_freq = 1\n",
|
||||
" pos_limit = 20\n",
|
||||
" out_path = 'data/out' \n",
|
||||
" batch_size = 2 \n",
|
||||
" word_dim = 10\n",
|
||||
" pos_dim = 5\n",
|
||||
" dim_strategy = 'sum' # ['sum', 'cat']\n",
|
||||
" out_channels = 20\n",
|
||||
" intermediate = 10\n",
|
||||
" kernel_sizes = [3, 5, 7]\n",
|
||||
" activation = 'gelu'\n",
|
||||
" pooling_strategy = 'max'\n",
|
||||
" dropout = 0.3\n",
|
||||
" epoch = 10\n",
|
||||
" num_relations = 4\n",
|
||||
" learning_rate = 3e-4\n",
|
||||
" lr_factor = 0.7 # 学习率的衰减率\n",
|
||||
" lr_patience = 3 # 学习率衰减的等待epoch\n",
|
||||
" weight_decay = 1e-3 # L2正则\n",
|
||||
" early_stopping_patience = 6\n",
|
||||
" train_log = True\n",
|
||||
" log_interval = 1\n",
|
||||
" show_plot = True\n",
|
||||
" only_comparison_plot = False\n",
|
||||
" plot_utils = 'matplot'\n",
|
||||
" lm_file = 'bert-base-chinese'\n",
|
||||
" lm_num_hidden_layers = 2\n",
|
||||
" rnn_layers = 2\n",
|
||||
" \n",
|
||||
"cfg = Config()"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Word token builds a one hot dictionary, and then inputs it to the embedding layer to obtain the corresponding word information matrix\n",
|
||||
"# 0 is pad by default and 1 is unknown\n",
|
||||
"class Vocab(object):\n",
|
||||
" def __init__(self, name: str = 'basic', init_tokens = [\"[PAD]\", \"[UNK]\"]):\n",
|
||||
" self.name = name\n",
|
||||
" self.init_tokens = init_tokens\n",
|
||||
" self.trimed = False\n",
|
||||
" self.word2idx = {}\n",
|
||||
" self.word2count = {}\n",
|
||||
" self.idx2word = {}\n",
|
||||
" self.count = 0\n",
|
||||
" self._add_init_tokens()\n",
|
||||
"\n",
|
||||
" def _add_init_tokens(self):\n",
|
||||
" for token in self.init_tokens:\n",
|
||||
" self._add_word(token)\n",
|
||||
"\n",
|
||||
" def _add_word(self, word: str):\n",
|
||||
" if word not in self.word2idx:\n",
|
||||
" self.word2idx[word] = self.count\n",
|
||||
" self.word2count[word] = 1\n",
|
||||
" self.idx2word[self.count] = word\n",
|
||||
" self.count += 1\n",
|
||||
" else:\n",
|
||||
" self.word2count[word] += 1\n",
|
||||
"\n",
|
||||
" def add_words(self, words: Sequence):\n",
|
||||
" for word in words:\n",
|
||||
" self._add_word(word)\n",
|
||||
"\n",
|
||||
" def trim(self, min_freq=2, verbose: Optional[bool] = True):\n",
|
||||
" assert min_freq == int(min_freq), f'min_freq must be integer, can\\'t be {min_freq}'\n",
|
||||
" min_freq = int(min_freq)\n",
|
||||
" if min_freq < 2:\n",
|
||||
" return\n",
|
||||
" if self.trimed:\n",
|
||||
" return\n",
|
||||
" self.trimed = True\n",
|
||||
"\n",
|
||||
" keep_words = []\n",
|
||||
" new_words = []\n",
|
||||
"\n",
|
||||
" for k, v in self.word2count.items():\n",
|
||||
" if v >= min_freq:\n",
|
||||
" keep_words.append(k)\n",
|
||||
" new_words.extend([k] * v)\n",
|
||||
" if verbose:\n",
|
||||
" before_len = len(keep_words)\n",
|
||||
" after_len = len(self.word2idx) - len(self.init_tokens)\n",
|
||||
" logger.info('vocab after be trimmed, keep words [{} / {}] = {:.2f}%'.format(before_len, after_len, before_len / after_len * 100))\n",
|
||||
"\n",
|
||||
" # Reinitialize dictionaries\n",
|
||||
" self.word2idx = {}\n",
|
||||
" self.word2count = {}\n",
|
||||
" self.idx2word = {}\n",
|
||||
" self.count = 0\n",
|
||||
" self._add_init_tokens()\n",
|
||||
" self.add_words(new_words)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Functions required for preprocessing\n",
|
||||
"Path = str\n",
|
||||
"\n",
|
||||
"def load_csv(fp: Path, is_tsv: bool = False, verbose: bool = True) -> List:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'load csv from {fp}')\n",
|
||||
"\n",
|
||||
" dialect = 'excel-tab' if is_tsv else 'excel'\n",
|
||||
" with open(fp, encoding='utf-8') as f:\n",
|
||||
" reader = csv.DictReader(f, dialect=dialect)\n",
|
||||
" return list(reader)\n",
|
||||
"\n",
|
||||
" \n",
|
||||
"def load_pkl(fp: Path, verbose: bool = True) -> Any:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'load data from {fp}')\n",
|
||||
"\n",
|
||||
" with open(fp, 'rb') as f:\n",
|
||||
" data = pickle.load(f)\n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def save_pkl(data: Any, fp: Path, verbose: bool = True) -> None:\n",
|
||||
" if verbose:\n",
|
||||
" logger.info(f'save data in {fp}')\n",
|
||||
"\n",
|
||||
" with open(fp, 'wb') as f:\n",
|
||||
" pickle.dump(data, f)\n",
|
||||
" \n",
|
||||
"def _handle_attribute_data(attribute_data: List[Dict]) -> Dict:\n",
|
||||
" atts = OrderedDict()\n",
|
||||
" attribute_data = sorted(attribute_data, key=lambda i: int(i['index']))\n",
|
||||
" for d in attribute_data:\n",
|
||||
" atts[d['attribute']] = {\n",
|
||||
" 'index': int(d['index'])\n",
|
||||
" }\n",
|
||||
" return atts\n",
|
||||
"\n",
|
||||
"def _add_attribute_data(atts: Dict, data: List) -> None:\n",
|
||||
" for d in data:\n",
|
||||
" d['att2idx'] = atts[d['attribute']]['index']\n",
|
||||
"\n",
|
||||
"def _convert_tokens_into_index(data: List[Dict], vocab):\n",
|
||||
" unk_str = '[UNK]'\n",
|
||||
" unk_idx = vocab.word2idx[unk_str]\n",
|
||||
"\n",
|
||||
" for d in data:\n",
|
||||
" d['token2idx'] = [vocab.word2idx.get(i, unk_idx) for i in d['tokens']]\n",
|
||||
"\n",
|
||||
"def _add_pos_seq(train_data: List[Dict], cfg):\n",
|
||||
" for d in train_data:\n",
|
||||
" d['entity_pos'] = list(map(lambda i: i - d['entity_index'], list(range(d['seq_len']))))\n",
|
||||
" d['entity_pos'] = _handle_pos_limit(d['entity_pos'],int(cfg.pos_limit))\n",
|
||||
"\n",
|
||||
" d['attribute_value_pos'] = list(map(lambda i: i - d['attribute_value_index'], list(range(d['seq_len']))))\n",
|
||||
" d['attribute_value_pos'] = _handle_pos_limit(d['attribute_value_pos'],int(cfg.pos_limit))\n",
|
||||
" \n",
|
||||
"def _handle_pos_limit(pos: List[int], limit: int) -> List[int]:\n",
|
||||
" for i,p in enumerate(pos):\n",
|
||||
" if p > limit:\n",
|
||||
" pos[i] = limit\n",
|
||||
" if p < -limit:\n",
|
||||
" pos[i] = -limit\n",
|
||||
" return [p + limit + 1 for p in pos]\n",
|
||||
"\n",
|
||||
"def seq_len_to_mask(seq_len: Union[List, np.ndarray, torch.Tensor], max_len=None, mask_pos_to_true=True):\n",
|
||||
" if isinstance(seq_len, list):\n",
|
||||
" seq_len = np.array(seq_len)\n",
|
||||
"\n",
|
||||
" if isinstance(seq_len, np.ndarray):\n",
|
||||
" seq_len = torch.from_numpy(seq_len)\n",
|
||||
"\n",
|
||||
" if isinstance(seq_len, torch.Tensor):\n",
|
||||
" assert seq_len.dim() == 1, logger.error(f\"seq_len can only have one dimension, got {seq_len.dim()} != 1.\")\n",
|
||||
" batch_size = seq_len.size(0)\n",
|
||||
" max_len = int(max_len) if max_len else seq_len.max().long()\n",
|
||||
" broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len.device)\n",
|
||||
" if mask_pos_to_true:\n",
|
||||
" mask = broad_cast_seq_len.ge(seq_len.unsqueeze(1))\n",
|
||||
" else:\n",
|
||||
" mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))\n",
|
||||
" else:\n",
|
||||
" raise logger.error(\"Only support 1-d list or 1-d numpy.ndarray or 1-d torch.Tensor.\")\n",
|
||||
"\n",
|
||||
" return mask"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Preprocess\n",
|
||||
"logger.info('load raw files...')\n",
|
||||
"train_fp = os.path.join('data/train.csv')\n",
|
||||
"valid_fp = os.path.join('data/valid.csv')\n",
|
||||
"test_fp = os.path.join('data/test.csv')\n",
|
||||
"attribute_fp = os.path.join('data/attribute.csv')\n",
|
||||
"\n",
|
||||
"train_data = load_csv(train_fp)\n",
|
||||
"valid_data = load_csv(valid_fp)\n",
|
||||
"test_data = load_csv(test_fp)\n",
|
||||
"attribute_data = load_csv(attribute_fp)\n",
|
||||
"\n",
|
||||
"for d in train_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
"for d in valid_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
"for d in test_data:\n",
|
||||
" d['tokens'] = eval(d['tokens'])\n",
|
||||
"\n",
|
||||
"logger.info('convert attribution into index...')\n",
|
||||
"atts = _handle_attribute_data(attribute_data)\n",
|
||||
"_add_attribute_data(atts,train_data)\n",
|
||||
"_add_attribute_data(atts,test_data)\n",
|
||||
"_add_attribute_data(atts,valid_data)\n",
|
||||
"\n",
|
||||
"logger.info('build vocabulary...')\n",
|
||||
"vocab = Vocab('word')\n",
|
||||
"train_tokens = [d['tokens'] for d in train_data]\n",
|
||||
"valid_tokens = [d['tokens'] for d in valid_data]\n",
|
||||
"test_tokens = [d['tokens'] for d in test_data]\n",
|
||||
"sent_tokens = [*train_tokens, *valid_tokens, *test_tokens]\n",
|
||||
"for sent in sent_tokens:\n",
|
||||
" vocab.add_words(sent)\n",
|
||||
"vocab.trim(min_freq=cfg.min_freq)\n",
|
||||
"\n",
|
||||
"logger.info('convert tokens into index...')\n",
|
||||
"_convert_tokens_into_index(train_data, vocab)\n",
|
||||
"_convert_tokens_into_index(valid_data, vocab)\n",
|
||||
"_convert_tokens_into_index(test_data, vocab)\n",
|
||||
"\n",
|
||||
"logger.info('build position sequence...')\n",
|
||||
"_add_pos_seq(train_data, cfg)\n",
|
||||
"_add_pos_seq(valid_data, cfg)\n",
|
||||
"_add_pos_seq(test_data, cfg)\n",
|
||||
"\n",
|
||||
"logger.info('save data for backup...')\n",
|
||||
"os.makedirs(cfg.out_path, exist_ok=True)\n",
|
||||
"train_save_fp = os.path.join(cfg.out_path, 'train.pkl')\n",
|
||||
"valid_save_fp = os.path.join(cfg.out_path, 'valid.pkl')\n",
|
||||
"test_save_fp = os.path.join(cfg.out_path, 'test.pkl')\n",
|
||||
"save_pkl(train_data, train_save_fp)\n",
|
||||
"save_pkl(valid_data, valid_save_fp)\n",
|
||||
"save_pkl(test_data, test_save_fp)\n",
|
||||
"\n",
|
||||
"vocab_save_fp = os.path.join(cfg.out_path, 'vocab.pkl')\n",
|
||||
"vocab_txt = os.path.join(cfg.out_path, 'vocab.txt')\n",
|
||||
"save_pkl(vocab, vocab_save_fp)\n",
|
||||
"logger.info('save vocab in txt file, for watching...')\n",
|
||||
"with open(vocab_txt, 'w', encoding='utf-8') as f:\n",
|
||||
" f.write(os.linesep.join(vocab.word2idx.keys()))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# pytorch construct Dataset\n",
|
||||
"def collate_fn(cfg):\n",
|
||||
" def collate_fn_intra(batch):\n",
|
||||
" batch.sort(key=lambda data: data['seq_len'], reverse=True)\n",
|
||||
" max_len = batch[0]['seq_len']\n",
|
||||
"\n",
|
||||
" def _padding(x, max_len):\n",
|
||||
" return x + [0] * (max_len - len(x))\n",
|
||||
"\n",
|
||||
" x, y = dict(), []\n",
|
||||
" word, word_len = [], []\n",
|
||||
" head_pos, tail_pos = [], []\n",
|
||||
" pcnn_mask = []\n",
|
||||
" for data in batch:\n",
|
||||
" word.append(_padding(data['token2idx'], max_len))\n",
|
||||
" word_len.append(data['seq_len'])\n",
|
||||
" y.append(int(data['att2idx']))\n",
|
||||
"\n",
|
||||
" if cfg.model_name != 'lm':\n",
|
||||
" head_pos.append(_padding(data['entity_pos'], max_len))\n",
|
||||
" tail_pos.append(_padding(data['attribute_value_pos'], max_len))\n",
|
||||
" if cfg.model_name == 'cnn':\n",
|
||||
" if cfg.use_pcnn:\n",
|
||||
" pcnn_mask.append(_padding(data['entities_pos'], max_len))\n",
|
||||
"\n",
|
||||
" x['word'] = torch.tensor(word)\n",
|
||||
" x['lens'] = torch.tensor(word_len)\n",
|
||||
" y = torch.tensor(y)\n",
|
||||
"\n",
|
||||
" if cfg.model_name != 'lm':\n",
|
||||
" x['entity_pos'] = torch.tensor(head_pos)\n",
|
||||
" x['attribute_value_pos'] = torch.tensor(tail_pos)\n",
|
||||
" if cfg.model_name == 'cnn' and cfg.use_pcnn:\n",
|
||||
" x['pcnn_mask'] = torch.tensor(pcnn_mask)\n",
|
||||
" if cfg.model_name == 'gcn':\n",
|
||||
" # 没找到合适的做 parsing tree 的工具,暂时随机初始化\n",
|
||||
" B, L = len(batch), max_len\n",
|
||||
" adj = torch.empty(B, L, L).random_(2)\n",
|
||||
" x['adj'] = adj\n",
|
||||
" return x, y\n",
|
||||
"\n",
|
||||
" return collate_fn_intra\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class CustomDataset(Dataset):\n",
|
||||
" \"\"\"\n",
|
||||
" 默认使用 List 存储数据\n",
|
||||
" \"\"\"\n",
|
||||
" def __init__(self, fp):\n",
|
||||
" self.file = load_pkl(fp)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, item):\n",
|
||||
" sample = self.file[item]\n",
|
||||
" return sample\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.file)"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# embedding layer\n",
|
||||
"class Embedding(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
" \"\"\"\n",
|
||||
" word embedding: 一般 0 为 padding\n",
|
||||
" pos embedding: 一般 0 为 padding\n",
|
||||
" dim_strategy: [cat, sum] 多个 embedding 是拼接还是相加\n",
|
||||
" \"\"\"\n",
|
||||
" super(Embedding, self).__init__()\n",
|
||||
"\n",
|
||||
" # self.xxx = config.xxx\n",
|
||||
" self.vocab_size = config.vocab_size\n",
|
||||
" self.word_dim = config.word_dim\n",
|
||||
" self.pos_size = config.pos_size\n",
|
||||
" self.pos_dim = config.pos_dim if config.dim_strategy == 'cat' else config.word_dim\n",
|
||||
" self.dim_strategy = config.dim_strategy\n",
|
||||
"\n",
|
||||
" self.wordEmbed = nn.Embedding(self.vocab_size, self.word_dim, padding_idx=0)\n",
|
||||
" self.entityPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)\n",
|
||||
" self.attribute_keyPosEmbed = nn.Embedding(self.pos_size, self.pos_dim, padding_idx=0)\n",
|
||||
" \n",
|
||||
" self.layer_norm = nn.LayerNorm(self.word_dim)\n",
|
||||
"\n",
|
||||
" def forward(self, *x):\n",
|
||||
" word, entity, attribute_key = x\n",
|
||||
" word_embedding = self.wordEmbed(word)\n",
|
||||
" entity_embedding = self.entityPosEmbed(entity)\n",
|
||||
" attribute_key_embedding = self.attribute_keyPosEmbed(attribute_key)\n",
|
||||
"\n",
|
||||
" if self.dim_strategy == 'cat':\n",
|
||||
" return torch.cat((word_embedding, entity_embedding, attribute_key_embedding), -1)\n",
|
||||
" elif self.dim_strategy == 'sum':\n",
|
||||
" # 此时 pos_dim == word_dim\n",
|
||||
" return self.layer_norm(word_embedding + entity_embedding + attribute_key_embedding)\n",
|
||||
" else:\n",
|
||||
" raise Exception('dim_strategy must choose from [sum, cat]')"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Gelu activation function, specified by transformer, works better than relu\n",
|
||||
"class GELU(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(GELU, self).__init__()\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# cnn model\n",
|
||||
"class RNN(nn.Module):\n",
|
||||
" def __init__(self, config):\n",
|
||||
" \"\"\"\n",
|
||||
" type_rnn: RNN, GRU, LSTM 可选\n",
|
||||
" \"\"\"\n",
|
||||
" super(RNN, self).__init__()\n",
|
||||
"\n",
|
||||
" # self.xxx = config.xxx\n",
|
||||
" self.input_size = config.input_size\n",
|
||||
" self.hidden_size = config.hidden_size // 2 if config.bidirectional else config.hidden_size\n",
|
||||
" self.num_layers = config.num_layers\n",
|
||||
" self.dropout = config.dropout\n",
|
||||
" self.bidirectional = config.bidirectional\n",
|
||||
" self.last_layer_hn = config.last_layer_hn\n",
|
||||
" self.type_rnn = config.type_rnn\n",
|
||||
"\n",
|
||||
" rnn = eval(f'nn.{self.type_rnn}')\n",
|
||||
" self.rnn = rnn(input_size=self.input_size,\n",
|
||||
" hidden_size=self.hidden_size,\n",
|
||||
" num_layers=self.num_layers,\n",
|
||||
" dropout=self.dropout,\n",
|
||||
" bidirectional=self.bidirectional,\n",
|
||||
" bias=True,\n",
|
||||
" batch_first=True)\n",
|
||||
"\n",
|
||||
" # 有bug\n",
|
||||
" # self._init_weights()\n",
|
||||
"\n",
|
||||
" def _init_weights(self):\n",
|
||||
" \"\"\"orthogonal init yields generally good results than uniform init\"\"\"\n",
|
||||
" gain = 1 # use default value\n",
|
||||
" for nth in range(self.num_layers * self.bidirectional):\n",
|
||||
" # w_ih, (4 * hidden_size x input_size)\n",
|
||||
" nn.init.orthogonal_(self.rnn.all_weights[nth][0], gain=gain)\n",
|
||||
" # w_hh, (4 * hidden_size x hidden_size)\n",
|
||||
" nn.init.orthogonal_(self.rnn.all_weights[nth][1], gain=gain)\n",
|
||||
" # b_ih, (4 * hidden_size)\n",
|
||||
" nn.init.zeros_(self.rnn.all_weights[nth][2])\n",
|
||||
" # b_hh, (4 * hidden_size)\n",
|
||||
" nn.init.zeros_(self.rnn.all_weights[nth][3])\n",
|
||||
"\n",
|
||||
" def forward(self, x, x_len):\n",
|
||||
" \"\"\"\n",
|
||||
" Args: \n",
|
||||
" torch.Tensor [batch_size, seq_max_length, input_size], [B, L, H_in] 一般是经过embedding后的值\n",
|
||||
" x_len: torch.Tensor [L] 已经排好序的句长值\n",
|
||||
" Returns:\n",
|
||||
" output: torch.Tensor [B, L, H_out] 序列标注的使用结果\n",
|
||||
" hn: torch.Tensor [B, N, H_out] / [B, H_out] 分类的结果,当 last_layer_hn 时只有最后一层结果\n",
|
||||
" \"\"\"\n",
|
||||
" B, L, _ = x.size()\n",
|
||||
" H, N = self.hidden_size, self.num_layers\n",
|
||||
"\n",
|
||||
" x_len = x_len.cpu()\n",
|
||||
" x = pack_padded_sequence(x, x_len, batch_first=True, enforce_sorted=True)\n",
|
||||
" output, hn = self.rnn(x)\n",
|
||||
" output, _ = pad_packed_sequence(output, batch_first=True, total_length=L)\n",
|
||||
"\n",
|
||||
" if self.type_rnn == 'LSTM':\n",
|
||||
" hn = hn[0]\n",
|
||||
" if self.bidirectional:\n",
|
||||
" hn = hn.view(N, 2, B, H).transpose(1, 2).contiguous().view(N, B, 2 * H).transpose(0, 1)\n",
|
||||
" else:\n",
|
||||
" hn = hn.transpose(0, 1)\n",
|
||||
" if self.last_layer_hn:\n",
|
||||
" hn = hn[:, -1, :]\n",
|
||||
"\n",
|
||||
" return output, hn\n"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# p,r,f1 measurement\n",
|
||||
"class PRMetric():\n",
|
||||
" def __init__(self):\n",
|
||||
" \n",
|
||||
" self.y_true = np.empty(0)\n",
|
||||
" self.y_pred = np.empty(0)\n",
|
||||
"\n",
|
||||
" def reset(self):\n",
|
||||
" self.y_true = np.empty(0)\n",
|
||||
" self.y_pred = np.empty(0)\n",
|
||||
"\n",
|
||||
" def update(self, y_true:torch.Tensor, y_pred:torch.Tensor):\n",
|
||||
" y_true = y_true.cpu().detach().numpy()\n",
|
||||
" y_pred = y_pred.cpu().detach().numpy()\n",
|
||||
" y_pred = np.argmax(y_pred,axis=-1)\n",
|
||||
"\n",
|
||||
" self.y_true = np.append(self.y_true, y_true)\n",
|
||||
" self.y_pred = np.append(self.y_pred, y_pred)\n",
|
||||
"\n",
|
||||
" def compute(self):\n",
|
||||
" p, r, f1, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='macro',warn_for=tuple())\n",
|
||||
" _, _, acc, _ = precision_recall_fscore_support(self.y_true,self.y_pred,average='micro',warn_for=tuple())\n",
|
||||
"\n",
|
||||
" return acc,p,r,f1"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Iteration in training process\n",
|
||||
"def train(epoch, model, dataloader, optimizer, criterion, cfg):\n",
|
||||
" model.train()\n",
|
||||
"\n",
|
||||
" metric = PRMetric()\n",
|
||||
" losses = []\n",
|
||||
"\n",
|
||||
" for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" y_pred = model(x)\n",
|
||||
" loss = criterion(y_pred, y)\n",
|
||||
"\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
"\n",
|
||||
" metric.update(y_true=y, y_pred=y_pred)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
"\n",
|
||||
" data_total = len(dataloader.dataset)\n",
|
||||
" data_cal = data_total if batch_idx == len(dataloader) else batch_idx * len(y)\n",
|
||||
" if (cfg.train_log and batch_idx % cfg.log_interval == 0) or batch_idx == len(dataloader):\n",
|
||||
" acc,p,r,f1 = metric.compute()\n",
|
||||
" print(f'Train Epoch {epoch}: [{data_cal}/{data_total} ({100. * data_cal / data_total:.0f}%)]\\t'\n",
|
||||
" f'Loss: {loss.item():.6f}')\n",
|
||||
" print(f'Train Epoch {epoch}: Acc: {100. * acc:.2f}%\\t'\n",
|
||||
" f'macro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]')\n",
|
||||
"\n",
|
||||
" if cfg.show_plot and not cfg.only_comparison_plot:\n",
|
||||
" if cfg.plot_utils == 'matplot':\n",
|
||||
" plt.plot(losses)\n",
|
||||
" plt.title(f'epoch {epoch} train loss')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
" return losses[-1]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Iteration in testing process\n",
|
||||
"def validate(epoch, model, dataloader, criterion,verbose=True):\n",
|
||||
" model.eval()\n",
|
||||
"\n",
|
||||
" metric = PRMetric()\n",
|
||||
" losses = []\n",
|
||||
"\n",
|
||||
" for batch_idx, (x, y) in enumerate(dataloader, 1):\n",
|
||||
" with torch.no_grad():\n",
|
||||
" y_pred = model(x)\n",
|
||||
" loss = criterion(y_pred, y)\n",
|
||||
"\n",
|
||||
" metric.update(y_true=y, y_pred=y_pred)\n",
|
||||
" losses.append(loss.item())\n",
|
||||
"\n",
|
||||
" loss = sum(losses) / len(losses)\n",
|
||||
" acc,p,r,f1 = metric.compute()\n",
|
||||
" data_total = len(dataloader.dataset)\n",
|
||||
" if verbose:\n",
|
||||
" print(f'Valid Epoch {epoch}: [{data_total}/{data_total}](100%)\\t Loss: {loss:.6f}')\n",
|
||||
" print(f'Valid Epoch {epoch}: Acc: {100. * acc:.2f}%\\tmacro metrics: [p: {p:.4f}, r:{r:.4f}, f1:{f1:.4f}]\\n\\n')\n",
|
||||
"\n",
|
||||
" return f1,loss"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# Load dataset\n",
|
||||
"train_dataset = CustomDataset(train_save_fp)\n",
|
||||
"valid_dataset = CustomDataset(valid_save_fp)\n",
|
||||
"test_dataset = CustomDataset(test_save_fp)\n",
|
||||
"\n",
|
||||
"train_dataloader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||||
"valid_dataloader = DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))\n",
|
||||
"test_dataloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# After the preprocessed data is loaded, vocab_size is known\n",
|
||||
"vocab = load_pkl(vocab_save_fp)\n",
|
||||
"vocab_size = vocab.count\n",
|
||||
"cfg.vocab_size = vocab_size"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"source": [
|
||||
"# main entry, define optimization function, loss function and so on\n",
|
||||
"# start epoch\n",
|
||||
"# Use the loss of the valid dataset to make an early stop judgment. When it does not decline, this is the time when the model generalization is the best.\n",
|
||||
"model = RNN(cfg)\n",
|
||||
"print(model)\n",
|
||||
"\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)\n",
|
||||
"scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=cfg.lr_factor, patience=cfg.lr_patience)\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"\n",
|
||||
"best_f1, best_epoch = -1, 0\n",
|
||||
"es_loss, es_f1, es_epoch, es_patience, best_es_epoch, best_es_f1, = 1000, -1, 0, 0, 0, -1\n",
|
||||
"train_losses, valid_losses = [], []\n",
|
||||
"\n",
|
||||
"logger.info('=' * 10 + ' Start training ' + '=' * 10)\n",
|
||||
"for epoch in range(1, cfg.epoch + 1):\n",
|
||||
" train_loss = train(epoch, model, train_dataloader, optimizer, criterion, cfg)\n",
|
||||
" valid_f1, valid_loss = validate(epoch, model, valid_dataloader, criterion)\n",
|
||||
" scheduler.step(valid_loss)\n",
|
||||
"\n",
|
||||
" train_losses.append(train_loss)\n",
|
||||
" valid_losses.append(valid_loss)\n",
|
||||
" if best_f1 < valid_f1:\n",
|
||||
" best_f1 = valid_f1\n",
|
||||
" best_epoch = epoch\n",
|
||||
" # 使用 valid loss 做 early stopping 的判断标准\n",
|
||||
" if es_loss > valid_loss:\n",
|
||||
" es_loss = valid_loss\n",
|
||||
" es_f1 = valid_f1\n",
|
||||
" best_es_f1 = valid_f1\n",
|
||||
" es_epoch = epoch\n",
|
||||
" best_es_epoch = epoch\n",
|
||||
" es_patience = 0\n",
|
||||
" else:\n",
|
||||
" es_patience += 1\n",
|
||||
" if es_patience >= cfg.early_stopping_patience:\n",
|
||||
" best_es_epoch = es_epoch\n",
|
||||
" best_es_f1 = es_f1\n",
|
||||
"\n",
|
||||
"if cfg.show_plot:\n",
|
||||
" if cfg.plot_utils == 'matplot':\n",
|
||||
" plt.plot(train_losses, 'x-')\n",
|
||||
" plt.plot(valid_losses, '+-')\n",
|
||||
" plt.legend(['train', 'valid'])\n",
|
||||
" plt.title('train/valid comparison loss')\n",
|
||||
" plt.show()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(f'best(valid loss quota) early stopping epoch: {best_es_epoch}, '\n",
|
||||
" f'this epoch macro f1: {best_es_f1:0.4f}')\n",
|
||||
"print(f'total {cfg.epoch} epochs, best(valid macro f1) epoch: {best_epoch}, '\n",
|
||||
" f'this epoch macro f1: {best_f1:.4f}')\n",
|
||||
"\n",
|
||||
"test_f1, _ = validate(0, model, test_dataloader, criterion,verbose=False)\n",
|
||||
"print(f'after {cfg.epoch} epochs, final test data macro f1: {test_f1:.4f}')"
|
||||
],
|
||||
"outputs": [],
|
||||
"metadata": {}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"This demo does not include parameter adjustment. Interested students can go to [deepke] by themselves( http://openkg.cn/tool/deepke )Warehouse, download and use more models:)"
|
||||
],
|
||||
"metadata": {}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -1,3 +1,3 @@
|
|||
sentence,attribute,entity,entity_offset,attribute_value,attribute_value_offset
|
||||
柳为易,女,1989年5月出生,中共党员 ,汉族,重庆市人,民族,柳为易,0,汉族,22
|
||||
庄肇奎 (1728-1798) 榜姓杜,字星堂,号胥园,江苏武进籍,浙江秀水(今嘉兴)人,字,庄肇奎,0,星堂,23
|
||||
sentence,attribute,entity,entity_offset,attribute_value,attribute_value_offset,len,tokens,dependency
|
||||
柳为易,女,1989年5月出生,中共党员 ,汉族,重庆市人,民族,柳为易,0,汉族,22,17,"['柳为', '易', ',', '女', ',', '1989', '年', '5', '月', '出生', ',', '中共党员', ',', '汉族', ',', '重庆市', '人']","[1, 1, 12, 1, 12, 1, 10, 1, 1, 1, 12, 16, 12, 15, 12, 1, 13]"
|
||||
庄肇奎 (1728-1798) 榜姓杜,字星堂,号胥园,江苏武进籍,浙江秀水(今嘉兴)人,字,庄肇奎,0,星堂,23,23,"['庄肇奎', '(', '1728', '-', '1798', ')', '榜姓', '杜', ',', '字星堂', ',', '号', '胥园', ',', '江苏', '武进', '籍', ',', '浙江', '秀水', '(', '今', '嘉兴', ')', '人']","[1, 9, 1, 1, 1, 11, 1, 1, 12, 1, 12, 1, 1, 12, 1, 1, 1, 12, 1, 1, 9, 1, 1, 11, 13]"
|
|
|
@ -1,7 +1,5 @@
|
|||
sentence,attribute,entity,entity_offset,attribute_value,attribute_value_offset
|
||||
苏轼(1037~1101年),字子瞻,又字和仲,号“东坡居士”,眉州眉山(即今四川眉州)人,是宋代(北宋)著名的文学家、书画家,字,苏轼,0,和仲,21
|
||||
屈中乾,男,汉族,中共党员,特级教师,民族,屈中乾,0,汉族,6
|
||||
阳成俊,男,汉族,贵州省委党校大学学历,民族,阳成俊,0,汉族,6
|
||||
黄向静,女,汉族,1965年5月生,大学学历,1986年17月参加工作,中共党员,身体健康,民族,黄向静,0,汉族,6
|
||||
生平简介陈执中(990-1059),字昭誉,名相陈恕之子,北宋洪州南昌(今属江西)人,字,陈执中,4,昭誉,19
|
||||
司马懿,字仲达,河南温县人,字,司马懿,0,仲达,5
|
||||
sentence,attribute,entity,entity_offset,attribute_value,attribute_value_offset,len,tokens,dependency
|
||||
苏轼(1037~1101年),字子瞻,又字和仲,号“东坡居士”,眉州眉山(即今四川眉州)人,是宋代(北宋)著名的文学家、书画家,字,苏轼,0,和仲,21,42,"['苏轼', '(', '1037', '~', '1101', '年', ')', ',', '字子', '瞻', ',', '又', '字', '和', '仲', ',', '号', '“', '东坡', '居士', '”', ',', '眉州', '眉山', '(', '即', '今', '四川', '眉州', ')', '人', ',', '是', '宋代', '(', '北宋', ')', '著名', '的', '文学家', '、', '书画家']","[1, 9, 1, 1, 1, 10, 11, 12, 1, 1, 12, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 12, 1, 1, 9, 1, 1, 1, 1, 11, 13, 12, 1, 1, 9, 1, 11, 1, 1, 1, 1, 1]"
|
||||
屈中乾,男,汉族,中共党员,特级教师,民族,屈中乾,0,汉族,6,10,"['屈中', '乾', ',', '男', ',', '汉族', ',', '中共党员', ',', '特级教师']","[1, 1, 12, 14, 12, 15, 12, 16, 12, 1]"
|
||||
黄向静,女,汉族,1965年5月生,大学学历,1986年17月参加工作,中共党员,身体健康,民族,黄向静,0,汉族,6,24,"['黄向静', ',', '女', ',', '汉族', ',', '1965', '年', '5', '月生', ',', '大学', '学历', ',', '1986', '年', '17', '月', '参加', '工作', ',', '中共党员', ',', '身体健康']","[1, 12, 1, 12, 15, 12, 1, 10, 1, 1, 12, 1, 1, 12, 1, 10, 1, 1, 1, 1, 12, 16, 12, 1]"
|
||||
司马懿,字仲达,河南温县人,字,司马懿,0,仲达,5,7,"['司马懿', ',', '字仲达', ',', '河南', '温县', '人']","[1, 12, 1, 12, 1, 1, 13]"
|
|
|
@ -1,3 +1,3 @@
|
|||
sentence,attribute,entity,entity_offset,attribute_value,attribute_value_offset
|
||||
田承冉 男,1952年生,汉族,山东桓台人,共党员,民族,田承冉,0,汉族,13
|
||||
冷家骥,字展麒,山东招远人,字,冷家骥,0,展麒,5
|
||||
sentence,attribute,entity,entity_offset,attribute_value,attribute_value_offset,len,tokens,dependency
|
||||
田承冉 男,1952年生,汉族,山东桓台人,共党员,民族,田承冉,0,汉族,13,14,"['田承冉', '男', ',', '1952', '年生', ',', '汉族', ',', '山东', '桓台', '人', ',', '共', '党员']","[1, 14, 12, 1, 1, 12, 15, 12, 1, 1, 13, 12, 1, 1]"
|
||||
冷家骥,字展麒,山东招远人,字,冷家骥,0,展麒,5,8,"['冷家骥', ',', '字展', '麒', ',', '山东', '招远', '人']","[1, 12, 1, 1, 12, 1, 1, 13]"
|
|
Loading…
Reference in New Issue