分类支持传参置信度

This commit is contained in:
WenmuZhou 2020-09-18 11:29:55 +08:00
parent 06430c9359
commit cf2a483369
2 changed files with 3 additions and 1 deletions

View File

@ -39,6 +39,7 @@ class TextClassifier(object):
self.cls_batch_num = args.rec_batch_num
self.label_list = args.label_list
self.use_zero_copy_run = args.use_zero_copy_run
self.cls_thresh = args.cls_thresh
def resize_norm_img(self, img):
imgC, imgH, imgW = self.cls_image_shape
@ -110,7 +111,7 @@ class TextClassifier(object):
score = prob_out[rno][label_idx]
label = self.label_list[label_idx]
cls_res[indices[beg_img_no + rno]] = [label, score]
if '180' in label and score > 0.9999:
if '180' in label and score > self.cls_thresh:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res, predict_time

View File

@ -78,6 +78,7 @@ def parse_args():
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=30)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)