sorted outputs when export model
This commit is contained in:
parent
d539508e76
commit
72ebbf2de1
|
@ -142,8 +142,8 @@ class TextDetector(object):
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
outs_dict = {}
|
outs_dict = {}
|
||||||
if self.det_algorithm == "EAST":
|
if self.det_algorithm == "EAST":
|
||||||
outs_dict['f_score'] = outputs[0]
|
outs_dict['f_geo'] = outputs[0]
|
||||||
outs_dict['f_geo'] = outputs[1]
|
outs_dict['f_score'] = outputs[1]
|
||||||
else:
|
else:
|
||||||
outs_dict['maps'] = outputs[0]
|
outs_dict['maps'] = outputs[0]
|
||||||
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
|
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
|
||||||
|
@ -169,14 +169,7 @@ if __name__ == "__main__":
|
||||||
total_time += elapse
|
total_time += elapse
|
||||||
count += 1
|
count += 1
|
||||||
print("Predict time of %s:" % image_file, elapse)
|
print("Predict time of %s:" % image_file, elapse)
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
"""
|
||||||
draw_img = draw_ocr(img, dt_boxes, None, None, False)
|
add visualized code
|
||||||
draw_img_save = "./inference_results/"
|
"""
|
||||||
if not os.path.exists(draw_img_save):
|
|
||||||
os.makedirs(draw_img_save)
|
|
||||||
cv2.imwrite(
|
|
||||||
os.path.join(draw_img_save, os.path.basename(image_file)),
|
|
||||||
draw_img[:, :, ::-1])
|
|
||||||
print("The visualized image saved in {}".format(
|
|
||||||
os.path.join(draw_img_save, os.path.basename(image_file))))
|
|
||||||
print("Avg Time:", total_time / (count - 1))
|
print("Avg Time:", total_time / (count - 1))
|
||||||
|
|
|
@ -114,7 +114,6 @@ if __name__ == "__main__":
|
||||||
valid_image_file_list.append(image_file)
|
valid_image_file_list.append(image_file)
|
||||||
img_list.append(img)
|
img_list.append(img)
|
||||||
rec_res, predict_time = text_recognizer(img_list)
|
rec_res, predict_time = text_recognizer(img_list)
|
||||||
rec_res, predict_time = text_recognizer(img_list)
|
|
||||||
for ino in range(len(img_list)):
|
for ino in range(len(img_list)):
|
||||||
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
|
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
|
||||||
print("Total predict time for %d images:%.3f" %
|
print("Total predict time for %d images:%.3f" %
|
||||||
|
|
|
@ -191,8 +191,8 @@ def build_export(config, main_prog, startup_prog):
|
||||||
func_infor = config['Architecture']['function']
|
func_infor = config['Architecture']['function']
|
||||||
model = create_module(func_infor)(params=config)
|
model = create_module(func_infor)(params=config)
|
||||||
image, outputs = model(mode='export')
|
image, outputs = model(mode='export')
|
||||||
fetches_var = [outputs[name] for name in outputs]
|
fetches_var = sorted([outputs[name] for name in outputs])
|
||||||
fetches_var_name = [name for name in outputs]
|
fetches_var_name = [name for name in fetches_var]
|
||||||
feeded_var_names = [image.name]
|
feeded_var_names = [image.name]
|
||||||
target_vars = fetches_var
|
target_vars = fetches_var
|
||||||
return feeded_var_names, target_vars, fetches_var_name
|
return feeded_var_names, target_vars, fetches_var_name
|
||||||
|
|
Loading…
Reference in New Issue