ner-few-shot
This commit is contained in:
parent
d4ba08a709
commit
8330b2d896
|
@ -0,0 +1,703 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers.configuration_bart import BartConfig
|
||||
from .modeling_bart import BartModel, _prepare_bart_decoder_inputs
|
||||
|
||||
from ..utils.utils import avg_token_embeddings, seq_to_mask, _get_model_device
|
||||
from functools import partial
|
||||
from typing import Union
|
||||
|
||||
|
||||
class PromptBartEncoder(nn.Module):
|
||||
def __init__(self, encoder):
|
||||
super(PromptBartEncoder, self).__init__()
|
||||
self.bart_encoder = encoder
|
||||
|
||||
def forward(self, src_tokens, attention_mask=None, past_key_values=None):
|
||||
encoder_dicts = self.bart_encoder(input_ids=src_tokens, attention_mask=attention_mask, past_key_values=past_key_values, return_dict=True, output_hidden_states=True)
|
||||
return encoder_dicts.last_hidden_state, encoder_dicts.hidden_states
|
||||
|
||||
class PromptBartDecoder(nn.Module):
|
||||
def __init__(self, decoder, pad_token_id, label_ids, use_prompt=False, prompt_len=10, learn_weights=False):
|
||||
super(PromptBartDecoder, self).__init__()
|
||||
self.bart_decoder = decoder
|
||||
self.pad_token_id = pad_token_id
|
||||
self.use_prompt = use_prompt
|
||||
self.prompt_len = prompt_len
|
||||
self.learn_weights = learn_weights
|
||||
self.label_ids = label_ids
|
||||
|
||||
print(label_ids)
|
||||
if self.learn_weights: # set learnable averge weights
|
||||
self.averge_weights = nn.ParameterList(parameters=None)
|
||||
for id in label_ids:
|
||||
if len(id) > 1:
|
||||
self.averge_weights.append(nn.Parameter(torch.FloatTensor(len(id))))
|
||||
print(self.averge_weights)
|
||||
mapping = [0, 2]
|
||||
for id in label_ids:
|
||||
mapping += id[:1]
|
||||
mapping = torch.LongTensor(mapping)
|
||||
else:
|
||||
mapping = torch.LongTensor([0, 2]+label_ids)
|
||||
self.label_start_id = min(label_ids)
|
||||
self.label_end_id = max(label_ids)+1
|
||||
|
||||
self.register_buffer('mapping', mapping)
|
||||
self.src_start_index = len(mapping)
|
||||
|
||||
hidden_size = decoder.embed_tokens.weight.size(1)
|
||||
self.bart_mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size),
|
||||
nn.Dropout(0.3),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_size, hidden_size))
|
||||
self.dropout_layer = nn.Dropout(0.3)
|
||||
|
||||
def forward(self, tgt_tokens, prompt_state):
|
||||
cumsum = tgt_tokens.eq(1).flip(dims=[1]).cumsum(dim=-1)
|
||||
tgt_pad_mask = cumsum.flip(dims=[-1]).ne(cumsum[:, -1:])
|
||||
|
||||
encoder_outputs = prompt_state.encoder_output # last_hidden_state
|
||||
attention_mask = prompt_state.encoder_mask # attention_mask
|
||||
first = prompt_state.first
|
||||
src_tokens = prompt_state.src_tokens
|
||||
past_key_values = prompt_state.past_key_values
|
||||
|
||||
# mapping target tokens
|
||||
mapping_token_mask = tgt_tokens.lt(self.src_start_index)
|
||||
mapped_tokens = tgt_tokens.masked_fill(tgt_tokens.ge(self.src_start_index), 0)
|
||||
tag_mapped_tokens = self.mapping[mapped_tokens]
|
||||
|
||||
src_tokens_index = tgt_tokens - self.src_start_index # bsz x num_src_token
|
||||
src_tokens_index = src_tokens_index.masked_fill(src_tokens_index.lt(0), 0)
|
||||
if first is not None:
|
||||
src_tokens = src_tokens.gather(index=first, dim=1)
|
||||
word_mapped_tokens = src_tokens.gather(index=src_tokens_index, dim=1)
|
||||
|
||||
tokens = torch.where(mapping_token_mask, tag_mapped_tokens, word_mapped_tokens) # bsz x max_len
|
||||
tokens = tokens.masked_fill(tgt_pad_mask, self.pad_token_id)
|
||||
|
||||
decoder_input_ids, _, causal_mask = _prepare_bart_decoder_inputs(
|
||||
self.pad_token_id,
|
||||
tokens,
|
||||
decoder_input_ids=None,
|
||||
decoder_padding_mask=None,
|
||||
causal_mask_dtype=self.bart_decoder.embed_tokens.weight.dtype
|
||||
)
|
||||
|
||||
if self.use_prompt:
|
||||
assert past_key_values is not None
|
||||
_, _, seqlen, _ = past_key_values[0]['self']['prev_value'].shape
|
||||
tgt_len = decoder_input_ids.size(1)
|
||||
temp_mask = torch.zeros(tgt_len, seqlen).to(causal_mask.device) #tgtlen, preseqlen
|
||||
causal_mask = torch.cat([temp_mask, causal_mask],dim=1) #tgtlen, preseqlen+tgtlen
|
||||
|
||||
if self.training:
|
||||
tokens = tokens[:, :-1]
|
||||
decoder_pad_mask = tokens.eq(self.pad_token_id)
|
||||
dict = self.bart_decoder(input_ids=tokens,
|
||||
encoder_hidden_states=encoder_outputs, # last_hidden_state
|
||||
encoder_padding_mask=attention_mask, # attention_mask
|
||||
decoder_padding_mask=decoder_pad_mask,
|
||||
decoder_causal_mask=causal_mask[:tokens.size(1), :self.prompt_len+tokens.size(1)],
|
||||
output_hidden_states=True,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=True)
|
||||
else:
|
||||
past_key_values = prompt_state.past_key_values
|
||||
dict = self.bart_decoder(input_ids=tokens,
|
||||
encoder_hidden_states=encoder_outputs,
|
||||
encoder_padding_mask=attention_mask,
|
||||
decoder_padding_mask=None,
|
||||
decoder_causal_mask=None,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
return_dict=True)
|
||||
hidden_state = dict.last_hidden_state # bsz x max_len x hidden_size
|
||||
hidden_state = self.dropout_layer(hidden_state)
|
||||
if not self.training:
|
||||
prompt_state.past_key_values = dict.past_key_values
|
||||
|
||||
logits = hidden_state.new_full((hidden_state.size(0), hidden_state.size(1), self.src_start_index+src_tokens.size(-1)),
|
||||
fill_value=-1e24)
|
||||
|
||||
# compute eos scores
|
||||
eos_scores = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[2:3])) # bsz x max_len x 1
|
||||
|
||||
if self.learn_weights: # use averge_weights compute entity labels scores
|
||||
tag_scores = None
|
||||
idx = 0
|
||||
for ids in self.label_ids: # bsz x max_len x num_class
|
||||
if len(ids) <= 1:
|
||||
temp_score = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[ids]))
|
||||
else:
|
||||
weight = F.softmax(self.averge_weights[idx])
|
||||
temp_score = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[[ids[0]]])) * weight[0]
|
||||
for i in range(1, len(ids)):
|
||||
temp_score = temp_score + F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[[ids[i]]])) * weight[i]
|
||||
idx += 1
|
||||
if tag_scores is None:
|
||||
tag_scores = temp_score
|
||||
else:
|
||||
tag_scores = torch.cat((tag_scores, temp_score), dim=2)
|
||||
else:
|
||||
tag_scores = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[self.label_start_id:self.label_end_id])) # bsz x max_len x num_class
|
||||
|
||||
# bsz x max_bpe_len x hidden_size
|
||||
src_outputs = encoder_outputs
|
||||
if hasattr(self, 'encoder_mlp'):
|
||||
src_outputs = self.encoder_mlp(src_outputs)
|
||||
|
||||
if first is not None:
|
||||
mask = first.eq(0) # bsz x 1 x max_word_len
|
||||
# bsz x max_word_len x hidden_size
|
||||
src_outputs = src_outputs.gather(index=first.unsqueeze(2).repeat(1, 1, src_outputs.size(-1)), dim=1)
|
||||
else:
|
||||
mask = attention_mask.eq(0)
|
||||
# src_outputs = self.decoder.embed_tokens(src_tokens)
|
||||
mask = mask.unsqueeze(1)
|
||||
input_embed = self.dropout_layer(self.bart_decoder.embed_tokens(src_tokens)) # bsz x max_word_len x hidden_size
|
||||
src_outputs = (src_outputs + input_embed)/2
|
||||
word_scores = torch.einsum('blh,bnh->bln', hidden_state, src_outputs) # bsz x max_len x max_word_len
|
||||
mask = mask.__or__(src_tokens.eq(2).cumsum(dim=1).ge(1).unsqueeze(1))
|
||||
word_scores = word_scores.masked_fill(mask, -1e32)
|
||||
|
||||
logits[:, :, 1:2] = eos_scores
|
||||
logits[:, :, 2:self.src_start_index] = tag_scores
|
||||
logits[:, :, self.src_start_index:] = word_scores
|
||||
return logits, prompt_state
|
||||
|
||||
def decode(self, tokens, state):
|
||||
return self(tokens, state)[0][:, -1]
|
||||
|
||||
class PromptBartModel(nn.Module):
|
||||
def __init__(self, tokenizer, label_ids, args):
|
||||
super(PromptBartModel, self).__init__()
|
||||
self.use_prompt = args.use_prompt
|
||||
self.prompt_len = args.prompt_len
|
||||
self.prompt_dim = args.prompt_dim
|
||||
self.learn_weights = args.learn_weights
|
||||
self.device = 'cuda' if torch.cuda.is_available else 'cpu'
|
||||
bart_name = args.bart_name
|
||||
|
||||
self.bart_config = BartConfig.from_pretrained(bart_name)
|
||||
self.bart_config.use_prompt = args.use_prompt
|
||||
self.bart_config.preseqlen = args.prompt_len
|
||||
bart_config = self.bart_config
|
||||
bart_model = BartModel.from_pretrained(bart_name, config=bart_config)
|
||||
num_tokens, _ = bart_model.encoder.embed_tokens.weight.shape
|
||||
bart_model.resize_token_embeddings(len(tokenizer.unique_no_split_tokens)+num_tokens)
|
||||
bart_model = avg_token_embeddings(tokenizer, bart_model, bart_name, num_tokens)
|
||||
|
||||
self.prompt_encoder = PromptBartEncoder(bart_model.encoder)
|
||||
self.prompt_decoder = PromptBartDecoder(bart_model.decoder, tokenizer.pad_token_id, label_ids, self.use_prompt, self.prompt_len, self.learn_weights)
|
||||
|
||||
self.prompt_inputs = torch.arange(self.prompt_len).long()
|
||||
self.encoder_prompt_embed = nn.Embedding(self.prompt_len, bart_config.d_model)
|
||||
self.encoder_mlp = nn.Sequential(
|
||||
nn.Linear(bart_config.d_model, self.prompt_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model))
|
||||
|
||||
self.decoder_prompt_embed = nn.Embedding(self.prompt_len, bart_config.d_model)
|
||||
self.decoder_mlp = nn.Sequential(
|
||||
nn.Linear(bart_config.d_model, self.prompt_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model))
|
||||
|
||||
self.prompt_cross_embed = nn.Embedding(self.prompt_len, bart_config.d_model)
|
||||
self.cross_mlp = nn.Sequential(
|
||||
nn.Linear(bart_config.d_model, self.prompt_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model))
|
||||
|
||||
self.dropout = nn.Dropout(0.0)
|
||||
|
||||
|
||||
def forward(self, src_tokens, tgt_tokens, src_seq_len, first):
|
||||
prompt_state = self.generator(src_tokens, src_seq_len, first)
|
||||
|
||||
decoder_outputs, prompt_state = self.prompt_decoder(tgt_tokens, prompt_state)
|
||||
return decoder_outputs
|
||||
|
||||
|
||||
def generator(self, src_tokens, src_seq_len, first):
|
||||
batch_size = src_tokens.size(0)
|
||||
past_key_values = self.get_prompt(batch_size) if self.use_prompt else None
|
||||
attention_mask = seq_to_mask(src_seq_len, max_len=src_tokens.size(1))
|
||||
encoder_outputs, hidden_states = self.prompt_encoder(src_tokens, attention_mask=attention_mask, past_key_values=past_key_values)
|
||||
prompt_state = PromptBartState(encoder_outputs, attention_mask, past_key_values, src_tokens, first, hidden_states[0], self.bart_config.preseqlen)
|
||||
|
||||
return prompt_state
|
||||
|
||||
|
||||
def get_prompt(self, batch_size):
|
||||
input_tokens = self.prompt_inputs.unsqueeze(0).expand(batch_size, -1).to(self.device)
|
||||
|
||||
# encoder prompt
|
||||
encoder_embed = self.encoder_prompt_embed(input_tokens)
|
||||
past_key_values = self.encoder_mlp(encoder_embed) #bsz, seqlen, layer*emb
|
||||
bsz, seqlen, _ = past_key_values.shape
|
||||
past_key_values = past_key_values.view(bsz, seqlen, self.bart_config.decoder_layers * 2,
|
||||
self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads)
|
||||
past_key_values = self.dropout(past_key_values)
|
||||
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) # key + value
|
||||
|
||||
# decoder prompt
|
||||
decoder_embed = self.decoder_prompt_embed(input_tokens)
|
||||
past_key_values2 = self.decoder_mlp(decoder_embed) # bsz, seqlen, layer*emb
|
||||
past_key_values2 = past_key_values2.view(bsz, seqlen, self.bart_config.decoder_layers * 2,
|
||||
self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads)
|
||||
past_key_values2 = self.dropout(past_key_values2)
|
||||
past_key_values2 = past_key_values2.permute([2, 0, 3, 1, 4]).split(2)
|
||||
|
||||
# cross prompt
|
||||
cross_embed = self.prompt_cross_embed(input_tokens)
|
||||
past_key_values_enc = self.cross_mlp(cross_embed) # bsz, seqlen, layer*emb
|
||||
past_key_values_enc = past_key_values_enc.view(bsz, seqlen, self.bart_config.decoder_layers * 2,
|
||||
self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads)
|
||||
past_key_values_enc = self.dropout(past_key_values_enc)
|
||||
past_key_values_enc = past_key_values_enc.permute([2, 0, 3, 1, 4]).split(2)
|
||||
|
||||
result = []
|
||||
for i, key_val in enumerate(past_key_values):
|
||||
temp_dict = {'self': {"prev_key": key_val[0].contiguous(),
|
||||
"prev_value": key_val[1].contiguous(),
|
||||
"prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val.device).bool() #bsz, preseqlen
|
||||
},
|
||||
}
|
||||
key_val2 = past_key_values2[i]
|
||||
temp_dict['encoder_decoder'] = {"prev_key": key_val2[0].contiguous(),
|
||||
"prev_value": key_val2[1].contiguous(),
|
||||
"prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val2.device).bool()
|
||||
}
|
||||
key_val_enc = past_key_values_enc[i]
|
||||
temp_dict['encoder'] = {"prev_key": key_val_enc[0].contiguous(),
|
||||
"prev_value": key_val_enc[1].contiguous(),
|
||||
"prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val_enc.device).bool()
|
||||
}
|
||||
result.append(temp_dict)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class PromptBartState(object):
|
||||
def __init__(self, encoder_output, encoder_mask, past_key_values, src_tokens, first, src_embed_outputs, preseqlen):
|
||||
self.encoder_output = encoder_output
|
||||
self.encoder_mask = encoder_mask
|
||||
self.past_key_values = past_key_values
|
||||
self.src_tokens = src_tokens
|
||||
self.first = first
|
||||
self.src_embed_outputs = src_embed_outputs
|
||||
self.preseqlen = preseqlen
|
||||
|
||||
def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0):
|
||||
if isinstance(state, torch.Tensor):
|
||||
state = state.index_select(index=indices, dim=dim)
|
||||
elif isinstance(state, list):
|
||||
for i in range(len(state)):
|
||||
assert state[i] is not None
|
||||
state[i] = self._reorder_state(state[i], indices, dim)
|
||||
elif isinstance(state, tuple):
|
||||
tmp_list = []
|
||||
for i in range(len(state)):
|
||||
assert state[i] is not None
|
||||
tmp_list.append(self._reorder_state(state[i], indices, dim))
|
||||
state = tuple(tmp_list)
|
||||
else:
|
||||
raise TypeError(f"Cannot reorder data of type:{type(state)}")
|
||||
|
||||
return state
|
||||
|
||||
def reorder_state(self, indices: torch.LongTensor):
|
||||
super().reorder_state(indices)
|
||||
self.src_tokens = self._reorder_state(self.src_tokens, indices)
|
||||
if self.first is not None:
|
||||
self.first = self._reorder_state(self.first, indices)
|
||||
self.src_embed_outputs = self._reorder_state(self.src_embed_outputs, indices)
|
||||
if self.past_key_values is not None:
|
||||
new = []
|
||||
for layer in self.past_key_values:
|
||||
new_layer = {}
|
||||
for key1 in list(layer.keys()):
|
||||
new_layer_ = {}
|
||||
for key2 in list(layer[key1].keys()):
|
||||
if layer[key1][key2] is not None:
|
||||
layer[key1][key2] = self._reorder_state(layer[key1][key2], indices)
|
||||
new_layer_[key2] = layer[key1][key2]
|
||||
new_layer[key1] = new_layer_
|
||||
new.append(new_layer)
|
||||
self.past_key_values = new
|
||||
|
||||
def num_samples(self):
|
||||
if self.encoder_output is not None:
|
||||
return self.encoder_output.size(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
class PromptGeneratorModel(nn.Module):
|
||||
def __init__(self, prompt_model, max_length=20, max_len_a=0.0, num_beams=1,
|
||||
do_sample=False, bos_token_id=None, eos_token_id=None,
|
||||
repetition_penalty=1, length_penalty=1.0, pad_token_id=0, restricter=None):
|
||||
super(PromptGeneratorModel, self).__init__()
|
||||
self.prompt_model = prompt_model
|
||||
self.decoder = prompt_model.prompt_decoder
|
||||
|
||||
self.generate_func = partial(greedy_generate, decoder=self.decoder, max_length=max_length, max_len_a=max_len_a,
|
||||
num_beams=num_beams,
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
|
||||
repetition_penalty=repetition_penalty,
|
||||
length_penalty=length_penalty, pad_token_id=pad_token_id,
|
||||
restricter=restricter)
|
||||
self.do_sample = do_sample
|
||||
self.max_length = max_length
|
||||
self.num_beams = num_beams
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
self.repetition_penalty = repetition_penalty
|
||||
self.length_penalty = length_penalty
|
||||
self.pad_token_id = pad_token_id
|
||||
self.restricter = restricter
|
||||
self.max_len_a = max_len_a
|
||||
|
||||
|
||||
def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None, first=None):
|
||||
"""
|
||||
:param torch.LongTensor src_tokens: bsz x max_len
|
||||
:param torch.LongTensor tgt_tokens: bsz x max_len'
|
||||
:param torch.LongTensor src_seq_len: bsz
|
||||
:param torch.LongTensor tgt_seq_len: bsz
|
||||
:return:
|
||||
"""
|
||||
return self.prompt_model(src_tokens, tgt_tokens, src_seq_len, first)
|
||||
|
||||
|
||||
def predict(self, src_tokens, src_seq_len=None, first=None):
|
||||
"""
|
||||
:param torch.LongTensor src_tokens: bsz x max_len
|
||||
:param torch.LongTensor src_seq_len: bsz
|
||||
:return:
|
||||
"""
|
||||
prompt_state = self.prompt_model.generator(src_tokens, src_seq_len, first) # encoder output
|
||||
result = self.generate_func(tokens=None, state=prompt_state)
|
||||
return result
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1,
|
||||
bos_token_id=None, eos_token_id=None, pad_token_id=0,
|
||||
repetition_penalty=1, length_penalty=1.0, restricter=None):
|
||||
if num_beams == 1:
|
||||
token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id,
|
||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
|
||||
pad_token_id=pad_token_id, restricter=restricter)
|
||||
else:
|
||||
token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a,
|
||||
num_beams=num_beams,
|
||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False,
|
||||
repetition_penalty=repetition_penalty, length_penalty=length_penalty,
|
||||
pad_token_id=pad_token_id, restricter=restricter)
|
||||
|
||||
return token_ids
|
||||
|
||||
|
||||
def _no_beam_search_generate(decoder: PromptBartDecoder, state, tokens=None, max_length=20, max_len_a=0.0, bos_token_id=None,
|
||||
eos_token_id=None,
|
||||
repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0,
|
||||
restricter=None):
|
||||
device = _get_model_device(decoder)
|
||||
if tokens is None:
|
||||
if bos_token_id is None:
|
||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
|
||||
batch_size = state.num_samples()
|
||||
if batch_size is None:
|
||||
raise RuntimeError("Cannot infer the number of samples from `state`.")
|
||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
|
||||
batch_size = tokens.size(0)
|
||||
if state.num_samples:
|
||||
assert state.num_samples() == batch_size, "The number of samples in `tokens` and `state` should match."
|
||||
|
||||
if eos_token_id is None:
|
||||
_eos_token_id = -1
|
||||
else:
|
||||
_eos_token_id = eos_token_id
|
||||
scores = decoder.decode(tokens=tokens, state=state) # update state
|
||||
if restricter is not None:
|
||||
_, next_tokens = restricter(state, tokens, scores, num_beams=1)
|
||||
else:
|
||||
next_tokens = scores.argmax(dim=-1, keepdim=True)
|
||||
token_ids = torch.cat([tokens, next_tokens], dim=1)
|
||||
cur_len = token_ids.size(1)
|
||||
dones = token_ids.new_zeros(batch_size).eq(1).__or__(next_tokens.squeeze(1).eq(eos_token_id))
|
||||
# tokens = tokens[:, -1:]
|
||||
|
||||
if max_len_a!=0:
|
||||
# (bsz x num_beams, )
|
||||
if state.encoder_mask is not None:
|
||||
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length
|
||||
else:
|
||||
max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long)
|
||||
real_max_length = max_lengths.max().item()
|
||||
else:
|
||||
real_max_length = max_length
|
||||
if state.encoder_mask is not None:
|
||||
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length
|
||||
else:
|
||||
max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long)
|
||||
|
||||
while cur_len < real_max_length:
|
||||
scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size
|
||||
|
||||
if repetition_penalty != 1.0:
|
||||
token_scores = scores.gather(dim=1, index=token_ids)
|
||||
lt_zero_mask = token_scores.lt(0).float()
|
||||
ge_zero_mask = lt_zero_mask.eq(0).float()
|
||||
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
|
||||
scores.scatter_(dim=1, index=token_ids, src=token_scores)
|
||||
|
||||
if eos_token_id is not None and length_penalty != 1.0:
|
||||
token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size
|
||||
eos_mask = scores.new_ones(scores.size(1))
|
||||
eos_mask[eos_token_id] = 0
|
||||
eos_mask = eos_mask.unsqueeze(0).eq(1)
|
||||
scores = scores.masked_scatter(eos_mask, token_scores)
|
||||
|
||||
if restricter is not None:
|
||||
_, next_tokens = restricter(state, token_ids, scores, 1)
|
||||
else:
|
||||
next_tokens = scores.argmax(dim=-1, keepdim=True)
|
||||
next_tokens = next_tokens.squeeze(-1)
|
||||
|
||||
if _eos_token_id!=-1:
|
||||
next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id)
|
||||
next_tokens = next_tokens.masked_fill(dones, pad_token_id)
|
||||
tokens = next_tokens.unsqueeze(1)
|
||||
|
||||
token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len
|
||||
|
||||
end_mask = next_tokens.eq(_eos_token_id)
|
||||
dones = dones.__or__(end_mask)
|
||||
cur_len += 1
|
||||
|
||||
if dones.min() == 1:
|
||||
break
|
||||
return token_ids
|
||||
|
||||
|
||||
def _beam_search_generate(decoder: PromptBartDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4,
|
||||
bos_token_id=None, eos_token_id=None, do_sample=True,
|
||||
repetition_penalty=1.0, length_penalty=None, pad_token_id=0,
|
||||
restricter=None) -> torch.LongTensor:
|
||||
assert do_sample is False
|
||||
# beam search
|
||||
device = _get_model_device(decoder)
|
||||
if tokens is None:
|
||||
if bos_token_id is None:
|
||||
raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.")
|
||||
batch_size = state.num_samples
|
||||
if batch_size is None:
|
||||
raise RuntimeError("Cannot infer the number of samples from `state`.")
|
||||
tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device)
|
||||
batch_size = tokens.size(0)
|
||||
if state.num_samples:
|
||||
assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match."
|
||||
|
||||
if eos_token_id is None:
|
||||
_eos_token_id = -1
|
||||
else:
|
||||
_eos_token_id = eos_token_id
|
||||
|
||||
scores = decoder.decode(tokens=tokens, state=state)
|
||||
vocab_size = scores.size(1)
|
||||
assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size."
|
||||
|
||||
scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size)
|
||||
if restricter is not None:
|
||||
_next_scores, _next_tokens = restricter(state, tokens, scores, num_beams+1)
|
||||
else:
|
||||
# bsz x (num_beams+1)
|
||||
_next_scores, _next_tokens = torch.topk(scores, num_beams+1, dim=1, largest=True, sorted=True)
|
||||
|
||||
indices = torch.arange(batch_size, dtype=torch.long).to(device)
|
||||
indices = indices.repeat_interleave(num_beams)
|
||||
state.reorder_state(indices)
|
||||
tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length
|
||||
|
||||
if max_len_a!=0:
|
||||
# (bsz x num_beams, )
|
||||
if state.encoder_mask is not None:
|
||||
max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length
|
||||
else:
|
||||
max_lengths = tokens.new_full((batch_size*num_beams, ), fill_value=max_length, dtype=torch.long)
|
||||
real_max_length = max_lengths.max().item()
|
||||
else:
|
||||
real_max_length = max_length
|
||||
if state.encoder_mask is not None:
|
||||
max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length
|
||||
else:
|
||||
max_lengths = tokens.new_full((batch_size*num_beams,), fill_value=max_length, dtype=torch.long)
|
||||
hypos = [
|
||||
BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size)
|
||||
]
|
||||
|
||||
not_eos_mask = _next_tokens.ne(_eos_token_id)
|
||||
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams)
|
||||
keep_mask = not_eos_mask.__and__(keep_mask)
|
||||
|
||||
next_tokens = _next_tokens.masked_select(keep_mask).view(batch_size, num_beams)
|
||||
next_scores = _next_scores.masked_select(keep_mask).view(batch_size, num_beams)
|
||||
|
||||
rows, cols = not_eos_mask.eq(0)[:, :num_beams].nonzero(as_tuple=True)
|
||||
|
||||
if len(rows)>0:
|
||||
for row, col in zip(rows.tolist(), cols.tolist()):
|
||||
_token = torch.cat([tokens[row*num_beams], _next_tokens[row, col:col+1]], dim=0)
|
||||
hypos[row].add(_token.clone(), _next_scores[row, col].item())
|
||||
|
||||
# (batch_size, cur_len)
|
||||
token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1)
|
||||
dones = [False] * batch_size
|
||||
|
||||
beam_scores = next_scores.view(-1) # batch_size * num_beams
|
||||
|
||||
cur_len = token_ids.size(1)
|
||||
|
||||
# 0, num_beams, 2*num_beams, ...
|
||||
batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids)
|
||||
|
||||
while cur_len < real_max_length:
|
||||
scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size)
|
||||
if repetition_penalty != 1.0:
|
||||
token_scores = scores.gather(dim=1, index=token_ids)
|
||||
lt_zero_mask = token_scores.lt(0).float()
|
||||
ge_zero_mask = lt_zero_mask.eq(0).float()
|
||||
token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores
|
||||
scores.scatter_(dim=1, index=token_ids, src=token_scores)
|
||||
|
||||
if _eos_token_id!=-1:
|
||||
max_len_eos_mask = max_lengths.eq(cur_len+1)
|
||||
eos_scores = scores[:, _eos_token_id]
|
||||
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores)
|
||||
|
||||
scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size)
|
||||
_scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size)
|
||||
_scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size)
|
||||
|
||||
if restricter is not None:
|
||||
next_scores, ids = restricter(state, token_ids, _scores, 2 * num_beams)
|
||||
else:
|
||||
next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams)
|
||||
from_which_beam = ids // vocab_size # (batch_size, 2*num_beams)
|
||||
next_tokens = ids % vocab_size # (batch_size, 2*num_beams)
|
||||
|
||||
not_eos_mask = next_tokens.ne(_eos_token_id)
|
||||
keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams)
|
||||
keep_mask = not_eos_mask.__and__(keep_mask)
|
||||
|
||||
_next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1)
|
||||
_from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams)
|
||||
_next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams)
|
||||
beam_scores = _next_scores.view(-1)
|
||||
|
||||
flag = True
|
||||
if cur_len+1 == real_max_length:
|
||||
eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0)
|
||||
eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size)
|
||||
eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1)
|
||||
else:
|
||||
effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams
|
||||
if effective_eos_mask.sum().gt(0):
|
||||
eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True)
|
||||
eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind
|
||||
eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx]
|
||||
else:
|
||||
flag = False
|
||||
|
||||
if flag:
|
||||
_token_ids = torch.cat([token_ids, _next_tokens], dim=-1)
|
||||
for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(),
|
||||
eos_beam_idx.tolist()):
|
||||
if not dones[batch_idx]:
|
||||
score = next_scores[batch_idx, beam_ind].item()
|
||||
if _eos_token_id!=-1:
|
||||
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score)
|
||||
else:
|
||||
hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score)
|
||||
|
||||
reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten
|
||||
state.reorder_state(reorder_inds)
|
||||
token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \
|
||||
max_lengths[batch_idx*num_beams]==cur_len+1
|
||||
|
||||
cur_len += 1
|
||||
|
||||
if all(dones):
|
||||
break
|
||||
|
||||
# select the best hypotheses
|
||||
tgt_len = token_ids.new_zeros(batch_size)
|
||||
best = []
|
||||
|
||||
for i, hypotheses in enumerate(hypos):
|
||||
best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
|
||||
if _eos_token_id!=-1:
|
||||
best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id])
|
||||
tgt_len[i] = len(best_hyp)
|
||||
best.append(best_hyp)
|
||||
|
||||
# generate target batch
|
||||
decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id)
|
||||
for i, hypo in enumerate(best):
|
||||
decoded[i, :tgt_len[i]] = hypo
|
||||
|
||||
return decoded
|
||||
|
||||
|
||||
class BeamHypotheses(object):
|
||||
def __init__(self, num_beams, max_length, length_penalty, early_stopping):
|
||||
"""
|
||||
Initialize n-best list of hypotheses.
|
||||
"""
|
||||
self.max_length = max_length - 1 # ignoring bos_token
|
||||
self.length_penalty = length_penalty
|
||||
self.early_stopping = early_stopping
|
||||
self.num_beams = num_beams
|
||||
self.hyp = []
|
||||
self.worst_score = 1e9
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Number of hypotheses in the list.
|
||||
"""
|
||||
return len(self.hyp)
|
||||
|
||||
def add(self, hyp, sum_logprobs):
|
||||
"""
|
||||
Add a new hypothesis to the list.
|
||||
"""
|
||||
score = sum_logprobs / len(hyp) ** self.length_penalty
|
||||
if len(self) < self.num_beams or score > self.worst_score:
|
||||
self.hyp.append((score, hyp))
|
||||
if len(self) > self.num_beams:
|
||||
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)])
|
||||
del self.hyp[sorted_scores[0][1]]
|
||||
self.worst_score = sorted_scores[1][0]
|
||||
else:
|
||||
self.worst_score = min(score, self.worst_score)
|
||||
|
||||
def is_done(self, best_sum_logprobs):
|
||||
"""
|
||||
If there are enough hypotheses and that none of the hypotheses being generated
|
||||
can become better than the worst one in the heap, then we are done with this sentence.
|
||||
"""
|
||||
if len(self) < self.num_beams:
|
||||
return False
|
||||
elif self.early_stopping:
|
||||
return True
|
||||
else:
|
||||
return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,240 @@
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from itertools import chain
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformers import BartTokenizer
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# load file and process bio
|
||||
class ConllNERProcessor(object):
|
||||
def __init__(self, data_path, mapping, bart_name, learn_weights) -> None:
|
||||
self.data_path = data_path
|
||||
self.tokenizer = BartTokenizer.from_pretrained(bart_name)
|
||||
self.mapping = mapping # 记录的是原始tag与转换后的tag的str的匹配关系
|
||||
self.original_token_nums = self.tokenizer.vocab_size
|
||||
self.learn_weights = learn_weights
|
||||
self._add_tags_to_tokens()
|
||||
|
||||
def load_from_file(self, mode='train'):
|
||||
"""load conll ner from file
|
||||
|
||||
Args:
|
||||
mode (str, optional): train/test/dev. Defaults to 'train'.
|
||||
Return:
|
||||
outputs (dict)
|
||||
raw_words: ['EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British', 'lamb', '.']
|
||||
raw_targets: ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O']
|
||||
entities: [['EU'], ['German'], ['British']]
|
||||
entity_tags: ['org', 'misc', 'misc']
|
||||
entity_spans: [[0, 1], [2, 3], [6, 7]]
|
||||
"""
|
||||
load_file = self.data_path[mode]
|
||||
logger.info("Loading data from {}".format(load_file))
|
||||
|
||||
# extract bio
|
||||
outputs = {'raw_words':[], 'raw_targets':[], 'entities':[], 'entity_tags':[], 'entity_spans':[]}
|
||||
with open(load_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
raw_words, raw_targets = [], []
|
||||
raw_word, raw_target = [], []
|
||||
for line in lines:
|
||||
if line != "\n":
|
||||
raw_word.append(line.split('\t')[0])
|
||||
raw_target.append(line.split('\t')[1][:-1])
|
||||
else:
|
||||
raw_words.append(raw_word)
|
||||
raw_targets.append(raw_target)
|
||||
raw_word, raw_target = [], []
|
||||
|
||||
for words, targets in zip(raw_words, raw_targets):
|
||||
entities, entity_tags, entity_spans = [], [], []
|
||||
start, end, start_flag = 0, 0, False
|
||||
for idx, tag in enumerate(targets):
|
||||
if tag.startswith('B-'): # 一个实体开头 另一个实体(I-)结束
|
||||
end = idx
|
||||
if start_flag: # 另一个实体以I-结束,紧接着当前实体B-出现
|
||||
entities.append(words[start:end])
|
||||
entity_tags.append(targets[start][2:].lower())
|
||||
entity_spans.append([start, end])
|
||||
start_flag = False
|
||||
start = idx
|
||||
start_flag = True
|
||||
elif tag.startswith('I-'): # 实体中间,不是开头也不是结束,end+1即可
|
||||
end = idx
|
||||
elif tag.startswith('O'): # 无实体,可能是上一个实体的结束
|
||||
end = idx
|
||||
if start_flag: # 上一个实体结束
|
||||
entities.append(words[start:end])
|
||||
entity_tags.append(targets[start][2:].lower())
|
||||
entity_spans.append([start, end])
|
||||
start_flag = False
|
||||
if start_flag: # 句子以实体I-结束,未被添加
|
||||
entities.append(words[start:end+1])
|
||||
entity_tags.append(targets[start][2:].lower())
|
||||
entity_spans.append([start, end+1])
|
||||
start_flag = False
|
||||
|
||||
if len(entities) != 0:
|
||||
outputs['raw_words'].append(words)
|
||||
outputs['raw_targets'].append(targets)
|
||||
outputs['entities'].append(entities)
|
||||
outputs['entity_tags'].append(entity_tags)
|
||||
outputs['entity_spans'].append(entity_spans)
|
||||
return outputs
|
||||
|
||||
def process(self, data_dict):
|
||||
target_shift = len(self.mapping) + 2
|
||||
def prepare_target(item):
|
||||
raw_word = item['raw_word']
|
||||
word_bpes = [[self.tokenizer.bos_token_id]]
|
||||
first = []
|
||||
cur_bpe_len = 1
|
||||
for word in raw_word:
|
||||
bpes = self.tokenizer.tokenize(word, add_prefix_space=True)
|
||||
bpes = self.tokenizer.convert_tokens_to_ids(bpes)
|
||||
first.append(cur_bpe_len)
|
||||
cur_bpe_len += len(bpes)
|
||||
word_bpes.append(bpes)
|
||||
assert first[-1] + len(bpes) == sum(map(len, word_bpes))
|
||||
word_bpes.append([self.tokenizer.eos_token_id])
|
||||
assert len(first) == len(raw_word) == len(word_bpes) - 2
|
||||
|
||||
lens = list(map(len, word_bpes))
|
||||
cum_lens = np.cumsum(lens).tolist()
|
||||
|
||||
entity_spans = item['entity_span'] # [(s1, e1, s2, e2), ()]
|
||||
entity_tags = item['entity_tag'] # [tag1, tag2...]
|
||||
entities = item['entity'] # [[ent1, ent2,], [ent1, ent2]]
|
||||
target = [0]
|
||||
pairs = []
|
||||
|
||||
first = list(range(cum_lens[-1]))
|
||||
|
||||
assert len(entity_spans) == len(entity_tags) #
|
||||
for idx, (entity, tag) in enumerate(zip(entity_spans, entity_tags)):
|
||||
cur_pair = []
|
||||
num_ent = len(entity) // 2
|
||||
for i in range(num_ent):
|
||||
start = entity[2 * i]
|
||||
end = entity[2 * i + 1]
|
||||
cur_pair_ = []
|
||||
cur_pair_.extend([cum_lens[k] for k in list(range(start, end))])
|
||||
cur_pair.extend([p + target_shift for p in cur_pair_])
|
||||
for _, (j, word_idx) in enumerate(zip((cur_pair[0], cur_pair[-1]), (0, -1))):
|
||||
j = j - target_shift
|
||||
assert all([cur_pair[i] < cum_lens[-1] + target_shift for i in range(len(cur_pair))])
|
||||
|
||||
cur_pair.append(self.mapping2targetid[tag] + 2)
|
||||
pairs.append([p for p in cur_pair])
|
||||
target.extend(list(chain(*pairs)))
|
||||
target.append(1)
|
||||
|
||||
word_bpes = list(chain(*word_bpes))
|
||||
assert len(word_bpes)<500
|
||||
|
||||
dict = {'tgt_tokens': target, 'target_span': pairs, 'src_tokens': word_bpes,
|
||||
'first': first, 'src_seq_len':len(word_bpes), 'tgt_seq_len':len(target)}
|
||||
return dict
|
||||
|
||||
logger.info("Process data...")
|
||||
for raw_word, raw_target, entity, entity_tag, entity_span in tqdm(zip(data_dict['raw_words'], data_dict['raw_targets'], data_dict['entities'],
|
||||
data_dict['entity_tags'], data_dict['entity_spans']), total=len(data_dict['raw_words']), desc='Processing'):
|
||||
item_dict = prepare_target({'raw_word': raw_word, 'raw_target':raw_target, 'entity': entity, 'entity_tag': entity_tag, 'entity_span': entity_span})
|
||||
# add item_dict to data_dict
|
||||
for key, value in item_dict.items():
|
||||
if key in data_dict:
|
||||
data_dict[key].append(value)
|
||||
else:
|
||||
data_dict[key] = [value]
|
||||
return data_dict
|
||||
|
||||
def _add_tags_to_tokens(self):
|
||||
mapping = self.mapping
|
||||
if self.learn_weights: # add extra tokens to huggingface tokenizer
|
||||
self.mapping2id = {}
|
||||
self.mapping2targetid = {}
|
||||
for key, value in self.mapping.items():
|
||||
key_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(value[2:-2], add_prefix_space=True))
|
||||
self.mapping2id[value] = key_id # may be list
|
||||
self.mapping2targetid[key] = len(self.mapping2targetid)
|
||||
else:
|
||||
tokens_to_add = sorted(list(mapping.values()), key=lambda x: len(x), reverse=True) #
|
||||
unique_no_split_tokens = self.tokenizer.unique_no_split_tokens # no split
|
||||
sorted_add_tokens = sorted(list(tokens_to_add), key=lambda x: len(x), reverse=True)
|
||||
for tok in sorted_add_tokens:
|
||||
assert self.tokenizer.convert_tokens_to_ids([tok])[0] == self.tokenizer.unk_token_id #
|
||||
self.tokenizer.unique_no_split_tokens = unique_no_split_tokens + sorted_add_tokens # add to no_split_tokens
|
||||
self.tokenizer.add_tokens(sorted_add_tokens)
|
||||
self.mapping2id = {} # tag to id
|
||||
self.mapping2targetid = {} # tag to number
|
||||
|
||||
for key, value in self.mapping.items():
|
||||
key_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(value))
|
||||
assert len(key_id) == 1, value
|
||||
assert key_id[0] >= self.original_token_nums
|
||||
self.mapping2id[value] = key_id[0] #
|
||||
self.mapping2targetid[key] = len(self.mapping2targetid)
|
||||
|
||||
|
||||
class ConllNERDataset(Dataset):
|
||||
def __init__(self, data_processor, mode='train') -> None:
|
||||
self.data_processor = data_processor
|
||||
self.data_dict = data_processor.load_from_file(mode=mode)
|
||||
self.complet_data = data_processor.process(self.data_dict)
|
||||
self.mode = mode
|
||||
|
||||
def __len__(self):
|
||||
return len(self.complet_data['src_tokens'])
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.mode == 'test':
|
||||
return torch.tensor(self.complet_data['src_tokens'][index]), torch.tensor(self.complet_data['src_seq_len'][index]), \
|
||||
torch.tensor(self.complet_data['first'][index]), self.complet_data['raw_words'][index]
|
||||
|
||||
return torch.tensor(self.complet_data['src_tokens'][index]), torch.tensor(self.complet_data['tgt_tokens'][index]), \
|
||||
torch.tensor(self.complet_data['src_seq_len'][index]), torch.tensor(self.complet_data['tgt_seq_len'][index]), \
|
||||
torch.tensor(self.complet_data['first'][index]), self.complet_data['target_span'][index]
|
||||
|
||||
|
||||
def collate_fn(self, batch):
|
||||
src_tokens, src_seq_len, first = [], [], []
|
||||
tgt_tokens, tgt_seq_len, target_span = [], [], []
|
||||
if self.mode == "test":
|
||||
raw_words = []
|
||||
for tup in batch:
|
||||
src_tokens.append(tup[0])
|
||||
src_seq_len.append(tup[1])
|
||||
first.append(tup[2])
|
||||
raw_words.append(tup[3])
|
||||
src_tokens = pad_sequence(src_tokens, batch_first=True, padding_value=self.data_processor.tokenizer.pad_token_id)
|
||||
first = pad_sequence(first, batch_first=True, padding_value=0)
|
||||
return src_tokens, torch.stack(src_seq_len, 0), first, raw_words
|
||||
|
||||
for tup in batch:
|
||||
src_tokens.append(tup[0])
|
||||
tgt_tokens.append(tup[1])
|
||||
src_seq_len.append(tup[2])
|
||||
tgt_seq_len.append(tup[3])
|
||||
first.append(tup[4])
|
||||
target_span.append(tup[5])
|
||||
src_tokens = pad_sequence(src_tokens, batch_first=True, padding_value=self.data_processor.tokenizer.pad_token_id)
|
||||
tgt_tokens = pad_sequence(tgt_tokens, batch_first=True, padding_value=1)
|
||||
first = pad_sequence(first, batch_first=True, padding_value=0)
|
||||
return src_tokens, tgt_tokens, torch.stack(src_seq_len, 0), torch.stack(tgt_seq_len, 0), first, target_span
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_path = {'train':'data/conll2003/train.txt'}
|
||||
bart_name = '../BARTNER-AMAX/facebook/'
|
||||
conll_processor = ConllNERProcessor(data_path, bart_name)
|
||||
conll_datasets = ConllNERDataset(conll_processor, mode='train')
|
||||
conll_dataloader = DataLoader(conll_datasets, collate_fn=conll_datasets.collate_fn, batch_size=8)
|
||||
for idx, data in enumerate(conll_dataloader):
|
||||
print(data)
|
||||
break
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
atis_mapping = {
|
||||
'depart_date.day_number':'<<depart date day number>>',
|
||||
'arrive_date.day_name':'<<arrive date day name>>',
|
||||
'airline_name':'<<airline name>>',
|
||||
'depart_date.year':'<<depart date year>>',
|
||||
'flight_mod':'<<flight mod>>',
|
||||
'return_date.day_name':'<<return date day name>>',
|
||||
'toloc.city_name':'<<toloc city name>>',
|
||||
'return_date.day_number':'<<return date day number>>',
|
||||
'time_relative':'<<time relative>>',
|
||||
'city_name':'<<city name>>',
|
||||
'state_code':'<<state code>>',
|
||||
'transport_type':'<<transport type>>',
|
||||
'class_type':'<<class type>>',
|
||||
'days_code':'<<days code>>',
|
||||
'toloc.country_name':'<<toloc country name>>',
|
||||
'arrive_date.today_relative':'<<arrive date today relative>>',
|
||||
'round_trip':'<<round trip>>',
|
||||
'toloc.state_name':'<<toloc state name>>',
|
||||
'aircraft_code':'<<aircraft code>>',
|
||||
'arrive_date.month_name':'<<arrive date month name>>',
|
||||
'depart_date.today_relative':'<<depart date today relative>>',
|
||||
'depart_time.start_time':'<<depart time start time>>',
|
||||
'compartment':'<<compartment>>',
|
||||
'day_number':'<<day number>>',
|
||||
'depart_date.date_relative':'<<depart date date relative>>',
|
||||
'arrive_date.day_number':'<<arrive date day number>>',
|
||||
'depart_time.time':'<<depart time time>>',
|
||||
'fare_amount':'<<fare amount>>',
|
||||
'depart_date.month_name':'<<depart date month name>>',
|
||||
'period_of_day':'<<period of day>>',
|
||||
'cost_relative':'<<cost relative>>',
|
||||
'fromloc.airport_name':'<<fromloc airport name>>',
|
||||
'fare_basis_code':'<<fare basis code>>',
|
||||
'arrive_time.start_time':'<<arrive time start time>>',
|
||||
'stoploc.airport_name':'<<stoploc airport name>>',
|
||||
'time':'<<time>>',
|
||||
'depart_time.time_relative':'<<depart time time relative>>',
|
||||
'return_time.period_of_day':'<<return time period of day>>',
|
||||
'depart_time.period_of_day':'<<depart time period of day>>',
|
||||
'economy':'<<economy>>',
|
||||
'mod':'<<mod>>',
|
||||
'stoploc.airport_code':'<<stoploc airport code>>',
|
||||
'stoploc.state_code':'<<stoploc state code>>',
|
||||
'arrive_time.end_time':'<<arrive time end time>>',
|
||||
'state_name':'<<state name>>',
|
||||
'airport_name':'<<airport name>>',
|
||||
'depart_date.day_name':'<<depart date day name>>',
|
||||
'fromloc.state_name':'<<fromloc state name>>',
|
||||
'arrive_time.time_relative':'<<arrive time time relative>>',
|
||||
'today_relative':'<<today relative>>',
|
||||
'day_name':'<<day name>>',
|
||||
'flight_stop':'<<flight stop>>',
|
||||
'month_name':'<<month name>>',
|
||||
'fromloc.city_name':'<<fromloc city name>>',
|
||||
'meal':'<<meal>>',
|
||||
'arrive_time.period_of_day':'<<arrive time period of day>>',
|
||||
'return_time.period_mod':'<<return time period mod>>',
|
||||
'toloc.airport_code':'<<toloc airport code>>',
|
||||
'airport_code':'<<airport code>>',
|
||||
'restriction_code':'<<restriction code>>',
|
||||
'flight_time':'<<flight time>>',
|
||||
'airline_code':'<<airline code>>',
|
||||
'depart_time.end_time':'<<depart time end time>>',
|
||||
'flight_days':'<<flight days>>',
|
||||
'booking_class':'<<booking class>>',
|
||||
'flight_number':'<<flight number>>',
|
||||
'or':'<<or>>',
|
||||
'fromloc.airport_code':'<<fromloc airport code>>',
|
||||
'meal_description':'<<meal description>>',
|
||||
'return_date.date_relative':'<<return date date relative>>',
|
||||
'return_date.month_name':'<<return date month name>>',
|
||||
'arrive_date.date_relative':'<<arrive date date relative>>',
|
||||
'return_date.today_relative':'<<return date today relative>>',
|
||||
'arrive_time.period_mod':'<<arrive time period mod>>',
|
||||
'depart_time.period_mod':'<<depart time period mod>>',
|
||||
'meal_code':'<<meal code>>',
|
||||
'flight':'<<flight>>',
|
||||
'toloc.airport_name':'<<toloc airport name>>',
|
||||
'stoploc.city_name':'<<stoploc city name>>',
|
||||
'connect':'<<connect>>',
|
||||
'arrive_time.time':'<<arrive time time>>',
|
||||
'toloc.state_code':'<<toloc state code>>',
|
||||
'fromloc.state_code':'<<fromloc state code>>'
|
||||
}
|
||||
|
||||
mit_movie_mapping = {
|
||||
'genre': '<<genre>>',
|
||||
'actor': '<<actor>>',
|
||||
'year': '<<year>>',
|
||||
'title': '<<title>>',
|
||||
'rating': '<<rating>>',
|
||||
'ratings_average': '<<ratings average>>',
|
||||
'director': '<<director>>',
|
||||
'plot': '<<plot>>',
|
||||
'character': '<<character>>',
|
||||
'song': '<<song>>',
|
||||
'review': '<<review>>',
|
||||
'trailer': '<<trailer>>'
|
||||
}
|
||||
|
||||
mit_restaurant_mapping = {
|
||||
'location': '<<location>>',
|
||||
'cuisine': '<<cuisine>>',
|
||||
'amenity': '<<amenity>>',
|
||||
'restaurant_name': '<<restaurant name>>',
|
||||
'rating': '<<rating>>',
|
||||
'dish': '<<dish>>',
|
||||
'hours': '<<hours>>',
|
||||
'price': '<<price>>'
|
||||
}
|
|
@ -0,0 +1,110 @@
|
|||
import numpy as np
|
||||
|
||||
|
||||
class Seq2SeqSpanMetric(object):
|
||||
def __init__(self, eos_token_id, num_labels, target_type='word'):
|
||||
self.eos_token_id = eos_token_id
|
||||
self.num_labels = num_labels
|
||||
self.word_start_index = num_labels+2
|
||||
|
||||
self.fp = 0
|
||||
self.tp = 0
|
||||
self.fn = 0
|
||||
self.em = 0
|
||||
self.total = 0
|
||||
self.target_type = target_type
|
||||
|
||||
def evaluate(self, target_span, pred, tgt_tokens):
|
||||
self.total += pred.size(0)
|
||||
pred_eos_index = pred.flip(dims=[1]).eq(self.eos_token_id).cumsum(dim=1).long()
|
||||
target_eos_index = tgt_tokens.flip(dims=[1]).eq(self.eos_token_id).cumsum(dim=1).long()
|
||||
|
||||
pred = pred[:, 1:]
|
||||
tgt_tokens = tgt_tokens[:, 1:]
|
||||
pred_seq_len = pred_eos_index.flip(dims=[1]).eq(pred_eos_index[:, -1:]).sum(dim=1) # bsz
|
||||
pred_seq_len = (pred_seq_len - 2).tolist()
|
||||
target_seq_len = target_eos_index.flip(dims=[1]).eq(target_eos_index[:, -1:]).sum(dim=1) # bsz
|
||||
target_seq_len = (target_seq_len-2).tolist()
|
||||
pred_spans = []
|
||||
for i, (ts, ps) in enumerate(zip(target_span, pred.tolist())):
|
||||
em = 0
|
||||
ps = ps[:pred_seq_len[i]]
|
||||
if pred_seq_len[i]==target_seq_len[i]:
|
||||
em = int(tgt_tokens[i, :target_seq_len[i]].eq(pred[i, :target_seq_len[i]]).sum().item()==target_seq_len[i])
|
||||
self.em += em
|
||||
pairs = []
|
||||
cur_pair = []
|
||||
if len(ps):
|
||||
for j in ps:
|
||||
if j<self.word_start_index:
|
||||
if self.target_type == 'span':
|
||||
if len(cur_pair)>0 and len(cur_pair)%2==0:
|
||||
if all([cur_pair[i]<=cur_pair[i+1] for i in range(len(cur_pair)-1)]):
|
||||
pairs.append(tuple(cur_pair+[j]))
|
||||
else:
|
||||
if len(cur_pair) > 0:
|
||||
if all([cur_pair[i]<cur_pair[i+1] for i in range(len(cur_pair)-1)]):
|
||||
pairs.append(tuple(cur_pair + [j]))
|
||||
cur_pair = []
|
||||
else:
|
||||
cur_pair.append(j)
|
||||
pred_spans.append(pairs.copy())
|
||||
|
||||
tp, fn, fp = _compute_tp_fn_fp(pairs, ts)
|
||||
self.fn += fn
|
||||
self.tp += tp
|
||||
self.fp += fp
|
||||
|
||||
def get_metric(self, reset=True):
|
||||
res = {}
|
||||
f, pre, rec = _compute_f_pre_rec(1, self.tp, self.fn, self.fp)
|
||||
res['f'] = round(f, 4)*100
|
||||
res['rec'] = round(rec, 4)*100
|
||||
res['pre'] = round(pre, 4)*100
|
||||
res['em'] = round(self.em/self.total, 4)
|
||||
if reset:
|
||||
self.total = 0
|
||||
self.fp = 0
|
||||
self.tp = 0
|
||||
self.fn = 0
|
||||
self.em = 0
|
||||
return res
|
||||
|
||||
def _compute_f_pre_rec(beta_square, tp, fn, fp):
|
||||
r"""
|
||||
|
||||
:param tp: int, true positive
|
||||
:param fn: int, false negative
|
||||
:param fp: int, false positive
|
||||
:return: (f, pre, rec)
|
||||
"""
|
||||
pre = tp / (fp + tp + 1e-13)
|
||||
rec = tp / (fn + tp + 1e-13)
|
||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)
|
||||
|
||||
return f, pre, rec
|
||||
|
||||
|
||||
def _compute_tp_fn_fp(ps, ts):
|
||||
ps = ps.copy()
|
||||
tp = 0
|
||||
fp = 0
|
||||
fn = 0
|
||||
if isinstance(ts, (set, list, np.ndarray)):
|
||||
ts = {tuple(key):1 for key in list(ts)}
|
||||
if isinstance(ps, (set, list, np.ndarray)):
|
||||
ps = {tuple(key):1 for key in list(ps)}
|
||||
|
||||
for key in ts.keys():
|
||||
t_num = ts[key]
|
||||
if key not in ps:
|
||||
p_num = 0
|
||||
else:
|
||||
p_num = ps[key]
|
||||
tp += min(p_num, t_num)
|
||||
fp += max(p_num - t_num, 0)
|
||||
fn += max(t_num - p_num, 0)
|
||||
if key in ps:
|
||||
ps.pop(key)
|
||||
fp += sum(ps.values())
|
||||
return tp, fn, fp
|
|
@ -0,0 +1,193 @@
|
|||
import torch
|
||||
from torch import optim
|
||||
from tqdm import tqdm
|
||||
from ..utils.utils import convert_preds_to_outputs, write_predictions
|
||||
import random
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, train_data=None, dev_data=None, test_data=None, model=None, process=None, args=None, logger=None, loss=None, metrics=None, writer=None) -> None:
|
||||
self.train_data = train_data
|
||||
self.dev_data = dev_data
|
||||
self.test_data = test_data
|
||||
self.model = model
|
||||
self.process = process
|
||||
self.logger = logger
|
||||
self.metrics = metrics
|
||||
self.writer = writer
|
||||
self.loss = loss
|
||||
self.num_epochs = args.num_epochs
|
||||
self.batch_size = args.batch_size
|
||||
self.lr = args.learning_rate
|
||||
self.eval_begin_epoch = args.eval_begin_epoch
|
||||
self.device = args.device
|
||||
self.load_path = args.load_path
|
||||
self.save_path = args.save_path
|
||||
self.refresh_step = 1
|
||||
self.best_metric = 0
|
||||
self.best_dev_epoch = None
|
||||
self.optimizer = None
|
||||
if self.train_data is not None:
|
||||
self.train_num_steps = len(self.train_data) * self.num_epochs
|
||||
self.step = 0
|
||||
self.args = args
|
||||
|
||||
|
||||
def train(self):
|
||||
self.before_train() # something should do before training
|
||||
self.step = 0
|
||||
self.model.train()
|
||||
self.logger.info("***** Running training *****")
|
||||
self.logger.info(" Num instance = %d", len(self.train_data)*self.batch_size)
|
||||
self.logger.info(" Num epoch = %d", self.num_epochs)
|
||||
self.logger.info(" Batch size = %d", self.batch_size)
|
||||
self.logger.info(" Learning rate = {}".format(self.lr))
|
||||
self.logger.info(" Evaluate begin = %d", self.eval_begin_epoch)
|
||||
|
||||
if self.load_path is not None: # load model from load_path
|
||||
self.logger.info("Loading model from {}".format(self.load_path))
|
||||
self.model.load_state_dict(torch.load(self.load_path))
|
||||
self.logger.info("Load model successful!")
|
||||
|
||||
with tqdm(total=self.train_num_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, initial=self.step) as pbar:
|
||||
self.pbar = pbar
|
||||
avg_loss = 0
|
||||
for epoch in range(self.num_epochs):
|
||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.num_epochs))
|
||||
for batch in self.train_data:
|
||||
self.step += 1
|
||||
batch = (tup.to(self.device) if isinstance(tup, torch.Tensor) else tup for tup in batch)
|
||||
loss = self._step(batch, mode="train")
|
||||
avg_loss += loss.item()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
if self.step % self.refresh_step == 0:
|
||||
avg_loss = float(avg_loss) / self.refresh_step
|
||||
print_output = "loss:{:<6.5f}".format(avg_loss)
|
||||
pbar.update(1)
|
||||
pbar.set_postfix_str(print_output)
|
||||
self.writer.add_scalar(tag='loss', scalar_value=avg_loss, global_step=self.step) # tensorbordx
|
||||
avg_loss = 0
|
||||
if epoch >= self.eval_begin_epoch:
|
||||
self.evaluate(epoch) # generator to dev.
|
||||
pbar.close()
|
||||
self.pbar = None
|
||||
self.logger.info("Get best performance at epoch {}, best f1 score is {:.2f}".format(self.best_dev_epoch, self.best_metric))
|
||||
|
||||
|
||||
|
||||
def evaluate(self, epoch):
|
||||
self.model.eval()
|
||||
self.logger.info("***** Running evaluate *****")
|
||||
self.logger.info(" Num instance = %d", len(self.dev_data)*self.batch_size)
|
||||
self.logger.info(" Batch size = %d", self.batch_size)
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(self.dev_data), leave=False, dynamic_ncols=True) as pbar:
|
||||
pbar.set_description_str(desc="Dev")
|
||||
for batch in self.dev_data:
|
||||
batch = (tup.to(self.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device
|
||||
self._step(batch, mode="dev")
|
||||
pbar.update()
|
||||
# evaluate done
|
||||
eva_result = self.metrics.get_metric()
|
||||
pbar.close()
|
||||
self.logger.info("Epoch {}/{}, best f1: {}, current f1 score: {:.2f}, recall: {:.2f}, precision: {:.2f}."\
|
||||
.format(epoch, self.num_epochs, self.best_metric, eva_result['f'], eva_result['rec'], eva_result['pre']))
|
||||
self.writer.add_scalars('evaluate', {'f1': eva_result['f'],
|
||||
'recall': eva_result['rec'],
|
||||
'precision': eva_result['pre']}, epoch)
|
||||
if eva_result['f'] >= self.best_metric: # this epoch get best performance
|
||||
self.logger.info("Get better performance at epoch {}".format(epoch))
|
||||
self.best_dev_epoch = epoch
|
||||
self.best_metric = eva_result['f'] # update best metric(f1 score)
|
||||
if self.save_path is not None: # need to save model
|
||||
torch.save(self.model.state_dict(), self.save_path+"/best_model.pth")
|
||||
self.logger.info("Save best model at {}".format(self.save_path))
|
||||
|
||||
self.model.train()
|
||||
|
||||
def predict(self):
|
||||
assert self.load_path is not None and self.test_data is not None
|
||||
|
||||
self.model.eval()
|
||||
self.logger.info("***** Running testing *****")
|
||||
self.logger.info(" Num instance = %d", len(self.test_data)*self.batch_size)
|
||||
self.logger.info(" Batch size = %d", self.batch_size)
|
||||
if self.load_path is not None: # load model from load_path
|
||||
self.logger.info("Loading model from {}".format(self.load_path))
|
||||
self.model.load_state_dict(torch.load(self.load_path))
|
||||
self.logger.info("Load model successful!")
|
||||
self.model.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(self.test_data), leave=False, dynamic_ncols=True) as pbar:
|
||||
pbar.set_description_str(desc="Test")
|
||||
texts = []
|
||||
labels = []
|
||||
for batch in self.test_data:
|
||||
batch = (tup.to(self.device) if isinstance(tup, torch.Tensor) else tup for tup in batch) # to cpu/cuda device
|
||||
src_tokens, src_seq_len, first, raw_words = batch
|
||||
preds = self._step((src_tokens, src_seq_len, first), mode="test")
|
||||
outputs = convert_preds_to_outputs(preds, raw_words, self.process.mapping, self.process.tokenizer)
|
||||
texts.extend(raw_words)
|
||||
labels.extend(outputs)
|
||||
pbar.update()
|
||||
|
||||
self.logger.info("***** Predict example *****")
|
||||
idx = random.randint(0, len(texts))
|
||||
print(len(texts), len(labels))
|
||||
self.logger.info("Raw texts: " + " ".join(texts[idx]))
|
||||
self.logger.info("Prediction: " + " ".join(labels[idx]))
|
||||
if self.args.write_path is not None: # write predict
|
||||
write_predictions(self.args.write_path, texts, labels)
|
||||
self.logger.info("Write into {}!".format(self.args.write_path))
|
||||
|
||||
|
||||
|
||||
def _step(self, batch, mode="train"):
|
||||
if mode=="dev": # dev: compute metric
|
||||
src_tokens, tgt_tokens, src_seq_len, tgt_seq_len, first, target_span = batch
|
||||
pred = self.model.predict(src_tokens, src_seq_len, first)
|
||||
self.metrics.evaluate(target_span, pred, tgt_tokens)
|
||||
return
|
||||
elif mode=="test": # test: just get pred
|
||||
src_tokens, src_seq_len, first = batch
|
||||
pred = self.model.predict(src_tokens, src_seq_len, first)
|
||||
return pred
|
||||
else: # train: get loss
|
||||
src_tokens, tgt_tokens, src_seq_len, tgt_seq_len, first, target_span = batch
|
||||
pred = self.model(src_tokens, tgt_tokens, src_seq_len, first)
|
||||
loss = self.loss(tgt_tokens, tgt_seq_len, pred)
|
||||
return loss
|
||||
|
||||
|
||||
def before_train(self):
|
||||
parameters = []
|
||||
params = {'lr':self.lr, 'weight_decay':1e-2}
|
||||
params['params'] = [param for name, param in self.model.named_parameters() if not ('bart_encoder' in name or 'bart_decoder' in name)]
|
||||
parameters.append(params)
|
||||
|
||||
params = {'lr':self.lr, 'weight_decay':1e-2}
|
||||
params['params'] = []
|
||||
for name, param in self.model.named_parameters():
|
||||
if ('bart_encoder' in name or 'bart_decoder' in name) and not ('layernorm' in name or 'layer_norm' in name):
|
||||
params['params'].append(param)
|
||||
parameters.append(params)
|
||||
|
||||
params = {'lr':self.lr, 'weight_decay':0}
|
||||
params['params'] = []
|
||||
for name, param in self.model.named_parameters():
|
||||
if ('bart_encoder' in name or 'bart_decoder' in name) and ('layernorm' in name or 'layer_norm' in name):
|
||||
params['params'].append(param)
|
||||
parameters.append(params)
|
||||
|
||||
self.optimizer = optim.AdamW(parameters)
|
||||
|
||||
if self.args.freeze_plm: # freeze pretrained language model(bart)
|
||||
for name, par in self.model.named_parameters():
|
||||
if 'prompt_encoder' in name or 'prompt_decoder' in name and "bart_mlp" not in name:
|
||||
par.requires_grad = False
|
||||
|
||||
self.model.to(self.device)
|
|
@ -0,0 +1,157 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BartModel, BartTokenizer
|
||||
|
||||
|
||||
def avg_token_embeddings(tokenizer: BartTokenizer, bart_model: BartModel, bart_name, num_tokens):
|
||||
"""when initial added tokens, use their averge token emebddings
|
||||
|
||||
Args:
|
||||
tokenizer (BartTokenizer): [description]
|
||||
bart_model (BartModel): [description]
|
||||
bart_name ([type]): [description]
|
||||
num_tokens ([type]): [description]
|
||||
|
||||
Raises:
|
||||
RuntimeError: [description]
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
_tokenizer = BartTokenizer.from_pretrained(bart_name)
|
||||
for token in tokenizer.unique_no_split_tokens:
|
||||
if token[:2] == '<<': # 特殊字符
|
||||
index = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(token))
|
||||
if len(index)>1:
|
||||
raise RuntimeError(f"{token} wrong split")
|
||||
else:
|
||||
index = index[0]
|
||||
assert index>=num_tokens, (index, num_tokens, token)
|
||||
indexes = _tokenizer.convert_tokens_to_ids(_tokenizer.tokenize(token[2:-2]))
|
||||
embed = bart_model.encoder.embed_tokens.weight.data[indexes[0]]
|
||||
for i in indexes[1:]:
|
||||
embed += bart_model.decoder.embed_tokens.weight.data[i]
|
||||
embed /= len(indexes)
|
||||
bart_model.decoder.embed_tokens.weight.data[index] = embed
|
||||
return bart_model
|
||||
|
||||
|
||||
def seq_to_mask(seq_len, max_len):
|
||||
"""[get attention mask with sequence length]
|
||||
|
||||
Args:
|
||||
seq_len ([torch.tensor]): [shape: bsz, each sequence length in a batch]
|
||||
"""
|
||||
max_len = int(max_len) if max_len else seq_len.max().long()
|
||||
cast_seq = torch.arange(max_len).expand(seq_len.size(0), -1).to(seq_len)
|
||||
mask = cast_seq.lt(seq_len.unsqueeze(1))
|
||||
return mask
|
||||
|
||||
def get_loss(tgt_tokens, tgt_seq_len, pred):
|
||||
"""
|
||||
|
||||
:param tgt_tokens: bsz x max_len, 包含了的[sos, token, eos]
|
||||
:param pred: bsz x max_len-1 x vocab_size
|
||||
:return:
|
||||
"""
|
||||
tgt_seq_len = tgt_seq_len - 1
|
||||
mask = seq_to_mask(tgt_seq_len, max_len=tgt_tokens.size(1) - 1).eq(0)
|
||||
tgt_tokens = tgt_tokens[:, 1:].masked_fill(mask, -100)
|
||||
loss = F.cross_entropy(target=tgt_tokens, input=pred.transpose(1, 2))
|
||||
return loss
|
||||
|
||||
def _get_model_device(model):
|
||||
assert isinstance(model, nn.Module)
|
||||
|
||||
parameters = list(model.parameters())
|
||||
if len(parameters) == 0:
|
||||
return None
|
||||
else:
|
||||
return parameters[0].device
|
||||
|
||||
def set_seed(seed=2021):
|
||||
"""sets random seed"""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def convert_preds_to_outputs(preds, raw_words, mapping, tokenizer):
|
||||
"""convet model predicitons to BIO outputs
|
||||
|
||||
Args:
|
||||
preds ([torch.Tensor]): [prompt model predictions, (bsz x seq_len x labels)]
|
||||
raw_words ([List]): [source raw words]
|
||||
mapping ([dict]): [map entity labels to <<>>]
|
||||
tokenizer : [BartTokenizer]
|
||||
|
||||
Returns:
|
||||
[outputs (List)]: [each item length equal to raw_words, BIO format.]
|
||||
"""
|
||||
id2label = list(mapping.keys())
|
||||
pred_eos_index = preds.flip(dims=[1]).eq(1).cumsum(dim=1).long()
|
||||
preds = preds[:, 1:]
|
||||
pred_seq_len = pred_eos_index.flip(dims=[1]).eq(pred_eos_index[:, -1:]).sum(dim=1) # bsz
|
||||
pred_seq_len = (pred_seq_len - 2).tolist()
|
||||
|
||||
word_start_index = len(mapping) + 2
|
||||
outputs = []
|
||||
for i, pred_item in enumerate(preds.tolist()):
|
||||
pred_item = pred_item[:pred_seq_len[i]] # single sentence prediction
|
||||
pairs, cur_pair = [], []
|
||||
if len(pred_item): # this sentence prediciton= is not null
|
||||
for idx in pred_item:
|
||||
if idx < word_start_index: # is entity
|
||||
if len(cur_pair) > 0:
|
||||
# assert word[i] < word[i+1]
|
||||
if all([cur_pair[i] < cur_pair[i + 1] for i in range(len(cur_pair) - 1)]):
|
||||
pairs.append(tuple(cur_pair + [idx])) # add valid words and current entity id
|
||||
cur_pair = [] # clear word pairs
|
||||
else: # is word
|
||||
cur_pair.append(idx) # add word id to word pairs
|
||||
raw_words_item = raw_words[i]
|
||||
cum_lens = [1]
|
||||
start_idx = 1
|
||||
for word in raw_words_item:
|
||||
start_idx += len(tokenizer.tokenize(word, add_prefix_space=True))
|
||||
cum_lens.append(start_idx)
|
||||
cum_lens.append(start_idx+1)
|
||||
output = ['O' for _ in range(len(raw_words_item))]
|
||||
# pairs: List[(word id, ... , entity id), (...), ...]
|
||||
for pair in pairs: # (word id, ... , entity id)
|
||||
entity = pair[-1]
|
||||
words = []
|
||||
for word in pair[:-1]:
|
||||
if word-word_start_index in cum_lens:
|
||||
words.append(cum_lens.index(word-word_start_index))
|
||||
if len(words) == 0: continue
|
||||
start_idx = words[0]
|
||||
end_idx = words[-1]
|
||||
output[start_idx] = f'B-{id2label[entity-2]}'
|
||||
for _ in range(start_idx+1, end_idx+1):
|
||||
output[_] = f'I-{id2label[entity-2]}'
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
|
||||
def write_predictions(path, texts, labels):
|
||||
"""[write model predictions to path (conll format)]
|
||||
|
||||
Args:
|
||||
path ([str]): [save path]
|
||||
texts ([List]): [raw texts]
|
||||
labels ([List]): [predict labels]
|
||||
"""
|
||||
print(len(texts), len(labels))
|
||||
assert len(texts) == len(labels)
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
f.writelines("-DOCSTART- O\n\n")
|
||||
for i in range(len(texts)):
|
||||
for j in range(len(texts[i])):
|
||||
f.writelines("{}\t{}\n".format(texts[i][j], labels[i][j]))
|
||||
f.writelines("\n")
|
Loading…
Reference in New Issue