support dict output for basemodel

This commit is contained in:
littletomatodonkey 2021-06-03 05:57:31 +00:00
parent e5d3a2d880
commit ab4db2acce
6 changed files with 63 additions and 8 deletions

View File

@ -39,6 +39,7 @@ Architecture:
Student:
pretrained: null
freeze_params: false
return_all_feats: true
model_type: rec
algorithm: CRNN
Transform:
@ -57,6 +58,7 @@ Architecture:
Teacher:
pretrained: null
freeze_params: false
return_all_feats: true
model_type: rec
algorithm: CRNN
Transform:
@ -80,18 +82,26 @@ Loss:
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: null
key: head_out
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
key: null
key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
PostProcess:
name: DistillationCTCLabelDecode
model_name: "Student"
key_out: null
key: head_out
Metric:
name: RecMetric

View File

@ -97,6 +97,7 @@ class DistanceLoss(nn.Layer):
"""
def __init__(self, mode="l2", name="loss_dist", **kargs):
super().__init__()
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
self.loss_func = nn.L1Loss(**kargs)

View File

@ -17,6 +17,7 @@ import paddle.nn as nn
from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss
class CombinedLoss(nn.Layer):

View File

@ -17,6 +17,7 @@ import paddle.nn as nn
from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
class DistillationDMLLoss(DMLLoss):
@ -69,3 +70,36 @@ class DistillationCTCLoss(CTCLoss):
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationDistanceLoss(DistanceLoss):
"""
"""
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, name=name)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
return loss_dict

View File

@ -67,14 +67,23 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None):
y = dict()
if self.use_transform:
x = self.transform(x)
x = self.backbone(x)
y["backbone_out"] = x
if self.use_neck:
x = self.neck(x)
y["neck_out"] = x
if data is None:
x = self.head(x)
else:
x = self.head(x, data)
return x
y["head_out"] = x
if self.return_all_feats:
return y
else:
return x

View File

@ -136,17 +136,17 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
character_type='ch',
use_space_char=False,
model_name="student",
key_out=None,
key=None,
**kwargs):
super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char)
self.model_name = model_name
self.key_out = key_out
self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
pred = preds[self.model_name]
if self.key_out is not None:
pred = pred[self.key_out]
if self.key is not None:
pred = pred[self.key]
return super().__call__(pred, label=label, *args, **kwargs)