fix metric
This commit is contained in:
parent
b48f760982
commit
0343756e46
|
@ -95,17 +95,17 @@ Loss:
|
|||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: backbone_out
|
||||
|
||||
|
||||
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode
|
||||
model_name: "Student"
|
||||
model_name: ["Student", "Teacher"]
|
||||
key: head_out
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
name: DistillationMetric
|
||||
base_metric_name: RecMetric
|
||||
main_indicator: acc
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
|
|
|
@ -22,9 +22,8 @@ from paddle.nn import SmoothL1Loss
|
|||
|
||||
|
||||
class CELoss(nn.Layer):
|
||||
def __init__(self, name="loss_ce", epsilon=None):
|
||||
def __init__(self, epsilon=None):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
|
||||
epsilon = None
|
||||
self.epsilon = epsilon
|
||||
|
@ -52,9 +51,7 @@ class CELoss(nn.Layer):
|
|||
else:
|
||||
soft_label = False
|
||||
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
|
||||
|
||||
loss_dict[self.name] = paddle.mean(loss)
|
||||
return loss_dict
|
||||
return loss
|
||||
|
||||
|
||||
class DMLLoss(nn.Layer):
|
||||
|
@ -62,11 +59,10 @@ class DMLLoss(nn.Layer):
|
|||
DMLLoss
|
||||
"""
|
||||
|
||||
def __init__(self, act=None, name="loss_dml"):
|
||||
def __init__(self, act=None):
|
||||
super().__init__()
|
||||
if act is not None:
|
||||
assert act in ["softmax", "sigmoid"]
|
||||
self.name = name
|
||||
if act == "softmax":
|
||||
self.act = nn.Softmax(axis=-1)
|
||||
elif act == "sigmoid":
|
||||
|
@ -75,7 +71,6 @@ class DMLLoss(nn.Layer):
|
|||
self.act = None
|
||||
|
||||
def forward(self, out1, out2):
|
||||
loss_dict = {}
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1)
|
||||
out2 = self.act(out2)
|
||||
|
@ -85,18 +80,16 @@ class DMLLoss(nn.Layer):
|
|||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, log_out1, reduction='batchmean')) / 2.0
|
||||
loss_dict[self.name] = loss
|
||||
return loss_dict
|
||||
return loss
|
||||
|
||||
|
||||
class DistanceLoss(nn.Layer):
|
||||
"""
|
||||
DistanceLoss:
|
||||
mode: loss mode
|
||||
name: loss key in the output dict
|
||||
"""
|
||||
|
||||
def __init__(self, mode="l2", name="loss_dist", **kargs):
|
||||
def __init__(self, mode="l2", **kargs):
|
||||
super().__init__()
|
||||
assert mode in ["l1", "l2", "smooth_l1"]
|
||||
if mode == "l1":
|
||||
|
@ -106,7 +99,5 @@ class DistanceLoss(nn.Layer):
|
|||
elif mode == "smooth_l1":
|
||||
self.loss_func = nn.SmoothL1Loss(**kargs)
|
||||
|
||||
self.name = "{}_{}".format(name, mode)
|
||||
|
||||
def forward(self, x, y):
|
||||
return {self.name: self.loss_func(x, y)}
|
||||
return self.loss_func(x, y)
|
||||
|
|
|
@ -26,10 +26,11 @@ class DistillationDMLLoss(DMLLoss):
|
|||
|
||||
def __init__(self, model_name_pairs=[], act=None, key=None,
|
||||
name="loss_dml"):
|
||||
super().__init__(act=act, name=name)
|
||||
super().__init__(act=act)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
|
@ -42,8 +43,8 @@ class DistillationDMLLoss(DMLLoss):
|
|||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
|
||||
key]
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
return loss_dict
|
||||
|
@ -82,10 +83,11 @@ class DistillationDistanceLoss(DistanceLoss):
|
|||
key=None,
|
||||
name="loss_distance",
|
||||
**kargs):
|
||||
super().__init__(mode=mode, name=name, **kargs)
|
||||
super().__init__(mode=mode, **kargs)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name + "_l2"
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
|
@ -101,5 +103,6 @@ class DistillationDistanceLoss(DistanceLoss):
|
|||
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
|
||||
key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
|
||||
idx)] = loss
|
||||
return loss_dict
|
||||
|
|
|
@ -19,20 +19,23 @@ from __future__ import unicode_literals
|
|||
|
||||
import copy
|
||||
|
||||
__all__ = ['build_metric']
|
||||
__all__ = ["build_metric"]
|
||||
|
||||
from .det_metric import DetMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .distillation_metric import DistillationMetric
|
||||
|
||||
|
||||
def build_metric(config):
|
||||
from .det_metric import DetMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
|
||||
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric']
|
||||
support_dict = [
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
module_name = config.pop("name")
|
||||
assert module_name in support_dict, Exception(
|
||||
'metric only support {}'.format(support_dict))
|
||||
"metric only support {}".format(support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
|
|
@ -135,19 +135,25 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
|
|||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
model_name="student",
|
||||
model_name=["student"],
|
||||
key=None,
|
||||
**kwargs):
|
||||
super(DistillationCTCLabelDecode, self).__init__(
|
||||
character_dict_path, character_type, use_space_char)
|
||||
if not isinstance(model_name, list):
|
||||
model_name = [model_name]
|
||||
self.model_name = model_name
|
||||
|
||||
self.key = key
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
pred = preds[self.model_name]
|
||||
if self.key is not None:
|
||||
pred = pred[self.key]
|
||||
return super().__call__(pred, label=label, *args, **kwargs)
|
||||
output = dict()
|
||||
for name in self.model_name:
|
||||
pred = preds[name]
|
||||
if self.key is not None:
|
||||
pred = pred[self.key]
|
||||
output[name] = super().__call__(pred, label=label, *args, **kwargs)
|
||||
return output
|
||||
|
||||
|
||||
class AttnLabelDecode(BaseRecLabelDecode):
|
||||
|
|
Loading…
Reference in New Issue