test
This commit is contained in:
parent
1a7ce771bb
commit
e634ece0d9
|
@ -18,13 +18,6 @@
|
|||
<p>基于深度学习的开源中文知识图谱抽取框架</p>
|
||||
</h3>
|
||||
|
||||
<p align="center">
|
||||
<a href="pics/logo_zju_klab.png">
|
||||
<img src="pics/logo_zju_klab.png">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
DeepKE 提供了多种知识抽取模型。
|
||||
|
||||
## 在线演示
|
||||
|
|
|
@ -36,9 +36,9 @@ class BertNer(BertForTokenClassification):
|
|||
logits = self.classifier(sequence_output)
|
||||
return logits
|
||||
|
||||
class Ner:
|
||||
class InferNer:
|
||||
|
||||
def __init__(self,model_dir: str):
|
||||
def __init__(self,model_dir:str,language):
|
||||
self.model , self.tokenizer, self.model_config = self.load_model(model_dir)
|
||||
self.label_map = self.model_config["label_map"]
|
||||
self.max_seq_length = self.model_config["max_seq_length"]
|
||||
|
@ -46,6 +46,7 @@ class Ner:
|
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.language = language
|
||||
|
||||
def load_model(self, model_dir: str, model_config: str = "model_config.json"):
|
||||
model_config = os.path.join(model_dir,model_config)
|
||||
|
@ -56,7 +57,6 @@ class Ner:
|
|||
|
||||
def tokenize(self, text: str):
|
||||
""" tokenize input"""
|
||||
# words = word_tokenize(text)
|
||||
words = list(text)
|
||||
tokens = []
|
||||
valid_positions = []
|
||||
|
@ -117,7 +117,11 @@ class Ner:
|
|||
logits.pop()
|
||||
|
||||
labels = [(self.label_map[label],confidence) for label,confidence in logits]
|
||||
words = list(text)
|
||||
if self.language == 'bert-base-chinese':
|
||||
words = list(text)
|
||||
else:
|
||||
nltk.download('punkt')
|
||||
words = word_tokenize(text)
|
||||
assert len(labels) == len(words)
|
||||
|
||||
result = []
|
||||
|
@ -137,27 +141,34 @@ class Ner:
|
|||
tmp.append(word)
|
||||
else:
|
||||
wordstype = result[i-1][1][2:]
|
||||
tag[wordstype].append(''.join(tmp))
|
||||
if self.language=='bert-base-chinese':
|
||||
tag[wordstype].append(''.join(tmp))
|
||||
else:
|
||||
tag[wordstype].append(' '.join(tmp))
|
||||
tmp.clear()
|
||||
tmp.append(word)
|
||||
elif i==len(result)-1:
|
||||
tmp.append(word)
|
||||
wordstype = result[i][1][2:]
|
||||
tag[wordstype].append(''.join(tmp))
|
||||
if self.language=='bert-base-chinese':
|
||||
tag[wordstype].append(''.join(tmp))
|
||||
else:
|
||||
tag[wordstype].append(' '.join(tmp))
|
||||
else:
|
||||
tmp.append(word)
|
||||
|
||||
return tag
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_path="conf", config_name='config')
|
||||
def main(cfg):
|
||||
model = Ner(utils.get_original_cwd()+'/'+"checkpoint/")
|
||||
model = InferNer(utils.get_original_cwd()+'/'+"checkpoint/", cfg.bert_model)
|
||||
text = cfg.text
|
||||
|
||||
print("The text to be NERed:")
|
||||
print("NER句子:")
|
||||
print(text)
|
||||
print('Results of NER:')
|
||||
print('NER结果:')
|
||||
|
||||
result = model.predict(text)
|
||||
for k,v in result.items():
|
||||
|
@ -171,8 +182,6 @@ def main(cfg):
|
|||
print('Organization')
|
||||
elif k=='MISC':
|
||||
print('Miscellaneous')
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||
import hydra
|
||||
from hydra import utils
|
||||
|
||||
class Ner(BertForTokenClassification):
|
||||
class TrainNer(BertForTokenClassification):
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):
|
||||
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
|
||||
|
@ -303,7 +303,7 @@ def main(cfg):
|
|||
|
||||
# Prepare model
|
||||
config = BertConfig.from_pretrained(cfg.bert_model, num_labels=num_labels, finetuning_task=cfg.task_name)
|
||||
model = Ner.from_pretrained(cfg.bert_model,
|
||||
model = TrainNer.from_pretrained(cfg.bert_model,
|
||||
from_tf = False,
|
||||
config = config)
|
||||
|
||||
|
|
|
@ -1,51 +1,51 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig,
|
||||
BertForTokenClassification, BertTokenizer,
|
||||
WarmupLinearSchedule)
|
||||
from torch import nn
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
from seqeval.metrics import classification_report
|
||||
|
||||
class Ner(BertForTokenClassification):
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):
|
||||
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
|
||||
batch_size,max_len,feat_dim = sequence_output.shape
|
||||
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda')
|
||||
for i in range(batch_size):
|
||||
jj = -1
|
||||
for j in range(max_len):
|
||||
if valid_ids[i][j].item() == 1:
|
||||
jj += 1
|
||||
valid_output[i][jj] = sequence_output[i][j]
|
||||
sequence_output = self.dropout(valid_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss(ignore_index=0)
|
||||
# Only keep active parts of the loss
|
||||
#attention_mask_label = None
|
||||
if attention_mask_label is not None:
|
||||
active_loss = attention_mask_label.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss
|
||||
else:
|
||||
return logits
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_transformers import (WEIGHTS_NAME, AdamW, BertConfig,
|
||||
BertForTokenClassification, BertTokenizer,
|
||||
WarmupLinearSchedule)
|
||||
from torch import nn
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
from seqeval.metrics import classification_report
|
||||
|
||||
class TrainNer(BertForTokenClassification):
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):
|
||||
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
|
||||
batch_size,max_len,feat_dim = sequence_output.shape
|
||||
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device='cuda')
|
||||
for i in range(batch_size):
|
||||
jj = -1
|
||||
for j in range(max_len):
|
||||
if valid_ids[i][j].item() == 1:
|
||||
jj += 1
|
||||
valid_output[i][jj] = sequence_output[i][j]
|
||||
sequence_output = self.dropout(valid_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss(ignore_index=0)
|
||||
# Only keep active parts of the loss
|
||||
#attention_mask_label = None
|
||||
if attention_mask_label is not None:
|
||||
active_loss = attention_mask_label.view(-1) == 1
|
||||
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
||||
active_labels = labels.view(-1)[active_loss]
|
||||
loss = loss_fct(active_logits, active_labels)
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss
|
||||
else:
|
||||
return logits
|
||||
|
|
|
@ -7,12 +7,14 @@ import os
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import nltk
|
||||
from nltk import word_tokenize
|
||||
from pytorch_transformers import (BertConfig, BertForTokenClassification,
|
||||
BertTokenizer)
|
||||
from collections import OrderedDict
|
||||
import nltk
|
||||
nltk.data.path.insert(0,os.path.dirname(os.getcwd())+'/module/data/nltk_data')
|
||||
|
||||
import hydra
|
||||
from hydra import utils
|
||||
|
||||
|
||||
class BertNer(BertForTokenClassification):
|
||||
|
@ -31,9 +33,9 @@ class BertNer(BertForTokenClassification):
|
|||
logits = self.classifier(sequence_output)
|
||||
return logits
|
||||
|
||||
class Ner:
|
||||
class InferNer:
|
||||
|
||||
def __init__(self,model_dir: str):
|
||||
def __init__(self,model_dir:str,language):
|
||||
self.model , self.tokenizer, self.model_config = self.load_model(model_dir)
|
||||
self.label_map = self.model_config["label_map"]
|
||||
self.max_seq_length = self.model_config["max_seq_length"]
|
||||
|
@ -41,6 +43,7 @@ class Ner:
|
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
self.language = language
|
||||
|
||||
def load_model(self, model_dir: str, model_config: str = "model_config.json"):
|
||||
model_config = os.path.join(model_dir,model_config)
|
||||
|
@ -51,7 +54,7 @@ class Ner:
|
|||
|
||||
def tokenize(self, text: str):
|
||||
""" tokenize input"""
|
||||
words = word_tokenize(text)
|
||||
words = list(text)
|
||||
tokens = []
|
||||
valid_positions = []
|
||||
for i,word in enumerate(words):
|
||||
|
@ -111,7 +114,11 @@ class Ner:
|
|||
logits.pop()
|
||||
|
||||
labels = [(self.label_map[label],confidence) for label,confidence in logits]
|
||||
words = word_tokenize(text)
|
||||
if self.language == 'bert-base-chinese':
|
||||
words = list(text)
|
||||
else:
|
||||
nltk.download('punkt')
|
||||
words = word_tokenize(text)
|
||||
assert len(labels) == len(words)
|
||||
|
||||
result = []
|
||||
|
@ -131,14 +138,21 @@ class Ner:
|
|||
tmp.append(word)
|
||||
else:
|
||||
wordstype = result[i-1][1][2:]
|
||||
tag[wordstype].append(' '.join(tmp))
|
||||
if self.language=='bert-base-chinese':
|
||||
tag[wordstype].append(''.join(tmp))
|
||||
else:
|
||||
tag[wordstype].append(' '.join(tmp))
|
||||
tmp.clear()
|
||||
tmp.append(word)
|
||||
elif i==len(result)-1:
|
||||
tmp.append(word)
|
||||
wordstype = result[i][1][2:]
|
||||
tag[wordstype].append(' '.join(tmp))
|
||||
if self.language=='bert-base-chinese':
|
||||
tag[wordstype].append(''.join(tmp))
|
||||
else:
|
||||
tag[wordstype].append(' '.join(tmp))
|
||||
else:
|
||||
tmp.append(word)
|
||||
|
||||
return tag
|
||||
|
||||
|
|
|
@ -1,2 +1,2 @@
|
|||
from .BasicNer import Ner
|
||||
from .InferBert import Ner
|
||||
from .BasicNer import TrainNer
|
||||
from .InferBert import InferNer
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .dataset import *
|
||||
from .preprocess import *
|
||||
from .trainer import *
|
||||
from .trainer import train
|
||||
|
|
Binary file not shown.
|
@ -1,98 +0,0 @@
|
|||
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
|
||||
|
||||
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
|
||||
been contributed by various people using NLTK for sentence boundary detection.
|
||||
|
||||
For information about how to use these models, please confer the tokenization HOWTO:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
|
||||
and chapter 3.8 of the NLTK book:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
|
||||
|
||||
There are pretrained tokenizers for the following languages:
|
||||
|
||||
File Language Source Contents Size of training corpus(in tokens) Model contributed by
|
||||
=======================================================================================================================================================================
|
||||
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
|
||||
Literarni Noviny
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
|
||||
(Berlingske Avisdata, Copenhagen) Weekend Avisen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
|
||||
(American)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
|
||||
Text Bank (Suomen Kielen newspapers
|
||||
Tekstipankki)
|
||||
Finnish Center for IT Science
|
||||
(CSC)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
|
||||
(Switzerland) CD-ROM
|
||||
(Uses "ss"
|
||||
instead of "ß")
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
|
||||
(Bokmål and Information Technologies,
|
||||
Nynorsk) Bergen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
|
||||
(http://www.nkjp.pl/)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
|
||||
(Brazilian) (Linguateca)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
|
||||
Slovene Academy for Arts
|
||||
and Sciences
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
|
||||
(and some other texts)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
|
||||
(Türkçe Derlem Projesi)
|
||||
University of Ankara
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
|
||||
Unicode using the codecs module.
|
||||
|
||||
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
|
||||
Computational Linguistics 32: 485-525.
|
||||
|
||||
---- Training Code ----
|
||||
|
||||
# import punkt
|
||||
import nltk.tokenize.punkt
|
||||
|
||||
# Make a new Tokenizer
|
||||
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
|
||||
|
||||
# Read in training corpus (one example: Slovene)
|
||||
import codecs
|
||||
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
|
||||
|
||||
# Train tokenizer
|
||||
tokenizer.train(text)
|
||||
|
||||
# Dump pickled tokenizer
|
||||
import pickle
|
||||
out = open("slovene.pickle","wb")
|
||||
pickle.dump(tokenizer, out)
|
||||
out.close()
|
||||
|
||||
---------
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,98 +0,0 @@
|
|||
Pretrained Punkt Models -- Jan Strunk (New version trained after issues 313 and 514 had been corrected)
|
||||
|
||||
Most models were prepared using the test corpora from Kiss and Strunk (2006). Additional models have
|
||||
been contributed by various people using NLTK for sentence boundary detection.
|
||||
|
||||
For information about how to use these models, please confer the tokenization HOWTO:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/howto/tokenize.html
|
||||
and chapter 3.8 of the NLTK book:
|
||||
http://nltk.googlecode.com/svn/trunk/doc/book/ch03.html#sec-segmentation
|
||||
|
||||
There are pretrained tokenizers for the following languages:
|
||||
|
||||
File Language Source Contents Size of training corpus(in tokens) Model contributed by
|
||||
=======================================================================================================================================================================
|
||||
czech.pickle Czech Multilingual Corpus 1 (ECI) Lidove Noviny ~345,000 Jan Strunk / Tibor Kiss
|
||||
Literarni Noviny
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
danish.pickle Danish Avisdata CD-Rom Ver. 1.1. 1995 Berlingske Tidende ~550,000 Jan Strunk / Tibor Kiss
|
||||
(Berlingske Avisdata, Copenhagen) Weekend Avisen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
dutch.pickle Dutch Multilingual Corpus 1 (ECI) De Limburger ~340,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
english.pickle English Penn Treebank (LDC) Wall Street Journal ~469,000 Jan Strunk / Tibor Kiss
|
||||
(American)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
estonian.pickle Estonian University of Tartu, Estonia Eesti Ekspress ~359,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
finnish.pickle Finnish Finnish Parole Corpus, Finnish Books and major national ~364,000 Jan Strunk / Tibor Kiss
|
||||
Text Bank (Suomen Kielen newspapers
|
||||
Tekstipankki)
|
||||
Finnish Center for IT Science
|
||||
(CSC)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
french.pickle French Multilingual Corpus 1 (ECI) Le Monde ~370,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
german.pickle German Neue Zürcher Zeitung AG Neue Zürcher Zeitung ~847,000 Jan Strunk / Tibor Kiss
|
||||
(Switzerland) CD-ROM
|
||||
(Uses "ss"
|
||||
instead of "ß")
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
greek.pickle Greek Efstathios Stamatatos To Vima (TO BHMA) ~227,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
italian.pickle Italian Multilingual Corpus 1 (ECI) La Stampa, Il Mattino ~312,000 Jan Strunk / Tibor Kiss
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
norwegian.pickle Norwegian Centre for Humanities Bergens Tidende ~479,000 Jan Strunk / Tibor Kiss
|
||||
(Bokmål and Information Technologies,
|
||||
Nynorsk) Bergen
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
polish.pickle Polish Polish National Corpus Literature, newspapers, etc. ~1,000,000 Krzysztof Langner
|
||||
(http://www.nkjp.pl/)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
portuguese.pickle Portuguese CETENFolha Corpus Folha de São Paulo ~321,000 Jan Strunk / Tibor Kiss
|
||||
(Brazilian) (Linguateca)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
slovene.pickle Slovene TRACTOR Delo ~354,000 Jan Strunk / Tibor Kiss
|
||||
Slovene Academy for Arts
|
||||
and Sciences
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
spanish.pickle Spanish Multilingual Corpus 1 (ECI) Sur ~353,000 Jan Strunk / Tibor Kiss
|
||||
(European)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
swedish.pickle Swedish Multilingual Corpus 1 (ECI) Dagens Nyheter ~339,000 Jan Strunk / Tibor Kiss
|
||||
(and some other texts)
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
turkish.pickle Turkish METU Turkish Corpus Milliyet ~333,000 Jan Strunk / Tibor Kiss
|
||||
(Türkçe Derlem Projesi)
|
||||
University of Ankara
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
The corpora contained about 400,000 tokens on average and mostly consisted of newspaper text converted to
|
||||
Unicode using the codecs module.
|
||||
|
||||
Kiss, Tibor and Strunk, Jan (2006): Unsupervised Multilingual Sentence Boundary Detection.
|
||||
Computational Linguistics 32: 485-525.
|
||||
|
||||
---- Training Code ----
|
||||
|
||||
# import punkt
|
||||
import nltk.tokenize.punkt
|
||||
|
||||
# Make a new Tokenizer
|
||||
tokenizer = nltk.tokenize.punkt.PunktSentenceTokenizer()
|
||||
|
||||
# Read in training corpus (one example: Slovene)
|
||||
import codecs
|
||||
text = codecs.open("slovene.plain","Ur","iso-8859-2").read()
|
||||
|
||||
# Train tokenizer
|
||||
tokenizer.train(text)
|
||||
|
||||
# Dump pickled tokenizer
|
||||
import pickle
|
||||
out = open("slovene.pickle","wb")
|
||||
pickle.dump(tokenizer, out)
|
||||
out.close()
|
||||
|
||||
---------
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -1,22 +0,0 @@
|
|||
import sys
|
||||
sys.path.append("..")
|
||||
from models.InferBert import Ner
|
||||
model = Ner("out_ner/")
|
||||
|
||||
text= "Irene, a master student in Zhejiang University, Hangzhou, is traveling in Warsaw for Chopin Music Festival."
|
||||
print("Text to predict Entity:")
|
||||
print(text)
|
||||
print('Results of NER:')
|
||||
|
||||
result = model.predict(text)
|
||||
for k,v in result.items():
|
||||
if v:
|
||||
print(v,end=': ')
|
||||
if k=='PER':
|
||||
print('Person')
|
||||
elif k=='LOC':
|
||||
print('Location')
|
||||
elif k=='ORG':
|
||||
print('Organization')
|
||||
elif k=='MISC':
|
||||
print('Miscellaneous')
|
|
@ -1,4 +1,4 @@
|
|||
from dataset import *
|
||||
from .dataset import *
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
|
@ -10,7 +10,7 @@ import sys
|
|||
import numpy as np
|
||||
|
||||
class NerProcessor(DataProcessor):
|
||||
"""Processor for the CoNLL-2003 data set."""
|
||||
"""Processor for the dataset."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -18,156 +18,61 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
|||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from seqeval.metrics import classification_report
|
||||
|
||||
from dataset import *
|
||||
from preprocess import *
|
||||
sys.path.append("..")
|
||||
from tools.BasicNer import Ner
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import hydra
|
||||
from hydra import utils
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
from .dataset import *
|
||||
from .preprocess import *
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../")))
|
||||
from models.BasicNer import TrainNer
|
||||
|
||||
## Required parameters
|
||||
parser.add_argument("--data_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
|
||||
parser.add_argument("--bert_model", default=None, type=str, required=True,
|
||||
help="Bert pre-trained model selected in the list: bert-base-uncased, "
|
||||
"bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
|
||||
"bert-base-multilingual-cased, bert-base-chinese.")
|
||||
parser.add_argument("--task_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The name of the task to train.")
|
||||
parser.add_argument("--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--cache_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3")
|
||||
parser.add_argument("--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after WordPiece tokenization. \n"
|
||||
"Sequences longer than this will be truncated, and sequences shorter \n"
|
||||
"than this will be padded.")
|
||||
parser.add_argument("--do_train",
|
||||
action='store_true',
|
||||
help="Whether to run training.")
|
||||
parser.add_argument("--do_eval",
|
||||
action='store_true',
|
||||
help="Whether to run eval or not.")
|
||||
parser.add_argument("--eval_on",
|
||||
default="dev",
|
||||
help="Whether to run eval on the dev set or test set.")
|
||||
parser.add_argument("--do_lower_case",
|
||||
action='store_true',
|
||||
help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument("--train_batch_size",
|
||||
default=32,
|
||||
type=int,
|
||||
help="Total batch size for training.")
|
||||
parser.add_argument("--eval_batch_size",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Total batch size for eval.")
|
||||
parser.add_argument("--learning_rate",
|
||||
default=5e-5,
|
||||
type=float,
|
||||
help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--num_train_epochs",
|
||||
default=3.0,
|
||||
type=float,
|
||||
help="Total number of training epochs to perform.")
|
||||
parser.add_argument("--warmup_proportion",
|
||||
default=0.1,
|
||||
type=float,
|
||||
help="Proportion of training to perform linear learning rate warmup for. "
|
||||
"E.g., 0.1 = 10%% of training.")
|
||||
parser.add_argument("--weight_decay", default=0.01, type=float,
|
||||
help="Weight deay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
|
||||
help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float,
|
||||
help="Max gradient norm.")
|
||||
parser.add_argument("--no_cuda",
|
||||
action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="local_rank for distributed training on gpus")
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=42,
|
||||
help="random seed for initialization")
|
||||
parser.add_argument('--gradient_accumulation_steps',
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
parser.add_argument('--fp16',
|
||||
action='store_true',
|
||||
help="Whether to use 16-bit float precision instead of 32-bit")
|
||||
parser.add_argument('--fp16_opt_level', type=str, default='O1',
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html")
|
||||
parser.add_argument('--loss_scale',
|
||||
type=float, default=0,
|
||||
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
|
||||
"0 (default value): dynamic loss scaling.\n"
|
||||
"Positive power of 2: static loss scaling value.\n")
|
||||
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
|
||||
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def train(cfg):
|
||||
processors = {"ner":NerProcessor}
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
if cfg.local_rank == -1 or cfg.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not cfg.no_cuda else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
else:
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
device = torch.device("cuda", args.local_rank)
|
||||
torch.cuda.set_device(cfg.local_rank)
|
||||
device = torch.device("cuda", cfg.local_rank)
|
||||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
|
||||
device, n_gpu, bool(cfg.local_rank != -1), cfg.fp16))
|
||||
|
||||
if args.gradient_accumulation_steps < 1:
|
||||
if cfg.gradient_accumulation_steps < 1:
|
||||
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
|
||||
args.gradient_accumulation_steps))
|
||||
cfg.gradient_accumulation_steps))
|
||||
|
||||
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
|
||||
cfg.train_batch_size = cfg.train_batch_size // cfg.gradient_accumulation_steps
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
random.seed(cfg.seed)
|
||||
np.random.seed(cfg.seed)
|
||||
torch.manual_seed(cfg.seed)
|
||||
|
||||
if not args.do_train and not args.do_eval:
|
||||
if not cfg.do_train and not cfg.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
if os.path.exists(utils.get_original_cwd()+'/'+cfg.output_dir) and os.listdir(utils.get_original_cwd()+'/'+cfg.output_dir) and cfg.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(utils.get_original_cwd()+'/'+cfg.output_dir))
|
||||
if not os.path.exists(utils.get_original_cwd()+'/'+cfg.output_dir):
|
||||
os.makedirs(utils.get_original_cwd()+'/'+cfg.output_dir)
|
||||
|
||||
task_name = args.task_name.lower()
|
||||
task_name = cfg.task_name.lower()
|
||||
|
||||
if task_name not in processors:
|
||||
raise ValueError("Task not found: %s" % (task_name))
|
||||
|
@ -176,27 +81,27 @@ def main():
|
|||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list) + 1
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
tokenizer = BertTokenizer.from_pretrained(cfg.bert_model, do_lower_case=cfg.do_lower_case)
|
||||
|
||||
train_examples = None
|
||||
num_train_optimization_steps = 0
|
||||
if args.do_train:
|
||||
train_examples = processor.get_train_examples(args.data_dir)
|
||||
if cfg.do_train:
|
||||
train_examples = processor.get_train_examples(utils.get_original_cwd()+'/'+cfg.data_dir)
|
||||
num_train_optimization_steps = int(
|
||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||
if args.local_rank != -1:
|
||||
len(train_examples) / cfg.train_batch_size / cfg.gradient_accumulation_steps) * cfg.num_train_epochs
|
||||
if cfg.local_rank != -1:
|
||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||
|
||||
if args.local_rank not in [-1, 0]:
|
||||
if cfg.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
# Prepare model
|
||||
config = BertConfig.from_pretrained(args.bert_model, num_labels=num_labels, finetuning_task=args.task_name)
|
||||
model = Ner.from_pretrained(args.bert_model,
|
||||
config = BertConfig.from_pretrained(cfg.bert_model, num_labels=num_labels, finetuning_task=cfg.task_name)
|
||||
model = TrainNer.from_pretrained(cfg.bert_model,
|
||||
from_tf = False,
|
||||
config = config)
|
||||
|
||||
if args.local_rank == 0:
|
||||
if cfg.local_rank == 0:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
model.to(device)
|
||||
|
@ -204,35 +109,35 @@ def main():
|
|||
param_optimizer = list(model.named_parameters())
|
||||
no_decay = ['bias','LayerNorm.weight']
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
|
||||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': cfg.weight_decay},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
warmup_steps = int(args.warmup_proportion * num_train_optimization_steps)
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
warmup_steps = int(cfg.warmup_proportion * num_train_optimization_steps)
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate, eps=cfg.adam_epsilon)
|
||||
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)
|
||||
if args.fp16:
|
||||
if cfg.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=cfg.fp16_opt_level)
|
||||
|
||||
# multi-gpu training (should be after apex fp16 initialization)
|
||||
if n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
if cfg.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[cfg.local_rank],
|
||||
output_device=cfg.local_rank,
|
||||
find_unused_parameters=True)
|
||||
|
||||
global_step = 0
|
||||
nb_tr_steps = 0
|
||||
tr_loss = 0
|
||||
label_map = {i : label for i, label in enumerate(label_list,1)}
|
||||
if args.do_train:
|
||||
if cfg.do_train:
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, label_list, args.max_seq_length, tokenizer)
|
||||
train_examples, label_list, cfg.max_seq_length, tokenizer)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
|
@ -240,14 +145,14 @@ def main():
|
|||
all_valid_ids = torch.tensor([f.valid_ids for f in train_features], dtype=torch.long)
|
||||
all_lmask_ids = torch.tensor([f.label_mask for f in train_features], dtype=torch.long)
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_valid_ids,all_lmask_ids)
|
||||
if args.local_rank == -1:
|
||||
if cfg.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
else:
|
||||
train_sampler = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=cfg.train_batch_size)
|
||||
|
||||
model.train()
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
for _ in trange(int(cfg.num_train_epochs), desc="Epoch"):
|
||||
tr_loss = 0
|
||||
nb_tr_examples, nb_tr_steps = 0, 0
|
||||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
|
@ -256,21 +161,21 @@ def main():
|
|||
loss = model(input_ids, segment_ids, input_mask, label_ids,valid_ids,l_mask)
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
loss = loss / args.gradient_accumulation_steps
|
||||
if cfg.gradient_accumulation_steps > 1:
|
||||
loss = loss / cfg.gradient_accumulation_steps
|
||||
|
||||
if args.fp16:
|
||||
if cfg.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), cfg.max_grad_norm)
|
||||
else:
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
|
||||
|
||||
tr_loss += loss.item()
|
||||
nb_tr_examples += input_ids.size(0)
|
||||
nb_tr_steps += 1
|
||||
if (step + 1) % args.gradient_accumulation_steps == 0:
|
||||
if (step + 1) % cfg.gradient_accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
scheduler.step() # Update learning rate schedule
|
||||
model.zero_grad()
|
||||
|
@ -278,27 +183,27 @@ def main():
|
|||
|
||||
# Save a trained model and the associated configuration
|
||||
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
|
||||
model_to_save.save_pretrained(args.output_dir)
|
||||
tokenizer.save_pretrained(args.output_dir)
|
||||
model_to_save.save_pretrained(utils.get_original_cwd()+'/'+cfg.output_dir)
|
||||
tokenizer.save_pretrained(utils.get_original_cwd()+'/'+cfg.output_dir)
|
||||
label_map = {i : label for i, label in enumerate(label_list,1)}
|
||||
model_config = {"bert_model":args.bert_model,"do_lower":args.do_lower_case,"max_seq_length":args.max_seq_length,"num_labels":len(label_list)+1,"label_map":label_map}
|
||||
json.dump(model_config,open(os.path.join(args.output_dir,"model_config.json"),"w"))
|
||||
model_config = {"bert_model":cfg.bert_model,"do_lower":cfg.do_lower_case,"max_seq_length":cfg.max_seq_length,"num_labels":len(label_list)+1,"label_map":label_map}
|
||||
json.dump(model_config,open(os.path.join(utils.get_original_cwd()+'/'+cfg.output_dir,"model_config.json"),"w"))
|
||||
# Load a trained model and config that you have fine-tuned
|
||||
else:
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = Ner.from_pretrained(args.output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
model = Ner.from_pretrained(utils.get_original_cwd()+'/'+cfg.output_dir)
|
||||
tokenizer = BertTokenizer.from_pretrained(utils.get_original_cwd()+'/'+cfg.output_dir, do_lower_case=cfg.do_lower_case)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
if args.eval_on == "dev":
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
elif args.eval_on == "test":
|
||||
eval_examples = processor.get_test_examples(args.data_dir)
|
||||
if cfg.do_eval and (cfg.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
if cfg.eval_on == "dev":
|
||||
eval_examples = processor.get_dev_examples(utils.get_original_cwd()+'/'+cfg.data_dir)
|
||||
elif cfg.eval_on == "test":
|
||||
eval_examples = processor.get_test_examples(utils.get_original_cwd()+'/'+cfg.data_dir)
|
||||
else:
|
||||
raise ValueError("eval on dev or test set only")
|
||||
eval_features = convert_examples_to_features(eval_examples, label_list, args.max_seq_length, tokenizer)
|
||||
eval_features = convert_examples_to_features(eval_examples, label_list, cfg.max_seq_length, tokenizer)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
|
@ -308,7 +213,7 @@ def main():
|
|||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids,all_valid_ids,all_lmask_ids)
|
||||
# Run prediction for full data
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=cfg.eval_batch_size)
|
||||
model.eval()
|
||||
eval_loss, eval_accuracy = 0, 0
|
||||
nb_eval_steps, nb_eval_examples = 0, 0
|
||||
|
@ -346,13 +251,9 @@ def main():
|
|||
temp_2.append(label_map[logits[i][j]])
|
||||
|
||||
report = classification_report(y_true, y_pred,digits=4)
|
||||
# logger.info("\n%s", report)
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
logger.info("\n%s", report)
|
||||
output_eval_file = os.path.join(utils.get_original_cwd()+'/'+cfg.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
# logger.info("***** Eval results *****")
|
||||
# logger.info("\n%s", report)
|
||||
writer.write(report)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
logger.info("***** Eval results *****")
|
||||
logger.info("\n%s", report)
|
||||
writer.write(report)
|
Loading…
Reference in New Issue