parent
c852b91647
commit
bc563c642c
|
@ -0,0 +1,97 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 100
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/cls/mv3/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 1000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
load_static_weights: True
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
label_list: ['0','180']
|
||||
|
||||
Architecture:
|
||||
model_type: cls
|
||||
algorithm: CLS
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.35
|
||||
model_name: small
|
||||
Neck:
|
||||
Head:
|
||||
name: ClsHead
|
||||
class_dim: 2
|
||||
|
||||
Loss:
|
||||
name: ClsLoss
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: ClsPostProcess
|
||||
|
||||
Metric:
|
||||
name: ClsMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/cls
|
||||
label_file_list:
|
||||
- ./train_data/cls/train.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ClsLabelEncode: # Class handling label
|
||||
- RecAug:
|
||||
use_tia: False
|
||||
- RandAugment:
|
||||
- ClsResizeImg:
|
||||
image_shape: [3, 48, 192]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 512
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/cls
|
||||
label_file_list:
|
||||
- ./train_data/cls/test.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ClsLabelEncode: # Class handling label
|
||||
- ClsResizeImg:
|
||||
image_shape: [3, 48, 192]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 512
|
||||
num_workers: 4
|
|
@ -0,0 +1,96 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/r34_vd_none_bilstm_ctc/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 2000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
learning_rate: 0.0005
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 34
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 256
|
||||
Head:
|
||||
name: CTCHead
|
||||
fc_decay: 0
|
||||
|
||||
Loss:
|
||||
name: CTCLoss
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
batch_size_per_card: 256
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 4
|
|
@ -52,20 +52,29 @@ include_directories(${OpenCV_INCLUDE_DIRS})
|
|||
|
||||
if (WIN32)
|
||||
add_definitions("/DGOOGLE_GLOG_DLL_DECL=")
|
||||
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT")
|
||||
if(WITH_MKL)
|
||||
set(FLAG_OPENMP "/openmp")
|
||||
endif()
|
||||
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}")
|
||||
set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}")
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj /MTd ${FLAG_OPENMP}")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj /MT ${FLAG_OPENMP}")
|
||||
if (WITH_STATIC_LIB)
|
||||
safe_set_static_flag()
|
||||
add_definitions(-DSTATIC_LIB)
|
||||
endif()
|
||||
message("cmake c debug flags " ${CMAKE_C_FLAGS_DEBUG})
|
||||
message("cmake c release flags " ${CMAKE_C_FLAGS_RELEASE})
|
||||
message("cmake cxx debug flags " ${CMAKE_CXX_FLAGS_DEBUG})
|
||||
message("cmake cxx release flags " ${CMAKE_CXX_FLAGS_RELEASE})
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -o3 -std=c++11")
|
||||
if(WITH_MKL)
|
||||
set(FLAG_OPENMP "-fopenmp")
|
||||
endif()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -o3 ${FLAG_OPENMP} -std=c++11")
|
||||
set(CMAKE_STATIC_LIBRARY_PREFIX "")
|
||||
message("cmake cxx flags" ${CMAKE_CXX_FLAGS})
|
||||
endif()
|
||||
message("flags" ${CMAKE_CXX_FLAGS})
|
||||
|
||||
|
||||
if (WITH_GPU)
|
||||
if (NOT DEFINED CUDA_LIB OR ${CUDA_LIB} STREQUAL "")
|
||||
|
@ -198,4 +207,4 @@ if (WIN32 AND WITH_MKL)
|
|||
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mklml/lib/libiomp5md.dll ./release/libiomp5md.dll
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${PADDLE_LIB}/third_party/install/mkldnn/lib/mkldnn.dll ./release/mkldnn.dll
|
||||
)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -57,6 +57,12 @@ public:
|
|||
|
||||
this->char_list_file.assign(config_map_["char_list_file"]);
|
||||
|
||||
this->use_angle_cls = bool(stoi(config_map_["use_angle_cls"]));
|
||||
|
||||
this->cls_model_dir.assign(config_map_["cls_model_dir"]);
|
||||
|
||||
this->cls_thresh = stod(config_map_["cls_thresh"]);
|
||||
|
||||
this->visualize = bool(stoi(config_map_["visualize"]));
|
||||
}
|
||||
|
||||
|
@ -84,8 +90,14 @@ public:
|
|||
|
||||
std::string rec_model_dir;
|
||||
|
||||
bool use_angle_cls;
|
||||
|
||||
std::string char_list_file;
|
||||
|
||||
std::string cls_model_dir;
|
||||
|
||||
double cls_thresh;
|
||||
|
||||
bool visualize = true;
|
||||
|
||||
void PrintConfigInfo();
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
#include "paddle_api.h"
|
||||
#include "paddle_inference_api.h"
|
||||
#include <chrono>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <vector>
|
||||
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <include/preprocess_op.h>
|
||||
#include <include/utility.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class Classifier {
|
||||
public:
|
||||
explicit Classifier(const std::string &model_dir, const bool &use_gpu,
|
||||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const bool &use_zero_copy_run,
|
||||
const double &cls_thresh) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
this->use_zero_copy_run_ = use_zero_copy_run;
|
||||
|
||||
this->cls_thresh = cls_thresh;
|
||||
|
||||
LoadModel(model_dir);
|
||||
}
|
||||
|
||||
// Load Paddle inference model
|
||||
void LoadModel(const std::string &model_dir);
|
||||
|
||||
cv::Mat Run(cv::Mat &img);
|
||||
|
||||
private:
|
||||
std::shared_ptr<PaddlePredictor> predictor_;
|
||||
|
||||
bool use_gpu_ = false;
|
||||
int gpu_id_ = 0;
|
||||
int gpu_mem_ = 4000;
|
||||
int cpu_math_library_num_threads_ = 4;
|
||||
bool use_mkldnn_ = false;
|
||||
bool use_zero_copy_run_ = false;
|
||||
double cls_thresh = 0.5;
|
||||
|
||||
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
|
||||
std::vector<float> scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
bool is_scale_ = true;
|
||||
|
||||
// pre-process
|
||||
ClsResizeImg resize_op_;
|
||||
Normalize normalize_op_;
|
||||
Permute permute_op_;
|
||||
|
||||
}; // class Classifier
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -27,6 +27,7 @@
|
|||
#include <fstream>
|
||||
#include <numeric>
|
||||
|
||||
#include <include/ocr_cls.h>
|
||||
#include <include/postprocess_op.h>
|
||||
#include <include/preprocess_op.h>
|
||||
#include <include/utility.h>
|
||||
|
@ -48,6 +49,8 @@ public:
|
|||
this->use_zero_copy_run_ = use_zero_copy_run;
|
||||
|
||||
this->label_list_ = Utility::ReadDict(label_path);
|
||||
this->label_list_.insert(this->label_list_.begin(),
|
||||
"#"); // blank char for ctc
|
||||
this->label_list_.push_back(" ");
|
||||
|
||||
LoadModel(model_dir);
|
||||
|
@ -56,7 +59,8 @@ public:
|
|||
// Load Paddle inference model
|
||||
void LoadModel(const std::string &model_dir);
|
||||
|
||||
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img);
|
||||
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img,
|
||||
Classifier *cls);
|
||||
|
||||
private:
|
||||
std::shared_ptr<PaddlePredictor> predictor_;
|
||||
|
|
|
@ -56,4 +56,10 @@ public:
|
|||
const std::vector<int> &rec_image_shape = {3, 32, 320});
|
||||
};
|
||||
|
||||
class ClsResizeImg {
|
||||
public:
|
||||
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const std::vector<int> &rec_image_shape = {3, 48, 192});
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -193,6 +193,39 @@ make -j
|
|||
sh tools/run.sh
|
||||
```
|
||||
|
||||
* 若需要使用方向分类器,则需要将`tools/config.txt`中的`use_angle_cls`参数修改为1,表示开启方向分类器的预测。
|
||||
* 更多地,tools/config.txt中的参数及解释如下。
|
||||
|
||||
```
|
||||
use_gpu 0 # 是否使用GPU,1表示使用,0表示不使用
|
||||
gpu_id 0 # GPU id,使用GPU时有效
|
||||
gpu_mem 4000 # 申请的GPU内存
|
||||
cpu_math_library_num_threads 10 # CPU预测时的线程数,在机器核数充足的情况下,该值越大,预测速度越快
|
||||
use_mkldnn 1 # 是否使用mkldnn库
|
||||
use_zero_copy_run 1 # 是否使用use_zero_copy_run进行预测
|
||||
|
||||
# det config
|
||||
max_side_len 960 # 输入图像长宽大于960时,等比例缩放图像,使得图像最长边为960
|
||||
det_db_thresh 0.3 # 用于过滤DB预测的二值化图像,设置为0.-0.3对结果影响不明显
|
||||
det_db_box_thresh 0.5 # DB后处理过滤box的阈值,如果检测存在漏框情况,可酌情减小
|
||||
det_db_unclip_ratio 1.6 # 表示文本框的紧致程度,越小则文本框更靠近文本
|
||||
det_model_dir ./inference/det_db # 检测模型inference model地址
|
||||
|
||||
# cls config
|
||||
use_angle_cls 0 # 是否使用方向分类器,0表示不使用,1表示使用
|
||||
cls_model_dir ./inference/cls # 方向分类器inference model地址
|
||||
cls_thresh 0.9 # 方向分类器的得分阈值
|
||||
|
||||
# rec config
|
||||
rec_model_dir ./inference/rec_crnn # 识别模型inference model地址
|
||||
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt # 字典文件
|
||||
|
||||
# show the detection results
|
||||
visualize 1 # 是否对结果进行可视化,为1时,会在当前文件夹下保存文件名为`ocr_vis.png`的预测结果。
|
||||
```
|
||||
|
||||
* PaddleOCR也支持多语言的预测,更多细节可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分。
|
||||
|
||||
最终屏幕上会输出检测结果如下。
|
||||
|
||||
<div align="center">
|
||||
|
@ -202,4 +235,4 @@ sh tools/run.sh
|
|||
|
||||
### 2.3 注意
|
||||
|
||||
* C++预测默认未开启MKLDNN(`tools/config.txt`中的`use_mkldnn`设置为0),如果需要使用MKLDNN进行预测加速,则需要将`use_mkldnn`修改为1,同时使用最新版本的Paddle源码编译预测库。在使用MKLDNN进行CPU预测时,如果同时预测多张图像,则会出现内存泄露的问题(不打开MKLDNN则没有该问题),目前该问题正在修复中,临时解决方案为:预测多张图片时,每隔30张图片左右对识别(`CRNNRecognizer`)和检测类(`DBDetector`)重新初始化一次。
|
||||
* 在使用Paddle预测库时,推荐使用2.0.0-beta0版本的预测库。
|
||||
|
|
|
@ -162,7 +162,7 @@ inference/
|
|||
sh tools/build.sh
|
||||
```
|
||||
|
||||
具体地,`tools/build.sh`中内容如下。
|
||||
Specifically, the content in `tools/build.sh` is as follows.
|
||||
|
||||
```shell
|
||||
OPENCV_DIR=your_opencv_dir
|
||||
|
@ -201,6 +201,40 @@ make -j
|
|||
sh tools/run.sh
|
||||
```
|
||||
|
||||
* If you want to orientation classifier to correct the detected boxes, you can set `use_angle_cls` in the file `tools/config.txt` as 1 to enable the function.
|
||||
* What's more, Parameters and their meanings in `tools/config.txt` are as follows.
|
||||
|
||||
|
||||
```
|
||||
use_gpu 0 # Whether to use GPU, 0 means not to use, 1 means to use
|
||||
gpu_id 0 # GPU id when use_gpu is 1
|
||||
gpu_mem 4000 # GPU memory requested
|
||||
cpu_math_library_num_threads 10 # Number of threads when using CPU inference. When machine cores is enough, the large the value, the faster the inference speed
|
||||
use_mkldnn 1 # Whether to use mkdlnn library
|
||||
use_zero_copy_run 1 # Whether to use use_zero_copy_run for inference
|
||||
|
||||
max_side_len 960 # Limit the maximum image height and width to 960
|
||||
det_db_thresh 0.3 # Used to filter the binarized image of DB prediction, setting 0.-0.3 has no obvious effect on the result
|
||||
det_db_box_thresh 0.5 # DDB post-processing filter box threshold, if there is a missing box detected, it can be reduced as appropriate
|
||||
det_db_unclip_ratio 1.6 # Indicates the compactness of the text box, the smaller the value, the closer the text box to the text
|
||||
det_model_dir ./inference/det_db # Address of detection inference model
|
||||
|
||||
# cls config
|
||||
use_angle_cls 0 # Whether to use the direction classifier, 0 means not to use, 1 means to use
|
||||
cls_model_dir ./inference/cls # Address of direction classifier inference model
|
||||
cls_thresh 0.9 # Score threshold of the direction classifier
|
||||
|
||||
# rec config
|
||||
rec_model_dir ./inference/rec_crnn # Address of recognition inference model
|
||||
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt # dictionary file
|
||||
|
||||
# show the detection results
|
||||
visualize 1 # Whether to visualize the results,when it is set as 1, The prediction result will be save in the image file `./ocr_vis.png`.
|
||||
```
|
||||
|
||||
* Multi-language inference is also supported in PaddleOCR, for more details, please refer to part of multi-language dictionaries and models in [recognition tutorial](../../doc/doc_en/recognition_en.md).
|
||||
|
||||
|
||||
The detection results will be shown on the screen, which is as follows.
|
||||
|
||||
<div align="center">
|
||||
|
@ -208,6 +242,6 @@ The detection results will be shown on the screen, which is as follows.
|
|||
</div>
|
||||
|
||||
|
||||
### 2.3 Note
|
||||
### 2.3 Notes
|
||||
|
||||
* `MKLDNN` is disabled by default for C++ inference (`use_mkldnn` in `tools/config.txt` is set to 0), if you need to use MKLDNN for inference acceleration, you need to modify `use_mkldnn` to 1, and use the latest version of the Paddle source code to compile the inference library. When using MKLDNN for CPU prediction, if multiple images are predicted at the same time, there will be a memory leak problem (the problem is not present if MKLDNN is disabled). The problem is currently being fixed, and the temporary solution is: when predicting multiple pictures, Re-initialize the recognition (`CRNNRecognizer`) and detection class (`DBDetector`) every 30 pictures or so.
|
||||
* Paddle2.0.0-beta0 inference model library is recommanded for this tuturial.
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "glog/logging.h"
|
||||
#include "omp.h"
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
|
@ -53,17 +55,38 @@ int main(int argc, char **argv) {
|
|||
config.cpu_math_library_num_threads, config.use_mkldnn,
|
||||
config.use_zero_copy_run, config.max_side_len, config.det_db_thresh,
|
||||
config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize);
|
||||
|
||||
Classifier *cls = nullptr;
|
||||
if (config.use_angle_cls == true) {
|
||||
cls = new Classifier(config.cls_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.use_zero_copy_run,
|
||||
config.cls_thresh);
|
||||
}
|
||||
|
||||
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.use_zero_copy_run,
|
||||
config.char_list_file);
|
||||
|
||||
#ifdef USE_MKL
|
||||
#pragma omp parallel
|
||||
for (auto i = 0; i < 10; i++) {
|
||||
LOG_IF(WARNING,
|
||||
config.cpu_math_library_num_threads != omp_get_num_threads())
|
||||
<< "WARNING! MKL is running on " << omp_get_num_threads()
|
||||
<< " threads while cpu_math_library_num_threads is set to "
|
||||
<< config.cpu_math_library_num_threads
|
||||
<< ". Possible reason could be 1. You have set omp_set_num_threads() "
|
||||
"somewhere; 2. MKL is not linked properly";
|
||||
}
|
||||
#endif
|
||||
|
||||
auto start = std::chrono::system_clock::now();
|
||||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
det.Run(srcimg, boxes);
|
||||
|
||||
rec.Run(boxes, srcimg);
|
||||
|
||||
rec.Run(boxes, srcimg, cls);
|
||||
auto end = std::chrono::system_clock::now();
|
||||
auto duration =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <include/ocr_cls.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
cv::Mat Classifier::Run(cv::Mat &img) {
|
||||
cv::Mat src_img;
|
||||
img.copyTo(src_img);
|
||||
cv::Mat resize_img;
|
||||
|
||||
std::vector<int> cls_image_shape = {3, 48, 192};
|
||||
int index = 0;
|
||||
float wh_ratio = float(img.cols) / float(img.rows);
|
||||
|
||||
this->resize_op_.Run(img, resize_img, cls_image_shape);
|
||||
|
||||
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||
this->is_scale_);
|
||||
|
||||
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
||||
|
||||
this->permute_op_.Run(&resize_img, input.data());
|
||||
|
||||
// Inference.
|
||||
if (this->use_zero_copy_run_) {
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
input_t->copy_from_cpu(input.data());
|
||||
this->predictor_->ZeroCopyRun();
|
||||
} else {
|
||||
paddle::PaddleTensor input_t;
|
||||
input_t.shape = {1, 3, resize_img.rows, resize_img.cols};
|
||||
input_t.data =
|
||||
paddle::PaddleBuf(input.data(), input.size() * sizeof(float));
|
||||
input_t.dtype = PaddleDType::FLOAT32;
|
||||
std::vector<paddle::PaddleTensor> outputs;
|
||||
this->predictor_->Run({input_t}, &outputs, 1);
|
||||
}
|
||||
|
||||
std::vector<float> softmax_out;
|
||||
std::vector<int64_t> label_out;
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
auto softmax_out_t = this->predictor_->GetOutputTensor(output_names[0]);
|
||||
auto softmax_shape_out = softmax_out_t->shape();
|
||||
|
||||
int softmax_out_num =
|
||||
std::accumulate(softmax_shape_out.begin(), softmax_shape_out.end(), 1,
|
||||
std::multiplies<int>());
|
||||
|
||||
softmax_out.resize(softmax_out_num);
|
||||
|
||||
softmax_out_t->copy_to_cpu(softmax_out.data());
|
||||
|
||||
float score = 0;
|
||||
int label = 0;
|
||||
for (int i = 0; i < softmax_out_num; i++) {
|
||||
if (softmax_out[i] > score) {
|
||||
score = softmax_out[i];
|
||||
label = i;
|
||||
}
|
||||
}
|
||||
if (label % 2 == 1 && score > this->cls_thresh) {
|
||||
cv::rotate(src_img, src_img, 1);
|
||||
}
|
||||
return src_img;
|
||||
}
|
||||
|
||||
void Classifier::LoadModel(const std::string &model_dir) {
|
||||
AnalysisConfig config;
|
||||
config.SetModel(model_dir + "/model", model_dir + "/params");
|
||||
|
||||
if (this->use_gpu_) {
|
||||
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
|
||||
} else {
|
||||
config.DisableGpu();
|
||||
if (this->use_mkldnn_) {
|
||||
config.EnableMKLDNN();
|
||||
}
|
||||
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
|
||||
}
|
||||
|
||||
// false for zero copy tensor
|
||||
config.SwitchUseFeedFetchOps(!this->use_zero_copy_run_);
|
||||
// true for multiple input
|
||||
config.SwitchSpecifyInputNames(true);
|
||||
|
||||
config.SwitchIrOptim(true);
|
||||
|
||||
config.EnableMemoryOptim();
|
||||
config.DisableGlogInfo();
|
||||
|
||||
this->predictor_ = CreatePaddlePredictor(config);
|
||||
}
|
||||
} // namespace PaddleOCR
|
|
@ -26,6 +26,8 @@ void DBDetector::LoadModel(const std::string &model_dir) {
|
|||
config.DisableGpu();
|
||||
if (this->use_mkldnn_) {
|
||||
config.EnableMKLDNN();
|
||||
// cache 10 different shapes for mkldnn to avoid memory leak
|
||||
config.SetMkldnnCacheCapacity(10);
|
||||
}
|
||||
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
|
||||
}
|
||||
|
@ -106,9 +108,12 @@ void DBDetector::Run(cv::Mat &img,
|
|||
const double maxvalue = 255;
|
||||
cv::Mat bit_map;
|
||||
cv::threshold(cbuf_map, bit_map, threshold, maxvalue, cv::THRESH_BINARY);
|
||||
|
||||
boxes = post_processor_.BoxesFromBitmap(
|
||||
pred_map, bit_map, this->det_db_box_thresh_, this->det_db_unclip_ratio_);
|
||||
cv::Mat dilation_map;
|
||||
cv::Mat dila_ele = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(2, 2));
|
||||
cv::dilate(bit_map, dilation_map, dila_ele);
|
||||
boxes = post_processor_.BoxesFromBitmap(pred_map, dilation_map,
|
||||
this->det_db_box_thresh_,
|
||||
this->det_db_unclip_ratio_);
|
||||
|
||||
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
namespace PaddleOCR {
|
||||
|
||||
void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
||||
cv::Mat &img) {
|
||||
cv::Mat &img, Classifier *cls) {
|
||||
cv::Mat srcimg;
|
||||
img.copyTo(srcimg);
|
||||
cv::Mat crop_img;
|
||||
|
@ -27,6 +27,9 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
int index = 0;
|
||||
for (int i = boxes.size() - 1; i >= 0; i--) {
|
||||
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
||||
if (cls != nullptr) {
|
||||
crop_img = cls->Run(crop_img);
|
||||
}
|
||||
|
||||
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
|
||||
|
||||
|
@ -56,62 +59,44 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
this->predictor_->Run({input_t}, &outputs, 1);
|
||||
}
|
||||
|
||||
std::vector<int64_t> rec_idx;
|
||||
std::vector<float> predict_batch;
|
||||
auto output_names = this->predictor_->GetOutputNames();
|
||||
auto output_t = this->predictor_->GetOutputTensor(output_names[0]);
|
||||
auto rec_idx_lod = output_t->lod();
|
||||
auto shape_out = output_t->shape();
|
||||
auto predict_shape = output_t->shape();
|
||||
|
||||
int out_num = std::accumulate(shape_out.begin(), shape_out.end(), 1,
|
||||
int out_num = std::accumulate(predict_shape.begin(), predict_shape.end(), 1,
|
||||
std::multiplies<int>());
|
||||
predict_batch.resize(out_num);
|
||||
|
||||
rec_idx.resize(out_num);
|
||||
output_t->copy_to_cpu(rec_idx.data());
|
||||
|
||||
std::vector<int> pred_idx;
|
||||
for (int n = int(rec_idx_lod[0][0]); n < int(rec_idx_lod[0][1]); n++) {
|
||||
pred_idx.push_back(int(rec_idx[n]));
|
||||
}
|
||||
|
||||
if (pred_idx.size() < 1e-3)
|
||||
continue;
|
||||
|
||||
index += 1;
|
||||
std::cout << index << "\t";
|
||||
for (int n = 0; n < pred_idx.size(); n++) {
|
||||
std::cout << label_list_[pred_idx[n]];
|
||||
}
|
||||
|
||||
std::vector<float> predict_batch;
|
||||
auto output_t_1 = this->predictor_->GetOutputTensor(output_names[1]);
|
||||
|
||||
auto predict_lod = output_t_1->lod();
|
||||
auto predict_shape = output_t_1->shape();
|
||||
int out_num_1 = std::accumulate(predict_shape.begin(), predict_shape.end(),
|
||||
1, std::multiplies<int>());
|
||||
|
||||
predict_batch.resize(out_num_1);
|
||||
output_t_1->copy_to_cpu(predict_batch.data());
|
||||
output_t->copy_to_cpu(predict_batch.data());
|
||||
|
||||
// ctc decode
|
||||
std::vector<std::string> str_res;
|
||||
int argmax_idx;
|
||||
int blank = predict_shape[1];
|
||||
int last_index = 0;
|
||||
float score = 0.f;
|
||||
int count = 0;
|
||||
float max_value = 0.0f;
|
||||
|
||||
for (int n = predict_lod[0][0]; n < predict_lod[0][1] - 1; n++) {
|
||||
for (int n = 0; n < predict_shape[1]; n++) {
|
||||
argmax_idx =
|
||||
int(Utility::argmax(&predict_batch[n * predict_shape[1]],
|
||||
&predict_batch[(n + 1) * predict_shape[1]]));
|
||||
int(Utility::argmax(&predict_batch[n * predict_shape[2]],
|
||||
&predict_batch[(n + 1) * predict_shape[2]]));
|
||||
max_value =
|
||||
float(*std::max_element(&predict_batch[n * predict_shape[1]],
|
||||
&predict_batch[(n + 1) * predict_shape[1]]));
|
||||
if (blank - 1 - argmax_idx > 1e-5) {
|
||||
float(*std::max_element(&predict_batch[n * predict_shape[2]],
|
||||
&predict_batch[(n + 1) * predict_shape[2]]));
|
||||
|
||||
if (argmax_idx > 0 && (not(i > 0 && argmax_idx == last_index))) {
|
||||
score += max_value;
|
||||
count += 1;
|
||||
str_res.push_back(label_list_[argmax_idx]);
|
||||
}
|
||||
last_index = argmax_idx;
|
||||
}
|
||||
score /= count;
|
||||
for (int i = 0; i < str_res.size(); i++) {
|
||||
std::cout << str_res[i];
|
||||
}
|
||||
std::cout << "\tscore: " << score << std::endl;
|
||||
}
|
||||
}
|
||||
|
@ -126,6 +111,8 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
|||
config.DisableGpu();
|
||||
if (this->use_mkldnn_) {
|
||||
config.EnableMKLDNN();
|
||||
// cache 10 different shapes for mkldnn to avoid memory leak
|
||||
config.SetMkldnnCacheCapacity(10);
|
||||
}
|
||||
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
|
||||
}
|
||||
|
@ -199,4 +186,4 @@ cv::Mat CRNNRecognizer::GetRotateCropImage(const cv::Mat &srcimage,
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
||||
} // namespace PaddleOCR
|
|
@ -294,7 +294,7 @@ PostProcessor::FilterTagDetRes(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
pow(boxes[n][0][1] - boxes[n][1][1], 2)));
|
||||
rect_height = int(sqrt(pow(boxes[n][0][0] - boxes[n][3][0], 2) +
|
||||
pow(boxes[n][0][1] - boxes[n][3][1], 2)));
|
||||
if (rect_width <= 10 || rect_height <= 10)
|
||||
if (rect_width <= 4 || rect_height <= 4)
|
||||
continue;
|
||||
root_points.push_back(boxes[n]);
|
||||
}
|
||||
|
|
|
@ -85,7 +85,7 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
|
|||
|
||||
if (resize_w % 32 == 0)
|
||||
resize_w = resize_w;
|
||||
else if (resize_w / 32 < 1)
|
||||
else if (resize_w / 32 < 1 + 1e-5)
|
||||
resize_w = 32;
|
||||
else
|
||||
resize_w = (resize_w / 32 - 1) * 32;
|
||||
|
@ -116,4 +116,26 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
|
|||
cv::INTER_LINEAR);
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
||||
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
const std::vector<int> &rec_image_shape) {
|
||||
int imgC, imgH, imgW;
|
||||
imgC = rec_image_shape[0];
|
||||
imgH = rec_image_shape[1];
|
||||
imgW = rec_image_shape[2];
|
||||
|
||||
float ratio = float(img.cols) / float(img.rows);
|
||||
int resize_w, resize_h;
|
||||
if (ceilf(imgH * ratio) > imgW)
|
||||
resize_w = imgW;
|
||||
else
|
||||
resize_w = int(ceilf(imgH * ratio));
|
||||
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
|
||||
cv::INTER_LINEAR);
|
||||
if (resize_w < imgW) {
|
||||
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
|
||||
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -21,8 +21,8 @@ from .make_border_map import MakeBorderMap
|
|||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .operators import *
|
||||
from .label_ops import *
|
||||
|
||||
|
|
|
@ -18,7 +18,19 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import numpy as np
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
class ClsLabelEncode(object):
|
||||
def __init__(self, label_list, **kwargs):
|
||||
self.label_list = label_list
|
||||
|
||||
def __call__(self, data):
|
||||
label = data['label']
|
||||
if label not in self.label_list:
|
||||
return None
|
||||
label = self.label_list.index(label)
|
||||
data['label'] = label
|
||||
return data
|
||||
|
||||
|
||||
class DetLabelEncode(object):
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
import six
|
||||
|
||||
|
||||
class RawRandAugment(object):
|
||||
def __init__(self,
|
||||
num_layers=2,
|
||||
magnitude=5,
|
||||
fillcolor=(128, 128, 128),
|
||||
**kwargs):
|
||||
self.num_layers = num_layers
|
||||
self.magnitude = magnitude
|
||||
self.max_level = 10
|
||||
|
||||
abso_level = self.magnitude / self.max_level
|
||||
self.level_map = {
|
||||
"shearX": 0.3 * abso_level,
|
||||
"shearY": 0.3 * abso_level,
|
||||
"translateX": 150.0 / 331 * abso_level,
|
||||
"translateY": 150.0 / 331 * abso_level,
|
||||
"rotate": 30 * abso_level,
|
||||
"color": 0.9 * abso_level,
|
||||
"posterize": int(4.0 * abso_level),
|
||||
"solarize": 256.0 * abso_level,
|
||||
"contrast": 0.9 * abso_level,
|
||||
"sharpness": 0.9 * abso_level,
|
||||
"brightness": 0.9 * abso_level,
|
||||
"autocontrast": 0,
|
||||
"equalize": 0,
|
||||
"invert": 0
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/
|
||||
# specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot,
|
||||
Image.new("RGBA", rot.size, (128, ) * 4),
|
||||
rot).convert(img.mode)
|
||||
|
||||
rnd_ch_op = random.choice
|
||||
|
||||
self.func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC,
|
||||
fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
|
||||
Image.BICUBIC,
|
||||
fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, magnitude * img.size[0] * rnd_ch_op([-1, 1]), 0, 1, 0),
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size,
|
||||
Image.AFFINE,
|
||||
(1, 0, 0, 0, 1, magnitude * img.size[1] * rnd_ch_op([-1, 1])),
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"posterize": lambda img, magnitude:
|
||||
ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude:
|
||||
ImageOps.solarize(img, magnitude),
|
||||
"contrast": lambda img, magnitude:
|
||||
ImageEnhance.Contrast(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"sharpness": lambda img, magnitude:
|
||||
ImageEnhance.Sharpness(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"brightness": lambda img, magnitude:
|
||||
ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * rnd_ch_op([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude:
|
||||
ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
def __call__(self, img):
|
||||
avaiable_op_names = list(self.level_map.keys())
|
||||
for layer_num in range(self.num_layers):
|
||||
op_name = np.random.choice(avaiable_op_names)
|
||||
img = self.func[op_name](img, self.level_map[op_name])
|
||||
return img
|
||||
|
||||
|
||||
class RandAugment(RawRandAugment):
|
||||
""" RandAugment wrapper to auto fit different img types """
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if six.PY2:
|
||||
super(RandAugment, self).__init__(*args, **kwargs)
|
||||
else:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if not isinstance(img, Image.Image):
|
||||
img = np.ascontiguousarray(img)
|
||||
img = Image.fromarray(img)
|
||||
|
||||
if six.PY2:
|
||||
img = super(RandAugment, self).__call__(img)
|
||||
else:
|
||||
img = super().__call__(img)
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.asarray(img)
|
||||
data['image'] = img
|
||||
return data
|
|
@ -35,16 +35,27 @@ from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
|||
|
||||
|
||||
class RecAug(object):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
def __init__(self, use_tia=True, **kwargsz):
|
||||
self.use_tia = use_tia
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img = warp(img, 10)
|
||||
img = warp(img, 10, self.use_tia)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class ClsResizeImg(object):
|
||||
def __init__(self, image_shape, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img = resize_norm_img(img, self.image_shape)
|
||||
data['image'] = norm_img
|
||||
return data
|
||||
|
||||
|
||||
class RecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
|
@ -194,7 +205,7 @@ class Config:
|
|||
Config
|
||||
"""
|
||||
|
||||
def __init__(self, ):
|
||||
def __init__(self, use_tia):
|
||||
self.anglex = random.random() * 30
|
||||
self.angley = random.random() * 15
|
||||
self.anglez = random.random() * 10
|
||||
|
@ -203,6 +214,7 @@ class Config:
|
|||
self.shearx = random.random() * 0.3
|
||||
self.sheary = random.random() * 0.05
|
||||
self.borderMode = cv2.BORDER_REPLICATE
|
||||
self.use_tia = use_tia
|
||||
|
||||
def make(self, w, h, ang):
|
||||
"""
|
||||
|
@ -219,9 +231,9 @@ class Config:
|
|||
self.w = w
|
||||
self.h = h
|
||||
|
||||
self.perspective = True
|
||||
self.stretch = True
|
||||
self.distort = True
|
||||
self.perspective = self.use_tia
|
||||
self.stretch = self.use_tia
|
||||
self.distort = self.use_tia
|
||||
|
||||
self.crop = True
|
||||
self.affine = False
|
||||
|
@ -317,12 +329,12 @@ def get_warpAffine(config):
|
|||
return rz
|
||||
|
||||
|
||||
def warp(img, ang):
|
||||
def warp(img, ang, use_tia=True):
|
||||
"""
|
||||
warp
|
||||
"""
|
||||
h, w, _ = img.shape
|
||||
config = Config()
|
||||
config = Config(use_tia=use_tia)
|
||||
config.make(w, h, ang)
|
||||
new_img = img
|
||||
|
||||
|
|
|
@ -22,7 +22,10 @@ def build_loss(config):
|
|||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
|
||||
support_dict = ['DBLoss', 'CTCLoss']
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
support_dict = ['DBLoss', 'CTCLoss', 'ClsLoss']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class ClsLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(ClsLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
label = batch[1]
|
||||
loss = self.loss_func(input=predicts, label=label)
|
||||
return {'loss': loss}
|
|
@ -32,5 +32,5 @@ class CTCLoss(nn.Layer):
|
|||
labels = batch[1].astype("int32")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
|
||||
loss = loss.mean()
|
||||
loss = loss.mean() # sum
|
||||
return {'loss': loss}
|
||||
|
|
|
@ -25,8 +25,9 @@ __all__ = ['build_metric']
|
|||
def build_metric(config):
|
||||
from .det_metric import DetMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
|
||||
support_dict = ['DetMetric', 'RecMetric']
|
||||
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class ClsMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred_label, *args, **kwargs):
|
||||
preds, labels = pred_label
|
||||
correct_num = 0
|
||||
all_num = 0
|
||||
for (pred, pred_conf), (target, _) in zip(preds, labels):
|
||||
if pred == target:
|
||||
correct_num += 1
|
||||
all_num += 1
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
return {'acc': correct_num / all_num, }
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metircs {
|
||||
'acc': 0,
|
||||
'norm_edit_dis': 0,
|
||||
}
|
||||
"""
|
||||
acc = self.correct_num / self.all_num
|
||||
self.reset()
|
||||
return {'acc': acc}
|
||||
|
||||
def reset(self):
|
||||
self.correct_num = 0
|
||||
self.all_num = 0
|
|
@ -20,7 +20,7 @@ def build_backbone(config, model_type):
|
|||
from .det_mobilenet_v3 import MobileNetV3
|
||||
from .det_resnet_vd import ResNet
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
|
||||
elif model_type == 'rec':
|
||||
elif model_type == 'rec' or model_type == 'cls':
|
||||
from .rec_mobilenet_v3 import MobileNetV3
|
||||
from .rec_resnet_vd import ResNet
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN']
|
||||
|
|
|
@ -136,13 +136,3 @@ class MobileNetV3(nn.Layer):
|
|||
x = self.conv2(x)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import paddle
|
||||
paddle.disable_static()
|
||||
x = paddle.zeros((1, 3, 32, 320))
|
||||
x = paddle.to_variable(x)
|
||||
net = MobileNetV3(model_name='small', small_stride=[1, 2, 2, 2])
|
||||
y = net(x)
|
||||
print(y.shape)
|
||||
|
|
|
@ -21,7 +21,10 @@ def build_head(config):
|
|||
|
||||
# rec head
|
||||
from .rec_ctc_head import CTCHead
|
||||
support_dict = ['DBHead', 'CTCHead']
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = ['DBHead', 'CTCHead', 'ClsHead']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class ClsHead(nn.Layer):
|
||||
"""
|
||||
Class orientation
|
||||
|
||||
Args:
|
||||
|
||||
params(dict): super parameters for build Class network
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, class_dim, **kwargs):
|
||||
super(ClsHead, self).__init__()
|
||||
self.pool = nn.AdaptiveAvgPool2D(1)
|
||||
stdv = 1.0 / math.sqrt(in_channels * 1.0)
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
class_dim,
|
||||
weight_attr=ParamAttr(
|
||||
name="fc_0.w_0",
|
||||
initializer=nn.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc_0.b_0"), )
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(x)
|
||||
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
|
||||
x = self.fc(x)
|
||||
if not self.training:
|
||||
x = F.softmax(x, axis=1)
|
||||
return x
|
|
@ -25,8 +25,11 @@ __all__ = ['build_post_process']
|
|||
def build_post_process(config, global_config=None):
|
||||
from .db_postprocess import DBPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
|
||||
|
||||
support_dict = ['DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode']
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
|
||||
support_dict = [
|
||||
'DBPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
|
||||
|
||||
class ClsPostProcess(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, label_list, **kwargs):
|
||||
super(ClsPostProcess, self).__init__()
|
||||
self.label_list = label_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
pred_idxs = preds.argmax(axis=1)
|
||||
decode_out = [(self.label_list[idx], preds[i, idx])
|
||||
for i, idx in enumerate(pred_idxs)]
|
||||
if label is None:
|
||||
return decode_out
|
||||
label = [(self.label_list[idx], 1.0) for idx in label]
|
||||
return decode_out, label
|
|
@ -81,7 +81,7 @@ class BaseRecLabelDecode(object):
|
|||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text, conf_list))
|
||||
result_list.append((text, np.mean(conf_list)))
|
||||
return result_list
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
|
||||
import paddle.fluid as fluid
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextClassifier(object):
|
||||
def __init__(self, args):
|
||||
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
|
||||
self.cls_batch_num = args.rec_batch_num
|
||||
self.cls_thresh = args.cls_thresh
|
||||
self.use_zero_copy_run = args.use_zero_copy_run
|
||||
postprocess_params = {
|
||||
'name': 'ClsPostProcess',
|
||||
"label_list": args.label_list,
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors = \
|
||||
utility.create_predictor(args, 'cls', logger)
|
||||
|
||||
def resize_norm_img(self, img):
|
||||
imgC, imgH, imgW = self.cls_image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
if self.cls_image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_list = copy.deepcopy(img_list)
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
width_list = []
|
||||
for img in img_list:
|
||||
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||
# Sorting can speed up the cls process
|
||||
indices = np.argsort(np.array(width_list))
|
||||
|
||||
cls_res = [['', 0.0]] * img_num
|
||||
batch_num = self.cls_batch_num
|
||||
predict_time = 0
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
max_wh_ratio = 0
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]])
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
starttime = time.time()
|
||||
|
||||
if self.use_zero_copy_run:
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.zero_copy_run()
|
||||
else:
|
||||
norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
|
||||
self.predictor.run([norm_img_batch])
|
||||
prob_out = self.output_tensors[0].copy_to_cpu()
|
||||
cls_res = self.postprocess_op(prob_out)
|
||||
elapse = time.time() - starttime
|
||||
for rno in range(len(cls_res)):
|
||||
label, score = cls_res[rno]
|
||||
cls_res[indices[beg_img_no + rno]] = [label, score]
|
||||
if '180' in label and score > self.cls_thresh:
|
||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||
img_list[indices[beg_img_no + rno]], 1)
|
||||
return img_list, cls_res, predict_time
|
||||
|
||||
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
text_classifier = TextClassifier(args)
|
||||
valid_image_file_list = []
|
||||
img_list = []
|
||||
for image_file in image_file_list:
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
valid_image_file_list.append(image_file)
|
||||
img_list.append(img)
|
||||
try:
|
||||
img_list, cls_res, predict_time = text_classifier(img_list)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
logger.info(
|
||||
"ERROR!!!! \n"
|
||||
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
|
||||
"If your model has tps module: "
|
||||
"TPS does not support variable shape.\n"
|
||||
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
|
||||
exit()
|
||||
for ino in range(len(img_list)):
|
||||
print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino]))
|
||||
print("Total predict time for %d images, cost: %.3f" %
|
||||
(len(img_list), predict_time))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(utility.parse_args())
|
|
@ -30,6 +30,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
|||
from ppocr.data import create_operators, transform
|
||||
from ppocr.postprocess import build_post_process
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextDetector(object):
|
||||
def __init__(self, args):
|
||||
|
@ -158,9 +160,7 @@ class TextDetector(object):
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = utility.parse_args()
|
||||
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
logger = get_logger()
|
||||
text_detector = TextDetector(args)
|
||||
count = 0
|
||||
total_time = 0
|
||||
|
|
|
@ -13,12 +13,12 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
|
@ -30,6 +30,8 @@ from ppocr.postprocess import build_post_process
|
|||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TextRecognizer(object):
|
||||
def __init__(self, args):
|
||||
|
@ -80,7 +82,7 @@ class TextRecognizer(object):
|
|||
# rec_res = []
|
||||
rec_res = [['', 0.0]] * img_num
|
||||
batch_num = self.rec_batch_num
|
||||
predict_time = 0
|
||||
elapse = 0
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
|
@ -110,7 +112,9 @@ class TextRecognizer(object):
|
|||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
preds = outputs[0]
|
||||
rec_res = self.postprocess_op(preds)
|
||||
rec_result = self.postprocess_op(preds)
|
||||
for rno in range(len(rec_result)):
|
||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||
elapse = time.time() - starttime
|
||||
return rec_res, elapse
|
||||
|
||||
|
@ -147,5 +151,4 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = get_logger()
|
||||
main(utility.parse_args())
|
||||
|
|
|
@ -17,20 +17,17 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
|
|||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
import cv2
|
||||
import tools.infer.predict_det as predict_det
|
||||
import tools.infer.predict_rec as predict_rec
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from PIL import Image
|
||||
import tools.infer.utility as utility
|
||||
from tools.infer.utility import draw_ocr
|
||||
from tools.infer.utility import draw_ocr_box_txt
|
||||
import tools.infer.predict_rec as predict_rec
|
||||
import tools.infer.predict_det as predict_det
|
||||
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
class TextSystem(object):
|
||||
|
@ -153,11 +150,7 @@ def main(args):
|
|||
scores = [rec_res[i][1] for i in range(len(rec_res))]
|
||||
|
||||
draw_img = draw_ocr(
|
||||
image,
|
||||
boxes,
|
||||
txts,
|
||||
scores,
|
||||
drop_score=drop_score)
|
||||
image, boxes, txts, scores, drop_score=drop_score)
|
||||
draw_img_save = "./inference_results/"
|
||||
if not os.path.exists(draw_img_save):
|
||||
os.makedirs(draw_img_save)
|
||||
|
@ -169,4 +162,5 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger = get_logger()
|
||||
main(utility.parse_args())
|
||||
|
|
|
@ -29,48 +29,62 @@ def parse_args():
|
|||
return v.lower() in ("true", "t", "1")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
#params for prediction engine
|
||||
# params for prediction engine
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||
|
||||
#params for text detector
|
||||
# params for text detector
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
parser.add_argument("--det_algorithm", type=str, default='DB')
|
||||
parser.add_argument("--det_model_dir", type=str)
|
||||
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
||||
parser.add_argument("--det_limit_type", type=str, default='max')
|
||||
|
||||
#DB parmas
|
||||
# DB parmas
|
||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
|
||||
|
||||
#EAST parmas
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
||||
|
||||
#SAST parmas
|
||||
# SAST parmas
|
||||
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
|
||||
parser.add_argument("--det_sast_polygon", type=bool, default=False)
|
||||
|
||||
#params for text recognizer
|
||||
# params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||
parser.add_argument("--rec_model_dir", type=str)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||
parser.add_argument("--rec_char_type", type=str, default='ch')
|
||||
parser.add_argument("--rec_batch_num", type=int, default=30)
|
||||
parser.add_argument("--rec_batch_num", type=int, default=6)
|
||||
parser.add_argument("--max_text_length", type=int, default=25)
|
||||
parser.add_argument(
|
||||
"--rec_char_dict_path",
|
||||
type=str,
|
||||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
||||
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
|
||||
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
||||
parser.add_argument(
|
||||
"--vis_font_path", type=str, default="./doc/simfang.ttf")
|
||||
|
||||
# params for text classifier
|
||||
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||
parser.add_argument("--cls_model_dir", type=str)
|
||||
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
||||
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
||||
parser.add_argument("--cls_batch_num", type=int, default=30)
|
||||
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
||||
|
||||
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
|
||||
parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)
|
||||
|
||||
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
import paddle
|
||||
|
||||
from ppocr.data import create_operators, transform
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.utility import get_image_file_list
|
||||
import tools.program as program
|
||||
|
||||
|
||||
def main():
|
||||
global_config = config['Global']
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# build model
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model, logger)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
for op in config['Eval']['dataset']['transforms']:
|
||||
op_name = list(op)[0]
|
||||
if 'Label' in op_name:
|
||||
continue
|
||||
elif op_name == 'KeepKeys':
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
transforms.append(op)
|
||||
global_config['infer_mode'] = True
|
||||
ops = create_operators(transforms, global_config)
|
||||
|
||||
model.eval()
|
||||
for file in get_image_file_list(config['Global']['infer_img']):
|
||||
logger.info("infer_img: {}".format(file))
|
||||
with open(file, 'rb') as f:
|
||||
img = f.read()
|
||||
data = {'image': img}
|
||||
batch = transform(data, ops)
|
||||
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds)
|
||||
for rec_reuslt in post_result:
|
||||
logger.info('\t result: {}'.format(rec_reuslt))
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
main()
|
Loading…
Reference in New Issue