fix cls type
This commit is contained in:
parent
a81b88a01a
commit
66c3294cd2
|
@ -25,6 +25,6 @@ class ClsLoss(nn.Layer):
|
|||
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
label = batch[1]
|
||||
label = batch[1].astype("int64")
|
||||
loss = self.loss_func(input=predicts, label=label)
|
||||
return {'loss': loss}
|
||||
|
|
Loading…
Reference in New Issue