polish code
This commit is contained in:
parent
bc1ad20701
commit
aa7e9ac34e
|
@ -196,32 +196,36 @@ class TextRecognizer(object):
|
|||
norm_img_batch.append(norm_img[0])
|
||||
|
||||
norm_img_batch = np.concatenate(norm_img_batch, axis=0)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
|
||||
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||
if self.loss_type == "srn":
|
||||
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
|
||||
gsrm_slf_attn_bias1_list = np.concatenate(
|
||||
gsrm_slf_attn_bias1_list)
|
||||
gsrm_slf_attn_bias2_list = np.concatenate(
|
||||
gsrm_slf_attn_bias2_list)
|
||||
starttime = time.time()
|
||||
|
||||
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
|
||||
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
||||
encoder_word_pos_list = fluid.core.PaddleTensor(
|
||||
encoder_word_pos_list)
|
||||
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
|
||||
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
|
||||
gsrm_slf_attn_bias1_list)
|
||||
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
|
||||
gsrm_slf_attn_bias2_list)
|
||||
|
||||
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list)
|
||||
inputs = [
|
||||
norm_img_batch, encoder_word_pos_list,
|
||||
gsrm_slf_attn_bias1_list, gsrm_slf_attn_bias2_list,
|
||||
gsrm_word_pos_list
|
||||
]
|
||||
|
||||
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list)
|
||||
|
||||
starttime = time.time()
|
||||
|
||||
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
||||
encoder_word_pos_list = fluid.core.PaddleTensor(
|
||||
encoder_word_pos_list)
|
||||
gsrm_word_pos_list = fluid.core.PaddleTensor(gsrm_word_pos_list)
|
||||
gsrm_slf_attn_bias1_list = fluid.core.PaddleTensor(
|
||||
gsrm_slf_attn_bias1_list)
|
||||
gsrm_slf_attn_bias2_list = fluid.core.PaddleTensor(
|
||||
gsrm_slf_attn_bias2_list)
|
||||
|
||||
inputs = [
|
||||
norm_img_batch, encoder_word_pos_list, gsrm_slf_attn_bias1_list,
|
||||
gsrm_slf_attn_bias2_list, gsrm_word_pos_list
|
||||
]
|
||||
|
||||
self.predictor.run(inputs)
|
||||
self.predictor.run(inputs)
|
||||
else:
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.zero_copy_run()
|
||||
|
||||
if self.loss_type == "ctc":
|
||||
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||
|
|
|
@ -59,10 +59,10 @@ def parse_args():
|
|||
parser.add_argument("--det_sast_polygon", type=bool, default=False)
|
||||
|
||||
#params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='SRN')
|
||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||
parser.add_argument("--rec_model_dir", type=str)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 64, 256")
|
||||
parser.add_argument("--rec_char_type", type=str, default='en')
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||
parser.add_argument("--rec_char_type", type=str, default='ch')
|
||||
parser.add_argument("--rec_batch_num", type=int, default=30)
|
||||
parser.add_argument("--max_text_length", type=int, default=25)
|
||||
parser.add_argument(
|
||||
|
@ -107,11 +107,9 @@ def create_predictor(args, mode):
|
|||
|
||||
# use zero copy
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
#config.switch_use_feed_fetch_ops(False)
|
||||
config.switch_use_feed_fetch_ops(True)
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
predictor = create_paddle_predictor(config)
|
||||
input_names = predictor.get_input_names()
|
||||
print(input_names)
|
||||
for name in input_names:
|
||||
input_tensor = predictor.get_input_tensor(name)
|
||||
output_names = predictor.get_output_names()
|
||||
|
|
|
@ -162,10 +162,7 @@ def main():
|
|||
|
||||
fluid.io.save_inference_model(
|
||||
"./output/",
|
||||
feeded_var_names=[
|
||||
'image', 'encoder_word_pos', 'gsrm_slf_attn_bias1',
|
||||
'gsrm_slf_attn_bias2', 'gsrm_word_pos'
|
||||
],
|
||||
feeded_var_names=['image'],
|
||||
target_vars=target_var,
|
||||
executor=exe,
|
||||
main_program=eval_prog,
|
||||
|
|
Loading…
Reference in New Issue