add python interface
This commit is contained in:
parent
3ebebae3e5
commit
0a011e564a
|
@ -42,6 +42,7 @@ class TextDetector(object):
|
|||
def __init__(self, args):
|
||||
max_side_len = args.det_max_side_len
|
||||
self.det_algorithm = args.det_algorithm
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
preprocess_params = {'max_side_len': max_side_len}
|
||||
postprocess_params = {}
|
||||
if self.det_algorithm == "DB":
|
||||
|
@ -138,8 +139,12 @@ class TextDetector(object):
|
|||
return None, 0
|
||||
im = im.copy()
|
||||
starttime = time.time()
|
||||
im = fluid.core.PaddleTensor(im)
|
||||
self.predictor.run([im])
|
||||
if self.use_zero_copy_run:
|
||||
self.input_tensor.copy_from_cpu(im)
|
||||
self.predictor.zero_copy_run()
|
||||
else:
|
||||
im = fluid.core.PaddleTensor(im)
|
||||
self.predictor.run([im])
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
|
|
|
@ -40,6 +40,7 @@ class TextRecognizer(object):
|
|||
self.character_type = args.rec_char_type
|
||||
self.rec_batch_num = args.rec_batch_num
|
||||
self.rec_algorithm = args.rec_algorithm
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
char_ops_params = {
|
||||
"character_type": args.rec_char_type,
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
|
@ -105,8 +106,12 @@ class TextRecognizer(object):
|
|||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
starttime = time.time()
|
||||
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
||||
self.predictor.run([norm_img_batch])
|
||||
if self.use_zero_copy_run:
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.zero_copy_run()
|
||||
else:
|
||||
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
||||
self.predictor.run([norm_img_batch])
|
||||
|
||||
if self.loss_type == "ctc":
|
||||
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||
|
|
|
@ -71,6 +71,7 @@ def parse_args():
|
|||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
||||
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -105,8 +106,12 @@ def create_predictor(args, mode):
|
|||
#config.enable_memory_optim()
|
||||
config.disable_glog_info()
|
||||
|
||||
# use zero copy
|
||||
config.switch_use_feed_fetch_ops(True)
|
||||
if args.use_zero_copy_run:
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
else:
|
||||
config.switch_use_feed_fetch_ops(True)
|
||||
|
||||
predictor = create_paddle_predictor(config)
|
||||
input_names = predictor.get_input_names()
|
||||
input_tensor = predictor.get_input_tensor(input_names[0])
|
||||
|
|
Loading…
Reference in New Issue