Merge pull request #592 from littletomatodonkey/fix_predictor_run
replace zero_copy_run to run for memory leak
This commit is contained in:
commit
967115b8e9
|
@ -41,6 +41,8 @@ public:
|
|||
|
||||
this->use_mkldnn = bool(stoi(config_map_["use_mkldnn"]));
|
||||
|
||||
this->use_zero_copy_run = bool(stoi(config_map_["use_zero_copy_run"]));
|
||||
|
||||
this->max_side_len = stoi(config_map_["max_side_len"]);
|
||||
|
||||
this->det_db_thresh = stod(config_map_["det_db_thresh"]);
|
||||
|
@ -68,6 +70,8 @@ public:
|
|||
|
||||
bool use_mkldnn = false;
|
||||
|
||||
bool use_zero_copy_run = false;
|
||||
|
||||
int max_side_len = 960;
|
||||
|
||||
double det_db_thresh = 0.3;
|
||||
|
|
|
@ -39,8 +39,8 @@ public:
|
|||
explicit DBDetector(const std::string &model_dir, const bool &use_gpu,
|
||||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const int &max_side_len,
|
||||
const double &det_db_thresh,
|
||||
const bool &use_mkldnn, const bool &use_zero_copy_run,
|
||||
const int &max_side_len, const double &det_db_thresh,
|
||||
const double &det_db_box_thresh,
|
||||
const double &det_db_unclip_ratio,
|
||||
const bool &visualize) {
|
||||
|
@ -49,6 +49,7 @@ public:
|
|||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
this->use_zero_copy_run_ = use_zero_copy_run;
|
||||
|
||||
this->max_side_len_ = max_side_len;
|
||||
|
||||
|
@ -75,6 +76,7 @@ private:
|
|||
int gpu_mem_ = 4000;
|
||||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
bool use_zero_copy_run_ = false;
|
||||
|
||||
int max_side_len_ = 960;
|
||||
|
||||
|
|
|
@ -38,12 +38,14 @@ public:
|
|||
explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu,
|
||||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const string &label_path) {
|
||||
const bool &use_mkldnn, const bool &use_zero_copy_run,
|
||||
const string &label_path) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
this->use_zero_copy_run_ = use_zero_copy_run;
|
||||
|
||||
this->label_list_ = Utility::ReadDict(label_path);
|
||||
this->label_list_.push_back(" ");
|
||||
|
@ -64,6 +66,7 @@ private:
|
|||
int gpu_mem_ = 4000;
|
||||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
bool use_zero_copy_run_ = false;
|
||||
|
||||
std::vector<std::string> label_list_;
|
||||
|
||||
|
|
|
@ -48,14 +48,15 @@ int main(int argc, char **argv) {
|
|||
|
||||
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
|
||||
|
||||
DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.max_side_len, config.det_db_thresh,
|
||||
config.det_db_box_thresh, config.det_db_unclip_ratio,
|
||||
config.visualize);
|
||||
DBDetector det(
|
||||
config.det_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem,
|
||||
config.cpu_math_library_num_threads, config.use_mkldnn,
|
||||
config.use_zero_copy_run, config.max_side_len, config.det_db_thresh,
|
||||
config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize);
|
||||
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.char_list_file);
|
||||
config.use_mkldnn, config.use_zero_copy_run,
|
||||
config.char_list_file);
|
||||
|
||||
auto start = std::chrono::system_clock::now();
|
||||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
|
|
|
@ -31,7 +31,8 @@ void DBDetector::LoadModel(const std::string &model_dir) {
|
|||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
// true for commom tensor
|
||||
config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_);
|
||||
// true for multiple input
|
||||
config.SwitchSpecifyInputNames(true);
|
||||
|
||||
|
@ -59,12 +60,22 @@ void DBDetector::Run(cv::Mat &img,
|
|||
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
||||
this->permute_op_.Run(&resize_img, input.data());
|
||||
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->copy_from_cpu(input.data());
|
||||
|
||||
this->predictor_->ZeroCopyRun();
|
||||
// Inference.
|
||||
if (this->use_zero_copy_run_) {
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->copy_from_cpu(input.data());
|
||||
this->predictor_->ZeroCopyRun();
|
||||
} else {
|
||||
paddle::PaddleTensor input_t;
|
||||
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||
input_t.data =
|
||||
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
|
||||
input_t.dtype = PaddleDType::FLOAT32;
|
||||
std::vector<paddle::PaddleTensor> outputs;
|
||||
this->predictor_->Run({input_t}, &outputs, 1);
|
||||
}
|
||||
|
||||
std::vector<float> out_data;
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
|
|
|
@ -39,18 +39,29 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
|
||||
this->permute_op_.Run(&resize_img, input.data());
|
||||
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->copy_from_cpu(input.data());
|
||||
|
||||
this->predictor_->ZeroCopyRun();
|
||||
// Inference.
|
||||
if (this->use_zero_copy_run_) {
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->copy_from_cpu(input.data());
|
||||
this->predictor_->ZeroCopyRun();
|
||||
} else {
|
||||
paddle::PaddleTensor input_t;
|
||||
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||
input_t.data =
|
||||
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
|
||||
input_t.dtype = PaddleDType::FLOAT32;
|
||||
std::vector<paddle::PaddleTensor> outputs;
|
||||
this->predictor_->Run({input_t}, &outputs, 1);
|
||||
}
|
||||
|
||||
std::vector<int64_t> rec_idx;
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
auto output_t = this->predictor_->GetOutputTensor(output_names[0]);
|
||||
auto rec_idx_lod = output_t->lod();
|
||||
auto shape_out = output_t->shape();
|
||||
|
||||
int out_num = std::accumulate(shape_out.begin(), shape_out.end(), 1,
|
||||
std::multiplies<int>());
|
||||
|
||||
|
@ -120,7 +131,8 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
|||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
// true for commom tensor
|
||||
config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_);
|
||||
// true for multiple input
|
||||
config.SwitchSpecifyInputNames(true);
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ gpu_id 0
|
|||
gpu_mem 4000
|
||||
cpu_math_library_num_threads 10
|
||||
use_mkldnn 0
|
||||
use_zero_copy_run 1
|
||||
|
||||
# det config
|
||||
max_side_len 960
|
||||
|
|
|
@ -17,28 +17,32 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
|
|||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
import cv2
|
||||
from ppocr.data.det.sast_process import SASTProcessTest
|
||||
from ppocr.data.det.east_process import EASTProcessTest
|
||||
from ppocr.data.det.db_process import DBProcessTest
|
||||
from ppocr.postprocess.db_postprocess import DBPostProcess
|
||||
from ppocr.postprocess.east_postprocess import EASTPostPocess
|
||||
from ppocr.postprocess.sast_postprocess import SASTPostProcess
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import sys
|
||||
|
||||
import paddle.fluid as fluid
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.data.det.sast_process import SASTProcessTest
|
||||
from ppocr.data.det.east_process import EASTProcessTest
|
||||
from ppocr.data.det.db_process import DBProcessTest
|
||||
from ppocr.postprocess.db_postprocess import DBPostProcess
|
||||
from ppocr.postprocess.east_postprocess import EASTPostPocess
|
||||
from ppocr.postprocess.sast_postprocess import SASTPostProcess
|
||||
|
||||
|
||||
class TextDetector(object):
|
||||
def __init__(self, args):
|
||||
max_side_len = args.det_max_side_len
|
||||
self.det_algorithm = args.det_algorithm
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
preprocess_params = {'max_side_len': max_side_len}
|
||||
postprocess_params = {}
|
||||
if self.det_algorithm == "DB":
|
||||
|
@ -127,7 +131,7 @@ class TextDetector(object):
|
|||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
im, ratio_list = self.preprocess_op(img)
|
||||
|
@ -135,8 +139,12 @@ class TextDetector(object):
|
|||
return None, 0
|
||||
im = im.copy()
|
||||
starttime = time.time()
|
||||
self.input_tensor.copy_from_cpu(im)
|
||||
self.predictor.zero_copy_run()
|
||||
if self.use_zero_copy_run:
|
||||
self.input_tensor.copy_from_cpu(im)
|
||||
self.predictor.zero_copy_run()
|
||||
else:
|
||||
im = fluid.core.PaddleTensor(im)
|
||||
self.predictor.run([im])
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
|
@ -152,7 +160,7 @@ class TextDetector(object):
|
|||
outs_dict['f_tvo'] = outputs[3]
|
||||
else:
|
||||
outs_dict['maps'] = outputs[0]
|
||||
|
||||
|
||||
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
|
||||
dt_boxes = dt_boxes_list[0]
|
||||
if self.det_algorithm == "SAST" and self.det_sast_polygon:
|
||||
|
|
|
@ -17,15 +17,18 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
|
|||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
|
||||
import paddle.fluid as fluid
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.character import CharacterOps
|
||||
|
||||
|
||||
|
@ -37,6 +40,7 @@ class TextRecognizer(object):
|
|||
self.character_type = args.rec_char_type
|
||||
self.rec_batch_num = args.rec_batch_num
|
||||
self.rec_algorithm = args.rec_algorithm
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
char_ops_params = {
|
||||
"character_type": args.rec_char_type,
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
|
@ -102,8 +106,12 @@ class TextRecognizer(object):
|
|||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
starttime = time.time()
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.zero_copy_run()
|
||||
if self.use_zero_copy_run:
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.zero_copy_run()
|
||||
else:
|
||||
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
||||
self.predictor.run([norm_img_batch])
|
||||
|
||||
if self.loss_type == "ctc":
|
||||
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||
|
|
|
@ -71,6 +71,7 @@ def parse_args():
|
|||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
||||
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -105,9 +106,12 @@ def create_predictor(args, mode):
|
|||
#config.enable_memory_optim()
|
||||
config.disable_glog_info()
|
||||
|
||||
# use zero copy
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
if args.use_zero_copy_run:
|
||||
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
|
||||
config.switch_use_feed_fetch_ops(False)
|
||||
else:
|
||||
config.switch_use_feed_fetch_ops(True)
|
||||
|
||||
predictor = create_paddle_predictor(config)
|
||||
input_names = predictor.get_input_names()
|
||||
input_tensor = predictor.get_input_tensor(input_names[0])
|
||||
|
|
Loading…
Reference in New Issue