fix bug
This commit is contained in:
parent
185d1e1f92
commit
6ce4419819
|
@ -88,7 +88,7 @@ Loss:
|
|||
- DistillationDMLLoss:
|
||||
model_name_pairs:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: ["thrink_maps"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
|
@ -96,7 +96,7 @@ Loss:
|
|||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
key: maps
|
||||
# key: maps
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
|
|
|
@ -50,11 +50,11 @@ class CombinedLoss(nn.Layer):
|
|||
if isinstance(loss, paddle.Tensor):
|
||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||
weight = self.loss_weight[idx]
|
||||
for key in loss:
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
loss_all += loss[key] * weight
|
||||
else:
|
||||
loss["{}_{}".format(key, idx)] = loss[key]
|
||||
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
||||
# loss[f"{key}_{idx}"] = loss[key]
|
||||
loss_dict.update(loss)
|
||||
loss_dict["loss"] = loss_all
|
||||
|
|
|
@ -24,7 +24,6 @@ from .det_db_loss import DBLoss
|
|||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||
|
||||
|
||||
|
||||
def _sum_loss(loss_dict):
|
||||
if "loss" in loss_dict.keys():
|
||||
return loss_dict
|
||||
|
@ -51,10 +50,18 @@ class DistillationDMLLoss(DMLLoss):
|
|||
super().__init__(act=act)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
self.name = name
|
||||
self.maps_name = maps_name
|
||||
|
||||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
return []
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
|
||||
return model_name_pairs
|
||||
else:
|
||||
return [model_name_pairs]
|
||||
|
||||
def _check_maps_name(self, maps_name):
|
||||
if maps_name is None:
|
||||
return None
|
||||
|
@ -69,13 +76,14 @@ class DistillationDMLLoss(DMLLoss):
|
|||
new_outs = {}
|
||||
for k in self.maps_name:
|
||||
if k == "thrink_maps":
|
||||
new_outs[k] = paddle.slice(outs, axes=1, starts=0, ends=1)
|
||||
new_outs[k] = paddle.slice(outs, axes=[1], starts=[0], ends=[1])
|
||||
elif k == "threshold_maps":
|
||||
new_outs[k] = paddle.slice(outs, axes=1, starts=1, ends=2)
|
||||
new_outs[k] = paddle.slice(outs, axes=[1], starts=[1], ends=[2])
|
||||
elif k == "binary_maps":
|
||||
new_outs[k] = paddle.slice(outs, axes=1, starts=2, ends=3)
|
||||
new_outs[k] = paddle.slice(outs, axes=[1], starts=[2], ends=[3])
|
||||
else:
|
||||
continue
|
||||
return new_outs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
|
@ -104,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
|
|||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], map_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, map_name,
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name,
|
||||
idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
@ -151,7 +159,7 @@ class DistillationDBLoss(DBLoss):
|
|||
self.name = name
|
||||
self.key = None
|
||||
|
||||
def forward(self, preicts, batch):
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = {}
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
|
|
|
@ -34,7 +34,8 @@ def build_post_process(config, global_config=None):
|
|||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -200,11 +200,8 @@ class DistillationDBPostProcess(DBPostProcess):
|
|||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
**kwargs):
|
||||
super(DistillationDBPostProcess, self).__init__(thresh,
|
||||
box_thresh,
|
||||
max_candidates,
|
||||
unclip_ratio,
|
||||
use_dilation,
|
||||
super(DistillationDBPostProcess, self).__init__(
|
||||
thresh, box_thresh, max_candidates, unclip_ratio, use_dilation,
|
||||
score_mode)
|
||||
if not isinstance(model_name, list):
|
||||
model_name = [model_name]
|
||||
|
@ -221,10 +218,3 @@ class DistillationDBPostProcess(DBPostProcess):
|
|||
results[name] = super().__call__(pred, shape_list=label)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue