diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py index 54da7017..639265ed 100644 --- a/ppocr/losses/combined_loss.py +++ b/ppocr/losses/combined_loss.py @@ -44,15 +44,17 @@ class CombinedLoss(nn.Layer): def forward(self, input, batch, **kargs): loss_dict = {} + loss_all = 0. for idx, loss_func in enumerate(self.loss_func): loss = loss_func(input, batch, **kargs) if isinstance(loss, paddle.Tensor): loss = {"loss_{}_{}".format(str(loss), idx): loss} weight = self.loss_weight[idx] - loss = { - "{}_{}".format(key, idx): loss[key] * weight - for key in loss - } + for key in loss: + if key == "loss": + loss_all += loss[key] * weight + # else: + # loss[f"{key}_{idx}"] = loss[key] loss_dict.update(loss) - loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) + loss_dict["loss"] = loss_all return loss_dict