This commit is contained in:
tlk-dsg 2021-12-20 15:15:16 +08:00
parent 291d937f0e
commit 3fde66353d
1 changed files with 4 additions and 4 deletions

View File

@ -30,10 +30,10 @@ logger = logging.getLogger(__name__)
class TrainNer(BertForTokenClassification):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None,device=None):
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
batch_size,max_len,feat_dim = sequence_output.shape
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=1) #device=cfg.gpu_id if use_gpu else 'cpu'
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=device) #device=cfg.gpu_id if use_gpu else 'cpu'
for i in range(batch_size):
jj = -1
for j in range(max_len):
@ -136,7 +136,7 @@ def main(cfg):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids, valid_ids,l_mask = batch
loss = model(input_ids, segment_ids, input_mask, label_ids,valid_ids,l_mask)
loss = model(input_ids, segment_ids, input_mask, label_ids,valid_ids,l_mask,device)
if cfg.gradient_accumulation_steps > 1:
loss = loss / cfg.gradient_accumulation_steps
@ -202,7 +202,7 @@ def main(cfg):
l_mask = l_mask.to(device)
with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask,valid_ids=valid_ids,attention_mask_label=l_mask)
logits = model(input_ids, segment_ids, input_mask,valid_ids=valid_ids,attention_mask_label=l_mask,device=device)
logits = torch.argmax(F.log_softmax(logits,dim=2),dim=2)
logits = logits.detach().cpu().numpy()