Merge pull request #4216 from andyjpaddle/fix_export_sar
fix sar export inference model
This commit is contained in:
commit
6f7e07e6b3
|
@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger):
|
|||
]
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
elif arch_config["algorithm"] == "SAR":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 48, 160], dtype="float32"),
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
else:
|
||||
infer_shape = [3, -1, -1]
|
||||
if arch_config["model_type"] == "rec":
|
||||
|
|
|
@ -68,6 +68,13 @@ class TextRecognizer(object):
|
|||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
elif self.rec_algorithm == "SAR":
|
||||
postprocess_params = {
|
||||
'name': 'SARLabelDecode',
|
||||
"character_type": args.rec_char_type,
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors, self.config = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
|
@ -194,6 +201,41 @@ class TextRecognizer(object):
|
|||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2)
|
||||
|
||||
def resize_norm_img_sar(self, img, image_shape,
|
||||
width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
valid_ratio = 1.0
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
width_divisor = int(1 / width_downsample_ratio)
|
||||
# resize
|
||||
ratio = w / float(h)
|
||||
resize_w = math.ceil(imgH * ratio)
|
||||
if resize_w % width_divisor != 0:
|
||||
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||
if imgW_min is not None:
|
||||
resize_w = max(imgW_min, resize_w)
|
||||
if imgW_max is not None:
|
||||
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||
resize_w = min(imgW_max, resize_w)
|
||||
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
# norm
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resize_shape = resized_image.shape
|
||||
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||
padding_im[:, :, 0:resize_w] = resized_image
|
||||
pad_shape = padding_im.shape
|
||||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
|
@ -216,11 +258,19 @@ class TextRecognizer(object):
|
|||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
if self.rec_algorithm != "SRN":
|
||||
if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR":
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
elif self.rec_algorithm == "SAR":
|
||||
norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
|
||||
img_list[indices[ino]], self.rec_image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
valid_ratio = np.expand_dims(valid_ratio, axis=0)
|
||||
valid_ratios = []
|
||||
valid_ratios.append(valid_ratio)
|
||||
norm_img_batch.append(norm_img)
|
||||
else:
|
||||
norm_img = self.process_image_srn(
|
||||
img_list[indices[ino]], self.rec_image_shape, 8, 25)
|
||||
|
@ -266,6 +316,25 @@ class TextRecognizer(object):
|
|||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = {"predict": outputs[2]}
|
||||
elif self.rec_algorithm == "SAR":
|
||||
valid_ratios = np.concatenate(valid_ratios)
|
||||
inputs = [
|
||||
norm_img_batch,
|
||||
valid_ratios,
|
||||
]
|
||||
input_names = self.predictor.get_input_names()
|
||||
for i in range(len(input_names)):
|
||||
input_tensor = self.predictor.get_input_handle(input_names[
|
||||
i])
|
||||
input_tensor.copy_from_cpu(inputs[i])
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
if self.benchmark:
|
||||
self.autolog.times.stamp()
|
||||
preds = outputs[0]
|
||||
else:
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.run()
|
||||
|
|
Loading…
Reference in New Issue