添加分类模型
This commit is contained in:
parent
7c09c97d70
commit
e11b2108fa
|
@ -0,0 +1,43 @@
|
||||||
|
Global:
|
||||||
|
algorithm: CLS
|
||||||
|
use_gpu: false
|
||||||
|
epoch_num: 30
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 10
|
||||||
|
save_model_dir: output/cls_mb3
|
||||||
|
save_epoch_step: 3
|
||||||
|
eval_batch_step: 100
|
||||||
|
train_batch_size_per_card: 256
|
||||||
|
test_batch_size_per_card: 256
|
||||||
|
image_shape: [3, 32, 100]
|
||||||
|
label_list: [0,180]
|
||||||
|
reader_yml: ./configs/cls/cls_reader.yml
|
||||||
|
pretrain_weights:
|
||||||
|
checkpoints: /Users/zhoujun20/Desktop/code/class_model/cls_mb3_ultra_small_0.35/best_accuracy
|
||||||
|
save_inference_dir:
|
||||||
|
infer_img: /Users/zhoujun20/Desktop/code/PaddleOCR/doc/imgs_words/ch/word_1.jpg
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
function: ppocr.modeling.architectures.cls_model,ClsModel
|
||||||
|
|
||||||
|
Backbone:
|
||||||
|
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
|
||||||
|
scale: 0.35
|
||||||
|
model_name: Ultra_small
|
||||||
|
|
||||||
|
Head:
|
||||||
|
function: ppocr.modeling.heads.cls_head,ClsHead
|
||||||
|
class_dim: 2
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
function: ppocr.modeling.losses.cls_loss,ClsLoss
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
function: ppocr.optimizer,AdamDecay
|
||||||
|
base_lr: 0.001
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
decay:
|
||||||
|
function: piecewise_decay
|
||||||
|
boundaries: [20,30]
|
||||||
|
decay_rate: 0.1
|
|
@ -0,0 +1,13 @@
|
||||||
|
TrainReader:
|
||||||
|
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
|
||||||
|
num_workers: 1
|
||||||
|
img_set_dir: /
|
||||||
|
label_file_path: /Users/zhoujun20/Downloads/direction/rotate_ver/train.txt
|
||||||
|
|
||||||
|
EvalReader:
|
||||||
|
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
|
||||||
|
img_set_dir: /
|
||||||
|
label_file_path: /Users/zhoujun20/Downloads/direction/rotate_ver/train.txt
|
||||||
|
|
||||||
|
TestReader:
|
||||||
|
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
|
|
@ -55,6 +55,10 @@ public:
|
||||||
|
|
||||||
this->char_list_file.assign(config_map_["char_list_file"]);
|
this->char_list_file.assign(config_map_["char_list_file"]);
|
||||||
|
|
||||||
|
this->cls_model_dir.assign(config_map_["cls_model_dir"]);
|
||||||
|
|
||||||
|
this->cls_thresh = stod(config_map_["cls_thresh"]);
|
||||||
|
|
||||||
this->visualize = bool(stoi(config_map_["visualize"]));
|
this->visualize = bool(stoi(config_map_["visualize"]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,6 +86,10 @@ public:
|
||||||
|
|
||||||
std::string char_list_file;
|
std::string char_list_file;
|
||||||
|
|
||||||
|
std::string cls_model_dir;
|
||||||
|
|
||||||
|
double cls_thresh;
|
||||||
|
|
||||||
bool visualize = true;
|
bool visualize = true;
|
||||||
|
|
||||||
void PrintConfigInfo();
|
void PrintConfigInfo();
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "opencv2/core.hpp"
|
||||||
|
#include "opencv2/imgcodecs.hpp"
|
||||||
|
#include "opencv2/imgproc.hpp"
|
||||||
|
#include "paddle_api.h"
|
||||||
|
#include "paddle_inference_api.h"
|
||||||
|
#include <chrono>
|
||||||
|
#include <iomanip>
|
||||||
|
#include <iostream>
|
||||||
|
#include <ostream>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <numeric>
|
||||||
|
|
||||||
|
#include <include/preprocess_op.h>
|
||||||
|
#include <include/utility.h>
|
||||||
|
|
||||||
|
namespace PaddleOCR {
|
||||||
|
|
||||||
|
class Classifier {
|
||||||
|
public:
|
||||||
|
explicit Classifier(const std::string &model_dir, const bool &use_gpu,
|
||||||
|
const int &gpu_id, const int &gpu_mem,
|
||||||
|
const int &cpu_math_library_num_threads,
|
||||||
|
const bool &use_mkldnn, const double &cls_thresh) {
|
||||||
|
this->use_gpu_ = use_gpu;
|
||||||
|
this->gpu_id_ = gpu_id;
|
||||||
|
this->gpu_mem_ = gpu_mem;
|
||||||
|
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||||
|
this->use_mkldnn_ = use_mkldnn;
|
||||||
|
|
||||||
|
this->cls_thresh = cls_thresh;
|
||||||
|
|
||||||
|
LoadModel(model_dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load Paddle inference model
|
||||||
|
void LoadModel(const std::string &model_dir);
|
||||||
|
|
||||||
|
cv::Mat Run(cv::Mat &img);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<PaddlePredictor> predictor_;
|
||||||
|
|
||||||
|
bool use_gpu_ = false;
|
||||||
|
int gpu_id_ = 0;
|
||||||
|
int gpu_mem_ = 4000;
|
||||||
|
int cpu_math_library_num_threads_ = 4;
|
||||||
|
bool use_mkldnn_ = false;
|
||||||
|
|
||||||
|
double cls_thresh = 0.5;
|
||||||
|
|
||||||
|
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
|
||||||
|
std::vector<float> scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||||
|
bool is_scale_ = true;
|
||||||
|
|
||||||
|
// pre-process
|
||||||
|
ClsResizeImg resize_op_;
|
||||||
|
Normalize normalize_op_;
|
||||||
|
Permute permute_op_;
|
||||||
|
|
||||||
|
}; // class Classifier
|
||||||
|
|
||||||
|
} // namespace PaddleOCR
|
|
@ -27,6 +27,7 @@
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
|
#include <include/ocr_cls.h>
|
||||||
#include <include/postprocess_op.h>
|
#include <include/postprocess_op.h>
|
||||||
#include <include/preprocess_op.h>
|
#include <include/preprocess_op.h>
|
||||||
#include <include/utility.h>
|
#include <include/utility.h>
|
||||||
|
@ -54,7 +55,8 @@ public:
|
||||||
// Load Paddle inference model
|
// Load Paddle inference model
|
||||||
void LoadModel(const std::string &model_dir);
|
void LoadModel(const std::string &model_dir);
|
||||||
|
|
||||||
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img);
|
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img,
|
||||||
|
Classifier &cls);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<PaddlePredictor> predictor_;
|
std::shared_ptr<PaddlePredictor> predictor_;
|
||||||
|
|
|
@ -56,4 +56,10 @@ public:
|
||||||
const std::vector<int> &rec_image_shape = {3, 32, 320});
|
const std::vector<int> &rec_image_shape = {3, 32, 320});
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class ClsResizeImg {
|
||||||
|
public:
|
||||||
|
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||||
|
const std::vector<int> &rec_image_shape = {3, 32, 320});
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace PaddleOCR
|
} // namespace PaddleOCR
|
|
@ -53,6 +53,9 @@ int main(int argc, char **argv) {
|
||||||
config.use_mkldnn, config.max_side_len, config.det_db_thresh,
|
config.use_mkldnn, config.max_side_len, config.det_db_thresh,
|
||||||
config.det_db_box_thresh, config.det_db_unclip_ratio,
|
config.det_db_box_thresh, config.det_db_unclip_ratio,
|
||||||
config.visualize);
|
config.visualize);
|
||||||
|
Classifier cls(config.cls_model_dir, config.use_gpu, config.gpu_id,
|
||||||
|
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||||
|
config.use_mkldnn, config.cls_thresh);
|
||||||
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
|
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
|
||||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||||
config.use_mkldnn, config.char_list_file);
|
config.use_mkldnn, config.char_list_file);
|
||||||
|
@ -61,7 +64,7 @@ int main(int argc, char **argv) {
|
||||||
std::vector<std::vector<std::vector<int>>> boxes;
|
std::vector<std::vector<std::vector<int>>> boxes;
|
||||||
det.Run(srcimg, boxes);
|
det.Run(srcimg, boxes);
|
||||||
|
|
||||||
rec.Run(boxes, srcimg);
|
rec.Run(boxes, srcimg, cls);
|
||||||
|
|
||||||
auto end = std::chrono::system_clock::now();
|
auto end = std::chrono::system_clock::now();
|
||||||
auto duration =
|
auto duration =
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <include/ocr_cls.h>
|
||||||
|
|
||||||
|
namespace PaddleOCR {
|
||||||
|
|
||||||
|
cv::Mat Classifier::Run(cv::Mat &img) {
|
||||||
|
cv::Mat src_img;
|
||||||
|
img.copyTo(src_img);
|
||||||
|
cv::Mat resize_img;
|
||||||
|
|
||||||
|
std::vector<int> rec_image_shape = {3, 32, 100};
|
||||||
|
int index = 0;
|
||||||
|
float wh_ratio = float(img.cols) / float(img.rows);
|
||||||
|
|
||||||
|
this->resize_op_.Run(img, resize_img, rec_image_shape);
|
||||||
|
|
||||||
|
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||||
|
this->is_scale_);
|
||||||
|
|
||||||
|
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
||||||
|
|
||||||
|
this->permute_op_.Run(&resize_img, input.data());
|
||||||
|
|
||||||
|
auto input_names = this->predictor_->GetInputNames();
|
||||||
|
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||||
|
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||||
|
input_t->copy_from_cpu(input.data());
|
||||||
|
|
||||||
|
this->predictor_->ZeroCopyRun();
|
||||||
|
|
||||||
|
std::vector<float> softmax_out;
|
||||||
|
std::vector<int64_t> label_out;
|
||||||
|
auto output_names = this->predictor_->GetOutputNames();
|
||||||
|
auto softmax_out_t = this->predictor_->GetOutputTensor(output_names[0]);
|
||||||
|
auto label_out_t = this->predictor_->GetOutputTensor(output_names[1]);
|
||||||
|
auto softmax_shape_out = softmax_out_t->shape();
|
||||||
|
auto label_shape_out = label_out_t->shape();
|
||||||
|
|
||||||
|
int softmax_out_num =
|
||||||
|
std::accumulate(softmax_shape_out.begin(), softmax_shape_out.end(), 1,
|
||||||
|
std::multiplies<int>());
|
||||||
|
|
||||||
|
int label_out_num =
|
||||||
|
std::accumulate(label_shape_out.begin(), label_shape_out.end(), 1,
|
||||||
|
std::multiplies<int>());
|
||||||
|
softmax_out.resize(softmax_out_num);
|
||||||
|
label_out.resize(label_out_num);
|
||||||
|
|
||||||
|
softmax_out_t->copy_to_cpu(softmax_out.data());
|
||||||
|
label_out_t->copy_to_cpu(label_out.data());
|
||||||
|
|
||||||
|
int label = label_out[0];
|
||||||
|
float score = softmax_out[label];
|
||||||
|
// std::cout << "\nlabel "<<label<<" score: "<<score;
|
||||||
|
if (label % 2 == 1 && score > this->cls_thresh) {
|
||||||
|
cv::rotate(src_img, src_img, 1);
|
||||||
|
}
|
||||||
|
return src_img;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Classifier::LoadModel(const std::string &model_dir) {
|
||||||
|
AnalysisConfig config;
|
||||||
|
config.SetModel(model_dir + "/model", model_dir + "/params");
|
||||||
|
|
||||||
|
if (this->use_gpu_) {
|
||||||
|
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
|
||||||
|
} else {
|
||||||
|
config.DisableGpu();
|
||||||
|
if (this->use_mkldnn_) {
|
||||||
|
config.EnableMKLDNN();
|
||||||
|
}
|
||||||
|
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// false for zero copy tensor
|
||||||
|
config.SwitchUseFeedFetchOps(false);
|
||||||
|
// true for multiple input
|
||||||
|
config.SwitchSpecifyInputNames(true);
|
||||||
|
|
||||||
|
config.SwitchIrOptim(true);
|
||||||
|
|
||||||
|
config.EnableMemoryOptim();
|
||||||
|
config.DisableGlogInfo();
|
||||||
|
|
||||||
|
this->predictor_ = CreatePaddlePredictor(config);
|
||||||
|
}
|
||||||
|
} // namespace PaddleOCR
|
|
@ -17,7 +17,7 @@
|
||||||
namespace PaddleOCR {
|
namespace PaddleOCR {
|
||||||
|
|
||||||
void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
||||||
cv::Mat &img) {
|
cv::Mat &img, Classifier &cls) {
|
||||||
cv::Mat srcimg;
|
cv::Mat srcimg;
|
||||||
img.copyTo(srcimg);
|
img.copyTo(srcimg);
|
||||||
cv::Mat crop_img;
|
cv::Mat crop_img;
|
||||||
|
@ -28,6 +28,8 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
||||||
for (int i = boxes.size() - 1; i >= 0; i--) {
|
for (int i = boxes.size() - 1; i >= 0; i--) {
|
||||||
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
||||||
|
|
||||||
|
crop_img = cls.Run(crop_img);
|
||||||
|
|
||||||
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
|
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
|
||||||
|
|
||||||
this->resize_op_.Run(crop_img, resize_img, wh_ratio);
|
this->resize_op_.Run(crop_img, resize_img, wh_ratio);
|
||||||
|
|
|
@ -116,4 +116,26 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
|
||||||
cv::INTER_LINEAR);
|
cv::INTER_LINEAR);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||||
|
const std::vector<int> &rec_image_shape) {
|
||||||
|
int imgC, imgH, imgW;
|
||||||
|
imgC = rec_image_shape[0];
|
||||||
|
imgH = rec_image_shape[1];
|
||||||
|
imgW = rec_image_shape[2];
|
||||||
|
|
||||||
|
float ratio = float(img.cols) / float(img.rows);
|
||||||
|
int resize_w, resize_h;
|
||||||
|
if (ceilf(imgH * ratio) > imgW)
|
||||||
|
resize_w = imgW;
|
||||||
|
else
|
||||||
|
resize_w = int(ceilf(imgH * ratio));
|
||||||
|
|
||||||
|
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||||
|
cv::INTER_LINEAR);
|
||||||
|
if (resize_w < imgW) {
|
||||||
|
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
|
||||||
|
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace PaddleOCR
|
} // namespace PaddleOCR
|
|
@ -40,8 +40,8 @@ CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SY
|
||||||
|
|
||||||
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
|
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
|
||||||
|
|
||||||
ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o
|
ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o
|
||||||
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
|
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
|
||||||
|
|
||||||
ocr_db_crnn.o: ocr_db_crnn.cc
|
ocr_db_crnn.o: ocr_db_crnn.cc
|
||||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc
|
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc
|
||||||
|
@ -49,6 +49,9 @@ ocr_db_crnn.o: ocr_db_crnn.cc
|
||||||
crnn_process.o: fetch_opencv crnn_process.cc
|
crnn_process.o: fetch_opencv crnn_process.cc
|
||||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o crnn_process.o -c crnn_process.cc
|
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o crnn_process.o -c crnn_process.cc
|
||||||
|
|
||||||
|
cls_process.o: fetch_opencv cls_process.cc
|
||||||
|
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o cls_process.o -c cls_process.cc
|
||||||
|
|
||||||
db_post_process.o: fetch_clipper fetch_opencv db_post_process.cc
|
db_post_process.o: fetch_clipper fetch_opencv db_post_process.cc
|
||||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o db_post_process.o -c db_post_process.cc
|
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o db_post_process.o -c db_post_process.cc
|
||||||
|
|
||||||
|
@ -73,5 +76,5 @@ fetch_opencv:
|
||||||
|
|
||||||
.PHONY: clean
|
.PHONY: clean
|
||||||
clean:
|
clean:
|
||||||
rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o
|
rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o cls_process.o
|
||||||
rm -f ocr_db_crnn
|
rm -f ocr_db_crnn
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "cls_process.h" //NOLINT
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
const std::vector<int> rec_image_shape{3, 32, 100};
|
||||||
|
|
||||||
|
cv::Mat ClsResizeImg(cv::Mat img) {
|
||||||
|
int imgC, imgH, imgW;
|
||||||
|
imgC = rec_image_shape[0];
|
||||||
|
imgH = rec_image_shape[1];
|
||||||
|
imgW = rec_image_shape[2];
|
||||||
|
|
||||||
|
float ratio = static_cast<float>(img.cols) / static_cast<float>(img.rows);
|
||||||
|
|
||||||
|
int resize_w, resize_h;
|
||||||
|
if (ceilf(imgH * ratio) > imgW)
|
||||||
|
resize_w = imgW;
|
||||||
|
else
|
||||||
|
resize_w = int(ceilf(imgH * ratio));
|
||||||
|
cv::Mat resize_img;
|
||||||
|
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||||
|
cv::INTER_LINEAR);
|
||||||
|
if (resize_w < imgW) {
|
||||||
|
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
|
||||||
|
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||||
|
}
|
||||||
|
return resize_img;
|
||||||
|
}
|
|
@ -0,0 +1,29 @@
|
||||||
|
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <fstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "math.h" //NOLINT
|
||||||
|
#include "opencv2/core.hpp"
|
||||||
|
#include "opencv2/imgcodecs.hpp"
|
||||||
|
#include "opencv2/imgproc.hpp"
|
||||||
|
|
||||||
|
cv::Mat ClsResizeImg(cv::Mat img);
|
|
@ -15,6 +15,7 @@
|
||||||
#include "paddle_api.h" // NOLINT
|
#include "paddle_api.h" // NOLINT
|
||||||
#include <chrono>
|
#include <chrono>
|
||||||
|
|
||||||
|
#include "cls_process.h"
|
||||||
#include "crnn_process.h"
|
#include "crnn_process.h"
|
||||||
#include "db_post_process.h"
|
#include "db_post_process.h"
|
||||||
|
|
||||||
|
@ -105,11 +106,55 @@ cv::Mat DetResizeImg(const cv::Mat img, int max_size_len,
|
||||||
return resize_img;
|
return resize_img;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cv::Mat RunClsModel(cv::Mat img, std::shared_ptr<PaddlePredictor> predictor_cls,
|
||||||
|
const float thresh = 0.5) {
|
||||||
|
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
|
||||||
|
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||||
|
|
||||||
|
cv::Mat srcimg;
|
||||||
|
img.copyTo(srcimg);
|
||||||
|
cv::Mat crop_img;
|
||||||
|
cv::Mat resize_img;
|
||||||
|
|
||||||
|
int index = 0;
|
||||||
|
float wh_ratio =
|
||||||
|
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
|
||||||
|
|
||||||
|
resize_img = ClsResizeImg(crop_img);
|
||||||
|
resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
|
||||||
|
|
||||||
|
const float *dimg = reinterpret_cast<const float *>(resize_img.data);
|
||||||
|
|
||||||
|
std::unique_ptr<Tensor> input_tensor0(std::move(predictor_cls->GetInput(0)));
|
||||||
|
input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols});
|
||||||
|
auto *data0 = input_tensor0->mutable_data<float>();
|
||||||
|
|
||||||
|
NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale);
|
||||||
|
// Run CLS predictor
|
||||||
|
predictor_cls->Run();
|
||||||
|
|
||||||
|
// Get output and run postprocess
|
||||||
|
std::unique_ptr<const Tensor> softmax_out(
|
||||||
|
std::move(predictor_cls->GetOutput(0)));
|
||||||
|
std::unique_ptr<const Tensor> label_out(
|
||||||
|
std::move(predictor_cls->GetOutput(1)));
|
||||||
|
auto *softmax_scores = softmax_out->mutable_data<float>();
|
||||||
|
auto *label_idxs = label_out->data<int64>();
|
||||||
|
int label_idx = label_idxs[0];
|
||||||
|
float score = softmax_scores[label_idx];
|
||||||
|
|
||||||
|
if (label_idx % 2 == 1 && score > thresh) {
|
||||||
|
cv::rotate(srcimg, srcimg, 1);
|
||||||
|
}
|
||||||
|
return srcimg;
|
||||||
|
}
|
||||||
|
|
||||||
void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
|
void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
|
||||||
std::shared_ptr<PaddlePredictor> predictor_crnn,
|
std::shared_ptr<PaddlePredictor> predictor_crnn,
|
||||||
std::vector<std::string> &rec_text,
|
std::vector<std::string> &rec_text,
|
||||||
std::vector<float> &rec_text_score,
|
std::vector<float> &rec_text_score,
|
||||||
std::vector<std::string> charactor_dict) {
|
std::vector<std::string> charactor_dict,
|
||||||
|
std::shared_ptr<PaddlePredictor> predictor_cls) {
|
||||||
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
|
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
|
||||||
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||||
|
|
||||||
|
@ -121,6 +166,7 @@ void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
|
||||||
int index = 0;
|
int index = 0;
|
||||||
for (int i = boxes.size() - 1; i >= 0; i--) {
|
for (int i = boxes.size() - 1; i >= 0; i--) {
|
||||||
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
||||||
|
crop_img = RunClsModel(crop_img, predictor_cls);
|
||||||
float wh_ratio =
|
float wh_ratio =
|
||||||
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
|
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
|
||||||
|
|
||||||
|
@ -323,8 +369,9 @@ int main(int argc, char **argv) {
|
||||||
}
|
}
|
||||||
std::string det_model_file = argv[1];
|
std::string det_model_file = argv[1];
|
||||||
std::string rec_model_file = argv[2];
|
std::string rec_model_file = argv[2];
|
||||||
std::string img_path = argv[3];
|
std::string cls_model_file = argv[3];
|
||||||
std::string dict_path = argv[4];
|
std::string img_path = argv[4];
|
||||||
|
std::string dict_path = argv[5];
|
||||||
|
|
||||||
//// load config from txt file
|
//// load config from txt file
|
||||||
auto Config = LoadConfigTxt("./config.txt");
|
auto Config = LoadConfigTxt("./config.txt");
|
||||||
|
@ -333,6 +380,7 @@ int main(int argc, char **argv) {
|
||||||
|
|
||||||
auto det_predictor = loadModel(det_model_file);
|
auto det_predictor = loadModel(det_model_file);
|
||||||
auto rec_predictor = loadModel(rec_model_file);
|
auto rec_predictor = loadModel(rec_model_file);
|
||||||
|
auto cls_predictor = loadModel(cls_model_file);
|
||||||
|
|
||||||
auto charactor_dict = ReadDict(dict_path);
|
auto charactor_dict = ReadDict(dict_path);
|
||||||
charactor_dict.push_back(" ");
|
charactor_dict.push_back(" ");
|
||||||
|
@ -343,7 +391,7 @@ int main(int argc, char **argv) {
|
||||||
std::vector<std::string> rec_text;
|
std::vector<std::string> rec_text;
|
||||||
std::vector<float> rec_text_score;
|
std::vector<float> rec_text_score;
|
||||||
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
|
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
|
||||||
charactor_dict);
|
charactor_dict, cls_predictor);
|
||||||
|
|
||||||
auto end = std::chrono::system_clock::now();
|
auto end = std::chrono::system_clock::now();
|
||||||
auto duration =
|
auto duration =
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
|
@ -0,0 +1,128 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from ppocr.utils.utility import initial_logger
|
||||||
|
from ppocr.utils.utility import get_image_file_list
|
||||||
|
|
||||||
|
logger = initial_logger()
|
||||||
|
|
||||||
|
from ppocr.data.rec.img_tools import warp, resize_norm_img
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleReader(object):
|
||||||
|
def __init__(self, params):
|
||||||
|
if params['mode'] != 'train':
|
||||||
|
self.num_workers = 1
|
||||||
|
else:
|
||||||
|
self.num_workers = params['num_workers']
|
||||||
|
if params['mode'] != 'test':
|
||||||
|
self.img_set_dir = params['img_set_dir']
|
||||||
|
self.label_file_path = params['label_file_path']
|
||||||
|
self.use_gpu = params['use_gpu']
|
||||||
|
self.image_shape = params['image_shape']
|
||||||
|
self.mode = params['mode']
|
||||||
|
self.infer_img = params['infer_img']
|
||||||
|
self.use_distort = False
|
||||||
|
self.label_list = params['label_list']
|
||||||
|
if "distort" in params:
|
||||||
|
self.use_distort = params['distort'] and params['use_gpu']
|
||||||
|
if not params['use_gpu']:
|
||||||
|
logger.info(
|
||||||
|
"Distort operation can only support in GPU.Distort will be set to False."
|
||||||
|
)
|
||||||
|
if params['mode'] == 'train':
|
||||||
|
self.batch_size = params['train_batch_size_per_card']
|
||||||
|
self.drop_last = True
|
||||||
|
else:
|
||||||
|
self.batch_size = params['test_batch_size_per_card']
|
||||||
|
self.drop_last = False
|
||||||
|
self.use_distort = False
|
||||||
|
|
||||||
|
def __call__(self, process_id):
|
||||||
|
if self.mode != 'train':
|
||||||
|
process_id = 0
|
||||||
|
|
||||||
|
def get_device_num():
|
||||||
|
if self.use_gpu:
|
||||||
|
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", 1)
|
||||||
|
gpu_num = len(gpus.split(','))
|
||||||
|
return gpu_num
|
||||||
|
else:
|
||||||
|
cpu_num = os.environ.get("CPU_NUM", 1)
|
||||||
|
return int(cpu_num)
|
||||||
|
|
||||||
|
def sample_iter_reader():
|
||||||
|
if self.mode != 'train' and self.infer_img is not None:
|
||||||
|
image_file_list = get_image_file_list(self.infer_img)
|
||||||
|
for single_img in image_file_list:
|
||||||
|
img = cv2.imread(single_img)
|
||||||
|
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
norm_img = resize_norm_img(img, self.image_shape)
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
yield norm_img
|
||||||
|
else:
|
||||||
|
with open(self.label_file_path, "rb") as fin:
|
||||||
|
label_infor_list = fin.readlines()
|
||||||
|
img_num = len(label_infor_list)
|
||||||
|
img_id_list = list(range(img_num))
|
||||||
|
random.shuffle(img_id_list)
|
||||||
|
if sys.platform == "win32" and self.num_workers != 1:
|
||||||
|
print("multiprocess is not fully compatible with Windows."
|
||||||
|
"num_workers will be 1.")
|
||||||
|
self.num_workers = 1
|
||||||
|
if self.batch_size * get_device_num(
|
||||||
|
) * self.num_workers > img_num:
|
||||||
|
raise Exception(
|
||||||
|
"The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})".
|
||||||
|
format(img_num, self.batch_size * get_device_num() *
|
||||||
|
self.num_workers))
|
||||||
|
for img_id in range(process_id, img_num, self.num_workers):
|
||||||
|
label_infor = label_infor_list[img_id_list[img_id]]
|
||||||
|
substr = label_infor.decode('utf-8').strip("\n").split("\t")
|
||||||
|
img_path = self.img_set_dir + "/" + substr[0]
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
if img is None:
|
||||||
|
logger.info("{} does not exist!".format(img_path))
|
||||||
|
continue
|
||||||
|
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
|
label = substr[1]
|
||||||
|
if self.use_distort:
|
||||||
|
img = warp(img, 10)
|
||||||
|
norm_img = resize_norm_img(img, self.image_shape)
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
yield (norm_img, self.label_list.index(int(label)))
|
||||||
|
|
||||||
|
def batch_iter_reader():
|
||||||
|
batch_outs = []
|
||||||
|
for outs in sample_iter_reader():
|
||||||
|
batch_outs.append(outs)
|
||||||
|
if len(batch_outs) == self.batch_size:
|
||||||
|
yield batch_outs
|
||||||
|
batch_outs = []
|
||||||
|
if not self.drop_last:
|
||||||
|
if len(batch_outs) != 0:
|
||||||
|
yield batch_outs
|
||||||
|
|
||||||
|
if self.infer_img is None:
|
||||||
|
return batch_iter_reader
|
||||||
|
return sample_iter_reader
|
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from paddle import fluid
|
||||||
|
|
||||||
|
from ppocr.utils.utility import create_module
|
||||||
|
from ppocr.utils.utility import initial_logger
|
||||||
|
|
||||||
|
logger = initial_logger()
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
|
||||||
|
class ClsModel(object):
|
||||||
|
def __init__(self, params):
|
||||||
|
super(ClsModel, self).__init__()
|
||||||
|
global_params = params['Global']
|
||||||
|
self.infer_img = global_params['infer_img']
|
||||||
|
|
||||||
|
backbone_params = deepcopy(params["Backbone"])
|
||||||
|
backbone_params.update(global_params)
|
||||||
|
self.backbone = create_module(backbone_params['function']) \
|
||||||
|
(params=backbone_params)
|
||||||
|
|
||||||
|
head_params = deepcopy(params["Head"])
|
||||||
|
head_params.update(global_params)
|
||||||
|
self.head = create_module(head_params['function']) \
|
||||||
|
(params=head_params)
|
||||||
|
|
||||||
|
loss_params = deepcopy(params["Loss"])
|
||||||
|
loss_params.update(global_params)
|
||||||
|
self.loss = create_module(loss_params['function']) \
|
||||||
|
(params=loss_params)
|
||||||
|
|
||||||
|
self.image_shape = global_params['image_shape']
|
||||||
|
|
||||||
|
def create_feed(self, mode):
|
||||||
|
image_shape = deepcopy(self.image_shape)
|
||||||
|
image_shape.insert(0, -1)
|
||||||
|
if mode == "train":
|
||||||
|
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||||
|
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
|
||||||
|
feed_list = [image, label]
|
||||||
|
labels = {'label': label}
|
||||||
|
loader = fluid.io.DataLoader.from_generator(
|
||||||
|
feed_list=feed_list,
|
||||||
|
capacity=64,
|
||||||
|
use_double_buffer=True,
|
||||||
|
iterable=False)
|
||||||
|
else:
|
||||||
|
labels = None
|
||||||
|
loader = None
|
||||||
|
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||||
|
return image, labels, loader
|
||||||
|
|
||||||
|
def __call__(self, mode):
|
||||||
|
image, labels, loader = self.create_feed(mode)
|
||||||
|
inputs = image
|
||||||
|
conv_feas = self.backbone(inputs)
|
||||||
|
predicts = self.head(conv_feas, labels, mode)
|
||||||
|
if mode == "train":
|
||||||
|
loss = self.loss(predicts, labels)
|
||||||
|
label = labels['label']
|
||||||
|
acc = fluid.layers.accuracy(predicts['predict'], label, k=1)
|
||||||
|
outputs = {'total_loss': loss, 'decoded_out': \
|
||||||
|
predicts['decoded_out'], 'label': label, 'acc': acc}
|
||||||
|
return loader, outputs
|
||||||
|
|
||||||
|
else:
|
||||||
|
return loader, predicts
|
|
@ -0,0 +1,46 @@
|
||||||
|
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
|
|
||||||
|
class ClsHead(object):
|
||||||
|
def __init__(self, params):
|
||||||
|
super(ClsHead, self).__init__()
|
||||||
|
self.class_dim = params['class_dim']
|
||||||
|
|
||||||
|
def __call__(self, inputs, labels=None, mode=None):
|
||||||
|
pool = fluid.layers.pool2d(
|
||||||
|
input=inputs, pool_type='avg', global_pooling=True)
|
||||||
|
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
|
||||||
|
|
||||||
|
out = fluid.layers.fc(
|
||||||
|
input=pool,
|
||||||
|
size=self.class_dim,
|
||||||
|
param_attr=fluid.param_attr.ParamAttr(
|
||||||
|
name="fc_0.w_0",
|
||||||
|
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||||
|
bias_attr=fluid.param_attr.ParamAttr(name="fc_0.b_0"))
|
||||||
|
|
||||||
|
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
|
||||||
|
out_label = fluid.layers.argmax(out, axis=1)
|
||||||
|
predicts = {'predict': softmax_out, 'decoded_out': out_label}
|
||||||
|
return predicts
|
|
@ -0,0 +1,33 @@
|
||||||
|
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
|
|
||||||
|
class ClsLoss(object):
|
||||||
|
def __init__(self, params):
|
||||||
|
super(ClsLoss, self).__init__()
|
||||||
|
self.loss_func = fluid.layers.cross_entropy
|
||||||
|
|
||||||
|
def __call__(self, predicts, labels):
|
||||||
|
predict = predicts['predict']
|
||||||
|
label = labels['label']
|
||||||
|
# softmax_out = fluid.layers.softmax(predict, use_cudnn=False)
|
||||||
|
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
||||||
|
sum_cost = fluid.layers.mean(cost)
|
||||||
|
return sum_cost
|
|
@ -45,10 +45,12 @@ from ppocr.utils.save_load import init_model
|
||||||
from eval_utils.eval_det_utils import eval_det_run
|
from eval_utils.eval_det_utils import eval_det_run
|
||||||
from eval_utils.eval_rec_utils import test_rec_benchmark
|
from eval_utils.eval_rec_utils import test_rec_benchmark
|
||||||
from eval_utils.eval_rec_utils import eval_rec_run
|
from eval_utils.eval_rec_utils import eval_rec_run
|
||||||
|
from eval_utils.eval_cls_utils import eval_cls_run
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
startup_prog, eval_program, place, config, train_alg_type = program.preprocess()
|
startup_prog, eval_program, place, config, train_alg_type = program.preprocess(
|
||||||
|
)
|
||||||
eval_build_outputs = program.build(
|
eval_build_outputs = program.build(
|
||||||
config, eval_program, startup_prog, mode='test')
|
config, eval_program, startup_prog, mode='test')
|
||||||
eval_fetch_name_list = eval_build_outputs[1]
|
eval_fetch_name_list = eval_build_outputs[1]
|
||||||
|
@ -67,6 +69,14 @@ def main():
|
||||||
'fetch_varname_list':eval_fetch_varname_list}
|
'fetch_varname_list':eval_fetch_varname_list}
|
||||||
metrics = eval_det_run(exe, config, eval_info_dict, "eval")
|
metrics = eval_det_run(exe, config, eval_info_dict, "eval")
|
||||||
logger.info("Eval result: {}".format(metrics))
|
logger.info("Eval result: {}".format(metrics))
|
||||||
|
elif train_alg_type == 'cls':
|
||||||
|
eval_reader = reader_main(config=config, mode="eval")
|
||||||
|
eval_info_dict = {'program': eval_program, \
|
||||||
|
'reader': eval_reader, \
|
||||||
|
'fetch_name_list': eval_fetch_name_list, \
|
||||||
|
'fetch_varname_list': eval_fetch_varname_list}
|
||||||
|
metrics = eval_cls_run(exe, eval_info_dict)
|
||||||
|
logger.info("Eval result: {}".format(metrics))
|
||||||
else:
|
else:
|
||||||
reader_type = config['Global']['reader_yml']
|
reader_type = config['Global']['reader_yml']
|
||||||
if "benchmark" not in reader_type:
|
if "benchmark" not in reader_type:
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
|
||||||
|
__all__ = ['eval_class_run']
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
|
||||||
|
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def eval_cls_run(exe, eval_info_dict):
|
||||||
|
"""
|
||||||
|
Run evaluation program, return program outputs.
|
||||||
|
"""
|
||||||
|
total_sample_num = 0
|
||||||
|
total_acc_num = 0
|
||||||
|
total_batch_num = 0
|
||||||
|
|
||||||
|
for data in eval_info_dict['reader']():
|
||||||
|
img_num = len(data)
|
||||||
|
img_list = []
|
||||||
|
label_list = []
|
||||||
|
for ino in range(img_num):
|
||||||
|
img_list.append(data[ino][0])
|
||||||
|
label_list.append(data[ino][1])
|
||||||
|
|
||||||
|
img_list = np.concatenate(img_list, axis=0)
|
||||||
|
outs = exe.run(eval_info_dict['program'], \
|
||||||
|
feed={'image': img_list}, \
|
||||||
|
fetch_list=eval_info_dict['fetch_varname_list'], \
|
||||||
|
return_numpy=False)
|
||||||
|
softmax_outs = np.array(outs[1])
|
||||||
|
|
||||||
|
acc, acc_num = cal_cls_acc(softmax_outs, label_list)
|
||||||
|
total_acc_num += acc_num
|
||||||
|
total_sample_num += len(label_list)
|
||||||
|
# logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
|
||||||
|
total_batch_num += 1
|
||||||
|
avg_acc = total_acc_num * 1.0 / total_sample_num
|
||||||
|
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
|
||||||
|
"total_sample_num": total_sample_num}
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def cal_cls_acc(preds, labels):
|
||||||
|
acc_num = 0
|
||||||
|
for pred, label in zip(preds, labels):
|
||||||
|
if pred == label:
|
||||||
|
acc_num += 1
|
||||||
|
return acc_num / len(preds), acc_num
|
|
@ -0,0 +1,143 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||||
|
|
||||||
|
import tools.infer.utility as utility
|
||||||
|
from ppocr.utils.utility import initial_logger
|
||||||
|
|
||||||
|
logger = initial_logger()
|
||||||
|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||||
|
import cv2
|
||||||
|
import copy
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class TextClassifier(object):
|
||||||
|
def __init__(self, args):
|
||||||
|
self.predictor, self.input_tensor, self.output_tensors = \
|
||||||
|
utility.create_predictor(args, mode="cls")
|
||||||
|
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
|
||||||
|
self.cls_batch_num = args.rec_batch_num
|
||||||
|
self.label_list = args.label_list
|
||||||
|
|
||||||
|
def resize_norm_img(self, img):
|
||||||
|
imgC, imgH, imgW = self.cls_image_shape
|
||||||
|
h = img.shape[0]
|
||||||
|
w = img.shape[1]
|
||||||
|
ratio = w / float(h)
|
||||||
|
if math.ceil(imgH * ratio) > imgW:
|
||||||
|
resized_w = imgW
|
||||||
|
else:
|
||||||
|
resized_w = int(math.ceil(imgH * ratio))
|
||||||
|
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||||
|
resized_image = resized_image.astype('float32')
|
||||||
|
if self.cls_image_shape[0] == 1:
|
||||||
|
resized_image = resized_image / 255
|
||||||
|
resized_image = resized_image[np.newaxis, :]
|
||||||
|
else:
|
||||||
|
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||||
|
resized_image -= 0.5
|
||||||
|
resized_image /= 0.5
|
||||||
|
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||||
|
padding_im[:, :, 0:resized_w] = resized_image
|
||||||
|
return padding_im
|
||||||
|
|
||||||
|
def __call__(self, img_list):
|
||||||
|
img_list = copy.deepcopy(img_list)
|
||||||
|
img_num = len(img_list)
|
||||||
|
# Calculate the aspect ratio of all text bars
|
||||||
|
width_list = []
|
||||||
|
for img in img_list:
|
||||||
|
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||||
|
# Sorting can speed up the cls process
|
||||||
|
indices = np.argsort(np.array(width_list))
|
||||||
|
|
||||||
|
cls_res = [['', 0.0]] * img_num
|
||||||
|
batch_num = self.cls_batch_num
|
||||||
|
predict_time = 0
|
||||||
|
for beg_img_no in range(0, img_num, batch_num):
|
||||||
|
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||||
|
norm_img_batch = []
|
||||||
|
max_wh_ratio = 0
|
||||||
|
for ino in range(beg_img_no, end_img_no):
|
||||||
|
h, w = img_list[indices[ino]].shape[0:2]
|
||||||
|
wh_ratio = w * 1.0 / h
|
||||||
|
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||||
|
for ino in range(beg_img_no, end_img_no):
|
||||||
|
norm_img = self.resize_norm_img(img_list[indices[ino]])
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
norm_img_batch.append(norm_img)
|
||||||
|
norm_img_batch = np.concatenate(norm_img_batch)
|
||||||
|
norm_img_batch = norm_img_batch.copy()
|
||||||
|
starttime = time.time()
|
||||||
|
|
||||||
|
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||||
|
self.predictor.zero_copy_run()
|
||||||
|
|
||||||
|
prob_out = self.output_tensors[0].copy_to_cpu()
|
||||||
|
label_out = self.output_tensors[1].copy_to_cpu()
|
||||||
|
|
||||||
|
elapse = time.time() - starttime
|
||||||
|
predict_time += elapse
|
||||||
|
for rno in range(len(label_out)):
|
||||||
|
label_idx = label_out[rno]
|
||||||
|
score = prob_out[rno][label_idx]
|
||||||
|
label = self.label_list[label_idx]
|
||||||
|
cls_res[indices[beg_img_no + rno]] = [label, score]
|
||||||
|
if label == 180:
|
||||||
|
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||||
|
img_list[indices[beg_img_no + rno]], 1)
|
||||||
|
return img_list, cls_res, predict_time
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
image_file_list = get_image_file_list(args.image_dir)
|
||||||
|
text_classifier = TextClassifier(args)
|
||||||
|
valid_image_file_list = []
|
||||||
|
img_list = []
|
||||||
|
for image_file in image_file_list[:10]:
|
||||||
|
img, flag = check_and_read_gif(image_file)
|
||||||
|
if not flag:
|
||||||
|
img = cv2.imread(image_file)
|
||||||
|
if img is None:
|
||||||
|
logger.info("error in loading image:{}".format(image_file))
|
||||||
|
continue
|
||||||
|
valid_image_file_list.append(image_file)
|
||||||
|
img_list.append(img)
|
||||||
|
try:
|
||||||
|
img_list, cls_res, predict_time = text_classifier(img_list)
|
||||||
|
print(cls_res)
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
for img, angle in zip(img_list, cls_res):
|
||||||
|
plt.title(str(angle))
|
||||||
|
plt.imshow(img)
|
||||||
|
plt.show()
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
exit()
|
||||||
|
for ino in range(len(img_list)):
|
||||||
|
print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino]))
|
||||||
|
print("Total predict time for %d images:%.3f" %
|
||||||
|
(len(img_list), predict_time))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main(utility.parse_args())
|
|
@ -13,16 +13,19 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(__dir__)
|
sys.path.append(__dir__)
|
||||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||||
|
|
||||||
import tools.infer.utility as utility
|
import tools.infer.utility as utility
|
||||||
from ppocr.utils.utility import initial_logger
|
from ppocr.utils.utility import initial_logger
|
||||||
|
|
||||||
logger = initial_logger()
|
logger = initial_logger()
|
||||||
import cv2
|
import cv2
|
||||||
import tools.infer.predict_det as predict_det
|
import tools.infer.predict_det as predict_det
|
||||||
import tools.infer.predict_rec as predict_rec
|
import tools.infer.predict_rec as predict_rec
|
||||||
|
import tools.infer.predict_cls as predict_cls
|
||||||
import copy
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
|
@ -37,6 +40,7 @@ class TextSystem(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
self.text_detector = predict_det.TextDetector(args)
|
self.text_detector = predict_det.TextDetector(args)
|
||||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||||
|
self.text_classifier = predict_cls.TextClassifier(args)
|
||||||
|
|
||||||
def get_rotate_crop_image(self, img, points):
|
def get_rotate_crop_image(self, img, points):
|
||||||
'''
|
'''
|
||||||
|
@ -91,7 +95,10 @@ class TextSystem(object):
|
||||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||||
img_crop_list.append(img_crop)
|
img_crop_list.append(img_crop)
|
||||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
img_rotate_list, angle_list, elapse = self.text_classifier(
|
||||||
|
img_crop_list)
|
||||||
|
print("cls num : {}, elapse : {}".format(len(img_rotate_list), elapse))
|
||||||
|
rec_res, elapse = self.text_recognizer(img_rotate_list)
|
||||||
print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
|
print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
|
||||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||||
return dt_boxes, rec_res
|
return dt_boxes, rec_res
|
||||||
|
@ -110,8 +117,8 @@ def sorted_boxes(dt_boxes):
|
||||||
_boxes = list(sorted_boxes)
|
_boxes = list(sorted_boxes)
|
||||||
|
|
||||||
for i in range(num_boxes - 1):
|
for i in range(num_boxes - 1):
|
||||||
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
|
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||||
tmp = _boxes[i]
|
tmp = _boxes[i]
|
||||||
_boxes[i] = _boxes[i + 1]
|
_boxes[i] = _boxes[i + 1]
|
||||||
_boxes[i + 1] = tmp
|
_boxes[i + 1] = tmp
|
||||||
|
|
|
@ -65,6 +65,13 @@ def parse_args():
|
||||||
type=str,
|
type=str,
|
||||||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
default="./ppocr/utils/ppocr_keys_v1.txt")
|
||||||
parser.add_argument("--use_space_char", type=bool, default=True)
|
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||||
|
|
||||||
|
# params for text classifier
|
||||||
|
parser.add_argument("--cls_model_dir", type=str)
|
||||||
|
parser.add_argument("--cls_image_shape", type=str, default="3, 32, 100")
|
||||||
|
parser.add_argument("--label_list", type=list, default=[0, 180])
|
||||||
|
parser.add_argument("--cls_batch_num", type=int, default=30)
|
||||||
|
|
||||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -72,6 +79,8 @@ def parse_args():
|
||||||
def create_predictor(args, mode):
|
def create_predictor(args, mode):
|
||||||
if mode == "det":
|
if mode == "det":
|
||||||
model_dir = args.det_model_dir
|
model_dir = args.det_model_dir
|
||||||
|
elif mode == 'cls':
|
||||||
|
model_dir = args.cls_model_dir
|
||||||
else:
|
else:
|
||||||
model_dir = args.rec_model_dir
|
model_dir = args.rec_model_dir
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
__dir__ = os.path.dirname(__file__)
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.join(__dir__, '..'))
|
||||||
|
|
||||||
|
|
||||||
|
def set_paddle_flags(**kwargs):
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if os.environ.get(key, None) is None:
|
||||||
|
os.environ[key] = str(value)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(paddle-dev): All of these flags should be
|
||||||
|
# set before `import paddle`. Otherwise, it would
|
||||||
|
# not take any effect.
|
||||||
|
set_paddle_flags(
|
||||||
|
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||||
|
)
|
||||||
|
|
||||||
|
import tools.program as program
|
||||||
|
from paddle import fluid
|
||||||
|
from ppocr.utils.utility import initial_logger
|
||||||
|
logger = initial_logger()
|
||||||
|
from ppocr.data.reader_main import reader_main
|
||||||
|
from ppocr.utils.save_load import init_model
|
||||||
|
from ppocr.utils.utility import create_module
|
||||||
|
from ppocr.utils.utility import get_image_file_list
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = program.load_config(FLAGS.config)
|
||||||
|
program.merge_config(FLAGS.opt)
|
||||||
|
logger.info(config)
|
||||||
|
|
||||||
|
# check if set use_gpu=True in paddlepaddle cpu version
|
||||||
|
use_gpu = config['Global']['use_gpu']
|
||||||
|
# check_gpu(use_gpu)
|
||||||
|
|
||||||
|
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
|
||||||
|
rec_model = create_module(config['Architecture']['function'])(params=config)
|
||||||
|
startup_prog = fluid.Program()
|
||||||
|
eval_prog = fluid.Program()
|
||||||
|
with fluid.program_guard(eval_prog, startup_prog):
|
||||||
|
with fluid.unique_name.guard():
|
||||||
|
_, outputs = rec_model(mode="test")
|
||||||
|
fetch_name_list = list(outputs.keys())
|
||||||
|
fetch_varname_list = [outputs[v].name for v in fetch_name_list]
|
||||||
|
eval_prog = eval_prog.clone(for_test=True)
|
||||||
|
exe.run(startup_prog)
|
||||||
|
|
||||||
|
init_model(config, eval_prog, exe)
|
||||||
|
|
||||||
|
blobs = reader_main(config, 'test')()
|
||||||
|
infer_img = config['Global']['infer_img']
|
||||||
|
infer_list = get_image_file_list(infer_img)
|
||||||
|
max_img_num = len(infer_list)
|
||||||
|
if len(infer_list) == 0:
|
||||||
|
logger.info("Can not find img in infer_img dir.")
|
||||||
|
for i in range(max_img_num):
|
||||||
|
logger.info("infer_img:%s" % infer_list[i])
|
||||||
|
img = next(blobs)
|
||||||
|
predict = exe.run(program=eval_prog,
|
||||||
|
feed={"image": img},
|
||||||
|
fetch_list=fetch_varname_list,
|
||||||
|
return_numpy=False)
|
||||||
|
for k in predict:
|
||||||
|
k = np.array(k)
|
||||||
|
print(k)
|
||||||
|
# save for inference model
|
||||||
|
target_var = []
|
||||||
|
for key, values in outputs.items():
|
||||||
|
target_var.append(values)
|
||||||
|
|
||||||
|
fluid.io.save_inference_model(
|
||||||
|
"./output",
|
||||||
|
feeded_var_names=['image'],
|
||||||
|
target_vars=target_var,
|
||||||
|
executor=exe,
|
||||||
|
main_program=eval_prog,
|
||||||
|
model_filename="model",
|
||||||
|
params_filename="params")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = program.ArgsParser()
|
||||||
|
FLAGS = parser.parse_args()
|
||||||
|
main()
|
|
@ -30,6 +30,7 @@ import time
|
||||||
from ppocr.utils.stats import TrainingStats
|
from ppocr.utils.stats import TrainingStats
|
||||||
from eval_utils.eval_det_utils import eval_det_run
|
from eval_utils.eval_det_utils import eval_det_run
|
||||||
from eval_utils.eval_rec_utils import eval_rec_run
|
from eval_utils.eval_rec_utils import eval_rec_run
|
||||||
|
from eval_utils.eval_cls_utils import eval_cls_run
|
||||||
from ppocr.utils.save_load import save_model
|
from ppocr.utils.save_load import save_model
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
||||||
|
@ -398,6 +399,87 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
|
||||||
|
train_batch_id = 0
|
||||||
|
log_smooth_window = config['Global']['log_smooth_window']
|
||||||
|
epoch_num = config['Global']['epoch_num']
|
||||||
|
print_batch_step = config['Global']['print_batch_step']
|
||||||
|
eval_batch_step = config['Global']['eval_batch_step']
|
||||||
|
start_eval_step = 0
|
||||||
|
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
||||||
|
start_eval_step = eval_batch_step[0]
|
||||||
|
eval_batch_step = eval_batch_step[1]
|
||||||
|
logger.info(
|
||||||
|
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
|
||||||
|
format(start_eval_step, eval_batch_step))
|
||||||
|
save_epoch_step = config['Global']['save_epoch_step']
|
||||||
|
save_model_dir = config['Global']['save_model_dir']
|
||||||
|
if not os.path.exists(save_model_dir):
|
||||||
|
os.makedirs(save_model_dir)
|
||||||
|
train_stats = TrainingStats(log_smooth_window, ['loss', 'acc'])
|
||||||
|
best_eval_acc = -1
|
||||||
|
best_batch_id = 0
|
||||||
|
best_epoch = 0
|
||||||
|
train_loader = train_info_dict['reader']
|
||||||
|
for epoch in range(epoch_num):
|
||||||
|
train_loader.start()
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
t1 = time.time()
|
||||||
|
train_outs = exe.run(
|
||||||
|
program=train_info_dict['compile_program'],
|
||||||
|
fetch_list=train_info_dict['fetch_varname_list'],
|
||||||
|
return_numpy=False)
|
||||||
|
fetch_map = dict(
|
||||||
|
zip(train_info_dict['fetch_name_list'],
|
||||||
|
range(len(train_outs))))
|
||||||
|
|
||||||
|
loss = np.mean(np.array(train_outs[fetch_map['total_loss']]))
|
||||||
|
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
|
||||||
|
acc = np.mean(np.array(train_outs[fetch_map['acc']]))
|
||||||
|
|
||||||
|
t2 = time.time()
|
||||||
|
train_batch_elapse = t2 - t1
|
||||||
|
stats = {'loss': loss, 'acc': acc}
|
||||||
|
train_stats.update(stats)
|
||||||
|
if train_batch_id > start_eval_step and (train_batch_id - start_eval_step) \
|
||||||
|
% print_batch_step == 0:
|
||||||
|
logs = train_stats.log()
|
||||||
|
strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format(
|
||||||
|
epoch, train_batch_id, lr, logs, train_batch_elapse)
|
||||||
|
logger.info(strs)
|
||||||
|
|
||||||
|
if train_batch_id > 0 and\
|
||||||
|
train_batch_id % eval_batch_step == 0:
|
||||||
|
model_average = train_info_dict['model_average']
|
||||||
|
if model_average != None:
|
||||||
|
model_average.apply(exe)
|
||||||
|
metrics = eval_cls_run(exe, eval_info_dict)
|
||||||
|
eval_acc = metrics['avg_acc']
|
||||||
|
eval_sample_num = metrics['total_sample_num']
|
||||||
|
if eval_acc > best_eval_acc:
|
||||||
|
best_eval_acc = eval_acc
|
||||||
|
best_batch_id = train_batch_id
|
||||||
|
best_epoch = epoch
|
||||||
|
save_path = save_model_dir + "/best_accuracy"
|
||||||
|
save_model(train_info_dict['train_program'], save_path)
|
||||||
|
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
|
||||||
|
train_batch_id, eval_acc, best_eval_acc, best_epoch,
|
||||||
|
best_batch_id, eval_sample_num)
|
||||||
|
logger.info(strs)
|
||||||
|
train_batch_id += 1
|
||||||
|
|
||||||
|
except fluid.core.EOFException:
|
||||||
|
train_loader.reset()
|
||||||
|
if epoch == 0 and save_epoch_step == 1:
|
||||||
|
save_path = save_model_dir + "/iter_epoch_0"
|
||||||
|
save_model(train_info_dict['train_program'], save_path)
|
||||||
|
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||||
|
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||||
|
save_model(train_info_dict['train_program'], save_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def preprocess():
|
def preprocess():
|
||||||
FLAGS = ArgsParser().parse_args()
|
FLAGS = ArgsParser().parse_args()
|
||||||
config = load_config(FLAGS.config)
|
config = load_config(FLAGS.config)
|
||||||
|
@ -409,7 +491,9 @@ def preprocess():
|
||||||
check_gpu(use_gpu)
|
check_gpu(use_gpu)
|
||||||
|
|
||||||
alg = config['Global']['algorithm']
|
alg = config['Global']['algorithm']
|
||||||
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
|
assert alg in [
|
||||||
|
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
|
||||||
|
]
|
||||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
||||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||||
|
|
||||||
|
@ -419,7 +503,9 @@ def preprocess():
|
||||||
|
|
||||||
if alg in ['EAST', 'DB', 'SAST']:
|
if alg in ['EAST', 'DB', 'SAST']:
|
||||||
train_alg_type = 'det'
|
train_alg_type = 'det'
|
||||||
else:
|
elif alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
||||||
train_alg_type = 'rec'
|
train_alg_type = 'rec'
|
||||||
|
else:
|
||||||
|
train_alg_type = 'cls'
|
||||||
|
|
||||||
return startup_program, train_program, place, config, train_alg_type
|
return startup_program, train_program, place, config, train_alg_type
|
||||||
|
|
|
@ -75,7 +75,8 @@ def main():
|
||||||
|
|
||||||
# dump mode structure
|
# dump mode structure
|
||||||
if config['Global']['debug']:
|
if config['Global']['debug']:
|
||||||
if train_alg_type == 'rec' and 'attention' in config['Global']['loss_type']:
|
if train_alg_type == 'rec' and 'attention' in config['Global'][
|
||||||
|
'loss_type']:
|
||||||
logger.warning('Does not suport dump attention...')
|
logger.warning('Does not suport dump attention...')
|
||||||
else:
|
else:
|
||||||
summary(train_program)
|
summary(train_program)
|
||||||
|
@ -96,8 +97,10 @@ def main():
|
||||||
|
|
||||||
if train_alg_type == 'det':
|
if train_alg_type == 'det':
|
||||||
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
|
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
|
||||||
else:
|
elif train_alg_type == 'rec':
|
||||||
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
||||||
|
else:
|
||||||
|
program.train_eval_cls_run(config, exe, train_info_dict, eval_info_dict)
|
||||||
|
|
||||||
|
|
||||||
def test_reader():
|
def test_reader():
|
||||||
|
@ -119,6 +122,7 @@ def test_reader():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
startup_program, train_program, place, config, train_alg_type = program.preprocess()
|
startup_program, train_program, place, config, train_alg_type = program.preprocess(
|
||||||
|
)
|
||||||
main()
|
main()
|
||||||
# test_reader()
|
# test_reader()
|
||||||
|
|
Loading…
Reference in New Issue