Merge branch 'develop' of https://github.com/PaddlePaddle/PaddleOCR into whl
This commit is contained in:
commit
ecba3f85d6
|
@ -0,0 +1,44 @@
|
|||
Global:
|
||||
algorithm: CLS
|
||||
use_gpu: False
|
||||
epoch_num: 100
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 100
|
||||
save_model_dir: output/cls_mv3
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: 500
|
||||
train_batch_size_per_card: 512
|
||||
test_batch_size_per_card: 512
|
||||
image_shape: [3, 48, 192]
|
||||
label_list: ['0','180']
|
||||
distort: True
|
||||
reader_yml: ./configs/cls/cls_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
infer_img:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.cls_model,ClsModel
|
||||
|
||||
Backbone:
|
||||
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
|
||||
scale: 0.35
|
||||
model_name: small
|
||||
|
||||
Head:
|
||||
function: ppocr.modeling.heads.cls_head,ClsHead
|
||||
class_dim: 2
|
||||
|
||||
Loss:
|
||||
function: ppocr.modeling.losses.cls_loss,ClsLoss
|
||||
|
||||
Optimizer:
|
||||
function: ppocr.optimizer,AdamDecay
|
||||
base_lr: 0.001
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
decay:
|
||||
function: cosine_decay
|
||||
step_each_epoch: 1169
|
||||
total_epoch: 100
|
|
@ -0,0 +1,13 @@
|
|||
TrainReader:
|
||||
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
|
||||
num_workers: 8
|
||||
img_set_dir: ./train_data/cls
|
||||
label_file_path: ./train_data/cls/train.txt
|
||||
|
||||
EvalReader:
|
||||
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
|
||||
img_set_dir: ./train_data/cls
|
||||
label_file_path: ./train_data/cls/test.txt
|
||||
|
||||
TestReader:
|
||||
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
|
Binary file not shown.
After Width: | Height: | Size: 198 KiB |
Binary file not shown.
After Width: | Height: | Size: 171 KiB |
Binary file not shown.
After Width: | Height: | Size: 61 KiB |
|
@ -4,29 +4,29 @@
|
|||
|
||||
#include "native.h"
|
||||
#include "ocr_ppredictor.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <paddle_api.h>
|
||||
#include <string>
|
||||
|
||||
static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode);
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jlong JNICALL
|
||||
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(JNIEnv *env, jobject thiz,
|
||||
jstring j_det_model_path,
|
||||
jstring j_rec_model_path,
|
||||
jint j_thread_num,
|
||||
jstring j_cpu_mode) {
|
||||
std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path);
|
||||
std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path);
|
||||
int thread_num = j_thread_num;
|
||||
std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode);
|
||||
ppredictor::OCR_Config conf;
|
||||
conf.thread_num = thread_num;
|
||||
conf.mode = str_to_cpu_mode(cpu_mode);
|
||||
ppredictor::OCR_PPredictor *orc_predictor = new ppredictor::OCR_PPredictor{conf};
|
||||
orc_predictor->init_from_file(det_model_path, rec_model_path);
|
||||
return reinterpret_cast<jlong>(orc_predictor);
|
||||
extern "C" JNIEXPORT jlong JNICALL
|
||||
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(
|
||||
JNIEnv *env, jobject thiz, jstring j_det_model_path,
|
||||
jstring j_rec_model_path, jstring j_cls_model_path, jint j_thread_num,
|
||||
jstring j_cpu_mode) {
|
||||
std::string det_model_path = jstring_to_cpp_string(env, j_det_model_path);
|
||||
std::string rec_model_path = jstring_to_cpp_string(env, j_rec_model_path);
|
||||
std::string cls_model_path = jstring_to_cpp_string(env, j_cls_model_path);
|
||||
int thread_num = j_thread_num;
|
||||
std::string cpu_mode = jstring_to_cpp_string(env, j_cpu_mode);
|
||||
ppredictor::OCR_Config conf;
|
||||
conf.thread_num = thread_num;
|
||||
conf.mode = str_to_cpu_mode(cpu_mode);
|
||||
ppredictor::OCR_PPredictor *orc_predictor =
|
||||
new ppredictor::OCR_PPredictor{conf};
|
||||
orc_predictor->init_from_file(det_model_path, rec_model_path, cls_model_path);
|
||||
return reinterpret_cast<jlong>(orc_predictor);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -34,82 +34,81 @@ Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_init(JNIEnv *env, jobject
|
|||
* @param cpu_mode
|
||||
* @return
|
||||
*/
|
||||
static paddle::lite_api::PowerMode str_to_cpu_mode(const std::string &cpu_mode) {
|
||||
static std::map<std::string, paddle::lite_api::PowerMode> cpu_mode_map{
|
||||
{"LITE_POWER_HIGH", paddle::lite_api::LITE_POWER_HIGH},
|
||||
{"LITE_POWER_LOW", paddle::lite_api::LITE_POWER_HIGH},
|
||||
{"LITE_POWER_FULL", paddle::lite_api::LITE_POWER_FULL},
|
||||
{"LITE_POWER_NO_BIND", paddle::lite_api::LITE_POWER_NO_BIND},
|
||||
{"LITE_POWER_RAND_HIGH", paddle::lite_api::LITE_POWER_RAND_HIGH},
|
||||
{"LITE_POWER_RAND_LOW", paddle::lite_api::LITE_POWER_RAND_LOW}
|
||||
};
|
||||
std::string upper_key;
|
||||
std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(), ::toupper);
|
||||
auto index = cpu_mode_map.find(upper_key);
|
||||
if (index == cpu_mode_map.end()) {
|
||||
LOGE("cpu_mode not found %s", upper_key.c_str());
|
||||
return paddle::lite_api::LITE_POWER_HIGH;
|
||||
} else {
|
||||
return index->second;
|
||||
}
|
||||
|
||||
static paddle::lite_api::PowerMode
|
||||
str_to_cpu_mode(const std::string &cpu_mode) {
|
||||
static std::map<std::string, paddle::lite_api::PowerMode> cpu_mode_map{
|
||||
{"LITE_POWER_HIGH", paddle::lite_api::LITE_POWER_HIGH},
|
||||
{"LITE_POWER_LOW", paddle::lite_api::LITE_POWER_HIGH},
|
||||
{"LITE_POWER_FULL", paddle::lite_api::LITE_POWER_FULL},
|
||||
{"LITE_POWER_NO_BIND", paddle::lite_api::LITE_POWER_NO_BIND},
|
||||
{"LITE_POWER_RAND_HIGH", paddle::lite_api::LITE_POWER_RAND_HIGH},
|
||||
{"LITE_POWER_RAND_LOW", paddle::lite_api::LITE_POWER_RAND_LOW}};
|
||||
std::string upper_key;
|
||||
std::transform(cpu_mode.cbegin(), cpu_mode.cend(), upper_key.begin(),
|
||||
::toupper);
|
||||
auto index = cpu_mode_map.find(upper_key);
|
||||
if (index == cpu_mode_map.end()) {
|
||||
LOGE("cpu_mode not found %s", upper_key.c_str());
|
||||
return paddle::lite_api::LITE_POWER_HIGH;
|
||||
} else {
|
||||
return index->second;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jfloatArray JNICALL
|
||||
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(JNIEnv *env, jobject thiz,
|
||||
jlong java_pointer, jfloatArray buf,
|
||||
jfloatArray ddims,
|
||||
jobject original_image) {
|
||||
LOGI("begin to run native forward");
|
||||
if (java_pointer == 0) {
|
||||
LOGE("JAVA pointer is NULL");
|
||||
return cpp_array_to_jfloatarray(env, nullptr, 0);
|
||||
}
|
||||
cv::Mat origin = bitmap_to_cv_mat(env, original_image);
|
||||
if (origin.size == 0) {
|
||||
LOGE("origin bitmap cannot convert to CV Mat");
|
||||
return cpp_array_to_jfloatarray(env, nullptr, 0);
|
||||
}
|
||||
ppredictor::OCR_PPredictor *ppredictor = (ppredictor::OCR_PPredictor *) java_pointer;
|
||||
std::vector<float> dims_float_arr = jfloatarray_to_float_vector(env, ddims);
|
||||
std::vector<int64_t> dims_arr;
|
||||
dims_arr.resize(dims_float_arr.size());
|
||||
std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin());
|
||||
extern "C" JNIEXPORT jfloatArray JNICALL
|
||||
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_forward(
|
||||
JNIEnv *env, jobject thiz, jlong java_pointer, jfloatArray buf,
|
||||
jfloatArray ddims, jobject original_image) {
|
||||
LOGI("begin to run native forward");
|
||||
if (java_pointer == 0) {
|
||||
LOGE("JAVA pointer is NULL");
|
||||
return cpp_array_to_jfloatarray(env, nullptr, 0);
|
||||
}
|
||||
cv::Mat origin = bitmap_to_cv_mat(env, original_image);
|
||||
if (origin.size == 0) {
|
||||
LOGE("origin bitmap cannot convert to CV Mat");
|
||||
return cpp_array_to_jfloatarray(env, nullptr, 0);
|
||||
}
|
||||
ppredictor::OCR_PPredictor *ppredictor =
|
||||
(ppredictor::OCR_PPredictor *)java_pointer;
|
||||
std::vector<float> dims_float_arr = jfloatarray_to_float_vector(env, ddims);
|
||||
std::vector<int64_t> dims_arr;
|
||||
dims_arr.resize(dims_float_arr.size());
|
||||
std::copy(dims_float_arr.cbegin(), dims_float_arr.cend(), dims_arr.begin());
|
||||
|
||||
// 这里值有点大,就不调用jfloatarray_to_float_vector了
|
||||
int64_t buf_len = (int64_t) env->GetArrayLength(buf);
|
||||
jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE);
|
||||
float *data = (jfloat *) buf_data;
|
||||
std::vector<ppredictor::OCRPredictResult> results = ppredictor->infer_ocr(dims_arr, data,
|
||||
buf_len,
|
||||
NET_OCR, origin);
|
||||
LOGI("infer_ocr finished with boxes %ld", results.size());
|
||||
// 这里将std::vector<ppredictor::OCRPredictResult> 序列化成 float数组,传输到java层再反序列化
|
||||
std::vector<float> float_arr;
|
||||
for (const ppredictor::OCRPredictResult &r :results) {
|
||||
float_arr.push_back(r.points.size());
|
||||
float_arr.push_back(r.word_index.size());
|
||||
float_arr.push_back(r.score);
|
||||
for (const std::vector<int> &point : r.points) {
|
||||
float_arr.push_back(point.at(0));
|
||||
float_arr.push_back(point.at(1));
|
||||
}
|
||||
for (int index: r.word_index) {
|
||||
float_arr.push_back(index);
|
||||
}
|
||||
// 这里值有点大,就不调用jfloatarray_to_float_vector了
|
||||
int64_t buf_len = (int64_t)env->GetArrayLength(buf);
|
||||
jfloat *buf_data = env->GetFloatArrayElements(buf, JNI_FALSE);
|
||||
float *data = (jfloat *)buf_data;
|
||||
std::vector<ppredictor::OCRPredictResult> results =
|
||||
ppredictor->infer_ocr(dims_arr, data, buf_len, NET_OCR, origin);
|
||||
LOGI("infer_ocr finished with boxes %ld", results.size());
|
||||
// 这里将std::vector<ppredictor::OCRPredictResult> 序列化成
|
||||
// float数组,传输到java层再反序列化
|
||||
std::vector<float> float_arr;
|
||||
for (const ppredictor::OCRPredictResult &r : results) {
|
||||
float_arr.push_back(r.points.size());
|
||||
float_arr.push_back(r.word_index.size());
|
||||
float_arr.push_back(r.score);
|
||||
for (const std::vector<int> &point : r.points) {
|
||||
float_arr.push_back(point.at(0));
|
||||
float_arr.push_back(point.at(1));
|
||||
}
|
||||
return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size());
|
||||
for (int index : r.word_index) {
|
||||
float_arr.push_back(index);
|
||||
}
|
||||
}
|
||||
return cpp_array_to_jfloatarray(env, float_arr.data(), float_arr.size());
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT void JNICALL
|
||||
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_release(JNIEnv *env, jobject thiz,
|
||||
jlong java_pointer){
|
||||
if (java_pointer == 0) {
|
||||
LOGE("JAVA pointer is NULL");
|
||||
return;
|
||||
}
|
||||
ppredictor::OCR_PPredictor *ppredictor = (ppredictor::OCR_PPredictor *) java_pointer;
|
||||
delete ppredictor;
|
||||
extern "C" JNIEXPORT void JNICALL
|
||||
Java_com_baidu_paddle_lite_demo_ocr_OCRPredictorNative_release(
|
||||
JNIEnv *env, jobject thiz, jlong java_pointer) {
|
||||
if (java_pointer == 0) {
|
||||
LOGE("JAVA pointer is NULL");
|
||||
return;
|
||||
}
|
||||
ppredictor::OCR_PPredictor *ppredictor =
|
||||
(ppredictor::OCR_PPredictor *)java_pointer;
|
||||
delete ppredictor;
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "ocr_cls_process.h"
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
const std::vector<int> CLS_IMAGE_SHAPE = {3, 32, 100};
|
||||
|
||||
cv::Mat cls_resize_img(const cv::Mat &img) {
|
||||
int imgC = CLS_IMAGE_SHAPE[0];
|
||||
int imgW = CLS_IMAGE_SHAPE[2];
|
||||
int imgH = CLS_IMAGE_SHAPE[1];
|
||||
|
||||
float ratio = float(img.cols) / float(img.rows);
|
||||
int resize_w = 0;
|
||||
if (ceilf(imgH * ratio) > imgW)
|
||||
resize_w = imgW;
|
||||
else
|
||||
resize_w = int(ceilf(imgH * ratio));
|
||||
|
||||
cv::Mat resize_img;
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||
cv::INTER_CUBIC);
|
||||
|
||||
if (resize_w < imgW) {
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, int(imgW - resize_w),
|
||||
cv::BORDER_CONSTANT, {0, 0, 0});
|
||||
}
|
||||
return resize_img;
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <vector>
|
||||
|
||||
extern const std::vector<int> CLS_IMAGE_SHAPE;
|
||||
|
||||
cv::Mat cls_resize_img(const cv::Mat &img);
|
|
@ -3,38 +3,48 @@
|
|||
//
|
||||
|
||||
#include "ocr_ppredictor.h"
|
||||
#include "preprocess.h"
|
||||
#include "common.h"
|
||||
#include "ocr_db_post_process.h"
|
||||
#include "ocr_cls_process.h"
|
||||
#include "ocr_crnn_process.h"
|
||||
#include "ocr_db_post_process.h"
|
||||
#include "preprocess.h"
|
||||
|
||||
namespace ppredictor {
|
||||
|
||||
OCR_PPredictor::OCR_PPredictor(const OCR_Config &config) : _config(config) {
|
||||
OCR_PPredictor::OCR_PPredictor(const OCR_Config &config) : _config(config) {}
|
||||
|
||||
int OCR_PPredictor::init(const std::string &det_model_content,
|
||||
const std::string &rec_model_content,
|
||||
const std::string &cls_model_content) {
|
||||
_det_predictor = std::unique_ptr<PPredictor>(
|
||||
new PPredictor{_config.thread_num, NET_OCR, _config.mode});
|
||||
_det_predictor->init_nb(det_model_content);
|
||||
|
||||
_rec_predictor = std::unique_ptr<PPredictor>(
|
||||
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
|
||||
_rec_predictor->init_nb(rec_model_content);
|
||||
|
||||
_cls_predictor = std::unique_ptr<PPredictor>(
|
||||
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
|
||||
_cls_predictor->init_nb(cls_model_content);
|
||||
return RETURN_OK;
|
||||
}
|
||||
|
||||
int
|
||||
OCR_PPredictor::init(const std::string &det_model_content, const std::string &rec_model_content) {
|
||||
_det_predictor = std::unique_ptr<PPredictor>(
|
||||
new PPredictor{_config.thread_num, NET_OCR, _config.mode});
|
||||
_det_predictor->init_nb(det_model_content);
|
||||
int OCR_PPredictor::init_from_file(const std::string &det_model_path,
|
||||
const std::string &rec_model_path,
|
||||
const std::string &cls_model_path) {
|
||||
_det_predictor = std::unique_ptr<PPredictor>(
|
||||
new PPredictor{_config.thread_num, NET_OCR, _config.mode});
|
||||
_det_predictor->init_from_file(det_model_path);
|
||||
|
||||
_rec_predictor = std::unique_ptr<PPredictor>(
|
||||
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
|
||||
_rec_predictor->init_nb(rec_model_content);
|
||||
return RETURN_OK;
|
||||
}
|
||||
_rec_predictor = std::unique_ptr<PPredictor>(
|
||||
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
|
||||
_rec_predictor->init_from_file(rec_model_path);
|
||||
|
||||
int OCR_PPredictor::init_from_file(const std::string &det_model_path, const std::string &rec_model_path){
|
||||
_det_predictor = std::unique_ptr<PPredictor>(
|
||||
new PPredictor{_config.thread_num, NET_OCR, _config.mode});
|
||||
_det_predictor->init_from_file(det_model_path);
|
||||
|
||||
_rec_predictor = std::unique_ptr<PPredictor>(
|
||||
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
|
||||
_rec_predictor->init_from_file(rec_model_path);
|
||||
return RETURN_OK;
|
||||
_cls_predictor = std::unique_ptr<PPredictor>(
|
||||
new PPredictor{_config.thread_num, NET_OCR_INTERNAL, _config.mode});
|
||||
_cls_predictor->init_from_file(cls_model_path);
|
||||
return RETURN_OK;
|
||||
}
|
||||
/**
|
||||
* for debug use, show result of First Step
|
||||
|
@ -42,145 +52,188 @@ int OCR_PPredictor::init_from_file(const std::string &det_model_path, const std:
|
|||
* @param boxes
|
||||
* @param srcimg
|
||||
*/
|
||||
static void visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes,
|
||||
const std::vector<std::vector<std::vector<int>>> &boxes,
|
||||
const cv::Mat &srcimg) {
|
||||
// visualization
|
||||
cv::Point rook_points[filter_boxes.size()][4];
|
||||
for (int n = 0; n < filter_boxes.size(); n++) {
|
||||
for (int m = 0; m < filter_boxes[0].size(); m++) {
|
||||
rook_points[n][m] = cv::Point(int(filter_boxes[n][m][0]), int(filter_boxes[n][m][1]));
|
||||
}
|
||||
static void
|
||||
visual_img(const std::vector<std::vector<std::vector<int>>> &filter_boxes,
|
||||
const std::vector<std::vector<std::vector<int>>> &boxes,
|
||||
const cv::Mat &srcimg) {
|
||||
// visualization
|
||||
cv::Point rook_points[filter_boxes.size()][4];
|
||||
for (int n = 0; n < filter_boxes.size(); n++) {
|
||||
for (int m = 0; m < filter_boxes[0].size(); m++) {
|
||||
rook_points[n][m] =
|
||||
cv::Point(int(filter_boxes[n][m][0]), int(filter_boxes[n][m][1]));
|
||||
}
|
||||
}
|
||||
|
||||
cv::Mat img_vis;
|
||||
srcimg.copyTo(img_vis);
|
||||
for (int n = 0; n < boxes.size(); n++) {
|
||||
const cv::Point *ppt[1] = {rook_points[n]};
|
||||
int npt[] = {4};
|
||||
cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
|
||||
}
|
||||
// 调试用,自行替换需要修改的路径
|
||||
cv::imwrite("/sdcard/1/vis.png", img_vis);
|
||||
cv::Mat img_vis;
|
||||
srcimg.copyTo(img_vis);
|
||||
for (int n = 0; n < boxes.size(); n++) {
|
||||
const cv::Point *ppt[1] = {rook_points[n]};
|
||||
int npt[] = {4};
|
||||
cv::polylines(img_vis, ppt, npt, 1, 1, CV_RGB(0, 255, 0), 2, 8, 0);
|
||||
}
|
||||
// 调试用,自行替换需要修改的路径
|
||||
cv::imwrite("/sdcard/1/vis.png", img_vis);
|
||||
}
|
||||
|
||||
std::vector<OCRPredictResult>
|
||||
OCR_PPredictor::infer_ocr(const std::vector<int64_t> &dims, const float *input_data, int input_len,
|
||||
int net_flag, cv::Mat &origin) {
|
||||
OCR_PPredictor::infer_ocr(const std::vector<int64_t> &dims,
|
||||
const float *input_data, int input_len, int net_flag,
|
||||
cv::Mat &origin) {
|
||||
PredictorInput input = _det_predictor->get_first_input();
|
||||
input.set_dims(dims);
|
||||
input.set_data(input_data, input_len);
|
||||
std::vector<PredictorOutput> results = _det_predictor->infer();
|
||||
PredictorOutput &res = results.at(0);
|
||||
std::vector<std::vector<std::vector<int>>> filtered_box = calc_filtered_boxes(
|
||||
res.get_float_data(), res.get_size(), (int)dims[2], (int)dims[3], origin);
|
||||
LOGI("Filter_box size %ld", filtered_box.size());
|
||||
return infer_rec(filtered_box, origin);
|
||||
}
|
||||
|
||||
PredictorInput input = _det_predictor->get_first_input();
|
||||
std::vector<OCRPredictResult> OCR_PPredictor::infer_rec(
|
||||
const std::vector<std::vector<std::vector<int>>> &boxes,
|
||||
const cv::Mat &origin_img) {
|
||||
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
std::vector<int64_t> dims = {1, 3, 0, 0};
|
||||
std::vector<OCRPredictResult> ocr_results;
|
||||
|
||||
PredictorInput input = _rec_predictor->get_first_input();
|
||||
for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) {
|
||||
const std::vector<std::vector<int>> &box = *bp;
|
||||
cv::Mat crop_img = get_rotate_crop_image(origin_img, box);
|
||||
crop_img = infer_cls(crop_img);
|
||||
|
||||
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
|
||||
cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio);
|
||||
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
|
||||
const float *dimg = reinterpret_cast<const float *>(input_image.data);
|
||||
int input_size = input_image.rows * input_image.cols;
|
||||
|
||||
dims[2] = input_image.rows;
|
||||
dims[3] = input_image.cols;
|
||||
input.set_dims(dims);
|
||||
input.set_data(input_data, input_len);
|
||||
std::vector<PredictorOutput> results = _det_predictor->infer();
|
||||
PredictorOutput &res = results.at(0);
|
||||
std::vector<std::vector<std::vector<int>>> filtered_box
|
||||
= calc_filtered_boxes(res.get_float_data(), res.get_size(), (int) dims[2], (int) dims[3],
|
||||
origin);
|
||||
LOGI("Filter_box size %ld", filtered_box.size());
|
||||
return infer_rec(filtered_box, origin);
|
||||
|
||||
neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean,
|
||||
scale);
|
||||
|
||||
std::vector<PredictorOutput> results = _rec_predictor->infer();
|
||||
|
||||
OCRPredictResult res;
|
||||
res.word_index = postprocess_rec_word_index(results.at(0));
|
||||
if (res.word_index.empty()) {
|
||||
continue;
|
||||
}
|
||||
res.score = postprocess_rec_score(results.at(1));
|
||||
res.points = box;
|
||||
ocr_results.emplace_back(std::move(res));
|
||||
}
|
||||
LOGI("ocr_results finished %lu", ocr_results.size());
|
||||
return ocr_results;
|
||||
}
|
||||
|
||||
std::vector<OCRPredictResult>
|
||||
OCR_PPredictor::infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes,
|
||||
const cv::Mat &origin_img) {
|
||||
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
std::vector<int64_t> dims = {1, 3, 0, 0};
|
||||
std::vector<OCRPredictResult> ocr_results;
|
||||
cv::Mat OCR_PPredictor::infer_cls(const cv::Mat &img, float thresh) {
|
||||
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
std::vector<int64_t> dims = {1, 3, 0, 0};
|
||||
std::vector<OCRPredictResult> ocr_results;
|
||||
|
||||
PredictorInput input = _rec_predictor->get_first_input();
|
||||
for (auto bp = boxes.crbegin(); bp != boxes.crend(); ++bp) {
|
||||
const std::vector<std::vector<int>> &box = *bp;
|
||||
cv::Mat crop_img = get_rotate_crop_image(origin_img, box);
|
||||
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
|
||||
cv::Mat input_image = crnn_resize_img(crop_img, wh_ratio);
|
||||
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
|
||||
const float *dimg = reinterpret_cast<const float *>(input_image.data);
|
||||
int input_size = input_image.rows * input_image.cols;
|
||||
PredictorInput input = _cls_predictor->get_first_input();
|
||||
|
||||
dims[2] = input_image.rows;
|
||||
dims[3] = input_image.cols;
|
||||
input.set_dims(dims);
|
||||
cv::Mat input_image = cls_resize_img(img);
|
||||
input_image.convertTo(input_image, CV_32FC3, 1 / 255.0f);
|
||||
const float *dimg = reinterpret_cast<const float *>(input_image.data);
|
||||
int input_size = input_image.rows * input_image.cols;
|
||||
|
||||
neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean, scale);
|
||||
dims[2] = input_image.rows;
|
||||
dims[3] = input_image.cols;
|
||||
input.set_dims(dims);
|
||||
|
||||
std::vector<PredictorOutput> results = _rec_predictor->infer();
|
||||
neon_mean_scale(dimg, input.get_mutable_float_data(), input_size, mean,
|
||||
scale);
|
||||
|
||||
OCRPredictResult res;
|
||||
res.word_index = postprocess_rec_word_index(results.at(0));
|
||||
if (res.word_index.empty()) {
|
||||
continue;
|
||||
}
|
||||
res.score = postprocess_rec_score(results.at(1));
|
||||
res.points = box;
|
||||
ocr_results.emplace_back(std::move(res));
|
||||
}
|
||||
LOGI("ocr_results finished %lu", ocr_results.size());
|
||||
return ocr_results;
|
||||
std::vector<PredictorOutput> results = _cls_predictor->infer();
|
||||
|
||||
const float *scores = results.at(0).get_float_data();
|
||||
const int *labels = results.at(1).get_int_data();
|
||||
for (int64_t i = 0; i < results.at(0).get_size(); i++) {
|
||||
LOGI("output scores [%f]", scores[i]);
|
||||
}
|
||||
for (int64_t i = 0; i < results.at(1).get_size(); i++) {
|
||||
LOGI("output label [%d]", labels[i]);
|
||||
}
|
||||
int label_idx = labels[0];
|
||||
float score = scores[label_idx];
|
||||
|
||||
cv::Mat srcimg;
|
||||
img.copyTo(srcimg);
|
||||
if (label_idx % 2 == 1 && score > thresh) {
|
||||
cv::rotate(srcimg, srcimg, 1);
|
||||
}
|
||||
return srcimg;
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::vector<int>>>
|
||||
OCR_PPredictor::calc_filtered_boxes(const float *pred, int pred_size, int output_height,
|
||||
int output_width, const cv::Mat &origin) {
|
||||
const double threshold = 0.3;
|
||||
const double maxvalue = 1;
|
||||
OCR_PPredictor::calc_filtered_boxes(const float *pred, int pred_size,
|
||||
int output_height, int output_width,
|
||||
const cv::Mat &origin) {
|
||||
const double threshold = 0.3;
|
||||
const double maxvalue = 1;
|
||||
|
||||
cv::Mat pred_map = cv::Mat::zeros(output_height, output_width, CV_32F);
|
||||
memcpy(pred_map.data, pred, pred_size * sizeof(float));
|
||||
cv::Mat cbuf_map;
|
||||
pred_map.convertTo(cbuf_map, CV_8UC1);
|
||||
cv::Mat pred_map = cv::Mat::zeros(output_height, output_width, CV_32F);
|
||||
memcpy(pred_map.data, pred, pred_size * sizeof(float));
|
||||
cv::Mat cbuf_map;
|
||||
pred_map.convertTo(cbuf_map, CV_8UC1);
|
||||
|
||||
cv::Mat bit_map;
|
||||
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
|
||||
cv::Mat bit_map;
|
||||
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
|
||||
|
||||
std::vector<std::vector<std::vector<int>>> boxes = boxes_from_bitmap(pred_map, bit_map);
|
||||
float ratio_h = output_height * 1.0f / origin.rows;
|
||||
float ratio_w = output_width * 1.0f / origin.cols;
|
||||
std::vector<std::vector<std::vector<int>>> filter_boxes = filter_tag_det_res(boxes, ratio_h,
|
||||
ratio_w, origin);
|
||||
return filter_boxes;
|
||||
std::vector<std::vector<std::vector<int>>> boxes =
|
||||
boxes_from_bitmap(pred_map, bit_map);
|
||||
float ratio_h = output_height * 1.0f / origin.rows;
|
||||
float ratio_w = output_width * 1.0f / origin.cols;
|
||||
std::vector<std::vector<std::vector<int>>> filter_boxes =
|
||||
filter_tag_det_res(boxes, ratio_h, ratio_w, origin);
|
||||
return filter_boxes;
|
||||
}
|
||||
|
||||
std::vector<int> OCR_PPredictor::postprocess_rec_word_index(const PredictorOutput &res) {
|
||||
const int *rec_idx = res.get_int_data();
|
||||
const std::vector<std::vector<uint64_t>> rec_idx_lod = res.get_lod();
|
||||
std::vector<int>
|
||||
OCR_PPredictor::postprocess_rec_word_index(const PredictorOutput &res) {
|
||||
const int *rec_idx = res.get_int_data();
|
||||
const std::vector<std::vector<uint64_t>> rec_idx_lod = res.get_lod();
|
||||
|
||||
std::vector<int> pred_idx;
|
||||
for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1] * 2); n += 2) {
|
||||
pred_idx.emplace_back(rec_idx[n]);
|
||||
}
|
||||
return pred_idx;
|
||||
std::vector<int> pred_idx;
|
||||
for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1] * 2); n += 2) {
|
||||
pred_idx.emplace_back(rec_idx[n]);
|
||||
}
|
||||
return pred_idx;
|
||||
}
|
||||
|
||||
float OCR_PPredictor::postprocess_rec_score(const PredictorOutput &res) {
|
||||
const float *predict_batch = res.get_float_data();
|
||||
const std::vector<int64_t> predict_shape = res.get_shape();
|
||||
const std::vector<std::vector<uint64_t>> predict_lod = res.get_lod();
|
||||
int blank = predict_shape[1];
|
||||
float score = 0.f;
|
||||
int count = 0;
|
||||
for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) {
|
||||
int argmax_idx = argmax(predict_batch + n * predict_shape[1],
|
||||
predict_batch + (n + 1) * predict_shape[1]);
|
||||
float max_value = predict_batch[n * predict_shape[1] + argmax_idx];
|
||||
if (blank - 1 - argmax_idx > 1e-5) {
|
||||
score += max_value;
|
||||
count += 1;
|
||||
}
|
||||
|
||||
const float *predict_batch = res.get_float_data();
|
||||
const std::vector<int64_t> predict_shape = res.get_shape();
|
||||
const std::vector<std::vector<uint64_t>> predict_lod = res.get_lod();
|
||||
int blank = predict_shape[1];
|
||||
float score = 0.f;
|
||||
int count = 0;
|
||||
for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) {
|
||||
int argmax_idx = argmax(predict_batch + n * predict_shape[1],
|
||||
predict_batch + (n + 1) * predict_shape[1]);
|
||||
float max_value = predict_batch[n * predict_shape[1] + argmax_idx];
|
||||
if (blank - 1 - argmax_idx > 1e-5) {
|
||||
score += max_value;
|
||||
count += 1;
|
||||
}
|
||||
if (count == 0) {
|
||||
LOGE("calc score count 0");
|
||||
} else {
|
||||
score /= count;
|
||||
}
|
||||
LOGI("calc score: %f", score);
|
||||
return score;
|
||||
|
||||
}
|
||||
if (count == 0) {
|
||||
LOGE("calc score count 0");
|
||||
} else {
|
||||
score /= count;
|
||||
}
|
||||
LOGI("calc score: %f", score);
|
||||
return score;
|
||||
}
|
||||
|
||||
|
||||
NET_TYPE OCR_PPredictor::get_net_flag() const {
|
||||
return NET_OCR;
|
||||
}
|
||||
NET_TYPE OCR_PPredictor::get_net_flag() const { return NET_OCR; }
|
||||
}
|
|
@ -4,10 +4,10 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ppredictor.h"
|
||||
#include <opencv2/opencv.hpp>
|
||||
#include <paddle_api.h>
|
||||
#include "ppredictor.h"
|
||||
#include <string>
|
||||
|
||||
namespace ppredictor {
|
||||
|
||||
|
@ -15,17 +15,18 @@ namespace ppredictor {
|
|||
* Config
|
||||
*/
|
||||
struct OCR_Config {
|
||||
int thread_num = 4; // Thread num
|
||||
paddle::lite_api::PowerMode mode = paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode
|
||||
int thread_num = 4; // Thread num
|
||||
paddle::lite_api::PowerMode mode =
|
||||
paddle::lite_api::LITE_POWER_HIGH; // PaddleLite Mode
|
||||
};
|
||||
|
||||
/**
|
||||
* PolyGone Result
|
||||
*/
|
||||
struct OCRPredictResult {
|
||||
std::vector<int> word_index;
|
||||
std::vector<std::vector<int>> points;
|
||||
float score;
|
||||
std::vector<int> word_index;
|
||||
std::vector<std::vector<int>> points;
|
||||
float score;
|
||||
};
|
||||
|
||||
/**
|
||||
|
@ -35,78 +36,87 @@ struct OCRPredictResult {
|
|||
*/
|
||||
class OCR_PPredictor : public PPredictor_Interface {
|
||||
public:
|
||||
OCR_PPredictor(const OCR_Config &config);
|
||||
OCR_PPredictor(const OCR_Config &config);
|
||||
|
||||
virtual ~OCR_PPredictor() {
|
||||
virtual ~OCR_PPredictor() {}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化二个模型的Predictor
|
||||
* @param det_model_content
|
||||
* @param rec_model_content
|
||||
* @return
|
||||
*/
|
||||
int init(const std::string &det_model_content, const std::string &rec_model_content);
|
||||
int init_from_file(const std::string &det_model_path, const std::string &rec_model_path);
|
||||
/**
|
||||
* Return OCR result
|
||||
* @param dims
|
||||
* @param input_data
|
||||
* @param input_len
|
||||
* @param net_flag
|
||||
* @param origin
|
||||
* @return
|
||||
*/
|
||||
virtual std::vector<OCRPredictResult>
|
||||
infer_ocr(const std::vector<int64_t> &dims, const float *input_data, int input_len,
|
||||
int net_flag, cv::Mat &origin);
|
||||
|
||||
|
||||
virtual NET_TYPE get_net_flag() const;
|
||||
/**
|
||||
* 初始化二个模型的Predictor
|
||||
* @param det_model_content
|
||||
* @param rec_model_content
|
||||
* @return
|
||||
*/
|
||||
int init(const std::string &det_model_content,
|
||||
const std::string &rec_model_content,
|
||||
const std::string &cls_model_content);
|
||||
int init_from_file(const std::string &det_model_path,
|
||||
const std::string &rec_model_path,
|
||||
const std::string &cls_model_path);
|
||||
/**
|
||||
* Return OCR result
|
||||
* @param dims
|
||||
* @param input_data
|
||||
* @param input_len
|
||||
* @param net_flag
|
||||
* @param origin
|
||||
* @return
|
||||
*/
|
||||
virtual std::vector<OCRPredictResult>
|
||||
infer_ocr(const std::vector<int64_t> &dims, const float *input_data,
|
||||
int input_len, int net_flag, cv::Mat &origin);
|
||||
|
||||
virtual NET_TYPE get_net_flag() const;
|
||||
|
||||
private:
|
||||
/**
|
||||
* calcul Polygone from the result image of first model
|
||||
* @param pred
|
||||
* @param output_height
|
||||
* @param output_width
|
||||
* @param origin
|
||||
* @return
|
||||
*/
|
||||
std::vector<std::vector<std::vector<int>>>
|
||||
calc_filtered_boxes(const float *pred, int pred_size, int output_height,
|
||||
int output_width, const cv::Mat &origin);
|
||||
|
||||
/**
|
||||
* calcul Polygone from the result image of first model
|
||||
* @param pred
|
||||
* @param output_height
|
||||
* @param output_width
|
||||
* @param origin
|
||||
* @return
|
||||
*/
|
||||
std::vector<std::vector<std::vector<int>>>
|
||||
calc_filtered_boxes(const float *pred, int pred_size, int output_height, int output_width,
|
||||
const cv::Mat &origin);
|
||||
/**
|
||||
* infer for second model
|
||||
*
|
||||
* @param boxes
|
||||
* @param origin
|
||||
* @return
|
||||
*/
|
||||
std::vector<OCRPredictResult>
|
||||
infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes,
|
||||
const cv::Mat &origin);
|
||||
|
||||
/**
|
||||
* infer for second model
|
||||
*
|
||||
* @param boxes
|
||||
* @param origin
|
||||
* @return
|
||||
*/
|
||||
std::vector<OCRPredictResult>
|
||||
infer_rec(const std::vector<std::vector<std::vector<int>>> &boxes, const cv::Mat &origin);
|
||||
/**
|
||||
* infer for cls model
|
||||
*
|
||||
* @param boxes
|
||||
* @param origin
|
||||
* @return
|
||||
*/
|
||||
cv::Mat infer_cls(const cv::Mat &origin, float thresh = 0.5);
|
||||
|
||||
/**
|
||||
* Postprocess or sencod model to extract text
|
||||
* @param res
|
||||
* @return
|
||||
*/
|
||||
std::vector<int> postprocess_rec_word_index(const PredictorOutput &res);
|
||||
/**
|
||||
* Postprocess or sencod model to extract text
|
||||
* @param res
|
||||
* @return
|
||||
*/
|
||||
std::vector<int> postprocess_rec_word_index(const PredictorOutput &res);
|
||||
|
||||
/**
|
||||
* calculate confidence of second model text result
|
||||
* @param res
|
||||
* @return
|
||||
*/
|
||||
float postprocess_rec_score(const PredictorOutput &res);
|
||||
|
||||
std::unique_ptr<PPredictor> _det_predictor;
|
||||
std::unique_ptr<PPredictor> _rec_predictor;
|
||||
OCR_Config _config;
|
||||
/**
|
||||
* calculate confidence of second model text result
|
||||
* @param res
|
||||
* @return
|
||||
*/
|
||||
float postprocess_rec_score(const PredictorOutput &res);
|
||||
|
||||
std::unique_ptr<PPredictor> _det_predictor;
|
||||
std::unique_ptr<PPredictor> _rec_predictor;
|
||||
std::unique_ptr<PPredictor> _cls_predictor;
|
||||
OCR_Config _config;
|
||||
};
|
||||
}
|
||||
|
|
|
@ -29,7 +29,7 @@ public class OCRPredictorNative {
|
|||
public OCRPredictorNative(Config config) {
|
||||
this.config = config;
|
||||
loadLibrary();
|
||||
nativePointer = init(config.detModelFilename, config.recModelFilename,
|
||||
nativePointer = init(config.detModelFilename, config.recModelFilename,config.clsModelFilename,
|
||||
config.cpuThreadNum, config.cpuPower);
|
||||
Log.i("OCRPredictorNative", "load success " + nativePointer);
|
||||
|
||||
|
@ -38,7 +38,7 @@ public class OCRPredictorNative {
|
|||
public void release() {
|
||||
if (nativePointer != 0) {
|
||||
nativePointer = 0;
|
||||
destory(nativePointer);
|
||||
// destory(nativePointer);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,10 +55,11 @@ public class OCRPredictorNative {
|
|||
public String cpuPower;
|
||||
public String detModelFilename;
|
||||
public String recModelFilename;
|
||||
public String clsModelFilename;
|
||||
|
||||
}
|
||||
|
||||
protected native long init(String detModelPath, String recModelPath, int threadNum, String cpuMode);
|
||||
protected native long init(String detModelPath, String recModelPath,String clsModelPath, int threadNum, String cpuMode);
|
||||
|
||||
protected native float[] forward(long pointer, float[] buf, float[] ddims, Bitmap originalImage);
|
||||
|
||||
|
|
|
@ -121,7 +121,8 @@ public class Predictor {
|
|||
config.cpuThreadNum = cpuThreadNum;
|
||||
config.detModelFilename = realPath + File.separator + "ch_det_mv3_db_opt.nb";
|
||||
config.recModelFilename = realPath + File.separator + "ch_rec_mv3_crnn_opt.nb";
|
||||
Log.e("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename);
|
||||
config.clsModelFilename = realPath + File.separator + "cls_opt_arm.nb";
|
||||
Log.e("Predictor", "model path" + config.detModelFilename + " ; " + config.recModelFilename + ";" + config.clsModelFilename);
|
||||
config.cpuPower = cpuPowerMode;
|
||||
paddlePredictor = new OCRPredictorNative(config);
|
||||
|
||||
|
|
|
@ -57,6 +57,12 @@ public:
|
|||
|
||||
this->char_list_file.assign(config_map_["char_list_file"]);
|
||||
|
||||
this->use_angle_cls = bool(stoi(config_map_["use_angle_cls"]));
|
||||
|
||||
this->cls_model_dir.assign(config_map_["cls_model_dir"]);
|
||||
|
||||
this->cls_thresh = stod(config_map_["cls_thresh"]);
|
||||
|
||||
this->visualize = bool(stoi(config_map_["visualize"]));
|
||||
}
|
||||
|
||||
|
@ -84,8 +90,14 @@ public:
|
|||
|
||||
std::string rec_model_dir;
|
||||
|
||||
bool use_angle_cls;
|
||||
|
||||
std::string char_list_file;
|
||||
|
||||
std::string cls_model_dir;
|
||||
|
||||
double cls_thresh;
|
||||
|
||||
bool visualize = true;
|
||||
|
||||
void PrintConfigInfo();
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#include "paddle_api.h"
|
||||
#include "paddle_inference_api.h"
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <include/preprocess_op.h>
|
||||
#include <include/utility.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class Classifier {
|
||||
public:
|
||||
explicit Classifier(const std::string &model_dir, const bool &use_gpu,
|
||||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const bool &use_zero_copy_run,
|
||||
const double &cls_thresh) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
this->use_zero_copy_run_ = use_zero_copy_run;
|
||||
|
||||
this->cls_thresh = cls_thresh;
|
||||
|
||||
LoadModel(model_dir);
|
||||
}
|
||||
|
||||
// Load Paddle inference model
|
||||
void LoadModel(const std::string &model_dir);
|
||||
|
||||
cv::Mat Run(cv::Mat &img);
|
||||
|
||||
private:
|
||||
std::shared_ptr<PaddlePredictor> predictor_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
int gpu_id_ = 0;
|
||||
int gpu_mem_ = 4000;
|
||||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
bool use_zero_copy_run_ = false;
|
||||
double cls_thresh = 0.5;
|
||||
|
||||
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
bool is_scale_ = true;
|
||||
|
||||
// pre-process
|
||||
ClsResizeImg resize_op_;
|
||||
Normalize normalize_op_;
|
||||
Permute permute_op_;
|
||||
|
||||
}; // class Classifier
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -27,6 +27,7 @@
|
|||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <include/ocr_cls.h>
|
||||
#include <include/postprocess_op.h>
|
||||
#include <include/preprocess_op.h>
|
||||
#include <include/utility.h>
|
||||
|
@ -56,7 +57,8 @@ public:
|
|||
// Load Paddle inference model
|
||||
void LoadModel(const std::string &model_dir);
|
||||
|
||||
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img);
|
||||
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img,
|
||||
Classifier *cls);
|
||||
|
||||
private:
|
||||
std::shared_ptr<PaddlePredictor> predictor_;
|
||||
|
|
|
@ -56,4 +56,10 @@ public:
|
|||
const std::vector<int> &rec_image_shape = {3, 32, 320});
|
||||
};
|
||||
|
||||
class ClsResizeImg {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const std::vector<int> &rec_image_shape = {3, 32, 320});
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -53,6 +53,15 @@ int main(int argc, char **argv) {
|
|||
config.cpu_math_library_num_threads, config.use_mkldnn,
|
||||
config.use_zero_copy_run, config.max_side_len, config.det_db_thresh,
|
||||
config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize);
|
||||
|
||||
Classifier *cls = nullptr;
|
||||
if (config.use_angle_cls == true) {
|
||||
cls = new Classifier(config.cls_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.use_zero_copy_run,
|
||||
config.cls_thresh);
|
||||
}
|
||||
|
||||
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.use_zero_copy_run,
|
||||
|
@ -62,7 +71,7 @@ int main(int argc, char **argv) {
|
|||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
det.Run(srcimg, boxes);
|
||||
|
||||
rec.Run(boxes, srcimg);
|
||||
rec.Run(boxes, srcimg, cls);
|
||||
|
||||
auto end = std::chrono::system_clock::now();
|
||||
auto duration =
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <include/ocr_cls.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
cv::Mat Classifier::Run(cv::Mat &img) {
|
||||
cv::Mat src_img;
|
||||
img.copyTo(src_img);
|
||||
cv::Mat resize_img;
|
||||
|
||||
std::vector<int> rec_image_shape = {3, 32, 100};
|
||||
int index = 0;
|
||||
float wh_ratio = float(img.cols) / float(img.rows);
|
||||
|
||||
this->resize_op_.Run(img, resize_img, rec_image_shape);
|
||||
|
||||
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||
this->is_scale_);
|
||||
|
||||
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
||||
|
||||
this->permute_op_.Run(&resize_img, input.data());
|
||||
|
||||
// Inference.
|
||||
if (this->use_zero_copy_run_) {
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->copy_from_cpu(input.data());
|
||||
this->predictor_->ZeroCopyRun();
|
||||
} else {
|
||||
paddle::PaddleTensor input_t;
|
||||
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||
input_t.data =
|
||||
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
|
||||
input_t.dtype = PaddleDType::FLOAT32;
|
||||
std::vector<paddle::PaddleTensor> outputs;
|
||||
this->predictor_->Run({input_t}, &outputs, 1);
|
||||
}
|
||||
|
||||
std::vector<float> softmax_out;
|
||||
std::vector<int64_t> label_out;
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
auto softmax_out_t = this->predictor_->GetOutputTensor(output_names[0]);
|
||||
auto label_out_t = this->predictor_->GetOutputTensor(output_names[1]);
|
||||
auto softmax_shape_out = softmax_out_t->shape();
|
||||
auto label_shape_out = label_out_t->shape();
|
||||
|
||||
int softmax_out_num =
|
||||
std::accumulate(softmax_shape_out.begin(), softmax_shape_out.end(), 1,
|
||||
std::multiplies<int>());
|
||||
|
||||
int label_out_num =
|
||||
std::accumulate(label_shape_out.begin(), label_shape_out.end(), 1,
|
||||
std::multiplies<int>());
|
||||
softmax_out.resize(softmax_out_num);
|
||||
label_out.resize(label_out_num);
|
||||
|
||||
softmax_out_t->copy_to_cpu(softmax_out.data());
|
||||
label_out_t->copy_to_cpu(label_out.data());
|
||||
|
||||
int label = label_out[0];
|
||||
float score = softmax_out[label];
|
||||
// std::cout << "\nlabel "<<label<<" score: "<<score;
|
||||
if (label % 2 == 1 && score > this->cls_thresh) {
|
||||
cv::rotate(src_img, src_img, 1);
|
||||
}
|
||||
return src_img;
|
||||
}
|
||||
|
||||
void Classifier::LoadModel(const std::string &model_dir) {
|
||||
AnalysisConfig config;
|
||||
config.SetModel(model_dir + "/model", model_dir + "/params");
|
||||
|
||||
if (this->use_gpu_) {
|
||||
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
|
||||
} else {
|
||||
config.DisableGpu();
|
||||
if (this->use_mkldnn_) {
|
||||
config.EnableMKLDNN();
|
||||
}
|
||||
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
|
||||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_);
|
||||
// true for multiple input
|
||||
config.SwitchSpecifyInputNames(true);
|
||||
|
||||
config.SwitchIrOptim(true);
|
||||
|
||||
config.EnableMemoryOptim();
|
||||
config.DisableGlogInfo();
|
||||
|
||||
this->predictor_ = CreatePaddlePredictor(config);
|
||||
}
|
||||
} // namespace PaddleOCR
|
|
@ -17,7 +17,7 @@
|
|||
namespace PaddleOCR {
|
||||
|
||||
void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
||||
cv::Mat &img) {
|
||||
cv::Mat &img, Classifier *cls) {
|
||||
cv::Mat srcimg;
|
||||
img.copyTo(srcimg);
|
||||
cv::Mat crop_img;
|
||||
|
@ -27,6 +27,9 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
int index = 0;
|
||||
for (int i = boxes.size() - 1; i >= 0; i--) {
|
||||
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
||||
if (cls != nullptr) {
|
||||
crop_img = cls->Run(crop_img);
|
||||
}
|
||||
|
||||
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
|
||||
|
||||
|
|
|
@ -116,4 +116,26 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
|
|||
cv::INTER_LINEAR);
|
||||
}
|
||||
|
||||
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const std::vector<int> &rec_image_shape) {
|
||||
int imgC, imgH, imgW;
|
||||
imgC = rec_image_shape[0];
|
||||
imgH = rec_image_shape[1];
|
||||
imgW = rec_image_shape[2];
|
||||
|
||||
float ratio = float(img.cols) / float(img.rows);
|
||||
int resize_w, resize_h;
|
||||
if (ceilf(imgH * ratio) > imgW)
|
||||
resize_w = imgW;
|
||||
else
|
||||
resize_w = int(ceilf(imgH * ratio));
|
||||
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||
cv::INTER_LINEAR);
|
||||
if (resize_w < imgW) {
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
|
||||
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -13,6 +13,11 @@ det_db_box_thresh 0.5
|
|||
det_db_unclip_ratio 2.0
|
||||
det_model_dir ./inference/det_db
|
||||
|
||||
# cls config
|
||||
use_angle_cls 0
|
||||
cls_model_dir ../inference/cls
|
||||
cls_thresh 0.9
|
||||
|
||||
# rec config
|
||||
rec_model_dir ./inference/rec_crnn
|
||||
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt
|
||||
|
|
|
@ -40,8 +40,8 @@ CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SY
|
|||
|
||||
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
|
||||
|
||||
ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o
|
||||
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
|
||||
ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o
|
||||
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
|
||||
|
||||
ocr_db_crnn.o: ocr_db_crnn.cc
|
||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc
|
||||
|
@ -49,6 +49,9 @@ ocr_db_crnn.o: ocr_db_crnn.cc
|
|||
crnn_process.o: fetch_opencv crnn_process.cc
|
||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o crnn_process.o -c crnn_process.cc
|
||||
|
||||
cls_process.o: fetch_opencv cls_process.cc
|
||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o cls_process.o -c cls_process.cc
|
||||
|
||||
db_post_process.o: fetch_clipper fetch_opencv db_post_process.cc
|
||||
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o db_post_process.o -c db_post_process.cc
|
||||
|
||||
|
@ -73,5 +76,5 @@ fetch_opencv:
|
|||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o
|
||||
rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o cls_process.o
|
||||
rm -f ocr_db_crnn
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "cls_process.h" //NOLINT
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
const std::vector<int> rec_image_shape{3, 32, 100};
|
||||
|
||||
cv::Mat ClsResizeImg(cv::Mat img) {
|
||||
int imgC, imgH, imgW;
|
||||
imgC = rec_image_shape[0];
|
||||
imgH = rec_image_shape[1];
|
||||
imgW = rec_image_shape[2];
|
||||
|
||||
float ratio = static_cast<float>(img.cols) / static_cast<float>(img.rows);
|
||||
|
||||
int resize_w, resize_h;
|
||||
if (ceilf(imgH * ratio) > imgW)
|
||||
resize_w = imgW;
|
||||
else
|
||||
resize_w = int(ceilf(imgH * ratio));
|
||||
cv::Mat resize_img;
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||
cv::INTER_LINEAR);
|
||||
if (resize_w < imgW) {
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
|
||||
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||
}
|
||||
return resize_img;
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "math.h" //NOLINT
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
|
||||
cv::Mat ClsResizeImg(cv::Mat img);
|
|
@ -15,6 +15,7 @@
|
|||
#include "paddle_api.h" // NOLINT
|
||||
#include <chrono>
|
||||
|
||||
#include "cls_process.h"
|
||||
#include "crnn_process.h"
|
||||
#include "db_post_process.h"
|
||||
|
||||
|
@ -105,11 +106,55 @@ cv::Mat DetResizeImg(const cv::Mat img, int max_size_len,
|
|||
return resize_img;
|
||||
}
|
||||
|
||||
cv::Mat RunClsModel(cv::Mat img, std::shared_ptr<PaddlePredictor> predictor_cls,
|
||||
const float thresh = 0.5) {
|
||||
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
|
||||
cv::Mat srcimg;
|
||||
img.copyTo(srcimg);
|
||||
cv::Mat crop_img;
|
||||
cv::Mat resize_img;
|
||||
|
||||
int index = 0;
|
||||
float wh_ratio =
|
||||
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
|
||||
|
||||
resize_img = ClsResizeImg(crop_img);
|
||||
resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
|
||||
|
||||
const float *dimg = reinterpret_cast<const float *>(resize_img.data);
|
||||
|
||||
std::unique_ptr<Tensor> input_tensor0(std::move(predictor_cls->GetInput(0)));
|
||||
input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols});
|
||||
auto *data0 = input_tensor0->mutable_data<float>();
|
||||
|
||||
NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale);
|
||||
// Run CLS predictor
|
||||
predictor_cls->Run();
|
||||
|
||||
// Get output and run postprocess
|
||||
std::unique_ptr<const Tensor> softmax_out(
|
||||
std::move(predictor_cls->GetOutput(0)));
|
||||
std::unique_ptr<const Tensor> label_out(
|
||||
std::move(predictor_cls->GetOutput(1)));
|
||||
auto *softmax_scores = softmax_out->mutable_data<float>();
|
||||
auto *label_idxs = label_out->data<int64>();
|
||||
int label_idx = label_idxs[0];
|
||||
float score = softmax_scores[label_idx];
|
||||
|
||||
if (label_idx % 2 == 1 && score > thresh) {
|
||||
cv::rotate(srcimg, srcimg, 1);
|
||||
}
|
||||
return srcimg;
|
||||
}
|
||||
|
||||
void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
|
||||
std::shared_ptr<PaddlePredictor> predictor_crnn,
|
||||
std::vector<std::string> &rec_text,
|
||||
std::vector<float> &rec_text_score,
|
||||
std::vector<std::string> charactor_dict) {
|
||||
std::vector<std::string> charactor_dict,
|
||||
std::shared_ptr<PaddlePredictor> predictor_cls) {
|
||||
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
|
||||
|
@ -121,6 +166,7 @@ void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
|
|||
int index = 0;
|
||||
for (int i = boxes.size() - 1; i >= 0; i--) {
|
||||
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
||||
crop_img = RunClsModel(crop_img, predictor_cls);
|
||||
float wh_ratio =
|
||||
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
|
||||
|
||||
|
@ -323,8 +369,9 @@ int main(int argc, char **argv) {
|
|||
}
|
||||
std::string det_model_file = argv[1];
|
||||
std::string rec_model_file = argv[2];
|
||||
std::string img_path = argv[3];
|
||||
std::string dict_path = argv[4];
|
||||
std::string cls_model_file = argv[3];
|
||||
std::string img_path = argv[4];
|
||||
std::string dict_path = argv[5];
|
||||
|
||||
//// load config from txt file
|
||||
auto Config = LoadConfigTxt("./config.txt");
|
||||
|
@ -333,6 +380,7 @@ int main(int argc, char **argv) {
|
|||
|
||||
auto det_predictor = loadModel(det_model_file);
|
||||
auto rec_predictor = loadModel(rec_model_file);
|
||||
auto cls_predictor = loadModel(cls_model_file);
|
||||
|
||||
auto charactor_dict = ReadDict(dict_path);
|
||||
charactor_dict.push_back(" ");
|
||||
|
@ -343,7 +391,7 @@ int main(int argc, char **argv) {
|
|||
std::vector<std::string> rec_text;
|
||||
std::vector<float> rec_text_score;
|
||||
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
|
||||
charactor_dict);
|
||||
charactor_dict, cls_predictor);
|
||||
|
||||
auto end = std::chrono::system_clock::now();
|
||||
auto duration =
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
## 文字角度分类
|
||||
|
||||
### 数据准备
|
||||
|
||||
请按如下步骤设置数据集:
|
||||
|
||||
训练数据的默认存储路径是 `PaddleOCR/train_data/cls`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录:
|
||||
|
||||
```
|
||||
ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/cls/dataset
|
||||
```
|
||||
|
||||
请参考下文组织您的数据。
|
||||
- 训练集
|
||||
|
||||
首先请将训练图片放入同一个文件夹(train_images),并用一个txt文件(cls_gt_train.txt)记录图片路径和标签。
|
||||
|
||||
**注意:** 默认请将图片路径和图片标签用 `\t` 分割,如用其他方式分割将造成训练报错
|
||||
|
||||
0和180分别表示图片的角度为0度和180度
|
||||
|
||||
```
|
||||
" 图像文件名 图像标注信息 "
|
||||
|
||||
train_data/cls/word_001.jpg 0
|
||||
train_data/cls/word_002.jpg 180
|
||||
```
|
||||
|
||||
最终训练集应有如下文件结构:
|
||||
```
|
||||
|-train_data
|
||||
|-cls
|
||||
|- cls_gt_train.txt
|
||||
|- train
|
||||
|- word_001.png
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
|
||||
- 测试集
|
||||
|
||||
同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个cls_gt_test.txt,测试集的结构如下所示:
|
||||
|
||||
```
|
||||
|-train_data
|
||||
|-cls
|
||||
|- 和一个cls_gt_test.txt
|
||||
|- test
|
||||
|- word_001.jpg
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
|
||||
### 启动训练
|
||||
|
||||
PaddleOCR提供了训练脚本、评估脚本和预测脚本。
|
||||
|
||||
开始训练:
|
||||
|
||||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||
|
||||
```
|
||||
# 设置PYTHONPATH路径
|
||||
export PYTHONPATH=$PYTHONPATH:.
|
||||
# GPU训练 支持单卡,多卡训练,通过CUDA_VISIBLE_DEVICES指定卡号
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
# 启动训练
|
||||
python3 tools/train.py -c configs/cls/cls_mv3.yml
|
||||
```
|
||||
|
||||
- 数据增强
|
||||
|
||||
PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入扰动,请在配置文件中设置 `distort: true`。
|
||||
|
||||
默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse),随机数据增强(RandAugment)。
|
||||
|
||||
训练过程中除随机数据增强外每种扰动方式以50%的概率被选择,具体代码实现请参考:
|
||||
[randaugment.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/cls/randaugment.py)
|
||||
[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
|
||||
|
||||
*由于OpenCV的兼容性问题,扰动操作暂时只支持linux*
|
||||
|
||||
### 训练
|
||||
|
||||
PaddleOCR支持训练和评估交替进行, 可以在 `configs/cls/cls_mv3.yml` 中修改 `eval_batch_step` 设置评估频率,默认每500个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/cls_mv3/best_accuracy` 。
|
||||
|
||||
如果验证集很大,测试将会比较耗时,建议减少评估次数,或训练完再进行评估。
|
||||
|
||||
**注意,预测/评估时的配置文件请务必与训练一致。**
|
||||
|
||||
### 评估
|
||||
|
||||
评估数据集可以通过`configs/cls/cls_reader.yml` 修改EvalReader中的 `label_file_path` 设置。
|
||||
|
||||
*注意* 评估时必须确保配置文件中 infer_img 字段为空
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
# GPU 评估, Global.checkpoints 为待测权重
|
||||
python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
### 预测
|
||||
|
||||
* 训练引擎的预测
|
||||
|
||||
使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。
|
||||
|
||||
默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 指定权重:
|
||||
|
||||
```
|
||||
# 预测分类结果
|
||||
python3 tools/infer_cls.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
|
||||
```
|
||||
|
||||
预测图片:
|
||||
|
||||
![](../imgs_words/en/word_1.png)
|
||||
|
||||
得到输入图像的预测结果:
|
||||
|
||||
```
|
||||
infer_img: doc/imgs_words/en/word_1.png
|
||||
scores: [[0.93161047 0.06838956]]
|
||||
label: [0]
|
||||
```
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
PaddleOCR提供了EAST、DB两种文本检测算法,均支持MobileNetV3、ResNet50_vd两种骨干网络,根据需要选择相应的配置文件,启动训练。例如,训练使用MobileNetV3作为骨干网络的DB检测模型(即超轻量模型使用的配置):
|
||||
```
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml 2>&1 | tee det_db.log
|
||||
```
|
||||
更详细的数据准备和训练教程参考文档教程中[文本检测模型训练/评估/预测](./detection.md)。
|
||||
|
||||
|
@ -14,7 +14,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml
|
|||
|
||||
PaddleOCR提供了CRNN、Rosetta、STAR-Net、RARE四种文本识别算法,均支持MobileNetV3、ResNet34_vd两种骨干网络,根据需要选择相应的配置文件,启动训练。例如,训练使用MobileNetV3作为骨干网络的CRNN识别模型(即超轻量模型使用的配置):
|
||||
```
|
||||
python3 tools/train.py -c configs/rec/rec_chinese_lite_train.yml
|
||||
python3 tools/train.py -c configs/rec/rec_chinese_lite_train.yml 2>&1 | tee rec_ch_lite.log
|
||||
```
|
||||
更详细的数据准备和训练教程参考文档教程中[文本识别模型训练/评估/预测](./recognition.md)。
|
||||
|
||||
|
|
|
@ -62,7 +62,10 @@ tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_model
|
|||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||
|
||||
```shell
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
||||
# 训练 mv3_db 模型,并将训练日志保存为 tain_det.log
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml \
|
||||
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/ \
|
||||
2>&1 | tee train_det.log
|
||||
```
|
||||
|
||||
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
|
||||
|
|
|
@ -11,24 +11,28 @@ inference 模型(`fluid.io.save_inference_model`保存的模型)
|
|||
- [一、训练模型转inference模型](#训练模型转inference模型)
|
||||
- [检测模型转inference模型](#检测模型转inference模型)
|
||||
- [识别模型转inference模型](#识别模型转inference模型)
|
||||
|
||||
- [方向分类模型转inference模型](#方向分类模型转inference模型)
|
||||
|
||||
- [二、文本检测模型推理](#文本检测模型推理)
|
||||
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
|
||||
- [2. DB文本检测模型推理](#DB文本检测模型推理)
|
||||
- [3. EAST文本检测模型推理](#EAST文本检测模型推理)
|
||||
- [4. SAST文本检测模型推理](#SAST文本检测模型推理)
|
||||
|
||||
|
||||
- [三、文本识别模型推理](#文本识别模型推理)
|
||||
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
|
||||
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
|
||||
- [3. 基于Attention损失的识别模型推理](#基于Attention损失的识别模型推理)
|
||||
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
||||
|
||||
- [四、文本检测、识别串联推理](#文本检测、识别串联推理)
|
||||
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
||||
|
||||
- [四、方向分类模型推理](#方向识别模型推理)
|
||||
- [1. 方向分类模型推理](#方向分类模型推理)
|
||||
|
||||
- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
|
||||
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
|
||||
- [2. 其他模型推理](#其他模型推理)
|
||||
|
||||
|
||||
|
||||
|
||||
<a name="训练模型转inference模型"></a>
|
||||
## 一、训练模型转inference模型
|
||||
<a name="检测模型转inference模型"></a>
|
||||
|
@ -84,6 +88,32 @@ python3 tools/export_model.py -c configs/rec/rec_chinese_lite_train.yml -o Globa
|
|||
└─ params 识别inference模型的参数文件
|
||||
```
|
||||
|
||||
<a name="方向分类模型转inference模型"></a>
|
||||
### 方向分类模型转inference模型
|
||||
|
||||
下载方向分类模型:
|
||||
```
|
||||
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile-v1.1.cls_pre.tar && tar xf ./ch_lite/ch_ppocr_mobile-v1.1.cls_pre.tar -C ./ch_lite/
|
||||
```
|
||||
|
||||
方向分类模型转inference模型与检测的方式相同,如下:
|
||||
```
|
||||
# -c后面设置训练算法的yml配置文件
|
||||
# -o配置可选参数
|
||||
# Global.checkpoints参数设置待转换的训练模型地址,不用添加文件后缀.pdmodel,.pdopt或.pdparams。
|
||||
# Global.save_inference_dir参数设置转换的模型将保存的地址。
|
||||
|
||||
python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=./ch_lite/cls_model/best_accuracy \
|
||||
Global.save_inference_dir=./inference/cls/
|
||||
```
|
||||
|
||||
转换成功后,在目录下有两个文件:
|
||||
```
|
||||
/inference/cls/
|
||||
└─ model 识别inference模型的program文件
|
||||
└─ params 识别inference模型的参数文件
|
||||
```
|
||||
|
||||
<a name="文本检测模型推理"></a>
|
||||
## 二、文本检测模型推理
|
||||
|
||||
|
@ -275,15 +305,36 @@ dict_character = list(self.character_str)
|
|||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path"
|
||||
```
|
||||
|
||||
<a name="文本检测、识别串联推理"></a>
|
||||
## 四、文本检测、识别串联推理
|
||||
|
||||
<a name="方向分类模型推理"></a>
|
||||
## 四、方向分类模型推理
|
||||
|
||||
下面将介绍方向分类模型推理。
|
||||
|
||||
<a name="方向分类模型推理"></a>
|
||||
### 1. 方向分类模型推理
|
||||
|
||||
方向分类模型推理,可以执行如下命令:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --cls_model_dir="./inference/cls/"
|
||||
```
|
||||
|
||||
![](../imgs_words/ch/word_4.jpg)
|
||||
|
||||
执行命令后,上面图像的预测结果(分类的方向和得分)会打印到屏幕上,示例如下:
|
||||
|
||||
Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999963]
|
||||
|
||||
<a name="文本检测、方向分类和文字识别串联推理"></a>
|
||||
## 五、文本检测、方向分类和文字识别串联推理
|
||||
<a name="超轻量中文OCR模型推理"></a>
|
||||
### 1. 超轻量中文OCR模型推理
|
||||
|
||||
在执行预测时,需要通过参数image_dir指定单张图像或者图像集合的路径、参数det_model_dir指定检测inference模型的路径和参数rec_model_dir指定识别inference模型的路径。可视化识别结果默认保存到 ./inference_results 文件夹里面。
|
||||
在执行预测时,需要通过参数`image_dir`指定单张图像或者图像集合的路径、参数`det_model_dir`,`cls_model_dir`和`rec_model_dir`分别指定检测,方向分类和识别的inference模型路径。参数`use_angle_cls`用于控制是否启用方向分类模型。可视化识别结果默认保存到 ./inference_results 文件夹里面。
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/"
|
||||
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls true
|
||||
```
|
||||
|
||||
执行命令后,识别结果图像如下:
|
||||
|
|
|
@ -128,8 +128,8 @@ tar -xf rec_mv3_none_bilstm_ctc.tar && rm -rf rec_mv3_none_bilstm_ctc.tar
|
|||
export PYTHONPATH=$PYTHONPATH:.
|
||||
# GPU训练 支持单卡,多卡训练,通过CUDA_VISIBLE_DEVICES指定卡号
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
# 训练icdar15英文数据
|
||||
python3 tools/train.py -c configs/rec/rec_icdar15_train.yml
|
||||
# 训练icdar15英文数据 并将训练日志保存为 tain_rec.log
|
||||
python3 tools/train.py -c configs/rec/rec_icdar15_train.yml 2>&1 | tee train_rec.log
|
||||
```
|
||||
|
||||
- 数据增强
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
## TEXT ANGLE CLASSIFICATION
|
||||
|
||||
### DATA PREPARATION
|
||||
|
||||
Please organize the dataset as follows:
|
||||
|
||||
The default storage path for training data is `PaddleOCR/train_data/cls`, if you already have a dataset on your disk, just create a soft link to the dataset directory:
|
||||
|
||||
```
|
||||
ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/cls/dataset
|
||||
```
|
||||
|
||||
please refer to the following to organize your data.
|
||||
|
||||
- Training set
|
||||
|
||||
First put the training images in the same folder (train_images), and use a txt file (cls_gt_train.txt) to store the image path and label.
|
||||
|
||||
* Note: by default, the image path and image label are split with `\t`, if you use other methods to split, it will cause training error
|
||||
|
||||
0 and 180 indicate that the angle of the image is 0 degrees and 180 degrees, respectively.
|
||||
|
||||
```
|
||||
" Image file name Image annotation "
|
||||
|
||||
train_data/word_001.jpg 0
|
||||
train_data/word_002.jpg 180
|
||||
```
|
||||
|
||||
The final training set should have the following file structure:
|
||||
|
||||
```
|
||||
|-train_data
|
||||
|-cls
|
||||
|- cls_gt_train.txt
|
||||
|- train
|
||||
|- word_001.png
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
|
||||
- Test set
|
||||
|
||||
Similar to the training set, the test set also needs to be provided a folder
|
||||
containing all images (test) and a cls_gt_test.txt. The structure of the test set is as follows:
|
||||
|
||||
```
|
||||
|-train_data
|
||||
|-cls
|
||||
|- cls_gt_test.txt
|
||||
|- test
|
||||
|- word_001.jpg
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
|
||||
### TRAINING
|
||||
|
||||
PaddleOCR provides training scripts, evaluation scripts, and prediction scripts.
|
||||
|
||||
Start training:
|
||||
|
||||
```
|
||||
# Set PYTHONPATH path
|
||||
export PYTHONPATH=$PYTHONPATH:.
|
||||
# GPU training Support single card and multi-card training, specify the card number through CUDA_VISIBLE_DEVICES
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
# Training icdar15 English data
|
||||
python3 tools/train.py -c configs/cls/cls_mv3.yml
|
||||
```
|
||||
|
||||
- Data Augmentation
|
||||
|
||||
PaddleOCR provides a variety of data augmentation methods. If you want to add disturbance during training, please set `distort: true` in the configuration file.
|
||||
|
||||
The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse, RandAugment.
|
||||
|
||||
Except for RandAugment, each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to:
|
||||
[randaugment.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/cls/randaugment.py)
|
||||
[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
|
||||
|
||||
|
||||
- Training
|
||||
|
||||
PaddleOCR supports alternating training and evaluation. You can modify `eval_batch_step` in `configs/cls/cls_mv3.yml` to set the evaluation frequency. By default, it is evaluated every 500 iter and the best acc model is saved under `output/cls_mv3/best_accuracy` during the evaluation process.
|
||||
|
||||
If the evaluation set is large, the test will be time-consuming. It is recommended to reduce the number of evaluations, or evaluate after training.
|
||||
|
||||
**Note that the configuration file for prediction/evaluation must be consistent with the training.**
|
||||
|
||||
### EVALUATION
|
||||
|
||||
The evaluation data set can be modified via `configs/cls/cls_reader.yml` setting of `label_file_path` in EvalReader.
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
# GPU evaluation, Global.checkpoints is the weight to be tested
|
||||
python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy
|
||||
```
|
||||
|
||||
### PREDICTION
|
||||
|
||||
* Training engine prediction
|
||||
|
||||
Using the model trained by paddleocr, you can quickly get prediction through the following script.
|
||||
|
||||
The default prediction picture is stored in `infer_img`, and the weight is specified via `-o Global.checkpoints`:
|
||||
|
||||
```
|
||||
# Predict English results
|
||||
python3 tools/infer_rec.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg
|
||||
```
|
||||
|
||||
Input image:
|
||||
|
||||
![](../imgs_words/en/word_1.png)
|
||||
|
||||
Get the prediction result of the input image:
|
||||
|
||||
```
|
||||
infer_img: doc/imgs_words/en/word_1.png
|
||||
scores: [[0.93161047 0.06838956]]
|
||||
label: [0]
|
||||
```
|
|
@ -6,7 +6,7 @@ The process of making a customized ultra-lightweight OCR models can be divided i
|
|||
|
||||
PaddleOCR provides two text detection algorithms: EAST and DB. Both support MobileNetV3 and ResNet50_vd backbone networks, select the corresponding configuration file as needed and start training. For example, to train with MobileNetV3 as the backbone network for DB detection model :
|
||||
```
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml 2>&1 | tee det_db.log
|
||||
```
|
||||
For more details about data preparation and training tutorials, refer to the documentation [Text detection model training/evaluation/prediction](./detection_en.md)
|
||||
|
||||
|
@ -14,7 +14,7 @@ For more details about data preparation and training tutorials, refer to the doc
|
|||
|
||||
PaddleOCR provides four text recognition algorithms: CRNN, Rosetta, STAR-Net, and RARE. They all support two backbone networks: MobileNetV3 and ResNet34_vd, select the corresponding configuration files as needed to start training. For example, to train a CRNN recognition model that uses MobileNetV3 as the backbone network:
|
||||
```
|
||||
python3 tools/train.py -c configs/rec/rec_chinese_lite_train.yml
|
||||
python3 tools/train.py -c configs/rec/rec_chinese_lite_train.yml 2>&1 | tee rec_ch_lite.log
|
||||
```
|
||||
For more details about data preparation and training tutorials, refer to the documentation [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_model
|
|||
#### START TRAINING
|
||||
*If CPU version installed, please set the parameter `use_gpu` to `false` in the configuration.*
|
||||
```shell
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml
|
||||
python3 tools/train.py -c configs/det/det_mv3_db.yml 2>&1 | tee train_det.log
|
||||
```
|
||||
|
||||
In the above instruction, use `-c` to select the training to use the `configs/det/det_db_mv3.yml` configuration file.
|
||||
|
|
|
@ -12,25 +12,28 @@ Next, we first introduce how to convert a trained model into an inference model,
|
|||
- [CONVERT TRAINING MODEL TO INFERENCE MODEL](#CONVERT)
|
||||
- [Convert detection model to inference model](#Convert_detection_model)
|
||||
- [Convert recognition model to inference model](#Convert_recognition_model)
|
||||
|
||||
|
||||
- [Convert angle classification model to inference model](#Convert_angle_class_model)
|
||||
|
||||
|
||||
- [TEXT DETECTION MODEL INFERENCE](#DETECTION_MODEL_INFERENCE)
|
||||
- [1. LIGHTWEIGHT CHINESE DETECTION MODEL INFERENCE](#LIGHTWEIGHT_DETECTION)
|
||||
- [2. DB TEXT DETECTION MODEL INFERENCE](#DB_DETECTION)
|
||||
- [3. EAST TEXT DETECTION MODEL INFERENCE](#EAST_DETECTION)
|
||||
- [4. SAST TEXT DETECTION MODEL INFERENCE](#SAST_DETECTION)
|
||||
|
||||
|
||||
- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
|
||||
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
|
||||
- [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
|
||||
- [3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE](#ATTENTION-BASED_RECOGNITION)
|
||||
- [4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
|
||||
|
||||
|
||||
- [TEXT DETECTION AND RECOGNITION INFERENCE CONCATENATION](#CONCATENATION)
|
||||
|
||||
- [ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
|
||||
- [1. ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
|
||||
|
||||
- [TEXT DETECTION ANGLE CLASSIFICATION AND RECOGNITION INFERENCE CONCATENATION](#CONCATENATION)
|
||||
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_CHINESE_MODEL)
|
||||
- [2. OTHER MODELS](#OTHER_MODELS)
|
||||
|
||||
|
||||
<a name="CONVERT"></a>
|
||||
## CONVERT TRAINING MODEL TO INFERENCE MODEL
|
||||
<a name="Convert_detection_model"></a>
|
||||
|
@ -87,6 +90,33 @@ After the conversion is successful, there are two files in the directory:
|
|||
└─ params Identify the parameter files of the inference model
|
||||
```
|
||||
|
||||
<a name="Convert_angle_class_model"></a>
|
||||
### Convert angle classification model to inference model
|
||||
|
||||
Download the angle classification model:
|
||||
```
|
||||
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile-v1.1.cls_pre.tar && tar xf ./ch_lite/ch_ppocr_mobile-v1.1.cls_pre.tar -C ./ch_lite/
|
||||
```
|
||||
|
||||
The angle classification model is converted to the inference model in the same way as the detection, as follows:
|
||||
```
|
||||
# -c Set the training algorithm yml configuration file
|
||||
# -o Set optional parameters
|
||||
# Global.checkpoints parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams.
|
||||
# Global.save_inference_dir Set the address where the converted model will be saved.
|
||||
|
||||
python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.checkpoints=./ch_lite/cls_model/best_accuracy \
|
||||
Global.save_inference_dir=./inference/cls/
|
||||
```
|
||||
|
||||
After the conversion is successful, there are two files in the directory:
|
||||
```
|
||||
/inference/cls/
|
||||
└─ model Identify the saved model files
|
||||
└─ params Identify the parameter files of the inference model
|
||||
```
|
||||
|
||||
|
||||
<a name="DETECTION_MODEL_INFERENCE"></a>
|
||||
## TEXT DETECTION MODEL INFERENCE
|
||||
|
||||
|
@ -276,16 +306,39 @@ If the chars dictionary is modified during training, you need to specify the new
|
|||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path"
|
||||
```
|
||||
|
||||
|
||||
<a name="ANGLE_CLASSIFICATION_MODEL_INFERENCE"></a>
|
||||
## ANGLE CLASSIFICATION MODEL INFERENCE
|
||||
|
||||
The following will introduce the angle classification model inference.
|
||||
|
||||
|
||||
<a name="ANGLE_CLASS_MODEL_INFERENCE"></a>
|
||||
### 1.ANGLE CLASSIFICATION MODEL INFERENCE
|
||||
|
||||
For angle classification model inference, you can execute the following commands:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words/ch/word_4.jpg" --cls_model_dir="./inference/cls/"
|
||||
```
|
||||
|
||||
![](../imgs_words/ch/word_4.jpg)
|
||||
|
||||
After executing the command, the prediction results (classification angle and score) of the above image will be printed on the screen.
|
||||
|
||||
Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999963]
|
||||
|
||||
|
||||
<a name="CONCATENATION"></a>
|
||||
## TEXT DETECTION AND RECOGNITION INFERENCE CONCATENATION
|
||||
## TEXT DETECTION ANGLE CLASSIFICATION AND RECOGNITION INFERENCE CONCATENATION
|
||||
|
||||
<a name="LIGHTWEIGHT_CHINESE_MODEL"></a>
|
||||
### 1. LIGHTWEIGHT CHINESE MODEL
|
||||
|
||||
When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, and the parameter `rec_model_dir` specifies the path to identify the inference model. The visualized recognition results are saved to the `./inference_results` folder by default.
|
||||
When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, the parameter `cls_model_dir` specifies the path to angle classification inference model and the parameter `rec_model_dir` specifies the path to identify the inference model. The parameter `use_angle_cls` is used to control whether to enable the angle classification model.The visualized recognition results are saved to the `./inference_results` folder by default.
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/"
|
||||
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls true
|
||||
```
|
||||
|
||||
After executing the command, the recognition result image is as follows:
|
||||
|
|
|
@ -130,8 +130,8 @@ Start training:
|
|||
export PYTHONPATH=$PYTHONPATH:.
|
||||
# GPU training Support single card and multi-card training, specify the card number through CUDA_VISIBLE_DEVICES
|
||||
export CUDA_VISIBLE_DEVICES=0,1,2,3
|
||||
# Training icdar15 English data
|
||||
python3 tools/train.py -c configs/rec/rec_icdar15_train.yml
|
||||
# Training icdar15 English data and saving the log as train_rec.log
|
||||
python3 tools/train.py -c configs/rec/rec_icdar15_train.yml 2>&1 | tee train_rec.log
|
||||
```
|
||||
|
||||
- Data Augmentation
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,144 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from ppocr.utils.utility import initial_logger
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
|
||||
logger = initial_logger()
|
||||
|
||||
from ppocr.data.rec.img_tools import resize_norm_img, warp
|
||||
from ppocr.data.cls.randaugment import RandAugment
|
||||
|
||||
|
||||
def random_crop(img):
|
||||
img_h, img_w = img.shape[:2]
|
||||
if img_w > img_h * 4:
|
||||
w = random.randint(img_h * 2, img_w)
|
||||
i = random.randint(0, img_w - w)
|
||||
|
||||
img = img[:, i:i + w, :]
|
||||
return img
|
||||
|
||||
|
||||
class SimpleReader(object):
|
||||
def __init__(self, params):
|
||||
if params['mode'] != 'train':
|
||||
self.num_workers = 1
|
||||
else:
|
||||
self.num_workers = params['num_workers']
|
||||
if params['mode'] != 'test':
|
||||
self.img_set_dir = params['img_set_dir']
|
||||
self.label_file_path = params['label_file_path']
|
||||
self.use_gpu = params['use_gpu']
|
||||
self.image_shape = params['image_shape']
|
||||
self.mode = params['mode']
|
||||
self.infer_img = params['infer_img']
|
||||
self.use_distort = params['mode'] == 'train' and params['distort']
|
||||
self.randaug = RandAugment()
|
||||
self.label_list = params['label_list']
|
||||
if "distort" in params:
|
||||
self.use_distort = params['distort'] and params['use_gpu']
|
||||
if not params['use_gpu']:
|
||||
logger.info(
|
||||
"Distort operation can only support in GPU.Distort will be set to False."
|
||||
)
|
||||
if params['mode'] == 'train':
|
||||
self.batch_size = params['train_batch_size_per_card']
|
||||
self.drop_last = True
|
||||
else:
|
||||
self.batch_size = params['test_batch_size_per_card']
|
||||
self.drop_last = False
|
||||
self.use_distort = False
|
||||
|
||||
def __call__(self, process_id):
|
||||
if self.mode != 'train':
|
||||
process_id = 0
|
||||
|
||||
def get_device_num():
|
||||
if self.use_gpu:
|
||||
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", 1)
|
||||
gpu_num = len(gpus.split(','))
|
||||
return gpu_num
|
||||
else:
|
||||
cpu_num = os.environ.get("CPU_NUM", 1)
|
||||
return int(cpu_num)
|
||||
|
||||
def sample_iter_reader():
|
||||
if self.mode != 'train' and self.infer_img is not None:
|
||||
image_file_list = get_image_file_list(self.infer_img)
|
||||
for single_img in image_file_list:
|
||||
img = cv2.imread(single_img)
|
||||
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
norm_img = resize_norm_img(img, self.image_shape)
|
||||
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
yield norm_img
|
||||
else:
|
||||
with open(self.label_file_path, "rb") as fin:
|
||||
label_infor_list = fin.readlines()
|
||||
img_num = len(label_infor_list)
|
||||
img_id_list = list(range(img_num))
|
||||
random.shuffle(img_id_list)
|
||||
if sys.platform == "win32" and self.num_workers != 1:
|
||||
print("multiprocess is not fully compatible with Windows."
|
||||
"num_workers will be 1.")
|
||||
self.num_workers = 1
|
||||
if self.batch_size * get_device_num(
|
||||
) * self.num_workers > img_num:
|
||||
raise Exception(
|
||||
"The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})".
|
||||
format(img_num, self.batch_size * get_device_num() *
|
||||
self.num_workers))
|
||||
for img_id in range(process_id, img_num, self.num_workers):
|
||||
label_infor = label_infor_list[img_id_list[img_id]]
|
||||
substr = label_infor.decode('utf-8').strip("\n").split("\t")
|
||||
label = self.label_list.index(substr[1])
|
||||
|
||||
img_path = self.img_set_dir + "/" + substr[0]
|
||||
img = cv2.imread(img_path)
|
||||
if img is None:
|
||||
logger.info("{} does not exist!".format(img_path))
|
||||
continue
|
||||
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
if self.use_distort:
|
||||
img = warp(img, 10)
|
||||
img = self.randaug(img)
|
||||
norm_img = resize_norm_img(img, self.image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
yield (norm_img, label)
|
||||
|
||||
def batch_iter_reader():
|
||||
batch_outs = []
|
||||
for outs in sample_iter_reader():
|
||||
batch_outs.append(outs)
|
||||
if len(batch_outs) == self.batch_size:
|
||||
yield batch_outs
|
||||
batch_outs = []
|
||||
if not self.drop_last:
|
||||
if len(batch_outs) != 0:
|
||||
yield batch_outs
|
||||
|
||||
if self.infer_img is None:
|
||||
return batch_iter_reader
|
||||
return sample_iter_reader
|
|
@ -0,0 +1,135 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
import six
|
||||
|
||||
|
||||
class RawRandAugment(object):
|
||||
def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)):
|
||||
self.num_layers = num_layers
|
||||
self.magnitude = magnitude
|
||||
self.max_level = 10
|
||||
|
||||
abso_level = self.magnitude / self.max_level
|
||||
self.level_map = {
|
||||
"shearX": 0.3 * abso_level,
|
||||
"shearY": 0.3 * abso_level,
|
||||
"translateX": 150.0 / 331 * abso_level,
|
||||
"translateY": 150.0 / 331 * abso_level,
|
||||
"rotate": 30 * abso_level,
|
||||
"color": 0.9 * abso_level,
|
||||
"posterize": int(4.0 * abso_level),
|
||||
"solarize": 256.0 * abso_level,
|
||||
"contrast": 0.9 * abso_level,
|
||||
"sharpness": 0.9 * abso_level,
|
||||
"brightness": 0.9 * abso_level,
|
||||
"autocontrast": 0,
|
||||
"equalize": 0,
|
||||
"invert": 0
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/
|
||||
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot,
|
||||
Image.new("RGBA", rot.size, (128, ) * 4),
|
||||
rot).convert(img.mode)
|
||||
|
||||
rnd_ch_op = random.choice
|
||||
|
||||
self.func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC,
|
||||
fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
|
||||
Image.BICUBIC,
|
||||
fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"posterize": lambda img, magnitude:
|
||||
ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude:
|
||||
ImageOps.solarize(img, magnitude),
|
||||
"contrast": lambda img, magnitude:
|
||||
ImageEnhance.Contrast(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"sharpness": lambda img, magnitude:
|
||||
ImageEnhance.Sharpness(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"brightness": lambda img, magnitude:
|
||||
ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude:
|
||||
ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
def __call__(self, img):
|
||||
avaiable_op_names = list(self.level_map.keys())
|
||||
for layer_num in range(self.num_layers):
|
||||
op_name = np.random.choice(avaiable_op_names)
|
||||
img = self.func[op_name](img, self.level_map[op_name])
|
||||
return img
|
||||
|
||||
|
||||
class RandAugment(RawRandAugment):
|
||||
""" RandAugment wrapper to auto fit different img types """
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if six.PY2:
|
||||
super(RandAugment, self).__init__(*args, **kwargs)
|
||||
else:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, img):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = np.ascontiguousarray(img)
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if six.PY2:
|
||||
img = super(RandAugment, self).__call__(img)
|
||||
else:
|
||||
img = super().__call__(img)
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.asarray(img)
|
||||
|
||||
return img
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import fluid
|
||||
|
||||
from ppocr.utils.utility import create_module
|
||||
from ppocr.utils.utility import initial_logger
|
||||
|
||||
logger = initial_logger()
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class ClsModel(object):
|
||||
def __init__(self, params):
|
||||
super(ClsModel, self).__init__()
|
||||
global_params = params['Global']
|
||||
self.infer_img = global_params['infer_img']
|
||||
|
||||
backbone_params = deepcopy(params["Backbone"])
|
||||
backbone_params.update(global_params)
|
||||
self.backbone = create_module(backbone_params['function']) \
|
||||
(params=backbone_params)
|
||||
|
||||
head_params = deepcopy(params["Head"])
|
||||
head_params.update(global_params)
|
||||
self.head = create_module(head_params['function']) \
|
||||
(params=head_params)
|
||||
|
||||
loss_params = deepcopy(params["Loss"])
|
||||
loss_params.update(global_params)
|
||||
self.loss = create_module(loss_params['function']) \
|
||||
(params=loss_params)
|
||||
|
||||
self.image_shape = global_params['image_shape']
|
||||
|
||||
def create_feed(self, mode):
|
||||
image_shape = deepcopy(self.image_shape)
|
||||
image_shape.insert(0, -1)
|
||||
if mode == "train":
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
|
||||
feed_list = [image, label]
|
||||
labels = {'label': label}
|
||||
loader = fluid.io.DataLoader.from_generator(
|
||||
feed_list=feed_list,
|
||||
capacity=64,
|
||||
use_double_buffer=True,
|
||||
iterable=False)
|
||||
else:
|
||||
labels = None
|
||||
loader = None
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
return image, labels, loader
|
||||
|
||||
def __call__(self, mode):
|
||||
image, labels, loader = self.create_feed(mode)
|
||||
inputs = image
|
||||
conv_feas = self.backbone(inputs)
|
||||
predicts = self.head(conv_feas, labels, mode)
|
||||
if mode == "train":
|
||||
loss = self.loss(predicts, labels)
|
||||
label = labels['label']
|
||||
acc = fluid.layers.accuracy(predicts['predict'], label, k=1)
|
||||
outputs = {'total_loss': loss, 'decoded_out': \
|
||||
predicts['decoded_out'], 'label': label, 'acc': acc}
|
||||
return loader, outputs
|
||||
elif mode == "export":
|
||||
return [image, predicts]
|
||||
else:
|
||||
return loader, predicts
|
|
@ -0,0 +1,46 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class ClsHead(object):
|
||||
def __init__(self, params):
|
||||
super(ClsHead, self).__init__()
|
||||
self.class_dim = params['class_dim']
|
||||
|
||||
def __call__(self, inputs, labels=None, mode=None):
|
||||
pool = fluid.layers.pool2d(
|
||||
input=inputs, pool_type='avg', global_pooling=True)
|
||||
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
|
||||
|
||||
out = fluid.layers.fc(
|
||||
input=pool,
|
||||
size=self.class_dim,
|
||||
param_attr=fluid.param_attr.ParamAttr(
|
||||
name="fc_0.w_0",
|
||||
initializer=fluid.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=fluid.param_attr.ParamAttr(name="fc_0.b_0"))
|
||||
|
||||
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
|
||||
out_label = fluid.layers.argmax(out, axis=1)
|
||||
predicts = {'predict': softmax_out, 'decoded_out': out_label}
|
||||
return predicts
|
|
@ -0,0 +1,33 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class ClsLoss(object):
|
||||
def __init__(self, params):
|
||||
super(ClsLoss, self).__init__()
|
||||
self.loss_func = fluid.layers.cross_entropy
|
||||
|
||||
def __call__(self, predicts, labels):
|
||||
predict = predicts['predict']
|
||||
label = labels['label']
|
||||
# softmax_out = fluid.layers.softmax(predict, use_cudnn=False)
|
||||
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
||||
sum_cost = fluid.layers.mean(cost)
|
||||
return sum_cost
|
|
@ -45,10 +45,12 @@ from ppocr.utils.save_load import init_model
|
|||
from eval_utils.eval_det_utils import eval_det_run
|
||||
from eval_utils.eval_rec_utils import test_rec_benchmark
|
||||
from eval_utils.eval_rec_utils import eval_rec_run
|
||||
from eval_utils.eval_cls_utils import eval_cls_run
|
||||
|
||||
|
||||
def main():
|
||||
startup_prog, eval_program, place, config, train_alg_type = program.preprocess()
|
||||
startup_prog, eval_program, place, config, train_alg_type = program.preprocess(
|
||||
)
|
||||
eval_build_outputs = program.build(
|
||||
config, eval_program, startup_prog, mode='test')
|
||||
eval_fetch_name_list = eval_build_outputs[1]
|
||||
|
@ -67,6 +69,14 @@ def main():
|
|||
'fetch_varname_list':eval_fetch_varname_list}
|
||||
metrics = eval_det_run(exe, config, eval_info_dict, "eval")
|
||||
logger.info("Eval result: {}".format(metrics))
|
||||
elif train_alg_type == 'cls':
|
||||
eval_reader = reader_main(config=config, mode="eval")
|
||||
eval_info_dict = {'program': eval_program, \
|
||||
'reader': eval_reader, \
|
||||
'fetch_name_list': eval_fetch_name_list, \
|
||||
'fetch_varname_list': eval_fetch_varname_list}
|
||||
metrics = eval_cls_run(exe, eval_info_dict)
|
||||
logger.info("Eval result: {}".format(metrics))
|
||||
else:
|
||||
reader_type = config['Global']['reader_yml']
|
||||
if "benchmark" not in reader_type:
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['eval_cls_run']
|
||||
|
||||
import logging
|
||||
|
||||
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
|
||||
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def eval_cls_run(exe, eval_info_dict):
|
||||
"""
|
||||
Run evaluation program, return program outputs.
|
||||
"""
|
||||
total_sample_num = 0
|
||||
total_acc_num = 0
|
||||
total_batch_num = 0
|
||||
|
||||
for data in eval_info_dict['reader']():
|
||||
img_num = len(data)
|
||||
img_list = []
|
||||
label_list = []
|
||||
for ino in range(img_num):
|
||||
img_list.append(data[ino][0])
|
||||
label_list.append(data[ino][1])
|
||||
|
||||
img_list = np.concatenate(img_list, axis=0)
|
||||
outs = exe.run(eval_info_dict['program'], \
|
||||
feed={'image': img_list}, \
|
||||
fetch_list=eval_info_dict['fetch_varname_list'], \
|
||||
return_numpy=False)
|
||||
softmax_outs = np.array(outs[1])
|
||||
if len(softmax_outs.shape) != 1:
|
||||
softmax_outs = np.array(outs[0])
|
||||
acc, acc_num = cal_cls_acc(softmax_outs, label_list)
|
||||
total_acc_num += acc_num
|
||||
total_sample_num += len(label_list)
|
||||
# logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
|
||||
total_batch_num += 1
|
||||
avg_acc = total_acc_num * 1.0 / total_sample_num
|
||||
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
|
||||
"total_sample_num": total_sample_num}
|
||||
return metrics
|
||||
|
||||
|
||||
def cal_cls_acc(preds, labels):
|
||||
acc_num = 0
|
||||
for pred, label in zip(preds, labels):
|
||||
if pred == label:
|
||||
acc_num += 1
|
||||
return acc_num / len(preds), acc_num
|
|
@ -0,0 +1,145 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
from paddle import fluid
|
||||
|
||||
|
||||
class TextClassifier(object):
|
||||
def __init__(self, args):
|
||||
self.predictor, self.input_tensor, self.output_tensors = \
|
||||
utility.create_predictor(args, mode="cls")
|
||||
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
|
||||
self.cls_batch_num = args.rec_batch_num
|
||||
self.label_list = args.label_list
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
|
||||
def resize_norm_img(self, img):
|
||||
imgC, imgH, imgW = self.cls_image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
if self.cls_image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_list = copy.deepcopy(img_list)
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
width_list = []
|
||||
for img in img_list:
|
||||
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||
# Sorting can speed up the cls process
|
||||
indices = np.argsort(np.array(width_list))
|
||||
|
||||
cls_res = [['', 0.0]] * img_num
|
||||
batch_num = self.cls_batch_num
|
||||
predict_time = 0
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
max_wh_ratio = 0
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]])
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
starttime = time.time()
|
||||
|
||||
if self.use_zero_copy_run:
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.zero_copy_run()
|
||||
else:
|
||||
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
||||
self.predictor.run([norm_img_batch])
|
||||
|
||||
prob_out = self.output_tensors[0].copy_to_cpu()
|
||||
label_out = self.output_tensors[1].copy_to_cpu()
|
||||
if len(label_out.shape) != 1:
|
||||
prob_out, label_out = label_out, prob_out
|
||||
|
||||
elapse = time.time() - starttime
|
||||
predict_time += elapse
|
||||
for rno in range(len(label_out)):
|
||||
label_idx = label_out[rno]
|
||||
score = prob_out[rno][label_idx]
|
||||
label = self.label_list[label_idx]
|
||||
cls_res[indices[beg_img_no + rno]] = [label, score]
|
||||
if '180' in label and score > 0.9999:
|
||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||
img_list[indices[beg_img_no + rno]], 1)
|
||||
return img_list, cls_res, predict_time
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
text_classifier = TextClassifier(args)
|
||||
valid_image_file_list = []
|
||||
img_list = []
|
||||
for image_file in image_file_list[:10]:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
valid_image_file_list.append(image_file)
|
||||
img_list.append(img)
|
||||
try:
|
||||
img_list, cls_res, predict_time = text_classifier(img_list)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
exit()
|
||||
for ino in range(len(img_list)):
|
||||
print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino]))
|
||||
print("Total predict time for %d images:%.3f" %
|
||||
(len(img_list), predict_time))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(utility.parse_args())
|
|
@ -13,16 +13,19 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
|
||||
logger = initial_logger()
|
||||
import cv2
|
||||
import tools.infer.predict_det as predict_det
|
||||
import tools.infer.predict_rec as predict_rec
|
||||
import tools.infer.predict_cls as predict_cls
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
|
@ -37,6 +40,9 @@ class TextSystem(object):
|
|||
def __init__(self, args):
|
||||
self.text_detector = predict_det.TextDetector(args)
|
||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||
self.use_angle_cls = args.use_angle_cls
|
||||
if self.use_angle_cls:
|
||||
self.text_classifier = predict_cls.TextClassifier(args)
|
||||
|
||||
def get_rotate_crop_image(self, img, points):
|
||||
'''
|
||||
|
@ -91,6 +97,11 @@ class TextSystem(object):
|
|||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
if self.use_angle_cls:
|
||||
img_crop_list, angle_list, elapse = self.text_classifier(
|
||||
img_crop_list)
|
||||
print("cls num : {}, elapse : {}".format(
|
||||
len(img_crop_list), elapse))
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
|
||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||
|
@ -110,8 +121,8 @@ def sorted_boxes(dt_boxes):
|
|||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||
tmp = _boxes[i]
|
||||
_boxes[i] = _boxes[i + 1]
|
||||
_boxes[i + 1] = tmp
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import argparse
|
||||
import os, sys
|
||||
from ppocr.utils.utility import initial_logger
|
||||
|
||||
logger = initial_logger()
|
||||
from paddle.fluid.core import PaddleTensor
|
||||
from paddle.fluid.core import AnalysisConfig
|
||||
|
@ -31,34 +32,34 @@ def parse_args():
|
|||
return v.lower() in ("true", "t", "1")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
#params for prediction engine
|
||||
# params for prediction engine
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||
|
||||
#params for text detector
|
||||
# params for text detector
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
parser.add_argument("--det_algorithm", type=str, default='DB')
|
||||
parser.add_argument("--det_model_dir", type=str)
|
||||
parser.add_argument("--det_max_side_len", type=float, default=960)
|
||||
|
||||
#DB parmas
|
||||
# DB parmas
|
||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
|
||||
|
||||
#EAST parmas
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
||||
|
||||
#SAST parmas
|
||||
# SAST parmas
|
||||
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
||||
parser.add_argument("--det_sast_polygon", type=bool, default=False)
|
||||
|
||||
#params for text recognizer
|
||||
# params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||
parser.add_argument("--rec_model_dir", type=str)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||
|
@ -70,14 +71,24 @@ def parse_args():
|
|||
type=str,
|
||||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
||||
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
|
||||
|
||||
# params for text classifier
|
||||
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||
parser.add_argument("--cls_model_dir", type=str)
|
||||
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
||||
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
||||
parser.add_argument("--cls_batch_num", type=int, default=30)
|
||||
|
||||
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
||||
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def create_predictor(args, mode):
|
||||
if mode == "det":
|
||||
model_dir = args.det_model_dir
|
||||
elif mode == 'cls':
|
||||
model_dir = args.cls_model_dir
|
||||
else:
|
||||
model_dir = args.rec_model_dir
|
||||
|
||||
|
@ -105,7 +116,7 @@ def create_predictor(args, mode):
|
|||
config.set_mkldnn_cache_capacity(10)
|
||||
config.enable_mkldnn()
|
||||
|
||||
#config.enable_memory_optim()
|
||||
# config.enable_memory_optim()
|
||||
config.disable_glog_info()
|
||||
|
||||
if args.use_zero_copy_run:
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
||||
|
||||
def set_paddle_flags(**kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if os.environ.get(key, None) is None:
|
||||
os.environ[key] = str(value)
|
||||
|
||||
|
||||
# NOTE(paddle-dev): All of these flags should be
|
||||
# set before `import paddle`. Otherwise, it would
|
||||
# not take any effect.
|
||||
set_paddle_flags(
|
||||
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||
)
|
||||
|
||||
import tools.program as program
|
||||
from paddle import fluid
|
||||
from ppocr.utils.utility import initial_logger
|
||||
|
||||
logger = initial_logger()
|
||||
from ppocr.data.reader_main import reader_main
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.utility import create_module
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
|
||||
|
||||
def main():
|
||||
config = program.load_config(FLAGS.config)
|
||||
program.merge_config(FLAGS.opt)
|
||||
logger.info(config)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
# check_gpu(use_gpu)
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
|
||||
rec_model = create_module(config['Architecture']['function'])(params=config)
|
||||
startup_prog = fluid.Program()
|
||||
eval_prog = fluid.Program()
|
||||
with fluid.program_guard(eval_prog, startup_prog):
|
||||
with fluid.unique_name.guard():
|
||||
_, outputs = rec_model(mode="test")
|
||||
fetch_name_list = list(outputs.keys())
|
||||
fetch_varname_list = [outputs[v].name for v in fetch_name_list]
|
||||
eval_prog = eval_prog.clone(for_test=True)
|
||||
exe.run(startup_prog)
|
||||
|
||||
init_model(config, eval_prog, exe)
|
||||
|
||||
blobs = reader_main(config, 'test')()
|
||||
infer_img = config['Global']['infer_img']
|
||||
infer_list = get_image_file_list(infer_img)
|
||||
max_img_num = len(infer_list)
|
||||
if len(infer_list) == 0:
|
||||
logger.info("Can not find img in infer_img dir.")
|
||||
for i in range(max_img_num):
|
||||
logger.info("infer_img:%s" % infer_list[i])
|
||||
img = next(blobs)
|
||||
predict = exe.run(program=eval_prog,
|
||||
feed={"image": img},
|
||||
fetch_list=fetch_varname_list,
|
||||
return_numpy=False)
|
||||
scores = np.array(predict[0])
|
||||
label = np.array(predict[1])
|
||||
if len(label.shape) != 1:
|
||||
label, scores = scores, label
|
||||
logger.info('\t scores: {}'.format(scores))
|
||||
logger.info('\t label: {}'.format(label))
|
||||
# save for inference model
|
||||
target_var = []
|
||||
for key, values in outputs.items():
|
||||
target_var.append(values)
|
||||
|
||||
fluid.io.save_inference_model(
|
||||
"./output",
|
||||
feeded_var_names=['image'],
|
||||
target_vars=target_var,
|
||||
executor=exe,
|
||||
main_program=eval_prog,
|
||||
model_filename="model",
|
||||
params_filename="params")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = program.ArgsParser()
|
||||
FLAGS = parser.parse_args()
|
||||
main()
|
|
@ -30,6 +30,7 @@ import time
|
|||
from ppocr.utils.stats import TrainingStats
|
||||
from eval_utils.eval_det_utils import eval_det_run
|
||||
from eval_utils.eval_rec_utils import eval_rec_run
|
||||
from eval_utils.eval_cls_utils import eval_cls_run
|
||||
from ppocr.utils.save_load import save_model
|
||||
import numpy as np
|
||||
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
||||
|
@ -409,6 +410,87 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
return
|
||||
|
||||
|
||||
def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
|
||||
train_batch_id = 0
|
||||
log_smooth_window = config['Global']['log_smooth_window']
|
||||
epoch_num = config['Global']['epoch_num']
|
||||
print_batch_step = config['Global']['print_batch_step']
|
||||
eval_batch_step = config['Global']['eval_batch_step']
|
||||
start_eval_step = 0
|
||||
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
|
||||
start_eval_step = eval_batch_step[0]
|
||||
eval_batch_step = eval_batch_step[1]
|
||||
logger.info(
|
||||
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
|
||||
format(start_eval_step, eval_batch_step))
|
||||
save_epoch_step = config['Global']['save_epoch_step']
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
if not os.path.exists(save_model_dir):
|
||||
os.makedirs(save_model_dir)
|
||||
train_stats = TrainingStats(log_smooth_window, ['loss', 'acc'])
|
||||
best_eval_acc = -1
|
||||
best_batch_id = 0
|
||||
best_epoch = 0
|
||||
train_loader = train_info_dict['reader']
|
||||
for epoch in range(epoch_num):
|
||||
train_loader.start()
|
||||
try:
|
||||
while True:
|
||||
t1 = time.time()
|
||||
train_outs = exe.run(
|
||||
program=train_info_dict['compile_program'],
|
||||
fetch_list=train_info_dict['fetch_varname_list'],
|
||||
return_numpy=False)
|
||||
fetch_map = dict(
|
||||
zip(train_info_dict['fetch_name_list'],
|
||||
range(len(train_outs))))
|
||||
|
||||
loss = np.mean(np.array(train_outs[fetch_map['total_loss']]))
|
||||
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
|
||||
acc = np.mean(np.array(train_outs[fetch_map['acc']]))
|
||||
|
||||
t2 = time.time()
|
||||
train_batch_elapse = t2 - t1
|
||||
stats = {'loss': loss, 'acc': acc}
|
||||
train_stats.update(stats)
|
||||
if train_batch_id > start_eval_step and (train_batch_id - start_eval_step) \
|
||||
% print_batch_step == 0:
|
||||
logs = train_stats.log()
|
||||
strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format(
|
||||
epoch, train_batch_id, lr, logs, train_batch_elapse)
|
||||
logger.info(strs)
|
||||
|
||||
if train_batch_id > 0 and\
|
||||
train_batch_id % eval_batch_step == 0:
|
||||
model_average = train_info_dict['model_average']
|
||||
if model_average != None:
|
||||
model_average.apply(exe)
|
||||
metrics = eval_cls_run(exe, eval_info_dict)
|
||||
eval_acc = metrics['avg_acc']
|
||||
eval_sample_num = metrics['total_sample_num']
|
||||
if eval_acc > best_eval_acc:
|
||||
best_eval_acc = eval_acc
|
||||
best_batch_id = train_batch_id
|
||||
best_epoch = epoch
|
||||
save_path = save_model_dir + "/best_accuracy"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
|
||||
train_batch_id, eval_acc, best_eval_acc, best_epoch,
|
||||
best_batch_id, eval_sample_num)
|
||||
logger.info(strs)
|
||||
train_batch_id += 1
|
||||
|
||||
except fluid.core.EOFException:
|
||||
train_loader.reset()
|
||||
if epoch == 0 and save_epoch_step == 1:
|
||||
save_path = save_model_dir + "/iter_epoch_0"
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
if epoch > 0 and epoch % save_epoch_step == 0:
|
||||
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
|
||||
save_model(train_info_dict['train_program'], save_path)
|
||||
return
|
||||
|
||||
|
||||
def preprocess():
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
|
@ -421,7 +503,7 @@ def preprocess():
|
|||
|
||||
alg = config['Global']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
|
||||
]
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||
|
@ -432,7 +514,9 @@ def preprocess():
|
|||
|
||||
if alg in ['EAST', 'DB', 'SAST']:
|
||||
train_alg_type = 'det'
|
||||
else:
|
||||
elif alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
||||
train_alg_type = 'rec'
|
||||
else:
|
||||
train_alg_type = 'cls'
|
||||
|
||||
return startup_program, train_program, place, config, train_alg_type
|
||||
|
|
|
@ -75,7 +75,8 @@ def main():
|
|||
|
||||
# dump mode structure
|
||||
if config['Global']['debug']:
|
||||
if train_alg_type == 'rec' and 'attention' in config['Global']['loss_type']:
|
||||
if train_alg_type == 'rec' and 'attention' in config['Global'][
|
||||
'loss_type']:
|
||||
logger.warning('Does not suport dump attention...')
|
||||
else:
|
||||
summary(train_program)
|
||||
|
@ -96,8 +97,10 @@ def main():
|
|||
|
||||
if train_alg_type == 'det':
|
||||
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
|
||||
else:
|
||||
elif train_alg_type == 'rec':
|
||||
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
||||
else:
|
||||
program.train_eval_cls_run(config, exe, train_info_dict, eval_info_dict)
|
||||
|
||||
|
||||
def test_reader():
|
||||
|
@ -119,6 +122,7 @@ def test_reader():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
startup_program, train_program, place, config, train_alg_type = program.preprocess()
|
||||
startup_program, train_program, place, config, train_alg_type = program.preprocess(
|
||||
)
|
||||
main()
|
||||
# test_reader()
|
||||
|
|
Loading…
Reference in New Issue