ner-few-shot

This commit is contained in:
lilei 2021-09-27 01:21:36 -05:00 committed by GitHub
parent d4ba08a709
commit 8330b2d896
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 3079 additions and 0 deletions

View File

@ -0,0 +1,703 @@
import torch
from torch import nn
from torch.nn import functional as F
from transformers.configuration_bart import BartConfig
from .modeling_bart import BartModel, _prepare_bart_decoder_inputs
from ..utils.utils import avg_token_embeddings, seq_to_mask, _get_model_device
from functools import partial
from typing import Union
class PromptBartEncoder(nn.Module):
def __init__(self, encoder):
super(PromptBartEncoder, self).__init__()
self.bart_encoder = encoder
def forward(self, src_tokens, attention_mask=None, past_key_values=None):
encoder_dicts = self.bart_encoder(input_ids=src_tokens, attention_mask=attention_mask, past_key_values=past_key_values, return_dict=True, output_hidden_states=True)
return encoder_dicts.last_hidden_state, encoder_dicts.hidden_states
class PromptBartDecoder(nn.Module):
def __init__(self, decoder, pad_token_id, label_ids, use_prompt=False, prompt_len=10, learn_weights=False):
super(PromptBartDecoder, self).__init__()
self.bart_decoder = decoder
self.pad_token_id = pad_token_id
self.use_prompt = use_prompt
self.prompt_len = prompt_len
self.learn_weights = learn_weights
self.label_ids = label_ids
print(label_ids)
if self.learn_weights: # set learnable averge weights
self.averge_weights = nn.ParameterList(parameters=None)
for id in label_ids:
if len(id) > 1:
self.averge_weights.append(nn.Parameter(torch.FloatTensor(len(id))))
print(self.averge_weights)
mapping = [0, 2]
for id in label_ids:
mapping += id[:1]
mapping = torch.LongTensor(mapping)
else:
mapping = torch.LongTensor([0, 2]+label_ids)
self.label_start_id = min(label_ids)
self.label_end_id = max(label_ids)+1
self.register_buffer('mapping', mapping)
self.src_start_index = len(mapping)
hidden_size = decoder.embed_tokens.weight.size(1)
self.bart_mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size),
nn.Dropout(0.3),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size))
self.dropout_layer = nn.Dropout(0.3)
def forward(self, tgt_tokens, prompt_state):
cumsum = tgt_tokens.eq(1).flip(dims=[1]).cumsum(dim=-1)
tgt_pad_mask = cumsum.flip(dims=[-1]).ne(cumsum[:, -1:])
encoder_outputs = prompt_state.encoder_output # last_hidden_state
attention_mask = prompt_state.encoder_mask # attention_mask
first = prompt_state.first
src_tokens = prompt_state.src_tokens
past_key_values = prompt_state.past_key_values
# mapping target tokens
mapping_token_mask = tgt_tokens.lt(self.src_start_index)
mapped_tokens = tgt_tokens.masked_fill(tgt_tokens.ge(self.src_start_index), 0)
tag_mapped_tokens = self.mapping[mapped_tokens]
src_tokens_index = tgt_tokens - self.src_start_index # bsz x num_src_token
src_tokens_index = src_tokens_index.masked_fill(src_tokens_index.lt(0), 0)
if first is not None:
src_tokens = src_tokens.gather(index=first, dim=1)
word_mapped_tokens = src_tokens.gather(index=src_tokens_index, dim=1)
tokens = torch.where(mapping_token_mask, tag_mapped_tokens, word_mapped_tokens) # bsz x max_len
tokens = tokens.masked_fill(tgt_pad_mask, self.pad_token_id)
decoder_input_ids, _, causal_mask = _prepare_bart_decoder_inputs(
self.pad_token_id,
tokens,
decoder_input_ids=None,
decoder_padding_mask=None,
causal_mask_dtype=self.bart_decoder.embed_tokens.weight.dtype
)
if self.use_prompt:
assert past_key_values is not None
_, _, seqlen, _ = past_key_values[0]['self']['prev_value'].shape
tgt_len = decoder_input_ids.size(1)
temp_mask = torch.zeros(tgt_len, seqlen).to(causal_mask.device) #tgtlen, preseqlen
causal_mask = torch.cat([temp_mask, causal_mask],dim=1) #tgtlen, preseqlen+tgtlen
if self.training:
tokens = tokens[:, :-1]
decoder_pad_mask = tokens.eq(self.pad_token_id)
dict = self.bart_decoder(input_ids=tokens,
encoder_hidden_states=encoder_outputs, # last_hidden_state
encoder_padding_mask=attention_mask, # attention_mask
decoder_padding_mask=decoder_pad_mask,
decoder_causal_mask=causal_mask[:tokens.size(1), :self.prompt_len+tokens.size(1)],
output_hidden_states=True,
past_key_values=past_key_values,
return_dict=True)
else:
past_key_values = prompt_state.past_key_values
dict = self.bart_decoder(input_ids=tokens,
encoder_hidden_states=encoder_outputs,
encoder_padding_mask=attention_mask,
decoder_padding_mask=None,
decoder_causal_mask=None,
past_key_values=past_key_values,
use_cache=True,
return_dict=True)
hidden_state = dict.last_hidden_state # bsz x max_len x hidden_size
hidden_state = self.dropout_layer(hidden_state)
if not self.training:
prompt_state.past_key_values = dict.past_key_values
logits = hidden_state.new_full((hidden_state.size(0), hidden_state.size(1), self.src_start_index+src_tokens.size(-1)),
fill_value=-1e24)
# compute eos scores
eos_scores = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[2:3])) # bsz x max_len x 1
if self.learn_weights: # use averge_weights compute entity labels scores
tag_scores = None
idx = 0
for ids in self.label_ids: # bsz x max_len x num_class
if len(ids) <= 1:
temp_score = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[ids]))
else:
weight = F.softmax(self.averge_weights[idx])
temp_score = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[[ids[0]]])) * weight[0]
for i in range(1, len(ids)):
temp_score = temp_score + F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[[ids[i]]])) * weight[i]
idx += 1
if tag_scores is None:
tag_scores = temp_score
else:
tag_scores = torch.cat((tag_scores, temp_score), dim=2)
else:
tag_scores = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[self.label_start_id:self.label_end_id])) # bsz x max_len x num_class
# bsz x max_bpe_len x hidden_size
src_outputs = encoder_outputs
if hasattr(self, 'encoder_mlp'):
src_outputs = self.encoder_mlp(src_outputs)
if first is not None:
mask = first.eq(0) # bsz x 1 x max_word_len
# bsz x max_word_len x hidden_size
src_outputs = src_outputs.gather(index=first.unsqueeze(2).repeat(1, 1, src_outputs.size(-1)), dim=1)
else:
mask = attention_mask.eq(0)
# src_outputs = self.decoder.embed_tokens(src_tokens)
mask = mask.unsqueeze(1)
input_embed = self.dropout_layer(self.bart_decoder.embed_tokens(src_tokens)) # bsz x max_word_len x hidden_size
src_outputs = (src_outputs + input_embed)/2
word_scores = torch.einsum('blh,bnh->bln', hidden_state, src_outputs) # bsz x max_len x max_word_len
mask = mask.__or__(src_tokens.eq(2).cumsum(dim=1).ge(1).unsqueeze(1))
word_scores = word_scores.masked_fill(mask, -1e32)
logits[:, :, 1:2] = eos_scores
logits[:, :, 2:self.src_start_index] = tag_scores
logits[:, :, self.src_start_index:] = word_scores
return logits, prompt_state
def decode(self, tokens, state):
return self(tokens, state)[0][:, -1]
class PromptBartModel(nn.Module):
def __init__(self, tokenizer, label_ids, args):
super(PromptBartModel, self).__init__()
self.use_prompt = args.use_prompt
self.prompt_len = args.prompt_len
self.prompt_dim = args.prompt_dim
self.learn_weights = args.learn_weights
self.device = 'cuda' if torch.cuda.is_available else 'cpu'
bart_name = args.bart_name
self.bart_config = BartConfig.from_pretrained(bart_name)
self.bart_config.use_prompt = args.use_prompt
self.bart_config.preseqlen = args.prompt_len
bart_config = self.bart_config
bart_model = BartModel.from_pretrained(bart_name, config=bart_config)
num_tokens, _ = bart_model.encoder.embed_tokens.weight.shape
bart_model.resize_token_embeddings(len(tokenizer.unique_no_split_tokens)+num_tokens)
bart_model = avg_token_embeddings(tokenizer, bart_model, bart_name, num_tokens)
self.prompt_encoder = PromptBartEncoder(bart_model.encoder)
self.prompt_decoder = PromptBartDecoder(bart_model.decoder, tokenizer.pad_token_id, label_ids, self.use_prompt, self.prompt_len, self.learn_weights)
self.prompt_inputs = torch.arange(self.prompt_len).long()
self.encoder_prompt_embed = nn.Embedding(self.prompt_len, bart_config.d_model)
self.encoder_mlp = nn.Sequential(
nn.Linear(bart_config.d_model, self.prompt_dim),
nn.Tanh(),
nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model))
self.decoder_prompt_embed = nn.Embedding(self.prompt_len, bart_config.d_model)
self.decoder_mlp = nn.Sequential(
nn.Linear(bart_config.d_model, self.prompt_dim),
nn.Tanh(),
nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model))
self.prompt_cross_embed = nn.Embedding(self.prompt_len, bart_config.d_model)
self.cross_mlp = nn.Sequential(
nn.Linear(bart_config.d_model, self.prompt_dim),
nn.Tanh(),
nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model))
self.dropout = nn.Dropout(0.0)
def forward(self, src_tokens, tgt_tokens, src_seq_len, first):
prompt_state = self.generator(src_tokens, src_seq_len, first)
decoder_outputs, prompt_state = self.prompt_decoder(tgt_tokens, prompt_state)
return decoder_outputs
def generator(self, src_tokens, src_seq_len, first):
batch_size = src_tokens.size(0)
past_key_values = self.get_prompt(batch_size) if self.use_prompt else None
attention_mask = seq_to_mask(src_seq_len, max_len=src_tokens.size(1))
encoder_outputs, hidden_states = self.prompt_encoder(src_tokens, attention_mask=attention_mask, past_key_values=past_key_values)
prompt_state = PromptBartState(encoder_outputs, attention_mask, past_key_values, src_tokens, first, hidden_states[0], self.bart_config.preseqlen)
return prompt_state
def get_prompt(self, batch_size):
input_tokens = self.prompt_inputs.unsqueeze(0).expand(batch_size, -1).to(self.device)
# encoder prompt
encoder_embed = self.encoder_prompt_embed(input_tokens)
past_key_values = self.encoder_mlp(encoder_embed) #bsz, seqlen, layer*emb
bsz, seqlen, _ = past_key_values.shape
past_key_values = past_key_values.view(bsz, seqlen, self.bart_config.decoder_layers * 2,
self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads)
past_key_values = self.dropout(past_key_values)
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) # key + value
# decoder prompt
decoder_embed = self.decoder_prompt_embed(input_tokens)
past_key_values2 = self.decoder_mlp(decoder_embed) # bsz, seqlen, layer*emb
past_key_values2 = past_key_values2.view(bsz, seqlen, self.bart_config.decoder_layers * 2,
self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads)
past_key_values2 = self.dropout(past_key_values2)
past_key_values2 = past_key_values2.permute([2, 0, 3, 1, 4]).split(2)
# cross prompt
cross_embed = self.prompt_cross_embed(input_tokens)
past_key_values_enc = self.cross_mlp(cross_embed) # bsz, seqlen, layer*emb
past_key_values_enc = past_key_values_enc.view(bsz, seqlen, self.bart_config.decoder_layers * 2,
self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads)
past_key_values_enc = self.dropout(past_key_values_enc)
past_key_values_enc = past_key_values_enc.permute([2, 0, 3, 1, 4]).split(2)
result = []
for i, key_val in enumerate(past_key_values):
temp_dict = {'self': {"prev_key": key_val[0].contiguous(),
"prev_value": key_val[1].contiguous(),
"prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val.device).bool() #bsz, preseqlen
},
}
key_val2 = past_key_values2[i]
temp_dict['encoder_decoder'] = {"prev_key": key_val2[0].contiguous(),
"prev_value": key_val2[1].contiguous(),
"prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val2.device).bool()
}
key_val_enc = past_key_values_enc[i]
temp_dict['encoder'] = {"prev_key": key_val_enc[0].contiguous(),
"prev_value": key_val_enc[1].contiguous(),
"prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val_enc.device).bool()
}
result.append(temp_dict)
return result
class PromptBartState(object):
def __init__(self, encoder_output, encoder_mask, past_key_values, src_tokens, first, src_embed_outputs, preseqlen):
self.encoder_output = encoder_output
self.encoder_mask = encoder_mask
self.past_key_values = past_key_values
self.src_tokens = src_tokens
self.first = first
self.src_embed_outputs = src_embed_outputs
self.preseqlen = preseqlen
def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0):
if isinstance(state, torch.Tensor):
state = state.index_select(index=indices, dim=dim)
elif isinstance(state, list):
for i in range(len(state)):
assert state[i] is not None
state[i] = self._reorder_state(state[i], indices, dim)
elif isinstance(state, tuple):
tmp_list = []
for i in range(len(state)):
assert state[i] is not None
tmp_list.append(self._reorder_state(state[i], indices, dim))
state = tuple(tmp_list)
else:
raise TypeError(f"Cannot reorder data of type:{type(state)}")
return state
def reorder_state(self, indices: torch.LongTensor):
super().reorder_state(indices)
self.src_tokens = self._reorder_state(self.src_tokens, indices)
if self.first is not None:
self.first = self._reorder_state(self.first, indices)
self.src_embed_outputs = self._reorder_state(self.src_embed_outputs, indices)
if self.past_key_values is not None:
new = []
for layer in self.past_key_values:
new_layer = {}
for key1 in list(layer.keys()):
new_layer_ = {}
for key2 in list(layer[key1].keys()):
if layer[key1][key2] is not None:
layer[key1][key2] = self._reorder_state(layer[key1][key2], indices)
new_layer_[key2] = layer[key1][key2]
new_layer[key1] = new_layer_
new.append(new_layer)
self.past_key_values = new
def num_samples(self):
if self.encoder_output is not None:
return self.encoder_output.size(0)
else:
return None
class PromptGeneratorModel(nn.Module):
def __init__(self, prompt_model, max_length=20, max_len_a=0.0, num_beams=1,
do_sample=False, bos_token_id=None, eos_token_id=None,
repetition_penalty=1, length_penalty=1.0, pad_token_id=0, restricter=None):
super(PromptGeneratorModel, self).__init__()
self.prompt_model = prompt_model
self.decoder = prompt_model.prompt_decoder
self.generate_func = partial(greedy_generate, decoder=self.decoder, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams,
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty, pad_token_id=pad_token_id,
restricter=restricter)
self.do_sample = do_sample
self.max_length = max_length
self.num_beams = num_beams
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.repetition_penalty = repetition_penalty
self.length_penalty = length_penalty
self.pad_token_id = pad_token_id
self.restricter = restricter
self.max_len_a = max_len_a
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None, first=None):
"""
:param torch.LongTensor src_tokens: bsz x max_len
:param torch.LongTensor tgt_tokens: bsz x max_len'
:param torch.LongTensor src_seq_len: bsz
:param torch.LongTensor tgt_seq_len: bsz
:return:
"""
return self.prompt_model(src_tokens, tgt_tokens, src_seq_len, first)
def predict(self, src_tokens, src_seq_len=None, first=None):
"""
:param torch.LongTensor src_tokens: bsz x max_len
:param torch.LongTensor src_seq_len: bsz
:return:
"""
prompt_state = self.prompt_model.generator(src_tokens, src_seq_len, first) # encoder output
result = self.generate_func(tokens=None, state=prompt_state)
return result
@torch.no_grad()
def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1,
bos_token_id=None, eos_token_id=None, pad_token_id=0,
repetition_penalty=1, length_penalty=1.0, restricter=None):
if num_beams == 1:
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id, restricter=restricter)
else:
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
num_beams=num_beams,
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
pad_token_id=pad_token_id, restricter=restricter)
return token_ids
def _no_beam_search_generate(decoder: PromptBartDecoder, state, tokens=None, max_length=20, max_len_a=0.0, bos_token_id=None,
eos_token_id=None,
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0,
restricter=None):
device = _get_model_device(decoder)
if tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
batch_size = state.num_samples()
if batch_size is None:
raise RuntimeError("Cannot infer the number of samples from `state`.")
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
batch_size = tokens.size(0)
if state.num_samples:
assert state.num_samples() == batch_size, "The number of samples in `tokens` and `state` should match."
if eos_token_id is None:
_eos_token_id = -1
else:
_eos_token_id = eos_token_id
scores = decoder.decode(tokens=tokens, state=state) # update state
if restricter is not None:
_, next_tokens = restricter(state, tokens, scores, num_beams=1)
else:
next_tokens = scores.argmax(dim=-1, keepdim=True)
token_ids = torch.cat([tokens, next_tokens], dim=1)
cur_len = token_ids.size(1)
dones = token_ids.new_zeros(batch_size).eq(1).__or__(next_tokens.squeeze(1).eq(eos_token_id))
# tokens = tokens[:, -1:]
if max_len_a!=0:
# (bsz x num_beams, )
if state.encoder_mask is not None:
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length
else:
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long)
real_max_length = max_lengths.max().item()
else:
real_max_length = max_length
if state.encoder_mask is not None:
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length
else:
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
while cur_len < real_max_length:
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size
if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
scores.scatter_(dim=1, index=token_ids, src=token_scores)
if eos_token_id is not None and length_penalty != 1.0:
token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size
eos_mask = scores.new_ones(scores.size(1))
eos_mask[eos_token_id] = 0
eos_mask = eos_mask.unsqueeze(0).eq(1)
scores = scores.masked_scatter(eos_mask, token_scores)
if restricter is not None:
_, next_tokens = restricter(state, token_ids, scores, 1)
else:
next_tokens = scores.argmax(dim=-1, keepdim=True)
next_tokens = next_tokens.squeeze(-1)
if _eos_token_id!=-1:
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id)
next_tokens = next_tokens.masked_fill(dones, pad_token_id)
tokens = next_tokens.unsqueeze(1)
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
end_mask = next_tokens.eq(_eos_token_id)
dones = dones.__or__(end_mask)
cur_len += 1
if dones.min() == 1:
break
return token_ids
def _beam_search_generate(decoder: PromptBartDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4,
bos_token_id=None, eos_token_id=None, do_sample=True,
repetition_penalty=1.0, length_penalty=None, pad_token_id=0,
restricter=None) -> torch.LongTensor:
assert do_sample is False
# beam search
device = _get_model_device(decoder)
if tokens is None:
if bos_token_id is None:
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
batch_size = state.num_samples
if batch_size is None:
raise RuntimeError("Cannot infer the number of samples from `state`.")
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
batch_size = tokens.size(0)
if state.num_samples:
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match."
if eos_token_id is None:
_eos_token_id = -1
else:
_eos_token_id = eos_token_id
scores = decoder.decode(tokens=tokens, state=state)
vocab_size = scores.size(1)
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
if restricter is not None:
_next_scores, _next_tokens = restricter(state, tokens, scores, num_beams+1)
else:
# bsz x (num_beams+1)
_next_scores, _next_tokens = torch.topk(scores, num_beams+1, dim=1, largest=True, sorted=True)
indices = torch.arange(batch_size, dtype=torch.long).to(device)
indices = indices.repeat_interleave(num_beams)
state.reorder_state(indices)
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
if max_len_a!=0:
# (bsz x num_beams, )
if state.encoder_mask is not None:
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length
else:
max_lengths = tokens.new_full((batch_size*num_beams, ), fill_value=max_length, dtype=torch.long)
real_max_length = max_lengths.max().item()
else:
real_max_length = max_length
if state.encoder_mask is not None:
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length
else:
max_lengths = tokens.new_full((batch_size*num_beams,), fill_value=max_length, dtype=torch.long)
hypos = [
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
]
not_eos_mask = _next_tokens.ne(_eos_token_id)
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams)
keep_mask = not_eos_mask.__and__(keep_mask)
next_tokens = _next_tokens.masked_select(keep_mask).view(batch_size, num_beams)
next_scores = _next_scores.masked_select(keep_mask).view(batch_size, num_beams)
rows, cols = not_eos_mask.eq(0)[:, :num_beams].nonzero(as_tuple=True)
if len(rows)>0:
for row, col in zip(rows.tolist(), cols.tolist()):
_token = torch.cat([tokens[row*num_beams], _next_tokens[row, col:col+1]], dim=0)
hypos[row].add(_token.clone(), _next_scores[row, col].item())
# (batch_size, cur_len)
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
dones = [False] * batch_size
beam_scores = next_scores.view(-1) # batch_size * num_beams
cur_len = token_ids.size(1)
# 0, num_beams, 2*num_beams, ...
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
while cur_len < real_max_length:
scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size)
if repetition_penalty != 1.0:
token_scores = scores.gather(dim=1, index=token_ids)
lt_zero_mask = token_scores.lt(0).float()
ge_zero_mask = lt_zero_mask.eq(0).float()
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
scores.scatter_(dim=1, index=token_ids, src=token_scores)
if _eos_token_id!=-1:
max_len_eos_mask = max_lengths.eq(cur_len+1)
eos_scores = scores[:, _eos_token_id]
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores)
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size)
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size)
if restricter is not None:
next_scores, ids = restricter(state, token_ids, _scores, 2 * num_beams)
else:
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams)
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams)
next_tokens = ids % vocab_size # (batch_size, 2*num_beams)
not_eos_mask = next_tokens.ne(_eos_token_id)
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams)
keep_mask = not_eos_mask.__and__(keep_mask)
_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1)
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams)
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
beam_scores = _next_scores.view(-1)
flag = True
if cur_len+1 == real_max_length:
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size)
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1)
else:
effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams
if effective_eos_mask.sum().gt(0):
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True)
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx]
else:
flag = False
if flag:
_token_ids = torch.cat([token_ids, _next_tokens], dim=-1)
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(),
eos_beam_idx.tolist()):
if not dones[batch_idx]:
score = next_scores[batch_idx, beam_ind].item()
if _eos_token_id!=-1:
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
else:
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score)
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten
state.reorder_state(reorder_inds)
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1)
for batch_idx in range(batch_size):
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \
max_lengths[batch_idx*num_beams]==cur_len+1
cur_len += 1
if all(dones):
break
# select the best hypotheses
tgt_len = token_ids.new_zeros(batch_size)
best = []
for i, hypotheses in enumerate(hypos):
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
if _eos_token_id!=-1:
best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id])
tgt_len[i] = len(best_hyp)
best.append(best_hyp)
# generate target batch
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id)
for i, hypo in enumerate(best):
decoded[i, :tgt_len[i]] = hypo
return decoded
class BeamHypotheses(object):
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
"""
Initialize n-best list of hypotheses.
"""
self.max_length = max_length - 1 # ignoring bos_token
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.hyp = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.hyp)
def add(self, hyp, sum_logprobs):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / len(hyp) ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.hyp.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
del self.hyp[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,240 @@
import torch
from tqdm import tqdm
import numpy as np
from itertools import chain
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import BartTokenizer
import logging
logger = logging.getLogger(__name__)
# load file and process bio
class ConllNERProcessor(object):
def __init__(self, data_path, mapping, bart_name, learn_weights) -> None:
self.data_path = data_path
self.tokenizer = BartTokenizer.from_pretrained(bart_name)
self.mapping = mapping # 记录的是原始tag与转换后的tag的str的匹配关系
self.original_token_nums = self.tokenizer.vocab_size
self.learn_weights = learn_weights
self._add_tags_to_tokens()
def load_from_file(self, mode='train'):
"""load conll ner from file
Args:
mode (str, optional): train/test/dev. Defaults to 'train'.
Return:
outputs (dict)
raw_words: ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']
raw_targets: ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
entities: [['EU'], ['German'], ['British']]
entity_tags: ['org', 'misc', 'misc']
entity_spans: [[0, 1], [2, 3], [6, 7]]
"""
load_file = self.data_path[mode]
logger.info("Loading data from {}".format(load_file))
# extract bio
outputs = {'raw_words':[], 'raw_targets':[], 'entities':[], 'entity_tags':[], 'entity_spans':[]}
with open(load_file, "r", encoding="utf-8") as f:
lines = f.readlines()
raw_words, raw_targets = [], []
raw_word, raw_target = [], []
for line in lines:
if line != "\n":
raw_word.append(line.split('\t')[0])
raw_target.append(line.split('\t')[1][:-1])
else:
raw_words.append(raw_word)
raw_targets.append(raw_target)
raw_word, raw_target = [], []
for words, targets in zip(raw_words, raw_targets):
entities, entity_tags, entity_spans = [], [], []
start, end, start_flag = 0, 0, False
for idx, tag in enumerate(targets):
if tag.startswith('B-'): # 一个实体开头 另一个实体I-)结束
end = idx
if start_flag: # 另一个实体以I-结束紧接着当前实体B-出现
entities.append(words[start:end])
entity_tags.append(targets[start][2:].lower())
entity_spans.append([start, end])
start_flag = False
start = idx
start_flag = True
elif tag.startswith('I-'): # 实体中间不是开头也不是结束end+1即可
end = idx
elif tag.startswith('O'): # 无实体,可能是上一个实体的结束
end = idx
if start_flag: # 上一个实体结束
entities.append(words[start:end])
entity_tags.append(targets[start][2:].lower())
entity_spans.append([start, end])
start_flag = False
if start_flag: # 句子以实体I-结束,未被添加
entities.append(words[start:end+1])
entity_tags.append(targets[start][2:].lower())
entity_spans.append([start, end+1])
start_flag = False
if len(entities) != 0:
outputs['raw_words'].append(words)
outputs['raw_targets'].append(targets)
outputs['entities'].append(entities)
outputs['entity_tags'].append(entity_tags)
outputs['entity_spans'].append(entity_spans)
return outputs
def process(self, data_dict):
target_shift = len(self.mapping) + 2
def prepare_target(item):
raw_word = item['raw_word']
word_bpes = [[self.tokenizer.bos_token_id]]
first = []
cur_bpe_len = 1
for word in raw_word:
bpes = self.tokenizer.tokenize(word, add_prefix_space=True)
bpes = self.tokenizer.convert_tokens_to_ids(bpes)
first.append(cur_bpe_len)
cur_bpe_len += len(bpes)
word_bpes.append(bpes)
assert first[-1] + len(bpes) == sum(map(len, word_bpes))
word_bpes.append([self.tokenizer.eos_token_id])
assert len(first) == len(raw_word) == len(word_bpes) - 2
lens = list(map(len, word_bpes))
cum_lens = np.cumsum(lens).tolist()
entity_spans = item['entity_span'] # [(s1, e1, s2, e2), ()]
entity_tags = item['entity_tag'] # [tag1, tag2...]
entities = item['entity'] # [[ent1, ent2,], [ent1, ent2]]
target = [0]
pairs = []
first = list(range(cum_lens[-1]))
assert len(entity_spans) == len(entity_tags) #
for idx, (entity, tag) in enumerate(zip(entity_spans, entity_tags)):
cur_pair = []
num_ent = len(entity) // 2
for i in range(num_ent):
start = entity[2 * i]
end = entity[2 * i + 1]
cur_pair_ = []
cur_pair_.extend([cum_lens[k] for k in list(range(start, end))])
cur_pair.extend([p + target_shift for p in cur_pair_])
for _, (j, word_idx) in enumerate(zip((cur_pair[0], cur_pair[-1]), (0, -1))):
j = j - target_shift
assert all([cur_pair[i] < cum_lens[-1] + target_shift for i in range(len(cur_pair))])
cur_pair.append(self.mapping2targetid[tag] + 2)
pairs.append([p for p in cur_pair])
target.extend(list(chain(*pairs)))
target.append(1)
word_bpes = list(chain(*word_bpes))
assert len(word_bpes)<500
dict = {'tgt_tokens': target, 'target_span': pairs, 'src_tokens': word_bpes,
'first': first, 'src_seq_len':len(word_bpes), 'tgt_seq_len':len(target)}
return dict
logger.info("Process data...")
for raw_word, raw_target, entity, entity_tag, entity_span in tqdm(zip(data_dict['raw_words'], data_dict['raw_targets'], data_dict['entities'],
data_dict['entity_tags'], data_dict['entity_spans']), total=len(data_dict['raw_words']), desc='Processing'):
item_dict = prepare_target({'raw_word': raw_word, 'raw_target':raw_target, 'entity': entity, 'entity_tag': entity_tag, 'entity_span': entity_span})
# add item_dict to data_dict
for key, value in item_dict.items():
if key in data_dict:
data_dict[key].append(value)
else:
data_dict[key] = [value]
return data_dict
def _add_tags_to_tokens(self):
mapping = self.mapping
if self.learn_weights: # add extra tokens to huggingface tokenizer
self.mapping2id = {}
self.mapping2targetid = {}
for key, value in self.mapping.items():
key_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(value[2:-2], add_prefix_space=True))
self.mapping2id[value] = key_id # may be list
self.mapping2targetid[key] = len(self.mapping2targetid)
else:
tokens_to_add = sorted(list(mapping.values()), key=lambda x: len(x), reverse=True) #
unique_no_split_tokens = self.tokenizer.unique_no_split_tokens # no split
sorted_add_tokens = sorted(list(tokens_to_add), key=lambda x: len(x), reverse=True)
for tok in sorted_add_tokens:
assert self.tokenizer.convert_tokens_to_ids([tok])[0] == self.tokenizer.unk_token_id #
self.tokenizer.unique_no_split_tokens = unique_no_split_tokens + sorted_add_tokens # add to no_split_tokens
self.tokenizer.add_tokens(sorted_add_tokens)
self.mapping2id = {} # tag to id
self.mapping2targetid = {} # tag to number
for key, value in self.mapping.items():
key_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(value))
assert len(key_id) == 1, value
assert key_id[0] >= self.original_token_nums
self.mapping2id[value] = key_id[0] #
self.mapping2targetid[key] = len(self.mapping2targetid)
class ConllNERDataset(Dataset):
def __init__(self, data_processor, mode='train') -> None:
self.data_processor = data_processor
self.data_dict = data_processor.load_from_file(mode=mode)
self.complet_data = data_processor.process(self.data_dict)
self.mode = mode
def __len__(self):
return len(self.complet_data['src_tokens'])
def __getitem__(self, index):
if self.mode == 'test':
return torch.tensor(self.complet_data['src_tokens'][index]), torch.tensor(self.complet_data['src_seq_len'][index]), \
torch.tensor(self.complet_data['first'][index]), self.complet_data['raw_words'][index]
return torch.tensor(self.complet_data['src_tokens'][index]), torch.tensor(self.complet_data['tgt_tokens'][index]), \
torch.tensor(self.complet_data['src_seq_len'][index]), torch.tensor(self.complet_data['tgt_seq_len'][index]), \
torch.tensor(self.complet_data['first'][index]), self.complet_data['target_span'][index]
def collate_fn(self, batch):
src_tokens, src_seq_len, first = [], [], []
tgt_tokens, tgt_seq_len, target_span = [], [], []
if self.mode == "test":
raw_words = []
for tup in batch:
src_tokens.append(tup[0])
src_seq_len.append(tup[1])
first.append(tup[2])
raw_words.append(tup[3])
src_tokens = pad_sequence(src_tokens, batch_first=True, padding_value=self.data_processor.tokenizer.pad_token_id)
first = pad_sequence(first, batch_first=True, padding_value=0)
return src_tokens, torch.stack(src_seq_len, 0), first, raw_words
for tup in batch:
src_tokens.append(tup[0])
tgt_tokens.append(tup[1])
src_seq_len.append(tup[2])
tgt_seq_len.append(tup[3])
first.append(tup[4])
target_span.append(tup[5])
src_tokens = pad_sequence(src_tokens, batch_first=True, padding_value=self.data_processor.tokenizer.pad_token_id)
tgt_tokens = pad_sequence(tgt_tokens, batch_first=True, padding_value=1)
first = pad_sequence(first, batch_first=True, padding_value=0)
return src_tokens, tgt_tokens, torch.stack(src_seq_len, 0), torch.stack(tgt_seq_len, 0), first, target_span
if __name__ == '__main__':
data_path = {'train':'data/conll2003/train.txt'}
bart_name = '../BARTNER-AMAX/facebook/'
conll_processor = ConllNERProcessor(data_path, bart_name)
conll_datasets = ConllNERDataset(conll_processor, mode='train')
conll_dataloader = DataLoader(conll_datasets, collate_fn=conll_datasets.collate_fn, batch_size=8)
for idx, data in enumerate(conll_dataloader):
print(data)
break

View File

@ -0,0 +1,111 @@
atis_mapping = {
'depart_date.day_number':'<<depart date day number>>',
'arrive_date.day_name':'<<arrive date day name>>',
'airline_name':'<<airline name>>',
'depart_date.year':'<<depart date year>>',
'flight_mod':'<<flight mod>>',
'return_date.day_name':'<<return date day name>>',
'toloc.city_name':'<<toloc city name>>',
'return_date.day_number':'<<return date day number>>',
'time_relative':'<<time relative>>',
'city_name':'<<city name>>',
'state_code':'<<state code>>',
'transport_type':'<<transport type>>',
'class_type':'<<class type>>',
'days_code':'<<days code>>',
'toloc.country_name':'<<toloc country name>>',
'arrive_date.today_relative':'<<arrive date today relative>>',
'round_trip':'<<round trip>>',
'toloc.state_name':'<<toloc state name>>',
'aircraft_code':'<<aircraft code>>',
'arrive_date.month_name':'<<arrive date month name>>',
'depart_date.today_relative':'<<depart date today relative>>',
'depart_time.start_time':'<<depart time start time>>',
'compartment':'<<compartment>>',
'day_number':'<<day number>>',
'depart_date.date_relative':'<<depart date date relative>>',
'arrive_date.day_number':'<<arrive date day number>>',
'depart_time.time':'<<depart time time>>',
'fare_amount':'<<fare amount>>',
'depart_date.month_name':'<<depart date month name>>',
'period_of_day':'<<period of day>>',
'cost_relative':'<<cost relative>>',
'fromloc.airport_name':'<<fromloc airport name>>',
'fare_basis_code':'<<fare basis code>>',
'arrive_time.start_time':'<<arrive time start time>>',
'stoploc.airport_name':'<<stoploc airport name>>',
'time':'<<time>>',
'depart_time.time_relative':'<<depart time time relative>>',
'return_time.period_of_day':'<<return time period of day>>',
'depart_time.period_of_day':'<<depart time period of day>>',
'economy':'<<economy>>',
'mod':'<<mod>>',
'stoploc.airport_code':'<<stoploc airport code>>',
'stoploc.state_code':'<<stoploc state code>>',
'arrive_time.end_time':'<<arrive time end time>>',
'state_name':'<<state name>>',
'airport_name':'<<airport name>>',
'depart_date.day_name':'<<depart date day name>>',
'fromloc.state_name':'<<fromloc state name>>',
'arrive_time.time_relative':'<<arrive time time relative>>',
'today_relative':'<<today relative>>',
'day_name':'<<day name>>',
'flight_stop':'<<flight stop>>',
'month_name':'<<month name>>',
'fromloc.city_name':'<<fromloc city name>>',
'meal':'<<meal>>',
'arrive_time.period_of_day':'<<arrive time period of day>>',
'return_time.period_mod':'<<return time period mod>>',
'toloc.airport_code':'<<toloc airport code>>',
'airport_code':'<<airport code>>',
'restriction_code':'<<restriction code>>',
'flight_time':'<<flight time>>',
'airline_code':'<<airline code>>',
'depart_time.end_time':'<<depart time end time>>',
'flight_days':'<<flight days>>',
'booking_class':'<<booking class>>',
'flight_number':'<<flight number>>',
'or':'<<or>>',
'fromloc.airport_code':'<<fromloc airport code>>',
'meal_description':'<<meal description>>',
'return_date.date_relative':'<<return date date relative>>',
'return_date.month_name':'<<return date month name>>',
'arrive_date.date_relative':'<<arrive date date relative>>',
'return_date.today_relative':'<<return date today relative>>',
'arrive_time.period_mod':'<<arrive time period mod>>',
'depart_time.period_mod':'<<depart time period mod>>',
'meal_code':'<<meal code>>',
'flight':'<<flight>>',
'toloc.airport_name':'<<toloc airport name>>',
'stoploc.city_name':'<<stoploc city name>>',
'connect':'<<connect>>',
'arrive_time.time':'<<arrive time time>>',
'toloc.state_code':'<<toloc state code>>',
'fromloc.state_code':'<<fromloc state code>>'
}
mit_movie_mapping = {
'genre': '<<genre>>',
'actor': '<<actor>>',
'year': '<<year>>',
'title': '<<title>>',
'rating': '<<rating>>',
'ratings_average': '<<ratings average>>',
'director': '<<director>>',
'plot': '<<plot>>',
'character': '<<character>>',
'song': '<<song>>',
'review': '<<review>>',
'trailer': '<<trailer>>'
}
mit_restaurant_mapping = {
'location': '<<location>>',
'cuisine': '<<cuisine>>',
'amenity': '<<amenity>>',
'restaurant_name': '<<restaurant name>>',
'rating': '<<rating>>',
'dish': '<<dish>>',
'hours': '<<hours>>',
'price': '<<price>>'
}

View File

@ -0,0 +1,110 @@
import numpy as np
class Seq2SeqSpanMetric(object):
def __init__(self, eos_token_id, num_labels, target_type='word'):
self.eos_token_id = eos_token_id
self.num_labels = num_labels
self.word_start_index = num_labels+2
self.fp = 0
self.tp = 0
self.fn = 0
self.em = 0
self.total = 0
self.target_type = target_type
def evaluate(self, target_span, pred, tgt_tokens):
self.total += pred.size(0)
pred_eos_index = pred.flip(dims=[1]).eq(self.eos_token_id).cumsum(dim=1).long()
target_eos_index = tgt_tokens.flip(dims=[1]).eq(self.eos_token_id).cumsum(dim=1).long()
pred = pred[:, 1:]
tgt_tokens = tgt_tokens[:, 1:]
pred_seq_len = pred_eos_index.flip(dims=[1]).eq(pred_eos_index[:, -1:]).sum(dim=1) # bsz
pred_seq_len = (pred_seq_len - 2).tolist()
target_seq_len = target_eos_index.flip(dims=[1]).eq(target_eos_index[:, -1:]).sum(dim=1) # bsz
target_seq_len = (target_seq_len-2).tolist()
pred_spans = []
for i, (ts, ps) in enumerate(zip(target_span, pred.tolist())):
em = 0
ps = ps[:pred_seq_len[i]]
if pred_seq_len[i]==target_seq_len[i]:
em = int(tgt_tokens[i, :target_seq_len[i]].eq(pred[i, :target_seq_len[i]]).sum().item()==target_seq_len[i])
self.em += em
pairs = []
cur_pair = []
if len(ps):
for j in ps:
if j<self.word_start_index:
if self.target_type == 'span':
if len(cur_pair)>0 and len(cur_pair)%2==0:
if all([cur_pair[i]<=cur_pair[i+1] for i in range(len(cur_pair)-1)]):
pairs.append(tuple(cur_pair+[j]))
else:
if len(cur_pair) > 0:
if all([cur_pair[i]<cur_pair[i+1] for i in range(len(cur_pair)-1)]):
pairs.append(tuple(cur_pair + [j]))
cur_pair = []
else:
cur_pair.append(j)
pred_spans.append(pairs.copy())
tp, fn, fp = _compute_tp_fn_fp(pairs, ts)
self.fn += fn
self.tp += tp
self.fp += fp
def get_metric(self, reset=True):
res = {}
f, pre, rec = _compute_f_pre_rec(1, self.tp, self.fn, self.fp)
res['f'] = round(f, 4)*100
res['rec'] = round(rec, 4)*100
res['pre'] = round(pre, 4)*100
res['em'] = round(self.em/self.total, 4)
if reset:
self.total = 0
self.fp = 0
self.tp = 0
self.fn = 0
self.em = 0
return res
def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""
:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
return f, pre, rec
def _compute_tp_fn_fp(ps, ts):
ps = ps.copy()
tp = 0
fp = 0
fn = 0
if isinstance(ts, (set, list, np.ndarray)):
ts = {tuple(key):1 for key in list(ts)}
if isinstance(ps, (set, list, np.ndarray)):
ps = {tuple(key):1 for key in list(ps)}
for key in ts.keys():
t_num = ts[key]
if key not in ps:
p_num = 0
else:
p_num = ps[key]
tp += min(p_num, t_num)
fp += max(p_num - t_num, 0)
fn += max(t_num - p_num, 0)
if key in ps:
ps.pop(key)
fp += sum(ps.values())
return tp, fn, fp

View File

@ -0,0 +1,193 @@
import torch
from torch import optim
from tqdm import tqdm
from ..utils.utils import convert_preds_to_outputs, write_predictions
import random
class Trainer(object):
def __init__(self, train_data=None, dev_data=None, test_data=None, model=None, process=None, args=None, logger=None, loss=None, metrics=None, writer=None) -> None:
self.train_data = train_data
self.dev_data = dev_data
self.test_data = test_data
self.model = model
self.process = process
self.logger = logger
self.metrics = metrics
self.writer = writer
self.loss = loss
self.num_epochs = args.num_epochs
self.batch_size = args.batch_size
self.lr = args.learning_rate
self.eval_begin_epoch = args.eval_begin_epoch
self.device = args.device
self.load_path = args.load_path
self.save_path = args.save_path
self.refresh_step = 1
self.best_metric = 0
self.best_dev_epoch = None
self.optimizer = None
if self.train_data is not None:
self.train_num_steps = len(self.train_data) * self.num_epochs
self.step = 0
self.args = args
def train(self):
self.before_train() # something should do before training
self.step = 0
self.model.train()
self.logger.info("***** Running training *****")
self.logger.info(" Num instance = %d", len(self.train_data)*self.batch_size)
self.logger.info(" Num epoch = %d", self.num_epochs)
self.logger.info(" Batch size = %d", self.batch_size)
self.logger.info(" Learning rate = {}".format(self.lr))
self.logger.info(" Evaluate begin = %d", self.eval_begin_epoch)
if self.load_path is not None: # load model from load_path
self.logger.info("Loading model from {}".format(self.load_path))
self.model.load_state_dict(torch.load(self.load_path))
self.logger.info("Load model successful!")
with tqdm(total=self.train_num_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, initial=self.step) as pbar:
self.pbar = pbar
avg_loss = 0
for epoch in range(self.num_epochs):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.num_epochs))
for batch in self.train_data:
self.step += 1
batch = (tup.to(self.device) if isinstance(tup, torch.Tensor) else tup for tup in batch)
loss = self._step(batch, mode="train")
avg_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.step % self.refresh_step == 0:
avg_loss = float(avg_loss) / self.refresh_step
print_output = "loss:{:<6.5f}".format(avg_loss)
pbar.update(1)
pbar.set_postfix_str(print_output)
self.writer.add_scalar(tag='loss', scalar_value=avg_loss, global_step=self.step) # tensorbordx
avg_loss = 0
if epoch >= self.eval_begin_epoch:
self.evaluate(epoch) # generator to dev.
pbar.close()
self.pbar = None
self.logger.info("Get best performance at epoch {}, best f1 score is {:.2f}".format(self.best_dev_epoch, self.best_metric))
def evaluate(self, epoch):
self.model.eval()
self.logger.info("***** Running evaluate *****")
self.logger.info(" Num instance = %d", len(self.dev_data)*self.batch_size)
self.logger.info(" Batch size = %d", self.batch_size)
with torch.no_grad():
with tqdm(total=len(self.dev_data), leave=False, dynamic_ncols=True) as pbar:
pbar.set_description_str(desc="Dev")
for batch in self.dev_data:
batch = (tup.to(self.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device
self._step(batch, mode="dev")
pbar.update()
# evaluate done
eva_result = self.metrics.get_metric()
pbar.close()
self.logger.info("Epoch {}/{}, best f1: {}, current f1 score: {:.2f}, recall: {:.2f}, precision: {:.2f}."\
.format(epoch, self.num_epochs, self.best_metric, eva_result['f'], eva_result['rec'], eva_result['pre']))
self.writer.add_scalars('evaluate', {'f1': eva_result['f'],
'recall': eva_result['rec'],
'precision': eva_result['pre']}, epoch)
if eva_result['f'] >= self.best_metric: # this epoch get best performance
self.logger.info("Get better performance at epoch {}".format(epoch))
self.best_dev_epoch = epoch
self.best_metric = eva_result['f'] # update best metric(f1 score)
if self.save_path is not None: # need to save model
torch.save(self.model.state_dict(), self.save_path+"/best_model.pth")
self.logger.info("Save best model at {}".format(self.save_path))
self.model.train()
def predict(self):
assert self.load_path is not None and self.test_data is not None
self.model.eval()
self.logger.info("***** Running testing *****")
self.logger.info(" Num instance = %d", len(self.test_data)*self.batch_size)
self.logger.info(" Batch size = %d", self.batch_size)
if self.load_path is not None: # load model from load_path
self.logger.info("Loading model from {}".format(self.load_path))
self.model.load_state_dict(torch.load(self.load_path))
self.logger.info("Load model successful!")
self.model.to(self.device)
with torch.no_grad():
with tqdm(total=len(self.test_data), leave=False, dynamic_ncols=True) as pbar:
pbar.set_description_str(desc="Test")
texts = []
labels = []
for batch in self.test_data:
batch = (tup.to(self.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device
src_tokens, src_seq_len, first, raw_words = batch
preds = self._step((src_tokens, src_seq_len, first), mode="test")
outputs = convert_preds_to_outputs(preds, raw_words, self.process.mapping, self.process.tokenizer)
texts.extend(raw_words)
labels.extend(outputs)
pbar.update()
self.logger.info("***** Predict example *****")
idx = random.randint(0, len(texts))
print(len(texts), len(labels))
self.logger.info("Raw texts: " + " ".join(texts[idx]))
self.logger.info("Prediction: " + " ".join(labels[idx]))
if self.args.write_path is not None: # write predict
write_predictions(self.args.write_path, texts, labels)
self.logger.info("Write into {}!".format(self.args.write_path))
def _step(self, batch, mode="train"):
if mode=="dev": # dev: compute metric
src_tokens, tgt_tokens, src_seq_len, tgt_seq_len, first, target_span = batch
pred = self.model.predict(src_tokens, src_seq_len, first)
self.metrics.evaluate(target_span, pred, tgt_tokens)
return
elif mode=="test": # test: just get pred
src_tokens, src_seq_len, first = batch
pred = self.model.predict(src_tokens, src_seq_len, first)
return pred
else: # train: get loss
src_tokens, tgt_tokens, src_seq_len, tgt_seq_len, first, target_span = batch
pred = self.model(src_tokens, tgt_tokens, src_seq_len, first)
loss = self.loss(tgt_tokens, tgt_seq_len, pred)
return loss
def before_train(self):
parameters = []
params = {'lr':self.lr, 'weight_decay':1e-2}
params['params'] = [param for name, param in self.model.named_parameters() if not ('bart_encoder' in name or 'bart_decoder' in name)]
parameters.append(params)
params = {'lr':self.lr, 'weight_decay':1e-2}
params['params'] = []
for name, param in self.model.named_parameters():
if ('bart_encoder' in name or 'bart_decoder' in name) and not ('layernorm' in name or 'layer_norm' in name):
params['params'].append(param)
parameters.append(params)
params = {'lr':self.lr, 'weight_decay':0}
params['params'] = []
for name, param in self.model.named_parameters():
if ('bart_encoder' in name or 'bart_decoder' in name) and ('layernorm' in name or 'layer_norm' in name):
params['params'].append(param)
parameters.append(params)
self.optimizer = optim.AdamW(parameters)
if self.args.freeze_plm: # freeze pretrained language model(bart)
for name, par in self.model.named_parameters():
if 'prompt_encoder' in name or 'prompt_decoder' in name and "bart_mlp" not in name:
par.requires_grad = False
self.model.to(self.device)

View File

@ -0,0 +1,157 @@
import torch
import numpy as np
import random
from torch import nn
import torch.nn.functional as F
from transformers import BartModel, BartTokenizer
def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel, bart_name, num_tokens):
"""when initial added tokens, use their averge token emebddings
Args:
tokenizer (BartTokenizer): [description]
bart_model (BartModel): [description]
bart_name ([type]): [description]
num_tokens ([type]): [description]
Raises:
RuntimeError: [description]
Returns:
[type]: [description]
"""
_tokenizer = BartTokenizer.from_pretrained(bart_name)
for token in tokenizer.unique_no_split_tokens:
if token[:2] == '<<': # 特殊字符
index = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(token))
if len(index)>1:
raise RuntimeError(f"{token} wrong split")
else:
index = index[0]
assert index>=num_tokens, (index, num_tokens, token)
indexes = _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(token[2:-2]))
embed = bart_model.encoder.embed_tokens.weight.data[indexes[0]]
for i in indexes[1:]:
embed += bart_model.decoder.embed_tokens.weight.data[i]
embed /= len(indexes)
bart_model.decoder.embed_tokens.weight.data[index] = embed
return bart_model
def seq_to_mask(seq_len, max_len):
"""[get attention mask with sequence length]
Args:
seq_len ([torch.tensor]): [shape: bsz, each sequence length in a batch]
"""
max_len = int(max_len) if max_len else seq_len.max().long()
cast_seq = torch.arange(max_len).expand(seq_len.size(0), -1).to(seq_len)
mask = cast_seq.lt(seq_len.unsqueeze(1))
return mask
def get_loss(tgt_tokens, tgt_seq_len, pred):
"""
:param tgt_tokens: bsz x max_len, 包含了的[sos, token, eos]
:param pred: bsz x max_len-1 x vocab_size
:return:
"""
tgt_seq_len = tgt_seq_len - 1
mask = seq_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1) - 1).eq(0)
tgt_tokens = tgt_tokens[:, 1:].masked_fill(mask, -100)
loss = F.cross_entropy(target=tgt_tokens, input=pred.transpose(1, 2))
return loss
def _get_model_device(model):
assert isinstance(model, nn.Module)
parameters = list(model.parameters())
if len(parameters) == 0:
return None
else:
return parameters[0].device
def set_seed(seed=2021):
"""sets random seed"""
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
np.random.seed(seed)
random.seed(seed)
def convert_preds_to_outputs(preds, raw_words, mapping, tokenizer):
"""convet model predicitons to BIO outputs
Args:
preds ([torch.Tensor]): [prompt model predictions, (bsz x seq_len x labels)]
raw_words ([List]): [source raw words]
mapping ([dict]): [map entity labels to <<>>]
tokenizer : [BartTokenizer]
Returns:
[outputs (List)]: [each item length equal to raw_words, BIO format.]
"""
id2label = list(mapping.keys())
pred_eos_index = preds.flip(dims=[1]).eq(1).cumsum(dim=1).long()
preds = preds[:, 1:]
pred_seq_len = pred_eos_index.flip(dims=[1]).eq(pred_eos_index[:, -1:]).sum(dim=1) # bsz
pred_seq_len = (pred_seq_len - 2).tolist()
word_start_index = len(mapping) + 2
outputs = []
for i, pred_item in enumerate(preds.tolist()):
pred_item = pred_item[:pred_seq_len[i]] # single sentence prediction
pairs, cur_pair = [], []
if len(pred_item): # this sentence prediciton= is not null
for idx in pred_item:
if idx < word_start_index: # is entity
if len(cur_pair) > 0:
# assert word[i] < word[i+1]
if all([cur_pair[i] < cur_pair[i + 1] for i in range(len(cur_pair) - 1)]):
pairs.append(tuple(cur_pair + [idx])) # add valid words and current entity id
cur_pair = [] # clear word pairs
else: # is word
cur_pair.append(idx) # add word id to word pairs
raw_words_item = raw_words[i]
cum_lens = [1]
start_idx = 1
for word in raw_words_item:
start_idx += len(tokenizer.tokenize(word, add_prefix_space=True))
cum_lens.append(start_idx)
cum_lens.append(start_idx+1)
output = ['O' for _ in range(len(raw_words_item))]
# pairs: List[(word id, ... , entity id), (...), ...]
for pair in pairs: # (word id, ... , entity id)
entity = pair[-1]
words = []
for word in pair[:-1]:
if word-word_start_index in cum_lens:
words.append(cum_lens.index(word-word_start_index))
if len(words) == 0: continue
start_idx = words[0]
end_idx = words[-1]
output[start_idx] = f'B-{id2label[entity-2]}'
for _ in range(start_idx+1, end_idx+1):
output[_] = f'I-{id2label[entity-2]}'
outputs.append(output)
return outputs
def write_predictions(path, texts, labels):
"""[write model predictions to path (conll format)]
Args:
path ([str]): [save path]
texts ([List]): [raw texts]
labels ([List]): [predict labels]
"""
print(len(texts), len(labels))
assert len(texts) == len(labels)
with open(path, "w", encoding="utf-8") as f:
f.writelines("-DOCSTART- O\n\n")
for i in range(len(texts)):
for j in range(len(texts[i])):
f.writelines("{}\t{}\n".format(texts[i][j], labels[i][j]))
f.writelines("\n")