add comment
This commit is contained in:
parent
725185cd6a
commit
90734ca685
|
@ -84,19 +84,29 @@ def parse_args():
|
||||||
|
|
||||||
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
||||||
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)
|
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)
|
||||||
|
|
||||||
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def create_predictor(args, mode):
|
def create_predictor(args, mode):
|
||||||
|
"""
|
||||||
|
create predictor for inference
|
||||||
|
:param args: params for prediction engine
|
||||||
|
:param mode: mode
|
||||||
|
:return: predictor
|
||||||
|
"""
|
||||||
if mode == "det":
|
if mode == "det":
|
||||||
model_dir = args.det_model_dir
|
model_dir = args.det_model_dir
|
||||||
elif mode == 'cls':
|
elif mode == 'cls':
|
||||||
model_dir = args.cls_model_dir
|
model_dir = args.cls_model_dir
|
||||||
else:
|
elif mode == 'rec':
|
||||||
model_dir = args.rec_model_dir
|
model_dir = args.rec_model_dir
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"'mode' of create_predictor() can only be one of ['det', 'cls', 'rec']"
|
||||||
|
)
|
||||||
|
|
||||||
if model_dir is None:
|
if model_dir is None:
|
||||||
logger.info("not find {} model file path {}".format(mode, model_dir))
|
logger.info("not find {} model file path {}".format(mode, model_dir))
|
||||||
|
@ -144,6 +154,12 @@ def create_predictor(args, mode):
|
||||||
|
|
||||||
|
|
||||||
def draw_text_det_res(dt_boxes, img_path):
|
def draw_text_det_res(dt_boxes, img_path):
|
||||||
|
"""
|
||||||
|
Visualize the results of detection
|
||||||
|
:param dt_boxes: The boxes predicted by detection model
|
||||||
|
:param img_path: Image path
|
||||||
|
:return: Visualized image
|
||||||
|
"""
|
||||||
src_im = cv2.imread(img_path)
|
src_im = cv2.imread(img_path)
|
||||||
for box in dt_boxes:
|
for box in dt_boxes:
|
||||||
box = np.array(box).astype(np.int32).reshape(-1, 2)
|
box = np.array(box).astype(np.int32).reshape(-1, 2)
|
||||||
|
|
Loading…
Reference in New Issue