2020-09-01 13:44:51 +08:00
|
|
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from __future__ import division
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2020-09-14 10:41:43 +08:00
|
|
|
__all__ = ['eval_cls_run']
|
2020-09-01 13:44:51 +08:00
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
|
|
|
|
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def eval_cls_run(exe, eval_info_dict):
|
|
|
|
"""
|
|
|
|
Run evaluation program, return program outputs.
|
|
|
|
"""
|
|
|
|
total_sample_num = 0
|
|
|
|
total_acc_num = 0
|
|
|
|
total_batch_num = 0
|
|
|
|
|
|
|
|
for data in eval_info_dict['reader']():
|
|
|
|
img_num = len(data)
|
|
|
|
img_list = []
|
|
|
|
label_list = []
|
|
|
|
for ino in range(img_num):
|
|
|
|
img_list.append(data[ino][0])
|
|
|
|
label_list.append(data[ino][1])
|
|
|
|
|
|
|
|
img_list = np.concatenate(img_list, axis=0)
|
|
|
|
outs = exe.run(eval_info_dict['program'], \
|
|
|
|
feed={'image': img_list}, \
|
|
|
|
fetch_list=eval_info_dict['fetch_varname_list'], \
|
|
|
|
return_numpy=False)
|
|
|
|
softmax_outs = np.array(outs[1])
|
2020-09-14 10:41:43 +08:00
|
|
|
if len(softmax_outs.shape) != 1:
|
|
|
|
softmax_outs = np.array(outs[0])
|
2020-09-01 13:44:51 +08:00
|
|
|
acc, acc_num = cal_cls_acc(softmax_outs, label_list)
|
|
|
|
total_acc_num += acc_num
|
|
|
|
total_sample_num += len(label_list)
|
|
|
|
# logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
|
|
|
|
total_batch_num += 1
|
|
|
|
avg_acc = total_acc_num * 1.0 / total_sample_num
|
|
|
|
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
|
|
|
|
"total_sample_num": total_sample_num}
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
|
|
def cal_cls_acc(preds, labels):
|
|
|
|
acc_num = 0
|
|
|
|
for pred, label in zip(preds, labels):
|
|
|
|
if pred == label:
|
|
|
|
acc_num += 1
|
|
|
|
return acc_num / len(preds), acc_num
|