分类支持传参置信度
This commit is contained in:
parent
06430c9359
commit
cf2a483369
|
@ -39,6 +39,7 @@ class TextClassifier(object):
|
|||
self.cls_batch_num = args.rec_batch_num
|
||||
self.label_list = args.label_list
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
self.cls_thresh = args.cls_thresh
|
||||
|
||||
def resize_norm_img(self, img):
|
||||
imgC, imgH, imgW = self.cls_image_shape
|
||||
|
@ -110,7 +111,7 @@ class TextClassifier(object):
|
|||
score = prob_out[rno][label_idx]
|
||||
label = self.label_list[label_idx]
|
||||
cls_res[indices[beg_img_no + rno]] = [label, score]
|
||||
if '180' in label and score > 0.9999:
|
||||
if '180' in label and score > self.cls_thresh:
|
||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||
img_list[indices[beg_img_no + rno]], 1)
|
||||
return img_list, cls_res, predict_time
|
||||
|
|
|
@ -78,6 +78,7 @@ def parse_args():
|
|||
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
||||
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
||||
parser.add_argument("--cls_batch_num", type=int, default=30)
|
||||
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
||||
|
||||
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
||||
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)
|
||||
|
|
Loading…
Reference in New Issue