support srn inference

This commit is contained in:
tink2123 2020-09-03 15:51:50 +08:00
parent b626aa3e41
commit bc1ad20701
5 changed files with 187 additions and 29 deletions

View File

@ -136,7 +136,7 @@ class RecModel(object):
else:
labels = None
loader = None
if self.char_type == "ch" and self.infer_img:
if self.char_type == "ch" and self.infer_img and self.loss_type != "srn":
image_shape[-1] = -1
if self.tps != None:
logger.info(
@ -172,16 +172,13 @@ class RecModel(object):
self.max_text_length
],
dtype="float32")
feed_list = [
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
labels = {
'encoder_word_pos': encoder_word_pos,
'gsrm_word_pos': gsrm_word_pos,
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
}
return image, labels, loader
def __call__(self, mode):
@ -218,8 +215,13 @@ class RecModel(object):
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
if self.loss_type == "srn":
raise Exception(
"Warning! SRN does not support export model currently")
return [
image, labels, {
'decoded_out': decoded_out,
'predicts': predict
}
]
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else:
predict = predicts['predict']

View File

@ -26,6 +26,7 @@ import copy
import numpy as np
import math
import time
import paddle.fluid as fluid
from ppocr.utils.character import CharacterOps
@ -37,18 +38,22 @@ class TextRecognizer(object):
self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
self.text_len = args.max_text_length
char_ops_params = {
"character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char,
"max_text_length": args.max_text_length
}
if self.rec_algorithm != "RARE":
if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]:
char_ops_params['loss_type'] = 'ctc'
self.loss_type = 'ctc'
else:
elif self.rec_algorithm == "RARE":
char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention'
elif self.rec_algorithm == "SRN":
char_ops_params['loss_type'] = 'srn'
self.loss_type = 'srn'
self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio):
@ -71,6 +76,83 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def resize_norm_img_srn(self, img, image_shape):
imgC, imgH, imgW = image_shape
img_black = np.zeros((imgH, imgW))
im_hei = img.shape[0]
im_wid = img.shape[1]
if im_wid <= im_hei * 1:
img_new = cv2.resize(img, (imgH * 1, imgH))
elif im_wid <= im_hei * 2:
img_new = cv2.resize(img, (imgH * 2, imgH))
elif im_wid <= im_hei * 3:
img_new = cv2.resize(img, (imgH * 3, imgH))
else:
img_new = cv2.resize(img, (imgW, imgH))
img_np = np.asarray(img_new)
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
img_black[:, 0:img_np.shape[1]] = img_np
img_black = img_black[:, :, np.newaxis]
row, col, c = img_black.shape
c = 1
return np.reshape(img_black, (c, row, col)).astype(np.float32)
def srn_other_inputs(self, image_shape, num_heads, max_text_length,
char_num):
imgC, imgH, imgW = image_shape
feature_dim = int((imgH / 8) * (imgW / 8))
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
(feature_dim, 1)).astype('int64')
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
(max_text_length, 1)).astype('int64')
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias1 = np.tile(
gsrm_slf_attn_bias1,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
[-1, 1, max_text_length, max_text_length])
gsrm_slf_attn_bias2 = np.tile(
gsrm_slf_attn_bias2,
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
encoder_word_pos = encoder_word_pos[np.newaxis, :]
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
return [
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
]
def process_image_srn(self,
img,
image_shape,
num_heads,
max_text_length,
char_ops=None):
norm_img = self.resize_norm_img_srn(img, image_shape)
norm_img = norm_img[np.newaxis, :]
char_num = char_ops.get_char_num()
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num)
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@ -94,16 +176,52 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
if self.loss_type != "srn":
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
else:
norm_img = self.process_image_srn(img_list[indices[ino]],
self.rec_image_shape, 8,
25, self.char_ops)
encoder_word_pos_list = []
gsrm_word_pos_list = []
gsrm_slf_attn_bias1_list = []
gsrm_slf_attn_bias2_list = []
encoder_word_pos_list.append(norm_img[1])
gsrm_word_pos_list.append(norm_img[2])
gsrm_slf_attn_bias1_list.append(norm_img[3])
gsrm_slf_attn_bias2_list.append(norm_img[4])
norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list)
starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
encoder_word_pos_list = fluid.core.PaddleTensor(
encoder_word_pos_list)
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
gsrm_slf_attn_bias2_list)
inputs = [
norm_img_batch, encoder_word_pos_list, gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list, gsrm_word_pos_list
]
self.predictor.run(inputs)
if self.loss_type == "ctc":
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
@ -128,6 +246,26 @@ class TextRecognizer(object):
score = np.mean(probs[valid_ind, ind[valid_ind]])
# rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
elif self.loss_type == 'srn':
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
probs = self.output_tensors[1].copy_to_cpu()
char_num = self.char_ops.get_char_num()
preds = rec_idx_batch.reshape(-1)
elapse = time.time() - starttime
predict_time += elapse
total_preds = preds.copy()
for ino in range(int(len(rec_idx_batch) / self.text_len)):
preds = total_preds[ino * self.text_len:(ino + 1) *
self.text_len]
ind = np.argmax(probs, axis=1)
valid_ind = np.where(preds != int(char_num - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]])
preds = preds[:valid_ind[-1] + 1]
preds_text = self.char_ops.decode(preds)
rec_res[indices[beg_img_no + ino]] = [preds_text, score]
else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu()
@ -162,6 +300,7 @@ def main(args):
continue
valid_image_file_list.append(image_file)
img_list.append(img)
try:
rec_res, predict_time = text_recognizer(img_list)
except Exception as e:

View File

@ -59,10 +59,10 @@ def parse_args():
parser.add_argument("--det_sast_polygon", type=bool, default=False)
#params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_algorithm", type=str, default='SRN')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_image_shape", type=str, default="3, 64, 256")
parser.add_argument("--rec_char_type", type=str, default='en')
parser.add_argument("--rec_batch_num", type=int, default=30)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
@ -107,10 +107,13 @@ def create_predictor(args, mode):
# use zero copy
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
config.switch_use_feed_fetch_ops(False)
#config.switch_use_feed_fetch_ops(False)
config.switch_use_feed_fetch_ops(True)
predictor = create_paddle_predictor(config)
input_names = predictor.get_input_names()
input_tensor = predictor.get_input_tensor(input_names[0])
print(input_names)
for name in input_names:
input_tensor = predictor.get_input_tensor(name)
output_names = predictor.get_output_names()
output_tensors = []
for output_name in output_names:

View File

@ -162,7 +162,10 @@ def main():
fluid.io.save_inference_model(
"./output/",
feeded_var_names=['image'],
feeded_var_names=[
'image', 'encoder_word_pos', 'gsrm_slf_attn_bias1',
'gsrm_slf_attn_bias2', 'gsrm_word_pos'
],
target_vars=target_var,
executor=exe,
main_program=eval_prog,

View File

@ -208,10 +208,19 @@ def build_export(config, main_prog, startup_prog):
with fluid.unique_name.guard():
func_infor = config['Architecture']['function']
model = create_module(func_infor)(params=config)
loss_type = config['Global']['loss_type']
if loss_type == "srn":
image, others, outputs = model(mode='export')
else:
image, outputs = model(mode='export')
fetches_var_name = sorted([name for name in outputs.keys()])
fetches_var = [outputs[name] for name in fetches_var_name]
if loss_type == "srn":
others_var_names = sorted([name for name in others.keys()])
feeded_var_names = [image.name] + others_var_names
else:
feeded_var_names = [image.name]
target_vars = fetches_var
return feeded_var_names, target_vars, fetches_var_name
@ -409,7 +418,9 @@ def preprocess():
check_gpu(use_gpu)
alg = config['Global']['algorithm']
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'
]
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
config['Global']['char_ops'] = CharacterOps(config['Global'])