commit
45f647db3f
|
@ -25,9 +25,9 @@
|
|||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class Config {
|
||||
class OCRConfig {
|
||||
public:
|
||||
explicit Config(const std::string &config_file) {
|
||||
explicit OCRConfig(const std::string &config_file) {
|
||||
config_map_ = LoadConfig(config_file);
|
||||
|
||||
this->use_gpu = bool(stoi(config_map_["use_gpu"]));
|
||||
|
@ -41,8 +41,6 @@ public:
|
|||
|
||||
this->use_mkldnn = bool(stoi(config_map_["use_mkldnn"]));
|
||||
|
||||
this->use_zero_copy_run = bool(stoi(config_map_["use_zero_copy_run"]));
|
||||
|
||||
this->max_side_len = stoi(config_map_["max_side_len"]);
|
||||
|
||||
this->det_db_thresh = stod(config_map_["det_db_thresh"]);
|
||||
|
@ -76,8 +74,6 @@ public:
|
|||
|
||||
bool use_mkldnn = false;
|
||||
|
||||
bool use_zero_copy_run = false;
|
||||
|
||||
int max_side_len = 960;
|
||||
|
||||
double det_db_thresh = 0.3;
|
||||
|
|
|
@ -30,6 +30,8 @@
|
|||
#include <include/preprocess_op.h>
|
||||
#include <include/utility.h>
|
||||
|
||||
using namespace paddle_infer;
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class Classifier {
|
||||
|
@ -37,14 +39,12 @@ public:
|
|||
explicit Classifier(const std::string &model_dir, const bool &use_gpu,
|
||||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const bool &use_zero_copy_run,
|
||||
const double &cls_thresh) {
|
||||
const bool &use_mkldnn, const double &cls_thresh) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
this->use_zero_copy_run_ = use_zero_copy_run;
|
||||
|
||||
this->cls_thresh = cls_thresh;
|
||||
|
||||
|
@ -57,14 +57,13 @@ public:
|
|||
cv::Mat Run(cv::Mat &img);
|
||||
|
||||
private:
|
||||
std::shared_ptr<PaddlePredictor> predictor_;
|
||||
std::shared_ptr<Predictor> predictor_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
int gpu_id_ = 0;
|
||||
int gpu_mem_ = 4000;
|
||||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
bool use_zero_copy_run_ = false;
|
||||
double cls_thresh = 0.5;
|
||||
|
||||
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
|
||||
|
|
|
@ -32,6 +32,8 @@
|
|||
#include <include/postprocess_op.h>
|
||||
#include <include/preprocess_op.h>
|
||||
|
||||
using namespace paddle_infer;
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class DBDetector {
|
||||
|
@ -39,8 +41,8 @@ public:
|
|||
explicit DBDetector(const std::string &model_dir, const bool &use_gpu,
|
||||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const bool &use_zero_copy_run,
|
||||
const int &max_side_len, const double &det_db_thresh,
|
||||
const bool &use_mkldnn, const int &max_side_len,
|
||||
const double &det_db_thresh,
|
||||
const double &det_db_box_thresh,
|
||||
const double &det_db_unclip_ratio,
|
||||
const bool &visualize) {
|
||||
|
@ -49,7 +51,6 @@ public:
|
|||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
this->use_zero_copy_run_ = use_zero_copy_run;
|
||||
|
||||
this->max_side_len_ = max_side_len;
|
||||
|
||||
|
@ -69,14 +70,13 @@ public:
|
|||
void Run(cv::Mat &img, std::vector<std::vector<std::vector<int>>> &boxes);
|
||||
|
||||
private:
|
||||
std::shared_ptr<PaddlePredictor> predictor_;
|
||||
std::shared_ptr<Predictor> predictor_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
int gpu_id_ = 0;
|
||||
int gpu_mem_ = 4000;
|
||||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
bool use_zero_copy_run_ = false;
|
||||
|
||||
int max_side_len_ = 960;
|
||||
|
||||
|
|
|
@ -32,6 +32,8 @@
|
|||
#include <include/preprocess_op.h>
|
||||
#include <include/utility.h>
|
||||
|
||||
using namespace paddle_infer;
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class CRNNRecognizer {
|
||||
|
@ -39,14 +41,12 @@ public:
|
|||
explicit CRNNRecognizer(const std::string &model_dir, const bool &use_gpu,
|
||||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const bool &use_zero_copy_run,
|
||||
const string &label_path) {
|
||||
const bool &use_mkldnn, const string &label_path) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
this->use_zero_copy_run_ = use_zero_copy_run;
|
||||
|
||||
this->label_list_ = Utility::ReadDict(label_path);
|
||||
this->label_list_.insert(this->label_list_.begin(),
|
||||
|
@ -63,14 +63,13 @@ public:
|
|||
Classifier *cls);
|
||||
|
||||
private:
|
||||
std::shared_ptr<PaddlePredictor> predictor_;
|
||||
std::shared_ptr<Predictor> predictor_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
int gpu_id_ = 0;
|
||||
int gpu_mem_ = 4000;
|
||||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
bool use_zero_copy_run_ = false;
|
||||
|
||||
std::vector<std::string> label_list_;
|
||||
|
||||
|
|
|
@ -16,8 +16,8 @@
|
|||
|
||||
namespace PaddleOCR {
|
||||
|
||||
std::vector<std::string> Config::split(const std::string &str,
|
||||
const std::string &delim) {
|
||||
std::vector<std::string> OCRConfig::split(const std::string &str,
|
||||
const std::string &delim) {
|
||||
std::vector<std::string> res;
|
||||
if ("" == str)
|
||||
return res;
|
||||
|
@ -38,7 +38,7 @@ std::vector<std::string> Config::split(const std::string &str,
|
|||
}
|
||||
|
||||
std::map<std::string, std::string>
|
||||
Config::LoadConfig(const std::string &config_path) {
|
||||
OCRConfig::LoadConfig(const std::string &config_path) {
|
||||
auto config = Utility::ReadDict(config_path);
|
||||
|
||||
std::map<std::string, std::string> dict;
|
||||
|
@ -53,7 +53,7 @@ Config::LoadConfig(const std::string &config_path) {
|
|||
return dict;
|
||||
}
|
||||
|
||||
void Config::PrintConfigInfo() {
|
||||
void OCRConfig::PrintConfigInfo() {
|
||||
std::cout << "=======Paddle OCR inference config======" << std::endl;
|
||||
for (auto iter = config_map_.begin(); iter != config_map_.end(); iter++) {
|
||||
std::cout << iter->first << " : " << iter->second << std::endl;
|
||||
|
|
|
@ -42,7 +42,7 @@ int main(int argc, char **argv) {
|
|||
exit(1);
|
||||
}
|
||||
|
||||
Config config(argv[1]);
|
||||
OCRConfig config(argv[1]);
|
||||
|
||||
config.PrintConfigInfo();
|
||||
|
||||
|
@ -50,37 +50,22 @@ int main(int argc, char **argv) {
|
|||
|
||||
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
|
||||
|
||||
DBDetector det(
|
||||
config.det_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem,
|
||||
config.cpu_math_library_num_threads, config.use_mkldnn,
|
||||
config.use_zero_copy_run, config.max_side_len, config.det_db_thresh,
|
||||
config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize);
|
||||
DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.max_side_len, config.det_db_thresh,
|
||||
config.det_db_box_thresh, config.det_db_unclip_ratio,
|
||||
config.visualize);
|
||||
|
||||
Classifier *cls = nullptr;
|
||||
if (config.use_angle_cls == true) {
|
||||
cls = new Classifier(config.cls_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.use_zero_copy_run,
|
||||
config.cls_thresh);
|
||||
config.use_mkldnn, config.cls_thresh);
|
||||
}
|
||||
|
||||
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.use_zero_copy_run,
|
||||
config.char_list_file);
|
||||
|
||||
#ifdef USE_MKL
|
||||
#pragma omp parallel
|
||||
for (auto i = 0; i < 10; i++) {
|
||||
LOG_IF(WARNING,
|
||||
config.cpu_math_library_num_threads != omp_get_num_threads())
|
||||
<< "WARNING! MKL is running on " << omp_get_num_threads()
|
||||
<< " threads while cpu_math_library_num_threads is set to "
|
||||
<< config.cpu_math_library_num_threads
|
||||
<< ". Possible reason could be 1. You have set omp_set_num_threads() "
|
||||
"somewhere; 2. MKL is not linked properly";
|
||||
}
|
||||
#endif
|
||||
config.use_mkldnn, config.char_list_file);
|
||||
|
||||
auto start = std::chrono::system_clock::now();
|
||||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
|
|
|
@ -35,26 +35,16 @@ cv::Mat Classifier::Run(cv::Mat &img) {
|
|||
this->permute_op_.Run(&resize_img, input.data());
|
||||
|
||||
// Inference.
|
||||
if (this->use_zero_copy_run_) {
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->copy_from_cpu(input.data());
|
||||
this->predictor_->ZeroCopyRun();
|
||||
} else {
|
||||
paddle::PaddleTensor input_t;
|
||||
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||
input_t.data =
|
||||
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
|
||||
input_t.dtype = PaddleDType::FLOAT32;
|
||||
std::vector<paddle::PaddleTensor> outputs;
|
||||
this->predictor_->Run({input_t}, &outputs, 1);
|
||||
}
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->CopyFromCpu(input.data());
|
||||
this->predictor_->Run();
|
||||
|
||||
std::vector<float> softmax_out;
|
||||
std::vector<int64_t> label_out;
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
auto softmax_out_t = this->predictor_->GetOutputTensor(output_names[0]);
|
||||
auto softmax_out_t = this->predictor_->GetOutputHandle(output_names[0]);
|
||||
auto softmax_shape_out = softmax_out_t->shape();
|
||||
|
||||
int softmax_out_num =
|
||||
|
@ -63,7 +53,7 @@ cv::Mat Classifier::Run(cv::Mat &img) {
|
|||
|
||||
softmax_out.resize(softmax_out_num);
|
||||
|
||||
softmax_out_t->copy_to_cpu(softmax_out.data());
|
||||
softmax_out_t->CopyToCpu(softmax_out.data());
|
||||
|
||||
float score = 0;
|
||||
int label = 0;
|
||||
|
@ -95,7 +85,7 @@ void Classifier::LoadModel(const std::string &model_dir) {
|
|||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_);
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
// true for multiple input
|
||||
config.SwitchSpecifyInputNames(true);
|
||||
|
||||
|
@ -104,6 +94,6 @@ void Classifier::LoadModel(const std::string &model_dir) {
|
|||
config.EnableMemoryOptim();
|
||||
config.DisableGlogInfo();
|
||||
|
||||
this->predictor_ = CreatePaddlePredictor(config);
|
||||
this->predictor_ = CreatePredictor(config);
|
||||
}
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -17,12 +17,17 @@
|
|||
namespace PaddleOCR {
|
||||
|
||||
void DBDetector::LoadModel(const std::string &model_dir) {
|
||||
AnalysisConfig config;
|
||||
// AnalysisConfig config;
|
||||
paddle_infer::Config config;
|
||||
config.SetModel(model_dir + "/inference.pdmodel",
|
||||
model_dir + "/inference.pdiparams");
|
||||
|
||||
if (this->use_gpu_) {
|
||||
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
|
||||
// config.EnableTensorRtEngine(
|
||||
// 1 << 20, 1, 3,
|
||||
// AnalysisConfig::Precision::kFloat32,
|
||||
// false, false);
|
||||
} else {
|
||||
config.DisableGpu();
|
||||
if (this->use_mkldnn_) {
|
||||
|
@ -32,10 +37,8 @@ void DBDetector::LoadModel(const std::string &model_dir) {
|
|||
}
|
||||
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
|
||||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
// true for commom tensor
|
||||
config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_);
|
||||
// use zero_copy_run as default
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
// true for multiple input
|
||||
config.SwitchSpecifyInputNames(true);
|
||||
|
||||
|
@ -44,7 +47,7 @@ void DBDetector::LoadModel(const std::string &model_dir) {
|
|||
config.EnableMemoryOptim();
|
||||
config.DisableGlogInfo();
|
||||
|
||||
this->predictor_ = CreatePaddlePredictor(config);
|
||||
this->predictor_ = CreatePredictor(config);
|
||||
}
|
||||
|
||||
void DBDetector::Run(cv::Mat &img,
|
||||
|
@ -64,31 +67,21 @@ void DBDetector::Run(cv::Mat &img,
|
|||
this->permute_op_.Run(&resize_img, input.data());
|
||||
|
||||
// Inference.
|
||||
if (this->use_zero_copy_run_) {
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->copy_from_cpu(input.data());
|
||||
this->predictor_->ZeroCopyRun();
|
||||
} else {
|
||||
paddle::PaddleTensor input_t;
|
||||
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||
input_t.data =
|
||||
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
|
||||
input_t.dtype = PaddleDType::FLOAT32;
|
||||
std::vector<paddle::PaddleTensor> outputs;
|
||||
this->predictor_->Run({input_t}, &outputs, 1);
|
||||
}
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->CopyFromCpu(input.data());
|
||||
this->predictor_->Run();
|
||||
|
||||
std::vector<float> out_data;
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
auto output_t = this->predictor_->GetOutputTensor(output_names[0]);
|
||||
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
|
||||
std::vector<int> output_shape = output_t->shape();
|
||||
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
|
||||
std::multiplies<int>());
|
||||
|
||||
out_data.resize(out_num);
|
||||
output_t->copy_to_cpu(out_data.data());
|
||||
output_t->CopyToCpu(out_data.data());
|
||||
|
||||
int n2 = output_shape[2];
|
||||
int n3 = output_shape[3];
|
||||
|
|
|
@ -43,32 +43,22 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
this->permute_op_.Run(&resize_img, input.data());
|
||||
|
||||
// Inference.
|
||||
if (this->use_zero_copy_run_) {
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->copy_from_cpu(input.data());
|
||||
this->predictor_->ZeroCopyRun();
|
||||
} else {
|
||||
paddle::PaddleTensor input_t;
|
||||
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||
input_t.data =
|
||||
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
|
||||
input_t.dtype = PaddleDType::FLOAT32;
|
||||
std::vector<paddle::PaddleTensor> outputs;
|
||||
this->predictor_->Run({input_t}, &outputs, 1);
|
||||
}
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->CopyFromCpu(input.data());
|
||||
this->predictor_->Run();
|
||||
|
||||
std::vector<float> predict_batch;
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
auto output_t = this->predictor_->GetOutputTensor(output_names[0]);
|
||||
auto output_t = this->predictor_->GetOutputHandle(output_names[0]);
|
||||
auto predict_shape = output_t->shape();
|
||||
|
||||
int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1,
|
||||
std::multiplies<int>());
|
||||
predict_batch.resize(out_num);
|
||||
|
||||
output_t->copy_to_cpu(predict_batch.data());
|
||||
output_t->CopyToCpu(predict_batch.data());
|
||||
|
||||
// ctc decode
|
||||
std::vector<std::string> str_res;
|
||||
|
@ -102,7 +92,8 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
}
|
||||
|
||||
void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
||||
AnalysisConfig config;
|
||||
// AnalysisConfig config;
|
||||
paddle_infer::Config config;
|
||||
config.SetModel(model_dir + "/inference.pdmodel",
|
||||
model_dir + "/inference.pdiparams");
|
||||
|
||||
|
@ -118,9 +109,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
|||
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
|
||||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
// true for commom tensor
|
||||
config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_);
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
// true for multiple input
|
||||
config.SwitchSpecifyInputNames(true);
|
||||
|
||||
|
@ -129,7 +118,7 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
|||
config.EnableMemoryOptim();
|
||||
config.DisableGlogInfo();
|
||||
|
||||
this->predictor_ = CreatePaddlePredictor(config);
|
||||
this->predictor_ = CreatePredictor(config);
|
||||
}
|
||||
|
||||
cv::Mat CRNNRecognizer::GetRotateCropImage(const cv::Mat &srcimage,
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
# model load config
|
||||
use_gpu 0
|
||||
use_gpu 0
|
||||
gpu_id 0
|
||||
gpu_mem 4000
|
||||
cpu_math_library_num_threads 10
|
||||
use_mkldnn 0
|
||||
use_zero_copy_run 1
|
||||
|
||||
# det config
|
||||
max_side_len 960
|
||||
|
|
Loading…
Reference in New Issue