This commit is contained in:
tlk-dsg 2021-12-20 15:16:53 +08:00
parent f96b4ca300
commit 75bcfb3e04
1 changed files with 1 additions and 1 deletions

View File

@ -33,7 +33,7 @@ class TrainNer(BertForTokenClassification):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None,device=None):
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
batch_size,max_len,feat_dim = sequence_output.shape
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=device) #device=cfg.gpu_id if use_gpu else 'cpu'
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=device)
for i in range(batch_size):
jj = -1
for j in range(max_len):