test
This commit is contained in:
parent
aa7468660e
commit
2bbb9fb8a3
12
README.md
12
README.md
|
@ -45,7 +45,7 @@ DeepKE包括了三个模块,可以进行关系抽取、实体命名识别以
|
|||
|
||||
> python == 3.8
|
||||
|
||||
- torch >= 1.2
|
||||
- torch >= 1.5
|
||||
- hydra-core == 1.0.6
|
||||
- tensorboard >= 2.0
|
||||
- matplotlib >= 3.1
|
||||
|
@ -62,11 +62,11 @@ DeepKE包括了三个模块,可以进行关系抽取、实体命名识别以
|
|||
1. **命名实体识别NER**
|
||||
|
||||
数据为txt文件,样式范例为:
|
||||
| Sentence | Person | Location | Organization |
|
||||
| :----------------------------------------------------------: | :------------------------: | :----------: | :----------------------------: |
|
||||
| 本报北京9月4日讯记者杨涌报道:部分省区人民日报宣传发行工作座谈会9月3日在4日在京举行。 | 杨涌 | 北京,京 | 人民日报 |
|
||||
| 《红楼梦》是中央电视台和中国电视剧制作中心根据中国古典文学名著《红楼梦》摄制于1987年的一部古装连续剧,由王扶林导演,周汝昌、王蒙、周岭等多位红学家参与制作。 | 王扶林,周汝昌,王蒙,周岭 | 中国 | 中央电视台,中国电视剧制作中心 |
|
||||
| 秦始皇兵马俑位于陕西省西安市,1961年被国务院公布为第一批全国重点文物保护单位,是世界八大奇迹之一。 | 秦始皇 | 陕西省西安市 | 国务院 |
|
||||
| Sentence | Person | Location | Organization |
|
||||
| :----------------------------------------------------------: | :------------------------: | :------------: | :----------------------------: |
|
||||
| 本报北京9月4日讯记者杨涌报道:部分省区人民日报宣传发行工作座谈会9月3日在4日在京举行。 | 杨涌 | 北京 | 人民日报 |
|
||||
| 《红楼梦》是中央电视台和中国电视剧制作中心根据中国古典文学名著《红楼梦》摄制于1987年的一部古装连续剧,由王扶林导演,周汝昌、王蒙、周岭等多位红学家参与制作。 | 王扶林,周汝昌,王蒙,周岭 | 中国 | 中央电视台,中国电视剧制作中心 |
|
||||
| 秦始皇兵马俑位于陕西省西安市,1961年被国务院公布为第一批全国重点文物保护单位,是世界八大奇迹之一。 | 秦始皇 | 陕西省,西安市 | 国务院 |
|
||||
|
||||
具体流程请进入详细的README中:
|
||||
|
||||
|
|
|
@ -5,11 +5,11 @@
|
|||
> python == 3.8
|
||||
|
||||
- pytorch-transformers == 1.2.0
|
||||
- torch == 1.2.0
|
||||
- torch == 1.5.0
|
||||
- hydra-core == 1.0.6
|
||||
- seqeval == 0.0.5
|
||||
- tqdm == 4.31.1
|
||||
- nltk == 3.4.5
|
||||
- matplotlib == 3.4.1
|
||||
- deepke
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
data_dir: "data/"
|
||||
bert_model: "bert-base-chinese" # ["bert-base-chinese", "bert-base-cased"]
|
||||
language: "cn" # ["cn", "en"]
|
||||
bert_model: "bert-base-chinese"
|
||||
task_name: "ner"
|
||||
output_dir: "checkpoint"
|
||||
output_dir: "checkpoints"
|
||||
max_seq_length: 128
|
||||
do_train: True
|
||||
do_eval: True
|
||||
|
@ -11,17 +10,16 @@ do_lower_case: True
|
|||
train_batch_size: 32
|
||||
eval_batch_size: 8
|
||||
learning_rate: 5e-5
|
||||
num_train_epochs: 1 # the number of training epochs
|
||||
num_train_epochs: 3 # the number of training epochs
|
||||
warmup_proportion: 0.1
|
||||
weight_decay: 0.01
|
||||
adam_epsilon: 1e-8
|
||||
max_grad_norm: 1.0
|
||||
no_cuda: False
|
||||
use_gpu: True # use gpu or not
|
||||
gpu_id: 1
|
||||
local_rank: -1
|
||||
seed: 42
|
||||
gradient_accumulation_steps: 1
|
||||
fp16: False
|
||||
fp16_opt_level: "01"
|
||||
loss_scale: 0.0
|
||||
use_gpu: True
|
||||
gpu_id: 1
|
|
@ -1,28 +0,0 @@
|
|||
## People's Daily(人民日报) dataset
|
||||
### Task
|
||||
Named Entity Recognition
|
||||
### Description
|
||||
**Tags**: LOC(地名), ORG(机构名), PER(人名)
|
||||
**Tag Strategy**:BIO
|
||||
**Split**: '*space*' (北 B-LOC)
|
||||
**Data Size**:
|
||||
Train data set ( [example.train](example.train) ):
|
||||
|
||||
|句数|字符数|LOC数|ORG数|PER数|
|
||||
|:-:|:-:|:-:|:-:|:-:|
|
||||
|20864|979180|16571|9277|8144|
|
||||
|
||||
Dev data set ( [example.dev](example.dev) ):
|
||||
|
||||
|句数|字符数|LOC数|ORG数|PER数|
|
||||
|:-:|:-:|:-:|:-:|:-:|
|
||||
|2318|109870|1951|984|884|
|
||||
|
||||
Test data set ( [example.test](example.test) )
|
||||
|
||||
|句数|字符数|LOC数|ORG数|PER数|
|
||||
|:-:|:-:|:-:|:-:|:-:|
|
||||
|4636|219197|3658|2185|1864|
|
||||
|
||||
**Reference**:
|
||||
<https://github.com/zjy-ucas/ChineseNER>
|
Binary file not shown.
|
@ -1,98 +0,0 @@
|
|||
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
|
||||
|
||||
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
|
||||
been contributed by various people using NLTK for sentence boundary detection.
|
||||
|
||||
For information about how to use these models, please confer the tokenization HOWTO:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
|
||||
and chapter 3.8 of the NLTK book:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
|
||||
|
||||
There are pretrained tokenizers for the following languages:
|
||||
|
||||
File Language Source Contents Size of training corpus(in tokens) Model contributed by
|
||||
=======================================================================================================================================================================
|
||||
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
|
||||
Literarni Noviny
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
|
||||
(Berlingske Avisdata, Copenhagen) Weekend Avisen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
|
||||
(American)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
|
||||
Text Bank (Suomen Kielen newspapers
|
||||
Tekstipankki)
|
||||
Finnish Center for IT Science
|
||||
(CSC)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
|
||||
(Switzerland) CD-ROM
|
||||
(Uses "ss"
|
||||
instead of "ß")
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
|
||||
(Bokmål and Information Technologies,
|
||||
Nynorsk) Bergen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
|
||||
(http://www.nkjp.pl/)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
|
||||
(Brazilian) (Linguateca)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
|
||||
Slovene Academy for Arts
|
||||
and Sciences
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
|
||||
(and some other texts)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
|
||||
(Türkçe Derlem Projesi)
|
||||
University of Ankara
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
|
||||
Unicode using the codecs module.
|
||||
|
||||
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
|
||||
Computational Linguistics 32: 485-525.
|
||||
|
||||
---- Training Code ----
|
||||
|
||||
# import punkt
|
||||
import nltk.tokenize.punkt
|
||||
|
||||
# Make a new Tokenizer
|
||||
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
|
||||
|
||||
# Read in training corpus (one example: Slovene)
|
||||
import codecs
|
||||
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
|
||||
|
||||
# Train tokenizer
|
||||
tokenizer.train(text)
|
||||
|
||||
# Dump pickled tokenizer
|
||||
import pickle
|
||||
out = open("slovene.pickle","wb")
|
||||
pickle.dump(tokenizer, out)
|
||||
out.close()
|
||||
|
||||
---------
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,98 +0,0 @@
|
|||
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
|
||||
|
||||
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
|
||||
been contributed by various people using NLTK for sentence boundary detection.
|
||||
|
||||
For information about how to use these models, please confer the tokenization HOWTO:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
|
||||
and chapter 3.8 of the NLTK book:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
|
||||
|
||||
There are pretrained tokenizers for the following languages:
|
||||
|
||||
File Language Source Contents Size of training corpus(in tokens) Model contributed by
|
||||
=======================================================================================================================================================================
|
||||
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
|
||||
Literarni Noviny
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
|
||||
(Berlingske Avisdata, Copenhagen) Weekend Avisen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
|
||||
(American)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
|
||||
Text Bank (Suomen Kielen newspapers
|
||||
Tekstipankki)
|
||||
Finnish Center for IT Science
|
||||
(CSC)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
|
||||
(Switzerland) CD-ROM
|
||||
(Uses "ss"
|
||||
instead of "ß")
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
|
||||
(Bokmål and Information Technologies,
|
||||
Nynorsk) Bergen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
|
||||
(http://www.nkjp.pl/)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
|
||||
(Brazilian) (Linguateca)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
|
||||
Slovene Academy for Arts
|
||||
and Sciences
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
|
||||
(and some other texts)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
|
||||
(Türkçe Derlem Projesi)
|
||||
University of Ankara
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
|
||||
Unicode using the codecs module.
|
||||
|
||||
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
|
||||
Computational Linguistics 32: 485-525.
|
||||
|
||||
---- Training Code ----
|
||||
|
||||
# import punkt
|
||||
import nltk.tokenize.punkt
|
||||
|
||||
# Make a new Tokenizer
|
||||
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
|
||||
|
||||
# Read in training corpus (one example: Slovene)
|
||||
import codecs
|
||||
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
|
||||
|
||||
# Train tokenizer
|
||||
tokenizer.train(text)
|
||||
|
||||
# Dump pickled tokenizer
|
||||
import pickle
|
||||
out = open("slovene.pickle","wb")
|
||||
pickle.dump(tokenizer, out)
|
||||
out.close()
|
||||
|
||||
---------
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,169 +1,10 @@
|
|||
"""BERT NER Inference."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from nltk import word_tokenize
|
||||
from pytorch_transformers import (BertConfig, BertForTokenClassification,
|
||||
BertTokenizer)
|
||||
from collections import OrderedDict
|
||||
import argparse
|
||||
|
||||
import nltk
|
||||
nltk.data.path.insert(0,'./data/nltk_data')
|
||||
|
||||
from deepke.name_entity_re.standard import *
|
||||
import hydra
|
||||
from hydra import utils
|
||||
|
||||
|
||||
class BertNer(BertForTokenClassification):
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, valid_ids=None):
|
||||
sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0]
|
||||
batch_size,max_len,feat_dim = sequence_output.shape
|
||||
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
for i in range(batch_size):
|
||||
jj = -1
|
||||
for j in range(max_len):
|
||||
if valid_ids[i][j].item() == 1:
|
||||
jj += 1
|
||||
valid_output[i][jj] = sequence_output[i][j]
|
||||
sequence_output = self.dropout(valid_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
return logits
|
||||
|
||||
class InferNer:
|
||||
|
||||
def __init__(self,model_dir:str,language):
|
||||
self.model , self.tokenizer, self.model_config = self.load_model(model_dir)
|
||||
self.label_map = self.model_config["label_map"]
|
||||
self.max_seq_length = self.model_config["max_seq_length"]
|
||||
self.label_map = {int(k):v for k,v in self.label_map.items()}
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.language = language
|
||||
|
||||
def load_model(self, model_dir: str, model_config: str = "model_config.json"):
|
||||
model_config = os.path.join(model_dir,model_config)
|
||||
model_config = json.load(open(model_config))
|
||||
model = BertNer.from_pretrained(model_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=model_config["do_lower"])
|
||||
return model, tokenizer, model_config
|
||||
|
||||
def tokenize(self, text: str):
|
||||
""" tokenize input"""
|
||||
words = list(text)
|
||||
tokens = []
|
||||
valid_positions = []
|
||||
for i,word in enumerate(words):
|
||||
token = self.tokenizer.tokenize(word)
|
||||
tokens.extend(token)
|
||||
for i in range(len(token)):
|
||||
if i == 0:
|
||||
valid_positions.append(1)
|
||||
else:
|
||||
valid_positions.append(0)
|
||||
return tokens, valid_positions
|
||||
|
||||
def preprocess(self, text: str):
|
||||
""" preprocess """
|
||||
tokens, valid_positions = self.tokenize(text)
|
||||
## insert "[CLS]"
|
||||
tokens.insert(0,"[CLS]")
|
||||
valid_positions.insert(0,1)
|
||||
## insert "[SEP]"
|
||||
tokens.append("[SEP]")
|
||||
valid_positions.append(1)
|
||||
segment_ids = []
|
||||
for i in range(len(tokens)):
|
||||
segment_ids.append(0)
|
||||
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
while len(input_ids) < self.max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
valid_positions.append(0)
|
||||
return input_ids,input_mask,segment_ids,valid_positions
|
||||
|
||||
def predict(self, text: str):
|
||||
input_ids,input_mask,segment_ids,valid_ids = self.preprocess(text)
|
||||
input_ids = torch.tensor([input_ids],dtype=torch.long,device=self.device)
|
||||
input_mask = torch.tensor([input_mask],dtype=torch.long,device=self.device)
|
||||
segment_ids = torch.tensor([segment_ids],dtype=torch.long,device=self.device)
|
||||
valid_ids = torch.tensor([valid_ids],dtype=torch.long,device=self.device)
|
||||
with torch.no_grad():
|
||||
logits = self.model(input_ids, segment_ids, input_mask,valid_ids)
|
||||
logits = F.softmax(logits,dim=2)
|
||||
logits_label = torch.argmax(logits,dim=2)
|
||||
logits_label = logits_label.detach().cpu().numpy().tolist()[0]
|
||||
|
||||
logits_confidence = [values[label].item() for values,label in zip(logits[0],logits_label)]
|
||||
|
||||
logits = []
|
||||
pos = 0
|
||||
for index,mask in enumerate(valid_ids[0]):
|
||||
if index == 0:
|
||||
continue
|
||||
if mask == 1:
|
||||
logits.append((logits_label[index-pos],logits_confidence[index-pos]))
|
||||
else:
|
||||
pos += 1
|
||||
logits.pop()
|
||||
|
||||
labels = [(self.label_map[label],confidence) for label,confidence in logits]
|
||||
if self.language == 'bert-base-chinese':
|
||||
words = list(text)
|
||||
else:
|
||||
nltk.download('punkt')
|
||||
words = word_tokenize(text)
|
||||
assert len(labels) == len(words)
|
||||
|
||||
result = []
|
||||
for word, (label, confidence) in zip(words, labels):
|
||||
if label!='O':
|
||||
result.append((word,label))
|
||||
tmp = []
|
||||
tag = OrderedDict()
|
||||
tag['PER'] = []
|
||||
tag['LOC'] = []
|
||||
tag['ORG'] = []
|
||||
tag['MISC'] = []
|
||||
|
||||
for i, (word, label) in enumerate(result):
|
||||
if label=='B-PER' or label=='B-LOC' or label=='B-ORG' or label=='B-MISC':
|
||||
if i==0:
|
||||
tmp.append(word)
|
||||
else:
|
||||
wordstype = result[i-1][1][2:]
|
||||
if self.language=='bert-base-chinese':
|
||||
tag[wordstype].append(''.join(tmp))
|
||||
else:
|
||||
tag[wordstype].append(' '.join(tmp))
|
||||
tmp.clear()
|
||||
tmp.append(word)
|
||||
elif i==len(result)-1:
|
||||
tmp.append(word)
|
||||
wordstype = result[i][1][2:]
|
||||
if self.language=='bert-base-chinese':
|
||||
tag[wordstype].append(''.join(tmp))
|
||||
else:
|
||||
tag[wordstype].append(' '.join(tmp))
|
||||
else:
|
||||
tmp.append(word)
|
||||
|
||||
return tag
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_path="conf", config_name='config')
|
||||
def main(cfg):
|
||||
model = InferNer(utils.get_original_cwd()+'/'+"checkpoint/", cfg.bert_model)
|
||||
model = InferNer(utils.get_original_cwd()+'/'+"checkpoints/")
|
||||
text = cfg.text
|
||||
|
||||
print("NER句子:")
|
||||
|
@ -180,9 +21,7 @@ def main(cfg):
|
|||
print('Location')
|
||||
elif k=='ORG':
|
||||
print('Organization')
|
||||
elif k=='MISC':
|
||||
print('Miscellaneous')
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
pytorch-transformers==1.2.0
|
||||
torch==1.2.0
|
||||
torch==1.5.0
|
||||
hydra-core==1.0.6
|
||||
seqeval==0.0.5
|
||||
tqdm==4.31.1
|
||||
nltk==3.4.5
|
||||
matplotlib==3.4.1
|
||||
deepke
|
||||
|
|
|
@ -6,35 +6,31 @@ import logging
|
|||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig,
|
||||
BertForTokenClassification, BertTokenizer,
|
||||
WarmupLinearSchedule)
|
||||
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig, BertForTokenClassification, BertTokenizer, WarmupLinearSchedule)
|
||||
from torch import nn
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from seqeval.metrics import classification_report
|
||||
import hydra
|
||||
from hydra import utils
|
||||
from deepke.name_entity_re.standard import *
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import hydra
|
||||
from hydra import utils
|
||||
|
||||
class TrainNer(BertForTokenClassification):
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):
|
||||
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
|
||||
batch_size,max_len,feat_dim = sequence_output.shape
|
||||
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=1) #device 如用gpu需要修改为cfg.gpu_id的值 不用则为cpu
|
||||
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=1) #device=cfg.gpu_id if use_gpu else 'cpu'
|
||||
for i in range(batch_size):
|
||||
jj = -1
|
||||
for j in range(max_len):
|
||||
|
@ -46,8 +42,6 @@ class TrainNer(BertForTokenClassification):
|
|||
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss(ignore_index=0)
|
||||
# Only keep active parts of the loss
|
||||
#attention_mask_label = None
|
||||
if attention_mask_label is not None:
|
||||
active_loss = attention_mask_label.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
|
@ -60,202 +54,17 @@ class TrainNer(BertForTokenClassification):
|
|||
return logits
|
||||
|
||||
|
||||
class InputExample(object):
|
||||
"""A single training/test example for simple sequence classification."""
|
||||
|
||||
def __init__(self, guid, text_a, text_b=None, label=None):
|
||||
"""Constructs a InputExample.
|
||||
Args:
|
||||
guid: Unique id for the example.
|
||||
text_a: string. The untokenized text of the first sequence. For single
|
||||
sequence tasks, only this sequence must be specified.
|
||||
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||
Only must be specified for sequence pair tasks.
|
||||
label: (Optional) string. The label of the example. This should be
|
||||
specified for train and dev examples, but not for test examples.
|
||||
"""
|
||||
self.guid = guid
|
||||
self.text_a = text_a
|
||||
self.text_b = text_b
|
||||
self.label = label
|
||||
|
||||
class InputFeatures(object):
|
||||
"""A single set of features of data."""
|
||||
|
||||
def __init__(self, input_ids, input_mask, segment_ids, label_id, valid_ids=None, label_mask=None):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_id = label_id
|
||||
self.valid_ids = valid_ids
|
||||
self.label_mask = label_mask
|
||||
|
||||
def readfile(filename):
|
||||
'''
|
||||
read file
|
||||
'''
|
||||
f = open(filename)
|
||||
data = []
|
||||
sentence = []
|
||||
label= []
|
||||
for line in f:
|
||||
if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n":
|
||||
if len(sentence) > 0:
|
||||
data.append((sentence,label))
|
||||
sentence = []
|
||||
label = []
|
||||
continue
|
||||
splits = line.split(' ')
|
||||
sentence.append(splits[0])
|
||||
label.append(splits[-1][:-1])
|
||||
|
||||
if len(sentence) >0:
|
||||
data.append((sentence,label))
|
||||
sentence = []
|
||||
label = []
|
||||
return data
|
||||
|
||||
class DataProcessor(object):
|
||||
"""Base class for data converters for sequence classification data sets."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self):
|
||||
"""Gets the list of labels for this data set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def _read_tsv(cls, input_file, quotechar=None):
|
||||
"""Reads a tab separated value file."""
|
||||
return readfile(input_file)
|
||||
|
||||
|
||||
class NerProcessor(DataProcessor):
|
||||
"""Processor for the dataset."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.txt")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "valid.txt")), "dev")
|
||||
|
||||
def get_test_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "test.txt")), "test")
|
||||
|
||||
def get_labels(self):
|
||||
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "[CLS]", "[SEP]"]
|
||||
|
||||
def _create_examples(self,lines,set_type):
|
||||
examples = []
|
||||
for i,(sentence,label) in enumerate(lines):
|
||||
guid = "%s-%s" % (set_type, i)
|
||||
text_a = ' '.join(sentence)
|
||||
text_b = None
|
||||
label = label
|
||||
examples.append(InputExample(guid=guid,text_a=text_a,text_b=text_b,label=label))
|
||||
return examples
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
label_map = {label : i for i, label in enumerate(label_list,1)}
|
||||
|
||||
features = []
|
||||
for (ex_index,example) in enumerate(examples):
|
||||
textlist = example.text_a.split(' ')
|
||||
labellist = example.label
|
||||
tokens = []
|
||||
labels = []
|
||||
valid = []
|
||||
label_mask = []
|
||||
for i, word in enumerate(textlist):
|
||||
token = tokenizer.tokenize(word)
|
||||
tokens.extend(token)
|
||||
label_1 = labellist[i]
|
||||
for m in range(len(token)):
|
||||
if m == 0:
|
||||
labels.append(label_1)
|
||||
valid.append(1)
|
||||
label_mask.append(1)
|
||||
else:
|
||||
valid.append(0)
|
||||
if len(tokens) >= max_seq_length - 1:
|
||||
tokens = tokens[0:(max_seq_length - 2)]
|
||||
labels = labels[0:(max_seq_length - 2)]
|
||||
valid = valid[0:(max_seq_length - 2)]
|
||||
label_mask = label_mask[0:(max_seq_length - 2)]
|
||||
ntokens = []
|
||||
segment_ids = []
|
||||
label_ids = []
|
||||
ntokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
valid.insert(0,1)
|
||||
label_mask.insert(0,1)
|
||||
label_ids.append(label_map["[CLS]"])
|
||||
for i, token in enumerate(tokens):
|
||||
ntokens.append(token)
|
||||
segment_ids.append(0)
|
||||
if len(labels) > i:
|
||||
label_ids.append(label_map[labels[i]])
|
||||
ntokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
valid.append(1)
|
||||
label_mask.append(1)
|
||||
label_ids.append(label_map["[SEP]"])
|
||||
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
|
||||
input_mask = [1] * len(input_ids)
|
||||
label_mask = [1] * len(label_ids)
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
label_ids.append(0)
|
||||
valid.append(1)
|
||||
label_mask.append(0)
|
||||
while len(label_ids) < max_seq_length:
|
||||
label_ids.append(0)
|
||||
label_mask.append(0)
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
assert len(label_ids) == max_seq_length
|
||||
assert len(valid) == max_seq_length
|
||||
assert len(label_mask) == max_seq_length
|
||||
|
||||
features.append(
|
||||
InputFeatures(input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
label_id=label_ids,
|
||||
valid_ids=valid,
|
||||
label_mask=label_mask))
|
||||
return features
|
||||
|
||||
|
||||
@hydra.main(config_path="conf", config_name='config')
|
||||
def main(cfg):
|
||||
processors = {"ner":NerProcessor}
|
||||
|
||||
|
||||
# Use gpu or not
|
||||
if cfg.use_gpu and torch.cuda.is_available():
|
||||
device = torch.device('cuda', cfg.gpu_id)
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
|
||||
if cfg.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
cfg.gradient_accumulation_steps))
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(cfg.gradient_accumulation_steps))
|
||||
|
||||
cfg.train_batch_size = cfg.train_batch_size // cfg.gradient_accumulation_steps
|
||||
|
||||
|
@ -266,28 +75,25 @@ def main(cfg):
|
|||
if not cfg.do_train and not cfg.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
# Checkpoints
|
||||
if os.path.exists(utils.get_original_cwd()+'/'+cfg.output_dir) and os.listdir(utils.get_original_cwd()+'/'+cfg.output_dir) and cfg.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(utils.get_original_cwd()+'/'+cfg.output_dir))
|
||||
if not os.path.exists(utils.get_original_cwd()+'/'+cfg.output_dir):
|
||||
os.makedirs(utils.get_original_cwd()+'/'+cfg.output_dir)
|
||||
|
||||
task_name = cfg.task_name.lower()
|
||||
|
||||
if task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (task_name))
|
||||
|
||||
processor = processors[task_name]()
|
||||
# Preprocess the input dataset
|
||||
processor = NerProcessor()
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list) + 1
|
||||
|
||||
|
||||
# Prepare the model
|
||||
tokenizer = BertTokenizer.from_pretrained(cfg.bert_model, do_lower_case=cfg.do_lower_case)
|
||||
|
||||
train_examples = None
|
||||
num_train_optimization_steps = 0
|
||||
if cfg.do_train:
|
||||
train_examples = processor.get_train_examples(utils.get_original_cwd()+'/'+cfg.data_dir)
|
||||
num_train_optimization_steps = int(
|
||||
len(train_examples) / cfg.train_batch_size / cfg.gradient_accumulation_steps) * cfg.num_train_epochs
|
||||
num_train_optimization_steps = int(len(train_examples) / cfg.train_batch_size / cfg.gradient_accumulation_steps) * cfg.num_train_epochs
|
||||
|
||||
config = BertConfig.from_pretrained(cfg.bert_model, num_labels=num_labels, finetuning_task=cfg.task_name)
|
||||
model = TrainNer.from_pretrained(cfg.bert_model,from_tf = False,config = config)
|
||||
|
|
2
setup.py
2
setup.py
|
@ -1,7 +1,7 @@
|
|||
from setuptools import setup, find_packages
|
||||
setup(
|
||||
name='deepke', # 打包后的包文件名
|
||||
version='0.2.67', #版本号
|
||||
version='0.2.79', #版本号
|
||||
keywords=["pip", "RE","NER","AE"], # 关键字
|
||||
description='DeepKE 是基于 Pytorch 的深度学习中文关系抽取处理套件。', # 说明
|
||||
long_description="client", #详细说明
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig,
|
||||
BertForTokenClassification, BertTokenizer,
|
||||
WarmupLinearSchedule)
|
||||
from torch import nn
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
from seqeval.metrics import classification_report
|
||||
|
||||
class TrainNer(BertForTokenClassification):
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):
|
||||
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
|
||||
batch_size,max_len,feat_dim = sequence_output.shape
|
||||
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda')
|
||||
for i in range(batch_size):
|
||||
jj = -1
|
||||
for j in range(max_len):
|
||||
if valid_ids[i][j].item() == 1:
|
||||
jj += 1
|
||||
valid_output[i][jj] = sequence_output[i][j]
|
||||
sequence_output = self.dropout(valid_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss(ignore_index=0)
|
||||
# Only keep active parts of the loss
|
||||
#attention_mask_label = None
|
||||
if attention_mask_label is not None:
|
||||
active_loss = attention_mask_label.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss
|
||||
else:
|
||||
return logits
|
|
@ -33,7 +33,7 @@ class BertNer(BertForTokenClassification):
|
|||
|
||||
class InferNer:
|
||||
|
||||
def __init__(self,model_dir:str):
|
||||
def __init__(self,model_dir: str):
|
||||
self.model , self.tokenizer, self.model_config = self.load_model(model_dir)
|
||||
self.label_map = self.model_config["label_map"]
|
||||
self.max_seq_length = self.model_config["max_seq_length"]
|
||||
|
@ -136,12 +136,8 @@ class InferNer:
|
|||
elif i==len(result)-1:
|
||||
tmp.append(word)
|
||||
wordstype = result[i][1][2:]
|
||||
if self.language=='bert-base-chinese':
|
||||
tag[wordstype].append(''.join(tmp))
|
||||
else:
|
||||
tag[wordstype].append(' '.join(tmp))
|
||||
tag[wordstype].append(''.join(tmp))
|
||||
else:
|
||||
tmp.append(word)
|
||||
|
||||
return tag
|
||||
|
||||
|
|
|
@ -1,2 +1 @@
|
|||
from .BasicNer import *
|
||||
from .InferBert import *
|
||||
|
|
|
@ -1,3 +1,2 @@
|
|||
from .dataset import *
|
||||
from .preprocess import *
|
||||
from .trainer import *
|
||||
from .preprocess import *
|
Binary file not shown.
|
@ -1,98 +0,0 @@
|
|||
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
|
||||
|
||||
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
|
||||
been contributed by various people using NLTK for sentence boundary detection.
|
||||
|
||||
For information about how to use these models, please confer the tokenization HOWTO:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
|
||||
and chapter 3.8 of the NLTK book:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
|
||||
|
||||
There are pretrained tokenizers for the following languages:
|
||||
|
||||
File Language Source Contents Size of training corpus(in tokens) Model contributed by
|
||||
=======================================================================================================================================================================
|
||||
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
|
||||
Literarni Noviny
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
|
||||
(Berlingske Avisdata, Copenhagen) Weekend Avisen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
|
||||
(American)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
|
||||
Text Bank (Suomen Kielen newspapers
|
||||
Tekstipankki)
|
||||
Finnish Center for IT Science
|
||||
(CSC)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
|
||||
(Switzerland) CD-ROM
|
||||
(Uses "ss"
|
||||
instead of "ß")
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
|
||||
(Bokmål and Information Technologies,
|
||||
Nynorsk) Bergen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
|
||||
(http://www.nkjp.pl/)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
|
||||
(Brazilian) (Linguateca)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
|
||||
Slovene Academy for Arts
|
||||
and Sciences
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
|
||||
(and some other texts)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
|
||||
(Türkçe Derlem Projesi)
|
||||
University of Ankara
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
|
||||
Unicode using the codecs module.
|
||||
|
||||
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
|
||||
Computational Linguistics 32: 485-525.
|
||||
|
||||
---- Training Code ----
|
||||
|
||||
# import punkt
|
||||
import nltk.tokenize.punkt
|
||||
|
||||
# Make a new Tokenizer
|
||||
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
|
||||
|
||||
# Read in training corpus (one example: Slovene)
|
||||
import codecs
|
||||
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
|
||||
|
||||
# Train tokenizer
|
||||
tokenizer.train(text)
|
||||
|
||||
# Dump pickled tokenizer
|
||||
import pickle
|
||||
out = open("slovene.pickle","wb")
|
||||
pickle.dump(tokenizer, out)
|
||||
out.close()
|
||||
|
||||
---------
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,98 +0,0 @@
|
|||
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
|
||||
|
||||
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
|
||||
been contributed by various people using NLTK for sentence boundary detection.
|
||||
|
||||
For information about how to use these models, please confer the tokenization HOWTO:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
|
||||
and chapter 3.8 of the NLTK book:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
|
||||
|
||||
There are pretrained tokenizers for the following languages:
|
||||
|
||||
File Language Source Contents Size of training corpus(in tokens) Model contributed by
|
||||
=======================================================================================================================================================================
|
||||
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
|
||||
Literarni Noviny
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
|
||||
(Berlingske Avisdata, Copenhagen) Weekend Avisen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
|
||||
(American)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
|
||||
Text Bank (Suomen Kielen newspapers
|
||||
Tekstipankki)
|
||||
Finnish Center for IT Science
|
||||
(CSC)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
|
||||
(Switzerland) CD-ROM
|
||||
(Uses "ss"
|
||||
instead of "ß")
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
|
||||
(Bokmål and Information Technologies,
|
||||
Nynorsk) Bergen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
|
||||
(http://www.nkjp.pl/)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
|
||||
(Brazilian) (Linguateca)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
|
||||
Slovene Academy for Arts
|
||||
and Sciences
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
|
||||
(and some other texts)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
|
||||
(Türkçe Derlem Projesi)
|
||||
University of Ankara
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
|
||||
Unicode using the codecs module.
|
||||
|
||||
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
|
||||
Computational Linguistics 32: 485-525.
|
||||
|
||||
---- Training Code ----
|
||||
|
||||
# import punkt
|
||||
import nltk.tokenize.punkt
|
||||
|
||||
# Make a new Tokenizer
|
||||
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
|
||||
|
||||
# Read in training corpus (one example: Slovene)
|
||||
import codecs
|
||||
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
|
||||
|
||||
# Train tokenizer
|
||||
tokenizer.train(text)
|
||||
|
||||
# Dump pickled tokenizer
|
||||
import pickle
|
||||
out = open("slovene.pickle","wb")
|
||||
pickle.dump(tokenizer, out)
|
||||
out.close()
|
||||
|
||||
---------
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -28,35 +28,49 @@
|
|||
"metadata": {},
|
||||
"source": [
|
||||
"## Dataset\n",
|
||||
"In this demo, we use dataset [**CoNLL-2003**](https://www.clips.uantwerpen.be/conll2003/ner/). It is a dataset for NER, concentrating on four types of named entities related to persons, locations, organizations, and names of miscellaneous entities.\n",
|
||||
"In this demo, we use [**People's Daily(人民日报) dataset**](https://github.com/OYE93/Chinese-NLP-Corpus/tree/master/NER/People's%20Daily). It is a dataset for NER, concentrating on their types of named entities related to persons(PER), locations(LOC), and organizations(ORG).\n",
|
||||
"\n",
|
||||
"| Word | Part-of-speech (POS) tag | Syntactic chunk tag | Named entity tag |\n",
|
||||
"| :--------: | :----------------------: | :-----------------: | :--------------: |\n",
|
||||
"| Pakistan | NNP | B-NP | B-LOC |\n",
|
||||
"| , | , | O | O |\n",
|
||||
"| who | WP | B-NP | O |\n",
|
||||
"| arrive | VBP | B-VP | O |\n",
|
||||
"| in | IN | B-PP | O |\n",
|
||||
"| Australia | NNP | B-NP | B-LOC |\n",
|
||||
"| later | JJ | B-NP | O |\n",
|
||||
"| this | DT | I-NP | O |\n",
|
||||
"| month | NN | I-NP | O |\n",
|
||||
"| , | , | O | O |\n",
|
||||
"| are | VBP | B-VP | O |\n",
|
||||
"| the | DT | B-NP | O |\n",
|
||||
"| other | JJ | I-NP | O |\n",
|
||||
"| team | NN | I-NP | O |\n",
|
||||
"| competing | VBG | B-VP | O |\n",
|
||||
"| in | IN | B-PP | O |\n",
|
||||
"| the | DT | B-NP | O |\n",
|
||||
"| World | NNP | I-NP | B-MISC |\n",
|
||||
"| Series | NNP | I-NP | I-MISC |\n",
|
||||
"| tournament | NN | I-NP | O |\n",
|
||||
"| . | . | O | O |\n",
|
||||
"| Word | Named entity tag |\n",
|
||||
"| :--: | :--------------: |\n",
|
||||
"| 早 | O |\n",
|
||||
"| 在 | O |\n",
|
||||
"| 1 | O |\n",
|
||||
"| 9 | O |\n",
|
||||
"| 7 | O |\n",
|
||||
"| 5 | O |\n",
|
||||
"| 年 | O |\n",
|
||||
"| , | O |\n",
|
||||
"| 张 | B-PER |\n",
|
||||
"| 鸿 | I-PER |\n",
|
||||
"| 飞 | I-PER |\n",
|
||||
"| 就 | O |\n",
|
||||
"| 有 | O |\n",
|
||||
"| 《 | O |\n",
|
||||
"| 草 | O |\n",
|
||||
"| 原 | O |\n",
|
||||
"| 新 | O |\n",
|
||||
"| 医 | O |\n",
|
||||
"| 》 | O |\n",
|
||||
"| 赴 | O |\n",
|
||||
"| 法 | B-LOC |\n",
|
||||
"| 展 | O |\n",
|
||||
"| 览 | O |\n",
|
||||
"| , | O |\n",
|
||||
"| 为 | O |\n",
|
||||
"| 我 | O |\n",
|
||||
"| 国 | O |\n",
|
||||
"| 驻 | B-ORG |\n",
|
||||
"| 法 | I-ORG |\n",
|
||||
"| 使 | I-ORG |\n",
|
||||
"| 馆 | I-ORG |\n",
|
||||
"| 收 | O |\n",
|
||||
"| 藏 | O |\n",
|
||||
"| 。 | O |\n",
|
||||
"\n",
|
||||
"- train.txt: It contains 14,987 sentences\n",
|
||||
"- valid.txt: It contains 3,466 sentences\n",
|
||||
"- test.txt: It contains 3,684 sentences"
|
||||
"\n",
|
||||
"- train.txt: It contains 20,864 sentences, including 979,180 named entity tags.\n",
|
||||
"- valid.txt: It contains 2,318 sentences, including 109,870 named entity tags.\n",
|
||||
"- test.txt: It contains 4,636 sentences, including 219,197 named entity tags."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -65,7 +79,7 @@
|
|||
"metadata": {},
|
||||
"source": [
|
||||
"## BERT\n",
|
||||
"[**Bidirectional Encoder Representations from Transformers (BERT)**](https://github.com/google-research/bert) is a transformer-based machine learning technique for natural language processing (NLP) pre-training developed by Google.\n",
|
||||
"[**Bidirectional Encoder Representations from Transformers (BERT)**](https://github.com/google-research/bert) \n",
|
||||
"\n",
|
||||
""
|
||||
]
|
||||
|
@ -81,15 +95,17 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "bd23f8d2",
|
||||
"id": "ddb0f3e4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install pytorch-transformers==1.2.0\n",
|
||||
"!pip install torch==1.2.0\n",
|
||||
"!pip install torch==1.5.0\n",
|
||||
"!pip install hydra-core==1.0.6\n",
|
||||
"!pip install seqeval==0.0.5\n",
|
||||
"!pip install tqdm==4.31.1\n",
|
||||
"!pip install nltk==3.4.5"
|
||||
"!pip install matplotlib==3.4.1\n",
|
||||
"!pip install deepke"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -103,12 +119,12 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f33f17d8",
|
||||
"id": "e21416ae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from __future__ import absolute_import, division, print_function\n",
|
||||
"import argparse\n",
|
||||
"\n",
|
||||
"import csv\n",
|
||||
"import json\n",
|
||||
"import logging\n",
|
||||
|
@ -124,6 +140,9 @@
|
|||
"from torch.utils.data.distributed import DistributedSampler\n",
|
||||
"from tqdm import tqdm, trange\n",
|
||||
"from seqeval.metrics import classification_report\n",
|
||||
"import hydra\n",
|
||||
"from hydra import utils\n",
|
||||
"from deepke.name_entity_re.standard import *\n",
|
||||
"\n",
|
||||
"logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',\n",
|
||||
" datefmt = '%m/%d/%Y %H:%M:%S',\n",
|
||||
|
@ -133,235 +152,71 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "de20573b",
|
||||
"id": "029e661b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Preprocess the dataset"
|
||||
"## Configure model parameters"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "01ff2c72",
|
||||
"id": "2178ffdf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def readfile(filename):\n",
|
||||
" '''\n",
|
||||
" read file\n",
|
||||
" '''\n",
|
||||
" f = open(filename)\n",
|
||||
" data = []\n",
|
||||
" sentence = []\n",
|
||||
" label= []\n",
|
||||
" for line in f:\n",
|
||||
" if len(line)==0 or line.startswith('-DOCSTART') or line[0]==\"\\n\":\n",
|
||||
" if len(sentence) > 0:\n",
|
||||
" data.append((sentence,label))\n",
|
||||
" sentence = []\n",
|
||||
" label = []\n",
|
||||
" continue\n",
|
||||
" splits = line.split(' ')\n",
|
||||
" sentence.append(splits[0])\n",
|
||||
" label.append(splits[-1][:-1])\n",
|
||||
"class Config(object):\n",
|
||||
" data_dir = \"data/\" # The input data dir\n",
|
||||
" bert_model = \"bert-base-chinese\"\n",
|
||||
" task_name = \"ner\"\n",
|
||||
" output_dir = \"checkpoints\"\n",
|
||||
" max_seq_length = 128\n",
|
||||
" do_train = True # Fine-tune or not\n",
|
||||
" do_eval = True # Evaluate or not\n",
|
||||
" eval_on = \"dev\"\n",
|
||||
" do_lower_case = True\n",
|
||||
" train_batch_size = 32\n",
|
||||
" eval_batch_size = 8\n",
|
||||
" learning_rate = 5e-5\n",
|
||||
" num_train_epochs = 3 # The number of training epochs\n",
|
||||
" warmup_proportion = 0.1\n",
|
||||
" weight_decay = 0.01\n",
|
||||
" adam_epsilon = 1e-8\n",
|
||||
" max_grad_norm = 1.0\n",
|
||||
" use_gpu = True # Use gpu or not\n",
|
||||
" gpu_id = 1 # Which gpu to be used\n",
|
||||
" local_rank = -1\n",
|
||||
" seed = 42\n",
|
||||
" gradient_accumulation_steps = 1\n",
|
||||
" fp16 = False\n",
|
||||
" fp16_opt_level = \"01\"\n",
|
||||
" loss_scale = 0.0\n",
|
||||
" text = \"秦始皇兵马俑位于陕西省西安市,1961年被国务院公布为第一批全国重点文物保护单位,是世界八大奇迹之一。\"\n",
|
||||
"\n",
|
||||
" if len(sentence) >0:\n",
|
||||
" data.append((sentence,label))\n",
|
||||
" sentence = []\n",
|
||||
" label = []\n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"class InputExample(object):\n",
|
||||
" \"\"\"A single training/test example for simple sequence classification.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, guid, text_a, text_b=None, label=None):\n",
|
||||
" \"\"\"Constructs a InputExample.\n",
|
||||
" Args:\n",
|
||||
" guid: Unique id for the example.\n",
|
||||
" text_a: string. The untokenized text of the first sequence. For single\n",
|
||||
" sequence tasks, only this sequence must be specified.\n",
|
||||
" text_b: (Optional) string. The untokenized text of the second sequence.\n",
|
||||
" Only must be specified for sequence pair tasks.\n",
|
||||
" label: (Optional) string. The label of the example. This should be\n",
|
||||
" specified for train and dev examples, but not for test examples.\n",
|
||||
" \"\"\"\n",
|
||||
" self.guid = guid\n",
|
||||
" self.text_a = text_a\n",
|
||||
" self.text_b = text_b\n",
|
||||
" self.label = label\n",
|
||||
"\n",
|
||||
"class InputFeatures(object):\n",
|
||||
" \"\"\"A single set of features of data.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(self, input_ids, input_mask, segment_ids, label_id, valid_ids=None, label_mask=None):\n",
|
||||
" self.input_ids = input_ids\n",
|
||||
" self.input_mask = input_mask\n",
|
||||
" self.segment_ids = segment_ids\n",
|
||||
" self.label_id = label_id\n",
|
||||
" self.valid_ids = valid_ids\n",
|
||||
" self.label_mask = label_mask\n",
|
||||
"\n",
|
||||
"class DataProcessor(object):\n",
|
||||
" \"\"\"Base class for data converters for sequence classification data sets.\"\"\"\n",
|
||||
"\n",
|
||||
" def get_train_examples(self, data_dir):\n",
|
||||
" \"\"\"Gets a collection of `InputExample`s for the train set.\"\"\"\n",
|
||||
" raise NotImplementedError()\n",
|
||||
"\n",
|
||||
" def get_dev_examples(self, data_dir):\n",
|
||||
" \"\"\"Gets a collection of `InputExample`s for the dev set.\"\"\"\n",
|
||||
" raise NotImplementedError()\n",
|
||||
"\n",
|
||||
" def get_labels(self):\n",
|
||||
" \"\"\"Gets the list of labels for this data set.\"\"\"\n",
|
||||
" raise NotImplementedError()\n",
|
||||
"\n",
|
||||
" @classmethod\n",
|
||||
" def _read_tsv(cls, input_file, quotechar=None):\n",
|
||||
" \"\"\"Reads a tab separated value file.\"\"\"\n",
|
||||
" return readfile(input_file)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class NerProcessor(DataProcessor):\n",
|
||||
" \"\"\"Processor for the CoNLL-2003 data set.\"\"\"\n",
|
||||
"\n",
|
||||
" def get_train_examples(self, data_dir):\n",
|
||||
" \"\"\"See base class.\"\"\"\n",
|
||||
" return self._create_examples(\n",
|
||||
" self._read_tsv(os.path.join(data_dir, \"train.txt\")), \"train\")\n",
|
||||
"\n",
|
||||
" def get_dev_examples(self, data_dir):\n",
|
||||
" \"\"\"See base class.\"\"\"\n",
|
||||
" return self._create_examples(\n",
|
||||
" self._read_tsv(os.path.join(data_dir, \"valid.txt\")), \"dev\")\n",
|
||||
"\n",
|
||||
" def get_test_examples(self, data_dir):\n",
|
||||
" \"\"\"See base class.\"\"\"\n",
|
||||
" return self._create_examples(\n",
|
||||
" self._read_tsv(os.path.join(data_dir, \"test.txt\")), \"test\")\n",
|
||||
"\n",
|
||||
" def get_labels(self):\n",
|
||||
" return [\"O\", \"B-MISC\", \"I-MISC\", \"B-PER\", \"I-PER\", \"B-ORG\", \"I-ORG\", \"B-LOC\", \"I-LOC\", \"[CLS]\", \"[SEP]\"]\n",
|
||||
"\n",
|
||||
" def _create_examples(self,lines,set_type):\n",
|
||||
" examples = []\n",
|
||||
" for i,(sentence,label) in enumerate(lines):\n",
|
||||
" guid = \"%s-%s\" % (set_type, i)\n",
|
||||
" text_a = ' '.join(sentence)\n",
|
||||
" text_b = None\n",
|
||||
" label = label\n",
|
||||
" examples.append(InputExample(guid=guid,text_a=text_a,text_b=text_b,label=label))\n",
|
||||
" return examples\n",
|
||||
"\n",
|
||||
"def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):\n",
|
||||
" \"\"\"Loads a data file into a list of `InputBatch`s.\"\"\"\n",
|
||||
"\n",
|
||||
" label_map = {label : i for i, label in enumerate(label_list,1)}\n",
|
||||
"\n",
|
||||
" features = []\n",
|
||||
" for (ex_index,example) in enumerate(examples):\n",
|
||||
" textlist = example.text_a.split(' ')\n",
|
||||
" labellist = example.label\n",
|
||||
" tokens = []\n",
|
||||
" labels = []\n",
|
||||
" valid = []\n",
|
||||
" label_mask = []\n",
|
||||
" for i, word in enumerate(textlist):\n",
|
||||
" token = tokenizer.tokenize(word)\n",
|
||||
" tokens.extend(token)\n",
|
||||
" label_1 = labellist[i]\n",
|
||||
" for m in range(len(token)):\n",
|
||||
" if m == 0:\n",
|
||||
" labels.append(label_1)\n",
|
||||
" valid.append(1)\n",
|
||||
" label_mask.append(1)\n",
|
||||
" else:\n",
|
||||
" valid.append(0)\n",
|
||||
" if len(tokens) >= max_seq_length - 1:\n",
|
||||
" tokens = tokens[0:(max_seq_length - 2)]\n",
|
||||
" labels = labels[0:(max_seq_length - 2)]\n",
|
||||
" valid = valid[0:(max_seq_length - 2)]\n",
|
||||
" label_mask = label_mask[0:(max_seq_length - 2)]\n",
|
||||
" ntokens = []\n",
|
||||
" segment_ids = []\n",
|
||||
" label_ids = []\n",
|
||||
" ntokens.append(\"[CLS]\")\n",
|
||||
" segment_ids.append(0)\n",
|
||||
" valid.insert(0,1)\n",
|
||||
" label_mask.insert(0,1)\n",
|
||||
" label_ids.append(label_map[\"[CLS]\"])\n",
|
||||
" for i, token in enumerate(tokens):\n",
|
||||
" ntokens.append(token)\n",
|
||||
" segment_ids.append(0)\n",
|
||||
" if len(labels) > i:\n",
|
||||
" label_ids.append(label_map[labels[i]])\n",
|
||||
" ntokens.append(\"[SEP]\")\n",
|
||||
" segment_ids.append(0)\n",
|
||||
" valid.append(1)\n",
|
||||
" label_mask.append(1)\n",
|
||||
" label_ids.append(label_map[\"[SEP]\"])\n",
|
||||
" input_ids = tokenizer.convert_tokens_to_ids(ntokens)\n",
|
||||
" input_mask = [1] * len(input_ids)\n",
|
||||
" label_mask = [1] * len(label_ids)\n",
|
||||
" while len(input_ids) < max_seq_length:\n",
|
||||
" input_ids.append(0)\n",
|
||||
" input_mask.append(0)\n",
|
||||
" segment_ids.append(0)\n",
|
||||
" label_ids.append(0)\n",
|
||||
" valid.append(1)\n",
|
||||
" label_mask.append(0)\n",
|
||||
" while len(label_ids) < max_seq_length:\n",
|
||||
" label_ids.append(0)\n",
|
||||
" label_mask.append(0)\n",
|
||||
" assert len(input_ids) == max_seq_length\n",
|
||||
" assert len(input_mask) == max_seq_length\n",
|
||||
" assert len(segment_ids) == max_seq_length\n",
|
||||
" assert len(label_ids) == max_seq_length\n",
|
||||
" assert len(valid) == max_seq_length\n",
|
||||
" assert len(label_mask) == max_seq_length\n",
|
||||
"\n",
|
||||
" if ex_index < 5:\n",
|
||||
" logger.info(\"*** Example ***\")\n",
|
||||
" logger.info(\"guid: %s\" % (example.guid))\n",
|
||||
" logger.info(\"tokens: %s\" % \" \".join(\n",
|
||||
" [str(x) for x in tokens]))\n",
|
||||
" logger.info(\"input_ids: %s\" % \" \".join([str(x) for x in input_ids]))\n",
|
||||
" logger.info(\"input_mask: %s\" % \" \".join([str(x) for x in input_mask]))\n",
|
||||
" logger.info(\n",
|
||||
" \"segment_ids: %s\" % \" \".join([str(x) for x in segment_ids]))\n",
|
||||
" # logger.info(\"label: %s (id = %d)\" % (example.label, label_ids))\n",
|
||||
"\n",
|
||||
" features.append(\n",
|
||||
" InputFeatures(input_ids=input_ids,\n",
|
||||
" input_mask=input_mask,\n",
|
||||
" segment_ids=segment_ids,\n",
|
||||
" label_id=label_ids,\n",
|
||||
" valid_ids=valid,\n",
|
||||
" label_mask=label_mask))\n",
|
||||
" return features"
|
||||
"cfg = Config()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1aedb530",
|
||||
"id": "5a67afe9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## BERT model"
|
||||
"## Prepare the model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b4abd61c",
|
||||
"id": "f45cc648",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class Ner(BertForTokenClassification):\n",
|
||||
"class TrainNer(BertForTokenClassification):\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):\n",
|
||||
" sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]\n",
|
||||
" batch_size,max_len,feat_dim = sequence_output.shape\n",
|
||||
" valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda')\n",
|
||||
" valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=1) #device 如用gpu需要修改为cfg.gpu_id的值 不用则为cpu\n",
|
||||
" for i in range(batch_size):\n",
|
||||
" jj = -1\n",
|
||||
" for j in range(max_len):\n",
|
||||
|
@ -373,8 +228,6 @@
|
|||
"\n",
|
||||
" if labels is not None:\n",
|
||||
" loss_fct = nn.CrossEntropyLoss(ignore_index=0)\n",
|
||||
" # Only keep active parts of the loss\n",
|
||||
" # attention_mask_label = None\n",
|
||||
" if attention_mask_label is not None:\n",
|
||||
" active_loss = attention_mask_label.view(-1) == 1\n",
|
||||
" active_logits = logits.view(-1, self.num_labels)[active_loss]\n",
|
||||
|
@ -387,140 +240,65 @@
|
|||
" return logits"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "029e661b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fine-Tune"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ae701820",
|
||||
"id": "9ed1c3dd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Required parameters\n",
|
||||
"data_dir = \"data/\"\n",
|
||||
"bert_model = \"bert-base-cased\"\n",
|
||||
"task_name = \"ner\"\n",
|
||||
"output_dir = \"out_ner\"\n",
|
||||
"max_seq_length = 128\n",
|
||||
"do_train = True\n",
|
||||
"do_eval = True\n",
|
||||
"eval_on = \"dev\"\n",
|
||||
"do_lower_case = \"True\"\n",
|
||||
"train_batch_size = 32\n",
|
||||
"eval_batch_size = 8\n",
|
||||
"learning_rate = 5e-5\n",
|
||||
"num_train_epochs = 5.0 # the number of training epochs\n",
|
||||
"warmup_proportion = 0.1\n",
|
||||
"weight_decay = 0.01\n",
|
||||
"adam_epsilon = 1e-8\n",
|
||||
"max_grad_norm = 1.0\n",
|
||||
"no_cuda = False\n",
|
||||
"local_rank = -1\n",
|
||||
"seed = 42\n",
|
||||
"gradient_accumulation_steps = 1\n",
|
||||
"fp16 = False\n",
|
||||
"fp16_opt_level = \"01\"\n",
|
||||
"loss_scale = 0.0"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3f4ced0f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"processors = {\"ner\":NerProcessor}\n",
|
||||
"\n",
|
||||
"if local_rank ==-1 or no_cuda:\n",
|
||||
" device = torch.device(\"cuda\" if torch.cuda.is_available() and not no_cuda else \"cpu\")\n",
|
||||
" n_gpu = torch.cuda.device_count()\n",
|
||||
"# Use gpu or not\n",
|
||||
"if cfg.use_gpu and torch.cuda.is_available():\n",
|
||||
" device = torch.device('cuda', cfg.gpu_id)\n",
|
||||
"else:\n",
|
||||
" torch.cuda.set_device(local_rank)\n",
|
||||
" device = torch.device(\"cuda\", local_rank)\n",
|
||||
" n_gpu = 1\n",
|
||||
" # Initializes the distributed backend which will take care of sychronizing nodes/GPUs\n",
|
||||
" torch.distributed.init_process_group(backend='nccl')\n",
|
||||
" device = torch.device('cpu')\n",
|
||||
"\n",
|
||||
"logger.info(\"device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}\".format(device, n_gpu, bool(local_rank != -1), fp16))\n",
|
||||
"if cfg.gradient_accumulation_steps < 1:\n",
|
||||
" raise ValueError(\"Invalid gradient_accumulation_steps parameter: {}, should be >= 1\".format(cfg.gradient_accumulation_steps))\n",
|
||||
"\n",
|
||||
"train_batch_size = train_batch_size // gradient_accumulation_steps\n",
|
||||
"cfg.train_batch_size = cfg.train_batch_size // cfg.gradient_accumulation_steps\n",
|
||||
"\n",
|
||||
"random.seed(seed)\n",
|
||||
"np.random.seed(seed)\n",
|
||||
"torch.manual_seed(seed)\n",
|
||||
"random.seed(cfg.seed)\n",
|
||||
"np.random.seed(cfg.seed)\n",
|
||||
"torch.manual_seed(cfg.seed)\n",
|
||||
"\n",
|
||||
"if os.path.exists(output_dir) and os.listdir(output_dir) and do_train:\n",
|
||||
" raise ValueError(\"Output directory ({}) already exists and is not empty.\".format(output_dir))\n",
|
||||
"if not os.path.exists(output_dir):\n",
|
||||
" os.makedirs(output_dir)\n",
|
||||
"if not cfg.do_train and not cfg.do_eval:\n",
|
||||
" raise ValueError(\"At least one of `do_train` or `do_eval` must be True.\")\n",
|
||||
"\n",
|
||||
"task_name = task_name.lower()\n",
|
||||
"# Checkpoints\n",
|
||||
"if os.path.exists(cfg.output_dir) and os.listdir(cfg.output_dir) and cfg.do_train:\n",
|
||||
" raise ValueError(\"Output directory ({}) already exists and is not empty.\".format(cfg.output_dir))\n",
|
||||
"if not os.path.exists(cfg.output_dir):\n",
|
||||
" os.makedirs(cfg.output_dir)\n",
|
||||
"\n",
|
||||
"processor = processors[task_name]()\n",
|
||||
"# Preprocess the input dataset\n",
|
||||
"processor = NerProcessor()\n",
|
||||
"label_list = processor.get_labels()\n",
|
||||
"num_labels = len(label_list) + 1\n",
|
||||
"\n",
|
||||
"tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=do_lower_case)\n",
|
||||
"# Prepare the model\n",
|
||||
"tokenizer = BertTokenizer.from_pretrained(cfg.bert_model, do_lower_case=cfg.do_lower_case)\n",
|
||||
"\n",
|
||||
"train_examples = None\n",
|
||||
"num_train_optimization_steps = 0\n",
|
||||
"if do_train:\n",
|
||||
" train_examples = processor.get_train_examples(data_dir)\n",
|
||||
" num_train_optimization_steps = int(len(train_examples) / train_batch_size / gradient_accumulation_steps) * num_train_epochs\n",
|
||||
" if local_rank != -1:\n",
|
||||
" num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()\n",
|
||||
"\n",
|
||||
"if local_rank not in [-1, 0]:\n",
|
||||
" torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "43846012",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Prepare model\n",
|
||||
"config = BertConfig.from_pretrained(bert_model, num_labels=num_labels, finetuning_task=task_name)\n",
|
||||
"model = Ner.from_pretrained(bert_model, from_tf = False, config = config)\n",
|
||||
"\n",
|
||||
"if local_rank == 0:\n",
|
||||
" torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab\n",
|
||||
"if cfg.do_train:\n",
|
||||
" train_examples = processor.get_train_examples(cfg.data_dir)\n",
|
||||
" num_train_optimization_steps = int(len(train_examples) / cfg.train_batch_size / cfg.gradient_accumulation_steps) * cfg.num_train_epochs\n",
|
||||
"\n",
|
||||
"config = BertConfig.from_pretrained(cfg.bert_model, num_labels=num_labels, finetuning_task=cfg.task_name)\n",
|
||||
"model = TrainNer.from_pretrained(cfg.bert_model,from_tf = False,config = config)\n",
|
||||
"model.to(device)\n",
|
||||
"\n",
|
||||
"param_optimizer = list(model.named_parameters())\n",
|
||||
"no_decay = ['bias','LayerNorm.weight']\n",
|
||||
"optimizer_grouped_parameters = [\n",
|
||||
" {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},\n",
|
||||
" {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': cfg.weight_decay},\n",
|
||||
" {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
|
||||
" ]\n",
|
||||
"warmup_steps = int(warmup_proportion * num_train_optimization_steps)\n",
|
||||
"optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_epsilon)\n",
|
||||
"warmup_steps = int(cfg.warmup_proportion * num_train_optimization_steps)\n",
|
||||
"optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate, eps=cfg.adam_epsilon)\n",
|
||||
"scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)\n",
|
||||
"if fp16:\n",
|
||||
" try:\n",
|
||||
" from apex import amp\n",
|
||||
" except ImportError:\n",
|
||||
" raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n",
|
||||
" model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)\n",
|
||||
"\n",
|
||||
"# multi-gpu training (should be after apex fp16 initialization)\n",
|
||||
"if n_gpu > 1:\n",
|
||||
" model = torch.nn.DataParallel(model)\n",
|
||||
"\n",
|
||||
"if local_rank != -1:\n",
|
||||
" model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],\n",
|
||||
" output_device=local_rank,\n",
|
||||
" find_unused_parameters=True)\n",
|
||||
"global_step = 0\n",
|
||||
"nb_tr_steps = 0\n",
|
||||
"tr_loss = 0\n",
|
||||
|
@ -538,18 +316,12 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a8e71425",
|
||||
"id": "95e37ce2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Train model\n",
|
||||
"if do_train:\n",
|
||||
" train_features = convert_examples_to_features(\n",
|
||||
" train_examples, label_list, max_seq_length, tokenizer)\n",
|
||||
" logger.info(\"***** Running training *****\")\n",
|
||||
" logger.info(\" Num examples = %d\", len(train_examples))\n",
|
||||
" logger.info(\" Batch size = %d\", train_batch_size)\n",
|
||||
" logger.info(\" Num steps = %d\", num_train_optimization_steps)\n",
|
||||
"if cfg.do_train:\n",
|
||||
" train_features = convert_examples_to_features(train_examples, label_list, cfg.max_seq_length, tokenizer)\n",
|
||||
" all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)\n",
|
||||
" all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)\n",
|
||||
" all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)\n",
|
||||
|
@ -557,37 +329,29 @@
|
|||
" all_valid_ids = torch.tensor([f.valid_ids for f in train_features], dtype=torch.long)\n",
|
||||
" all_lmask_ids = torch.tensor([f.label_mask for f in train_features], dtype=torch.long)\n",
|
||||
" train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_valid_ids,all_lmask_ids)\n",
|
||||
" if local_rank == -1:\n",
|
||||
" train_sampler = RandomSampler(train_data)\n",
|
||||
" else:\n",
|
||||
" train_sampler = DistributedSampler(train_data)\n",
|
||||
" train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)\n",
|
||||
" train_sampler = RandomSampler(train_data)\n",
|
||||
"\n",
|
||||
" train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=cfg.train_batch_size)\n",
|
||||
"\n",
|
||||
" model.train()\n",
|
||||
" for _ in trange(int(num_train_epochs), desc=\"Epoch\"):\n",
|
||||
"\n",
|
||||
" for _ in trange(int(cfg.num_train_epochs), desc=\"Epoch\"):\n",
|
||||
" tr_loss = 0\n",
|
||||
" nb_tr_examples, nb_tr_steps = 0, 0\n",
|
||||
" for step, batch in enumerate(tqdm(train_dataloader, desc=\"Iteration\")):\n",
|
||||
" batch = tuple(t.to(device) for t in batch)\n",
|
||||
" input_ids, input_mask, segment_ids, label_ids, valid_ids,l_mask = batch\n",
|
||||
" loss = model(input_ids, segment_ids, input_mask, label_ids,valid_ids,l_mask)\n",
|
||||
" if n_gpu > 1:\n",
|
||||
" loss = loss.mean() # mean() to average on multi-gpu.\n",
|
||||
" if gradient_accumulation_steps > 1:\n",
|
||||
" loss = loss / gradient_accumulation_steps\n",
|
||||
" if cfg.gradient_accumulation_steps > 1:\n",
|
||||
" loss = loss / cfg.gradient_accumulation_steps\n",
|
||||
"\n",
|
||||
" if fp16:\n",
|
||||
" with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
|
||||
" scaled_loss.backward()\n",
|
||||
" torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)\n",
|
||||
" else:\n",
|
||||
" loss.backward()\n",
|
||||
" torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)\n",
|
||||
" loss.backward()\n",
|
||||
" torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)\n",
|
||||
"\n",
|
||||
" tr_loss += loss.item()\n",
|
||||
" nb_tr_examples += input_ids.size(0)\n",
|
||||
" nb_tr_steps += 1\n",
|
||||
" if (step + 1) % gradient_accumulation_steps == 0:\n",
|
||||
" if (step + 1) % cfg.gradient_accumulation_steps == 0:\n",
|
||||
" optimizer.step()\n",
|
||||
" scheduler.step() # Update learning rate schedule\n",
|
||||
" model.zero_grad()\n",
|
||||
|
@ -595,16 +359,17 @@
|
|||
"\n",
|
||||
" # Save a trained model and the associated configuration\n",
|
||||
" model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self\n",
|
||||
" model_to_save.save_pretrained(output_dir)\n",
|
||||
" tokenizer.save_pretrained(output_dir)\n",
|
||||
" model_to_save.save_pretrained(cfg.output_dir)\n",
|
||||
" tokenizer.save_pretrained(cfg.output_dir)\n",
|
||||
" label_map = {i : label for i, label in enumerate(label_list,1)}\n",
|
||||
" model_config = {\"bert_model\":bert_model,\"do_lower\":do_lower_case,\"max_seq_length\":max_seq_length,\"num_labels\":len(label_list)+1,\"label_map\":label_map}\n",
|
||||
" json.dump(model_config,open(os.path.join(output_dir,\"model_config.json\"),\"w\"))\n",
|
||||
" model_config = {\"bert_model\":cfg.bert_model,\"do_lower\":cfg.do_lower_case, \"max_seq_length\":cfg.max_seq_length,\"num_labels\":len(label_list)+1,\"label_map\":label_map}\n",
|
||||
" json.dump(model_config,open(os.path.join(cfg.output_dir,\"model_config.json\"),\"w\"))\n",
|
||||
" # Load a trained model and config that you have fine-tuned\n",
|
||||
"else:\n",
|
||||
" # Load a trained model and vocabulary that you have fine-tuned\n",
|
||||
" model = Ner.from_pretrained(output_dir)\n",
|
||||
" tokenizer = BertTokenizer.from_pretrained(output_dir, do_lower_case=do_lower_case)\n",
|
||||
" model = TrainNer.from_pretrained(cfg.output_dir)\n",
|
||||
" tokenizer = BertTokenizer.from_pretrained(cfg.output_dir, do_lower_case=cfg.do_lower_case)\n",
|
||||
"\n",
|
||||
"model.to(device)"
|
||||
]
|
||||
},
|
||||
|
@ -619,22 +384,18 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a0fecf19",
|
||||
"id": "5cf1972a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Evaluation\n",
|
||||
"if do_eval and (local_rank == -1 or torch.distributed.get_rank() == 0):\n",
|
||||
" if eval_on == \"dev\":\n",
|
||||
" eval_examples = processor.get_dev_examples(data_dir)\n",
|
||||
" elif eval_on == \"test\":\n",
|
||||
" eval_examples = processor.get_test_examples(data_dir)\n",
|
||||
"if cfg.do_eval:\n",
|
||||
" if cfg.eval_on == \"dev\":\n",
|
||||
" eval_examples = processor.get_dev_examples(cfg.data_dir)\n",
|
||||
" elif cfg.eval_on == \"test\":\n",
|
||||
" eval_examples = processor.get_test_examples(cfg.data_dir)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(\"eval on dev or test set only\")\n",
|
||||
" eval_features = convert_examples_to_features(eval_examples, label_list, max_seq_length, tokenizer)\n",
|
||||
" logger.info(\"***** Running evaluation *****\")\n",
|
||||
" logger.info(\" Num examples = %d\", len(eval_examples))\n",
|
||||
" logger.info(\" Batch size = %d\", eval_batch_size)\n",
|
||||
" eval_features = convert_examples_to_features(eval_examples, label_list, cfg.max_seq_length, tokenizer)\n",
|
||||
" all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)\n",
|
||||
" all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)\n",
|
||||
" all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)\n",
|
||||
|
@ -644,7 +405,7 @@
|
|||
" eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_valid_ids,all_lmask_ids)\n",
|
||||
" # Run prediction for full data\n",
|
||||
" eval_sampler = SequentialSampler(eval_data)\n",
|
||||
" eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=eval_batch_size)\n",
|
||||
" eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=cfg.eval_batch_size)\n",
|
||||
" model.eval()\n",
|
||||
" eval_loss, eval_accuracy = 0, 0\n",
|
||||
" nb_eval_steps, nb_eval_examples = 0, 0\n",
|
||||
|
@ -683,7 +444,7 @@
|
|||
"\n",
|
||||
" report = classification_report(y_true, y_pred,digits=4)\n",
|
||||
" logger.info(\"\\n%s\", report)\n",
|
||||
" output_eval_file = os.path.join(output_dir, \"eval_results.txt\")\n",
|
||||
" output_eval_file = os.path.join(cfg.output_dir, \"eval_results.txt\")\n",
|
||||
" with open(output_eval_file, \"w\") as writer:\n",
|
||||
" logger.info(\"***** Eval results *****\")\n",
|
||||
" logger.info(\"\\n%s\", report)\n",
|
||||
|
@ -701,174 +462,16 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f78c5102",
|
||||
"id": "b0e33a75",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# make changes for better representation and display of 'entity detected' and their 'entity types' for the given sentence to test or inference\n",
|
||||
"from nltk import word_tokenize\n",
|
||||
"from collections import OrderedDict\n",
|
||||
"model = InferNer(\"checkpoints/\")\n",
|
||||
"text = cfg.text\n",
|
||||
"\n",
|
||||
"class BertNer(BertForTokenClassification):\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, token_type_ids=None, attention_mask=None, valid_ids=None):\n",
|
||||
" sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0]\n",
|
||||
" batch_size,max_len,feat_dim = sequence_output.shape\n",
|
||||
" valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
" for i in range(batch_size):\n",
|
||||
" jj = -1\n",
|
||||
" for j in range(max_len):\n",
|
||||
" if valid_ids[i][j].item() == 1:\n",
|
||||
" jj += 1\n",
|
||||
" valid_output[i][jj] = sequence_output[i][j]\n",
|
||||
" sequence_output = self.dropout(valid_output)\n",
|
||||
" logits = self.classifier(sequence_output)\n",
|
||||
" return logits\n",
|
||||
"\n",
|
||||
"class Ner:\n",
|
||||
"\n",
|
||||
" def __init__(self,model_dir: str):\n",
|
||||
" self.model , self.tokenizer, self.model_config = self.load_model(model_dir)\n",
|
||||
" self.label_map = self.model_config[\"label_map\"]\n",
|
||||
" self.max_seq_length = self.model_config[\"max_seq_length\"]\n",
|
||||
" self.label_map = {int(k):v for k,v in self.label_map.items()}\n",
|
||||
" self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||
" self.model = self.model.to(self.device)\n",
|
||||
" self.model.eval()\n",
|
||||
"\n",
|
||||
" def load_model(self, model_dir: str, model_config: str = \"model_config.json\"):\n",
|
||||
" model_config = os.path.join(model_dir,model_config)\n",
|
||||
" model_config = json.load(open(model_config))\n",
|
||||
" model = BertNer.from_pretrained(model_dir)\n",
|
||||
" tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=model_config[\"do_lower\"])\n",
|
||||
" return model, tokenizer, model_config\n",
|
||||
"\n",
|
||||
" def tokenize(self, text: str):\n",
|
||||
" \"\"\" tokenize input\"\"\"\n",
|
||||
" words = word_tokenize(text)\n",
|
||||
" tokens = []\n",
|
||||
" valid_positions = []\n",
|
||||
" for i,word in enumerate(words):\n",
|
||||
" token = self.tokenizer.tokenize(word)\n",
|
||||
" tokens.extend(token)\n",
|
||||
" for i in range(len(token)):\n",
|
||||
" if i == 0:\n",
|
||||
" valid_positions.append(1)\n",
|
||||
" else:\n",
|
||||
" valid_positions.append(0)\n",
|
||||
" return tokens, valid_positions\n",
|
||||
"\n",
|
||||
" def preprocess(self, text: str):\n",
|
||||
" \"\"\" preprocess \"\"\"\n",
|
||||
" tokens, valid_positions = self.tokenize(text)\n",
|
||||
" ## insert \"[CLS]\"\n",
|
||||
" tokens.insert(0,\"[CLS]\")\n",
|
||||
" valid_positions.insert(0,1)\n",
|
||||
" ## insert \"[SEP]\"\n",
|
||||
" tokens.append(\"[SEP]\")\n",
|
||||
" valid_positions.append(1)\n",
|
||||
" segment_ids = []\n",
|
||||
" for i in range(len(tokens)):\n",
|
||||
" segment_ids.append(0)\n",
|
||||
" input_ids = self.tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||
" input_mask = [1] * len(input_ids)\n",
|
||||
" while len(input_ids) < self.max_seq_length:\n",
|
||||
" input_ids.append(0)\n",
|
||||
" input_mask.append(0)\n",
|
||||
" segment_ids.append(0)\n",
|
||||
" valid_positions.append(0)\n",
|
||||
" return input_ids,input_mask,segment_ids,valid_positions\n",
|
||||
"\n",
|
||||
" def predict(self, text: str):\n",
|
||||
" input_ids,input_mask,segment_ids,valid_ids = self.preprocess(text)\n",
|
||||
" input_ids = torch.tensor([input_ids],dtype=torch.long,device=self.device)\n",
|
||||
" input_mask = torch.tensor([input_mask],dtype=torch.long,device=self.device)\n",
|
||||
" segment_ids = torch.tensor([segment_ids],dtype=torch.long,device=self.device)\n",
|
||||
" valid_ids = torch.tensor([valid_ids],dtype=torch.long,device=self.device)\n",
|
||||
" with torch.no_grad():\n",
|
||||
" logits = self.model(input_ids, segment_ids, input_mask,valid_ids)\n",
|
||||
" logits = F.softmax(logits,dim=2)\n",
|
||||
" logits_label = torch.argmax(logits,dim=2)\n",
|
||||
" logits_label = logits_label.detach().cpu().numpy().tolist()[0]\n",
|
||||
"\n",
|
||||
" logits_confidence = [values[label].item() for values,label in zip(logits[0],logits_label)]\n",
|
||||
"\n",
|
||||
" logits = []\n",
|
||||
" pos = 0\n",
|
||||
" for index,mask in enumerate(valid_ids[0]):\n",
|
||||
" if index == 0:\n",
|
||||
" continue\n",
|
||||
" if mask == 1:\n",
|
||||
" logits.append((logits_label[index-pos],logits_confidence[index-pos]))\n",
|
||||
" else:\n",
|
||||
" pos += 1\n",
|
||||
" logits.pop()\n",
|
||||
"\n",
|
||||
" labels = [(self.label_map[label],confidence) for label,confidence in logits]\n",
|
||||
" words = word_tokenize(text)\n",
|
||||
" assert len(labels) == len(words)\n",
|
||||
"\n",
|
||||
" result = []\n",
|
||||
" for word, (label, confidence) in zip(words, labels):\n",
|
||||
" if label!='O':\n",
|
||||
" result.append((word,label))\n",
|
||||
" tmp = []\n",
|
||||
" tag = OrderedDict()\n",
|
||||
" tag['PER'] = []\n",
|
||||
" tag['LOC'] = []\n",
|
||||
" tag['ORG'] = []\n",
|
||||
" tag['MISC'] = []\n",
|
||||
" \n",
|
||||
" for i, (word, label) in enumerate(result):\n",
|
||||
" if label=='B-PER' or label=='B-LOC' or label=='B-ORG' or label=='B-MISC':\n",
|
||||
" if i==0:\n",
|
||||
" tmp.append(word)\n",
|
||||
" else:\n",
|
||||
" wordstype = result[i-1][1][2:]\n",
|
||||
" tag[wordstype].append(' '.join(tmp))\n",
|
||||
" tmp.clear()\n",
|
||||
" tmp.append(word)\n",
|
||||
" elif i==len(result)-1:\n",
|
||||
" tmp.append(word)\n",
|
||||
" wordstype = result[i][1][2:]\n",
|
||||
" tag[wordstype].append(' '.join(tmp))\n",
|
||||
" else:\n",
|
||||
" tmp.append(word)\n",
|
||||
"\n",
|
||||
" return tag\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6fec9c7b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Run below command for import and download 'nltk' library as it is important for predictions of entities of the sentence\n",
|
||||
"import nltk\n",
|
||||
"nltk.download('punkt')\n",
|
||||
"\n",
|
||||
"# If it's too slow to download 'nltk_data', we offer it in 'data/nltk_data' and you can use it by running the following code.\n",
|
||||
"# import nltk\n",
|
||||
"# nltk.data.path.insert(0,'./data/nltk_data')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5ce89e90",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = Ner(\"out_ner/\")\n",
|
||||
"\n",
|
||||
"# Text to be NERed\n",
|
||||
"text = \"Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival.\"\n",
|
||||
"\n",
|
||||
"print(\"The text to be NERed:\")\n",
|
||||
"print(\"NER句子:\")\n",
|
||||
"print(text)\n",
|
||||
"print('Results of NER:')\n",
|
||||
"print('NER结果:')\n",
|
||||
"\n",
|
||||
"result = model.predict(text)\n",
|
||||
"for k,v in result.items():\n",
|
||||
|
@ -879,9 +482,7 @@
|
|||
" elif k=='LOC':\n",
|
||||
" print('Location')\n",
|
||||
" elif k=='ORG':\n",
|
||||
" print('Organization')\n",
|
||||
" elif k=='MISC':\n",
|
||||
" print('Miscellaneous')"
|
||||
" print('Organization')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -889,14 +490,14 @@
|
|||
"id": "97fc8159",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This demo does not include parameter adjustment. If you can interested in this, you can go to [deepke](http://openkg.cn/tool/deepke)\n",
|
||||
"This demo does not include parameter adjustment. If you are interested in this, you can go to [deepke](http://openkg.cn/tool/deepke)\n",
|
||||
"Warehouse, download and use more models:)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
|
@ -910,7 +511,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.11"
|
||||
"version": "3.8.11"
|
||||
},
|
||||
"latex_envs": {
|
||||
"LaTeX_envs_menu_present": true,
|
||||
|
|
Loading…
Reference in New Issue