fix metric etc.al
This commit is contained in:
parent
a7b32ca82b
commit
0742f5c521
|
@ -90,14 +90,14 @@ Loss:
|
|||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
# act: None
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
name: DBLoss
|
||||
# name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
|
@ -119,8 +119,8 @@ Optimizer:
|
|||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2"]
|
||||
key: head_out
|
||||
model_name: ["Student", "Student2", "Teacher"]
|
||||
# key: maps
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
|
|
|
@ -54,6 +54,27 @@ class CELoss(nn.Layer):
|
|||
return loss
|
||||
|
||||
|
||||
class KLJSLoss(object):
|
||||
def __init__(self, mode='kl'):
|
||||
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, p1, p2, reduction="mean"):
|
||||
|
||||
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
|
||||
|
||||
if self.mode.lower() == "js":
|
||||
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
|
||||
loss *= 0.5
|
||||
if reduction == "mean":
|
||||
loss = paddle.mean(loss, axis=[1,2])
|
||||
elif reduction=="none" or reduction is None:
|
||||
return loss
|
||||
else:
|
||||
loss = paddle.sum(loss, axis=[1,2])
|
||||
|
||||
return loss
|
||||
|
||||
class DMLLoss(nn.Layer):
|
||||
"""
|
||||
DMLLoss
|
||||
|
@ -70,16 +91,20 @@ class DMLLoss(nn.Layer):
|
|||
else:
|
||||
self.act = None
|
||||
|
||||
self.jskl_loss = KLJSLoss(mode="js")
|
||||
|
||||
def forward(self, out1, out2):
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1)
|
||||
out2 = self.act(out2)
|
||||
|
||||
if len(out1.shape) < 2:
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, out1, reduction='batchmean')) / 2.0
|
||||
else:
|
||||
loss = self.jskl_loss(out1, out2)
|
||||
return loss
|
||||
|
||||
|
||||
|
|
|
@ -55,7 +55,5 @@ class CombinedLoss(nn.Layer):
|
|||
loss_all += loss[key] * weight
|
||||
else:
|
||||
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
||||
# loss[f"{key}_{idx}"] = loss[key]
|
||||
loss_dict.update(loss)
|
||||
loss_dict["loss"] = loss_all
|
||||
return loss_dict
|
||||
|
|
|
@ -46,13 +46,13 @@ class DistillationDMLLoss(DMLLoss):
|
|||
act=None,
|
||||
key=None,
|
||||
maps_name=None,
|
||||
name="loss_dml"):
|
||||
name="dml"):
|
||||
super().__init__(act=act)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
self.name = name
|
||||
self.maps_name = maps_name
|
||||
self.maps_name = self._check_maps_name(maps_name)
|
||||
|
||||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
|
@ -76,11 +76,11 @@ class DistillationDMLLoss(DMLLoss):
|
|||
new_outs = {}
|
||||
for k in self.maps_name:
|
||||
if k == "thrink_maps":
|
||||
new_outs[k] = paddle.slice(outs, axes=[1], starts=[0], ends=[1])
|
||||
new_outs[k] = outs[:, 0, :, :]
|
||||
elif k == "threshold_maps":
|
||||
new_outs[k] = paddle.slice(outs, axes=[1], starts=[1], ends=[2])
|
||||
new_outs[k] = outs[:, 1, :, :]
|
||||
elif k == "binary_maps":
|
||||
new_outs[k] = paddle.slice(outs, axes=[1], starts=[2], ends=[3])
|
||||
new_outs[k] = outs[:, 2, :, :]
|
||||
else:
|
||||
continue
|
||||
return new_outs
|
||||
|
@ -105,14 +105,14 @@ class DistillationDMLLoss(DMLLoss):
|
|||
else:
|
||||
outs1 = self._slice_out(out1)
|
||||
outs2 = self._slice_out(out2)
|
||||
for k in outs1.keys():
|
||||
for _c, k in enumerate(outs1.keys()):
|
||||
loss = super().forward(outs1[k], outs2[k])
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], map_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name,
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
|
||||
idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
@ -152,7 +152,7 @@ class DistillationDBLoss(DBLoss):
|
|||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="db_loss",
|
||||
name="db",
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.model_name_list = model_name_list
|
||||
|
|
|
@ -55,6 +55,10 @@ class DetMetric(object):
|
|||
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
metircs = self.evaluator.combine_results(self.results)
|
||||
self.reset()
|
||||
return metircs
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
|
|
|
@ -200,21 +200,18 @@ class DistillationDBPostProcess(DBPostProcess):
|
|||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
**kwargs):
|
||||
super(DistillationDBPostProcess, self).__init__(
|
||||
thresh, box_thresh, max_candidates, unclip_ratio, use_dilation,
|
||||
score_mode)
|
||||
super().__init__()
|
||||
if not isinstance(model_name, list):
|
||||
model_name = [model_name]
|
||||
self.model_name = model_name
|
||||
|
||||
self.key = key
|
||||
|
||||
def forward(self, predicts, shape_list):
|
||||
def __call__(self, predicts, shape_list):
|
||||
results = {}
|
||||
for name in self.model_name:
|
||||
pred = predicts[name]
|
||||
if self.key is not None:
|
||||
pred = pred[self.key]
|
||||
results[name] = super().__call__(pred, shape_list=label)
|
||||
results[name] = super().__call__(pred, shape_list=shape_list)
|
||||
|
||||
return results
|
||||
|
|
|
@ -135,6 +135,7 @@ def load_pretrained_params(model, path):
|
|||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
print(f"load pretrain successful from {path}")
|
||||
return True
|
||||
|
||||
def save_model(model,
|
||||
|
|
|
@ -55,8 +55,10 @@ def main():
|
|||
|
||||
model = build_model(config['Architecture'])
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
if "model_type" in config['Architecture'].keys():
|
||||
model_type = config['Architecture']['model_type']
|
||||
|
||||
else:
|
||||
model_type = None
|
||||
best_model_dict = init_model(config, model)
|
||||
if len(best_model_dict):
|
||||
logger.info('metric in ckpt ***************')
|
||||
|
|
Loading…
Reference in New Issue