add doc、infer_det.py、requirments.txt
This commit is contained in:
parent
e3388a2440
commit
3f2d384faa
|
@ -12,6 +12,7 @@ Global:
|
||||||
image_shape: [3, 640, 640]
|
image_shape: [3, 640, 640]
|
||||||
reader_yml: ./configs/det/det_db_icdar15_reader.yml
|
reader_yml: ./configs/det/det_db_icdar15_reader.yml
|
||||||
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
|
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
|
||||||
|
checkpoints:
|
||||||
save_res_path: ./output/predicts_db.txt
|
save_res_path: ./output/predicts_db.txt
|
||||||
|
|
||||||
Architecture:
|
Architecture:
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
# 文字检测
|
||||||
|
|
||||||
|
本节以icdar15数据集为例,介绍PaddleOCR中检测模型的使用方式。
|
||||||
|
|
||||||
|
## 3.1 数据准备
|
||||||
|
icdar2015数据集可以从[官网](https://rrc.cvc.uab.es/?ch=4&com=downloads)下载到,首次下载需注册。
|
||||||
|
|
||||||
|
将下载到的数据集解压到工作目录下,假设解压在/PaddleOCR/train_data/ 下。另外,PaddleOCR将零散的标注文件整理成单独的标注文件
|
||||||
|
,您可以通过wget的方式进行下载。
|
||||||
|
```
|
||||||
|
wget -P /PaddleOCR/train_data/ 训练标注文件链接
|
||||||
|
wget -P /PaddleOCR/train_data/ 测试标注文件链接
|
||||||
|
```
|
||||||
|
|
||||||
|
解压数据集和下载标注文件后,/PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是:
|
||||||
|
```
|
||||||
|
/PaddleOCR/train_data/
|
||||||
|
└─ icdar_c4_train_imgs/ icdar数据集的训练数据
|
||||||
|
└─ ch4_test_images/ icdar数据集的测试数据
|
||||||
|
└─ train_icdar2015_label.txt icdar数据集的训练标注
|
||||||
|
└─ test_icdar2015_label.txt icdar数据集的测试标注
|
||||||
|
```
|
||||||
|
|
||||||
|
提供的标注文件格式为:
|
||||||
|
```
|
||||||
|
" 图像文件名 json.dumps编码的图像标注信息"
|
||||||
|
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]
|
||||||
|
```
|
||||||
|
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的points表示文本框的位置,如果您想在其他数据集上训练PaddleOCR,
|
||||||
|
可以按照上述形式构建标注文件。
|
||||||
|
|
||||||
|
|
||||||
|
## 3.2 快速启动训练
|
||||||
|
|
||||||
|
首先下载pretrain model,目前支持两种backbone,分别是MobileNetV3、ResNet50,您可以根据需求使用PaddleClas中的模型更换
|
||||||
|
backbone。
|
||||||
|
```
|
||||||
|
# 下载MobileNetV3的预训练模型
|
||||||
|
wget -P /PaddleOCR/pretrained_model/ 模型链接
|
||||||
|
# 下载ResNet50的预训练模型
|
||||||
|
wget -P /PaddleOCR/pretrained_model/ 模型链接
|
||||||
|
```
|
||||||
|
|
||||||
|
**启动训练**
|
||||||
|
```
|
||||||
|
cd PaddleOCR/
|
||||||
|
python3 tools/train.py -c configs/det/det_db_mv3.yml
|
||||||
|
```
|
||||||
|
|
||||||
|
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
|
||||||
|
有关配置文件的详细解释,请参考[链接]()。
|
||||||
|
|
||||||
|
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
|
||||||
|
```
|
||||||
|
python3 tools/train.py -c configs/det/det_db_mv3.yml -o Optimizer.base_lr=0.0001
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 3.3 指标评估
|
||||||
|
|
||||||
|
PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。
|
||||||
|
|
||||||
|
运行如下代码,根据配置文件det_db_mv3.yml中save_res_path指定的测试集检测结果文件,计算评估指标。
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 tools/eval.py -c configs/det/det_db_mv3.yml -o checkpoints ./output/best_accuracy
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3.4 测试检测效果
|
|
@ -0,0 +1,25 @@
|
||||||
|
### 2.1 快速安装
|
||||||
|
|
||||||
|
我们提供了PaddleOCR开发环境的docker,您可以pull我们提供的docker运行PaddleOCR的环境。
|
||||||
|
|
||||||
|
1. 准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。
|
||||||
|
```
|
||||||
|
# 切换到工作目录下
|
||||||
|
cd /home/Projects
|
||||||
|
# 创建一个名字为pdocr的docker容器,并将当前目录映射到容器的/data目录下
|
||||||
|
sudo nvidia-docker run --name pdocr -v $PWD:/data --network=host -it paddlepaddle/paddle:1.7.2-gpu-cuda10.0-cudnn7 /bin/bash
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 克隆PaddleOCR repo代码
|
||||||
|
```
|
||||||
|
apt-get update
|
||||||
|
apt-get install git
|
||||||
|
git clone https://github.com/PaddlePaddle/PaddleOCR
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 安装第三方库
|
||||||
|
```
|
||||||
|
cd PaddleOCR
|
||||||
|
pip3 install --upgrade pip
|
||||||
|
pip3 install -r requirements.txt
|
||||||
|
```
|
|
@ -0,0 +1,3 @@
|
||||||
|
shapely
|
||||||
|
imgaug
|
||||||
|
pyclipper
|
|
@ -64,15 +64,39 @@ class TextSystem(object):
|
||||||
if dt_boxes is None:
|
if dt_boxes is None:
|
||||||
return None, None
|
return None, None
|
||||||
img_crop_list = []
|
img_crop_list = []
|
||||||
|
|
||||||
|
dt_boxes = sorted_boxes(dt_boxes)
|
||||||
|
|
||||||
for bno in range(len(dt_boxes)):
|
for bno in range(len(dt_boxes)):
|
||||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||||
img_crop_list.append(img_crop)
|
img_crop_list.append(img_crop)
|
||||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||||
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
# self.print_draw_crop_rec_res(img_crop_list, rec_res)
|
||||||
return dt_boxes, rec_res
|
return dt_boxes, rec_res
|
||||||
|
|
||||||
|
|
||||||
|
def sorted_boxes(dt_boxes):
|
||||||
|
"""
|
||||||
|
Sort text boxes in order from top to bottom, left to right
|
||||||
|
args:
|
||||||
|
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||||
|
return:
|
||||||
|
sorted boxes(array) with shape [4, 2]
|
||||||
|
"""
|
||||||
|
num_boxes = dt_boxes.shape[0]
|
||||||
|
sorted_boxes = sorted(dt_boxes, key=lambda x: x[0][1])
|
||||||
|
_boxes = list(sorted_boxes)
|
||||||
|
|
||||||
|
for i in range(num_boxes - 1):
|
||||||
|
if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \
|
||||||
|
(_boxes[i + 1][0][0] < _boxes[i][0][0]):
|
||||||
|
tmp = _boxes[i]
|
||||||
|
_boxes[i] = _boxes[i + 1]
|
||||||
|
_boxes[i + 1] = tmp
|
||||||
|
return _boxes
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = utility.parse_args()
|
args = utility.parse_args()
|
||||||
image_file_list = utility.get_image_file_list(args.image_dir)
|
image_file_list = utility.get_image_file_list(args.image_dir)
|
||||||
|
|
|
@ -106,7 +106,7 @@ def create_predictor(args, mode):
|
||||||
# if args.use_fp16 else AnalysisConfig.Precision.Float32,
|
# if args.use_fp16 else AnalysisConfig.Precision.Float32,
|
||||||
# max_batch_size=args.batch_size)
|
# max_batch_size=args.batch_size)
|
||||||
|
|
||||||
config.enable_memory_optim()
|
# config.enable_memory_optim()
|
||||||
# use zero copy
|
# use zero copy
|
||||||
config.switch_use_feed_fetch_ops(False)
|
config.switch_use_feed_fetch_ops(False)
|
||||||
predictor = create_paddle_predictor(config)
|
predictor = create_paddle_predictor(config)
|
||||||
|
@ -136,12 +136,16 @@ if __name__ == '__main__':
|
||||||
args.det_model_dir = root_path + "test_models/public_v1/ch_det_mv3_db"
|
args.det_model_dir = root_path + "test_models/public_v1/ch_det_mv3_db"
|
||||||
|
|
||||||
predictor, input_tensor, output_tensors = create_predictor(args, mode='det')
|
predictor, input_tensor, output_tensors = create_predictor(args, mode='det')
|
||||||
print(predictor.get_input_names())
|
print("det input", predictor.get_input_names())
|
||||||
print(predictor.get_output_names())
|
print("det output", predictor.get_output_names())
|
||||||
print(predictor.program(), file=open("det_program.txt", 'w'))
|
# print(predictor.program(), file=open("det_program.txt", 'w'))
|
||||||
|
outputs = []
|
||||||
|
for output_tensor in output_tensors:
|
||||||
|
output = output_tensor.copy_to_cpu()
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
args.rec_model_dir = root_path + "test_models/public_v1/ch_rec_mv3_crnn/"
|
args.rec_model_dir = root_path + "test_models/public_v1/ch_rec_mv3_crnn/"
|
||||||
rec_predictor, input_tensor, output_tensors = create_predictor(
|
rec_predictor, input_tensor, output_tensors = create_predictor(
|
||||||
args, mode='rec')
|
args, mode='rec')
|
||||||
print(rec_predictor.get_input_names())
|
print("rec input", rec_predictor.get_input_names())
|
||||||
print(rec_predictor.get_output_names())
|
print("rec output", rec_predictor.get_output_names())
|
||||||
|
|
|
@ -0,0 +1,148 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from copy import deepcopy
|
||||||
|
import json
|
||||||
|
|
||||||
|
# from paddle.fluid.contrib.model_stat import summary
|
||||||
|
|
||||||
|
|
||||||
|
def set_paddle_flags(**kwargs):
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if os.environ.get(key, None) is None:
|
||||||
|
os.environ[key] = str(value)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE(paddle-dev): All of these flags should be
|
||||||
|
# set before `import paddle`. Otherwise, it would
|
||||||
|
# not take any effect.
|
||||||
|
set_paddle_flags(
|
||||||
|
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||||
|
)
|
||||||
|
|
||||||
|
from paddle import fluid
|
||||||
|
from ppocr.utils.utility import create_module
|
||||||
|
import program
|
||||||
|
from ppocr.utils.save_load import init_model
|
||||||
|
from ppocr.data.reader_main import reader_main
|
||||||
|
|
||||||
|
from ppocr.utils.utility import initial_logger
|
||||||
|
logger = initial_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def draw_det_res(dt_boxes, config, img_name, ino):
|
||||||
|
if len(dt_boxes) > 0:
|
||||||
|
img_set_path = config['TestReader']['img_set_dir']
|
||||||
|
img_path = img_set_path + img_name
|
||||||
|
import cv2
|
||||||
|
src_im = cv2.imread(img_path)
|
||||||
|
for box in dt_boxes:
|
||||||
|
box = box.astype(np.int32).reshape((-1, 1, 2))
|
||||||
|
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
|
||||||
|
save_det_path = os.path.basename(config['Global'][
|
||||||
|
'save_res_path']) + "/det_results/"
|
||||||
|
if not os.path.exists(save_det_path):
|
||||||
|
os.makedirs(save_det_path)
|
||||||
|
save_path = os.path.join(save_det_path, "det_{}.jpg".format(img_name))
|
||||||
|
cv2.imwrite(save_path, src_im)
|
||||||
|
logger.info("The detected Image saved in {}".format(save_path))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
config = program.load_config(FLAGS.config)
|
||||||
|
program.merge_config(FLAGS.opt)
|
||||||
|
print(config)
|
||||||
|
|
||||||
|
# check if set use_gpu=True in paddlepaddle cpu version
|
||||||
|
use_gpu = config['Global']['use_gpu']
|
||||||
|
program.check_gpu(use_gpu)
|
||||||
|
|
||||||
|
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
|
||||||
|
det_model = create_module(config['Architecture']['function'])(params=config)
|
||||||
|
|
||||||
|
startup_prog = fluid.Program()
|
||||||
|
eval_prog = fluid.Program()
|
||||||
|
with fluid.program_guard(eval_prog, startup_prog):
|
||||||
|
with fluid.unique_name.guard():
|
||||||
|
_, eval_outputs = det_model(mode="test")
|
||||||
|
fetch_name_list = list(eval_outputs.keys())
|
||||||
|
eval_fetch_list = [eval_outputs[v].name for v in fetch_name_list]
|
||||||
|
|
||||||
|
eval_prog = eval_prog.clone(for_test=True)
|
||||||
|
exe.run(startup_prog)
|
||||||
|
|
||||||
|
# load checkpoints
|
||||||
|
checkpoints = config['Global'].get('checkpoints')
|
||||||
|
if checkpoints:
|
||||||
|
path = checkpoints
|
||||||
|
fluid.load(eval_prog, path, exe)
|
||||||
|
logger.info("Finish initing model from {}".format(path))
|
||||||
|
else:
|
||||||
|
raise Exception("{} not exists!".format(checkpoints))
|
||||||
|
|
||||||
|
save_res_path = config['Global']['save_res_path']
|
||||||
|
with open(save_res_path, "wb") as fout:
|
||||||
|
test_reader = reader_main(config=config, mode='test')
|
||||||
|
tackling_num = 0
|
||||||
|
for data in test_reader():
|
||||||
|
img_num = len(data)
|
||||||
|
tackling_num = tackling_num + img_num
|
||||||
|
logger.info("tackling_num:%d", tackling_num)
|
||||||
|
img_list = []
|
||||||
|
ratio_list = []
|
||||||
|
img_name_list = []
|
||||||
|
for ino in range(img_num):
|
||||||
|
img_list.append(data[ino][0])
|
||||||
|
ratio_list.append(data[ino][1])
|
||||||
|
img_name_list.append(data[ino][2])
|
||||||
|
img_list = np.concatenate(img_list, axis=0)
|
||||||
|
outs = exe.run(eval_prog,\
|
||||||
|
feed={'image': img_list},\
|
||||||
|
fetch_list=eval_fetch_list)
|
||||||
|
|
||||||
|
global_params = config['Global']
|
||||||
|
postprocess_params = deepcopy(config["PostProcess"])
|
||||||
|
postprocess_params.update(global_params)
|
||||||
|
postprocess = create_module(postprocess_params['function'])\
|
||||||
|
(params=postprocess_params)
|
||||||
|
dt_boxes_list = postprocess(outs, ratio_list)
|
||||||
|
for ino in range(img_num):
|
||||||
|
dt_boxes = dt_boxes_list[ino]
|
||||||
|
img_name = img_name_list[ino]
|
||||||
|
dt_boxes_json = []
|
||||||
|
for box in dt_boxes:
|
||||||
|
tmp_json = {"transcription": ""}
|
||||||
|
tmp_json['points'] = box.tolist()
|
||||||
|
dt_boxes_json.append(tmp_json)
|
||||||
|
otstr = img_name + "\t" + json.dumps(dt_boxes_json) + "\n"
|
||||||
|
fout.write(otstr.encode())
|
||||||
|
draw_det_res(dt_boxes, config, img_name, ino)
|
||||||
|
|
||||||
|
logger.info("success!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = program.ArgsParser()
|
||||||
|
FLAGS = parser.parse_args()
|
||||||
|
main()
|
Loading…
Reference in New Issue