python端预测完成

This commit is contained in:
WenmuZhou 2020-11-17 17:28:28 +08:00
parent 903b102f5f
commit 0c287c41ea
5 changed files with 20 additions and 21 deletions

View File

@ -31,6 +31,8 @@ from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
logger = get_logger()
class TextClassifier(object):
def __init__(self, args):
@ -147,5 +149,4 @@ def main(args):
if __name__ == "__main__":
logger = get_logger()
main(utility.parse_args())

View File

@ -30,6 +30,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
logger = get_logger()
class TextDetector(object):
def __init__(self, args):
@ -158,9 +160,7 @@ class TextDetector(object):
if __name__ == "__main__":
args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir)
logger = get_logger()
text_detector = TextDetector(args)
count = 0
total_time = 0

View File

@ -13,12 +13,12 @@
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import cv2
import copy
import numpy as np
import math
import time
@ -30,6 +30,8 @@ from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
logger = get_logger()
class TextRecognizer(object):
def __init__(self, args):
@ -80,7 +82,7 @@ class TextRecognizer(object):
# rec_res = []
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
predict_time = 0
elapse = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
@ -110,7 +112,9 @@ class TextRecognizer(object):
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = outputs[0]
rec_res = self.postprocess_op(preds)
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
elapse = time.time() - starttime
return rec_res, elapse
@ -147,5 +151,4 @@ def main(args):
if __name__ == "__main__":
logger = get_logger()
main(utility.parse_args())

View File

@ -17,20 +17,17 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger
logger = initial_logger()
import cv2
import tools.infer.predict_det as predict_det
import tools.infer.predict_rec as predict_rec
import copy
import numpy as np
import math
import time
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from PIL import Image
import tools.infer.utility as utility
from tools.infer.utility import draw_ocr
from tools.infer.utility import draw_ocr_box_txt
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_det as predict_det
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.logging import get_logger
class TextSystem(object):
@ -153,11 +150,7 @@ def main(args):
scores = [rec_res[i][1] for i in range(len(rec_res))]
draw_img = draw_ocr(
image,
boxes,
txts,
scores,
drop_score=drop_score)
image, boxes, txts, scores, drop_score=drop_score)
draw_img_save = "./inference_results/"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
@ -169,4 +162,5 @@ def main(args):
if __name__ == "__main__":
logger = get_logger()
main(utility.parse_args())

View File

@ -39,7 +39,8 @@ def parse_args():
parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str)
parser.add_argument("--det_max_side_len", type=float, default=960)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)