test
This commit is contained in:
parent
778caf8520
commit
6881ad77f5
10
README.md
10
README.md
|
@ -42,12 +42,16 @@ DeepKE 提供了多种知识抽取模型。
|
||||||
|
|
||||||
1.NER
|
1.NER
|
||||||
|
|
||||||
|
```
|
||||||
|
REGULAR
|
||||||
|
```
|
||||||
|
|
||||||
2.RE
|
2.RE
|
||||||
|
|
||||||
1.REGULAR
|
1.REGULAR
|
||||||
|
|
||||||
2.FEW-SHOT
|
2.FEW-SHOT
|
||||||
|
|
||||||
3.DOCUMENT
|
3.DOCUMENT
|
||||||
|
|
||||||
3.AE
|
3.AE
|
||||||
|
@ -66,6 +70,8 @@ Deepke包含了以下功能:(各子块导航到各模块的readme)
|
||||||
|
|
||||||
1.NER
|
1.NER
|
||||||
|
|
||||||
|
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ner/regular/README.md)**
|
||||||
|
|
||||||
2.RE 其中RE包括了以下三个子功能
|
2.RE 其中RE包括了以下三个子功能
|
||||||
|
|
||||||
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)**
|
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)**
|
||||||
|
|
|
@ -44,12 +44,16 @@ demo 's urls
|
||||||
|
|
||||||
1.NER
|
1.NER
|
||||||
|
|
||||||
|
```
|
||||||
|
REGULAR
|
||||||
|
```
|
||||||
|
|
||||||
2.RE
|
2.RE
|
||||||
|
|
||||||
1.REGULAR
|
1.REGULAR
|
||||||
|
|
||||||
2.FEW-SHOT
|
2.FEW-SHOT
|
||||||
|
|
||||||
3.DOCUMENT
|
3.DOCUMENT
|
||||||
|
|
||||||
3.AE
|
3.AE
|
||||||
|
@ -68,6 +72,8 @@ Deepke contains these models:
|
||||||
|
|
||||||
1.NER
|
1.NER
|
||||||
|
|
||||||
|
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/ner/regular/README.md)**
|
||||||
|
|
||||||
2.RE
|
2.RE
|
||||||
|
|
||||||
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)**
|
**[REGULAR](https://github.com/zjunlp/deepke/blob/test_new_deepke/example/re/regular/README.md)**
|
||||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 149 KiB |
|
@ -0,0 +1,41 @@
|
||||||
|
# 快速上手
|
||||||
|
|
||||||
|
## 克隆代码
|
||||||
|
|
||||||
|
```
|
||||||
|
git clone git@github.com:zjunlp/DeepKE.git
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 配置环境
|
||||||
|
|
||||||
|
创建python虚拟环境(python>=3.7)
|
||||||
|
|
||||||
|
安装依赖库
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 使用工具
|
||||||
|
|
||||||
|
先进行训练,训练后的模型参数保存在out_ner文件夹中
|
||||||
|
|
||||||
|
```
|
||||||
|
python run.py --data_dir=data/ --bert_model=bert-base-cased --task_name=ner --output_dir=out_ner --max_seq_length=128 --do_train --num_train_epochs 5 --do_eval --warmup_proportion=0.1
|
||||||
|
```
|
||||||
|
|
||||||
|
再进行预测<br>
|
||||||
|
|
||||||
|
执行以下命令运行示例
|
||||||
|
|
||||||
|
```
|
||||||
|
python predict.py
|
||||||
|
```
|
||||||
|
如果需要指定NER的文本,可以利用--text参数指定
|
||||||
|
```
|
||||||
|
python predict.py --text="Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival."
|
||||||
|
```
|
|
@ -0,0 +1,174 @@
|
||||||
|
"""BERT NER Inference."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from nltk import word_tokenize
|
||||||
|
from pytorch_transformers import (BertConfig, BertForTokenClassification,
|
||||||
|
BertTokenizer)
|
||||||
|
from collections import OrderedDict
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import nltk
|
||||||
|
nltk.data.path.insert(0,'./data/nltk_data')
|
||||||
|
|
||||||
|
|
||||||
|
class BertNer(BertForTokenClassification):
|
||||||
|
|
||||||
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, valid_ids=None):
|
||||||
|
sequence_output = self.bert(input_ids, token_type_ids, attention_mask, head_mask=None)[0]
|
||||||
|
batch_size,max_len,feat_dim = sequence_output.shape
|
||||||
|
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
for i in range(batch_size):
|
||||||
|
jj = -1
|
||||||
|
for j in range(max_len):
|
||||||
|
if valid_ids[i][j].item() == 1:
|
||||||
|
jj += 1
|
||||||
|
valid_output[i][jj] = sequence_output[i][j]
|
||||||
|
sequence_output = self.dropout(valid_output)
|
||||||
|
logits = self.classifier(sequence_output)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
class Ner:
|
||||||
|
|
||||||
|
def __init__(self,model_dir: str):
|
||||||
|
self.model , self.tokenizer, self.model_config = self.load_model(model_dir)
|
||||||
|
self.label_map = self.model_config["label_map"]
|
||||||
|
self.max_seq_length = self.model_config["max_seq_length"]
|
||||||
|
self.label_map = {int(k):v for k,v in self.label_map.items()}
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def load_model(self, model_dir: str, model_config: str = "model_config.json"):
|
||||||
|
model_config = os.path.join(model_dir,model_config)
|
||||||
|
model_config = json.load(open(model_config))
|
||||||
|
model = BertNer.from_pretrained(model_dir)
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=model_config["do_lower"])
|
||||||
|
return model, tokenizer, model_config
|
||||||
|
|
||||||
|
def tokenize(self, text: str):
|
||||||
|
""" tokenize input"""
|
||||||
|
words = word_tokenize(text)
|
||||||
|
tokens = []
|
||||||
|
valid_positions = []
|
||||||
|
for i,word in enumerate(words):
|
||||||
|
token = self.tokenizer.tokenize(word)
|
||||||
|
tokens.extend(token)
|
||||||
|
for i in range(len(token)):
|
||||||
|
if i == 0:
|
||||||
|
valid_positions.append(1)
|
||||||
|
else:
|
||||||
|
valid_positions.append(0)
|
||||||
|
return tokens, valid_positions
|
||||||
|
|
||||||
|
def preprocess(self, text: str):
|
||||||
|
""" preprocess """
|
||||||
|
tokens, valid_positions = self.tokenize(text)
|
||||||
|
## insert "[CLS]"
|
||||||
|
tokens.insert(0,"[CLS]")
|
||||||
|
valid_positions.insert(0,1)
|
||||||
|
## insert "[SEP]"
|
||||||
|
tokens.append("[SEP]")
|
||||||
|
valid_positions.append(1)
|
||||||
|
segment_ids = []
|
||||||
|
for i in range(len(tokens)):
|
||||||
|
segment_ids.append(0)
|
||||||
|
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
input_mask = [1] * len(input_ids)
|
||||||
|
while len(input_ids) < self.max_seq_length:
|
||||||
|
input_ids.append(0)
|
||||||
|
input_mask.append(0)
|
||||||
|
segment_ids.append(0)
|
||||||
|
valid_positions.append(0)
|
||||||
|
return input_ids,input_mask,segment_ids,valid_positions
|
||||||
|
|
||||||
|
def predict(self, text: str):
|
||||||
|
input_ids,input_mask,segment_ids,valid_ids = self.preprocess(text)
|
||||||
|
input_ids = torch.tensor([input_ids],dtype=torch.long,device=self.device)
|
||||||
|
input_mask = torch.tensor([input_mask],dtype=torch.long,device=self.device)
|
||||||
|
segment_ids = torch.tensor([segment_ids],dtype=torch.long,device=self.device)
|
||||||
|
valid_ids = torch.tensor([valid_ids],dtype=torch.long,device=self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = self.model(input_ids, segment_ids, input_mask,valid_ids)
|
||||||
|
logits = F.softmax(logits,dim=2)
|
||||||
|
logits_label = torch.argmax(logits,dim=2)
|
||||||
|
logits_label = logits_label.detach().cpu().numpy().tolist()[0]
|
||||||
|
|
||||||
|
logits_confidence = [values[label].item() for values,label in zip(logits[0],logits_label)]
|
||||||
|
|
||||||
|
logits = []
|
||||||
|
pos = 0
|
||||||
|
for index,mask in enumerate(valid_ids[0]):
|
||||||
|
if index == 0:
|
||||||
|
continue
|
||||||
|
if mask == 1:
|
||||||
|
logits.append((logits_label[index-pos],logits_confidence[index-pos]))
|
||||||
|
else:
|
||||||
|
pos += 1
|
||||||
|
logits.pop()
|
||||||
|
|
||||||
|
labels = [(self.label_map[label],confidence) for label,confidence in logits]
|
||||||
|
words = word_tokenize(text)
|
||||||
|
assert len(labels) == len(words)
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for word, (label, confidence) in zip(words, labels):
|
||||||
|
if label!='O':
|
||||||
|
result.append((word,label))
|
||||||
|
tmp = []
|
||||||
|
tag = OrderedDict()
|
||||||
|
tag['PER'] = []
|
||||||
|
tag['LOC'] = []
|
||||||
|
tag['ORG'] = []
|
||||||
|
tag['MISC'] = []
|
||||||
|
|
||||||
|
for i, (word, label) in enumerate(result):
|
||||||
|
if label=='B-PER' or label=='B-LOC' or label=='B-ORG' or label=='B-MISC':
|
||||||
|
if i==0:
|
||||||
|
tmp.append(word)
|
||||||
|
else:
|
||||||
|
wordstype = result[i-1][1][2:]
|
||||||
|
tag[wordstype].append(' '.join(tmp))
|
||||||
|
tmp.clear()
|
||||||
|
tmp.append(word)
|
||||||
|
elif i==len(result)-1:
|
||||||
|
tmp.append(word)
|
||||||
|
wordstype = result[i][1][2:]
|
||||||
|
tag[wordstype].append(' '.join(tmp))
|
||||||
|
else:
|
||||||
|
tmp.append(word)
|
||||||
|
|
||||||
|
return tag
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model = Ner("out_ner/")
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--text",
|
||||||
|
default="Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival.",
|
||||||
|
type=str,
|
||||||
|
help="The text to be NERed")
|
||||||
|
text = parser.parse_args().text
|
||||||
|
|
||||||
|
print("The text to be NERed:")
|
||||||
|
print(text)
|
||||||
|
print('Results of NER:')
|
||||||
|
|
||||||
|
result = model.predict(text)
|
||||||
|
for k,v in result.items():
|
||||||
|
if v:
|
||||||
|
print(v,end=': ')
|
||||||
|
if k=='PER':
|
||||||
|
print('Person')
|
||||||
|
elif k=='LOC':
|
||||||
|
print('Location')
|
||||||
|
elif k=='ORG':
|
||||||
|
print('Organization')
|
||||||
|
elif k=='MISC':
|
||||||
|
print('Miscellaneous')
|
|
@ -1,38 +0,0 @@
|
||||||
# 快速上手
|
|
||||||
|
|
||||||
## 克隆代码
|
|
||||||
|
|
||||||
```
|
|
||||||
git clone git@github.com:xxupiano/BERTNER.git
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## 配置环境
|
|
||||||
|
|
||||||
创建python虚拟环境(python>=3.7)
|
|
||||||
|
|
||||||
安装依赖库
|
|
||||||
|
|
||||||
```
|
|
||||||
pip install -r requirements.txt
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## 使用工具
|
|
||||||
|
|
||||||
先进行训练
|
|
||||||
|
|
||||||
```
|
|
||||||
python run_ner.py --data_dir=data/ --bert_model=bert-base-cased --task_name=ner --output_dir=out_ner --max_seq_length=128 --do_train --num_train_epochs 5 --do_eval --warmup_proportion=0.1
|
|
||||||
```
|
|
||||||
|
|
||||||
再进行预测
|
|
||||||
|
|
||||||
- 修改main.py中text为需要进行NER的句子
|
|
||||||
|
|
||||||
- ```
|
|
||||||
python main.py
|
|
||||||
```
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from pytorch_transformers import (BertConfig, BertForTokenClassification,
|
||||||
BertTokenizer)
|
BertTokenizer)
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import nltk
|
import nltk
|
||||||
nltk.data.path.insert(0,'./data/nltk_data')
|
nltk.data.path.insert(0,os.path.dirname(os.getcwd())+'/module/data/nltk_data')
|
||||||
|
|
||||||
|
|
||||||
class BertNer(BertForTokenClassification):
|
class BertNer(BertForTokenClassification):
|
|
@ -0,0 +1,51 @@
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig,
|
||||||
|
BertForTokenClassification, BertTokenizer,
|
||||||
|
WarmupLinearSchedule)
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||||
|
TensorDataset)
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from seqeval.metrics import classification_report
|
||||||
|
|
||||||
|
class Ner(BertForTokenClassification):
|
||||||
|
|
||||||
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):
|
||||||
|
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
|
||||||
|
batch_size,max_len,feat_dim = sequence_output.shape
|
||||||
|
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda')
|
||||||
|
for i in range(batch_size):
|
||||||
|
jj = -1
|
||||||
|
for j in range(max_len):
|
||||||
|
if valid_ids[i][j].item() == 1:
|
||||||
|
jj += 1
|
||||||
|
valid_output[i][jj] = sequence_output[i][j]
|
||||||
|
sequence_output = self.dropout(valid_output)
|
||||||
|
logits = self.classifier(sequence_output)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = nn.CrossEntropyLoss(ignore_index=0)
|
||||||
|
# Only keep active parts of the loss
|
||||||
|
#attention_mask_label = None
|
||||||
|
if attention_mask_label is not None:
|
||||||
|
active_loss = attention_mask_label.view(-1) == 1
|
||||||
|
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||||
|
active_labels = labels.view(-1)[active_loss]
|
||||||
|
loss = loss_fct(active_logits, active_labels)
|
||||||
|
else:
|
||||||
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
return logits
|
Binary file not shown.
|
@ -0,0 +1,98 @@
|
||||||
|
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
|
||||||
|
|
||||||
|
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
|
||||||
|
been contributed by various people using NLTK for sentence boundary detection.
|
||||||
|
|
||||||
|
For information about how to use these models, please confer the tokenization HOWTO:
|
||||||
|
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
|
||||||
|
and chapter 3.8 of the NLTK book:
|
||||||
|
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
|
||||||
|
|
||||||
|
There are pretrained tokenizers for the following languages:
|
||||||
|
|
||||||
|
File Language Source Contents Size of training corpus(in tokens) Model contributed by
|
||||||
|
=======================================================================================================================================================================
|
||||||
|
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
|
||||||
|
Literarni Noviny
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
|
||||||
|
(Berlingske Avisdata, Copenhagen) Weekend Avisen
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
|
||||||
|
(American)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
|
||||||
|
Text Bank (Suomen Kielen newspapers
|
||||||
|
Tekstipankki)
|
||||||
|
Finnish Center for IT Science
|
||||||
|
(CSC)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
|
||||||
|
(European)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
|
||||||
|
(Switzerland) CD-ROM
|
||||||
|
(Uses "ss"
|
||||||
|
instead of "ß")
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
|
||||||
|
(Bokmål and Information Technologies,
|
||||||
|
Nynorsk) Bergen
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
|
||||||
|
(http://www.nkjp.pl/)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
|
||||||
|
(Brazilian) (Linguateca)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
|
||||||
|
Slovene Academy for Arts
|
||||||
|
and Sciences
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
|
||||||
|
(European)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
|
||||||
|
(and some other texts)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
|
||||||
|
(Türkçe Derlem Projesi)
|
||||||
|
University of Ankara
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
|
||||||
|
Unicode using the codecs module.
|
||||||
|
|
||||||
|
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
|
||||||
|
Computational Linguistics 32: 485-525.
|
||||||
|
|
||||||
|
---- Training Code ----
|
||||||
|
|
||||||
|
# import punkt
|
||||||
|
import nltk.tokenize.punkt
|
||||||
|
|
||||||
|
# Make a new Tokenizer
|
||||||
|
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
|
||||||
|
|
||||||
|
# Read in training corpus (one example: Slovene)
|
||||||
|
import codecs
|
||||||
|
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
|
||||||
|
|
||||||
|
# Train tokenizer
|
||||||
|
tokenizer.train(text)
|
||||||
|
|
||||||
|
# Dump pickled tokenizer
|
||||||
|
import pickle
|
||||||
|
out = open("slovene.pickle","wb")
|
||||||
|
pickle.dump(tokenizer, out)
|
||||||
|
out.close()
|
||||||
|
|
||||||
|
---------
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,98 @@
|
||||||
|
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
|
||||||
|
|
||||||
|
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
|
||||||
|
been contributed by various people using NLTK for sentence boundary detection.
|
||||||
|
|
||||||
|
For information about how to use these models, please confer the tokenization HOWTO:
|
||||||
|
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
|
||||||
|
and chapter 3.8 of the NLTK book:
|
||||||
|
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
|
||||||
|
|
||||||
|
There are pretrained tokenizers for the following languages:
|
||||||
|
|
||||||
|
File Language Source Contents Size of training corpus(in tokens) Model contributed by
|
||||||
|
=======================================================================================================================================================================
|
||||||
|
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
|
||||||
|
Literarni Noviny
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
|
||||||
|
(Berlingske Avisdata, Copenhagen) Weekend Avisen
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
|
||||||
|
(American)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
|
||||||
|
Text Bank (Suomen Kielen newspapers
|
||||||
|
Tekstipankki)
|
||||||
|
Finnish Center for IT Science
|
||||||
|
(CSC)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
|
||||||
|
(European)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
|
||||||
|
(Switzerland) CD-ROM
|
||||||
|
(Uses "ss"
|
||||||
|
instead of "ß")
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
|
||||||
|
(Bokmål and Information Technologies,
|
||||||
|
Nynorsk) Bergen
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
|
||||||
|
(http://www.nkjp.pl/)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
|
||||||
|
(Brazilian) (Linguateca)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
|
||||||
|
Slovene Academy for Arts
|
||||||
|
and Sciences
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
|
||||||
|
(European)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
|
||||||
|
(and some other texts)
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
|
||||||
|
(Türkçe Derlem Projesi)
|
||||||
|
University of Ankara
|
||||||
|
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
|
||||||
|
Unicode using the codecs module.
|
||||||
|
|
||||||
|
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
|
||||||
|
Computational Linguistics 32: 485-525.
|
||||||
|
|
||||||
|
---- Training Code ----
|
||||||
|
|
||||||
|
# import punkt
|
||||||
|
import nltk.tokenize.punkt
|
||||||
|
|
||||||
|
# Make a new Tokenizer
|
||||||
|
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
|
||||||
|
|
||||||
|
# Read in training corpus (one example: Slovene)
|
||||||
|
import codecs
|
||||||
|
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
|
||||||
|
|
||||||
|
# Train tokenizer
|
||||||
|
tokenizer.train(text)
|
||||||
|
|
||||||
|
# Dump pickled tokenizer
|
||||||
|
import pickle
|
||||||
|
out = open("slovene.pickle","wb")
|
||||||
|
pickle.dump(tokenizer, out)
|
||||||
|
out.close()
|
||||||
|
|
||||||
|
---------
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,75 @@
|
||||||
|
class InputExample(object):
|
||||||
|
"""A single training/test example for simple sequence classification."""
|
||||||
|
|
||||||
|
def __init__(self, guid, text_a, text_b=None, label=None):
|
||||||
|
"""Constructs a InputExample.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guid: Unique id for the example.
|
||||||
|
text_a: string. The untokenized text of the first sequence. For single
|
||||||
|
sequence tasks, only this sequence must be specified.
|
||||||
|
text_b: (Optional) string. The untokenized text of the second sequence.
|
||||||
|
Only must be specified for sequence pair tasks.
|
||||||
|
label: (Optional) string. The label of the example. This should be
|
||||||
|
specified for train and dev examples, but not for test examples.
|
||||||
|
"""
|
||||||
|
self.guid = guid
|
||||||
|
self.text_a = text_a
|
||||||
|
self.text_b = text_b
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
class InputFeatures(object):
|
||||||
|
"""A single set of features of data."""
|
||||||
|
|
||||||
|
def __init__(self, input_ids, input_mask, segment_ids, label_id, valid_ids=None, label_mask=None):
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self.input_mask = input_mask
|
||||||
|
self.segment_ids = segment_ids
|
||||||
|
self.label_id = label_id
|
||||||
|
self.valid_ids = valid_ids
|
||||||
|
self.label_mask = label_mask
|
||||||
|
|
||||||
|
def readfile(filename):
|
||||||
|
'''
|
||||||
|
read file
|
||||||
|
'''
|
||||||
|
f = open(filename)
|
||||||
|
data = []
|
||||||
|
sentence = []
|
||||||
|
label= []
|
||||||
|
for line in f:
|
||||||
|
if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n":
|
||||||
|
if len(sentence) > 0:
|
||||||
|
data.append((sentence,label))
|
||||||
|
sentence = []
|
||||||
|
label = []
|
||||||
|
continue
|
||||||
|
splits = line.split(' ')
|
||||||
|
sentence.append(splits[0])
|
||||||
|
label.append(splits[-1][:-1])
|
||||||
|
|
||||||
|
if len(sentence) >0:
|
||||||
|
data.append((sentence,label))
|
||||||
|
sentence = []
|
||||||
|
label = []
|
||||||
|
return data
|
||||||
|
|
||||||
|
class DataProcessor(object):
|
||||||
|
"""Base class for data converters for sequence classification data sets."""
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""Gets a collection of `InputExample`s for the train set."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""Gets a collection of `InputExample`s for the dev set."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
"""Gets the list of labels for this data set."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _read_tsv(cls, input_file, quotechar=None):
|
||||||
|
"""Reads a tab separated value file."""
|
||||||
|
return readfile(input_file)
|
|
@ -0,0 +1,358 @@
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig,
|
||||||
|
BertForTokenClassification, BertTokenizer,
|
||||||
|
WarmupLinearSchedule)
|
||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||||
|
TensorDataset)
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
from seqeval.metrics import classification_report
|
||||||
|
|
||||||
|
from dataset import *
|
||||||
|
from preprocess import *
|
||||||
|
sys.path.append("..")
|
||||||
|
from models.NER import Ner
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
## Required parameters
|
||||||
|
parser.add_argument("--data_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||||
|
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||||
|
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||||
|
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||||
|
"bert-base-multilingual-cased, bert-base-chinese.")
|
||||||
|
parser.add_argument("--task_name",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The name of the task to train.")
|
||||||
|
parser.add_argument("--output_dir",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The output directory where the model predictions and checkpoints will be written.")
|
||||||
|
|
||||||
|
## Other parameters
|
||||||
|
parser.add_argument("--cache_dir",
|
||||||
|
default="",
|
||||||
|
type=str,
|
||||||
|
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||||
|
parser.add_argument("--max_seq_length",
|
||||||
|
default=128,
|
||||||
|
type=int,
|
||||||
|
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||||
|
"Sequences longer than this will be truncated, and sequences shorter \n"
|
||||||
|
"than this will be padded.")
|
||||||
|
parser.add_argument("--do_train",
|
||||||
|
action='store_true',
|
||||||
|
help="Whether to run training.")
|
||||||
|
parser.add_argument("--do_eval",
|
||||||
|
action='store_true',
|
||||||
|
help="Whether to run eval or not.")
|
||||||
|
parser.add_argument("--eval_on",
|
||||||
|
default="dev",
|
||||||
|
help="Whether to run eval on the dev set or test set.")
|
||||||
|
parser.add_argument("--do_lower_case",
|
||||||
|
action='store_true',
|
||||||
|
help="Set this flag if you are using an uncased model.")
|
||||||
|
parser.add_argument("--train_batch_size",
|
||||||
|
default=32,
|
||||||
|
type=int,
|
||||||
|
help="Total batch size for training.")
|
||||||
|
parser.add_argument("--eval_batch_size",
|
||||||
|
default=8,
|
||||||
|
type=int,
|
||||||
|
help="Total batch size for eval.")
|
||||||
|
parser.add_argument("--learning_rate",
|
||||||
|
default=5e-5,
|
||||||
|
type=float,
|
||||||
|
help="The initial learning rate for Adam.")
|
||||||
|
parser.add_argument("--num_train_epochs",
|
||||||
|
default=3.0,
|
||||||
|
type=float,
|
||||||
|
help="Total number of training epochs to perform.")
|
||||||
|
parser.add_argument("--warmup_proportion",
|
||||||
|
default=0.1,
|
||||||
|
type=float,
|
||||||
|
help="Proportion of training to perform linear learning rate warmup for. "
|
||||||
|
"E.g., 0.1 = 10%% of training.")
|
||||||
|
parser.add_argument("--weight_decay", default=0.01, type=float,
|
||||||
|
help="Weight deay if we apply some.")
|
||||||
|
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||||
|
help="Epsilon for Adam optimizer.")
|
||||||
|
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||||
|
help="Max gradient norm.")
|
||||||
|
parser.add_argument("--no_cuda",
|
||||||
|
action='store_true',
|
||||||
|
help="Whether not to use CUDA when available")
|
||||||
|
parser.add_argument("--local_rank",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="local_rank for distributed training on gpus")
|
||||||
|
parser.add_argument('--seed',
|
||||||
|
type=int,
|
||||||
|
default=42,
|
||||||
|
help="random seed for initialization")
|
||||||
|
parser.add_argument('--gradient_accumulation_steps',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||||
|
parser.add_argument('--fp16',
|
||||||
|
action='store_true',
|
||||||
|
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||||
|
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||||
|
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||||
|
"See details at https://nvidia.github.io/apex/amp.html")
|
||||||
|
parser.add_argument('--loss_scale',
|
||||||
|
type=float, default=0,
|
||||||
|
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||||
|
"0 (default value): dynamic loss scaling.\n"
|
||||||
|
"Positive power of 2: static loss scaling value.\n")
|
||||||
|
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||||
|
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.server_ip and args.server_port:
|
||||||
|
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||||
|
import ptvsd
|
||||||
|
print("Waiting for debugger attach")
|
||||||
|
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||||
|
ptvsd.wait_for_attach()
|
||||||
|
|
||||||
|
processors = {"ner":NerProcessor}
|
||||||
|
|
||||||
|
if args.local_rank == -1 or args.no_cuda:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||||
|
n_gpu = torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
torch.cuda.set_device(args.local_rank)
|
||||||
|
device = torch.device("cuda", args.local_rank)
|
||||||
|
n_gpu = 1
|
||||||
|
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||||
|
torch.distributed.init_process_group(backend='nccl')
|
||||||
|
|
||||||
|
if args.gradient_accumulation_steps < 1:
|
||||||
|
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||||
|
args.gradient_accumulation_steps))
|
||||||
|
|
||||||
|
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
torch.manual_seed(args.seed)
|
||||||
|
|
||||||
|
if not args.do_train and not args.do_eval:
|
||||||
|
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||||
|
|
||||||
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
||||||
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
|
task_name = args.task_name.lower()
|
||||||
|
|
||||||
|
if task_name not in processors:
|
||||||
|
raise ValueError("Task not found: %s" % (task_name))
|
||||||
|
|
||||||
|
processor = processors[task_name]()
|
||||||
|
label_list = processor.get_labels()
|
||||||
|
num_labels = len(label_list) + 1
|
||||||
|
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
|
train_examples = None
|
||||||
|
num_train_optimization_steps = 0
|
||||||
|
if args.do_train:
|
||||||
|
train_examples = processor.get_train_examples(args.data_dir)
|
||||||
|
num_train_optimization_steps = int(
|
||||||
|
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||||
|
if args.local_rank != -1:
|
||||||
|
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||||
|
|
||||||
|
if args.local_rank not in [-1, 0]:
|
||||||
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|
||||||
|
# Prepare model
|
||||||
|
config = BertConfig.from_pretrained(args.bert_model, num_labels=num_labels, finetuning_task=args.task_name)
|
||||||
|
model = Ner.from_pretrained(args.bert_model,
|
||||||
|
from_tf = False,
|
||||||
|
config = config)
|
||||||
|
|
||||||
|
if args.local_rank == 0:
|
||||||
|
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
param_optimizer = list(model.named_parameters())
|
||||||
|
no_decay = ['bias','LayerNorm.weight']
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||||
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||||
|
]
|
||||||
|
warmup_steps = int(args.warmup_proportion * num_train_optimization_steps)
|
||||||
|
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||||
|
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)
|
||||||
|
if args.fp16:
|
||||||
|
try:
|
||||||
|
from apex import amp
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||||
|
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||||
|
|
||||||
|
# multi-gpu training (should be after apex fp16 initialization)
|
||||||
|
if n_gpu > 1:
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
|
||||||
|
if args.local_rank != -1:
|
||||||
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||||
|
output_device=args.local_rank,
|
||||||
|
find_unused_parameters=True)
|
||||||
|
|
||||||
|
global_step = 0
|
||||||
|
nb_tr_steps = 0
|
||||||
|
tr_loss = 0
|
||||||
|
label_map = {i : label for i, label in enumerate(label_list,1)}
|
||||||
|
if args.do_train:
|
||||||
|
train_features = convert_examples_to_features(
|
||||||
|
train_examples, label_list, args.max_seq_length, tokenizer)
|
||||||
|
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||||
|
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||||
|
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||||
|
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||||
|
all_valid_ids = torch.tensor([f.valid_ids for f in train_features], dtype=torch.long)
|
||||||
|
all_lmask_ids = torch.tensor([f.label_mask for f in train_features], dtype=torch.long)
|
||||||
|
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_valid_ids,all_lmask_ids)
|
||||||
|
if args.local_rank == -1:
|
||||||
|
train_sampler = RandomSampler(train_data)
|
||||||
|
else:
|
||||||
|
train_sampler = DistributedSampler(train_data)
|
||||||
|
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||||
|
tr_loss = 0
|
||||||
|
nb_tr_examples, nb_tr_steps = 0, 0
|
||||||
|
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||||
|
batch = tuple(t.to(device) for t in batch)
|
||||||
|
input_ids, input_mask, segment_ids, label_ids, valid_ids,l_mask = batch
|
||||||
|
loss = model(input_ids, segment_ids, input_mask, label_ids,valid_ids,l_mask)
|
||||||
|
if n_gpu > 1:
|
||||||
|
loss = loss.mean() # mean() to average on multi-gpu.
|
||||||
|
if args.gradient_accumulation_steps > 1:
|
||||||
|
loss = loss / args.gradient_accumulation_steps
|
||||||
|
|
||||||
|
if args.fp16:
|
||||||
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||||
|
scaled_loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||||
|
|
||||||
|
tr_loss += loss.item()
|
||||||
|
nb_tr_examples += input_ids.size(0)
|
||||||
|
nb_tr_steps += 1
|
||||||
|
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step() # Update learning rate schedule
|
||||||
|
model.zero_grad()
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
# Save a trained model and the associated configuration
|
||||||
|
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||||
|
model_to_save.save_pretrained(args.output_dir)
|
||||||
|
tokenizer.save_pretrained(args.output_dir)
|
||||||
|
label_map = {i : label for i, label in enumerate(label_list,1)}
|
||||||
|
model_config = {"bert_model":args.bert_model,"do_lower":args.do_lower_case,"max_seq_length":args.max_seq_length,"num_labels":len(label_list)+1,"label_map":label_map}
|
||||||
|
json.dump(model_config,open(os.path.join(args.output_dir,"model_config.json"),"w"))
|
||||||
|
# Load a trained model and config that you have fine-tuned
|
||||||
|
else:
|
||||||
|
# Load a trained model and vocabulary that you have fine-tuned
|
||||||
|
model = Ner.from_pretrained(args.output_dir)
|
||||||
|
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||||
|
if args.eval_on == "dev":
|
||||||
|
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||||
|
elif args.eval_on == "test":
|
||||||
|
eval_examples = processor.get_test_examples(args.data_dir)
|
||||||
|
else:
|
||||||
|
raise ValueError("eval on dev or test set only")
|
||||||
|
eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer)
|
||||||
|
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||||
|
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||||
|
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||||
|
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||||
|
all_valid_ids = torch.tensor([f.valid_ids for f in eval_features], dtype=torch.long)
|
||||||
|
all_lmask_ids = torch.tensor([f.label_mask for f in eval_features], dtype=torch.long)
|
||||||
|
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_valid_ids,all_lmask_ids)
|
||||||
|
# Run prediction for full data
|
||||||
|
eval_sampler = SequentialSampler(eval_data)
|
||||||
|
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||||
|
model.eval()
|
||||||
|
eval_loss, eval_accuracy = 0, 0
|
||||||
|
nb_eval_steps, nb_eval_examples = 0, 0
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
label_map = {i : label for i, label in enumerate(label_list,1)}
|
||||||
|
for input_ids, input_mask, segment_ids, label_ids,valid_ids,l_mask in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
|
input_ids = input_ids.to(device)
|
||||||
|
input_mask = input_mask.to(device)
|
||||||
|
segment_ids = segment_ids.to(device)
|
||||||
|
valid_ids = valid_ids.to(device)
|
||||||
|
label_ids = label_ids.to(device)
|
||||||
|
l_mask = l_mask.to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(input_ids, segment_ids, input_mask,valid_ids=valid_ids,attention_mask_label=l_mask)
|
||||||
|
|
||||||
|
logits = torch.argmax(F.log_softmax(logits,dim=2),dim=2)
|
||||||
|
logits = logits.detach().cpu().numpy()
|
||||||
|
label_ids = label_ids.to('cpu').numpy()
|
||||||
|
input_mask = input_mask.to('cpu').numpy()
|
||||||
|
|
||||||
|
for i, label in enumerate(label_ids):
|
||||||
|
temp_1 = []
|
||||||
|
temp_2 = []
|
||||||
|
for j,m in enumerate(label):
|
||||||
|
if j == 0:
|
||||||
|
continue
|
||||||
|
elif label_ids[i][j] == len(label_map):
|
||||||
|
y_true.append(temp_1)
|
||||||
|
y_pred.append(temp_2)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
temp_1.append(label_map[label_ids[i][j]])
|
||||||
|
temp_2.append(label_map[logits[i][j]])
|
||||||
|
|
||||||
|
report = classification_report(y_true, y_pred,digits=4)
|
||||||
|
# logger.info("\n%s", report)
|
||||||
|
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||||
|
with open(output_eval_file, "w") as writer:
|
||||||
|
# logger.info("***** Eval results *****")
|
||||||
|
# logger.info("\n%s", report)
|
||||||
|
writer.write(report)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1,4 +1,6 @@
|
||||||
from bert import Ner
|
import sys
|
||||||
|
sys.path.append("..")
|
||||||
|
from models.BERTNER import Ner
|
||||||
model = Ner("out_ner/")
|
model = Ner("out_ner/")
|
||||||
|
|
||||||
text= "Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival."
|
text= "Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival."
|
||||||
|
@ -17,4 +19,4 @@ for k,v in result.items():
|
||||||
elif k=='ORG':
|
elif k=='ORG':
|
||||||
print('Organization')
|
print('Organization')
|
||||||
elif k=='MISC':
|
elif k=='MISC':
|
||||||
print('Miscellaneous')
|
print('Miscellaneous')
|
|
@ -0,0 +1,117 @@
|
||||||
|
from dataset import *
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class NerProcessor(DataProcessor):
|
||||||
|
"""Processor for the CoNLL-2003 data set."""
|
||||||
|
|
||||||
|
def get_train_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "train.txt")), "train")
|
||||||
|
|
||||||
|
def get_dev_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "valid.txt")), "dev")
|
||||||
|
|
||||||
|
def get_test_examples(self, data_dir):
|
||||||
|
"""See base class."""
|
||||||
|
return self._create_examples(
|
||||||
|
self._read_tsv(os.path.join(data_dir, "test.txt")), "test")
|
||||||
|
|
||||||
|
def get_labels(self):
|
||||||
|
return ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "[CLS]", "[SEP]"]
|
||||||
|
|
||||||
|
def _create_examples(self,lines,set_type):
|
||||||
|
examples = []
|
||||||
|
for i,(sentence,label) in enumerate(lines):
|
||||||
|
guid = "%s-%s" % (set_type, i)
|
||||||
|
text_a = ' '.join(sentence)
|
||||||
|
text_b = None
|
||||||
|
label = label
|
||||||
|
examples.append(InputExample(guid=guid,text_a=text_a,text_b=text_b,label=label))
|
||||||
|
return examples
|
||||||
|
|
||||||
|
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
||||||
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
|
label_map = {label : i for i, label in enumerate(label_list,1)}
|
||||||
|
|
||||||
|
features = []
|
||||||
|
for (ex_index,example) in enumerate(examples):
|
||||||
|
textlist = example.text_a.split(' ')
|
||||||
|
labellist = example.label
|
||||||
|
tokens = []
|
||||||
|
labels = []
|
||||||
|
valid = []
|
||||||
|
label_mask = []
|
||||||
|
for i, word in enumerate(textlist):
|
||||||
|
token = tokenizer.tokenize(word)
|
||||||
|
tokens.extend(token)
|
||||||
|
label_1 = labellist[i]
|
||||||
|
for m in range(len(token)):
|
||||||
|
if m == 0:
|
||||||
|
labels.append(label_1)
|
||||||
|
valid.append(1)
|
||||||
|
label_mask.append(1)
|
||||||
|
else:
|
||||||
|
valid.append(0)
|
||||||
|
if len(tokens) >= max_seq_length - 1:
|
||||||
|
tokens = tokens[0:(max_seq_length - 2)]
|
||||||
|
labels = labels[0:(max_seq_length - 2)]
|
||||||
|
valid = valid[0:(max_seq_length - 2)]
|
||||||
|
label_mask = label_mask[0:(max_seq_length - 2)]
|
||||||
|
ntokens = []
|
||||||
|
segment_ids = []
|
||||||
|
label_ids = []
|
||||||
|
ntokens.append("[CLS]")
|
||||||
|
segment_ids.append(0)
|
||||||
|
valid.insert(0,1)
|
||||||
|
label_mask.insert(0,1)
|
||||||
|
label_ids.append(label_map["[CLS]"])
|
||||||
|
for i, token in enumerate(tokens):
|
||||||
|
ntokens.append(token)
|
||||||
|
segment_ids.append(0)
|
||||||
|
if len(labels) > i:
|
||||||
|
label_ids.append(label_map[labels[i]])
|
||||||
|
ntokens.append("[SEP]")
|
||||||
|
segment_ids.append(0)
|
||||||
|
valid.append(1)
|
||||||
|
label_mask.append(1)
|
||||||
|
label_ids.append(label_map["[SEP]"])
|
||||||
|
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
|
||||||
|
input_mask = [1] * len(input_ids)
|
||||||
|
label_mask = [1] * len(label_ids)
|
||||||
|
while len(input_ids) < max_seq_length:
|
||||||
|
input_ids.append(0)
|
||||||
|
input_mask.append(0)
|
||||||
|
segment_ids.append(0)
|
||||||
|
label_ids.append(0)
|
||||||
|
valid.append(1)
|
||||||
|
label_mask.append(0)
|
||||||
|
while len(label_ids) < max_seq_length:
|
||||||
|
label_ids.append(0)
|
||||||
|
label_mask.append(0)
|
||||||
|
assert len(input_ids) == max_seq_length
|
||||||
|
assert len(input_mask) == max_seq_length
|
||||||
|
assert len(segment_ids) == max_seq_length
|
||||||
|
assert len(label_ids) == max_seq_length
|
||||||
|
assert len(valid) == max_seq_length
|
||||||
|
assert len(label_mask) == max_seq_length
|
||||||
|
|
||||||
|
features.append(
|
||||||
|
InputFeatures(input_ids=input_ids,
|
||||||
|
input_mask=input_mask,
|
||||||
|
segment_ids=segment_ids,
|
||||||
|
label_id=label_ids,
|
||||||
|
valid_ids=valid,
|
||||||
|
label_mask=label_mask))
|
||||||
|
return features
|
|
@ -1,157 +0,0 @@
|
||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "08e09d48",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# BERT based NER using CoNLL-2003\n",
|
|
||||||
"> Author: Xin Xu <xxucs@zju.edu.cn>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "14365b62",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Overview\n",
|
|
||||||
"- **Named-entity recognition (NER)** (also known as named entity identification, entity chunking, and entity extraction) is a subtask of information extraction that seeks to locate and classify named entities mentioned in unstructured text into pre-defined categories such as person names, organizations, locations, medical codes, time expressions, quantities, monetary values, percentages, etc.\n",
|
|
||||||
"- [**CoNLL-2003**](https://www.clips.uantwerpen.be/conll2003/ner/) is a dataset for NER, concentrating on four types of named entities related to persons, locations, organizations, and names of miscellaneous entities. The dataset is in 'data' folder, containing *train.txt*, *valid.txt* and *test.txt*\n",
|
|
||||||
"- [**Bidirectional Encoder Representations from Transformers (BERT)**](https://github.com/google-research/bert) is a transformer-based machine learning technique for natural language processing (NLP) pre-training developed by Google."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "733b418c",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Clone Repository\n",
|
|
||||||
"The 1st step is to clone DeepKE Github Repository."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "65822b98",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!git clone https://github.com/xxupiano/BERTNER.git"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "eb6b8798",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!cd BERTNER"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "c3b0cf3f",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Prepare the runtime environment"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "3e46f572",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!pip install -r requirements.txt"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "4b2319dc",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Fine-Tune\n",
|
|
||||||
"- Finetune or train the **bert-base** model run the 'run_ner.py'\n",
|
|
||||||
"- In below command we have to pass different arguments:\n",
|
|
||||||
" - '--data_dir' argument required to collect dataset. Pass 'data/' as argument which we can see as directory inside 'BERT-NER' folder for the previous comment and command for 'BERT-NER files'.\n",
|
|
||||||
" - '--bert_model' used to download pretrained bert base model of Hugging Face transformers. There are different model-names as suggested by hugging face for argument, here we select 'bert-base-cased'.\n",
|
|
||||||
" - '--task_name' argument used for task to perform. Enter 'ner' as we will train the model for Named Entity Recogintion(NER).\n",
|
|
||||||
" - '--output_dir' argument is for where to store fine-tuned model. We give name 'out_base' for directory where fine-tuned model stored.\n",
|
|
||||||
" - Other arguments like '--max_seq_length', '--num_train_epochs' and '--warmup_proportion', just give values as suggested in repository.\n",
|
|
||||||
" - For training pass argument '--do_train' and after that evaluating for results pass argument '--do_eval'."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "1cdd7e86",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!python run_ner.py --data_dir=data/ --bert_model=bert-base-cased --task_name=ner --output_dir=out_ner --max_seq_length=128 --do_train --num_train_epochs 5 --do_eval --warmup_proportion=0.1"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "6c0f79a8",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Prediction\n",
|
|
||||||
"- Set the variable *text* in the following cell as the sentence to be NERed\n",
|
|
||||||
"- Run the following cell to get the NER result"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"id": "0da6a2f6",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from bert import Ner\n",
|
|
||||||
"model = Ner(\"out_ner/\")\n",
|
|
||||||
"\n",
|
|
||||||
"text= \"Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival.\"\n",
|
|
||||||
"print(\"Text to predict Entity:\")\n",
|
|
||||||
"print(text)\n",
|
|
||||||
"print('Results of NER:')\n",
|
|
||||||
"\n",
|
|
||||||
"result = model.predict(text)\n",
|
|
||||||
"for k,v in result.items():\n",
|
|
||||||
" if v:\n",
|
|
||||||
" print(v,end=': ')\n",
|
|
||||||
" if k=='PER':\n",
|
|
||||||
" print('Person')\n",
|
|
||||||
" elif k=='LOC':\n",
|
|
||||||
" print('Location')\n",
|
|
||||||
" elif k=='ORG':\n",
|
|
||||||
" print('Organization')\n",
|
|
||||||
" elif k=='MISC':\n",
|
|
||||||
" print('Miscellaneous')"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.7.11"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue