fix bug
This commit is contained in:
parent
185d1e1f92
commit
6ce4419819
|
@ -88,7 +88,7 @@ Loss:
|
||||||
- DistillationDMLLoss:
|
- DistillationDMLLoss:
|
||||||
model_name_pairs:
|
model_name_pairs:
|
||||||
- ["Student", "Student2"]
|
- ["Student", "Student2"]
|
||||||
maps_name: ["thrink_maps"]
|
maps_name: "thrink_maps"
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
act: "softmax"
|
act: "softmax"
|
||||||
model_name_pairs: ["Student", "Student2"]
|
model_name_pairs: ["Student", "Student2"]
|
||||||
|
@ -96,7 +96,7 @@ Loss:
|
||||||
- DistillationDBLoss:
|
- DistillationDBLoss:
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
model_name_list: ["Student", "Student2"]
|
model_name_list: ["Student", "Student2"]
|
||||||
key: maps
|
# key: maps
|
||||||
name: DBLoss
|
name: DBLoss
|
||||||
balance_loss: true
|
balance_loss: true
|
||||||
main_loss_type: DiceLoss
|
main_loss_type: DiceLoss
|
||||||
|
|
|
@ -50,11 +50,11 @@ class CombinedLoss(nn.Layer):
|
||||||
if isinstance(loss, paddle.Tensor):
|
if isinstance(loss, paddle.Tensor):
|
||||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||||
weight = self.loss_weight[idx]
|
weight = self.loss_weight[idx]
|
||||||
for key in loss:
|
for key in loss.keys():
|
||||||
if key == "loss":
|
if key == "loss":
|
||||||
loss_all += loss[key] * weight
|
loss_all += loss[key] * weight
|
||||||
else:
|
else:
|
||||||
loss["{}_{}".format(key, idx)] = loss[key]
|
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
||||||
# loss[f"{key}_{idx}"] = loss[key]
|
# loss[f"{key}_{idx}"] = loss[key]
|
||||||
loss_dict.update(loss)
|
loss_dict.update(loss)
|
||||||
loss_dict["loss"] = loss_all
|
loss_dict["loss"] = loss_all
|
||||||
|
|
|
@ -24,7 +24,6 @@ from .det_db_loss import DBLoss
|
||||||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _sum_loss(loss_dict):
|
def _sum_loss(loss_dict):
|
||||||
if "loss" in loss_dict.keys():
|
if "loss" in loss_dict.keys():
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
@ -51,10 +50,18 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
super().__init__(act=act)
|
super().__init__(act=act)
|
||||||
assert isinstance(model_name_pairs, list)
|
assert isinstance(model_name_pairs, list)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_name_pairs = model_name_pairs
|
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||||
self.name = name
|
self.name = name
|
||||||
self.maps_name = maps_name
|
self.maps_name = maps_name
|
||||||
|
|
||||||
|
def _check_model_name_pairs(self, model_name_pairs):
|
||||||
|
if not isinstance(model_name_pairs, list):
|
||||||
|
return []
|
||||||
|
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
|
||||||
|
return model_name_pairs
|
||||||
|
else:
|
||||||
|
return [model_name_pairs]
|
||||||
|
|
||||||
def _check_maps_name(self, maps_name):
|
def _check_maps_name(self, maps_name):
|
||||||
if maps_name is None:
|
if maps_name is None:
|
||||||
return None
|
return None
|
||||||
|
@ -69,13 +76,14 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
new_outs = {}
|
new_outs = {}
|
||||||
for k in self.maps_name:
|
for k in self.maps_name:
|
||||||
if k == "thrink_maps":
|
if k == "thrink_maps":
|
||||||
new_outs[k] = paddle.slice(outs, axes=1, starts=0, ends=1)
|
new_outs[k] = paddle.slice(outs, axes=[1], starts=[0], ends=[1])
|
||||||
elif k == "threshold_maps":
|
elif k == "threshold_maps":
|
||||||
new_outs[k] = paddle.slice(outs, axes=1, starts=1, ends=2)
|
new_outs[k] = paddle.slice(outs, axes=[1], starts=[1], ends=[2])
|
||||||
elif k == "binary_maps":
|
elif k == "binary_maps":
|
||||||
new_outs[k] = paddle.slice(outs, axes=1, starts=2, ends=3)
|
new_outs[k] = paddle.slice(outs, axes=[1], starts=[2], ends=[3])
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
return new_outs
|
||||||
|
|
||||||
def forward(self, predicts, batch):
|
def forward(self, predicts, batch):
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
|
@ -104,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||||
0], pair[1], map_name, idx)] = loss[key]
|
0], pair[1], map_name, idx)] = loss[key]
|
||||||
else:
|
else:
|
||||||
loss_dict["{}_{}_{}".format(self.name, map_name,
|
loss_dict["{}_{}_{}".format(self.name, self.maps_name,
|
||||||
idx)] = loss
|
idx)] = loss
|
||||||
|
|
||||||
loss_dict = _sum_loss(loss_dict)
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
@ -151,7 +159,7 @@ class DistillationDBLoss(DBLoss):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.key = None
|
self.key = None
|
||||||
|
|
||||||
def forward(self, preicts, batch):
|
def forward(self, predicts, batch):
|
||||||
loss_dict = {}
|
loss_dict = {}
|
||||||
for idx, model_name in enumerate(self.model_name_list):
|
for idx, model_name in enumerate(self.model_name_list):
|
||||||
out = predicts[model_name]
|
out = predicts[model_name]
|
||||||
|
|
|
@ -34,7 +34,8 @@ def build_post_process(config, global_config=None):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||||
'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
|
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||||
|
'DistillationDBPostProcess'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -200,11 +200,8 @@ class DistillationDBPostProcess(DBPostProcess):
|
||||||
use_dilation=False,
|
use_dilation=False,
|
||||||
score_mode="fast",
|
score_mode="fast",
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(DistillationDBPostProcess, self).__init__(thresh,
|
super(DistillationDBPostProcess, self).__init__(
|
||||||
box_thresh,
|
thresh, box_thresh, max_candidates, unclip_ratio, use_dilation,
|
||||||
max_candidates,
|
|
||||||
unclip_ratio,
|
|
||||||
use_dilation,
|
|
||||||
score_mode)
|
score_mode)
|
||||||
if not isinstance(model_name, list):
|
if not isinstance(model_name, list):
|
||||||
model_name = [model_name]
|
model_name = [model_name]
|
||||||
|
@ -221,10 +218,3 @@ class DistillationDBPostProcess(DBPostProcess):
|
||||||
results[name] = super().__call__(pred, shape_list=label)
|
results[name] = super().__call__(pred, shape_list=label)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue