diff --git a/configs/cls/cls_mv3.yml b/configs/cls/cls_mv3.yml new file mode 100755 index 00000000..124eb482 --- /dev/null +++ b/configs/cls/cls_mv3.yml @@ -0,0 +1,43 @@ +Global: + algorithm: CLS + use_gpu: false + epoch_num: 30 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: output/cls_mb3 + save_epoch_step: 3 + eval_batch_step: 100 + train_batch_size_per_card: 256 + test_batch_size_per_card: 256 + image_shape: [3, 32, 100] + label_list: [0,180] + reader_yml: ./configs/cls/cls_reader.yml + pretrain_weights: + checkpoints: /Users/zhoujun20/Desktop/code/class_model/cls_mb3_ultra_small_0.35/best_accuracy + save_inference_dir: + infer_img: /Users/zhoujun20/Desktop/code/PaddleOCR/doc/imgs_words/ch/word_1.jpg + +Architecture: + function: ppocr.modeling.architectures.cls_model,ClsModel + +Backbone: + function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3 + scale: 0.35 + model_name: Ultra_small + +Head: + function: ppocr.modeling.heads.cls_head,ClsHead + class_dim: 2 + +Loss: + function: ppocr.modeling.losses.cls_loss,ClsLoss + +Optimizer: + function: ppocr.optimizer,AdamDecay + base_lr: 0.001 + beta1: 0.9 + beta2: 0.999 + decay: + function: piecewise_decay + boundaries: [20,30] + decay_rate: 0.1 diff --git a/configs/cls/cls_reader.yml b/configs/cls/cls_reader.yml new file mode 100755 index 00000000..3002fcbd --- /dev/null +++ b/configs/cls/cls_reader.yml @@ -0,0 +1,13 @@ +TrainReader: + reader_function: ppocr.data.cls.dataset_traversal,SimpleReader + num_workers: 1 + img_set_dir: / + label_file_path: /Users/zhoujun20/Downloads/direction/rotate_ver/train.txt + +EvalReader: + reader_function: ppocr.data.cls.dataset_traversal,SimpleReader + img_set_dir: / + label_file_path: /Users/zhoujun20/Downloads/direction/rotate_ver/train.txt + +TestReader: + reader_function: ppocr.data.cls.dataset_traversal,SimpleReader diff --git a/deploy/cpp_infer/include/config.h b/deploy/cpp_infer/include/config.h index 2adefb73..9dc95eb8 100644 --- a/deploy/cpp_infer/include/config.h +++ b/deploy/cpp_infer/include/config.h @@ -55,6 +55,10 @@ public: this->char_list_file.assign(config_map_["char_list_file"]); + this->cls_model_dir.assign(config_map_["cls_model_dir"]); + + this->cls_thresh = stod(config_map_["cls_thresh"]); + this->visualize = bool(stoi(config_map_["visualize"])); } @@ -82,6 +86,10 @@ public: std::string char_list_file; + std::string cls_model_dir; + + double cls_thresh; + bool visualize = true; void PrintConfigInfo(); diff --git a/deploy/cpp_infer/include/ocr_cls.h b/deploy/cpp_infer/include/ocr_cls.h new file mode 100644 index 00000000..4d8f2a13 --- /dev/null +++ b/deploy/cpp_infer/include/ocr_cls.h @@ -0,0 +1,79 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" +#include "paddle_api.h" +#include "paddle_inference_api.h" +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +namespace PaddleOCR { + +class Classifier { +public: + explicit Classifier(const std::string &model_dir, const bool &use_gpu, + const int &gpu_id, const int &gpu_mem, + const int &cpu_math_library_num_threads, + const bool &use_mkldnn, const double &cls_thresh) { + this->use_gpu_ = use_gpu; + this->gpu_id_ = gpu_id; + this->gpu_mem_ = gpu_mem; + this->cpu_math_library_num_threads_ = cpu_math_library_num_threads; + this->use_mkldnn_ = use_mkldnn; + + this->cls_thresh = cls_thresh; + + LoadModel(model_dir); + } + + // Load Paddle inference model + void LoadModel(const std::string &model_dir); + + cv::Mat Run(cv::Mat &img); + +private: + std::shared_ptr predictor_; + + bool use_gpu_ = false; + int gpu_id_ = 0; + int gpu_mem_ = 4000; + int cpu_math_library_num_threads_ = 4; + bool use_mkldnn_ = false; + + double cls_thresh = 0.5; + + std::vector mean_ = {0.5f, 0.5f, 0.5f}; + std::vector scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; + bool is_scale_ = true; + + // pre-process + ClsResizeImg resize_op_; + Normalize normalize_op_; + Permute permute_op_; + +}; // class Classifier + +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/include/ocr_rec.h b/deploy/cpp_infer/include/ocr_rec.h index 471aeb58..d2180b33 100644 --- a/deploy/cpp_infer/include/ocr_rec.h +++ b/deploy/cpp_infer/include/ocr_rec.h @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -54,7 +55,8 @@ public: // Load Paddle inference model void LoadModel(const std::string &model_dir); - void Run(std::vector>> boxes, cv::Mat &img); + void Run(std::vector>> boxes, cv::Mat &img, + Classifier &cls); private: std::shared_ptr predictor_; diff --git a/deploy/cpp_infer/include/preprocess_op.h b/deploy/cpp_infer/include/preprocess_op.h index 309d7fd4..5cbc5cd7 100644 --- a/deploy/cpp_infer/include/preprocess_op.h +++ b/deploy/cpp_infer/include/preprocess_op.h @@ -56,4 +56,10 @@ public: const std::vector &rec_image_shape = {3, 32, 320}); }; +class ClsResizeImg { +public: + virtual void Run(const cv::Mat &img, cv::Mat &resize_img, + const std::vector &rec_image_shape = {3, 32, 320}); +}; + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/cpp_infer/src/main.cpp b/deploy/cpp_infer/src/main.cpp index 27c98e5b..d5c399fa 100644 --- a/deploy/cpp_infer/src/main.cpp +++ b/deploy/cpp_infer/src/main.cpp @@ -53,6 +53,9 @@ int main(int argc, char **argv) { config.use_mkldnn, config.max_side_len, config.det_db_thresh, config.det_db_box_thresh, config.det_db_unclip_ratio, config.visualize); + Classifier cls(config.cls_model_dir, config.use_gpu, config.gpu_id, + config.gpu_mem, config.cpu_math_library_num_threads, + config.use_mkldnn, config.cls_thresh); CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id, config.gpu_mem, config.cpu_math_library_num_threads, config.use_mkldnn, config.char_list_file); @@ -61,7 +64,7 @@ int main(int argc, char **argv) { std::vector>> boxes; det.Run(srcimg, boxes); - rec.Run(boxes, srcimg); + rec.Run(boxes, srcimg, cls); auto end = std::chrono::system_clock::now(); auto duration = diff --git a/deploy/cpp_infer/src/ocr_cls.cpp b/deploy/cpp_infer/src/ocr_cls.cpp new file mode 100644 index 00000000..15604fe2 --- /dev/null +++ b/deploy/cpp_infer/src/ocr_cls.cpp @@ -0,0 +1,100 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace PaddleOCR { + +cv::Mat Classifier::Run(cv::Mat &img) { + cv::Mat src_img; + img.copyTo(src_img); + cv::Mat resize_img; + + std::vector rec_image_shape = {3, 32, 100}; + int index = 0; + float wh_ratio = float(img.cols) / float(img.rows); + + this->resize_op_.Run(img, resize_img, rec_image_shape); + + this->normalize_op_.Run(&resize_img, this->mean_, this->scale_, + this->is_scale_); + + std::vector input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f); + + this->permute_op_.Run(&resize_img, input.data()); + + auto input_names = this->predictor_->GetInputNames(); + auto input_t = this->predictor_->GetInputTensor(input_names[0]); + input_t->Reshape({1, 3, resize_img.rows, resize_img.cols}); + input_t->copy_from_cpu(input.data()); + + this->predictor_->ZeroCopyRun(); + + std::vector softmax_out; + std::vector label_out; + auto output_names = this->predictor_->GetOutputNames(); + auto softmax_out_t = this->predictor_->GetOutputTensor(output_names[0]); + auto label_out_t = this->predictor_->GetOutputTensor(output_names[1]); + auto softmax_shape_out = softmax_out_t->shape(); + auto label_shape_out = label_out_t->shape(); + + int softmax_out_num = + std::accumulate(softmax_shape_out.begin(), softmax_shape_out.end(), 1, + std::multiplies()); + + int label_out_num = + std::accumulate(label_shape_out.begin(), label_shape_out.end(), 1, + std::multiplies()); + softmax_out.resize(softmax_out_num); + label_out.resize(label_out_num); + + softmax_out_t->copy_to_cpu(softmax_out.data()); + label_out_t->copy_to_cpu(label_out.data()); + + int label = label_out[0]; + float score = softmax_out[label]; + // std::cout << "\nlabel "< this->cls_thresh) { + cv::rotate(src_img, src_img, 1); + } + return src_img; +} + +void Classifier::LoadModel(const std::string &model_dir) { + AnalysisConfig config; + config.SetModel(model_dir + "/model", model_dir + "/params"); + + if (this->use_gpu_) { + config.EnableUseGpu(this->gpu_mem_, this->gpu_id_); + } else { + config.DisableGpu(); + if (this->use_mkldnn_) { + config.EnableMKLDNN(); + } + config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_); + } + + // false for zero copy tensor + config.SwitchUseFeedFetchOps(false); + // true for multiple input + config.SwitchSpecifyInputNames(true); + + config.SwitchIrOptim(true); + + config.EnableMemoryOptim(); + config.DisableGlogInfo(); + + this->predictor_ = CreatePaddlePredictor(config); +} +} // namespace PaddleOCR diff --git a/deploy/cpp_infer/src/ocr_rec.cpp b/deploy/cpp_infer/src/ocr_rec.cpp index bbd7b9b2..8b5eaf9c 100644 --- a/deploy/cpp_infer/src/ocr_rec.cpp +++ b/deploy/cpp_infer/src/ocr_rec.cpp @@ -17,7 +17,7 @@ namespace PaddleOCR { void CRNNRecognizer::Run(std::vector>> boxes, - cv::Mat &img) { + cv::Mat &img, Classifier &cls) { cv::Mat srcimg; img.copyTo(srcimg); cv::Mat crop_img; @@ -28,6 +28,8 @@ void CRNNRecognizer::Run(std::vector>> boxes, for (int i = boxes.size() - 1; i >= 0; i--) { crop_img = GetRotateCropImage(srcimg, boxes[i]); + crop_img = cls.Run(crop_img); + float wh_ratio = float(crop_img.cols) / float(crop_img.rows); this->resize_op_.Run(crop_img, resize_img, wh_ratio); diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp index 0078063e..b44e9d02 100644 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -116,4 +116,26 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio, cv::INTER_LINEAR); } +void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, + const std::vector &rec_image_shape) { + int imgC, imgH, imgW; + imgC = rec_image_shape[0]; + imgH = rec_image_shape[1]; + imgW = rec_image_shape[2]; + + float ratio = float(img.cols) / float(img.rows); + int resize_w, resize_h; + if (ceilf(imgH * ratio) > imgW) + resize_w = imgW; + else + resize_w = int(ceilf(imgH * ratio)); + + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_LINEAR); + if (resize_w < imgW) { + cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w, + cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); + } +} + } // namespace PaddleOCR \ No newline at end of file diff --git a/deploy/lite/Makefile b/deploy/lite/Makefile index 96e05ecf..4c30d644 100644 --- a/deploy/lite/Makefile +++ b/deploy/lite/Makefile @@ -40,8 +40,8 @@ CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SY #CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS) -ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o - $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS) +ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o + $(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS) ocr_db_crnn.o: ocr_db_crnn.cc $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc @@ -49,6 +49,9 @@ ocr_db_crnn.o: ocr_db_crnn.cc crnn_process.o: fetch_opencv crnn_process.cc $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o crnn_process.o -c crnn_process.cc +cls_process.o: fetch_opencv cls_process.cc + $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o cls_process.o -c cls_process.cc + db_post_process.o: fetch_clipper fetch_opencv db_post_process.cc $(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o db_post_process.o -c db_post_process.cc @@ -73,5 +76,5 @@ fetch_opencv: .PHONY: clean clean: - rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o + rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o cls_process.o rm -f ocr_db_crnn diff --git a/deploy/lite/cls_process.cc b/deploy/lite/cls_process.cc new file mode 100644 index 00000000..f522e4bc --- /dev/null +++ b/deploy/lite/cls_process.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cls_process.h" //NOLINT +#include +#include +#include + +const std::vector rec_image_shape{3, 32, 100}; + +cv::Mat ClsResizeImg(cv::Mat img) { + int imgC, imgH, imgW; + imgC = rec_image_shape[0]; + imgH = rec_image_shape[1]; + imgW = rec_image_shape[2]; + + float ratio = static_cast(img.cols) / static_cast(img.rows); + + int resize_w, resize_h; + if (ceilf(imgH * ratio) > imgW) + resize_w = imgW; + else + resize_w = int(ceilf(imgH * ratio)); + cv::Mat resize_img; + cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, + cv::INTER_LINEAR); + if (resize_w < imgW) { + cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w, + cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0)); + } + return resize_img; +} \ No newline at end of file diff --git a/deploy/lite/cls_process.h b/deploy/lite/cls_process.h new file mode 100644 index 00000000..eedeeb9b --- /dev/null +++ b/deploy/lite/cls_process.h @@ -0,0 +1,29 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "math.h" //NOLINT +#include "opencv2/core.hpp" +#include "opencv2/imgcodecs.hpp" +#include "opencv2/imgproc.hpp" + +cv::Mat ClsResizeImg(cv::Mat img); \ No newline at end of file diff --git a/deploy/lite/ocr_db_crnn.cc b/deploy/lite/ocr_db_crnn.cc index c94062fd..fea093c3 100644 --- a/deploy/lite/ocr_db_crnn.cc +++ b/deploy/lite/ocr_db_crnn.cc @@ -15,6 +15,7 @@ #include "paddle_api.h" // NOLINT #include +#include "cls_process.h" #include "crnn_process.h" #include "db_post_process.h" @@ -105,11 +106,55 @@ cv::Mat DetResizeImg(const cv::Mat img, int max_size_len, return resize_img; } +cv::Mat RunClsModel(cv::Mat img, std::shared_ptr predictor_cls, + const float thresh = 0.5) { + std::vector mean = {0.5f, 0.5f, 0.5f}; + std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; + + cv::Mat srcimg; + img.copyTo(srcimg); + cv::Mat crop_img; + cv::Mat resize_img; + + int index = 0; + float wh_ratio = + static_cast(crop_img.cols) / static_cast(crop_img.rows); + + resize_img = ClsResizeImg(crop_img); + resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f); + + const float *dimg = reinterpret_cast(resize_img.data); + + std::unique_ptr input_tensor0(std::move(predictor_cls->GetInput(0))); + input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols}); + auto *data0 = input_tensor0->mutable_data(); + + NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale); + // Run CLS predictor + predictor_cls->Run(); + + // Get output and run postprocess + std::unique_ptr softmax_out( + std::move(predictor_cls->GetOutput(0))); + std::unique_ptr label_out( + std::move(predictor_cls->GetOutput(1))); + auto *softmax_scores = softmax_out->mutable_data(); + auto *label_idxs = label_out->data(); + int label_idx = label_idxs[0]; + float score = softmax_scores[label_idx]; + + if (label_idx % 2 == 1 && score > thresh) { + cv::rotate(srcimg, srcimg, 1); + } + return srcimg; +} + void RunRecModel(std::vector>> boxes, cv::Mat img, std::shared_ptr predictor_crnn, std::vector &rec_text, std::vector &rec_text_score, - std::vector charactor_dict) { + std::vector charactor_dict, + std::shared_ptr predictor_cls) { std::vector mean = {0.5f, 0.5f, 0.5f}; std::vector scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f}; @@ -121,6 +166,7 @@ void RunRecModel(std::vector>> boxes, cv::Mat img, int index = 0; for (int i = boxes.size() - 1; i >= 0; i--) { crop_img = GetRotateCropImage(srcimg, boxes[i]); + crop_img = RunClsModel(crop_img, predictor_cls); float wh_ratio = static_cast(crop_img.cols) / static_cast(crop_img.rows); @@ -323,8 +369,9 @@ int main(int argc, char **argv) { } std::string det_model_file = argv[1]; std::string rec_model_file = argv[2]; - std::string img_path = argv[3]; - std::string dict_path = argv[4]; + std::string cls_model_file = argv[3]; + std::string img_path = argv[4]; + std::string dict_path = argv[5]; //// load config from txt file auto Config = LoadConfigTxt("./config.txt"); @@ -333,6 +380,7 @@ int main(int argc, char **argv) { auto det_predictor = loadModel(det_model_file); auto rec_predictor = loadModel(rec_model_file); + auto cls_predictor = loadModel(cls_model_file); auto charactor_dict = ReadDict(dict_path); charactor_dict.push_back(" "); @@ -343,7 +391,7 @@ int main(int argc, char **argv) { std::vector rec_text; std::vector rec_text_score; RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score, - charactor_dict); + charactor_dict, cls_predictor); auto end = std::chrono::system_clock::now(); auto duration = diff --git a/ppocr/data/cls/__init__.py b/ppocr/data/cls/__init__.py new file mode 100755 index 00000000..abf198b9 --- /dev/null +++ b/ppocr/data/cls/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/ppocr/data/cls/dataset_traversal.py b/ppocr/data/cls/dataset_traversal.py new file mode 100755 index 00000000..fa688f46 --- /dev/null +++ b/ppocr/data/cls/dataset_traversal.py @@ -0,0 +1,128 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import random +import numpy as np +import cv2 + +from ppocr.utils.utility import initial_logger +from ppocr.utils.utility import get_image_file_list + +logger = initial_logger() + +from ppocr.data.rec.img_tools import warp, resize_norm_img + + +class SimpleReader(object): + def __init__(self, params): + if params['mode'] != 'train': + self.num_workers = 1 + else: + self.num_workers = params['num_workers'] + if params['mode'] != 'test': + self.img_set_dir = params['img_set_dir'] + self.label_file_path = params['label_file_path'] + self.use_gpu = params['use_gpu'] + self.image_shape = params['image_shape'] + self.mode = params['mode'] + self.infer_img = params['infer_img'] + self.use_distort = False + self.label_list = params['label_list'] + if "distort" in params: + self.use_distort = params['distort'] and params['use_gpu'] + if not params['use_gpu']: + logger.info( + "Distort operation can only support in GPU.Distort will be set to False." + ) + if params['mode'] == 'train': + self.batch_size = params['train_batch_size_per_card'] + self.drop_last = True + else: + self.batch_size = params['test_batch_size_per_card'] + self.drop_last = False + self.use_distort = False + + def __call__(self, process_id): + if self.mode != 'train': + process_id = 0 + + def get_device_num(): + if self.use_gpu: + gpus = os.environ.get("CUDA_VISIBLE_DEVICES", 1) + gpu_num = len(gpus.split(',')) + return gpu_num + else: + cpu_num = os.environ.get("CPU_NUM", 1) + return int(cpu_num) + + def sample_iter_reader(): + if self.mode != 'train' and self.infer_img is not None: + image_file_list = get_image_file_list(self.infer_img) + for single_img in image_file_list: + img = cv2.imread(single_img) + if img.shape[-1] == 1 or len(list(img.shape)) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + norm_img = resize_norm_img(img, self.image_shape) + norm_img = norm_img[np.newaxis, :] + yield norm_img + else: + with open(self.label_file_path, "rb") as fin: + label_infor_list = fin.readlines() + img_num = len(label_infor_list) + img_id_list = list(range(img_num)) + random.shuffle(img_id_list) + if sys.platform == "win32" and self.num_workers != 1: + print("multiprocess is not fully compatible with Windows." + "num_workers will be 1.") + self.num_workers = 1 + if self.batch_size * get_device_num( + ) * self.num_workers > img_num: + raise Exception( + "The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})". + format(img_num, self.batch_size * get_device_num() * + self.num_workers)) + for img_id in range(process_id, img_num, self.num_workers): + label_infor = label_infor_list[img_id_list[img_id]] + substr = label_infor.decode('utf-8').strip("\n").split("\t") + img_path = self.img_set_dir + "/" + substr[0] + img = cv2.imread(img_path) + if img is None: + logger.info("{} does not exist!".format(img_path)) + continue + if img.shape[-1] == 1 or len(list(img.shape)) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + label = substr[1] + if self.use_distort: + img = warp(img, 10) + norm_img = resize_norm_img(img, self.image_shape) + norm_img = norm_img[np.newaxis, :] + yield (norm_img, self.label_list.index(int(label))) + + def batch_iter_reader(): + batch_outs = [] + for outs in sample_iter_reader(): + batch_outs.append(outs) + if len(batch_outs) == self.batch_size: + yield batch_outs + batch_outs = [] + if not self.drop_last: + if len(batch_outs) != 0: + yield batch_outs + + if self.infer_img is None: + return batch_iter_reader + return sample_iter_reader diff --git a/ppocr/modeling/architectures/cls_model.py b/ppocr/modeling/architectures/cls_model.py new file mode 100755 index 00000000..6df20770 --- /dev/null +++ b/ppocr/modeling/architectures/cls_model.py @@ -0,0 +1,84 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid + +from ppocr.utils.utility import create_module +from ppocr.utils.utility import initial_logger + +logger = initial_logger() +from copy import deepcopy + + +class ClsModel(object): + def __init__(self, params): + super(ClsModel, self).__init__() + global_params = params['Global'] + self.infer_img = global_params['infer_img'] + + backbone_params = deepcopy(params["Backbone"]) + backbone_params.update(global_params) + self.backbone = create_module(backbone_params['function']) \ + (params=backbone_params) + + head_params = deepcopy(params["Head"]) + head_params.update(global_params) + self.head = create_module(head_params['function']) \ + (params=head_params) + + loss_params = deepcopy(params["Loss"]) + loss_params.update(global_params) + self.loss = create_module(loss_params['function']) \ + (params=loss_params) + + self.image_shape = global_params['image_shape'] + + def create_feed(self, mode): + image_shape = deepcopy(self.image_shape) + image_shape.insert(0, -1) + if mode == "train": + image = fluid.data(name='image', shape=image_shape, dtype='float32') + label = fluid.data(name='label', shape=[None, 1], dtype='int64') + feed_list = [image, label] + labels = {'label': label} + loader = fluid.io.DataLoader.from_generator( + feed_list=feed_list, + capacity=64, + use_double_buffer=True, + iterable=False) + else: + labels = None + loader = None + image = fluid.data(name='image', shape=image_shape, dtype='float32') + return image, labels, loader + + def __call__(self, mode): + image, labels, loader = self.create_feed(mode) + inputs = image + conv_feas = self.backbone(inputs) + predicts = self.head(conv_feas, labels, mode) + if mode == "train": + loss = self.loss(predicts, labels) + label = labels['label'] + acc = fluid.layers.accuracy(predicts['predict'], label, k=1) + outputs = {'total_loss': loss, 'decoded_out': \ + predicts['decoded_out'], 'label': label, 'acc': acc} + return loader, outputs + + else: + return loader, predicts diff --git a/ppocr/modeling/heads/cls_head.py b/ppocr/modeling/heads/cls_head.py new file mode 100644 index 00000000..4567adcb --- /dev/null +++ b/ppocr/modeling/heads/cls_head.py @@ -0,0 +1,46 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +import paddle.fluid as fluid + + +class ClsHead(object): + def __init__(self, params): + super(ClsHead, self).__init__() + self.class_dim = params['class_dim'] + + def __call__(self, inputs, labels=None, mode=None): + pool = fluid.layers.pool2d( + input=inputs, pool_type='avg', global_pooling=True) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + + out = fluid.layers.fc( + input=pool, + size=self.class_dim, + param_attr=fluid.param_attr.ParamAttr( + name="fc_0.w_0", + initializer=fluid.initializer.Uniform(-stdv, stdv)), + bias_attr=fluid.param_attr.ParamAttr(name="fc_0.b_0")) + + softmax_out = fluid.layers.softmax(out, use_cudnn=False) + out_label = fluid.layers.argmax(out, axis=1) + predicts = {'predict': softmax_out, 'decoded_out': out_label} + return predicts diff --git a/ppocr/modeling/losses/cls_loss.py b/ppocr/modeling/losses/cls_loss.py new file mode 100755 index 00000000..c187dce3 --- /dev/null +++ b/ppocr/modeling/losses/cls_loss.py @@ -0,0 +1,33 @@ +# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.fluid as fluid + + +class ClsLoss(object): + def __init__(self, params): + super(ClsLoss, self).__init__() + self.loss_func = fluid.layers.cross_entropy + + def __call__(self, predicts, labels): + predict = predicts['predict'] + label = labels['label'] + # softmax_out = fluid.layers.softmax(predict, use_cudnn=False) + cost = fluid.layers.cross_entropy(input=predict, label=label) + sum_cost = fluid.layers.mean(cost) + return sum_cost diff --git a/tools/eval.py b/tools/eval.py index edd84a9d..041e825e 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -45,10 +45,12 @@ from ppocr.utils.save_load import init_model from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_rec_utils import test_rec_benchmark from eval_utils.eval_rec_utils import eval_rec_run +from eval_utils.eval_cls_utils import eval_cls_run def main(): - startup_prog, eval_program, place, config, train_alg_type = program.preprocess() + startup_prog, eval_program, place, config, train_alg_type = program.preprocess( + ) eval_build_outputs = program.build( config, eval_program, startup_prog, mode='test') eval_fetch_name_list = eval_build_outputs[1] @@ -67,6 +69,14 @@ def main(): 'fetch_varname_list':eval_fetch_varname_list} metrics = eval_det_run(exe, config, eval_info_dict, "eval") logger.info("Eval result: {}".format(metrics)) + elif train_alg_type == 'cls': + eval_reader = reader_main(config=config, mode="eval") + eval_info_dict = {'program': eval_program, \ + 'reader': eval_reader, \ + 'fetch_name_list': eval_fetch_name_list, \ + 'fetch_varname_list': eval_fetch_varname_list} + metrics = eval_cls_run(exe, eval_info_dict) + logger.info("Eval result: {}".format(metrics)) else: reader_type = config['Global']['reader_yml'] if "benchmark" not in reader_type: diff --git a/tools/eval_utils/eval_cls_utils.py b/tools/eval_utils/eval_cls_utils.py new file mode 100644 index 00000000..80a13111 --- /dev/null +++ b/tools/eval_utils/eval_cls_utils.py @@ -0,0 +1,72 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging +import numpy as np + +import paddle.fluid as fluid + +__all__ = ['eval_class_run'] + +import logging + +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def eval_cls_run(exe, eval_info_dict): + """ + Run evaluation program, return program outputs. + """ + total_sample_num = 0 + total_acc_num = 0 + total_batch_num = 0 + + for data in eval_info_dict['reader'](): + img_num = len(data) + img_list = [] + label_list = [] + for ino in range(img_num): + img_list.append(data[ino][0]) + label_list.append(data[ino][1]) + + img_list = np.concatenate(img_list, axis=0) + outs = exe.run(eval_info_dict['program'], \ + feed={'image': img_list}, \ + fetch_list=eval_info_dict['fetch_varname_list'], \ + return_numpy=False) + softmax_outs = np.array(outs[1]) + + acc, acc_num = cal_cls_acc(softmax_outs, label_list) + total_acc_num += acc_num + total_sample_num += len(label_list) + # logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc)) + total_batch_num += 1 + avg_acc = total_acc_num * 1.0 / total_sample_num + metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \ + "total_sample_num": total_sample_num} + return metrics + + +def cal_cls_acc(preds, labels): + acc_num = 0 + for pred, label in zip(preds, labels): + if pred == label: + acc_num += 1 + return acc_num / len(preds), acc_num diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py new file mode 100755 index 00000000..d4434445 --- /dev/null +++ b/tools/infer/predict_cls.py @@ -0,0 +1,143 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) + +import tools.infer.utility as utility +from ppocr.utils.utility import initial_logger + +logger = initial_logger() +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +import cv2 +import copy +import numpy as np +import math +import time + + +class TextClassifier(object): + def __init__(self, args): + self.predictor, self.input_tensor, self.output_tensors = \ + utility.create_predictor(args, mode="cls") + self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")] + self.cls_batch_num = args.rec_batch_num + self.label_list = args.label_list + + def resize_norm_img(self, img): + imgC, imgH, imgW = self.cls_image_shape + h = img.shape[0] + w = img.shape[1] + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + if self.cls_image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) + padding_im[:, :, 0:resized_w] = resized_image + return padding_im + + def __call__(self, img_list): + img_list = copy.deepcopy(img_list) + img_num = len(img_list) + # Calculate the aspect ratio of all text bars + width_list = [] + for img in img_list: + width_list.append(img.shape[1] / float(img.shape[0])) + # Sorting can speed up the cls process + indices = np.argsort(np.array(width_list)) + + cls_res = [['', 0.0]] * img_num + batch_num = self.cls_batch_num + predict_time = 0 + for beg_img_no in range(0, img_num, batch_num): + end_img_no = min(img_num, beg_img_no + batch_num) + norm_img_batch = [] + max_wh_ratio = 0 + for ino in range(beg_img_no, end_img_no): + h, w = img_list[indices[ino]].shape[0:2] + wh_ratio = w * 1.0 / h + max_wh_ratio = max(max_wh_ratio, wh_ratio) + for ino in range(beg_img_no, end_img_no): + norm_img = self.resize_norm_img(img_list[indices[ino]]) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + norm_img_batch = np.concatenate(norm_img_batch) + norm_img_batch = norm_img_batch.copy() + starttime = time.time() + + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.zero_copy_run() + + prob_out = self.output_tensors[0].copy_to_cpu() + label_out = self.output_tensors[1].copy_to_cpu() + + elapse = time.time() - starttime + predict_time += elapse + for rno in range(len(label_out)): + label_idx = label_out[rno] + score = prob_out[rno][label_idx] + label = self.label_list[label_idx] + cls_res[indices[beg_img_no + rno]] = [label, score] + if label == 180: + img_list[indices[beg_img_no + rno]] = cv2.rotate( + img_list[indices[beg_img_no + rno]], 1) + return img_list, cls_res, predict_time + + +def main(args): + image_file_list = get_image_file_list(args.image_dir) + text_classifier = TextClassifier(args) + valid_image_file_list = [] + img_list = [] + for image_file in image_file_list[:10]: + img, flag = check_and_read_gif(image_file) + if not flag: + img = cv2.imread(image_file) + if img is None: + logger.info("error in loading image:{}".format(image_file)) + continue + valid_image_file_list.append(image_file) + img_list.append(img) + try: + img_list, cls_res, predict_time = text_classifier(img_list) + print(cls_res) + from matplotlib import pyplot as plt + for img, angle in zip(img_list, cls_res): + plt.title(str(angle)) + plt.imshow(img) + plt.show() + except Exception as e: + print(e) + exit() + for ino in range(len(img_list)): + print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino])) + print("Total predict time for %d images:%.3f" % + (len(img_list), predict_time)) + + +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index f8a62679..c34fb963 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -13,16 +13,19 @@ # limitations under the License. import os import sys + __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) import tools.infer.utility as utility from ppocr.utils.utility import initial_logger + logger = initial_logger() import cv2 import tools.infer.predict_det as predict_det import tools.infer.predict_rec as predict_rec +import tools.infer.predict_cls as predict_cls import copy import numpy as np import math @@ -37,6 +40,7 @@ class TextSystem(object): def __init__(self, args): self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) + self.text_classifier = predict_cls.TextClassifier(args) def get_rotate_crop_image(self, img, points): ''' @@ -91,7 +95,10 @@ class TextSystem(object): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) - rec_res, elapse = self.text_recognizer(img_crop_list) + img_rotate_list, angle_list, elapse = self.text_classifier( + img_crop_list) + print("cls num : {}, elapse : {}".format(len(img_rotate_list), elapse)) + rec_res, elapse = self.text_recognizer(img_rotate_list) print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) return dt_boxes, rec_res @@ -110,8 +117,8 @@ def sorted_boxes(dt_boxes): _boxes = list(sorted_boxes) for i in range(num_boxes - 1): - if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): tmp = _boxes[i] _boxes[i] = _boxes[i + 1] _boxes[i + 1] = tmp diff --git a/tools/infer/utility.py b/tools/infer/utility.py index b0a0ec1f..bde7a41c 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -65,6 +65,13 @@ def parse_args(): type=str, default="./ppocr/utils/ppocr_keys_v1.txt") parser.add_argument("--use_space_char", type=bool, default=True) + + # params for text classifier + parser.add_argument("--cls_model_dir", type=str) + parser.add_argument("--cls_image_shape", type=str, default="3, 32, 100") + parser.add_argument("--label_list", type=list, default=[0, 180]) + parser.add_argument("--cls_batch_num", type=int, default=30) + parser.add_argument("--enable_mkldnn", type=bool, default=False) return parser.parse_args() @@ -72,6 +79,8 @@ def parse_args(): def create_predictor(args, mode): if mode == "det": model_dir = args.det_model_dir + elif mode == 'cls': + model_dir = args.cls_model_dir else: model_dir = args.rec_model_dir diff --git a/tools/infer_cls.py b/tools/infer_cls.py new file mode 100755 index 00000000..443b1e05 --- /dev/null +++ b/tools/infer_cls.py @@ -0,0 +1,109 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import os +import sys +__dir__ = os.path.dirname(__file__) +sys.path.append(__dir__) +sys.path.append(os.path.join(__dir__, '..')) + + +def set_paddle_flags(**kwargs): + for key, value in kwargs.items(): + if os.environ.get(key, None) is None: + os.environ[key] = str(value) + + +# NOTE(paddle-dev): All of these flags should be +# set before `import paddle`. Otherwise, it would +# not take any effect. +set_paddle_flags( + FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory +) + +import tools.program as program +from paddle import fluid +from ppocr.utils.utility import initial_logger +logger = initial_logger() +from ppocr.data.reader_main import reader_main +from ppocr.utils.save_load import init_model +from ppocr.utils.utility import create_module +from ppocr.utils.utility import get_image_file_list + + +def main(): + config = program.load_config(FLAGS.config) + program.merge_config(FLAGS.opt) + logger.info(config) + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + # check_gpu(use_gpu) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + rec_model = create_module(config['Architecture']['function'])(params=config) + startup_prog = fluid.Program() + eval_prog = fluid.Program() + with fluid.program_guard(eval_prog, startup_prog): + with fluid.unique_name.guard(): + _, outputs = rec_model(mode="test") + fetch_name_list = list(outputs.keys()) + fetch_varname_list = [outputs[v].name for v in fetch_name_list] + eval_prog = eval_prog.clone(for_test=True) + exe.run(startup_prog) + + init_model(config, eval_prog, exe) + + blobs = reader_main(config, 'test')() + infer_img = config['Global']['infer_img'] + infer_list = get_image_file_list(infer_img) + max_img_num = len(infer_list) + if len(infer_list) == 0: + logger.info("Can not find img in infer_img dir.") + for i in range(max_img_num): + logger.info("infer_img:%s" % infer_list[i]) + img = next(blobs) + predict = exe.run(program=eval_prog, + feed={"image": img}, + fetch_list=fetch_varname_list, + return_numpy=False) + for k in predict: + k = np.array(k) + print(k) + # save for inference model + target_var = [] + for key, values in outputs.items(): + target_var.append(values) + + fluid.io.save_inference_model( + "./output", + feeded_var_names=['image'], + target_vars=target_var, + executor=exe, + main_program=eval_prog, + model_filename="model", + params_filename="params") + + +if __name__ == '__main__': + parser = program.ArgsParser() + FLAGS = parser.parse_args() + main() diff --git a/tools/program.py b/tools/program.py index 6d8b9937..34e3419c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -30,6 +30,7 @@ import time from ppocr.utils.stats import TrainingStats from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_rec_utils import eval_rec_run +from eval_utils.eval_cls_utils import eval_cls_run from ppocr.utils.save_load import save_model import numpy as np from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps @@ -398,6 +399,87 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): return +def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict): + train_batch_id = 0 + log_smooth_window = config['Global']['log_smooth_window'] + epoch_num = config['Global']['epoch_num'] + print_batch_step = config['Global']['print_batch_step'] + eval_batch_step = config['Global']['eval_batch_step'] + start_eval_step = 0 + if type(eval_batch_step) == list and len(eval_batch_step) >= 2: + start_eval_step = eval_batch_step[0] + eval_batch_step = eval_batch_step[1] + logger.info( + "During the training process, after the {}th iteration, an evaluation is run every {} iterations". + format(start_eval_step, eval_batch_step)) + save_epoch_step = config['Global']['save_epoch_step'] + save_model_dir = config['Global']['save_model_dir'] + if not os.path.exists(save_model_dir): + os.makedirs(save_model_dir) + train_stats = TrainingStats(log_smooth_window, ['loss', 'acc']) + best_eval_acc = -1 + best_batch_id = 0 + best_epoch = 0 + train_loader = train_info_dict['reader'] + for epoch in range(epoch_num): + train_loader.start() + try: + while True: + t1 = time.time() + train_outs = exe.run( + program=train_info_dict['compile_program'], + fetch_list=train_info_dict['fetch_varname_list'], + return_numpy=False) + fetch_map = dict( + zip(train_info_dict['fetch_name_list'], + range(len(train_outs)))) + + loss = np.mean(np.array(train_outs[fetch_map['total_loss']])) + lr = np.mean(np.array(train_outs[fetch_map['lr']])) + acc = np.mean(np.array(train_outs[fetch_map['acc']])) + + t2 = time.time() + train_batch_elapse = t2 - t1 + stats = {'loss': loss, 'acc': acc} + train_stats.update(stats) + if train_batch_id > start_eval_step and (train_batch_id - start_eval_step) \ + % print_batch_step == 0: + logs = train_stats.log() + strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format( + epoch, train_batch_id, lr, logs, train_batch_elapse) + logger.info(strs) + + if train_batch_id > 0 and\ + train_batch_id % eval_batch_step == 0: + model_average = train_info_dict['model_average'] + if model_average != None: + model_average.apply(exe) + metrics = eval_cls_run(exe, eval_info_dict) + eval_acc = metrics['avg_acc'] + eval_sample_num = metrics['total_sample_num'] + if eval_acc > best_eval_acc: + best_eval_acc = eval_acc + best_batch_id = train_batch_id + best_epoch = epoch + save_path = save_model_dir + "/best_accuracy" + save_model(train_info_dict['train_program'], save_path) + strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format( + train_batch_id, eval_acc, best_eval_acc, best_epoch, + best_batch_id, eval_sample_num) + logger.info(strs) + train_batch_id += 1 + + except fluid.core.EOFException: + train_loader.reset() + if epoch == 0 and save_epoch_step == 1: + save_path = save_model_dir + "/iter_epoch_0" + save_model(train_info_dict['train_program'], save_path) + if epoch > 0 and epoch % save_epoch_step == 0: + save_path = save_model_dir + "/iter_epoch_%d" % (epoch) + save_model(train_info_dict['train_program'], save_path) + return + + def preprocess(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) @@ -409,7 +491,9 @@ def preprocess(): check_gpu(use_gpu) alg = config['Global']['algorithm'] - assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN'] + assert alg in [ + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' + ] if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: config['Global']['char_ops'] = CharacterOps(config['Global']) @@ -419,7 +503,9 @@ def preprocess(): if alg in ['EAST', 'DB', 'SAST']: train_alg_type = 'det' - else: + elif alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']: train_alg_type = 'rec' + else: + train_alg_type = 'cls' return startup_program, train_program, place, config, train_alg_type diff --git a/tools/train.py b/tools/train.py index 2ea9d0e0..e477d9c3 100755 --- a/tools/train.py +++ b/tools/train.py @@ -75,7 +75,8 @@ def main(): # dump mode structure if config['Global']['debug']: - if train_alg_type == 'rec' and 'attention' in config['Global']['loss_type']: + if train_alg_type == 'rec' and 'attention' in config['Global'][ + 'loss_type']: logger.warning('Does not suport dump attention...') else: summary(train_program) @@ -96,8 +97,10 @@ def main(): if train_alg_type == 'det': program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict) - else: + elif train_alg_type == 'rec': program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) + else: + program.train_eval_cls_run(config, exe, train_info_dict, eval_info_dict) def test_reader(): @@ -119,6 +122,7 @@ def test_reader(): if __name__ == '__main__': - startup_program, train_program, place, config, train_alg_type = program.preprocess() + startup_program, train_program, place, config, train_alg_type = program.preprocess( + ) main() # test_reader()