This commit is contained in:
LDOUBLEV 2021-07-07 02:45:32 +00:00
parent 185d1e1f92
commit 6ce4419819
5 changed files with 24 additions and 25 deletions

View File

@ -88,7 +88,7 @@ Loss:
- DistillationDMLLoss: - DistillationDMLLoss:
model_name_pairs: model_name_pairs:
- ["Student", "Student2"] - ["Student", "Student2"]
maps_name: ["thrink_maps"] maps_name: "thrink_maps"
weight: 1.0 weight: 1.0
act: "softmax" act: "softmax"
model_name_pairs: ["Student", "Student2"] model_name_pairs: ["Student", "Student2"]
@ -96,7 +96,7 @@ Loss:
- DistillationDBLoss: - DistillationDBLoss:
weight: 1.0 weight: 1.0
model_name_list: ["Student", "Student2"] model_name_list: ["Student", "Student2"]
key: maps # key: maps
name: DBLoss name: DBLoss
balance_loss: true balance_loss: true
main_loss_type: DiceLoss main_loss_type: DiceLoss

View File

@ -50,11 +50,11 @@ class CombinedLoss(nn.Layer):
if isinstance(loss, paddle.Tensor): if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss} loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx] weight = self.loss_weight[idx]
for key in loss: for key in loss.keys():
if key == "loss": if key == "loss":
loss_all += loss[key] * weight loss_all += loss[key] * weight
else: else:
loss["{}_{}".format(key, idx)] = loss[key] loss_dict["{}_{}".format(key, idx)] = loss[key]
# loss[f"{key}_{idx}"] = loss[key] # loss[f"{key}_{idx}"] = loss[key]
loss_dict.update(loss) loss_dict.update(loss)
loss_dict["loss"] = loss_all loss_dict["loss"] = loss_all

View File

@ -24,7 +24,6 @@ from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
def _sum_loss(loss_dict): def _sum_loss(loss_dict):
if "loss" in loss_dict.keys(): if "loss" in loss_dict.keys():
return loss_dict return loss_dict
@ -51,10 +50,18 @@ class DistillationDMLLoss(DMLLoss):
super().__init__(act=act) super().__init__(act=act)
assert isinstance(model_name_pairs, list) assert isinstance(model_name_pairs, list)
self.key = key self.key = key
self.model_name_pairs = model_name_pairs self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name self.name = name
self.maps_name = maps_name self.maps_name = maps_name
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def _check_maps_name(self, maps_name): def _check_maps_name(self, maps_name):
if maps_name is None: if maps_name is None:
return None return None
@ -69,13 +76,14 @@ class DistillationDMLLoss(DMLLoss):
new_outs = {} new_outs = {}
for k in self.maps_name: for k in self.maps_name:
if k == "thrink_maps": if k == "thrink_maps":
new_outs[k] = paddle.slice(outs, axes=1, starts=0, ends=1) new_outs[k] = paddle.slice(outs, axes=[1], starts=[0], ends=[1])
elif k == "threshold_maps": elif k == "threshold_maps":
new_outs[k] = paddle.slice(outs, axes=1, starts=1, ends=2) new_outs[k] = paddle.slice(outs, axes=[1], starts=[1], ends=[2])
elif k == "binary_maps": elif k == "binary_maps":
new_outs[k] = paddle.slice(outs, axes=1, starts=2, ends=3) new_outs[k] = paddle.slice(outs, axes=[1], starts=[2], ends=[3])
else: else:
continue continue
return new_outs
def forward(self, predicts, batch): def forward(self, predicts, batch):
loss_dict = dict() loss_dict = dict()
@ -104,7 +112,7 @@ class DistillationDMLLoss(DMLLoss):
loss_dict["{}_{}_{}_{}_{}".format(key, pair[ loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], map_name, idx)] = loss[key] 0], pair[1], map_name, idx)] = loss[key]
else: else:
loss_dict["{}_{}_{}".format(self.name, map_name, loss_dict["{}_{}_{}".format(self.name, self.maps_name,
idx)] = loss idx)] = loss
loss_dict = _sum_loss(loss_dict) loss_dict = _sum_loss(loss_dict)
@ -151,7 +159,7 @@ class DistillationDBLoss(DBLoss):
self.name = name self.name = name
self.key = None self.key = None
def forward(self, preicts, batch): def forward(self, predicts, batch):
loss_dict = {} loss_dict = {}
for idx, model_name in enumerate(self.model_name_list): for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name] out = predicts[model_name]

View File

@ -34,7 +34,8 @@ def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess' 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)

View File

@ -200,11 +200,8 @@ class DistillationDBPostProcess(DBPostProcess):
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
**kwargs): **kwargs):
super(DistillationDBPostProcess, self).__init__(thresh, super(DistillationDBPostProcess, self).__init__(
box_thresh, thresh, box_thresh, max_candidates, unclip_ratio, use_dilation,
max_candidates,
unclip_ratio,
use_dilation,
score_mode) score_mode)
if not isinstance(model_name, list): if not isinstance(model_name, list):
model_name = [model_name] model_name = [model_name]
@ -221,10 +218,3 @@ class DistillationDBPostProcess(DBPostProcess):
results[name] = super().__call__(pred, shape_list=label) results[name] = super().__call__(pred, shape_list=label)
return results return results