PaddleOCR/tools/infer_det.py

115 lines
3.8 KiB
Python
Raw Normal View History

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
2020-06-12 13:49:24 +08:00
import os
import sys
2020-10-13 17:13:33 +08:00
__dir__ = os.path.dirname(os.path.abspath(__file__))
2020-06-12 13:49:24 +08:00
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
2020-12-22 15:57:21 +08:00
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
2020-05-11 19:59:07 +08:00
import cv2
2020-10-13 17:13:33 +08:00
import json
import paddle
2020-10-13 17:13:33 +08:00
from ppocr.data import create_operators, transform
2020-11-09 16:40:24 +08:00
from ppocr.modeling.architectures import build_model
2020-10-13 17:13:33 +08:00
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
2020-11-09 16:40:24 +08:00
from ppocr.utils.utility import get_image_file_list
2020-10-13 17:13:33 +08:00
import tools.program as program
2020-05-15 14:22:57 +08:00
def draw_det_res(dt_boxes, config, img, img_name):
if len(dt_boxes) > 0:
import cv2
2020-05-15 14:22:57 +08:00
src_im = img
for box in dt_boxes:
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
2020-05-15 14:22:57 +08:00
save_det_path = os.path.dirname(config['Global'][
'save_res_path']) + "/det_results/"
if not os.path.exists(save_det_path):
os.makedirs(save_det_path)
2020-05-15 14:22:57 +08:00
save_path = os.path.join(save_det_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The detected Image saved in {}".format(save_path))
def main():
2020-10-13 17:13:33 +08:00
global_config = config['Global']
# build model
model = build_model(config['Architecture'])
init_model(config, model, logger)
# build post process
post_process_class = build_post_process(config['PostProcess'])
# create data ops
transforms = []
2020-11-09 16:40:24 +08:00
for op in config['Eval']['dataset']['transforms']:
2020-10-13 17:13:33 +08:00
op_name = list(op)[0]
if 'Label' in op_name:
continue
2020-11-09 16:40:24 +08:00
elif op_name == 'KeepKeys':
2020-10-13 17:13:33 +08:00
op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
ops = create_operators(transforms, global_config)
save_res_path = config['Global']['save_res_path']
2020-05-15 14:22:57 +08:00
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
2020-10-13 17:13:33 +08:00
model.eval()
with open(save_res_path, "wb") as fout:
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
2020-11-09 16:40:24 +08:00
images = paddle.to_tensor(images)
2020-10-13 17:13:33 +08:00
preds = model(images)
post_result = post_process_class(preds, shape_list)
boxes = post_result[0]['points']
2021-02-24 20:20:17 +08:00
# write result
2020-10-13 17:13:33 +08:00
dt_boxes_json = []
for box in boxes:
tmp_json = {"transcription": ""}
tmp_json['points'] = box.tolist()
dt_boxes_json.append(tmp_json)
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
src_img = cv2.imread(file)
draw_det_res(boxes, config, src_img, file)
logger.info("success!")
if __name__ == '__main__':
2020-11-09 16:40:24 +08:00
config, device, logger, vdl_writer = program.preprocess()
main()