diff --git a/ppocr/data/imaug/randaugment.py b/ppocr/data/imaug/randaugment.py index 0bfac353..56f114d2 100644 --- a/ppocr/data/imaug/randaugment.py +++ b/ppocr/data/imaug/randaugment.py @@ -117,13 +117,16 @@ class RawRandAugment(object): class RandAugment(RawRandAugment): """ RandAugment wrapper to auto fit different img types """ - def __init__(self, *args, **kwargs): + def __init__(self, prob=0.5, *args, **kwargs): + self.prob = prob if six.PY2: super(RandAugment, self).__init__(*args, **kwargs) else: super().__init__(*args, **kwargs) def __call__(self, data): + if np.random.rand() > self.prob: + return data img = data['image'] if not isinstance(img, Image.Image): img = np.ascontiguousarray(img) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 074172cc..d2592c6c 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -98,10 +98,10 @@ class TextClassifier(object): norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() starttime = time.time() - self.input_tensor.copy_from_cpu(norm_img_batch) self.predictor.run() prob_out = self.output_tensors[0].copy_to_cpu() + self.predictor.try_shrink_memory() cls_result = self.postprocess_op(prob_out) elapse += time.time() - starttime for rno in range(len(cls_result)): diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index b14825bd..f5ea0504 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -180,7 +180,7 @@ class TextDetector(object): preds['maps'] = outputs[0] else: raise NotImplementedError - + self.predictor.try_shrink_memory() post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] if self.det_algorithm == "SAST" and self.det_sast_polygon: diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index b24e57dd..1cb6e01b 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -237,7 +237,7 @@ class TextRecognizer(object): output = output_tensor.copy_to_cpu() outputs.append(output) preds = outputs[0] - + self.predictor.try_shrink_memory() rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 9aa0afed..7391e936 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -145,7 +145,8 @@ def create_predictor(args, mode, logger): #config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) args.rec_batch_num = 1 - # config.enable_memory_optim() + # enable memory optim + config.enable_memory_optim() config.disable_glog_info() config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")