Merge pull request #665 from tink2123/support_srn_inference
Support srn inference
This commit is contained in:
commit
5fb3c419c9
|
@ -136,7 +136,7 @@ class RecModel(object):
|
||||||
else:
|
else:
|
||||||
labels = None
|
labels = None
|
||||||
loader = None
|
loader = None
|
||||||
if self.char_type == "ch" and self.infer_img:
|
if self.char_type == "ch" and self.infer_img and self.loss_type != "srn":
|
||||||
image_shape[-1] = -1
|
image_shape[-1] = -1
|
||||||
if self.tps != None:
|
if self.tps != None:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -172,16 +172,13 @@ class RecModel(object):
|
||||||
self.max_text_length
|
self.max_text_length
|
||||||
],
|
],
|
||||||
dtype="float32")
|
dtype="float32")
|
||||||
feed_list = [
|
|
||||||
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
|
||||||
gsrm_slf_attn_bias2
|
|
||||||
]
|
|
||||||
labels = {
|
labels = {
|
||||||
'encoder_word_pos': encoder_word_pos,
|
'encoder_word_pos': encoder_word_pos,
|
||||||
'gsrm_word_pos': gsrm_word_pos,
|
'gsrm_word_pos': gsrm_word_pos,
|
||||||
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
|
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
|
||||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
|
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
|
||||||
}
|
}
|
||||||
|
|
||||||
return image, labels, loader
|
return image, labels, loader
|
||||||
|
|
||||||
def __call__(self, mode):
|
def __call__(self, mode):
|
||||||
|
@ -218,8 +215,13 @@ class RecModel(object):
|
||||||
if self.loss_type == "ctc":
|
if self.loss_type == "ctc":
|
||||||
predict = fluid.layers.softmax(predict)
|
predict = fluid.layers.softmax(predict)
|
||||||
if self.loss_type == "srn":
|
if self.loss_type == "srn":
|
||||||
raise Exception(
|
return [
|
||||||
"Warning! SRN does not support export model currently")
|
image, labels, {
|
||||||
|
'decoded_out': decoded_out,
|
||||||
|
'predicts': predict
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
|
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
|
||||||
else:
|
else:
|
||||||
predict = predicts['predict']
|
predict = predicts['predict']
|
||||||
|
|
|
@ -40,6 +40,7 @@ class TextRecognizer(object):
|
||||||
self.character_type = args.rec_char_type
|
self.character_type = args.rec_char_type
|
||||||
self.rec_batch_num = args.rec_batch_num
|
self.rec_batch_num = args.rec_batch_num
|
||||||
self.rec_algorithm = args.rec_algorithm
|
self.rec_algorithm = args.rec_algorithm
|
||||||
|
self.text_len = args.max_text_length
|
||||||
self.use_zero_copy_run = args.use_zero_copy_run
|
self.use_zero_copy_run = args.use_zero_copy_run
|
||||||
char_ops_params = {
|
char_ops_params = {
|
||||||
"character_type": args.rec_char_type,
|
"character_type": args.rec_char_type,
|
||||||
|
@ -47,12 +48,15 @@ class TextRecognizer(object):
|
||||||
"use_space_char": args.use_space_char,
|
"use_space_char": args.use_space_char,
|
||||||
"max_text_length": args.max_text_length
|
"max_text_length": args.max_text_length
|
||||||
}
|
}
|
||||||
if self.rec_algorithm != "RARE":
|
if self.rec_algorithm in ["CRNN", "Rosetta", "STAR-Net"]:
|
||||||
char_ops_params['loss_type'] = 'ctc'
|
char_ops_params['loss_type'] = 'ctc'
|
||||||
self.loss_type = 'ctc'
|
self.loss_type = 'ctc'
|
||||||
else:
|
elif self.rec_algorithm == "RARE":
|
||||||
char_ops_params['loss_type'] = 'attention'
|
char_ops_params['loss_type'] = 'attention'
|
||||||
self.loss_type = 'attention'
|
self.loss_type = 'attention'
|
||||||
|
elif self.rec_algorithm == "SRN":
|
||||||
|
char_ops_params['loss_type'] = 'srn'
|
||||||
|
self.loss_type = 'srn'
|
||||||
self.char_ops = CharacterOps(char_ops_params)
|
self.char_ops = CharacterOps(char_ops_params)
|
||||||
|
|
||||||
def resize_norm_img(self, img, max_wh_ratio):
|
def resize_norm_img(self, img, max_wh_ratio):
|
||||||
|
@ -75,6 +79,83 @@ class TextRecognizer(object):
|
||||||
padding_im[:, :, 0:resized_w] = resized_image
|
padding_im[:, :, 0:resized_w] = resized_image
|
||||||
return padding_im
|
return padding_im
|
||||||
|
|
||||||
|
def resize_norm_img_srn(self, img, image_shape):
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
|
||||||
|
img_black = np.zeros((imgH, imgW))
|
||||||
|
im_hei = img.shape[0]
|
||||||
|
im_wid = img.shape[1]
|
||||||
|
|
||||||
|
if im_wid <= im_hei * 1:
|
||||||
|
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||||
|
elif im_wid <= im_hei * 2:
|
||||||
|
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||||
|
elif im_wid <= im_hei * 3:
|
||||||
|
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||||
|
else:
|
||||||
|
img_new = cv2.resize(img, (imgW, imgH))
|
||||||
|
|
||||||
|
img_np = np.asarray(img_new)
|
||||||
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||||
|
img_black[:, 0:img_np.shape[1]] = img_np
|
||||||
|
img_black = img_black[:, :, np.newaxis]
|
||||||
|
|
||||||
|
row, col, c = img_black.shape
|
||||||
|
c = 1
|
||||||
|
|
||||||
|
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||||
|
|
||||||
|
def srn_other_inputs(self, image_shape, num_heads, max_text_length,
|
||||||
|
char_num):
|
||||||
|
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||||
|
|
||||||
|
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||||
|
(feature_dim, 1)).astype('int64')
|
||||||
|
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||||
|
(max_text_length, 1)).astype('int64')
|
||||||
|
|
||||||
|
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||||
|
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||||
|
[-1, 1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias1 = np.tile(
|
||||||
|
gsrm_slf_attn_bias1,
|
||||||
|
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||||
|
|
||||||
|
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||||
|
[-1, 1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias2 = np.tile(
|
||||||
|
gsrm_slf_attn_bias2,
|
||||||
|
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||||
|
|
||||||
|
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||||
|
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||||
|
|
||||||
|
return [
|
||||||
|
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2
|
||||||
|
]
|
||||||
|
|
||||||
|
def process_image_srn(self,
|
||||||
|
img,
|
||||||
|
image_shape,
|
||||||
|
num_heads,
|
||||||
|
max_text_length,
|
||||||
|
char_ops=None):
|
||||||
|
norm_img = self.resize_norm_img_srn(img, image_shape)
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
char_num = char_ops.get_char_num()
|
||||||
|
|
||||||
|
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||||
|
self.srn_other_inputs(image_shape, num_heads, max_text_length, char_num)
|
||||||
|
|
||||||
|
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
||||||
|
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
||||||
|
|
||||||
|
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2)
|
||||||
|
|
||||||
def __call__(self, img_list):
|
def __call__(self, img_list):
|
||||||
img_num = len(img_list)
|
img_num = len(img_list)
|
||||||
# Calculate the aspect ratio of all text bars
|
# Calculate the aspect ratio of all text bars
|
||||||
|
@ -98,13 +179,55 @@ class TextRecognizer(object):
|
||||||
wh_ratio = w * 1.0 / h
|
wh_ratio = w * 1.0 / h
|
||||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||||
for ino in range(beg_img_no, end_img_no):
|
for ino in range(beg_img_no, end_img_no):
|
||||||
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
|
if self.loss_type != "srn":
|
||||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||||
max_wh_ratio)
|
max_wh_ratio)
|
||||||
norm_img = norm_img[np.newaxis, :]
|
norm_img = norm_img[np.newaxis, :]
|
||||||
norm_img_batch.append(norm_img)
|
norm_img_batch.append(norm_img)
|
||||||
norm_img_batch = np.concatenate(norm_img_batch)
|
else:
|
||||||
|
norm_img = self.process_image_srn(img_list[indices[ino]],
|
||||||
|
self.rec_image_shape, 8,
|
||||||
|
25, self.char_ops)
|
||||||
|
encoder_word_pos_list = []
|
||||||
|
gsrm_word_pos_list = []
|
||||||
|
gsrm_slf_attn_bias1_list = []
|
||||||
|
gsrm_slf_attn_bias2_list = []
|
||||||
|
encoder_word_pos_list.append(norm_img[1])
|
||||||
|
gsrm_word_pos_list.append(norm_img[2])
|
||||||
|
gsrm_slf_attn_bias1_list.append(norm_img[3])
|
||||||
|
gsrm_slf_attn_bias2_list.append(norm_img[4])
|
||||||
|
norm_img_batch.append(norm_img[0])
|
||||||
|
|
||||||
|
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
|
||||||
norm_img_batch = norm_img_batch.copy()
|
norm_img_batch = norm_img_batch.copy()
|
||||||
|
|
||||||
|
if self.loss_type == "srn":
|
||||||
|
starttime = time.time()
|
||||||
|
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||||
|
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
|
||||||
|
gsrm_slf_attn_bias1_list = np.concatenate(
|
||||||
|
gsrm_slf_attn_bias1_list)
|
||||||
|
gsrm_slf_attn_bias2_list = np.concatenate(
|
||||||
|
gsrm_slf_attn_bias2_list)
|
||||||
|
starttime = time.time()
|
||||||
|
|
||||||
|
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
||||||
|
encoder_word_pos_list = fluid.core.PaddleTensor(
|
||||||
|
encoder_word_pos_list)
|
||||||
|
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
|
||||||
|
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
|
||||||
|
gsrm_slf_attn_bias1_list)
|
||||||
|
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
|
||||||
|
gsrm_slf_attn_bias2_list)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
norm_img_batch, encoder_word_pos_list,
|
||||||
|
gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list,
|
||||||
|
gsrm_word_pos_list
|
||||||
|
]
|
||||||
|
|
||||||
|
self.predictor.run(inputs)
|
||||||
|
else:
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
if self.use_zero_copy_run:
|
if self.use_zero_copy_run:
|
||||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||||
|
@ -136,6 +259,26 @@ class TextRecognizer(object):
|
||||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||||
# rec_res.append([preds_text, score])
|
# rec_res.append([preds_text, score])
|
||||||
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
|
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
|
||||||
|
elif self.loss_type == 'srn':
|
||||||
|
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||||
|
probs = self.output_tensors[1].copy_to_cpu()
|
||||||
|
char_num = self.char_ops.get_char_num()
|
||||||
|
preds = rec_idx_batch.reshape(-1)
|
||||||
|
elapse = time.time() - starttime
|
||||||
|
predict_time += elapse
|
||||||
|
total_preds = preds.copy()
|
||||||
|
for ino in range(int(len(rec_idx_batch) / self.text_len)):
|
||||||
|
preds = total_preds[ino * self.text_len:(ino + 1) *
|
||||||
|
self.text_len]
|
||||||
|
ind = np.argmax(probs, axis=1)
|
||||||
|
valid_ind = np.where(preds != int(char_num - 1))[0]
|
||||||
|
if len(valid_ind) == 0:
|
||||||
|
continue
|
||||||
|
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||||
|
preds = preds[:valid_ind[-1] + 1]
|
||||||
|
preds_text = self.char_ops.decode(preds)
|
||||||
|
|
||||||
|
rec_res[indices[beg_img_no + ino]] = [preds_text, score]
|
||||||
else:
|
else:
|
||||||
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||||
predict_batch = self.output_tensors[1].copy_to_cpu()
|
predict_batch = self.output_tensors[1].copy_to_cpu()
|
||||||
|
@ -170,6 +313,7 @@ def main(args):
|
||||||
continue
|
continue
|
||||||
valid_image_file_list.append(image_file)
|
valid_image_file_list.append(image_file)
|
||||||
img_list.append(img)
|
img_list.append(img)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rec_res, predict_time = text_recognizer(img_list)
|
rec_res, predict_time = text_recognizer(img_list)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -114,7 +114,8 @@ def create_predictor(args, mode):
|
||||||
|
|
||||||
predictor = create_paddle_predictor(config)
|
predictor = create_paddle_predictor(config)
|
||||||
input_names = predictor.get_input_names()
|
input_names = predictor.get_input_names()
|
||||||
input_tensor = predictor.get_input_tensor(input_names[0])
|
for name in input_names:
|
||||||
|
input_tensor = predictor.get_input_tensor(name)
|
||||||
output_names = predictor.get_output_names()
|
output_names = predictor.get_output_names()
|
||||||
output_tensors = []
|
output_tensors = []
|
||||||
for output_name in output_names:
|
for output_name in output_names:
|
||||||
|
|
|
@ -208,10 +208,19 @@ def build_export(config, main_prog, startup_prog):
|
||||||
with fluid.unique_name.guard():
|
with fluid.unique_name.guard():
|
||||||
func_infor = config['Architecture']['function']
|
func_infor = config['Architecture']['function']
|
||||||
model = create_module(func_infor)(params=config)
|
model = create_module(func_infor)(params=config)
|
||||||
|
algorithm = config['Global']['algorithm']
|
||||||
|
if algorithm == "SRN":
|
||||||
|
image, others, outputs = model(mode='export')
|
||||||
|
else:
|
||||||
image, outputs = model(mode='export')
|
image, outputs = model(mode='export')
|
||||||
fetches_var_name = sorted([name for name in outputs.keys()])
|
fetches_var_name = sorted([name for name in outputs.keys()])
|
||||||
fetches_var = [outputs[name] for name in fetches_var_name]
|
fetches_var = [outputs[name] for name in fetches_var_name]
|
||||||
|
if algorithm == "SRN":
|
||||||
|
others_var_names = sorted([name for name in others.keys()])
|
||||||
|
feeded_var_names = [image.name] + others_var_names
|
||||||
|
else:
|
||||||
feeded_var_names = [image.name]
|
feeded_var_names = [image.name]
|
||||||
|
|
||||||
target_vars = fetches_var
|
target_vars = fetches_var
|
||||||
return feeded_var_names, target_vars, fetches_var_name
|
return feeded_var_names, target_vars, fetches_var_name
|
||||||
|
|
||||||
|
@ -409,7 +418,9 @@ def preprocess():
|
||||||
check_gpu(use_gpu)
|
check_gpu(use_gpu)
|
||||||
|
|
||||||
alg = config['Global']['algorithm']
|
alg = config['Global']['algorithm']
|
||||||
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
|
assert alg in [
|
||||||
|
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'
|
||||||
|
]
|
||||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
||||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue