添加分类模型

This commit is contained in:
WenmuZhou 2020-09-01 13:44:51 +08:00
parent 7c09c97d70
commit e11b2108fa
27 changed files with 1164 additions and 19 deletions

43
configs/cls/cls_mv3.yml Executable file
View File

@ -0,0 +1,43 @@
Global:
algorithm: CLS
use_gpu: false
epoch_num: 30
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output/cls_mb3
save_epoch_step: 3
eval_batch_step: 100
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
label_list: [0,180]
reader_yml: ./configs/cls/cls_reader.yml
pretrain_weights:
checkpoints: /Users/zhoujun20/Desktop/code/class_model/cls_mb3_ultra_small_0.35/best_accuracy
save_inference_dir:
infer_img: /Users/zhoujun20/Desktop/code/PaddleOCR/doc/imgs_words/ch/word_1.jpg
Architecture:
function: ppocr.modeling.architectures.cls_model,ClsModel
Backbone:
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
scale: 0.35
model_name: Ultra_small
Head:
function: ppocr.modeling.heads.cls_head,ClsHead
class_dim: 2
Loss:
function: ppocr.modeling.losses.cls_loss,ClsLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999
decay:
function: piecewise_decay
boundaries: [20,30]
decay_rate: 0.1

13
configs/cls/cls_reader.yml Executable file
View File

@ -0,0 +1,13 @@
TrainReader:
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
num_workers: 1
img_set_dir: /
label_file_path: /Users/zhoujun20/Downloads/direction/rotate_ver/train.txt
EvalReader:
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader
img_set_dir: /
label_file_path: /Users/zhoujun20/Downloads/direction/rotate_ver/train.txt
TestReader:
reader_function: ppocr.data.cls.dataset_traversal,SimpleReader

View File

@ -55,6 +55,10 @@ public:
this->char_list_file.assign(config_map_["char_list_file"]);
this->cls_model_dir.assign(config_map_["cls_model_dir"]);
this->cls_thresh = stod(config_map_["cls_thresh"]);
this->visualize = bool(stoi(config_map_["visualize"]));
}
@ -82,6 +86,10 @@ public:
std::string char_list_file;
std::string cls_model_dir;
double cls_thresh;
bool visualize = true;
void PrintConfigInfo();

View File

@ -0,0 +1,79 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
#include "paddle_api.h"
#include "paddle_inference_api.h"
#include <chrono>
#include <iomanip>
#include <iostream>
#include <ostream>
#include <vector>
#include <cstring>
#include <fstream>
#include <numeric>
#include <include/preprocess_op.h>
#include <include/utility.h>
namespace PaddleOCR {
class Classifier {
public:
explicit Classifier(const std::string &model_dir, const bool &use_gpu,
const int &gpu_id, const int &gpu_mem,
const int &cpu_math_library_num_threads,
const bool &use_mkldnn, const double &cls_thresh) {
this->use_gpu_ = use_gpu;
this->gpu_id_ = gpu_id;
this->gpu_mem_ = gpu_mem;
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
this->use_mkldnn_ = use_mkldnn;
this->cls_thresh = cls_thresh;
LoadModel(model_dir);
}
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
cv::Mat Run(cv::Mat &img);
private:
std::shared_ptr<PaddlePredictor> predictor_;
bool use_gpu_ = false;
int gpu_id_ = 0;
int gpu_mem_ = 4000;
int cpu_math_library_num_threads_ = 4;
bool use_mkldnn_ = false;
double cls_thresh = 0.5;
std::vector<float> mean_ = {0.5f, 0.5f, 0.5f};
std::vector<float> scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
bool is_scale_ = true;
// pre-process
ClsResizeImg resize_op_;
Normalize normalize_op_;
Permute permute_op_;
}; // class Classifier
} // namespace PaddleOCR

View File

@ -27,6 +27,7 @@
#include <fstream>
#include <numeric>
#include <include/ocr_cls.h>
#include <include/postprocess_op.h>
#include <include/preprocess_op.h>
#include <include/utility.h>
@ -54,7 +55,8 @@ public:
// Load Paddle inference model
void LoadModel(const std::string &model_dir);
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img);
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img,
Classifier &cls);
private:
std::shared_ptr<PaddlePredictor> predictor_;

View File

@ -56,4 +56,10 @@ public:
const std::vector<int> &rec_image_shape = {3, 32, 320});
};
class ClsResizeImg {
public:
virtual void Run(const cv::Mat &img, cv::Mat &resize_img,
const std::vector<int> &rec_image_shape = {3, 32, 320});
};
} // namespace PaddleOCR

View File

@ -53,6 +53,9 @@ int main(int argc, char **argv) {
config.use_mkldnn, config.max_side_len, config.det_db_thresh,
config.det_db_box_thresh, config.det_db_unclip_ratio,
config.visualize);
Classifier cls(config.cls_model_dir, config.use_gpu, config.gpu_id,
config.gpu_mem, config.cpu_math_library_num_threads,
config.use_mkldnn, config.cls_thresh);
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
config.gpu_mem, config.cpu_math_library_num_threads,
config.use_mkldnn, config.char_list_file);
@ -61,7 +64,7 @@ int main(int argc, char **argv) {
std::vector<std::vector<std::vector<int>>> boxes;
det.Run(srcimg, boxes);
rec.Run(boxes, srcimg);
rec.Run(boxes, srcimg, cls);
auto end = std::chrono::system_clock::now();
auto duration =

View File

@ -0,0 +1,100 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <include/ocr_cls.h>
namespace PaddleOCR {
cv::Mat Classifier::Run(cv::Mat &img) {
cv::Mat src_img;
img.copyTo(src_img);
cv::Mat resize_img;
std::vector<int> rec_image_shape = {3, 32, 100};
int index = 0;
float wh_ratio = float(img.cols) / float(img.rows);
this->resize_op_.Run(img, resize_img, rec_image_shape);
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
this->is_scale_);
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
this->permute_op_.Run(&resize_img, input.data());
auto input_names = this->predictor_->GetInputNames();
auto input_t = this->predictor_->GetInputTensor(input_names[0]);
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
input_t->copy_from_cpu(input.data());
this->predictor_->ZeroCopyRun();
std::vector<float> softmax_out;
std::vector<int64_t> label_out;
auto output_names = this->predictor_->GetOutputNames();
auto softmax_out_t = this->predictor_->GetOutputTensor(output_names[0]);
auto label_out_t = this->predictor_->GetOutputTensor(output_names[1]);
auto softmax_shape_out = softmax_out_t->shape();
auto label_shape_out = label_out_t->shape();
int softmax_out_num =
std::accumulate(softmax_shape_out.begin(), softmax_shape_out.end(), 1,
std::multiplies<int>());
int label_out_num =
std::accumulate(label_shape_out.begin(), label_shape_out.end(), 1,
std::multiplies<int>());
softmax_out.resize(softmax_out_num);
label_out.resize(label_out_num);
softmax_out_t->copy_to_cpu(softmax_out.data());
label_out_t->copy_to_cpu(label_out.data());
int label = label_out[0];
float score = softmax_out[label];
// std::cout << "\nlabel "<<label<<" score: "<<score;
if (label % 2 == 1 && score > this->cls_thresh) {
cv::rotate(src_img, src_img, 1);
}
return src_img;
}
void Classifier::LoadModel(const std::string &model_dir) {
AnalysisConfig config;
config.SetModel(model_dir + "/model", model_dir + "/params");
if (this->use_gpu_) {
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
} else {
config.DisableGpu();
if (this->use_mkldnn_) {
config.EnableMKLDNN();
}
config.SetCpuMathLibraryNumThreads(this->cpu_math_library_num_threads_);
}
// false for zero copy tensor
config.SwitchUseFeedFetchOps(false);
// true for multiple input
config.SwitchSpecifyInputNames(true);
config.SwitchIrOptim(true);
config.EnableMemoryOptim();
config.DisableGlogInfo();
this->predictor_ = CreatePaddlePredictor(config);
}
} // namespace PaddleOCR

View File

@ -17,7 +17,7 @@
namespace PaddleOCR {
void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
cv::Mat &img) {
cv::Mat &img, Classifier &cls) {
cv::Mat srcimg;
img.copyTo(srcimg);
cv::Mat crop_img;
@ -28,6 +28,8 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
for (int i = boxes.size() - 1; i >= 0; i--) {
crop_img = GetRotateCropImage(srcimg, boxes[i]);
crop_img = cls.Run(crop_img);
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
this->resize_op_.Run(crop_img, resize_img, wh_ratio);

View File

@ -116,4 +116,26 @@ void CrnnResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img, float wh_ratio,
cv::INTER_LINEAR);
}
void ClsResizeImg::Run(const cv::Mat &img, cv::Mat &resize_img,
const std::vector<int> &rec_image_shape) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
float ratio = float(img.cols) / float(img.rows);
int resize_w, resize_h;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR);
if (resize_w < imgW) {
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
}
} // namespace PaddleOCR

View File

@ -40,8 +40,8 @@ CXX_LIBS = ${OPENCV_LIBS} -L$(LITE_ROOT)/cxx/lib/ -lpaddle_light_api_shared $(SY
#CXX_LIBS = $(LITE_ROOT)/cxx/lib/libpaddle_api_light_bundled.a $(SYSTEM_LIBS)
ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
ocr_db_crnn: fetch_opencv ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o
$(CC) $(SYSROOT_LINK) $(CXXFLAGS_LINK) ocr_db_crnn.o crnn_process.o db_post_process.o clipper.o cls_process.o -o ocr_db_crnn $(CXX_LIBS) $(LDFLAGS)
ocr_db_crnn.o: ocr_db_crnn.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o ocr_db_crnn.o -c ocr_db_crnn.cc
@ -49,6 +49,9 @@ ocr_db_crnn.o: ocr_db_crnn.cc
crnn_process.o: fetch_opencv crnn_process.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o crnn_process.o -c crnn_process.cc
cls_process.o: fetch_opencv cls_process.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o cls_process.o -c cls_process.cc
db_post_process.o: fetch_clipper fetch_opencv db_post_process.cc
$(CC) $(SYSROOT_COMPLILE) $(CXX_DEFINES) $(CXX_INCLUDES) $(CXX_FLAGS) -o db_post_process.o -c db_post_process.cc
@ -73,5 +76,5 @@ fetch_opencv:
.PHONY: clean
clean:
rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o
rm -f ocr_db_crnn.o clipper.o db_post_process.o crnn_process.o cls_process.o
rm -f ocr_db_crnn

View File

@ -0,0 +1,43 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cls_process.h" //NOLINT
#include <algorithm>
#include <memory>
#include <string>
const std::vector<int> rec_image_shape{3, 32, 100};
cv::Mat ClsResizeImg(cv::Mat img) {
int imgC, imgH, imgW;
imgC = rec_image_shape[0];
imgH = rec_image_shape[1];
imgW = rec_image_shape[2];
float ratio = static_cast<float>(img.cols) / static_cast<float>(img.rows);
int resize_w, resize_h;
if (ceilf(imgH * ratio) > imgW)
resize_w = imgW;
else
resize_w = int(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR);
if (resize_w < imgW) {
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, imgW - resize_w,
cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
}
return resize_img;
}

29
deploy/lite/cls_process.h Normal file
View File

@ -0,0 +1,29 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
#include "math.h" //NOLINT
#include "opencv2/core.hpp"
#include "opencv2/imgcodecs.hpp"
#include "opencv2/imgproc.hpp"
cv::Mat ClsResizeImg(cv::Mat img);

View File

@ -15,6 +15,7 @@
#include "paddle_api.h" // NOLINT
#include <chrono>
#include "cls_process.h"
#include "crnn_process.h"
#include "db_post_process.h"
@ -105,11 +106,55 @@ cv::Mat DetResizeImg(const cv::Mat img, int max_size_len,
return resize_img;
}
cv::Mat RunClsModel(cv::Mat img, std::shared_ptr<PaddlePredictor> predictor_cls,
const float thresh = 0.5) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
cv::Mat srcimg;
img.copyTo(srcimg);
cv::Mat crop_img;
cv::Mat resize_img;
int index = 0;
float wh_ratio =
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
resize_img = ClsResizeImg(crop_img);
resize_img.convertTo(resize_img, CV_32FC3, 1 / 255.f);
const float *dimg = reinterpret_cast<const float *>(resize_img.data);
std::unique_ptr<Tensor> input_tensor0(std::move(predictor_cls->GetInput(0)));
input_tensor0->Resize({1, 3, resize_img.rows, resize_img.cols});
auto *data0 = input_tensor0->mutable_data<float>();
NeonMeanScale(dimg, data0, resize_img.rows * resize_img.cols, mean, scale);
// Run CLS predictor
predictor_cls->Run();
// Get output and run postprocess
std::unique_ptr<const Tensor> softmax_out(
std::move(predictor_cls->GetOutput(0)));
std::unique_ptr<const Tensor> label_out(
std::move(predictor_cls->GetOutput(1)));
auto *softmax_scores = softmax_out->mutable_data<float>();
auto *label_idxs = label_out->data<int64>();
int label_idx = label_idxs[0];
float score = softmax_scores[label_idx];
if (label_idx % 2 == 1 && score > thresh) {
cv::rotate(srcimg, srcimg, 1);
}
return srcimg;
}
void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
std::shared_ptr<PaddlePredictor> predictor_crnn,
std::vector<std::string> &rec_text,
std::vector<float> &rec_text_score,
std::vector<std::string> charactor_dict) {
std::vector<std::string> charactor_dict,
std::shared_ptr<PaddlePredictor> predictor_cls) {
std::vector<float> mean = {0.5f, 0.5f, 0.5f};
std::vector<float> scale = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
@ -121,6 +166,7 @@ void RunRecModel(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat img,
int index = 0;
for (int i = boxes.size() - 1; i >= 0; i--) {
crop_img = GetRotateCropImage(srcimg, boxes[i]);
crop_img = RunClsModel(crop_img, predictor_cls);
float wh_ratio =
static_cast<float>(crop_img.cols) / static_cast<float>(crop_img.rows);
@ -323,8 +369,9 @@ int main(int argc, char **argv) {
}
std::string det_model_file = argv[1];
std::string rec_model_file = argv[2];
std::string img_path = argv[3];
std::string dict_path = argv[4];
std::string cls_model_file = argv[3];
std::string img_path = argv[4];
std::string dict_path = argv[5];
//// load config from txt file
auto Config = LoadConfigTxt("./config.txt");
@ -333,6 +380,7 @@ int main(int argc, char **argv) {
auto det_predictor = loadModel(det_model_file);
auto rec_predictor = loadModel(rec_model_file);
auto cls_predictor = loadModel(cls_model_file);
auto charactor_dict = ReadDict(dict_path);
charactor_dict.push_back(" ");
@ -343,7 +391,7 @@ int main(int argc, char **argv) {
std::vector<std::string> rec_text;
std::vector<float> rec_text_score;
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
charactor_dict);
charactor_dict, cls_predictor);
auto end = std::chrono::system_clock::now();
auto duration =

13
ppocr/data/cls/__init__.py Executable file
View File

@ -0,0 +1,13 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,128 @@
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import random
import numpy as np
import cv2
from ppocr.utils.utility import initial_logger
from ppocr.utils.utility import get_image_file_list
logger = initial_logger()
from ppocr.data.rec.img_tools import warp, resize_norm_img
class SimpleReader(object):
def __init__(self, params):
if params['mode'] != 'train':
self.num_workers = 1
else:
self.num_workers = params['num_workers']
if params['mode'] != 'test':
self.img_set_dir = params['img_set_dir']
self.label_file_path = params['label_file_path']
self.use_gpu = params['use_gpu']
self.image_shape = params['image_shape']
self.mode = params['mode']
self.infer_img = params['infer_img']
self.use_distort = False
self.label_list = params['label_list']
if "distort" in params:
self.use_distort = params['distort'] and params['use_gpu']
if not params['use_gpu']:
logger.info(
"Distort operation can only support in GPU.Distort will be set to False."
)
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
self.drop_last = True
else:
self.batch_size = params['test_batch_size_per_card']
self.drop_last = False
self.use_distort = False
def __call__(self, process_id):
if self.mode != 'train':
process_id = 0
def get_device_num():
if self.use_gpu:
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", 1)
gpu_num = len(gpus.split(','))
return gpu_num
else:
cpu_num = os.environ.get("CPU_NUM", 1)
return int(cpu_num)
def sample_iter_reader():
if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = resize_norm_img(img, self.image_shape)
norm_img = norm_img[np.newaxis, :]
yield norm_img
else:
with open(self.label_file_path, "rb") as fin:
label_infor_list = fin.readlines()
img_num = len(label_infor_list)
img_id_list = list(range(img_num))
random.shuffle(img_id_list)
if sys.platform == "win32" and self.num_workers != 1:
print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.")
self.num_workers = 1
if self.batch_size * get_device_num(
) * self.num_workers > img_num:
raise Exception(
"The number of the whole data ({}) is smaller than the batch_size * devices_num * num_workers ({})".
format(img_num, self.batch_size * get_device_num() *
self.num_workers))
for img_id in range(process_id, img_num, self.num_workers):
label_infor = label_infor_list[img_id_list[img_id]]
substr = label_infor.decode('utf-8').strip("\n").split("\t")
img_path = self.img_set_dir + "/" + substr[0]
img = cv2.imread(img_path)
if img is None:
logger.info("{} does not exist!".format(img_path))
continue
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
label = substr[1]
if self.use_distort:
img = warp(img, 10)
norm_img = resize_norm_img(img, self.image_shape)
norm_img = norm_img[np.newaxis, :]
yield (norm_img, self.label_list.index(int(label)))
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
batch_outs.append(outs)
if len(batch_outs) == self.batch_size:
yield batch_outs
batch_outs = []
if not self.drop_last:
if len(batch_outs) != 0:
yield batch_outs
if self.infer_img is None:
return batch_iter_reader
return sample_iter_reader

View File

@ -0,0 +1,84 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import fluid
from ppocr.utils.utility import create_module
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from copy import deepcopy
class ClsModel(object):
def __init__(self, params):
super(ClsModel, self).__init__()
global_params = params['Global']
self.infer_img = global_params['infer_img']
backbone_params = deepcopy(params["Backbone"])
backbone_params.update(global_params)
self.backbone = create_module(backbone_params['function']) \
(params=backbone_params)
head_params = deepcopy(params["Head"])
head_params.update(global_params)
self.head = create_module(head_params['function']) \
(params=head_params)
loss_params = deepcopy(params["Loss"])
loss_params.update(global_params)
self.loss = create_module(loss_params['function']) \
(params=loss_params)
self.image_shape = global_params['image_shape']
def create_feed(self, mode):
image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1)
if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
feed_list = [image, label]
labels = {'label': label}
loader = fluid.io.DataLoader.from_generator(
feed_list=feed_list,
capacity=64,
use_double_buffer=True,
iterable=False)
else:
labels = None
loader = None
image = fluid.data(name='image', shape=image_shape, dtype='float32')
return image, labels, loader
def __call__(self, mode):
image, labels, loader = self.create_feed(mode)
inputs = image
conv_feas = self.backbone(inputs)
predicts = self.head(conv_feas, labels, mode)
if mode == "train":
loss = self.loss(predicts, labels)
label = labels['label']
acc = fluid.layers.accuracy(predicts['predict'], label, k=1)
outputs = {'total_loss': loss, 'decoded_out': \
predicts['decoded_out'], 'label': label, 'acc': acc}
return loader, outputs
else:
return loader, predicts

View File

@ -0,0 +1,46 @@
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.fluid as fluid
class ClsHead(object):
def __init__(self, params):
super(ClsHead, self).__init__()
self.class_dim = params['class_dim']
def __call__(self, inputs, labels=None, mode=None):
pool = fluid.layers.pool2d(
input=inputs, pool_type='avg', global_pooling=True)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
out = fluid.layers.fc(
input=pool,
size=self.class_dim,
param_attr=fluid.param_attr.ParamAttr(
name="fc_0.w_0",
initializer=fluid.initializer.Uniform(-stdv, stdv)),
bias_attr=fluid.param_attr.ParamAttr(name="fc_0.b_0"))
softmax_out = fluid.layers.softmax(out, use_cudnn=False)
out_label = fluid.layers.argmax(out, axis=1)
predicts = {'predict': softmax_out, 'decoded_out': out_label}
return predicts

View File

@ -0,0 +1,33 @@
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid as fluid
class ClsLoss(object):
def __init__(self, params):
super(ClsLoss, self).__init__()
self.loss_func = fluid.layers.cross_entropy
def __call__(self, predicts, labels):
predict = predicts['predict']
label = labels['label']
# softmax_out = fluid.layers.softmax(predict, use_cudnn=False)
cost = fluid.layers.cross_entropy(input=predict, label=label)
sum_cost = fluid.layers.mean(cost)
return sum_cost

View File

@ -45,10 +45,12 @@ from ppocr.utils.save_load import init_model
from eval_utils.eval_det_utils import eval_det_run
from eval_utils.eval_rec_utils import test_rec_benchmark
from eval_utils.eval_rec_utils import eval_rec_run
from eval_utils.eval_cls_utils import eval_cls_run
def main():
startup_prog, eval_program, place, config, train_alg_type = program.preprocess()
startup_prog, eval_program, place, config, train_alg_type = program.preprocess(
)
eval_build_outputs = program.build(
config, eval_program, startup_prog, mode='test')
eval_fetch_name_list = eval_build_outputs[1]
@ -67,6 +69,14 @@ def main():
'fetch_varname_list':eval_fetch_varname_list}
metrics = eval_det_run(exe, config, eval_info_dict, "eval")
logger.info("Eval result: {}".format(metrics))
elif train_alg_type == 'cls':
eval_reader = reader_main(config=config, mode="eval")
eval_info_dict = {'program': eval_program, \
'reader': eval_reader, \
'fetch_name_list': eval_fetch_name_list, \
'fetch_varname_list': eval_fetch_varname_list}
metrics = eval_cls_run(exe, eval_info_dict)
logger.info("Eval result: {}".format(metrics))
else:
reader_type = config['Global']['reader_yml']
if "benchmark" not in reader_type:

View File

@ -0,0 +1,72 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import numpy as np
import paddle.fluid as fluid
__all__ = ['eval_class_run']
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def eval_cls_run(exe, eval_info_dict):
"""
Run evaluation program, return program outputs.
"""
total_sample_num = 0
total_acc_num = 0
total_batch_num = 0
for data in eval_info_dict['reader']():
img_num = len(data)
img_list = []
label_list = []
for ino in range(img_num):
img_list.append(data[ino][0])
label_list.append(data[ino][1])
img_list = np.concatenate(img_list, axis=0)
outs = exe.run(eval_info_dict['program'], \
feed={'image': img_list}, \
fetch_list=eval_info_dict['fetch_varname_list'], \
return_numpy=False)
softmax_outs = np.array(outs[1])
acc, acc_num = cal_cls_acc(softmax_outs, label_list)
total_acc_num += acc_num
total_sample_num += len(label_list)
# logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
total_batch_num += 1
avg_acc = total_acc_num * 1.0 / total_sample_num
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
"total_sample_num": total_sample_num}
return metrics
def cal_cls_acc(preds, labels):
acc_num = 0
for pred, label in zip(preds, labels):
if pred == label:
acc_num += 1
return acc_num / len(preds), acc_num

143
tools/infer/predict_cls.py Executable file
View File

@ -0,0 +1,143 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
import cv2
import copy
import numpy as np
import math
import time
class TextClassifier(object):
def __init__(self, args):
self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, mode="cls")
self.cls_image_shape = [int(v) for v in args.cls_image_shape.split(",")]
self.cls_batch_num = args.rec_batch_num
self.label_list = args.label_list
def resize_norm_img(self, img):
imgC, imgH, imgW = self.cls_image_shape
h = img.shape[0]
w = img.shape[1]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if self.cls_image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def __call__(self, img_list):
img_list = copy.deepcopy(img_list)
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
for img in img_list:
width_list.append(img.shape[1] / float(img.shape[0]))
# Sorting can speed up the cls process
indices = np.argsort(np.array(width_list))
cls_res = [['', 0.0]] * img_num
batch_num = self.cls_batch_num
predict_time = 0
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
h, w = img_list[indices[ino]].shape[0:2]
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
norm_img = self.resize_norm_img(img_list[indices[ino]])
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run()
prob_out = self.output_tensors[0].copy_to_cpu()
label_out = self.output_tensors[1].copy_to_cpu()
elapse = time.time() - starttime
predict_time += elapse
for rno in range(len(label_out)):
label_idx = label_out[rno]
score = prob_out[rno][label_idx]
label = self.label_list[label_idx]
cls_res[indices[beg_img_no + rno]] = [label, score]
if label == 180:
img_list[indices[beg_img_no + rno]] = cv2.rotate(
img_list[indices[beg_img_no + rno]], 1)
return img_list, cls_res, predict_time
def main(args):
image_file_list = get_image_file_list(args.image_dir)
text_classifier = TextClassifier(args)
valid_image_file_list = []
img_list = []
for image_file in image_file_list[:10]:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
valid_image_file_list.append(image_file)
img_list.append(img)
try:
img_list, cls_res, predict_time = text_classifier(img_list)
print(cls_res)
from matplotlib import pyplot as plt
for img, angle in zip(img_list, cls_res):
plt.title(str(angle))
plt.imshow(img)
plt.show()
except Exception as e:
print(e)
exit()
for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino]))
print("Total predict time for %d images:%.3f" %
(len(img_list), predict_time))
if __name__ == "__main__":
main(utility.parse_args())

View File

@ -13,16 +13,19 @@
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import tools.infer.utility as utility
from ppocr.utils.utility import initial_logger
logger = initial_logger()
import cv2
import tools.infer.predict_det as predict_det
import tools.infer.predict_rec as predict_rec
import tools.infer.predict_cls as predict_cls
import copy
import numpy as np
import math
@ -37,6 +40,7 @@ class TextSystem(object):
def __init__(self, args):
self.text_detector = predict_det.TextDetector(args)
self.text_recognizer = predict_rec.TextRecognizer(args)
self.text_classifier = predict_cls.TextClassifier(args)
def get_rotate_crop_image(self, img, points):
'''
@ -91,7 +95,10 @@ class TextSystem(object):
tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop)
rec_res, elapse = self.text_recognizer(img_crop_list)
img_rotate_list, angle_list, elapse = self.text_classifier(
img_crop_list)
print("cls num : {}, elapse : {}".format(len(img_rotate_list), elapse))
rec_res, elapse = self.text_recognizer(img_rotate_list)
print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse))
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
return dt_boxes, rec_res
@ -110,8 +117,8 @@ def sorted_boxes(dt_boxes):
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
tmp = _boxes[i]
_boxes[i] = _boxes[i + 1]
_boxes[i + 1] = tmp

View File

@ -65,6 +65,13 @@ def parse_args():
type=str,
default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=bool, default=True)
# params for text classifier
parser.add_argument("--cls_model_dir", type=str)
parser.add_argument("--cls_image_shape", type=str, default="3, 32, 100")
parser.add_argument("--label_list", type=list, default=[0, 180])
parser.add_argument("--cls_batch_num", type=int, default=30)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
return parser.parse_args()
@ -72,6 +79,8 @@ def parse_args():
def create_predictor(args, mode):
if mode == "det":
model_dir = args.det_model_dir
elif mode == 'cls':
model_dir = args.cls_model_dir
else:
model_dir = args.rec_model_dir

109
tools/infer_cls.py Executable file
View File

@ -0,0 +1,109 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import os
import sys
__dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
def set_paddle_flags(**kwargs):
for key, value in kwargs.items():
if os.environ.get(key, None) is None:
os.environ[key] = str(value)
# NOTE(paddle-dev): All of these flags should be
# set before `import paddle`. Otherwise, it would
# not take any effect.
set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
)
import tools.program as program
from paddle import fluid
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from ppocr.data.reader_main import reader_main
from ppocr.utils.save_load import init_model
from ppocr.utils.utility import create_module
from ppocr.utils.utility import get_image_file_list
def main():
config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt)
logger.info(config)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
# check_gpu(use_gpu)
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
rec_model = create_module(config['Architecture']['function'])(params=config)
startup_prog = fluid.Program()
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
_, outputs = rec_model(mode="test")
fetch_name_list = list(outputs.keys())
fetch_varname_list = [outputs[v].name for v in fetch_name_list]
eval_prog = eval_prog.clone(for_test=True)
exe.run(startup_prog)
init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')()
infer_img = config['Global']['infer_img']
infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list)
if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num):
logger.info("infer_img:%s" % infer_list[i])
img = next(blobs)
predict = exe.run(program=eval_prog,
feed={"image": img},
fetch_list=fetch_varname_list,
return_numpy=False)
for k in predict:
k = np.array(k)
print(k)
# save for inference model
target_var = []
for key, values in outputs.items():
target_var.append(values)
fluid.io.save_inference_model(
"./output",
feeded_var_names=['image'],
target_vars=target_var,
executor=exe,
main_program=eval_prog,
model_filename="model",
params_filename="params")
if __name__ == '__main__':
parser = program.ArgsParser()
FLAGS = parser.parse_args()
main()

View File

@ -30,6 +30,7 @@ import time
from ppocr.utils.stats import TrainingStats
from eval_utils.eval_det_utils import eval_det_run
from eval_utils.eval_rec_utils import eval_rec_run
from eval_utils.eval_cls_utils import eval_cls_run
from ppocr.utils.save_load import save_model
import numpy as np
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
@ -398,6 +399,87 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
return
def train_eval_cls_run(config, exe, train_info_dict, eval_info_dict):
train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step']
start_eval_step = 0
if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
start_eval_step = eval_batch_step[0]
eval_batch_step = eval_batch_step[1]
logger.info(
"During the training process, after the {}th iteration, an evaluation is run every {} iterations".
format(start_eval_step, eval_batch_step))
save_epoch_step = config['Global']['save_epoch_step']
save_model_dir = config['Global']['save_model_dir']
if not os.path.exists(save_model_dir):
os.makedirs(save_model_dir)
train_stats = TrainingStats(log_smooth_window, ['loss', 'acc'])
best_eval_acc = -1
best_batch_id = 0
best_epoch = 0
train_loader = train_info_dict['reader']
for epoch in range(epoch_num):
train_loader.start()
try:
while True:
t1 = time.time()
train_outs = exe.run(
program=train_info_dict['compile_program'],
fetch_list=train_info_dict['fetch_varname_list'],
return_numpy=False)
fetch_map = dict(
zip(train_info_dict['fetch_name_list'],
range(len(train_outs))))
loss = np.mean(np.array(train_outs[fetch_map['total_loss']]))
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
acc = np.mean(np.array(train_outs[fetch_map['acc']]))
t2 = time.time()
train_batch_elapse = t2 - t1
stats = {'loss': loss, 'acc': acc}
train_stats.update(stats)
if train_batch_id > start_eval_step and (train_batch_id - start_eval_step) \
% print_batch_step == 0:
logs = train_stats.log()
strs = 'epoch: {}, iter: {}, lr: {:.6f}, {}, time: {:.3f}'.format(
epoch, train_batch_id, lr, logs, train_batch_elapse)
logger.info(strs)
if train_batch_id > 0 and\
train_batch_id % eval_batch_step == 0:
model_average = train_info_dict['model_average']
if model_average != None:
model_average.apply(exe)
metrics = eval_cls_run(exe, eval_info_dict)
eval_acc = metrics['avg_acc']
eval_sample_num = metrics['total_sample_num']
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
best_batch_id = train_batch_id
best_epoch = epoch
save_path = save_model_dir + "/best_accuracy"
save_model(train_info_dict['train_program'], save_path)
strs = 'Test iter: {}, acc:{:.6f}, best_acc:{:.6f}, best_epoch:{}, best_batch_id:{}, eval_sample_num:{}'.format(
train_batch_id, eval_acc, best_eval_acc, best_epoch,
best_batch_id, eval_sample_num)
logger.info(strs)
train_batch_id += 1
except fluid.core.EOFException:
train_loader.reset()
if epoch == 0 and save_epoch_step == 1:
save_path = save_model_dir + "/iter_epoch_0"
save_model(train_info_dict['train_program'], save_path)
if epoch > 0 and epoch % save_epoch_step == 0:
save_path = save_model_dir + "/iter_epoch_%d" % (epoch)
save_model(train_info_dict['train_program'], save_path)
return
def preprocess():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
@ -409,7 +491,9 @@ def preprocess():
check_gpu(use_gpu)
alg = config['Global']['algorithm']
assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS'
]
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
config['Global']['char_ops'] = CharacterOps(config['Global'])
@ -419,7 +503,9 @@ def preprocess():
if alg in ['EAST', 'DB', 'SAST']:
train_alg_type = 'det'
else:
elif alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
train_alg_type = 'rec'
else:
train_alg_type = 'cls'
return startup_program, train_program, place, config, train_alg_type

View File

@ -75,7 +75,8 @@ def main():
# dump mode structure
if config['Global']['debug']:
if train_alg_type == 'rec' and 'attention' in config['Global']['loss_type']:
if train_alg_type == 'rec' and 'attention' in config['Global'][
'loss_type']:
logger.warning('Does not suport dump attention...')
else:
summary(train_program)
@ -96,8 +97,10 @@ def main():
if train_alg_type == 'det':
program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict)
else:
elif train_alg_type == 'rec':
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
else:
program.train_eval_cls_run(config, exe, train_info_dict, eval_info_dict)
def test_reader():
@ -119,6 +122,7 @@ def test_reader():
if __name__ == '__main__':
startup_program, train_program, place, config, train_alg_type = program.preprocess()
startup_program, train_program, place, config, train_alg_type = program.preprocess(
)
main()
# test_reader()