添加分类模型
This commit is contained in:
parent
7c09c97d70
commit
e11b2108fa
|
@ -0,0 +1,43 @@
|
|||
Global:
|
||||
algorithm: CLS
|
||||
use_gpu: false
|
||||
epoch_num: 30
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output/cls_mb3
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: 100
|
||||
train_batch_size_per_card: 256
|
||||
test_batch_size_per_card: 256
|
||||
image_shape: [3, 32, 100]
|
||||
label_list: [0,180]
|
||||
reader_yml: ./configs/cls/cls_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints: /Users/zhoujun20/Desktop/code/class_model/cls_mb3_ultra_small_0.35/best_accuracy
|
||||
save_inference_dir:
|
||||
infer_img: /Users/zhoujun20/Desktop/code/PaddleOCR/doc/imgs_words/ch/word_1.jpg
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.cls_model,ClsModel
|
||||
|
||||
Backbone:
|
||||
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
|
||||
scale: 0.35
|
||||
model_name: Ultra_small
|
||||
|
||||
Head:
|
||||
function: ppocr.modeling.heads.cls_head,ClsHead
|
||||
class_dim: 2
|
||||
|
||||
Loss:
|
||||
function: ppocr.modeling.losses.cls_loss,ClsLoss
|
||||
|
||||
Optimizer:
|
||||
function: ppocr.optimizer,AdamDecay
|
||||
base_lr: 0.001
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
decay:
|
||||
function: piecewise_decay
|
||||
boundaries: [20,30]
|
||||
decay_rate: 0.1
|
|
@ -0,0 +1,13 @@
|
|||
TrainReader:
|
||||
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
|
||||
num_workers: 1
|
||||
img_set_dir: /
|
||||
label_file_path: /Users/zhoujun20/Downloads/direction/rotate_ver/train.txt
|
||||
|
||||
EvalReader:
|
||||
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
|
||||
img_set_dir: /
|
||||
label_file_path: /Users/zhoujun20/Downloads/direction/rotate_ver/train.txt
|
||||
|
||||
TestReader:
|
||||
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
|
|
@ -55,6 +55,10 @@ public:
|
|||
|
||||
this->char_list_file.assign(config_map_["char_list_file"]);
|
||||
|
||||
this->cls_model_dir.assign(config_map_["cls_model_dir"]);
|
||||
|
||||
this->cls_thresh = stod(config_map_["cls_thresh"]);
|
||||
|
||||
this->visualize = bool(stoi(config_map_["visualize"]));
|
||||
}
|
||||
|
||||
|
@ -82,6 +86,10 @@ public:
|
|||
|
||||
std::string char_list_file;
|
||||
|
||||
std::string cls_model_dir;
|
||||
|
||||
double cls_thresh;
|
||||
|
||||
bool visualize = true;
|
||||
|
||||
void PrintConfigInfo();
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#include "paddle_api.h"
|
||||
#include "paddle_inference_api.h"
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <include/preprocess_op.h>
|
||||
#include <include/utility.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class Classifier {
|
||||
public:
|
||||
explicit Classifier(const std::string &model_dir, const bool &use_gpu,
|
||||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const double &cls_thresh) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
|
||||
this->cls_thresh = cls_thresh;
|
||||
|
||||
LoadModel(model_dir);
|
||||
}
|
||||
|
||||
// Load Paddle inference model
|
||||
void LoadModel(const std::string &model_dir);
|
||||
|
||||
cv::Mat Run(cv::Mat &img);
|
||||
|
||||
private:
|
||||
std::shared_ptr<PaddlePredictor> predictor_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
int gpu_id_ = 0;
|
||||
int gpu_mem_ = 4000;
|
||||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
|
||||
double cls_thresh = 0.5;
|
||||
|
||||
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
bool is_scale_ = true;
|
||||
|
||||
// pre-process
|
||||
ClsResizeImg resize_op_;
|
||||
Normalize normalize_op_;
|
||||
Permute permute_op_;
|
||||
|
||||
}; // class Classifier
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -27,6 +27,7 @@
|
|||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <include/ocr_cls.h>
|
||||
#include <include/postprocess_op.h>
|
||||
#include <include/preprocess_op.h>
|
||||
#include <include/utility.h>
|
||||
|
@ -54,7 +55,8 @@ public:
|
|||
// Load Paddle inference model
|
||||
void LoadModel(const std::string &model_dir);
|
||||
|
||||
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img);
|
||||
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img,
|
||||
Classifier &cls);
|
||||
|
||||
private:
|
||||
std::shared_ptr<PaddlePredictor> predictor_;
|
||||
|
|
|
@ -56,4 +56,10 @@ public:
|
|||
const std::vector<int> &rec_image_shape = {3, 32, 320});
|
||||
};
|
||||
|
||||
class ClsResizeImg {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const std::vector<int> &rec_image_shape = {3, 32, 320});
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -53,6 +53,9 @@ int main(int argc, char **argv) {
|
|||
config.use_mkldnn, config.max_side_len, config.det_db_thresh,
|
||||
config.det_db_box_thresh, config.det_db_unclip_ratio,
|
||||
config.visualize);
|
||||
Classifier cls(config.cls_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.cls_thresh);
|
||||
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.char_list_file);
|
||||
|
@ -61,7 +64,7 @@ int main(int argc, char **argv) {
|
|||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
det.Run(srcimg, boxes);
|
||||
|
||||
rec.Run(boxes, srcimg);
|
||||
rec.Run(boxes, srcimg, cls);
|
||||
|
||||
auto end = std::chrono::system_clock::now();
|
||||
auto duration =
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <include/ocr_cls.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
cv::Mat Classifier::Run(cv::Mat &img) {
|
||||
cv::Mat src_img;
|
||||
img.copyTo(src_img);
|
||||
cv::Mat resize_img;
|
||||
|
||||
std::vector<int> rec_image_shape = {3, 32, 100};
|
||||
int index = 0;
|
||||
float wh_ratio = float(img.cols) / float(img.rows);
|
||||
|
||||
this->resize_op_.Run(img, resize_img, rec_image_shape);
|
||||
|
||||
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||
this->is_scale_);
|
||||
|
||||
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
||||
|
||||
this->permute_op_.Run(&resize_img, input.data());
|
||||
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->copy_from_cpu(input.data());
|
||||
|
||||
this->predictor_->ZeroCopyRun();
|
||||
|
||||
std::vector<float> softmax_out;
|
||||
std::vector<int64_t> label_out;
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
auto softmax_out_t = this->predictor_->GetOutputTensor(output_names[0]);
|
||||
auto label_out_t = this->predictor_->GetOutputTensor(output_names[1]);
|
||||
auto softmax_shape_out = softmax_out_t->shape();
|
||||
auto label_shape_out = label_out_t->shape();
|
||||
|
||||
int softmax_out_num =
|
||||
std::accumulate(softmax_shape_out.begin(), softmax_shape_out.end(), 1,
|
||||
std::multiplies<int>());
|
||||
|
||||
int label_out_num =
|
||||
std::accumulate(label_shape_out.begin(), label_shape_out.end(), 1,
|
||||
std::multiplies<int>());
|
||||
softmax_out.resize(softmax_out_num);
|
||||
label_out.resize(label_out_num);
|
||||
|
||||
softmax_out_t->copy_to_cpu(softmax_out.data());
|
||||
label_out_t->copy_to_cpu(label_out.data());
|
||||
|
||||
int label = label_out[0];
|
||||
float score = softmax_out[label];
|
||||
// std::cout << "\nlabel "<<label<<" score: "<<score;
|
||||
if (label % 2 == 1 && score > this->cls_thresh) {
|
||||
cv::rotate(src_img, src_img, 1);
|
||||
}
|
||||
return src_img;
|
||||
}
|
||||
|
||||
void Classifier::LoadModel(const std::string &model_dir) {
|
||||
AnalysisConfig config;
|
||||
config.SetModel(model_dir + "/model", model_dir + "/params");
|
||||
|
||||
if (this->use_gpu_) {
|
||||
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
|
||||
} else {
|
||||
config.DisableGpu();
|
||||
if (this->use_mkldnn_) {
|
||||
config.EnableMKLDNN();
|
||||
}
|
||||
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
|
||||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
config.SwitchUseFeedFetchOps(false);
|
||||
// true for multiple input
|
||||
config.SwitchSpecifyInputNames(true);
|
||||
|
||||
config.SwitchIrOptim(true);
|
||||
|
||||
config.EnableMemoryOptim();
|
||||
config.DisableGlogInfo();
|
||||
|
||||
this->predictor_ = CreatePaddlePredictor(config);
|
||||
}
|
||||
} // namespace PaddleOCR
|
|
@ -17,7 +17,7 @@
|
|||
namespace PaddleOCR {
|
||||
|
||||
void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
||||
cv::Mat &img) {
|
||||
cv::Mat &img, Classifier &cls) {
|
||||
cv::Mat srcimg;
|
||||
img.copyTo(srcimg);
|
||||
cv::Mat crop_img;
|
||||
|
@ -28,6 +28,8 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
for (int i = boxes.size() - 1; i >= 0; i--) {
|
||||
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
||||
|
||||
crop_img = cls.Run(crop_img);
|
||||
|
||||
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
|
||||
|
||||
this->resize_op_.Run(crop_img, resize_img, wh_ratio);
|
||||
|
|
|
@ -116,4 +116,26 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
|
|||
cv::INTER_LINEAR);
|
||||
}
|
||||
|
||||
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const std::vector<int> &rec_image_shape) {
|
||||
int imgC, imgH, imgW;
|
||||
imgC = rec_image_shape[0];
|
||||
imgH = rec_image_shape[1];
|
||||
imgW = rec_image_shape[2];
|
||||
|
||||
float ratio = float(img.cols) / float(img.rows);
|
||||
int resize_w, resize_h;
|
||||
if (ceilf(imgH * ratio) > imgW)
|
||||
resize_w = imgW;
|
||||
else
|
||||
resize_w = int(ceilf(imgH * ratio));
|
||||
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||
cv::INTER_LINEAR);
|
||||
if (resize_w < imgW) {
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
|
||||
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -40,8 +40,8 @@ CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SY
|
|||
|
||||
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
|
||||
|
||||
ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o
|
||||
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
|
||||
ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o
|
||||
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
|
||||
|
||||
ocr_db_crnn.o: ocr_db_crnn.cc
|
||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc
|
||||
|
@ -49,6 +49,9 @@ ocr_db_crnn.o: ocr_db_crnn.cc
|
|||
crnn_process.o: fetch_opencv crnn_process.cc
|
||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o crnn_process.o -c crnn_process.cc
|
||||
|
||||
cls_process.o: fetch_opencv cls_process.cc
|
||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o cls_process.o -c cls_process.cc
|
||||
|
||||
db_post_process.o: fetch_clipper fetch_opencv db_post_process.cc
|
||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o db_post_process.o -c db_post_process.cc
|
||||
|
||||
|
@ -73,5 +76,5 @@ fetch_opencv:
|
|||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o
|
||||
rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o cls_process.o
|
||||
rm -f ocr_db_crnn
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "cls_process.h" //NOLINT
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
const std::vector<int> rec_image_shape{3, 32, 100};
|
||||
|
||||
cv::Mat ClsResizeImg(cv::Mat img) {
|
||||
int imgC, imgH, imgW;
|
||||
imgC = rec_image_shape[0];
|
||||
imgH = rec_image_shape[1];
|
||||
imgW = rec_image_shape[2];
|
||||
|
||||
float ratio = static_cast<float>(img.cols) / static_cast<float>(img.rows);
|
||||
|
||||
int resize_w, resize_h;
|
||||
if (ceilf(imgH * ratio) > imgW)
|
||||
resize_w = imgW;
|
||||
else
|
||||
resize_w = int(ceilf(imgH * ratio));
|
||||
cv::Mat resize_img;
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||
cv::INTER_LINEAR);
|
||||
if (resize_w < imgW) {
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
|
||||
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||
}
|
||||
return resize_img;
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "math.h" //NOLINT
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
|
||||
cv::Mat ClsResizeImg(cv::Mat img);
|
|
@ -15,6 +15,7 @@
|
|||
#include "paddle_api.h" // NOLINT
|
||||
#include <chrono>
|
||||
|
||||
#include "cls_process.h"
|
||||
#include "crnn_process.h"
|
||||
#include "db_post_process.h"
|
||||
|
||||
|
@ -105,11 +106,55 @@ cv::Mat DetResizeImg(const cv::Mat img, int max_size_len,
|
|||
return resize_img;
|
||||
}
|
||||
|
||||
cv::Mat RunClsModel(cv::Mat img, std::shared_ptr<PaddlePredictor> predictor_cls,
|
||||
const float thresh = 0.5) {
|
||||
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
|
||||
cv::Mat srcimg;
|
||||
img.copyTo(srcimg);
|
||||
cv::Mat crop_img;
|
||||
cv::Mat resize_img;
|
||||
|
||||
int index = 0;
|
||||
float wh_ratio =
|
||||
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
|
||||
|
||||
resize_img = ClsResizeImg(crop_img);
|
||||
resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
|
||||
|
||||
const float *dimg = reinterpret_cast<const float *>(resize_img.data);
|
||||
|
||||
std::unique_ptr<Tensor> input_tensor0(std::move(predictor_cls->GetInput(0)));
|
||||
input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols});
|
||||
auto *data0 = input_tensor0->mutable_data<float>();
|
||||
|
||||
NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale);
|
||||
// Run CLS predictor
|
||||
predictor_cls->Run();
|
||||
|
||||
// Get output and run postprocess
|
||||
std::unique_ptr<const Tensor> softmax_out(
|
||||
std::move(predictor_cls->GetOutput(0)));
|
||||
std::unique_ptr<const Tensor> label_out(
|
||||
std::move(predictor_cls->GetOutput(1)));
|
||||
auto *softmax_scores = softmax_out->mutable_data<float>();
|
||||
auto *label_idxs = label_out->data<int64>();
|
||||
int label_idx = label_idxs[0];
|
||||
float score = softmax_scores[label_idx];
|
||||
|
||||
if (label_idx % 2 == 1 && score > thresh) {
|
||||
cv::rotate(srcimg, srcimg, 1);
|
||||
}
|
||||
return srcimg;
|
||||
}
|
||||
|
||||
void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
|
||||
std::shared_ptr<PaddlePredictor> predictor_crnn,
|
||||
std::vector<std::string> &rec_text,
|
||||
std::vector<float> &rec_text_score,
|
||||
std::vector<std::string> charactor_dict) {
|
||||
std::vector<std::string> charactor_dict,
|
||||
std::shared_ptr<PaddlePredictor> predictor_cls) {
|
||||
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
|
||||
|
@ -121,6 +166,7 @@ void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
|
|||
int index = 0;
|
||||
for (int i = boxes.size() - 1; i >= 0; i--) {
|
||||
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
||||
crop_img = RunClsModel(crop_img, predictor_cls);
|
||||
float wh_ratio =
|
||||
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
|
||||
|
||||
|
@ -323,8 +369,9 @@ int main(int argc, char **argv) {
|
|||
}
|
||||
std::string det_model_file = argv[1];
|
||||
std::string rec_model_file = argv[2];
|
||||
std::string img_path = argv[3];
|
||||
std::string dict_path = argv[4];
|
||||
std::string cls_model_file = argv[3];
|
||||
std::string img_path = argv[4];
|
||||
std::string dict_path = argv[5];
|
||||
|
||||
//// load config from txt file
|
||||
auto Config = LoadConfigTxt("./config.txt");
|
||||
|
@ -333,6 +380,7 @@ int main(int argc, char **argv) {
|
|||
|
||||
auto det_predictor = loadModel(det_model_file);
|
||||
auto rec_predictor = loadModel(rec_model_file);
|
||||
auto cls_predictor = loadModel(cls_model_file);
|
||||
|
||||
auto charactor_dict = ReadDict(dict_path);
|
||||
charactor_dict.push_back(" ");
|
||||
|
@ -343,7 +391,7 @@ int main(int argc, char **argv) {
|
|||
std::vector<std::string> rec_text;
|
||||
std::vector<float> rec_text_score;
|
||||
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
|
||||
charactor_dict);
|
||||
charactor_dict, cls_predictor);
|
||||
|
||||
auto end = std::chrono::system_clock::now();
|
||||
auto duration =
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,128 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from ppocr.utils.utility import initial_logger
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
|
||||
logger = initial_logger()
|
||||
|
||||
from ppocr.data.rec.img_tools import warp, resize_norm_img
|
||||
|
||||
|
||||
class SimpleReader(object):
|
||||
def __init__(self, params):
|
||||
if params['mode'] != 'train':
|
||||
self.num_workers = 1
|
||||
else:
|
||||
self.num_workers = params['num_workers']
|
||||
if params['mode'] != 'test':
|
||||
self.img_set_dir = params['img_set_dir']
|
||||
self.label_file_path = params['label_file_path']
|
||||
self.use_gpu = params['use_gpu']
|
||||
self.image_shape = params['image_shape']
|
||||
self.mode = params['mode']
|
||||
self.infer_img = params['infer_img']
|
||||
self.use_distort = False
|
||||
self.label_list = params['label_list']
|
||||
if "distort" in params:
|
||||
self.use_distort = params['distort'] and params['use_gpu']
|
||||
if not params['use_gpu']:
|
||||
logger.info(
|
||||
"Distort operation can only support in GPU.Distort will be set to False."
|
||||
)
|
||||
if params['mode'] == 'train':
|
||||
self.batch_size = params['train_batch_size_per_card']
|
||||
self.drop_last = True
|
||||
else:
|
||||
self.batch_size = params['test_batch_size_per_card']
|
||||
self.drop_last = False
|
||||
self.use_distort = False
|
||||
|
||||
def __call__(self, process_id):
|
||||
if self.mode != 'train':
|
||||
process_id = 0
|
||||
|
||||
def get_device_num():
|
||||
if self.use_gpu:
|
||||
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", 1)
|
||||
gpu_num = len(gpus.split(','))
|
||||
return gpu_num
|
||||
else:
|
||||
cpu_num = os.environ.get("CPU_NUM", 1)
|
||||
return int(cpu_num)
|
||||
|
||||
def sample_iter_reader():
|
||||
if self.mode != 'train' and self.infer_img is not None:
|
||||
image_file_list = get_image_file_list(self.infer_img)
|
||||
for single_img in image_file_list:
|
||||
img = cv2.imread(single_img)
|
||||
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
norm_img = resize_norm_img(img, self.image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
yield norm_img
|
||||
else:
|
||||
with open(self.label_file_path, "rb") as fin:
|
||||
label_infor_list = fin.readlines()
|
||||
img_num = len(label_infor_list)
|
||||
img_id_list = list(range(img_num))
|
||||
random.shuffle(img_id_list)
|
||||
if sys.platform == "win32" and self.num_workers != 1:
|
||||
print("multiprocess is not fully compatible with Windows."
|
||||
"num_workers will be 1.")
|
||||
self.num_workers = 1
|
||||
if self.batch_size * get_device_num(
|
||||
) * self.num_workers > img_num:
|
||||
raise Exception(
|
||||
"The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})".
|
||||
format(img_num, self.batch_size * get_device_num() *
|
||||
self.num_workers))
|
||||
for img_id in range(process_id, img_num, self.num_workers):
|
||||
label_infor = label_infor_list[img_id_list[img_id]]
|
||||
substr = label_infor.decode('utf-8').strip("\n").split("\t")
|
||||
img_path = self.img_set_dir + "/" + substr[0]
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
logger.info("{} does not exist!".format(img_path))
|
||||
continue
|
||||
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
label = substr[1]
|
||||
if self.use_distort:
|
||||
img = warp(img, 10)
|
||||
norm_img = resize_norm_img(img, self.image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
yield (norm_img, self.label_list.index(int(label)))
|
||||
|
||||
def batch_iter_reader():
|
||||
batch_outs = []
|
||||
for outs in sample_iter_reader():
|
||||
batch_outs.append(outs)
|
||||
if len(batch_outs) == self.batch_size:
|
||||
yield batch_outs
|
||||
batch_outs = []
|
||||
if not self.drop_last:
|
||||
if len(batch_outs) != 0:
|
||||
yield batch_outs
|
||||
|
||||
if self.infer_img is None:
|
||||
return batch_iter_reader
|
||||
return sample_iter_reader
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import fluid
|
||||
|
||||
from ppocr.utils.utility import create_module
|
||||
from ppocr.utils.utility import initial_logger
|
||||
|
||||
logger = initial_logger()
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class ClsModel(object):
|
||||
def __init__(self, params):
|
||||
super(ClsModel, self).__init__()
|
||||
global_params = params['Global']
|
||||
self.infer_img = global_params['infer_img']
|
||||
|
||||
backbone_params = deepcopy(params["Backbone"])
|
||||
backbone_params.update(global_params)
|
||||
self.backbone = create_module(backbone_params['function']) \
|
||||
(params=backbone_params)
|
||||
|
||||
head_params = deepcopy(params["Head"])
|
||||
head_params.update(global_params)
|
||||
self.head = create_module(head_params['function']) \
|
||||
(params=head_params)
|
||||
|
||||
loss_params = deepcopy(params["Loss"])
|
||||
loss_params.update(global_params)
|
||||
self.loss = create_module(loss_params['function']) \
|
||||
(params=loss_params)
|
||||
|
||||
self.image_shape = global_params['image_shape']
|
||||
|
||||
def create_feed(self, mode):
|
||||
image_shape = deepcopy(self.image_shape)
|
||||
image_shape.insert(0, -1)
|
||||
if mode == "train":
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
|
||||
feed_list = [image, label]
|
||||
labels = {'label': label}
|
||||
loader = fluid.io.DataLoader.from_generator(
|
||||
feed_list=feed_list,
|
||||
capacity=64,
|
||||
use_double_buffer=True,
|
||||
iterable=False)
|
||||
else:
|
||||
labels = None
|
||||
loader = None
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
return image, labels, loader
|
||||
|
||||
def __call__(self, mode):
|
||||
image, labels, loader = self.create_feed(mode)
|
||||
inputs = image
|
||||
conv_feas = self.backbone(inputs)
|
||||
predicts = self.head(conv_feas, labels, mode)
|
||||
if mode == "train":
|
||||
loss = self.loss(predicts, labels)
|
||||
label = labels['label']
|
||||
acc = fluid.layers.accuracy(predicts['predict'], label, k=1)
|
||||
outputs = {'total_loss': loss, 'decoded_out': \
|
||||
predicts['decoded_out'], 'label': label, 'acc': acc}
|
||||
return loader, outputs
|
||||
|
||||
else:
|
||||
return loader, predicts
|
|
@ -0,0 +1,46 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class ClsHead(object):
|
||||
def __init__(self, params):
|
||||
super(ClsHead, self).__init__()
|
||||
self.class_dim = params['class_dim']
|
||||
|
||||
def __call__(self, inputs, labels=None, mode=None):
|
||||
pool = fluid.layers.pool2d(
|
||||
input=inputs, pool_type='avg', global_pooling=True)
|
||||
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
|
||||
|
||||
out = fluid.layers.fc(
|
||||
input=pool,
|
||||
size=self.class_dim,
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
name="fc_0.w_0",
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=fluid.param_attr.ParamAttr(name="fc_0.b_0"))
|
||||
|
||||
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
|
||||
out_label = fluid.layers.argmax(out, axis=1)
|
||||
predicts = {'predict': softmax_out, 'decoded_out': out_label}
|
||||
return predicts
|
|
@ -0,0 +1,33 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class ClsLoss(object):
|
||||
def __init__(self, params):
|
||||
super(ClsLoss, self).__init__()
|
||||
self.loss_func = fluid.layers.cross_entropy
|
||||
|
||||
def __call__(self, predicts, labels):
|
||||
predict = predicts['predict']
|
||||
label = labels['label']
|
||||
# softmax_out = fluid.layers.softmax(predict, use_cudnn=False)
|
||||
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
||||
sum_cost = fluid.layers.mean(cost)
|
||||
return sum_cost
|
|
@ -45,10 +45,12 @@ from ppocr.utils.save_load import init_model
|
|||
from eval_utils.eval_det_utils import eval_det_run
|
||||
from eval_utils.eval_rec_utils import test_rec_benchmark
|
||||
from eval_utils.eval_rec_utils import eval_rec_run
|
||||
from eval_utils.eval_cls_utils import eval_cls_run
|
||||
|
||||
|
||||
def main():
|
||||
startup_prog, eval_program, place, config, train_alg_type = program.preprocess()
|
||||
startup_prog, eval_program, place, config, train_alg_type = program.preprocess(
|
||||
)
|
||||
eval_build_outputs = program.build(
|
||||
config, eval_program, startup_prog, mode='test')
|
||||
eval_fetch_name_list = eval_build_outputs[1]
|
||||
|
@ -67,6 +69,14 @@ def main():
|
|||
'fetch_varname_list':eval_fetch_varname_list}
|
||||
metrics = eval_det_run(exe, config, eval_info_dict, "eval")
|
||||
logger.info("Eval result: {}".format(metrics))
|
||||
elif train_alg_type == 'cls':
|
||||
eval_reader = reader_main(config=config, mode="eval")
|
||||
eval_info_dict = {'program': eval_program, \
|
||||
'reader': eval_reader, \
|
||||
'fetch_name_list': eval_fetch_name_list, \
|
||||
'fetch_varname_list': eval_fetch_varname_list}
|
||||
metrics = eval_cls_run(exe, eval_info_dict)
|
||||
logger.info("Eval result: {}".format(metrics))
|
||||
else:
|
||||
reader_type = config['Global']['reader_yml']
|
||||
if "benchmark" not in reader_type:
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
import paddle.fluid as fluid
|
||||
|
||||
__all__ = ['eval_class_run']
|
||||
|
||||
import logging
|
||||
|
||||
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
|
||||
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def eval_cls_run(exe, eval_info_dict):
|
||||
"""
|
||||
Run evaluation program, return program outputs.
|
||||
"""
|
||||
total_sample_num = 0
|
||||
total_acc_num = 0
|
||||
total_batch_num = 0
|
||||
|
||||
for data in eval_info_dict['reader']():
|
||||
img_num = len(data)
|
||||
img_list = []
|
||||
label_list = []
|
||||
for ino in range(img_num):
|
||||
img_list.append(data[ino][0])
|
||||
label_list.append(data[ino][1])
|
||||
|
||||
img_list = np.concatenate(img_list, axis=0)
|
||||
outs = exe.run(eval_info_dict['program'], \
|
||||
feed={'image': img_list}, \
|
||||
fetch_list=eval_info_dict['fetch_varname_list'], \
|
||||
return_numpy=False)
|
||||
softmax_outs = np.array(outs[1])
|
||||
|
||||
acc, acc_num = cal_cls_acc(softmax_outs, label_list)
|
||||
total_acc_num += acc_num
|
||||
total_sample_num += len(label_list)
|
||||
# logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
|
||||
total_batch_num += 1
|
||||
avg_acc = total_acc_num * 1.0 / total_sample_num
|
||||
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
|
||||
"total_sample_num": total_sample_num}
|
||||
return metrics
|
||||
|
||||
|
||||
def cal_cls_acc(preds, labels):
|
||||
acc_num = 0
|
||||
for pred, label in zip(preds, labels):
|
||||
if pred == label:
|
||||
acc_num += 1
|
||||
return acc_num / len(preds), acc_num
|
|
@ -0,0 +1,143 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
|
||||
|
||||
class TextClassifier(object):
|
||||
def __init__(self, args):
|
||||
self.predictor, self.input_tensor, self.output_tensors = \
|
||||
utility.create_predictor(args, mode="cls")
|
||||
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
|
||||
self.cls_batch_num = args.rec_batch_num
|
||||
self.label_list = args.label_list
|
||||
|
||||
def resize_norm_img(self, img):
|
||||
imgC, imgH, imgW = self.cls_image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
if self.cls_image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_list = copy.deepcopy(img_list)
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
width_list = []
|
||||
for img in img_list:
|
||||
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||
# Sorting can speed up the cls process
|
||||
indices = np.argsort(np.array(width_list))
|
||||
|
||||
cls_res = [['', 0.0]] * img_num
|
||||
batch_num = self.cls_batch_num
|
||||
predict_time = 0
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
max_wh_ratio = 0
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]])
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
starttime = time.time()
|
||||
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.zero_copy_run()
|
||||
|
||||
prob_out = self.output_tensors[0].copy_to_cpu()
|
||||
label_out = self.output_tensors[1].copy_to_cpu()
|
||||
|
||||
elapse = time.time() - starttime
|
||||
predict_time += elapse
|
||||
for rno in range(len(label_out)):
|
||||
label_idx = label_out[rno]
|
||||
score = prob_out[rno][label_idx]
|
||||
label = self.label_list[label_idx]
|
||||
cls_res[indices[beg_img_no + rno]] = [label, score]
|
||||
if label == 180:
|
||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||
img_list[indices[beg_img_no + rno]], 1)
|
||||
return img_list, cls_res, predict_time
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
text_classifier = TextClassifier(args)
|
||||
valid_image_file_list = []
|
||||
img_list = []
|
||||
for image_file in image_file_list[:10]:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
valid_image_file_list.append(image_file)
|
||||
img_list.append(img)
|
||||
try:
|
||||
img_list, cls_res, predict_time = text_classifier(img_list)
|
||||
print(cls_res)
|
||||
from matplotlib import pyplot as plt
|
||||
for img, angle in zip(img_list, cls_res):
|
||||
plt.title(str(angle))
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
exit()
|
||||
for ino in range(len(img_list)):
|
||||
print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino]))
|
||||
print("Total predict time for %d images:%.3f" %
|
||||
(len(img_list), predict_time))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(utility.parse_args())
|
|
@ -13,16 +13,19 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
|
||||
logger = initial_logger()
|
||||
import cv2
|
||||
import tools.infer.predict_det as predict_det
|
||||
import tools.infer.predict_rec as predict_rec
|
||||
import tools.infer.predict_cls as predict_cls
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
|
@ -37,6 +40,7 @@ class TextSystem(object):
|
|||
def __init__(self, args):
|
||||
self.text_detector = predict_det.TextDetector(args)
|
||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||
self.text_classifier = predict_cls.TextClassifier(args)
|
||||
|
||||
def get_rotate_crop_image(self, img, points):
|
||||
'''
|
||||
|
@ -91,7 +95,10 @@ class TextSystem(object):
|
|||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
img_rotate_list, angle_list, elapse = self.text_classifier(
|
||||
img_crop_list)
|
||||
print("cls num : {}, elapse : {}".format(len(img_rotate_list), elapse))
|
||||
rec_res, elapse = self.text_recognizer(img_rotate_list)
|
||||
print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
|
||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||
return dt_boxes, rec_res
|
||||
|
@ -110,8 +117,8 @@ def sorted_boxes(dt_boxes):
|
|||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
tmp = _boxes[i]
|
||||
_boxes[i] = _boxes[i + 1]
|
||||
_boxes[i + 1] = tmp
|
||||
|
|
|
@ -65,6 +65,13 @@ def parse_args():
|
|||
type=str,
|
||||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
||||
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||
|
||||
# params for text classifier
|
||||
parser.add_argument("--cls_model_dir", type=str)
|
||||
parser.add_argument("--cls_image_shape", type=str, default="3, 32, 100")
|
||||
parser.add_argument("--label_list", type=list, default=[0, 180])
|
||||
parser.add_argument("--cls_batch_num", type=int, default=30)
|
||||
|
||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -72,6 +79,8 @@ def parse_args():
|
|||
def create_predictor(args, mode):
|
||||
if mode == "det":
|
||||
model_dir = args.det_model_dir
|
||||
elif mode == 'cls':
|
||||
model_dir = args.cls_model_dir
|
||||
else:
|
||||
model_dir = args.rec_model_dir
|
||||
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if os.environ.get(key, None) is None:
|
||||
os.environ[key] = str(value)
|
||||
|
||||
|
||||
# NOTE(paddle-dev): All of these flags should be
|
||||
# set before `import paddle`. Otherwise, it would
|
||||
# not take any effect.
|
||||
set_paddle_flags(
|
||||
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||
)
|
||||
|
||||
import tools.program as program
|
||||
from paddle import fluid
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
from ppocr.data.reader_main import reader_main
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.utility import create_module
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
|
||||
|
||||
def main():
|
||||
config = program.load_config(FLAGS.config)
|
||||
program.merge_config(FLAGS.opt)
|
||||
logger.info(config)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
# check_gpu(use_gpu)
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
rec_model = create_module(config['Architecture']['function'])(params=config)
|
||||
startup_prog = fluid.Program()
|
||||
eval_prog = fluid.Program()
|
||||
with fluid.program_guard(eval_prog, startup_prog):
|
||||
with fluid.unique_name.guard():
|
||||
_, outputs = rec_model(mode="test")
|
||||
fetch_name_list = list(outputs.keys())
|
||||
fetch_varname_list = [outputs[v].name for v in fetch_name_list]
|
||||
eval_prog = eval_prog.clone(for_test=True)
|
||||
exe.run(startup_prog)
|
||||
|
||||
init_model(config, eval_prog, exe)
|
||||
|
||||
blobs = reader_main(config, 'test')()
|
||||
infer_img = config['Global']['infer_img']
|
||||
infer_list = get_image_file_list(infer_img)
|
||||
max_img_num = len(infer_list)
|
||||
if len(infer_list) == 0:
|
||||
logger.info("Can not find img in infer_img dir.")
|
||||
for i in range(max_img_num):
|
||||
logger.info("infer_img:%s" % infer_list[i])
|
||||
img = next(blobs)
|
||||
predict = exe.run(program=eval_prog,
|
||||
feed={"image": img},
|
||||
fetch_list=fetch_varname_list,
|
||||
return_numpy=False)
|
||||
for k in predict:
|
||||
k = np.array(k)
|
||||
print(k)
|
||||
# save for inference model
|
||||
target_var = []
|
||||
for key, values in outputs.items():
|
||||
target_var.append(values)
|
||||
|
||||
fluid.io.save_inference_model(
|
||||
"./output",
|
||||
feeded_var_names=['image'],
|
||||
target_vars=target_var,
|
||||
executor=exe,
|
||||
main_program=eval_prog,
|
||||
model_filename="model",
|
||||
params_filename="params")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = program.ArgsParser()
|
||||
FLAGS = parser.parse_args()
|
||||
main()
|
|
@ -30,6 +30,7 @@ import time
|
|||
from ppocr.utils.stats import TrainingStats
|
||||
from eval_utils.eval_det_utils import eval_det_run
|
||||
from eval_utils.eval_rec_utils import eval_rec_run
|
||||
from eval_utils.eval_cls_utils import eval_cls_run
|
||||
from ppocr.utils.save_load import save_model
|
||||
import numpy as np
|
||||
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
||||
|
@ -398,6 +399,87 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
return
|
||||
|
||||
|
||||
def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
|
||||
train_batch_id = 0
|
||||
log_smooth_window = config['Global']['log_smooth_window']
|
||||
epoch_num = config['Global']['epoch_num']
|
||||
print_batch_step = config['Global']['print_batch_step']
|
||||
eval_batch_step = config['Global']['eval_batch_step']
|
||||
start_eval_step = 0
|
||||
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
||||
start_eval_step = eval_batch_step[0]
|
||||
eval_batch_step = eval_batch_step[1]
|
||||
logger.info(
|
||||
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
|
||||
format(start_eval_step, eval_batch_step))
|
||||
save_epoch_step = config['Global']['save_epoch_step']
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
if not os.path.exists(save_model_dir):
|
||||
os.makedirs(save_model_dir)
|
||||
train_stats = TrainingStats(log_smooth_window, ['loss', 'acc'])
|
||||
best_eval_acc = -1
|
||||
best_batch_id = 0
|
||||
best_epoch = 0
|
||||
train_loader = train_info_dict['reader']
|
||||
for epoch in range(epoch_num):
|
||||
train_loader.start()
|
||||
try:
|
||||
while True:
|
||||
t1 = time.time()
|
||||
train_outs = exe.run(
|
||||
program=train_info_dict['compile_program'],
|
||||
fetch_list=train_info_dict['fetch_varname_list'],
|
||||
return_numpy=False)
|
||||
fetch_map = dict(
|
||||
zip(train_info_dict['fetch_name_list'],
|
||||
range(len(train_outs))))
|
||||
|
||||
loss = np.mean(np.array(train_outs[fetch_map['total_loss']]))
|
||||
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
|
||||
acc = np.mean(np.array(train_outs[fetch_map['acc']]))
|
||||
|
||||
t2 = time.time()
|
||||
train_batch_elapse = t2 - t1
|
||||
stats = {'loss': loss, 'acc': acc}
|
||||
train_stats.update(stats)
|
||||
if train_batch_id > start_eval_step and (train_batch_id - start_eval_step) \
|
||||
% print_batch_step == 0:
|
||||
logs = train_stats.log()
|
||||
strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format(
|
||||
epoch, train_batch_id, lr, logs, train_batch_elapse)
|
||||
logger.info(strs)
|
||||
|
||||
if train_batch_id > 0 and\
|
||||
train_batch_id % eval_batch_step == 0:
|
||||
model_average = train_info_dict['model_average']
|
||||
if model_average != None:
|
||||
model_average.apply(exe)
|
||||
metrics = eval_cls_run(exe, eval_info_dict)
|
||||
eval_acc = metrics['avg_acc']
|
||||
eval_sample_num = metrics['total_sample_num']
|
||||
if eval_acc > best_eval_acc:
|
||||
best_eval_acc = eval_acc
|
||||
best_batch_id = train_batch_id
|
||||
best_epoch = epoch
|
||||
save_path = save_model_dir + "/best_accuracy"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
|
||||
train_batch_id, eval_acc, best_eval_acc, best_epoch,
|
||||
best_batch_id, eval_sample_num)
|
||||
logger.info(strs)
|
||||
train_batch_id += 1
|
||||
|
||||
except fluid.core.EOFException:
|
||||
train_loader.reset()
|
||||
if epoch == 0 and save_epoch_step == 1:
|
||||
save_path = save_model_dir + "/iter_epoch_0"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
return
|
||||
|
||||
|
||||
def preprocess():
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
|
@ -409,7 +491,9 @@ def preprocess():
|
|||
check_gpu(use_gpu)
|
||||
|
||||
alg = config['Global']['algorithm']
|
||||
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
|
||||
]
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||
|
||||
|
@ -419,7 +503,9 @@ def preprocess():
|
|||
|
||||
if alg in ['EAST', 'DB', 'SAST']:
|
||||
train_alg_type = 'det'
|
||||
else:
|
||||
elif alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
||||
train_alg_type = 'rec'
|
||||
else:
|
||||
train_alg_type = 'cls'
|
||||
|
||||
return startup_program, train_program, place, config, train_alg_type
|
||||
|
|
|
@ -75,7 +75,8 @@ def main():
|
|||
|
||||
# dump mode structure
|
||||
if config['Global']['debug']:
|
||||
if train_alg_type == 'rec' and 'attention' in config['Global']['loss_type']:
|
||||
if train_alg_type == 'rec' and 'attention' in config['Global'][
|
||||
'loss_type']:
|
||||
logger.warning('Does not suport dump attention...')
|
||||
else:
|
||||
summary(train_program)
|
||||
|
@ -96,8 +97,10 @@ def main():
|
|||
|
||||
if train_alg_type == 'det':
|
||||
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
|
||||
else:
|
||||
elif train_alg_type == 'rec':
|
||||
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
||||
else:
|
||||
program.train_eval_cls_run(config, exe, train_info_dict, eval_info_dict)
|
||||
|
||||
|
||||
def test_reader():
|
||||
|
@ -119,6 +122,7 @@ def test_reader():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
startup_program, train_program, place, config, train_alg_type = program.preprocess()
|
||||
startup_program, train_program, place, config, train_alg_type = program.preprocess(
|
||||
)
|
||||
main()
|
||||
# test_reader()
|
||||
|
|
Loading…
Reference in New Issue