parent
d8719969ba
commit
3ce97f18ec
|
@ -248,9 +248,11 @@ class TextRecognizer(object):
|
|||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
text_recognizer = TextRecognizer(args)
|
||||
total_run_time = 0.0
|
||||
total_images_num = 0
|
||||
valid_image_file_list = []
|
||||
img_list = []
|
||||
for image_file in image_file_list:
|
||||
for idx, image_file in enumerate(image_file_list):
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
|
@ -259,22 +261,29 @@ def main(args):
|
|||
continue
|
||||
valid_image_file_list.append(image_file)
|
||||
img_list.append(img)
|
||||
try:
|
||||
rec_res, predict_time = text_recognizer(img_list)
|
||||
except:
|
||||
logger.info(traceback.format_exc())
|
||||
logger.info(
|
||||
"ERROR!!!! \n"
|
||||
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
|
||||
"If your model has tps module: "
|
||||
"TPS does not support variable shape.\n"
|
||||
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
|
||||
exit()
|
||||
for ino in range(len(img_list)):
|
||||
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
|
||||
rec_res[ino]))
|
||||
if len(img_list) >= args.rec_batch_num or idx == len(
|
||||
image_file_list) - 1:
|
||||
try:
|
||||
rec_res, predict_time = text_recognizer(img_list)
|
||||
total_run_time += predict_time
|
||||
except:
|
||||
logger.info(traceback.format_exc())
|
||||
logger.info(
|
||||
"ERROR!!!! \n"
|
||||
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
|
||||
"If your model has tps module: "
|
||||
"TPS does not support variable shape.\n"
|
||||
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' "
|
||||
)
|
||||
exit()
|
||||
for ino in range(len(img_list)):
|
||||
logger.info("Predicts of {}:{}".format(valid_image_file_list[
|
||||
ino], rec_res[ino]))
|
||||
total_images_num += len(valid_image_file_list)
|
||||
valid_image_file_list = []
|
||||
img_list = []
|
||||
logger.info("Total predict time for {} images, cost: {:.3f}".format(
|
||||
len(img_list), predict_time))
|
||||
total_images_num, total_run_time))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue