fix metric etc.al
This commit is contained in:
parent
a7b32ca82b
commit
0742f5c521
|
@ -90,14 +90,14 @@ Loss:
|
||||||
- ["Student", "Student2"]
|
- ["Student", "Student2"]
|
||||||
maps_name: "thrink_maps"
|
maps_name: "thrink_maps"
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
act: "softmax"
|
# act: None
|
||||||
model_name_pairs: ["Student", "Student2"]
|
model_name_pairs: ["Student", "Student2"]
|
||||||
key: maps
|
key: maps
|
||||||
- DistillationDBLoss:
|
- DistillationDBLoss:
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
model_name_list: ["Student", "Student2"]
|
model_name_list: ["Student", "Student2"]
|
||||||
# key: maps
|
# key: maps
|
||||||
name: DBLoss
|
# name: DBLoss
|
||||||
balance_loss: true
|
balance_loss: true
|
||||||
main_loss_type: DiceLoss
|
main_loss_type: DiceLoss
|
||||||
alpha: 5
|
alpha: 5
|
||||||
|
@ -119,8 +119,8 @@ Optimizer:
|
||||||
|
|
||||||
PostProcess:
|
PostProcess:
|
||||||
name: DistillationDBPostProcess
|
name: DistillationDBPostProcess
|
||||||
model_name: ["Student", "Student2"]
|
model_name: ["Student", "Student2", "Teacher"]
|
||||||
key: head_out
|
# key: maps
|
||||||
thresh: 0.3
|
thresh: 0.3
|
||||||
box_thresh: 0.6
|
box_thresh: 0.6
|
||||||
max_candidates: 1000
|
max_candidates: 1000
|
||||||
|
|
|
@ -54,6 +54,27 @@ class CELoss(nn.Layer):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class KLJSLoss(object):
|
||||||
|
def __init__(self, mode='kl'):
|
||||||
|
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def __call__(self, p1, p2, reduction="mean"):
|
||||||
|
|
||||||
|
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
|
||||||
|
|
||||||
|
if self.mode.lower() == "js":
|
||||||
|
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
|
||||||
|
loss *= 0.5
|
||||||
|
if reduction == "mean":
|
||||||
|
loss = paddle.mean(loss, axis=[1,2])
|
||||||
|
elif reduction=="none" or reduction is None:
|
||||||
|
return loss
|
||||||
|
else:
|
||||||
|
loss = paddle.sum(loss, axis=[1,2])
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
class DMLLoss(nn.Layer):
|
class DMLLoss(nn.Layer):
|
||||||
"""
|
"""
|
||||||
DMLLoss
|
DMLLoss
|
||||||
|
@ -69,17 +90,21 @@ class DMLLoss(nn.Layer):
|
||||||
self.act = nn.Sigmoid()
|
self.act = nn.Sigmoid()
|
||||||
else:
|
else:
|
||||||
self.act = None
|
self.act = None
|
||||||
|
|
||||||
|
self.jskl_loss = KLJSLoss(mode="js")
|
||||||
|
|
||||||
def forward(self, out1, out2):
|
def forward(self, out1, out2):
|
||||||
if self.act is not None:
|
if self.act is not None:
|
||||||
out1 = self.act(out1)
|
out1 = self.act(out1)
|
||||||
out2 = self.act(out2)
|
out2 = self.act(out2)
|
||||||
|
if len(out1.shape) < 2:
|
||||||
log_out1 = paddle.log(out1)
|
log_out1 = paddle.log(out1)
|
||||||
log_out2 = paddle.log(out2)
|
log_out2 = paddle.log(out2)
|
||||||
loss = (F.kl_div(
|
loss = (F.kl_div(
|
||||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||||
log_out2, out1, reduction='batchmean')) / 2.0
|
log_out2, out1, reduction='batchmean')) / 2.0
|
||||||
|
else:
|
||||||
|
loss = self.jskl_loss(out1, out2)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,5 @@ class CombinedLoss(nn.Layer):
|
||||||
loss_all += loss[key] * weight
|
loss_all += loss[key] * weight
|
||||||
else:
|
else:
|
||||||
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
||||||
# loss[f"{key}_{idx}"] = loss[key]
|
|
||||||
loss_dict.update(loss)
|
|
||||||
loss_dict["loss"] = loss_all
|
loss_dict["loss"] = loss_all
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
|
|
@ -46,13 +46,13 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
act=None,
|
act=None,
|
||||||
key=None,
|
key=None,
|
||||||
maps_name=None,
|
maps_name=None,
|
||||||
name="loss_dml"):
|
name="dml"):
|
||||||
super().__init__(act=act)
|
super().__init__(act=act)
|
||||||
assert isinstance(model_name_pairs, list)
|
assert isinstance(model_name_pairs, list)
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||||
self.name = name
|
self.name = name
|
||||||
self.maps_name = maps_name
|
self.maps_name = self._check_maps_name(maps_name)
|
||||||
|
|
||||||
def _check_model_name_pairs(self, model_name_pairs):
|
def _check_model_name_pairs(self, model_name_pairs):
|
||||||
if not isinstance(model_name_pairs, list):
|
if not isinstance(model_name_pairs, list):
|
||||||
|
@ -76,11 +76,11 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
new_outs = {}
|
new_outs = {}
|
||||||
for k in self.maps_name:
|
for k in self.maps_name:
|
||||||
if k == "thrink_maps":
|
if k == "thrink_maps":
|
||||||
new_outs[k] = paddle.slice(outs, axes=[1], starts=[0], ends=[1])
|
new_outs[k] = outs[:, 0, :, :]
|
||||||
elif k == "threshold_maps":
|
elif k == "threshold_maps":
|
||||||
new_outs[k] = paddle.slice(outs, axes=[1], starts=[1], ends=[2])
|
new_outs[k] = outs[:, 1, :, :]
|
||||||
elif k == "binary_maps":
|
elif k == "binary_maps":
|
||||||
new_outs[k] = paddle.slice(outs, axes=[1], starts=[2], ends=[3])
|
new_outs[k] = outs[:, 2, :, :]
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
return new_outs
|
return new_outs
|
||||||
|
@ -105,16 +105,16 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
else:
|
else:
|
||||||
outs1 = self._slice_out(out1)
|
outs1 = self._slice_out(out1)
|
||||||
outs2 = self._slice_out(out2)
|
outs2 = self._slice_out(out2)
|
||||||
for k in outs1.keys():
|
for _c, k in enumerate(outs1.keys()):
|
||||||
loss = super().forward(outs1[k], outs2[k])
|
loss = super().forward(outs1[k], outs2[k])
|
||||||
if isinstance(loss, dict):
|
if isinstance(loss, dict):
|
||||||
for key in loss:
|
for key in loss:
|
||||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||||
0], pair[1], map_name, idx)] = loss[key]
|
0], pair[1], map_name, idx)] = loss[key]
|
||||||
else:
|
else:
|
||||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name,
|
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
|
||||||
idx)] = loss
|
idx)] = loss
|
||||||
|
|
||||||
loss_dict = _sum_loss(loss_dict)
|
loss_dict = _sum_loss(loss_dict)
|
||||||
|
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
@ -152,7 +152,7 @@ class DistillationDBLoss(DBLoss):
|
||||||
beta=10,
|
beta=10,
|
||||||
ohem_ratio=3,
|
ohem_ratio=3,
|
||||||
eps=1e-6,
|
eps=1e-6,
|
||||||
name="db_loss",
|
name="db",
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_name_list = model_name_list
|
self.model_name_list = model_name_list
|
||||||
|
|
|
@ -55,6 +55,10 @@ class DetMetric(object):
|
||||||
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
||||||
self.results.append(result)
|
self.results.append(result)
|
||||||
|
|
||||||
|
metircs = self.evaluator.combine_results(self.results)
|
||||||
|
self.reset()
|
||||||
|
return metircs
|
||||||
|
|
||||||
def get_metric(self):
|
def get_metric(self):
|
||||||
"""
|
"""
|
||||||
return metrics {
|
return metrics {
|
||||||
|
|
|
@ -200,21 +200,18 @@ class DistillationDBPostProcess(DBPostProcess):
|
||||||
use_dilation=False,
|
use_dilation=False,
|
||||||
score_mode="fast",
|
score_mode="fast",
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(DistillationDBPostProcess, self).__init__(
|
super().__init__()
|
||||||
thresh, box_thresh, max_candidates, unclip_ratio, use_dilation,
|
|
||||||
score_mode)
|
|
||||||
if not isinstance(model_name, list):
|
if not isinstance(model_name, list):
|
||||||
model_name = [model_name]
|
model_name = [model_name]
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
self.key = key
|
self.key = key
|
||||||
|
|
||||||
def forward(self, predicts, shape_list):
|
def __call__(self, predicts, shape_list):
|
||||||
results = {}
|
results = {}
|
||||||
for name in self.model_name:
|
for name in self.model_name:
|
||||||
pred = predicts[name]
|
pred = predicts[name]
|
||||||
if self.key is not None:
|
if self.key is not None:
|
||||||
pred = pred[self.key]
|
pred = pred[self.key]
|
||||||
results[name] = super().__call__(pred, shape_list=label)
|
results[name] = super().__call__(pred, shape_list=shape_list)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
|
@ -130,11 +130,12 @@ def load_pretrained_params(model, path):
|
||||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||||
new_state_dict[k1] = params[k2]
|
new_state_dict[k1] = params[k2]
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||||
)
|
)
|
||||||
model.set_state_dict(new_state_dict)
|
model.set_state_dict(new_state_dict)
|
||||||
|
print(f"load pretrain successful from {path}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def save_model(model,
|
def save_model(model,
|
||||||
|
|
|
@ -55,8 +55,10 @@ def main():
|
||||||
|
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
model_type = config['Architecture']['model_type']
|
if "model_type" in config['Architecture'].keys():
|
||||||
|
model_type = config['Architecture']['model_type']
|
||||||
|
else:
|
||||||
|
model_type = None
|
||||||
best_model_dict = init_model(config, model)
|
best_model_dict = init_model(config, model)
|
||||||
if len(best_model_dict):
|
if len(best_model_dict):
|
||||||
logger.info('metric in ckpt ***************')
|
logger.info('metric in ckpt ***************')
|
||||||
|
|
Loading…
Reference in New Issue