This commit is contained in:
xxupiano 2021-10-07 16:49:04 +08:00
parent aa7468660e
commit 2bbb9fb8a3
94 changed files with 1336257 additions and 4984759 deletions

View File

@ -45,7 +45,7 @@ DeepKE包括了三个模块可以进行关系抽取、实体命名识别以
> python == 3.8
- torch >= 1.2
- torch >= 1.5
- hydra-core == 1.0.6
- tensorboard >= 2.0
- matplotlib >= 3.1
@ -62,11 +62,11 @@ DeepKE包括了三个模块可以进行关系抽取、实体命名识别以
1. **命名实体识别NER**
数据为txt文件样式范例为
| Sentence | Person | Location | Organization |
| :----------------------------------------------------------: | :------------------------: | :----------: | :----------------------------: |
| 本报北京9月4日讯记者杨涌报道部分省区人民日报宣传发行工作座谈会9月3日在4日在京举行。 | 杨涌 | 北京,京 | 人民日报 |
| 《红楼梦》是中央电视台和中国电视剧制作中心根据中国古典文学名著《红楼梦》摄制于1987年的一部古装连续剧由王扶林导演周汝昌、王蒙、周岭等多位红学家参与制作。 | 王扶林,周汝昌,王蒙,周岭 | 中国 | 中央电视台,中国电视剧制作中心 |
| 秦始皇兵马俑位于陕西省西安市1961年被国务院公布为第一批全国重点文物保护单位是世界八大奇迹之一。 | 秦始皇 | 陕西省西安市 | 国务院 |
| Sentence | Person | Location | Organization |
| :----------------------------------------------------------: | :------------------------: | :------------: | :----------------------------: |
| 本报北京9月4日讯记者杨涌报道部分省区人民日报宣传发行工作座谈会9月3日在4日在京举行。 | 杨涌 | 北京 | 人民日报 |
| 《红楼梦》是中央电视台和中国电视剧制作中心根据中国古典文学名著《红楼梦》摄制于1987年的一部古装连续剧由王扶林导演周汝昌、王蒙、周岭等多位红学家参与制作。 | 王扶林,周汝昌,王蒙,周岭 | 中国 | 中央电视台,中国电视剧制作中心 |
| 秦始皇兵马俑位于陕西省西安市1961年被国务院公布为第一批全国重点文物保护单位是世界八大奇迹之一。 | 秦始皇 | 陕西省西安市 | 国务院 |
具体流程请进入详细的README中

View File

@ -5,11 +5,11 @@
> python == 3.8
- pytorch-transformers == 1.2.0
- torch == 1.2.0
- torch == 1.5.0
- hydra-core == 1.0.6
- seqeval == 0.0.5
- tqdm == 4.31.1
- nltk == 3.4.5
- matplotlib == 3.4.1
- deepke

View File

@ -1,8 +1,7 @@
data_dir: "data/"
bert_model: "bert-base-chinese" # ["bert-base-chinese", "bert-base-cased"]
language: "cn" # ["cn", "en"]
bert_model: "bert-base-chinese"
task_name: "ner"
output_dir: "checkpoint"
output_dir: "checkpoints"
max_seq_length: 128
do_train: True
do_eval: True
@ -11,17 +10,16 @@ do_lower_case: True
train_batch_size: 32
eval_batch_size: 8
learning_rate: 5e-5
num_train_epochs: 1 # the number of training epochs
num_train_epochs: 3 # the number of training epochs
warmup_proportion: 0.1
weight_decay: 0.01
adam_epsilon: 1e-8
max_grad_norm: 1.0
no_cuda: False
use_gpu: True # use gpu or not
gpu_id: 1
local_rank: -1
seed: 42
gradient_accumulation_steps: 1
fp16: False
fp16_opt_level: "01"
loss_scale: 0.0
use_gpu: True
gpu_id: 1

View File

@ -1,28 +0,0 @@
## People's Daily(人民日报) dataset
### Task
Named Entity Recognition
### Description
**Tags**: LOC(地名), ORG(机构名), PER(人名)
**Tag Strategy**BIO
**Split**: '*space*' (北 B-LOC)
**Data Size**:
Train data set ( [example.train](example.train) ):
|句数|字符数|LOC数|ORG数|PER数|
|:-:|:-:|:-:|:-:|:-:|
|20864|979180|16571|9277|8144|
Dev data set ( [example.dev](example.dev) ):
|句数|字符数|LOC数|ORG数|PER数|
|:-:|:-:|:-:|:-:|:-:|
|2318|109870|1951|984|884|
Test data set ( [example.test](example.test) )
|句数|字符数|LOC数|ORG数|PER数|
|:-:|:-:|:-:|:-:|:-:|
|4636|219197|3658|2185|1864|
**Reference**:
<https://github.com/zjy-ucas/ChineseNER>

View File

@ -1,98 +0,0 @@
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
been contributed by various people using NLTK for sentence boundary detection.
For information about how to use these models, please confer the tokenization HOWTO:
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
and chapter 3.8 of the NLTK book:
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
There are pretrained tokenizers for the following languages:
File Language Source Contents Size of training corpus(in tokens) Model contributed by
=======================================================================================================================================================================
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
Literarni Noviny
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
(Berlingske Avisdata, Copenhagen) Weekend Avisen
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
(American)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
Text Bank (Suomen Kielen newspapers
Tekstipankki)
Finnish Center for IT Science
(CSC)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
(European)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
(Switzerland) CD-ROM
(Uses "ss"
instead of "ß")
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
(Bokmål and Information Technologies,
Nynorsk) Bergen
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
(http://www.nkjp.pl/)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
(Brazilian) (Linguateca)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
Slovene Academy for Arts
and Sciences
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
(European)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
(and some other texts)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
(Türkçe Derlem Projesi)
University of Ankara
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
Unicode using the codecs module.
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
Computational Linguistics 32: 485-525.
---- Training Code ----
# import punkt
import nltk.tokenize.punkt
# Make a new Tokenizer
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
# Read in training corpus (one example: Slovene)
import codecs
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
# Train tokenizer
tokenizer.train(text)
# Dump pickled tokenizer
import pickle
out = open("slovene.pickle","wb")
pickle.dump(tokenizer, out)
out.close()
---------

View File

@ -1,98 +0,0 @@
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
been contributed by various people using NLTK for sentence boundary detection.
For information about how to use these models, please confer the tokenization HOWTO:
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
and chapter 3.8 of the NLTK book:
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
There are pretrained tokenizers for the following languages:
File Language Source Contents Size of training corpus(in tokens) Model contributed by
=======================================================================================================================================================================
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
Literarni Noviny
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
(Berlingske Avisdata, Copenhagen) Weekend Avisen
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
(American)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
Text Bank (Suomen Kielen newspapers
Tekstipankki)
Finnish Center for IT Science
(CSC)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
(European)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
(Switzerland) CD-ROM
(Uses "ss"
instead of "ß")
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
(Bokmål and Information Technologies,
Nynorsk) Bergen
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
(http://www.nkjp.pl/)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
(Brazilian) (Linguateca)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
Slovene Academy for Arts
and Sciences
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
(European)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
(and some other texts)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
(Türkçe Derlem Projesi)
University of Ankara
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
Unicode using the codecs module.
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
Computational Linguistics 32: 485-525.
---- Training Code ----
# import punkt
import nltk.tokenize.punkt
# Make a new Tokenizer
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
# Read in training corpus (one example: Slovene)
import codecs
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
# Train tokenizer
tokenizer.train(text)
# Dump pickled tokenizer
import pickle
out = open("slovene.pickle","wb")
pickle.dump(tokenizer, out)
out.close()
---------

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,169 +1,10 @@
"""BERT NER Inference."""
from __future__ import absolute_import, division, print_function
import json
import os
import torch
import torch.nn.functional as F
from nltk import word_tokenize
from pytorch_transformers import (BertConfig, BertForTokenClassification,
BertTokenizer)
from collections import OrderedDict
import argparse
import nltk
nltk.data.path.insert(0,'./data/nltk_data')
from deepke.name_entity_re.standard import *
import hydra
from hydra import utils
class BertNer(BertForTokenClassification):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, valid_ids=None):
sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0]
batch_size,max_len,feat_dim = sequence_output.shape
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda' if torch.cuda.is_available() else 'cpu')
for i in range(batch_size):
jj = -1
for j in range(max_len):
if valid_ids[i][j].item() == 1:
jj += 1
valid_output[i][jj] = sequence_output[i][j]
sequence_output = self.dropout(valid_output)
logits = self.classifier(sequence_output)
return logits
class InferNer:
def __init__(self,model_dir:str,language):
self.model , self.tokenizer, self.model_config = self.load_model(model_dir)
self.label_map = self.model_config["label_map"]
self.max_seq_length = self.model_config["max_seq_length"]
self.label_map = {int(k):v for k,v in self.label_map.items()}
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = self.model.to(self.device)
self.model.eval()
self.language = language
def load_model(self, model_dir: str, model_config: str = "model_config.json"):
model_config = os.path.join(model_dir,model_config)
model_config = json.load(open(model_config))
model = BertNer.from_pretrained(model_dir)
tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=model_config["do_lower"])
return model, tokenizer, model_config
def tokenize(self, text: str):
""" tokenize input"""
words = list(text)
tokens = []
valid_positions = []
for i,word in enumerate(words):
token = self.tokenizer.tokenize(word)
tokens.extend(token)
for i in range(len(token)):
if i == 0:
valid_positions.append(1)
else:
valid_positions.append(0)
return tokens, valid_positions
def preprocess(self, text: str):
""" preprocess """
tokens, valid_positions = self.tokenize(text)
## insert "[CLS]"
tokens.insert(0,"[CLS]")
valid_positions.insert(0,1)
## insert "[SEP]"
tokens.append("[SEP]")
valid_positions.append(1)
segment_ids = []
for i in range(len(tokens)):
segment_ids.append(0)
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
while len(input_ids) < self.max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
valid_positions.append(0)
return input_ids,input_mask,segment_ids,valid_positions
def predict(self, text: str):
input_ids,input_mask,segment_ids,valid_ids = self.preprocess(text)
input_ids = torch.tensor([input_ids],dtype=torch.long,device=self.device)
input_mask = torch.tensor([input_mask],dtype=torch.long,device=self.device)
segment_ids = torch.tensor([segment_ids],dtype=torch.long,device=self.device)
valid_ids = torch.tensor([valid_ids],dtype=torch.long,device=self.device)
with torch.no_grad():
logits = self.model(input_ids, segment_ids, input_mask,valid_ids)
logits = F.softmax(logits,dim=2)
logits_label = torch.argmax(logits,dim=2)
logits_label = logits_label.detach().cpu().numpy().tolist()[0]
logits_confidence = [values[label].item() for values,label in zip(logits[0],logits_label)]
logits = []
pos = 0
for index,mask in enumerate(valid_ids[0]):
if index == 0:
continue
if mask == 1:
logits.append((logits_label[index-pos],logits_confidence[index-pos]))
else:
pos += 1
logits.pop()
labels = [(self.label_map[label],confidence) for label,confidence in logits]
if self.language == 'bert-base-chinese':
words = list(text)
else:
nltk.download('punkt')
words = word_tokenize(text)
assert len(labels) == len(words)
result = []
for word, (label, confidence) in zip(words, labels):
if label!='O':
result.append((word,label))
tmp = []
tag = OrderedDict()
tag['PER'] = []
tag['LOC'] = []
tag['ORG'] = []
tag['MISC'] = []
for i, (word, label) in enumerate(result):
if label=='B-PER' or label=='B-LOC' or label=='B-ORG' or label=='B-MISC':
if i==0:
tmp.append(word)
else:
wordstype = result[i-1][1][2:]
if self.language=='bert-base-chinese':
tag[wordstype].append(''.join(tmp))
else:
tag[wordstype].append(' '.join(tmp))
tmp.clear()
tmp.append(word)
elif i==len(result)-1:
tmp.append(word)
wordstype = result[i][1][2:]
if self.language=='bert-base-chinese':
tag[wordstype].append(''.join(tmp))
else:
tag[wordstype].append(' '.join(tmp))
else:
tmp.append(word)
return tag
@hydra.main(config_path="conf", config_name='config')
def main(cfg):
model = InferNer(utils.get_original_cwd()+'/'+"checkpoint/", cfg.bert_model)
model = InferNer(utils.get_original_cwd()+'/'+"checkpoints/")
text = cfg.text
print("NER句子:")
@ -180,9 +21,7 @@ def main(cfg):
print('Location')
elif k=='ORG':
print('Organization')
elif k=='MISC':
print('Miscellaneous')
if __name__ == "__main__":
main()

View File

@ -1,5 +1,7 @@
pytorch-transformers==1.2.0
torch==1.2.0
torch==1.5.0
hydra-core==1.0.6
seqeval==0.0.5
tqdm==4.31.1
nltk==3.4.5
matplotlib==3.4.1
deepke

View File

@ -6,35 +6,31 @@ import logging
import os
import random
import sys
import numpy as np
import torch
import torch.nn.functional as F
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig,
BertForTokenClassification, BertTokenizer,
WarmupLinearSchedule)
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig, BertForTokenClassification, BertTokenizer, WarmupLinearSchedule)
from torch import nn
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from seqeval.metrics import classification_report
import hydra
from hydra import utils
from deepke.name_entity_re.standard import *
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
import hydra
from hydra import utils
class TrainNer(BertForTokenClassification):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
batch_size,max_len,feat_dim = sequence_output.shape
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=1) #device 如用gpu需要修改为cfg.gpu_id的值 不用则为cpu
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=1) #device=cfg.gpu_id if use_gpu else 'cpu'
for i in range(batch_size):
jj = -1
for j in range(max_len):
@ -46,8 +42,6 @@ class TrainNer(BertForTokenClassification):
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=0)
# Only keep active parts of the loss
#attention_mask_label = None
if attention_mask_label is not None:
active_loss = attention_mask_label.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
@ -60,202 +54,17 @@ class TrainNer(BertForTokenClassification):
return logits
class InputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
"""A single set of features of data."""
def __init__(self, input_ids, input_mask, segment_ids, label_id, valid_ids=None, label_mask=None):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.valid_ids = valid_ids
self.label_mask = label_mask
def readfile(filename):
'''
read file
'''
f = open(filename)
data = []
sentence = []
label= []
for line in f:
if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n":
if len(sentence) > 0:
data.append((sentence,label))
sentence = []
label = []
continue
splits = line.split(' ')
sentence.append(splits[0])
label.append(splits[-1][:-1])
if len(sentence) >0:
data.append((sentence,label))
sentence = []
label = []
return data
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
return readfile(input_file)
class NerProcessor(DataProcessor):
"""Processor for the dataset."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.txt")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "valid.txt")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.txt")), "test")
def get_labels(self):
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "[CLS]", "[SEP]"]
def _create_examples(self,lines,set_type):
examples = []
for i,(sentence,label) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = ' '.join(sentence)
text_b = None
label = label
examples.append(InputExample(guid=guid,text_a=text_a,text_b=text_b,label=label))
return examples
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""
label_map = {label : i for i, label in enumerate(label_list,1)}
features = []
for (ex_index,example) in enumerate(examples):
textlist = example.text_a.split(' ')
labellist = example.label
tokens = []
labels = []
valid = []
label_mask = []
for i, word in enumerate(textlist):
token = tokenizer.tokenize(word)
tokens.extend(token)
label_1 = labellist[i]
for m in range(len(token)):
if m == 0:
labels.append(label_1)
valid.append(1)
label_mask.append(1)
else:
valid.append(0)
if len(tokens) >= max_seq_length - 1:
tokens = tokens[0:(max_seq_length - 2)]
labels = labels[0:(max_seq_length - 2)]
valid = valid[0:(max_seq_length - 2)]
label_mask = label_mask[0:(max_seq_length - 2)]
ntokens = []
segment_ids = []
label_ids = []
ntokens.append("[CLS]")
segment_ids.append(0)
valid.insert(0,1)
label_mask.insert(0,1)
label_ids.append(label_map["[CLS]"])
for i, token in enumerate(tokens):
ntokens.append(token)
segment_ids.append(0)
if len(labels) > i:
label_ids.append(label_map[labels[i]])
ntokens.append("[SEP]")
segment_ids.append(0)
valid.append(1)
label_mask.append(1)
label_ids.append(label_map["[SEP]"])
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
input_mask = [1] * len(input_ids)
label_mask = [1] * len(label_ids)
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(0)
valid.append(1)
label_mask.append(0)
while len(label_ids) < max_seq_length:
label_ids.append(0)
label_mask.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(label_ids) == max_seq_length
assert len(valid) == max_seq_length
assert len(label_mask) == max_seq_length
features.append(
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_ids,
valid_ids=valid,
label_mask=label_mask))
return features
@hydra.main(config_path="conf", config_name='config')
def main(cfg):
processors = {"ner":NerProcessor}
# Use gpu or not
if cfg.use_gpu and torch.cuda.is_available():
device = torch.device('cuda', cfg.gpu_id)
else:
device = torch.device('cpu')
if cfg.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
cfg.gradient_accumulation_steps))
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(cfg.gradient_accumulation_steps))
cfg.train_batch_size = cfg.train_batch_size // cfg.gradient_accumulation_steps
@ -266,28 +75,25 @@ def main(cfg):
if not cfg.do_train and not cfg.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
# Checkpoints
if os.path.exists(utils.get_original_cwd()+'/'+cfg.output_dir) and os.listdir(utils.get_original_cwd()+'/'+cfg.output_dir) and cfg.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(utils.get_original_cwd()+'/'+cfg.output_dir))
if not os.path.exists(utils.get_original_cwd()+'/'+cfg.output_dir):
os.makedirs(utils.get_original_cwd()+'/'+cfg.output_dir)
task_name = cfg.task_name.lower()
if task_name not in processors:
raise ValueError("Task not found: %s" % (task_name))
processor = processors[task_name]()
# Preprocess the input dataset
processor = NerProcessor()
label_list = processor.get_labels()
num_labels = len(label_list) + 1
# Prepare the model
tokenizer = BertTokenizer.from_pretrained(cfg.bert_model, do_lower_case=cfg.do_lower_case)
train_examples = None
num_train_optimization_steps = 0
if cfg.do_train:
train_examples = processor.get_train_examples(utils.get_original_cwd()+'/'+cfg.data_dir)
num_train_optimization_steps = int(
len(train_examples) / cfg.train_batch_size / cfg.gradient_accumulation_steps) * cfg.num_train_epochs
num_train_optimization_steps = int(len(train_examples) / cfg.train_batch_size / cfg.gradient_accumulation_steps) * cfg.num_train_epochs
config = BertConfig.from_pretrained(cfg.bert_model, num_labels=num_labels, finetuning_task=cfg.task_name)
model = TrainNer.from_pretrained(cfg.bert_model,from_tf = False,config = config)

View File

@ -1,7 +1,7 @@
from setuptools import setup, find_packages
setup(
name='deepke', # 打包后的包文件名
version='0.2.67', #版本号
version='0.2.79', #版本号
keywords=["pip", "RE","NER","AE"], # 关键字
description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明
long_description="client", #详细说明

View File

@ -1,51 +0,0 @@
from __future__ import absolute_import, division, print_function
import argparse
import csv
import json
import logging
import os
import random
import sys
import numpy as np
import torch
import torch.nn.functional as F
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig,
BertForTokenClassification, BertTokenizer,
WarmupLinearSchedule)
from torch import nn
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from seqeval.metrics import classification_report
class TrainNer(BertForTokenClassification):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
batch_size,max_len,feat_dim = sequence_output.shape
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda')
for i in range(batch_size):
jj = -1
for j in range(max_len):
if valid_ids[i][j].item() == 1:
jj += 1
valid_output[i][jj] = sequence_output[i][j]
sequence_output = self.dropout(valid_output)
logits = self.classifier(sequence_output)
if labels is not None:
loss_fct = nn.CrossEntropyLoss(ignore_index=0)
# Only keep active parts of the loss
#attention_mask_label = None
if attention_mask_label is not None:
active_loss = attention_mask_label.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
active_labels = labels.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits

View File

@ -33,7 +33,7 @@ class BertNer(BertForTokenClassification):
class InferNer:
def __init__(self,model_dir:str):
def __init__(self,model_dir: str):
self.model , self.tokenizer, self.model_config = self.load_model(model_dir)
self.label_map = self.model_config["label_map"]
self.max_seq_length = self.model_config["max_seq_length"]
@ -136,12 +136,8 @@ class InferNer:
elif i==len(result)-1:
tmp.append(word)
wordstype = result[i][1][2:]
if self.language=='bert-base-chinese':
tag[wordstype].append(''.join(tmp))
else:
tag[wordstype].append(' '.join(tmp))
tag[wordstype].append(''.join(tmp))
else:
tmp.append(word)
return tag

View File

@ -1,2 +1 @@
from .BasicNer import *
from .InferBert import *

View File

@ -1,3 +1,2 @@
from .dataset import *
from .preprocess import *
from .trainer import *
from .preprocess import *

View File

@ -1,98 +0,0 @@
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
been contributed by various people using NLTK for sentence boundary detection.
For information about how to use these models, please confer the tokenization HOWTO:
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
and chapter 3.8 of the NLTK book:
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
There are pretrained tokenizers for the following languages:
File Language Source Contents Size of training corpus(in tokens) Model contributed by
=======================================================================================================================================================================
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
Literarni Noviny
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
(Berlingske Avisdata, Copenhagen) Weekend Avisen
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
(American)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
Text Bank (Suomen Kielen newspapers
Tekstipankki)
Finnish Center for IT Science
(CSC)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
(European)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
(Switzerland) CD-ROM
(Uses "ss"
instead of "ß")
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
(Bokmål and Information Technologies,
Nynorsk) Bergen
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
(http://www.nkjp.pl/)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
(Brazilian) (Linguateca)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
Slovene Academy for Arts
and Sciences
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
(European)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
(and some other texts)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
(Türkçe Derlem Projesi)
University of Ankara
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
Unicode using the codecs module.
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
Computational Linguistics 32: 485-525.
---- Training Code ----
# import punkt
import nltk.tokenize.punkt
# Make a new Tokenizer
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
# Read in training corpus (one example: Slovene)
import codecs
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
# Train tokenizer
tokenizer.train(text)
# Dump pickled tokenizer
import pickle
out = open("slovene.pickle","wb")
pickle.dump(tokenizer, out)
out.close()
---------

View File

@ -1,98 +0,0 @@
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
been contributed by various people using NLTK for sentence boundary detection.
For information about how to use these models, please confer the tokenization HOWTO:
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
and chapter 3.8 of the NLTK book:
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
There are pretrained tokenizers for the following languages:
File Language Source Contents Size of training corpus(in tokens) Model contributed by
=======================================================================================================================================================================
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
Literarni Noviny
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
(Berlingske Avisdata, Copenhagen) Weekend Avisen
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
(American)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
Text Bank (Suomen Kielen newspapers
Tekstipankki)
Finnish Center for IT Science
(CSC)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
(European)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
(Switzerland) CD-ROM
(Uses "ss"
instead of "ß")
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
(Bokmål and Information Technologies,
Nynorsk) Bergen
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
(http://www.nkjp.pl/)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
(Brazilian) (Linguateca)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
Slovene Academy for Arts
and Sciences
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
(European)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
(and some other texts)
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
(Türkçe Derlem Projesi)
University of Ankara
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
Unicode using the codecs module.
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
Computational Linguistics 32: 485-525.
---- Training Code ----
# import punkt
import nltk.tokenize.punkt
# Make a new Tokenizer
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
# Read in training corpus (one example: Slovene)
import codecs
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
# Train tokenizer
tokenizer.train(text)
# Dump pickled tokenizer
import pickle
out = open("slovene.pickle","wb")
pickle.dump(tokenizer, out)
out.close()
---------

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -28,35 +28,49 @@
"metadata": {},
"source": [
"## Dataset\n",
"In this demo, we use dataset [**CoNLL-2003**](https://www.clips.uantwerpen.be/conll2003/ner/). It is a dataset for NER, concentrating on four types of named entities related to persons, locations, organizations, and names of miscellaneous entities.\n",
"In this demo, we use [**People's Daily(人民日报) dataset**](https://github.com/OYE93/Chinese-NLP-Corpus/tree/master/NER/People's%20Daily). It is a dataset for NER, concentrating on their types of named entities related to persons(PER), locations(LOC), and organizations(ORG).\n",
"\n",
"| Word | Part-of-speech (POS) tag | Syntactic chunk tag | Named entity tag |\n",
"| :--------: | :----------------------: | :-----------------: | :--------------: |\n",
"| Pakistan | NNP | B-NP | B-LOC |\n",
"| , | , | O | O |\n",
"| who | WP | B-NP | O |\n",
"| arrive | VBP | B-VP | O |\n",
"| in | IN | B-PP | O |\n",
"| Australia | NNP | B-NP | B-LOC |\n",
"| later | JJ | B-NP | O |\n",
"| this | DT | I-NP | O |\n",
"| month | NN | I-NP | O |\n",
"| , | , | O | O |\n",
"| are | VBP | B-VP | O |\n",
"| the | DT | B-NP | O |\n",
"| other | JJ | I-NP | O |\n",
"| team | NN | I-NP | O |\n",
"| competing | VBG | B-VP | O |\n",
"| in | IN | B-PP | O |\n",
"| the | DT | B-NP | O |\n",
"| World | NNP | I-NP | B-MISC |\n",
"| Series | NNP | I-NP | I-MISC |\n",
"| tournament | NN | I-NP | O |\n",
"| . | . | O | O |\n",
"| Word | Named entity tag |\n",
"| :--: | :--------------: |\n",
"| 早 | O |\n",
"| 在 | O |\n",
"| 1 | O |\n",
"| 9 | O |\n",
"| 7 | O |\n",
"| 5 | O |\n",
"| 年 | O |\n",
"| | O |\n",
"| 张 | B-PER |\n",
"| 鸿 | I-PER |\n",
"| 飞 | I-PER |\n",
"| 就 | O |\n",
"| 有 | O |\n",
"| 《 | O |\n",
"| 草 | O |\n",
"| 原 | O |\n",
"| 新 | O |\n",
"| 医 | O |\n",
"| 》 | O |\n",
"| 赴 | O |\n",
"| 法 | B-LOC |\n",
"| 展 | O |\n",
"| 览 | O |\n",
"| | O |\n",
"| 为 | O |\n",
"| 我 | O |\n",
"| 国 | O |\n",
"| 驻 | B-ORG |\n",
"| 法 | I-ORG |\n",
"| 使 | I-ORG |\n",
"| 馆 | I-ORG |\n",
"| 收 | O |\n",
"| 藏 | O |\n",
"| 。 | O |\n",
"\n",
"- train.txt: It contains 14,987 sentences\n",
"- valid.txt: It contains 3,466 sentences\n",
"- test.txt: It contains 3,684 sentences"
"\n",
"- train.txt: It contains 20,864 sentences, including 979,180 named entity tags.\n",
"- valid.txt: It contains 2,318 sentences, including 109,870 named entity tags.\n",
"- test.txt: It contains 4,636 sentences, including 219,197 named entity tags."
]
},
{
@ -65,7 +79,7 @@
"metadata": {},
"source": [
"## BERT\n",
"[**Bidirectional Encoder Representations from Transformers (BERT)**](https://github.com/google-research/bert) is a transformer-based machine learning technique for natural language processing (NLP) pre-training developed by Google.\n",
"[**Bidirectional Encoder Representations from Transformers (BERT)**](https://github.com/google-research/bert) \n",
"\n",
"![BERT](img/BERT.png)"
]
@ -81,15 +95,17 @@
{
"cell_type": "code",
"execution_count": null,
"id": "bd23f8d2",
"id": "ddb0f3e4",
"metadata": {},
"outputs": [],
"source": [
"!pip install pytorch-transformers==1.2.0\n",
"!pip install torch==1.2.0\n",
"!pip install torch==1.5.0\n",
"!pip install hydra-core==1.0.6\n",
"!pip install seqeval==0.0.5\n",
"!pip install tqdm==4.31.1\n",
"!pip install nltk==3.4.5"
"!pip install matplotlib==3.4.1\n",
"!pip install deepke"
]
},
{
@ -103,12 +119,12 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f33f17d8",
"id": "e21416ae",
"metadata": {},
"outputs": [],
"source": [
"from __future__ import absolute_import, division, print_function\n",
"import argparse\n",
"\n",
"import csv\n",
"import json\n",
"import logging\n",
@ -124,6 +140,9 @@
"from torch.utils.data.distributed import DistributedSampler\n",
"from tqdm import tqdm, trange\n",
"from seqeval.metrics import classification_report\n",
"import hydra\n",
"from hydra import utils\n",
"from deepke.name_entity_re.standard import *\n",
"\n",
"logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',\n",
" datefmt = '%m/%d/%Y %H:%M:%S',\n",
@ -133,235 +152,71 @@
},
{
"cell_type": "markdown",
"id": "de20573b",
"id": "029e661b",
"metadata": {},
"source": [
"## Preprocess the dataset"
"## Configure model parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01ff2c72",
"id": "2178ffdf",
"metadata": {},
"outputs": [],
"source": [
"def readfile(filename):\n",
" '''\n",
" read file\n",
" '''\n",
" f = open(filename)\n",
" data = []\n",
" sentence = []\n",
" label= []\n",
" for line in f:\n",
" if len(line)==0 or line.startswith('-DOCSTART') or line[0]==\"\\n\":\n",
" if len(sentence) > 0:\n",
" data.append((sentence,label))\n",
" sentence = []\n",
" label = []\n",
" continue\n",
" splits = line.split(' ')\n",
" sentence.append(splits[0])\n",
" label.append(splits[-1][:-1])\n",
"class Config(object):\n",
" data_dir = \"data/\" # The input data dir\n",
" bert_model = \"bert-base-chinese\"\n",
" task_name = \"ner\"\n",
" output_dir = \"checkpoints\"\n",
" max_seq_length = 128\n",
" do_train = True # Fine-tune or not\n",
" do_eval = True # Evaluate or not\n",
" eval_on = \"dev\"\n",
" do_lower_case = True\n",
" train_batch_size = 32\n",
" eval_batch_size = 8\n",
" learning_rate = 5e-5\n",
" num_train_epochs = 3 # The number of training epochs\n",
" warmup_proportion = 0.1\n",
" weight_decay = 0.01\n",
" adam_epsilon = 1e-8\n",
" max_grad_norm = 1.0\n",
" use_gpu = True # Use gpu or not\n",
" gpu_id = 1 # Which gpu to be used\n",
" local_rank = -1\n",
" seed = 42\n",
" gradient_accumulation_steps = 1\n",
" fp16 = False\n",
" fp16_opt_level = \"01\"\n",
" loss_scale = 0.0\n",
" text = \"秦始皇兵马俑位于陕西省西安市1961年被国务院公布为第一批全国重点文物保护单位是世界八大奇迹之一。\"\n",
"\n",
" if len(sentence) >0:\n",
" data.append((sentence,label))\n",
" sentence = []\n",
" label = []\n",
" return data\n",
"\n",
"class InputExample(object):\n",
" \"\"\"A single training/test example for simple sequence classification.\"\"\"\n",
"\n",
" def __init__(self, guid, text_a, text_b=None, label=None):\n",
" \"\"\"Constructs a InputExample.\n",
" Args:\n",
" guid: Unique id for the example.\n",
" text_a: string. The untokenized text of the first sequence. For single\n",
" sequence tasks, only this sequence must be specified.\n",
" text_b: (Optional) string. The untokenized text of the second sequence.\n",
" Only must be specified for sequence pair tasks.\n",
" label: (Optional) string. The label of the example. This should be\n",
" specified for train and dev examples, but not for test examples.\n",
" \"\"\"\n",
" self.guid = guid\n",
" self.text_a = text_a\n",
" self.text_b = text_b\n",
" self.label = label\n",
"\n",
"class InputFeatures(object):\n",
" \"\"\"A single set of features of data.\"\"\"\n",
"\n",
" def __init__(self, input_ids, input_mask, segment_ids, label_id, valid_ids=None, label_mask=None):\n",
" self.input_ids = input_ids\n",
" self.input_mask = input_mask\n",
" self.segment_ids = segment_ids\n",
" self.label_id = label_id\n",
" self.valid_ids = valid_ids\n",
" self.label_mask = label_mask\n",
"\n",
"class DataProcessor(object):\n",
" \"\"\"Base class for data converters for sequence classification data sets.\"\"\"\n",
"\n",
" def get_train_examples(self, data_dir):\n",
" \"\"\"Gets a collection of `InputExample`s for the train set.\"\"\"\n",
" raise NotImplementedError()\n",
"\n",
" def get_dev_examples(self, data_dir):\n",
" \"\"\"Gets a collection of `InputExample`s for the dev set.\"\"\"\n",
" raise NotImplementedError()\n",
"\n",
" def get_labels(self):\n",
" \"\"\"Gets the list of labels for this data set.\"\"\"\n",
" raise NotImplementedError()\n",
"\n",
" @classmethod\n",
" def _read_tsv(cls, input_file, quotechar=None):\n",
" \"\"\"Reads a tab separated value file.\"\"\"\n",
" return readfile(input_file)\n",
"\n",
"\n",
"class NerProcessor(DataProcessor):\n",
" \"\"\"Processor for the CoNLL-2003 data set.\"\"\"\n",
"\n",
" def get_train_examples(self, data_dir):\n",
" \"\"\"See base class.\"\"\"\n",
" return self._create_examples(\n",
" self._read_tsv(os.path.join(data_dir, \"train.txt\")), \"train\")\n",
"\n",
" def get_dev_examples(self, data_dir):\n",
" \"\"\"See base class.\"\"\"\n",
" return self._create_examples(\n",
" self._read_tsv(os.path.join(data_dir, \"valid.txt\")), \"dev\")\n",
"\n",
" def get_test_examples(self, data_dir):\n",
" \"\"\"See base class.\"\"\"\n",
" return self._create_examples(\n",
" self._read_tsv(os.path.join(data_dir, \"test.txt\")), \"test\")\n",
"\n",
" def get_labels(self):\n",
" return [\"O\", \"B-MISC\", \"I-MISC\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-LOC\", \"I-LOC\", \"[CLS]\", \"[SEP]\"]\n",
"\n",
" def _create_examples(self,lines,set_type):\n",
" examples = []\n",
" for i,(sentence,label) in enumerate(lines):\n",
" guid = \"%s-%s\" % (set_type, i)\n",
" text_a = ' '.join(sentence)\n",
" text_b = None\n",
" label = label\n",
" examples.append(InputExample(guid=guid,text_a=text_a,text_b=text_b,label=label))\n",
" return examples\n",
"\n",
"def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):\n",
" \"\"\"Loads a data file into a list of `InputBatch`s.\"\"\"\n",
"\n",
" label_map = {label : i for i, label in enumerate(label_list,1)}\n",
"\n",
" features = []\n",
" for (ex_index,example) in enumerate(examples):\n",
" textlist = example.text_a.split(' ')\n",
" labellist = example.label\n",
" tokens = []\n",
" labels = []\n",
" valid = []\n",
" label_mask = []\n",
" for i, word in enumerate(textlist):\n",
" token = tokenizer.tokenize(word)\n",
" tokens.extend(token)\n",
" label_1 = labellist[i]\n",
" for m in range(len(token)):\n",
" if m == 0:\n",
" labels.append(label_1)\n",
" valid.append(1)\n",
" label_mask.append(1)\n",
" else:\n",
" valid.append(0)\n",
" if len(tokens) >= max_seq_length - 1:\n",
" tokens = tokens[0:(max_seq_length - 2)]\n",
" labels = labels[0:(max_seq_length - 2)]\n",
" valid = valid[0:(max_seq_length - 2)]\n",
" label_mask = label_mask[0:(max_seq_length - 2)]\n",
" ntokens = []\n",
" segment_ids = []\n",
" label_ids = []\n",
" ntokens.append(\"[CLS]\")\n",
" segment_ids.append(0)\n",
" valid.insert(0,1)\n",
" label_mask.insert(0,1)\n",
" label_ids.append(label_map[\"[CLS]\"])\n",
" for i, token in enumerate(tokens):\n",
" ntokens.append(token)\n",
" segment_ids.append(0)\n",
" if len(labels) > i:\n",
" label_ids.append(label_map[labels[i]])\n",
" ntokens.append(\"[SEP]\")\n",
" segment_ids.append(0)\n",
" valid.append(1)\n",
" label_mask.append(1)\n",
" label_ids.append(label_map[\"[SEP]\"])\n",
" input_ids = tokenizer.convert_tokens_to_ids(ntokens)\n",
" input_mask = [1] * len(input_ids)\n",
" label_mask = [1] * len(label_ids)\n",
" while len(input_ids) < max_seq_length:\n",
" input_ids.append(0)\n",
" input_mask.append(0)\n",
" segment_ids.append(0)\n",
" label_ids.append(0)\n",
" valid.append(1)\n",
" label_mask.append(0)\n",
" while len(label_ids) < max_seq_length:\n",
" label_ids.append(0)\n",
" label_mask.append(0)\n",
" assert len(input_ids) == max_seq_length\n",
" assert len(input_mask) == max_seq_length\n",
" assert len(segment_ids) == max_seq_length\n",
" assert len(label_ids) == max_seq_length\n",
" assert len(valid) == max_seq_length\n",
" assert len(label_mask) == max_seq_length\n",
"\n",
" if ex_index < 5:\n",
" logger.info(\"*** Example ***\")\n",
" logger.info(\"guid: %s\" % (example.guid))\n",
" logger.info(\"tokens: %s\" % \" \".join(\n",
" [str(x) for x in tokens]))\n",
" logger.info(\"input_ids: %s\" % \" \".join([str(x) for x in input_ids]))\n",
" logger.info(\"input_mask: %s\" % \" \".join([str(x) for x in input_mask]))\n",
" logger.info(\n",
" \"segment_ids: %s\" % \" \".join([str(x) for x in segment_ids]))\n",
" # logger.info(\"label: %s (id = %d)\" % (example.label, label_ids))\n",
"\n",
" features.append(\n",
" InputFeatures(input_ids=input_ids,\n",
" input_mask=input_mask,\n",
" segment_ids=segment_ids,\n",
" label_id=label_ids,\n",
" valid_ids=valid,\n",
" label_mask=label_mask))\n",
" return features"
"cfg = Config()"
]
},
{
"cell_type": "markdown",
"id": "1aedb530",
"id": "5a67afe9",
"metadata": {},
"source": [
"## BERT model"
"## Prepare the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b4abd61c",
"id": "f45cc648",
"metadata": {},
"outputs": [],
"source": [
"class Ner(BertForTokenClassification):\n",
"class TrainNer(BertForTokenClassification):\n",
"\n",
" def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):\n",
" sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]\n",
" batch_size,max_len,feat_dim = sequence_output.shape\n",
" valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda')\n",
" valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=1) #device 如用gpu需要修改为cfg.gpu_id的值 不用则为cpu\n",
" for i in range(batch_size):\n",
" jj = -1\n",
" for j in range(max_len):\n",
@ -373,8 +228,6 @@
"\n",
" if labels is not None:\n",
" loss_fct = nn.CrossEntropyLoss(ignore_index=0)\n",
" # Only keep active parts of the loss\n",
" # attention_mask_label = None\n",
" if attention_mask_label is not None:\n",
" active_loss = attention_mask_label.view(-1) == 1\n",
" active_logits = logits.view(-1, self.num_labels)[active_loss]\n",
@ -387,140 +240,65 @@
" return logits"
]
},
{
"cell_type": "markdown",
"id": "029e661b",
"metadata": {},
"source": [
"## Fine-Tune"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ae701820",
"id": "9ed1c3dd",
"metadata": {},
"outputs": [],
"source": [
"# Required parameters\n",
"data_dir = \"data/\"\n",
"bert_model = \"bert-base-cased\"\n",
"task_name = \"ner\"\n",
"output_dir = \"out_ner\"\n",
"max_seq_length = 128\n",
"do_train = True\n",
"do_eval = True\n",
"eval_on = \"dev\"\n",
"do_lower_case = \"True\"\n",
"train_batch_size = 32\n",
"eval_batch_size = 8\n",
"learning_rate = 5e-5\n",
"num_train_epochs = 5.0 # the number of training epochs\n",
"warmup_proportion = 0.1\n",
"weight_decay = 0.01\n",
"adam_epsilon = 1e-8\n",
"max_grad_norm = 1.0\n",
"no_cuda = False\n",
"local_rank = -1\n",
"seed = 42\n",
"gradient_accumulation_steps = 1\n",
"fp16 = False\n",
"fp16_opt_level = \"01\"\n",
"loss_scale = 0.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f4ced0f",
"metadata": {},
"outputs": [],
"source": [
"processors = {\"ner\":NerProcessor}\n",
"\n",
"if local_rank ==-1 or no_cuda:\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() and not no_cuda else \"cpu\")\n",
" n_gpu = torch.cuda.device_count()\n",
"# Use gpu or not\n",
"if cfg.use_gpu and torch.cuda.is_available():\n",
" device = torch.device('cuda', cfg.gpu_id)\n",
"else:\n",
" torch.cuda.set_device(local_rank)\n",
" device = torch.device(\"cuda\", local_rank)\n",
" n_gpu = 1\n",
" # Initializes the distributed backend which will take care of sychronizing nodes/GPUs\n",
" torch.distributed.init_process_group(backend='nccl')\n",
" device = torch.device('cpu')\n",
"\n",
"logger.info(\"device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}\".format(device, n_gpu, bool(local_rank != -1), fp16))\n",
"if cfg.gradient_accumulation_steps < 1:\n",
" raise ValueError(\"Invalid gradient_accumulation_steps parameter: {}, should be >= 1\".format(cfg.gradient_accumulation_steps))\n",
"\n",
"train_batch_size = train_batch_size // gradient_accumulation_steps\n",
"cfg.train_batch_size = cfg.train_batch_size // cfg.gradient_accumulation_steps\n",
"\n",
"random.seed(seed)\n",
"np.random.seed(seed)\n",
"torch.manual_seed(seed)\n",
"random.seed(cfg.seed)\n",
"np.random.seed(cfg.seed)\n",
"torch.manual_seed(cfg.seed)\n",
"\n",
"if os.path.exists(output_dir) and os.listdir(output_dir) and do_train:\n",
" raise ValueError(\"Output directory ({}) already exists and is not empty.\".format(output_dir))\n",
"if not os.path.exists(output_dir):\n",
" os.makedirs(output_dir)\n",
"if not cfg.do_train and not cfg.do_eval:\n",
" raise ValueError(\"At least one of `do_train` or `do_eval` must be True.\")\n",
"\n",
"task_name = task_name.lower()\n",
"# Checkpoints\n",
"if os.path.exists(cfg.output_dir) and os.listdir(cfg.output_dir) and cfg.do_train:\n",
" raise ValueError(\"Output directory ({}) already exists and is not empty.\".format(cfg.output_dir))\n",
"if not os.path.exists(cfg.output_dir):\n",
" os.makedirs(cfg.output_dir)\n",
"\n",
"processor = processors[task_name]()\n",
"# Preprocess the input dataset\n",
"processor = NerProcessor()\n",
"label_list = processor.get_labels()\n",
"num_labels = len(label_list) + 1\n",
"\n",
"tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case)\n",
"# Prepare the model\n",
"tokenizer = BertTokenizer.from_pretrained(cfg.bert_model, do_lower_case=cfg.do_lower_case)\n",
"\n",
"train_examples = None\n",
"num_train_optimization_steps = 0\n",
"if do_train:\n",
" train_examples = processor.get_train_examples(data_dir)\n",
" num_train_optimization_steps = int(len(train_examples) / train_batch_size / gradient_accumulation_steps) * num_train_epochs\n",
" if local_rank != -1:\n",
" num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()\n",
"\n",
"if local_rank not in [-1, 0]:\n",
" torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43846012",
"metadata": {},
"outputs": [],
"source": [
"# Prepare model\n",
"config = BertConfig.from_pretrained(bert_model, num_labels=num_labels, finetuning_task=task_name)\n",
"model = Ner.from_pretrained(bert_model, from_tf = False, config = config)\n",
"\n",
"if local_rank == 0:\n",
" torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab\n",
"if cfg.do_train:\n",
" train_examples = processor.get_train_examples(cfg.data_dir)\n",
" num_train_optimization_steps = int(len(train_examples) / cfg.train_batch_size / cfg.gradient_accumulation_steps) * cfg.num_train_epochs\n",
"\n",
"config = BertConfig.from_pretrained(cfg.bert_model, num_labels=num_labels, finetuning_task=cfg.task_name)\n",
"model = TrainNer.from_pretrained(cfg.bert_model,from_tf = False,config = config)\n",
"model.to(device)\n",
"\n",
"param_optimizer = list(model.named_parameters())\n",
"no_decay = ['bias','LayerNorm.weight']\n",
"optimizer_grouped_parameters = [\n",
" {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},\n",
" {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': cfg.weight_decay},\n",
" {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
" ]\n",
"warmup_steps = int(warmup_proportion * num_train_optimization_steps)\n",
"optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)\n",
"warmup_steps = int(cfg.warmup_proportion * num_train_optimization_steps)\n",
"optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate, eps=cfg.adam_epsilon)\n",
"scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)\n",
"if fp16:\n",
" try:\n",
" from apex import amp\n",
" except ImportError:\n",
" raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n",
" model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)\n",
"\n",
"# multi-gpu training (should be after apex fp16 initialization)\n",
"if n_gpu > 1:\n",
" model = torch.nn.DataParallel(model)\n",
"\n",
"if local_rank != -1:\n",
" model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],\n",
" output_device=local_rank,\n",
" find_unused_parameters=True)\n",
"global_step = 0\n",
"nb_tr_steps = 0\n",
"tr_loss = 0\n",
@ -538,18 +316,12 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a8e71425",
"id": "95e37ce2",
"metadata": {},
"outputs": [],
"source": [
"# Train model\n",
"if do_train:\n",
" train_features = convert_examples_to_features(\n",
" train_examples, label_list, max_seq_length, tokenizer)\n",
" logger.info(\"***** Running training *****\")\n",
" logger.info(\" Num examples = %d\", len(train_examples))\n",
" logger.info(\" Batch size = %d\", train_batch_size)\n",
" logger.info(\" Num steps = %d\", num_train_optimization_steps)\n",
"if cfg.do_train:\n",
" train_features = convert_examples_to_features(train_examples, label_list, cfg.max_seq_length, tokenizer)\n",
" all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)\n",
" all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)\n",
" all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)\n",
@ -557,37 +329,29 @@
" all_valid_ids = torch.tensor([f.valid_ids for f in train_features], dtype=torch.long)\n",
" all_lmask_ids = torch.tensor([f.label_mask for f in train_features], dtype=torch.long)\n",
" train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_valid_ids,all_lmask_ids)\n",
" if local_rank == -1:\n",
" train_sampler = RandomSampler(train_data)\n",
" else:\n",
" train_sampler = DistributedSampler(train_data)\n",
" train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)\n",
" train_sampler = RandomSampler(train_data)\n",
"\n",
" train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=cfg.train_batch_size)\n",
"\n",
" model.train()\n",
" for _ in trange(int(num_train_epochs), desc=\"Epoch\"):\n",
"\n",
" for _ in trange(int(cfg.num_train_epochs), desc=\"Epoch\"):\n",
" tr_loss = 0\n",
" nb_tr_examples, nb_tr_steps = 0, 0\n",
" for step, batch in enumerate(tqdm(train_dataloader, desc=\"Iteration\")):\n",
" batch = tuple(t.to(device) for t in batch)\n",
" input_ids, input_mask, segment_ids, label_ids, valid_ids,l_mask = batch\n",
" loss = model(input_ids, segment_ids, input_mask, label_ids,valid_ids,l_mask)\n",
" if n_gpu > 1:\n",
" loss = loss.mean() # mean() to average on multi-gpu.\n",
" if gradient_accumulation_steps > 1:\n",
" loss = loss / gradient_accumulation_steps\n",
" if cfg.gradient_accumulation_steps > 1:\n",
" loss = loss / cfg.gradient_accumulation_steps\n",
"\n",
" if fp16:\n",
" with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
" scaled_loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)\n",
" else:\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)\n",
"\n",
" tr_loss += loss.item()\n",
" nb_tr_examples += input_ids.size(0)\n",
" nb_tr_steps += 1\n",
" if (step + 1) % gradient_accumulation_steps == 0:\n",
" if (step + 1) % cfg.gradient_accumulation_steps == 0:\n",
" optimizer.step()\n",
" scheduler.step() # Update learning rate schedule\n",
" model.zero_grad()\n",
@ -595,16 +359,17 @@
"\n",
" # Save a trained model and the associated configuration\n",
" model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self\n",
" model_to_save.save_pretrained(output_dir)\n",
" tokenizer.save_pretrained(output_dir)\n",
" model_to_save.save_pretrained(cfg.output_dir)\n",
" tokenizer.save_pretrained(cfg.output_dir)\n",
" label_map = {i : label for i, label in enumerate(label_list,1)}\n",
" model_config = {\"bert_model\":bert_model,\"do_lower\":do_lower_case,\"max_seq_length\":max_seq_length,\"num_labels\":len(label_list)+1,\"label_map\":label_map}\n",
" json.dump(model_config,open(os.path.join(output_dir,\"model_config.json\"),\"w\"))\n",
" model_config = {\"bert_model\":cfg.bert_model,\"do_lower\":cfg.do_lower_case, \"max_seq_length\":cfg.max_seq_length,\"num_labels\":len(label_list)+1,\"label_map\":label_map}\n",
" json.dump(model_config,open(os.path.join(cfg.output_dir,\"model_config.json\"),\"w\"))\n",
" # Load a trained model and config that you have fine-tuned\n",
"else:\n",
" # Load a trained model and vocabulary that you have fine-tuned\n",
" model = Ner.from_pretrained(output_dir)\n",
" tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=do_lower_case)\n",
" model = TrainNer.from_pretrained(cfg.output_dir)\n",
" tokenizer = BertTokenizer.from_pretrained(cfg.output_dir, do_lower_case=cfg.do_lower_case)\n",
"\n",
"model.to(device)"
]
},
@ -619,22 +384,18 @@
{
"cell_type": "code",
"execution_count": null,
"id": "a0fecf19",
"id": "5cf1972a",
"metadata": {},
"outputs": [],
"source": [
"# Evaluation\n",
"if do_eval and (local_rank == -1 or torch.distributed.get_rank() == 0):\n",
" if eval_on == \"dev\":\n",
" eval_examples = processor.get_dev_examples(data_dir)\n",
" elif eval_on == \"test\":\n",
" eval_examples = processor.get_test_examples(data_dir)\n",
"if cfg.do_eval:\n",
" if cfg.eval_on == \"dev\":\n",
" eval_examples = processor.get_dev_examples(cfg.data_dir)\n",
" elif cfg.eval_on == \"test\":\n",
" eval_examples = processor.get_test_examples(cfg.data_dir)\n",
" else:\n",
" raise ValueError(\"eval on dev or test set only\")\n",
" eval_features = convert_examples_to_features(eval_examples, label_list, max_seq_length, tokenizer)\n",
" logger.info(\"***** Running evaluation *****\")\n",
" logger.info(\" Num examples = %d\", len(eval_examples))\n",
" logger.info(\" Batch size = %d\", eval_batch_size)\n",
" eval_features = convert_examples_to_features(eval_examples, label_list, cfg.max_seq_length, tokenizer)\n",
" all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)\n",
" all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)\n",
" all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)\n",
@ -644,7 +405,7 @@
" eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_valid_ids,all_lmask_ids)\n",
" # Run prediction for full data\n",
" eval_sampler = SequentialSampler(eval_data)\n",
" eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)\n",
" eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=cfg.eval_batch_size)\n",
" model.eval()\n",
" eval_loss, eval_accuracy = 0, 0\n",
" nb_eval_steps, nb_eval_examples = 0, 0\n",
@ -683,7 +444,7 @@
"\n",
" report = classification_report(y_true, y_pred,digits=4)\n",
" logger.info(\"\\n%s\", report)\n",
" output_eval_file = os.path.join(output_dir, \"eval_results.txt\")\n",
" output_eval_file = os.path.join(cfg.output_dir, \"eval_results.txt\")\n",
" with open(output_eval_file, \"w\") as writer:\n",
" logger.info(\"***** Eval results *****\")\n",
" logger.info(\"\\n%s\", report)\n",
@ -701,174 +462,16 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f78c5102",
"id": "b0e33a75",
"metadata": {},
"outputs": [],
"source": [
"# make changes for better representation and display of 'entity detected' and their 'entity types' for the given sentence to test or inference\n",
"from nltk import word_tokenize\n",
"from collections import OrderedDict\n",
"model = InferNer(\"checkpoints/\")\n",
"text = cfg.text\n",
"\n",
"class BertNer(BertForTokenClassification):\n",
"\n",
" def forward(self, input_ids, token_type_ids=None, attention_mask=None, valid_ids=None):\n",
" sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0]\n",
" batch_size,max_len,feat_dim = sequence_output.shape\n",
" valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda' if torch.cuda.is_available() else 'cpu')\n",
" for i in range(batch_size):\n",
" jj = -1\n",
" for j in range(max_len):\n",
" if valid_ids[i][j].item() == 1:\n",
" jj += 1\n",
" valid_output[i][jj] = sequence_output[i][j]\n",
" sequence_output = self.dropout(valid_output)\n",
" logits = self.classifier(sequence_output)\n",
" return logits\n",
"\n",
"class Ner:\n",
"\n",
" def __init__(self,model_dir: str):\n",
" self.model , self.tokenizer, self.model_config = self.load_model(model_dir)\n",
" self.label_map = self.model_config[\"label_map\"]\n",
" self.max_seq_length = self.model_config[\"max_seq_length\"]\n",
" self.label_map = {int(k):v for k,v in self.label_map.items()}\n",
" self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
" self.model = self.model.to(self.device)\n",
" self.model.eval()\n",
"\n",
" def load_model(self, model_dir: str, model_config: str = \"model_config.json\"):\n",
" model_config = os.path.join(model_dir,model_config)\n",
" model_config = json.load(open(model_config))\n",
" model = BertNer.from_pretrained(model_dir)\n",
" tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=model_config[\"do_lower\"])\n",
" return model, tokenizer, model_config\n",
"\n",
" def tokenize(self, text: str):\n",
" \"\"\" tokenize input\"\"\"\n",
" words = word_tokenize(text)\n",
" tokens = []\n",
" valid_positions = []\n",
" for i,word in enumerate(words):\n",
" token = self.tokenizer.tokenize(word)\n",
" tokens.extend(token)\n",
" for i in range(len(token)):\n",
" if i == 0:\n",
" valid_positions.append(1)\n",
" else:\n",
" valid_positions.append(0)\n",
" return tokens, valid_positions\n",
"\n",
" def preprocess(self, text: str):\n",
" \"\"\" preprocess \"\"\"\n",
" tokens, valid_positions = self.tokenize(text)\n",
" ## insert \"[CLS]\"\n",
" tokens.insert(0,\"[CLS]\")\n",
" valid_positions.insert(0,1)\n",
" ## insert \"[SEP]\"\n",
" tokens.append(\"[SEP]\")\n",
" valid_positions.append(1)\n",
" segment_ids = []\n",
" for i in range(len(tokens)):\n",
" segment_ids.append(0)\n",
" input_ids = self.tokenizer.convert_tokens_to_ids(tokens)\n",
" input_mask = [1] * len(input_ids)\n",
" while len(input_ids) < self.max_seq_length:\n",
" input_ids.append(0)\n",
" input_mask.append(0)\n",
" segment_ids.append(0)\n",
" valid_positions.append(0)\n",
" return input_ids,input_mask,segment_ids,valid_positions\n",
"\n",
" def predict(self, text: str):\n",
" input_ids,input_mask,segment_ids,valid_ids = self.preprocess(text)\n",
" input_ids = torch.tensor([input_ids],dtype=torch.long,device=self.device)\n",
" input_mask = torch.tensor([input_mask],dtype=torch.long,device=self.device)\n",
" segment_ids = torch.tensor([segment_ids],dtype=torch.long,device=self.device)\n",
" valid_ids = torch.tensor([valid_ids],dtype=torch.long,device=self.device)\n",
" with torch.no_grad():\n",
" logits = self.model(input_ids, segment_ids, input_mask,valid_ids)\n",
" logits = F.softmax(logits,dim=2)\n",
" logits_label = torch.argmax(logits,dim=2)\n",
" logits_label = logits_label.detach().cpu().numpy().tolist()[0]\n",
"\n",
" logits_confidence = [values[label].item() for values,label in zip(logits[0],logits_label)]\n",
"\n",
" logits = []\n",
" pos = 0\n",
" for index,mask in enumerate(valid_ids[0]):\n",
" if index == 0:\n",
" continue\n",
" if mask == 1:\n",
" logits.append((logits_label[index-pos],logits_confidence[index-pos]))\n",
" else:\n",
" pos += 1\n",
" logits.pop()\n",
"\n",
" labels = [(self.label_map[label],confidence) for label,confidence in logits]\n",
" words = word_tokenize(text)\n",
" assert len(labels) == len(words)\n",
"\n",
" result = []\n",
" for word, (label, confidence) in zip(words, labels):\n",
" if label!='O':\n",
" result.append((word,label))\n",
" tmp = []\n",
" tag = OrderedDict()\n",
" tag['PER'] = []\n",
" tag['LOC'] = []\n",
" tag['ORG'] = []\n",
" tag['MISC'] = []\n",
" \n",
" for i, (word, label) in enumerate(result):\n",
" if label=='B-PER' or label=='B-LOC' or label=='B-ORG' or label=='B-MISC':\n",
" if i==0:\n",
" tmp.append(word)\n",
" else:\n",
" wordstype = result[i-1][1][2:]\n",
" tag[wordstype].append(' '.join(tmp))\n",
" tmp.clear()\n",
" tmp.append(word)\n",
" elif i==len(result)-1:\n",
" tmp.append(word)\n",
" wordstype = result[i][1][2:]\n",
" tag[wordstype].append(' '.join(tmp))\n",
" else:\n",
" tmp.append(word)\n",
"\n",
" return tag\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fec9c7b",
"metadata": {},
"outputs": [],
"source": [
"# Run below command for import and download 'nltk' library as it is important for predictions of entities of the sentence\n",
"import nltk\n",
"nltk.download('punkt')\n",
"\n",
"# If it's too slow to download 'nltk_data', we offer it in 'data/nltk_data' and you can use it by running the following code.\n",
"# import nltk\n",
"# nltk.data.path.insert(0,'./data/nltk_data')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ce89e90",
"metadata": {},
"outputs": [],
"source": [
"model = Ner(\"out_ner/\")\n",
"\n",
"# Text to be NERed\n",
"text = \"Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival.\"\n",
"\n",
"print(\"The text to be NERed:\")\n",
"print(\"NER句子:\")\n",
"print(text)\n",
"print('Results of NER:')\n",
"print('NER结果:')\n",
"\n",
"result = model.predict(text)\n",
"for k,v in result.items():\n",
@ -879,9 +482,7 @@
" elif k=='LOC':\n",
" print('Location')\n",
" elif k=='ORG':\n",
" print('Organization')\n",
" elif k=='MISC':\n",
" print('Miscellaneous')"
" print('Organization')"
]
},
{
@ -889,14 +490,14 @@
"id": "97fc8159",
"metadata": {},
"source": [
"This demo does not include parameter adjustment. If you can interested in this, you can go to [deepke](http://openkg.cn/tool/deepke)\n",
"This demo does not include parameter adjustment. If you are interested in this, you can go to [deepke](http://openkg.cn/tool/deepke)\n",
"Warehouse, download and use more models:)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -910,7 +511,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.11"
"version": "3.8.11"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,