support srn inference
This commit is contained in:
parent
b626aa3e41
commit
bc1ad20701
|
@ -136,7 +136,7 @@ class RecModel(object):
|
|||
else:
|
||||
labels = None
|
||||
loader = None
|
||||
if self.char_type == "ch" and self.infer_img:
|
||||
if self.char_type == "ch" and self.infer_img and self.loss_type != "srn":
|
||||
image_shape[-1] = -1
|
||||
if self.tps != None:
|
||||
logger.info(
|
||||
|
@ -172,16 +172,13 @@ class RecModel(object):
|
|||
self.max_text_length
|
||||
],
|
||||
dtype="float32")
|
||||
feed_list = [
|
||||
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
labels = {
|
||||
'encoder_word_pos': encoder_word_pos,
|
||||
'gsrm_word_pos': gsrm_word_pos,
|
||||
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
|
||||
}
|
||||
|
||||
return image, labels, loader
|
||||
|
||||
def __call__(self, mode):
|
||||
|
@ -218,8 +215,13 @@ class RecModel(object):
|
|||
if self.loss_type == "ctc":
|
||||
predict = fluid.layers.softmax(predict)
|
||||
if self.loss_type == "srn":
|
||||
raise Exception(
|
||||
"Warning! SRN does not support export model currently")
|
||||
return [
|
||||
image, labels, {
|
||||
'decoded_out': decoded_out,
|
||||
'predicts': predict
|
||||
}
|
||||
]
|
||||
|
||||
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
|
||||
else:
|
||||
predict = predicts['predict']
|
||||
|
|
|
@ -26,6 +26,7 @@ import copy
|
|||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import paddle.fluid as fluid
|
||||
from ppocr.utils.character import CharacterOps
|
||||
|
||||
|
||||
|
@ -37,18 +38,22 @@ class TextRecognizer(object):
|
|||
self.character_type = args.rec_char_type
|
||||
self.rec_batch_num = args.rec_batch_num
|
||||
self.rec_algorithm = args.rec_algorithm
|
||||
self.text_len = args.max_text_length
|
||||
char_ops_params = {
|
||||
"character_type": args.rec_char_type,
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char,
|
||||
"max_text_length": args.max_text_length
|
||||
}
|
||||
if self.rec_algorithm != "RARE":
|
||||
if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]:
|
||||
char_ops_params['loss_type'] = 'ctc'
|
||||
self.loss_type = 'ctc'
|
||||
else:
|
||||
elif self.rec_algorithm == "RARE":
|
||||
char_ops_params['loss_type'] = 'attention'
|
||||
self.loss_type = 'attention'
|
||||
elif self.rec_algorithm == "SRN":
|
||||
char_ops_params['loss_type'] = 'srn'
|
||||
self.loss_type = 'srn'
|
||||
self.char_ops = CharacterOps(char_ops_params)
|
||||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
|
@ -71,6 +76,83 @@ class TextRecognizer(object):
|
|||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def resize_norm_img_srn(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
img_black = np.zeros((imgH, imgW))
|
||||
im_hei = img.shape[0]
|
||||
im_wid = img.shape[1]
|
||||
|
||||
if im_wid <= im_hei * 1:
|
||||
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||
elif im_wid <= im_hei * 2:
|
||||
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||
elif im_wid <= im_hei * 3:
|
||||
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||
else:
|
||||
img_new = cv2.resize(img, (imgW, imgH))
|
||||
|
||||
img_np = np.asarray(img_new)
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||
img_black[:, 0:img_np.shape[1]] = img_np
|
||||
img_black = img_black[:, :, np.newaxis]
|
||||
|
||||
row, col, c = img_black.shape
|
||||
c = 1
|
||||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
def srn_other_inputs(self, image_shape, num_heads, max_text_length,
|
||||
char_num):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||
(feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||
(max_text_length, 1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(
|
||||
gsrm_slf_attn_bias1,
|
||||
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(
|
||||
gsrm_slf_attn_bias2,
|
||||
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||
|
||||
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||
|
||||
return [
|
||||
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
|
||||
def process_image_srn(self,
|
||||
img,
|
||||
image_shape,
|
||||
num_heads,
|
||||
max_text_length,
|
||||
char_ops=None):
|
||||
norm_img = self.resize_norm_img_srn(img, image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
char_num = char_ops.get_char_num()
|
||||
|
||||
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||
self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num)
|
||||
|
||||
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
||||
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
||||
|
||||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2)
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
|
@ -94,16 +176,52 @@ class TextRecognizer(object):
|
|||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
|
||||
if self.loss_type != "srn":
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
else:
|
||||
norm_img = self.process_image_srn(img_list[indices[ino]],
|
||||
self.rec_image_shape, 8,
|
||||
25, self.char_ops)
|
||||
encoder_word_pos_list = []
|
||||
gsrm_word_pos_list = []
|
||||
gsrm_slf_attn_bias1_list = []
|
||||
gsrm_slf_attn_bias2_list = []
|
||||
encoder_word_pos_list.append(norm_img[1])
|
||||
gsrm_word_pos_list.append(norm_img[2])
|
||||
gsrm_slf_attn_bias1_list.append(norm_img[3])
|
||||
gsrm_slf_attn_bias2_list.append(norm_img[4])
|
||||
norm_img_batch.append(norm_img[0])
|
||||
|
||||
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
|
||||
|
||||
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||
|
||||
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
|
||||
|
||||
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list)
|
||||
|
||||
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list)
|
||||
|
||||
starttime = time.time()
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.zero_copy_run()
|
||||
|
||||
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
||||
encoder_word_pos_list = fluid.core.PaddleTensor(
|
||||
encoder_word_pos_list)
|
||||
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
|
||||
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
|
||||
gsrm_slf_attn_bias1_list)
|
||||
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
|
||||
gsrm_slf_attn_bias2_list)
|
||||
|
||||
inputs = [
|
||||
norm_img_batch, encoder_word_pos_list, gsrm_slf_attn_bias1_list,
|
||||
gsrm_slf_attn_bias2_list, gsrm_word_pos_list
|
||||
]
|
||||
|
||||
self.predictor.run(inputs)
|
||||
|
||||
if self.loss_type == "ctc":
|
||||
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||
|
@ -128,6 +246,26 @@ class TextRecognizer(object):
|
|||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||
# rec_res.append([preds_text, score])
|
||||
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
|
||||
elif self.loss_type == 'srn':
|
||||
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||
probs = self.output_tensors[1].copy_to_cpu()
|
||||
char_num = self.char_ops.get_char_num()
|
||||
preds = rec_idx_batch.reshape(-1)
|
||||
elapse = time.time() - starttime
|
||||
predict_time += elapse
|
||||
total_preds = preds.copy()
|
||||
for ino in range(int(len(rec_idx_batch) / self.text_len)):
|
||||
preds = total_preds[ino * self.text_len:(ino + 1) *
|
||||
self.text_len]
|
||||
ind = np.argmax(probs, axis=1)
|
||||
valid_ind = np.where(preds != int(char_num - 1))[0]
|
||||
if len(valid_ind) == 0:
|
||||
continue
|
||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||
preds = preds[:valid_ind[-1] + 1]
|
||||
preds_text = self.char_ops.decode(preds)
|
||||
|
||||
rec_res[indices[beg_img_no + ino]] = [preds_text, score]
|
||||
else:
|
||||
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||
predict_batch = self.output_tensors[1].copy_to_cpu()
|
||||
|
@ -162,6 +300,7 @@ def main(args):
|
|||
continue
|
||||
valid_image_file_list.append(image_file)
|
||||
img_list.append(img)
|
||||
|
||||
try:
|
||||
rec_res, predict_time = text_recognizer(img_list)
|
||||
except Exception as e:
|
||||
|
|
|
@ -59,10 +59,10 @@ def parse_args():
|
|||
parser.add_argument("--det_sast_polygon", type=bool, default=False)
|
||||
|
||||
#params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||
parser.add_argument("--rec_algorithm", type=str, default='SRN')
|
||||
parser.add_argument("--rec_model_dir", type=str)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||
parser.add_argument("--rec_char_type", type=str, default='ch')
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 64, 256")
|
||||
parser.add_argument("--rec_char_type", type=str, default='en')
|
||||
parser.add_argument("--rec_batch_num", type=int, default=30)
|
||||
parser.add_argument("--max_text_length", type=int, default=25)
|
||||
parser.add_argument(
|
||||
|
@ -107,10 +107,13 @@ def create_predictor(args, mode):
|
|||
|
||||
# use zero copy
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
#config.switch_use_feed_fetch_ops(False)
|
||||
config.switch_use_feed_fetch_ops(True)
|
||||
predictor = create_paddle_predictor(config)
|
||||
input_names = predictor.get_input_names()
|
||||
input_tensor = predictor.get_input_tensor(input_names[0])
|
||||
print(input_names)
|
||||
for name in input_names:
|
||||
input_tensor = predictor.get_input_tensor(name)
|
||||
output_names = predictor.get_output_names()
|
||||
output_tensors = []
|
||||
for output_name in output_names:
|
||||
|
|
|
@ -162,7 +162,10 @@ def main():
|
|||
|
||||
fluid.io.save_inference_model(
|
||||
"./output/",
|
||||
feeded_var_names=['image'],
|
||||
feeded_var_names=[
|
||||
'image', 'encoder_word_pos', 'gsrm_slf_attn_bias1',
|
||||
'gsrm_slf_attn_bias2', 'gsrm_word_pos'
|
||||
],
|
||||
target_vars=target_var,
|
||||
executor=exe,
|
||||
main_program=eval_prog,
|
||||
|
|
|
@ -208,10 +208,19 @@ def build_export(config, main_prog, startup_prog):
|
|||
with fluid.unique_name.guard():
|
||||
func_infor = config['Architecture']['function']
|
||||
model = create_module(func_infor)(params=config)
|
||||
loss_type = config['Global']['loss_type']
|
||||
if loss_type == "srn":
|
||||
image, others, outputs = model(mode='export')
|
||||
else:
|
||||
image, outputs = model(mode='export')
|
||||
fetches_var_name = sorted([name for name in outputs.keys()])
|
||||
fetches_var = [outputs[name] for name in fetches_var_name]
|
||||
if loss_type == "srn":
|
||||
others_var_names = sorted([name for name in others.keys()])
|
||||
feeded_var_names = [image.name] + others_var_names
|
||||
else:
|
||||
feeded_var_names = [image.name]
|
||||
|
||||
target_vars = fetches_var
|
||||
return feeded_var_names, target_vars, fetches_var_name
|
||||
|
||||
|
@ -409,7 +418,9 @@ def preprocess():
|
|||
check_gpu(use_gpu)
|
||||
|
||||
alg = config['Global']['algorithm']
|
||||
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'
|
||||
]
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||
|
||||
|
|
Loading…
Reference in New Issue