Merge branch 'dygraph' into dygraph
This commit is contained in:
commit
ca14b8655e
|
@ -60,8 +60,10 @@ PostProcess:
|
|||
name: PGPostProcess
|
||||
score_thresh: 0.5
|
||||
mode: fast # fast or slow two ways
|
||||
|
||||
Metric:
|
||||
name: E2EMetric
|
||||
mode: A # two ways for eval, A: label from txt, B: label from gt_mat
|
||||
gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat
|
||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||
main_indicator: f_score_e2e
|
||||
|
@ -70,13 +72,13 @@ Train:
|
|||
dataset:
|
||||
name: PGDataSet
|
||||
data_dir: ./train_data/total_text/train
|
||||
label_file_list: [./train_data/total_text/train/]
|
||||
label_file_list: [./train_data/total_text/train/train.txt]
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- E2ELabelEncode:
|
||||
- E2ELabelEncodeTrain:
|
||||
- PGProcessTrain:
|
||||
batch_size: 14 # same as loader: batch_size_per_card
|
||||
min_crop_size: 24
|
||||
|
@ -94,11 +96,12 @@ Eval:
|
|||
dataset:
|
||||
name: PGDataSet
|
||||
data_dir: ./train_data/total_text/test
|
||||
label_file_list: [./train_data/total_text/test/]
|
||||
label_file_list: [./train_data/total_text/test/test.txt]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- E2ELabelEncodeTest:
|
||||
- E2EResizeForTest:
|
||||
max_side_len: 768
|
||||
- NormalizeImage:
|
||||
|
@ -108,7 +111,7 @@ Eval:
|
|||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'shape', 'img_id']
|
||||
keep_keys: [ 'image', 'shape', 'polys', 'texts', 'ignore_tags', 'img_id']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: True
|
||||
save_res_path: ./output/rec/predicts_chinese_common_v2.0.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: True
|
||||
save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_ic15.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_mv3_none_bilstm_ctc.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_mv3_none_none_ctc.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_att.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_mv3_tps_bilstm_ctc.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_r34_vd_none_bilstm_ctc.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_r34_vd_none_none_ctc.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -19,6 +19,7 @@ Global:
|
|||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_r34_vd_tps_bilstm_ctc.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
|
|
|
@ -20,6 +20,7 @@ Global:
|
|||
num_heads: 8
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_srn.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -51,6 +51,7 @@ public:
|
|||
float &ssid);
|
||||
|
||||
float BoxScoreFast(std::vector<std::vector<float>> box_array, cv::Mat pred);
|
||||
float PolygonScoreAcc(std::vector<cv::Point> contour, cv::Mat pred);
|
||||
|
||||
std::vector<std::vector<std::vector<int>>>
|
||||
BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
|
||||
|
|
|
@ -159,6 +159,39 @@ std::vector<std::vector<float>> PostProcessor::GetMiniBoxes(cv::RotatedRect box,
|
|||
return array;
|
||||
}
|
||||
|
||||
float PostProcessor::PolygonScoreAcc(std::vector<cv::Point> contour,
|
||||
cv::Mat pred){
|
||||
int width = pred.cols;
|
||||
int height = pred.rows;
|
||||
std::vector<float> box_x;
|
||||
std::vector<float> box_y;
|
||||
for(int i=0; i<contour.size(); ++i){
|
||||
box_x.push_back(contour[i].x);
|
||||
box_y.push_back(contour[i].y);
|
||||
}
|
||||
|
||||
int xmin = clamp(int(std::floor(*(std::min_element(box_x.begin(), box_x.end())))), 0, width - 1);
|
||||
int xmax = clamp(int(std::ceil(*(std::max_element(box_x.begin(), box_x.end())))), 0, width - 1);
|
||||
int ymin = clamp(int(std::floor(*(std::min_element(box_y.begin(), box_y.end())))), 0, height - 1);
|
||||
int ymax = clamp(int(std::ceil(*(std::max_element(box_y.begin(), box_y.end())))), 0, height - 1);
|
||||
|
||||
cv::Mat mask;
|
||||
mask = cv::Mat::zeros(ymax - ymin + 1, xmax - xmin + 1, CV_8UC1);
|
||||
|
||||
cv::Point rook_point[contour.size()];
|
||||
for(int i=0; i<contour.size(); ++i){
|
||||
rook_point[i] = cv::Point(int(box_x[i]) - xmin, int(box_y[i]) - ymin);
|
||||
}
|
||||
const cv::Point *ppt[1] = {rook_point};
|
||||
int npt[] = {int(contour.size())};
|
||||
cv::fillPoly(mask, ppt, npt, 1, cv::Scalar(1));
|
||||
|
||||
cv::Mat croppedImg;
|
||||
pred(cv::Rect(xmin, ymin, xmax - xmin + 1, ymax - ymin + 1)).copyTo(croppedImg);
|
||||
float score = cv::mean(croppedImg, mask)[0];
|
||||
return score;
|
||||
}
|
||||
|
||||
float PostProcessor::BoxScoreFast(std::vector<std::vector<float>> box_array,
|
||||
cv::Mat pred) {
|
||||
auto array = box_array;
|
||||
|
@ -235,6 +268,8 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap,
|
|||
|
||||
float score;
|
||||
score = BoxScoreFast(array, pred);
|
||||
/* compute using polygon*/
|
||||
// score = PolygonScoreAcc(contours[_i], pred);
|
||||
if (score < box_thresh)
|
||||
continue;
|
||||
|
||||
|
|
|
@ -77,19 +77,10 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
|
|||
|
||||
int resize_h = int(float(h) * ratio);
|
||||
int resize_w = int(float(w) * ratio);
|
||||
if (resize_h % 32 == 0)
|
||||
resize_h = resize_h;
|
||||
else if (resize_h / 32 < 1 + 1e-5)
|
||||
resize_h = 32;
|
||||
else
|
||||
resize_h = (resize_h / 32) * 32;
|
||||
|
||||
resize_h = max(int(round(float(resize_h) / 32) * 32), 32);
|
||||
resize_w = max(int(round(float(resize_w) / 32) * 32), 32);
|
||||
|
||||
if (resize_w % 32 == 0)
|
||||
resize_w = resize_w;
|
||||
else if (resize_w / 32 < 1 + 1e-5)
|
||||
resize_w = 32;
|
||||
else
|
||||
resize_w = (resize_w / 32) * 32;
|
||||
if (!use_tensorrt) {
|
||||
cv::resize(img, resize_img, cv::Size(resize_w, resize_h));
|
||||
ratio_h = float(resize_h) / float(h);
|
||||
|
|
|
@ -6,6 +6,7 @@ from __future__ import print_function
|
|||
import os
|
||||
import sys
|
||||
sys.path.insert(0, ".")
|
||||
import copy
|
||||
|
||||
from paddlehub.common.logger import logger
|
||||
from paddlehub.module.module import moduleinfo, runnable, serving
|
||||
|
@ -14,6 +15,8 @@ import paddlehub as hub
|
|||
|
||||
from tools.infer.utility import base64_to_cv2
|
||||
from tools.infer.predict_cls import TextClassifier
|
||||
from tools.infer.utility import parse_args
|
||||
from deploy.hubserving.ocr_cls.params import read_params
|
||||
|
||||
|
||||
@moduleinfo(
|
||||
|
@ -28,8 +31,7 @@ class OCRCls(hub.Module):
|
|||
"""
|
||||
initialize with the necessary elements
|
||||
"""
|
||||
from ocr_cls.params import read_params
|
||||
cfg = read_params()
|
||||
cfg = self.merge_configs()
|
||||
|
||||
cfg.use_gpu = use_gpu
|
||||
if use_gpu:
|
||||
|
@ -48,6 +50,20 @@ class OCRCls(hub.Module):
|
|||
|
||||
self.text_classifier = TextClassifier(cfg)
|
||||
|
||||
def merge_configs(self, ):
|
||||
# deafult cfg
|
||||
backup_argv = copy.deepcopy(sys.argv)
|
||||
sys.argv = sys.argv[:1]
|
||||
cfg = parse_args()
|
||||
|
||||
update_cfg_map = vars(read_params())
|
||||
|
||||
for key in update_cfg_map:
|
||||
cfg.__setattr__(key, update_cfg_map[key])
|
||||
|
||||
sys.argv = copy.deepcopy(backup_argv)
|
||||
return cfg
|
||||
|
||||
def read_images(self, paths=[]):
|
||||
images = []
|
||||
for img_path in paths:
|
||||
|
|
|
@ -7,6 +7,8 @@ import os
|
|||
import sys
|
||||
sys.path.insert(0, ".")
|
||||
|
||||
import copy
|
||||
|
||||
from paddlehub.common.logger import logger
|
||||
from paddlehub.module.module import moduleinfo, runnable, serving
|
||||
import cv2
|
||||
|
@ -15,6 +17,8 @@ import paddlehub as hub
|
|||
|
||||
from tools.infer.utility import base64_to_cv2
|
||||
from tools.infer.predict_det import TextDetector
|
||||
from tools.infer.utility import parse_args
|
||||
from deploy.hubserving.ocr_system.params import read_params
|
||||
|
||||
|
||||
@moduleinfo(
|
||||
|
@ -29,8 +33,7 @@ class OCRDet(hub.Module):
|
|||
"""
|
||||
initialize with the necessary elements
|
||||
"""
|
||||
from ocr_det.params import read_params
|
||||
cfg = read_params()
|
||||
cfg = self.merge_configs()
|
||||
|
||||
cfg.use_gpu = use_gpu
|
||||
if use_gpu:
|
||||
|
@ -49,6 +52,20 @@ class OCRDet(hub.Module):
|
|||
|
||||
self.text_detector = TextDetector(cfg)
|
||||
|
||||
def merge_configs(self, ):
|
||||
# deafult cfg
|
||||
backup_argv = copy.deepcopy(sys.argv)
|
||||
sys.argv = sys.argv[:1]
|
||||
cfg = parse_args()
|
||||
|
||||
update_cfg_map = vars(read_params())
|
||||
|
||||
for key in update_cfg_map:
|
||||
cfg.__setattr__(key, update_cfg_map[key])
|
||||
|
||||
sys.argv = copy.deepcopy(backup_argv)
|
||||
return cfg
|
||||
|
||||
def read_images(self, paths=[]):
|
||||
images = []
|
||||
for img_path in paths:
|
||||
|
|
|
@ -22,6 +22,7 @@ def read_params():
|
|||
cfg.det_db_box_thresh = 0.5
|
||||
cfg.det_db_unclip_ratio = 1.6
|
||||
cfg.use_dilation = False
|
||||
cfg.det_db_score_mode = "fast"
|
||||
|
||||
# #EAST parmas
|
||||
# cfg.det_east_score_thresh = 0.8
|
||||
|
|
|
@ -6,6 +6,7 @@ from __future__ import print_function
|
|||
import os
|
||||
import sys
|
||||
sys.path.insert(0, ".")
|
||||
import copy
|
||||
|
||||
from paddlehub.common.logger import logger
|
||||
from paddlehub.module.module import moduleinfo, runnable, serving
|
||||
|
@ -14,6 +15,8 @@ import paddlehub as hub
|
|||
|
||||
from tools.infer.utility import base64_to_cv2
|
||||
from tools.infer.predict_rec import TextRecognizer
|
||||
from tools.infer.utility import parse_args
|
||||
from deploy.hubserving.ocr_rec.params import read_params
|
||||
|
||||
|
||||
@moduleinfo(
|
||||
|
@ -28,8 +31,7 @@ class OCRRec(hub.Module):
|
|||
"""
|
||||
initialize with the necessary elements
|
||||
"""
|
||||
from ocr_rec.params import read_params
|
||||
cfg = read_params()
|
||||
cfg = self.merge_configs()
|
||||
|
||||
cfg.use_gpu = use_gpu
|
||||
if use_gpu:
|
||||
|
@ -48,6 +50,20 @@ class OCRRec(hub.Module):
|
|||
|
||||
self.text_recognizer = TextRecognizer(cfg)
|
||||
|
||||
def merge_configs(self, ):
|
||||
# deafult cfg
|
||||
backup_argv = copy.deepcopy(sys.argv)
|
||||
sys.argv = sys.argv[:1]
|
||||
cfg = parse_args()
|
||||
|
||||
update_cfg_map = vars(read_params())
|
||||
|
||||
for key in update_cfg_map:
|
||||
cfg.__setattr__(key, update_cfg_map[key])
|
||||
|
||||
sys.argv = copy.deepcopy(backup_argv)
|
||||
return cfg
|
||||
|
||||
def read_images(self, paths=[]):
|
||||
images = []
|
||||
for img_path in paths:
|
||||
|
|
|
@ -6,6 +6,7 @@ from __future__ import print_function
|
|||
import os
|
||||
import sys
|
||||
sys.path.insert(0, ".")
|
||||
import copy
|
||||
|
||||
import time
|
||||
|
||||
|
@ -17,6 +18,8 @@ import paddlehub as hub
|
|||
|
||||
from tools.infer.utility import base64_to_cv2
|
||||
from tools.infer.predict_system import TextSystem
|
||||
from tools.infer.utility import parse_args
|
||||
from deploy.hubserving.ocr_system.params import read_params
|
||||
|
||||
|
||||
@moduleinfo(
|
||||
|
@ -31,8 +34,7 @@ class OCRSystem(hub.Module):
|
|||
"""
|
||||
initialize with the necessary elements
|
||||
"""
|
||||
from ocr_system.params import read_params
|
||||
cfg = read_params()
|
||||
cfg = self.merge_configs()
|
||||
|
||||
cfg.use_gpu = use_gpu
|
||||
if use_gpu:
|
||||
|
@ -51,6 +53,20 @@ class OCRSystem(hub.Module):
|
|||
|
||||
self.text_sys = TextSystem(cfg)
|
||||
|
||||
def merge_configs(self, ):
|
||||
# deafult cfg
|
||||
backup_argv = copy.deepcopy(sys.argv)
|
||||
sys.argv = sys.argv[:1]
|
||||
cfg = parse_args()
|
||||
|
||||
update_cfg_map = vars(read_params())
|
||||
|
||||
for key in update_cfg_map:
|
||||
cfg.__setattr__(key, update_cfg_map[key])
|
||||
|
||||
sys.argv = copy.deepcopy(backup_argv)
|
||||
return cfg
|
||||
|
||||
def read_images(self, paths=[]):
|
||||
images = []
|
||||
for img_path in paths:
|
||||
|
|
|
@ -22,6 +22,7 @@ def read_params():
|
|||
cfg.det_db_box_thresh = 0.5
|
||||
cfg.det_db_unclip_ratio = 1.6
|
||||
cfg.use_dilation = False
|
||||
cfg.det_db_score_mode = "fast"
|
||||
|
||||
#EAST parmas
|
||||
cfg.det_east_score_thresh = 0.8
|
||||
|
|
|
@ -83,19 +83,19 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
|
|||
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
|
||||
|
||||
### 准备数据
|
||||
下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构:
|
||||
下载解压[totaltext](https://paddleocr.bj.bcebos.com/dataset/total_text.tar) 数据集到PaddleOCR/train_data/目录,数据集组织结构:
|
||||
```
|
||||
/PaddleOCR/train_data/total_text/train/
|
||||
|- rgb/ # total_text数据集的训练数据
|
||||
|- gt_0.png
|
||||
|- img11.jpg
|
||||
| ...
|
||||
|- total_text.txt # total_text数据集的训练标注
|
||||
|- train.txt # total_text数据集的训练标注
|
||||
```
|
||||
|
||||
total_text.txt标注文件格式如下,文件名和标注信息中间用"\t"分隔:
|
||||
```
|
||||
" 图像文件名 json.dumps编码的图像标注信息"
|
||||
rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}]
|
||||
rgb/img11.jpg [{"transcription": "ASRAMA", "points": [[214.0, 325.0], [235.0, 308.0], [259.0, 296.0], [286.0, 291.0], [313.0, 295.0], [338.0, 305.0], [362.0, 320.0], [349.0, 347.0], [330.0, 337.0], [310.0, 329.0], [290.0, 324.0], [269.0, 328.0], [249.0, 336.0], [231.0, 346.0]]}, {...}]
|
||||
```
|
||||
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。
|
||||
`transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
|
||||
|
|
|
@ -76,19 +76,19 @@ The visualized end-to-end results are saved to the `./inference_results` folder
|
|||
This section takes the totaltext dataset as an example to introduce the training, evaluation and testing of the end-to-end model in PaddleOCR.
|
||||
|
||||
### Data Preparation
|
||||
Download and unzip [totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) dataset to PaddleOCR/train_data/, dataset organization structure is as follow:
|
||||
Download and unzip [totaltext](https://paddleocr.bj.bcebos.com/dataset/total_text.tar) dataset to PaddleOCR/train_data/, dataset organization structure is as follow:
|
||||
```
|
||||
/PaddleOCR/train_data/total_text/train/
|
||||
|- rgb/ # total_text training data of dataset
|
||||
|- gt_0.png
|
||||
|- img11.png
|
||||
| ...
|
||||
|- total_text.txt # total_text training annotation of dataset
|
||||
|- train.txt # total_text training annotation of dataset
|
||||
```
|
||||
|
||||
total_text.txt: the format of dimension file is as follows,the file name and annotation information are separated by "\t":
|
||||
```
|
||||
" Image file name Image annotation information encoded by json.dumps"
|
||||
rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}]
|
||||
rgb/img11.jpg [{"transcription": "ASRAMA", "points": [[214.0, 325.0], [235.0, 308.0], [259.0, 296.0], [286.0, 291.0], [313.0, 295.0], [338.0, 305.0], [362.0, 320.0], [349.0, 347.0], [330.0, 337.0], [310.0, 329.0], [290.0, 324.0], [269.0, 328.0], [249.0, 336.0], [231.0, 346.0]]}, {...}]
|
||||
```
|
||||
The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries.
|
||||
|
||||
|
|
|
@ -193,6 +193,7 @@ def parse_args(mMain=True, add_help=True):
|
|||
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
|
||||
parser.add_argument("--use_dilation", type=bool, default=False)
|
||||
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
||||
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
|
@ -241,6 +242,7 @@ def parse_args(mMain=True, add_help=True):
|
|||
det_db_box_thresh=0.5,
|
||||
det_db_unclip_ratio=1.6,
|
||||
use_dilation=False,
|
||||
det_db_score_mode="fast",
|
||||
det_east_score_thresh=0.8,
|
||||
det_east_cover_thresh=0.1,
|
||||
det_east_nms_thresh=0.2,
|
||||
|
|
|
@ -187,7 +187,51 @@ class CTCLabelEncode(BaseRecLabelEncode):
|
|||
return dict_character
|
||||
|
||||
|
||||
class E2ELabelEncode(object):
|
||||
class E2ELabelEncodeTest(BaseRecLabelEncode):
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='EN',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(E2ELabelEncodeTest,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, data):
|
||||
import json
|
||||
padnum = len(self.dict)
|
||||
label = data['label']
|
||||
label = json.loads(label)
|
||||
nBox = len(label)
|
||||
boxes, txts, txt_tags = [], [], []
|
||||
for bno in range(0, nBox):
|
||||
box = label[bno]['points']
|
||||
txt = label[bno]['transcription']
|
||||
boxes.append(box)
|
||||
txts.append(txt)
|
||||
if txt in ['*', '###']:
|
||||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
boxes = np.array(boxes, dtype=np.float32)
|
||||
txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||
data['polys'] = boxes
|
||||
data['ignore_tags'] = txt_tags
|
||||
temp_texts = []
|
||||
for text in txts:
|
||||
text = text.lower()
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
text = text + [padnum] * (self.max_text_len - len(text)
|
||||
) # use 36 to pad
|
||||
temp_texts.append(text)
|
||||
data['texts'] = np.array(temp_texts)
|
||||
return data
|
||||
|
||||
|
||||
class E2ELabelEncodeTrain(object):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
|
|
@ -72,6 +72,7 @@ class PGDataSet(Dataset):
|
|||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data_line = self.data_lines[file_idx]
|
||||
img_id = 0
|
||||
try:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
|
@ -79,9 +80,10 @@ class PGDataSet(Dataset):
|
|||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
if self.mode.lower() == 'eval':
|
||||
img_id = int(data_line.split(".")[0][7:])
|
||||
else:
|
||||
img_id = 0
|
||||
try:
|
||||
img_id = int(data_line.split(".")[0][7:])
|
||||
except:
|
||||
img_id = 0
|
||||
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
|
|
|
@ -18,16 +18,18 @@ from __future__ import print_function
|
|||
|
||||
__all__ = ['E2EMetric']
|
||||
|
||||
from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results
|
||||
from ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results
|
||||
from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict
|
||||
|
||||
|
||||
class E2EMetric(object):
|
||||
def __init__(self,
|
||||
mode,
|
||||
gt_mat_dir,
|
||||
character_dict_path,
|
||||
main_indicator='f_score_e2e',
|
||||
**kwargs):
|
||||
self.mode = mode
|
||||
self.gt_mat_dir = gt_mat_dir
|
||||
self.label_list = get_dict(character_dict_path)
|
||||
self.max_index = len(self.label_list)
|
||||
|
@ -35,13 +37,45 @@ class E2EMetric(object):
|
|||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
img_id = batch[2][0]
|
||||
e2e_info_list = [{
|
||||
'points': det_polyon,
|
||||
'texts': pred_str
|
||||
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
|
||||
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
|
||||
self.results.append(result)
|
||||
if self.mode == 'A':
|
||||
gt_polyons_batch = batch[2]
|
||||
temp_gt_strs_batch = batch[3][0]
|
||||
ignore_tags_batch = batch[4]
|
||||
gt_strs_batch = []
|
||||
|
||||
for temp_list in temp_gt_strs_batch:
|
||||
t = ""
|
||||
for index in temp_list:
|
||||
if index < self.max_index:
|
||||
t += self.label_list[index]
|
||||
gt_strs_batch.append(t)
|
||||
|
||||
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
||||
[preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
|
||||
# prepare gt
|
||||
gt_info_list = [{
|
||||
'points': gt_polyon,
|
||||
'text': gt_str,
|
||||
'ignore': ignore_tag
|
||||
} for gt_polyon, gt_str, ignore_tag in
|
||||
zip(gt_polyons, gt_strs, ignore_tags)]
|
||||
# prepare det
|
||||
e2e_info_list = [{
|
||||
'points': det_polyon,
|
||||
'texts': pred_str
|
||||
} for det_polyon, pred_str in
|
||||
zip(pred['points'], pred['texts'])]
|
||||
|
||||
result = get_socre_A(gt_info_list, e2e_info_list)
|
||||
self.results.append(result)
|
||||
else:
|
||||
img_id = batch[5][0]
|
||||
e2e_info_list = [{
|
||||
'points': det_polyon,
|
||||
'texts': pred_str
|
||||
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
|
||||
result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
def get_metric(self):
|
||||
metircs = combine_results(self.results)
|
||||
|
|
|
@ -34,12 +34,18 @@ class DBPostProcess(object):
|
|||
max_candidates=1000,
|
||||
unclip_ratio=2.0,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
**kwargs):
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
self.max_candidates = max_candidates
|
||||
self.unclip_ratio = unclip_ratio
|
||||
self.min_size = 3
|
||||
self.score_mode = score_mode
|
||||
assert score_mode in [
|
||||
"slow", "fast"
|
||||
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
||||
|
||||
self.dilation_kernel = None if not use_dilation else np.array(
|
||||
[[1, 1], [1, 1]])
|
||||
|
||||
|
@ -69,7 +75,10 @@ class DBPostProcess(object):
|
|||
if sside < self.min_size:
|
||||
continue
|
||||
points = np.array(points)
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
if self.score_mode == "fast":
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
else:
|
||||
score = self.box_score_slow(pred, contour)
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
|
@ -120,6 +129,9 @@ class DBPostProcess(object):
|
|||
return box, min(bounding_box[1])
|
||||
|
||||
def box_score_fast(self, bitmap, _box):
|
||||
'''
|
||||
box_score_fast: use bbox mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
box = _box.copy()
|
||||
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
|
||||
|
@ -133,6 +145,27 @@ class DBPostProcess(object):
|
|||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def box_score_slow(self, bitmap, contour):
|
||||
'''
|
||||
box_score_slow: use polyon mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
contour = contour.copy()
|
||||
contour = np.reshape(contour, (-1, 2))
|
||||
|
||||
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
||||
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
||||
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
||||
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
|
||||
contour[:, 0] = contour[:, 0] - xmin
|
||||
contour[:, 1] = contour[:, 1] - ymin
|
||||
|
||||
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict['maps']
|
||||
if isinstance(pred, paddle.Tensor):
|
||||
|
|
|
@ -17,7 +17,144 @@ import scipy.io as io
|
|||
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
||||
|
||||
|
||||
def get_socre(gt_dir, img_id, pred_dict):
|
||||
def get_socre_A(gt_dir, pred_dict):
|
||||
allInputs = 1
|
||||
|
||||
def input_reading_mod(pred_dict):
|
||||
"""This helper reads input from txt files"""
|
||||
det = []
|
||||
n = len(pred_dict)
|
||||
for i in range(n):
|
||||
points = pred_dict[i]['points']
|
||||
text = pred_dict[i]['texts']
|
||||
point = ",".join(map(str, points.reshape(-1, )))
|
||||
det.append([point, text])
|
||||
return det
|
||||
|
||||
def gt_reading_mod(gt_dict):
|
||||
"""This helper reads groundtruths from mat files"""
|
||||
gt = []
|
||||
n = len(gt_dict)
|
||||
for i in range(n):
|
||||
points = gt_dict[i]['points'].tolist()
|
||||
h = len(points)
|
||||
text = gt_dict[i]['text']
|
||||
xx = [
|
||||
np.array(
|
||||
['x:'], dtype='<U2'), 0, np.array(
|
||||
['y:'], dtype='<U2'), 0, np.array(
|
||||
['#'], dtype='<U1'), np.array(
|
||||
['#'], dtype='<U1')
|
||||
]
|
||||
t_x, t_y = [], []
|
||||
for j in range(h):
|
||||
t_x.append(points[j][0])
|
||||
t_y.append(points[j][1])
|
||||
xx[1] = np.array([t_x], dtype='int16')
|
||||
xx[3] = np.array([t_y], dtype='int16')
|
||||
if text != "":
|
||||
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
|
||||
xx[5] = np.array(['c'], dtype='<U1')
|
||||
gt.append(xx)
|
||||
return gt
|
||||
|
||||
def detection_filtering(detections, groundtruths, threshold=0.5):
|
||||
for gt_id, gt in enumerate(groundtruths):
|
||||
if (gt[5] == '#') and (gt[1].shape[1] > 1):
|
||||
gt_x = list(map(int, np.squeeze(gt[1])))
|
||||
gt_y = list(map(int, np.squeeze(gt[3])))
|
||||
for det_id, detection in enumerate(detections):
|
||||
detection_orig = detection
|
||||
detection = [float(x) for x in detection[0].split(',')]
|
||||
detection = list(map(int, detection))
|
||||
det_x = detection[0::2]
|
||||
det_y = detection[1::2]
|
||||
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
|
||||
if det_gt_iou > threshold:
|
||||
detections[det_id] = []
|
||||
|
||||
detections[:] = [item for item in detections if item != []]
|
||||
return detections
|
||||
|
||||
def sigma_calculation(det_x, det_y, gt_x, gt_y):
|
||||
"""
|
||||
sigma = inter_area / gt_area
|
||||
"""
|
||||
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
|
||||
area(gt_x, gt_y)), 2)
|
||||
|
||||
def tau_calculation(det_x, det_y, gt_x, gt_y):
|
||||
if area(det_x, det_y) == 0.0:
|
||||
return 0
|
||||
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
|
||||
area(det_x, det_y)), 2)
|
||||
|
||||
##############################Initialization###################################
|
||||
# global_sigma = []
|
||||
# global_tau = []
|
||||
# global_pred_str = []
|
||||
# global_gt_str = []
|
||||
###############################################################################
|
||||
|
||||
for input_id in range(allInputs):
|
||||
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
|
||||
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
|
||||
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
|
||||
and (input_id != 'Deteval_result_non_curved.txt'):
|
||||
detections = input_reading_mod(pred_dict)
|
||||
groundtruths = gt_reading_mod(gt_dir)
|
||||
detections = detection_filtering(
|
||||
detections,
|
||||
groundtruths) # filters detections overlapping with DC area
|
||||
dc_id = []
|
||||
for i in range(len(groundtruths)):
|
||||
if groundtruths[i][5] == '#':
|
||||
dc_id.append(i)
|
||||
cnt = 0
|
||||
for a in dc_id:
|
||||
num = a - cnt
|
||||
del groundtruths[num]
|
||||
cnt += 1
|
||||
|
||||
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
|
||||
local_tau_table = np.zeros((len(groundtruths), len(detections)))
|
||||
local_pred_str = {}
|
||||
local_gt_str = {}
|
||||
|
||||
for gt_id, gt in enumerate(groundtruths):
|
||||
if len(detections) > 0:
|
||||
for det_id, detection in enumerate(detections):
|
||||
detection_orig = detection
|
||||
detection = [float(x) for x in detection[0].split(',')]
|
||||
detection = list(map(int, detection))
|
||||
pred_seq_str = detection_orig[1].strip()
|
||||
det_x = detection[0::2]
|
||||
det_y = detection[1::2]
|
||||
gt_x = list(map(int, np.squeeze(gt[1])))
|
||||
gt_y = list(map(int, np.squeeze(gt[3])))
|
||||
gt_seq_str = str(gt[4].tolist()[0])
|
||||
|
||||
local_sigma_table[gt_id, det_id] = sigma_calculation(
|
||||
det_x, det_y, gt_x, gt_y)
|
||||
local_tau_table[gt_id, det_id] = tau_calculation(
|
||||
det_x, det_y, gt_x, gt_y)
|
||||
local_pred_str[det_id] = pred_seq_str
|
||||
local_gt_str[gt_id] = gt_seq_str
|
||||
|
||||
global_sigma = local_sigma_table
|
||||
global_tau = local_tau_table
|
||||
global_pred_str = local_pred_str
|
||||
global_gt_str = local_gt_str
|
||||
|
||||
single_data = {}
|
||||
single_data['sigma'] = global_sigma
|
||||
single_data['global_tau'] = global_tau
|
||||
single_data['global_pred_str'] = global_pred_str
|
||||
single_data['global_gt_str'] = global_gt_str
|
||||
return single_data
|
||||
|
||||
|
||||
def get_socre_B(gt_dir, img_id, pred_dict):
|
||||
allInputs = 1
|
||||
|
||||
def input_reading_mod(pred_dict):
|
||||
|
|
|
@ -39,7 +39,10 @@ class TextDetector(object):
|
|||
self.args = args
|
||||
self.det_algorithm = args.det_algorithm
|
||||
pre_process_list = [{
|
||||
'DetResizeForTest': None
|
||||
'DetResizeForTest': {
|
||||
'limit_side_len': args.det_limit_side_len,
|
||||
'limit_type': args.det_limit_type
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
|
@ -62,6 +65,7 @@ class TextDetector(object):
|
|||
postprocess_params["max_candidates"] = 1000
|
||||
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||
postprocess_params["use_dilation"] = args.use_dilation
|
||||
postprocess_params["score_mode"] = args.det_db_score_mode
|
||||
elif self.det_algorithm == "EAST":
|
||||
postprocess_params['name'] = 'EASTPostProcess'
|
||||
postprocess_params["score_thresh"] = args.det_east_score_thresh
|
||||
|
|
|
@ -48,6 +48,7 @@ def parse_args():
|
|||
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
|
||||
parser.add_argument("--max_batch_size", type=int, default=10)
|
||||
parser.add_argument("--use_dilation", type=bool, default=False)
|
||||
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||
|
|
|
@ -73,35 +73,45 @@ def main():
|
|||
global_config['infer_mode'] = True
|
||||
ops = create_operators(transforms, global_config)
|
||||
|
||||
save_res_path = config['Global'].get('save_res_path',
|
||||
"./output/rec/predicts_rec.txt")
|
||||
if not os.path.exists(os.path.dirname(save_res_path)):
|
||||
os.makedirs(os.path.dirname(save_res_path))
|
||||
|
||||
model.eval()
|
||||
for file in get_image_file_list(config['Global']['infer_img']):
|
||||
logger.info("infer_img: {}".format(file))
|
||||
with open(file, 'rb') as f:
|
||||
img = f.read()
|
||||
data = {'image': img}
|
||||
batch = transform(data, ops)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
|
||||
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
|
||||
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
|
||||
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
|
||||
|
||||
others = [
|
||||
paddle.to_tensor(encoder_word_pos_list),
|
||||
paddle.to_tensor(gsrm_word_pos_list),
|
||||
paddle.to_tensor(gsrm_slf_attn_bias1_list),
|
||||
paddle.to_tensor(gsrm_slf_attn_bias2_list)
|
||||
]
|
||||
with open(save_res_path, "w") as fout:
|
||||
for file in get_image_file_list(config['Global']['infer_img']):
|
||||
logger.info("infer_img: {}".format(file))
|
||||
with open(file, 'rb') as f:
|
||||
img = f.read()
|
||||
data = {'image': img}
|
||||
batch = transform(data, ops)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
|
||||
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
|
||||
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
|
||||
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
|
||||
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
preds = model(images, others)
|
||||
else:
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds)
|
||||
for rec_reuslt in post_result:
|
||||
logger.info('\t result: {}'.format(rec_reuslt))
|
||||
others = [
|
||||
paddle.to_tensor(encoder_word_pos_list),
|
||||
paddle.to_tensor(gsrm_word_pos_list),
|
||||
paddle.to_tensor(gsrm_slf_attn_bias1_list),
|
||||
paddle.to_tensor(gsrm_slf_attn_bias2_list)
|
||||
]
|
||||
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
preds = model(images, others)
|
||||
else:
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds)
|
||||
for rec_reuslt in post_result:
|
||||
logger.info('\t result: {}'.format(rec_reuslt))
|
||||
if len(rec_reuslt) >= 2:
|
||||
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str(
|
||||
rec_reuslt[1]) + "\n")
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue