fix bug
This commit is contained in:
parent
f96b4ca300
commit
75bcfb3e04
|
@ -33,7 +33,7 @@ class TrainNer(BertForTokenClassification):
|
|||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,valid_ids=None,attention_mask_label=None,device=None):
|
||||
sequence_output = self.bert(input_ids, token_type_ids, attention_mask,head_mask=None)[0]
|
||||
batch_size,max_len,feat_dim = sequence_output.shape
|
||||
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=device) #device=cfg.gpu_id if use_gpu else 'cpu'
|
||||
valid_output = torch.zeros(batch_size,max_len,feat_dim,dtype=torch.float32,device=device)
|
||||
for i in range(batch_size):
|
||||
jj = -1
|
||||
for j in range(max_len):
|
||||
|
|
Loading…
Reference in New Issue