fix all minor bugs
This commit is contained in:
parent
7cacfc97d9
commit
c4720557e8
|
@ -117,7 +117,7 @@ class OCRService(WebService):
|
|||
|
||||
if __name__ == "__main__":
|
||||
ocr_service = OCRService(name="ocr")
|
||||
ocr_service.load_model_config("cls_server")
|
||||
ocr_service.load_model_config(global_args.cls_model_dir)
|
||||
ocr_service.init_rec()
|
||||
if global_args.use_gpu:
|
||||
ocr_service.prepare_server(
|
||||
|
|
|
@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
|
|||
data = {"feed": [{"image": image}], "fetch": ["res"]}
|
||||
r = requests.post(url=url, headers=headers, data=json.dumps(data))
|
||||
print(r.json())
|
||||
break
|
||||
|
|
|
@ -96,7 +96,7 @@ class DetService(WebService):
|
|||
|
||||
if __name__ == "__main__":
|
||||
ocr_service = DetService(name="ocr")
|
||||
ocr_service.load_model_config("serving_server_dir")
|
||||
ocr_service.load_model_config(global_args.det_model_dir)
|
||||
ocr_service.init_det()
|
||||
if global_args.use_gpu:
|
||||
ocr_service.prepare_server(
|
||||
|
|
|
@ -79,7 +79,6 @@ class TextDetectorHelper(TextDetector):
|
|||
class DetService(WebService):
|
||||
def init_det(self):
|
||||
self.text_detector = TextDetectorHelper(global_args)
|
||||
print("init finish")
|
||||
|
||||
def preprocess(self, feed=[], fetch=[]):
|
||||
data = base64.b64decode(feed[0]["image"].encode('utf8'))
|
||||
|
@ -96,7 +95,7 @@ class DetService(WebService):
|
|||
|
||||
if __name__ == "__main__":
|
||||
ocr_service = DetService(name="ocr")
|
||||
ocr_service.load_model_config("serving_server_dir")
|
||||
ocr_service.load_model_config(global_args.det_model_dir)
|
||||
ocr_service.init_det()
|
||||
if global_args.use_gpu:
|
||||
ocr_service.prepare_server(
|
||||
|
|
|
@ -44,17 +44,16 @@ class TextSystemHelper(TextSystem):
|
|||
if self.use_angle_cls:
|
||||
self.clas_client = Debugger()
|
||||
self.clas_client.load_model_config(
|
||||
"ocr_clas_server", gpu=True, profile=False)
|
||||
global_args.cls_model_dir, gpu=True, profile=False)
|
||||
self.text_classifier = TextClassifierHelper(args)
|
||||
self.det_client = Debugger()
|
||||
self.det_client.load_model_config(
|
||||
"serving_server_dir", gpu=True, profile=False)
|
||||
global_args.det_model_dir, gpu=True, profile=False)
|
||||
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
|
||||
|
||||
def preprocess(self, img):
|
||||
feed, fetch, self.tmp_args = self.text_detector.preprocess(img)
|
||||
fetch_map = self.det_client.predict(feed, fetch)
|
||||
print("det fetch_map", fetch_map)
|
||||
outputs = [fetch_map[x] for x in fetch]
|
||||
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
|
||||
if dt_boxes is None:
|
||||
|
@ -90,12 +89,10 @@ class OCRService(WebService):
|
|||
|
||||
def preprocess(self, feed=[], fetch=[]):
|
||||
# TODO: to handle batch rec images
|
||||
print("start preprocess")
|
||||
data = base64.b64decode(feed[0]["image"].encode('utf8'))
|
||||
data = np.fromstring(data, np.uint8)
|
||||
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
||||
feed, fetch, self.tmp_args = self.text_system.preprocess(im)
|
||||
print("ocr preprocess done")
|
||||
return feed, fetch
|
||||
|
||||
def postprocess(self, feed={}, fetch=[], fetch_map=None):
|
||||
|
|
|
@ -25,7 +25,7 @@ from clas_rpc_server import TextClassifierHelper
|
|||
from det_rpc_server import TextDetectorHelper
|
||||
from rec_rpc_server import TextRecognizerHelper
|
||||
import tools.infer.utility as utility
|
||||
from tools.infer.predict_system import TextSystem
|
||||
from tools.infer.predict_system import TextSystem, sorted_boxes
|
||||
import copy
|
||||
|
||||
global_args = utility.parse_args()
|
||||
|
@ -48,7 +48,7 @@ class TextSystemHelper(TextSystem):
|
|||
self.text_classifier = TextClassifierHelper(args)
|
||||
self.det_client = Client()
|
||||
self.det_client.load_client_config(
|
||||
"ocr_det_server/serving_client_conf.prototxt")
|
||||
"det_db_client/serving_client_conf.prototxt")
|
||||
self.det_client.connect(["127.0.0.1:9293"])
|
||||
self.fetch = ["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
|
||||
|
||||
|
@ -57,10 +57,10 @@ class TextSystemHelper(TextSystem):
|
|||
fetch_map = self.det_client.predict(feed, fetch)
|
||||
outputs = [fetch_map[x] for x in fetch]
|
||||
dt_boxes = self.text_detector.postprocess(outputs, self.tmp_args)
|
||||
print(dt_boxes)
|
||||
if dt_boxes is None:
|
||||
return None, None
|
||||
img_crop_list = []
|
||||
sorted_boxes = SortedBoxes()
|
||||
dt_boxes = sorted_boxes(dt_boxes)
|
||||
for bno in range(len(dt_boxes)):
|
||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
|
@ -70,6 +70,7 @@ class TextSystemHelper(TextSystem):
|
|||
feed, fetch, self.tmp_args = self.text_classifier.preprocess(
|
||||
img_crop_list)
|
||||
fetch_map = self.clas_client.predict(feed, fetch)
|
||||
print(fetch_map)
|
||||
outputs = [fetch_map[x] for x in self.text_classifier.fetch]
|
||||
for x in fetch_map.keys():
|
||||
if ".lod" in x:
|
||||
|
|
|
@ -36,8 +36,5 @@ for img_file in os.listdir(test_img_dir):
|
|||
image = cv2_to_base64(image_data1)
|
||||
data = {"feed": [{"image": image}], "fetch": ["res"]}
|
||||
r = requests.post(url=url, headers=headers, data=json.dumps(data))
|
||||
print(r)
|
||||
rjson = r.json()
|
||||
print(rjson)
|
||||
#for x in rjson["result"]["pred_text"]:
|
||||
# print(x)
|
||||
|
|
|
@ -85,7 +85,6 @@ class TextRecognizerHelper(TextRecognizer):
|
|||
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
|
||||
predict_lod = args["softmax_0.tmp_0.lod"]
|
||||
indices = args["indices"]
|
||||
print("indices", indices, rec_idx_lod)
|
||||
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
|
||||
for rno in range(len(rec_idx_lod) - 1):
|
||||
beg = rec_idx_lod[rno]
|
||||
|
@ -155,7 +154,6 @@ class OCRService(WebService):
|
|||
if ".lod" in x:
|
||||
self.tmp_args[x] = fetch_map[x]
|
||||
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
|
||||
print("rec_res", rec_res)
|
||||
res = {
|
||||
"pred_text": [x[0] for x in rec_res],
|
||||
"score": [str(x[1]) for x in rec_res]
|
||||
|
|
|
@ -91,7 +91,6 @@ class TextRecognizerHelper(TextRecognizer):
|
|||
rec_idx_lod = args["ctc_greedy_decoder_0.tmp_0.lod"]
|
||||
predict_lod = args["softmax_0.tmp_0.lod"]
|
||||
indices = args["indices"]
|
||||
print("indices", indices, rec_idx_lod)
|
||||
rec_res = [['', 0.0]] * (len(rec_idx_lod) - 1)
|
||||
for rno in range(len(rec_idx_lod) - 1):
|
||||
beg = rec_idx_lod[rno]
|
||||
|
@ -161,7 +160,6 @@ class OCRService(WebService):
|
|||
if ".lod" in x:
|
||||
self.tmp_args[x] = fetch_map[x]
|
||||
rec_res = self.text_recognizer.postprocess(outputs, self.tmp_args)
|
||||
print("rec_res", rec_res)
|
||||
res = {
|
||||
"pred_text": [x[0] for x in rec_res],
|
||||
"score": [str(x[1]) for x in rec_res]
|
||||
|
|
|
@ -37,4 +37,3 @@ for img_file in os.listdir(test_img_dir):
|
|||
data = {"feed": [{"image": image}], "fetch": ["res"]}
|
||||
r = requests.post(url=url, headers=headers, data=json.dumps(data))
|
||||
print(r.json())
|
||||
break
|
||||
|
|
|
@ -33,7 +33,7 @@ from paddle import fluid
|
|||
|
||||
class TextClassifier(object):
|
||||
def __init__(self, args):
|
||||
if args.use_serving is False:
|
||||
if args.use_pdserving is False:
|
||||
self.predictor, self.input_tensor, self.output_tensors = \
|
||||
utility.create_predictor(args, mode="cls")
|
||||
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
|
||||
|
|
|
@ -75,7 +75,7 @@ class TextDetector(object):
|
|||
else:
|
||||
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
|
||||
sys.exit(0)
|
||||
if args.use_gpu is False:
|
||||
if args.use_pdserving is False:
|
||||
self.predictor, self.input_tensor, self.output_tensors =\
|
||||
utility.create_predictor(args, mode="det")
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ from ppocr.utils.character import CharacterOps
|
|||
|
||||
class TextRecognizer(object):
|
||||
def __init__(self, args):
|
||||
if args.use_serving is False:
|
||||
if args.use_pdserving is False:
|
||||
self.predictor, self.input_tensor, self.output_tensors =\
|
||||
utility.create_predictor(args, mode="rec")
|
||||
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
|
||||
|
|
|
@ -161,7 +161,12 @@ def main(args):
|
|||
scores = [rec_res[i][1] for i in range(len(rec_res))]
|
||||
|
||||
draw_img = draw_ocr(
|
||||
image, boxes, txts, scores, drop_score=drop_score, font_path=font_path)
|
||||
image,
|
||||
boxes,
|
||||
txts,
|
||||
scores,
|
||||
drop_score=drop_score,
|
||||
font_path=font_path)
|
||||
draw_img_save = "./inference_results/"
|
||||
if not os.path.exists(draw_img_save):
|
||||
os.makedirs(draw_img_save)
|
||||
|
|
|
@ -37,7 +37,7 @@ def parse_args():
|
|||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||
parser.add_argument("--use_serving", type=str2bool, default=False)
|
||||
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||
|
||||
# params for text detector
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
|
@ -73,9 +73,7 @@ def parse_args():
|
|||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
||||
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
||||
parser.add_argument(
|
||||
"--vis_font_path",
|
||||
type=str,
|
||||
default="./doc/simfang.ttf")
|
||||
"--vis_font_path", type=str, default="./doc/simfang.ttf")
|
||||
|
||||
# params for text classifier
|
||||
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||
|
@ -230,8 +228,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
|
|||
1])**2)
|
||||
if box_height > 2 * box_width:
|
||||
font_size = max(int(box_width * 0.9), 10)
|
||||
font = ImageFont.truetype(
|
||||
font_path, font_size, encoding="utf-8")
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
cur_y = box[0][1]
|
||||
for c in txt:
|
||||
char_size = font.getsize(c)
|
||||
|
@ -240,8 +237,7 @@ def draw_ocr_box_txt(image, boxes, txts, font_path="./doc/simfang.ttf"):
|
|||
cur_y += char_size[1]
|
||||
else:
|
||||
font_size = max(int(box_height * 0.8), 10)
|
||||
font = ImageFont.truetype(
|
||||
font_path, font_size, encoding="utf-8")
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
draw_right.text(
|
||||
[box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
|
||||
img_left = Image.blend(image, img_left, 0.5)
|
||||
|
|
Loading…
Reference in New Issue