support dict output for basemodel
This commit is contained in:
parent
e5d3a2d880
commit
ab4db2acce
|
@ -39,6 +39,7 @@ Architecture:
|
|||
Student:
|
||||
pretrained: null
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
|
@ -57,6 +58,7 @@ Architecture:
|
|||
Teacher:
|
||||
pretrained: null
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
|
@ -80,18 +82,26 @@ Loss:
|
|||
- DistillationCTCLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
key: null
|
||||
key: head_out
|
||||
- DistillationDMLLoss:
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: null
|
||||
key: head_out
|
||||
- DistillationDistanceLoss:
|
||||
weight: 1.0
|
||||
mode: "l2"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: backbone_out
|
||||
|
||||
|
||||
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode
|
||||
model_name: "Student"
|
||||
key_out: null
|
||||
key: head_out
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
|
|
|
@ -97,6 +97,7 @@ class DistanceLoss(nn.Layer):
|
|||
"""
|
||||
|
||||
def __init__(self, mode="l2", name="loss_dist", **kargs):
|
||||
super().__init__()
|
||||
assert mode in ["l1", "l2", "smooth_l1"]
|
||||
if mode == "l1":
|
||||
self.loss_func = nn.L1Loss(**kargs)
|
||||
|
|
|
@ -17,6 +17,7 @@ import paddle.nn as nn
|
|||
|
||||
from .distillation_loss import DistillationCTCLoss
|
||||
from .distillation_loss import DistillationDMLLoss
|
||||
from .distillation_loss import DistillationDistanceLoss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Layer):
|
||||
|
|
|
@ -17,6 +17,7 @@ import paddle.nn as nn
|
|||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .basic_loss import DMLLoss
|
||||
from .basic_loss import DistanceLoss
|
||||
|
||||
|
||||
class DistillationDMLLoss(DMLLoss):
|
||||
|
@ -69,3 +70,36 @@ class DistillationCTCLoss(CTCLoss):
|
|||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDistanceLoss(DistanceLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mode="l2",
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
name="loss_distance",
|
||||
**kargs):
|
||||
super().__init__(mode=mode, name=name)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
|
||||
key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
return loss_dict
|
||||
|
|
|
@ -67,14 +67,23 @@ class BaseModel(nn.Layer):
|
|||
config["Head"]['in_channels'] = in_channels
|
||||
self.head = build_head(config["Head"])
|
||||
|
||||
self.return_all_feats = config.get("return_all_feats", False)
|
||||
|
||||
def forward(self, x, data=None):
|
||||
y = dict()
|
||||
if self.use_transform:
|
||||
x = self.transform(x)
|
||||
x = self.backbone(x)
|
||||
y["backbone_out"] = x
|
||||
if self.use_neck:
|
||||
x = self.neck(x)
|
||||
y["neck_out"] = x
|
||||
if data is None:
|
||||
x = self.head(x)
|
||||
else:
|
||||
x = self.head(x, data)
|
||||
return x
|
||||
y["head_out"] = x
|
||||
if self.return_all_feats:
|
||||
return y
|
||||
else:
|
||||
return x
|
||||
|
|
|
@ -136,17 +136,17 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
|
|||
character_type='ch',
|
||||
use_space_char=False,
|
||||
model_name="student",
|
||||
key_out=None,
|
||||
key=None,
|
||||
**kwargs):
|
||||
super(DistillationCTCLabelDecode, self).__init__(
|
||||
character_dict_path, character_type, use_space_char)
|
||||
self.model_name = model_name
|
||||
self.key_out = key_out
|
||||
self.key = key
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
pred = preds[self.model_name]
|
||||
if self.key_out is not None:
|
||||
pred = pred[self.key_out]
|
||||
if self.key is not None:
|
||||
pred = pred[self.key]
|
||||
return super().__call__(pred, label=label, *args, **kwargs)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue