This commit is contained in:
LDOUBLEV 2021-07-07 02:45:32 +00:00
parent 185d1e1f92
commit 6ce4419819
5 changed files with 24 additions and 25 deletions

View File

@ -88,7 +88,7 @@ Loss:
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: ["thrink_maps"]
maps_name: "thrink_maps"
weight: 1.0
act: "softmax"
model_name_pairs: ["Student", "Student2"]
@ -96,7 +96,7 @@ Loss:
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
key: maps
# key: maps
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss

View File

@ -50,11 +50,11 @@ class CombinedLoss(nn.Layer):
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx]
for key in loss:
for key in loss.keys():
if key == "loss":
loss_all += loss[key] * weight
else:
loss["{}_{}".format(key, idx)] = loss[key]
loss_dict["{}_{}".format(key, idx)] = loss[key]
# loss[f"{key}_{idx}"] = loss[key]
loss_dict.update(loss)
loss_dict["loss"] = loss_all

View File

@ -24,7 +24,6 @@ from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
def _sum_loss(loss_dict):
if "loss" in loss_dict.keys():
return loss_dict
@ -51,10 +50,18 @@ class DistillationDMLLoss(DMLLoss):
super().__init__(act=act)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = maps_name
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def _check_maps_name(self, maps_name):
if maps_name is None:
return None
@ -69,13 +76,14 @@ class DistillationDMLLoss(DMLLoss):
new_outs = {}
for k in self.maps_name:
if k == "thrink_maps":
new_outs[k] = paddle.slice(outs, axes=1, starts=0, ends=1)
new_outs[k] = paddle.slice(outs, axes=[1], starts=[0], ends=[1])
elif k == "threshold_maps":
new_outs[k] = paddle.slice(outs, axes=1, starts=1, ends=2)
new_outs[k] = paddle.slice(outs, axes=[1], starts=[1], ends=[2])
elif k == "binary_maps":
new_outs[k] = paddle.slice(outs, axes=1, starts=2, ends=3)
new_outs[k] = paddle.slice(outs, axes=[1], starts=[2], ends=[3])
else:
continue
return new_outs
def forward(self, predicts, batch):
loss_dict = dict()
@ -104,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], map_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, map_name,
loss_dict["{}_{}_{}".format(self.name, self.maps_name,
idx)] = loss
loss_dict = _sum_loss(loss_dict)
@ -151,7 +159,7 @@ class DistillationDBLoss(DBLoss):
self.name = name
self.key = None
def forward(self, preicts, batch):
def forward(self, predicts, batch):
loss_dict = {}
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]

View File

@ -34,7 +34,8 @@ def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess'
]
config = copy.deepcopy(config)

View File

@ -200,12 +200,9 @@ class DistillationDBPostProcess(DBPostProcess):
use_dilation=False,
score_mode="fast",
**kwargs):
super(DistillationDBPostProcess, self).__init__(thresh,
box_thresh,
max_candidates,
unclip_ratio,
use_dilation,
score_mode)
super(DistillationDBPostProcess, self).__init__(
thresh, box_thresh, max_candidates, unclip_ratio, use_dilation,
score_mode)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
@ -221,10 +218,3 @@ class DistillationDBPostProcess(DBPostProcess):
results[name] = super().__call__(pred, shape_list=label)
return results