fix eval bug
This commit is contained in:
parent
0742f5c521
commit
b8a65d4333
|
@ -55,9 +55,9 @@ class DetMetric(object):
|
|||
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
metircs = self.evaluator.combine_results(self.results)
|
||||
self.reset()
|
||||
return metircs
|
||||
# metircs = self.evaluator.combine_results(self.results)
|
||||
# self.reset()
|
||||
# return metircs
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
|
|
|
@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
|
|||
class DistillationMetric(object):
|
||||
def __init__(self,
|
||||
key=None,
|
||||
base_metric_name="RecMetric",
|
||||
main_indicator='acc',
|
||||
base_metric_name=None,
|
||||
main_indicator=None,
|
||||
**kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.key = key
|
||||
|
@ -42,16 +42,13 @@ class DistillationMetric(object):
|
|||
main_indicator=self.main_indicator, **self.kwargs)
|
||||
self.metrics[key].reset()
|
||||
|
||||
def __call__(self, preds, *args, **kwargs):
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
assert isinstance(preds, dict)
|
||||
if self.metrics is None:
|
||||
self._init_metrcis(preds)
|
||||
output = dict()
|
||||
for key in preds:
|
||||
metric = self.metrics[key].__call__(preds[key], *args, **kwargs)
|
||||
for sub_key in metric:
|
||||
output["{}_{}".format(key, sub_key)] = metric[sub_key]
|
||||
return output
|
||||
self.metrics[key].__call__(preds[key], batch, **kwargs)
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
|
|
|
@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
|
|||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained is not None:
|
||||
load_pretrained_params(model, pretrained)
|
||||
model = load_pretrained_params(model, pretrained)
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
|
|
|
@ -189,29 +189,27 @@ class DBPostProcess(object):
|
|||
return boxes_batch
|
||||
|
||||
|
||||
class DistillationDBPostProcess(DBPostProcess):
|
||||
def __init__(self,
|
||||
model_name=["student"],
|
||||
class DistillationDBPostProcess(object):
|
||||
def __init__(self, model_name=["student"],
|
||||
key=None,
|
||||
thresh=0.3,
|
||||
box_thresh=0.7,
|
||||
box_thresh=0.6,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=2.0,
|
||||
unclip_ratio=1.5,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
if not isinstance(model_name, list):
|
||||
model_name = [model_name]
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
self.post_process = DBPostProcess(thresh=thresh,
|
||||
box_thresh=box_thresh,
|
||||
max_candidates=max_candidates,
|
||||
unclip_ratio=unclip_ratio,
|
||||
use_dilation=use_dilation,
|
||||
score_mode=score_mode)
|
||||
|
||||
def __call__(self, predicts, shape_list):
|
||||
results = {}
|
||||
for name in self.model_name:
|
||||
pred = predicts[name]
|
||||
if self.key is not None:
|
||||
pred = pred[self.key]
|
||||
results[name] = super().__call__(pred, shape_list=shape_list)
|
||||
|
||||
for k in self.model_name:
|
||||
results[k] = self.post_process(predicts[k], shape_list=shape_list)
|
||||
return results
|
||||
|
|
|
@ -136,7 +136,7 @@ def load_pretrained_params(model, path):
|
|||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
print(f"load pretrain successful from {path}")
|
||||
return True
|
||||
return model
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
|
|
|
@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
|
|||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||
from ppocr.utils.utility import print_dict
|
||||
import tools.program as program
|
||||
|
||||
|
@ -59,7 +59,8 @@ def main():
|
|||
model_type = config['Architecture']['model_type']
|
||||
else:
|
||||
model_type = None
|
||||
best_model_dict = init_model(config, model)
|
||||
|
||||
best_model_dict = init_model(config, model, model_type)
|
||||
if len(best_model_dict):
|
||||
logger.info('metric in ckpt ***************')
|
||||
for k, v in best_model_dict.items():
|
||||
|
|
|
@ -374,6 +374,7 @@ def eval(model,
|
|||
eval_class(preds, batch)
|
||||
else:
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
# post_result = post_result_["Student"]
|
||||
eval_class(post_result, batch)
|
||||
pbar.update(1)
|
||||
total_frame += len(images)
|
||||
|
|
|
@ -97,8 +97,8 @@ def main(config, device, logger, vdl_writer):
|
|||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
||||
|
||||
#pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
||||
pre_best_model_dict = {}
|
||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||
if valid_dataloader is not None:
|
||||
logger.info('valid dataloader has {} iters'.format(
|
||||
|
|
Loading…
Reference in New Issue