diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index 0bacf12d..8128bde1 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -60,8 +60,10 @@ PostProcess: name: PGPostProcess score_thresh: 0.5 mode: fast # fast or slow two ways + Metric: name: E2EMetric + mode: A # two ways for eval, A: label from txt, B: label from gt_mat gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat character_dict_path: ppocr/utils/ic15_dict.txt main_indicator: f_score_e2e @@ -70,13 +72,13 @@ Train: dataset: name: PGDataSet data_dir: ./train_data/total_text/train - label_file_list: [./train_data/total_text/train/] + label_file_list: [./train_data/total_text/train/train.txt] ratio_list: [1.0] transforms: - DecodeImage: # load image img_mode: BGR channel_first: False - - E2ELabelEncode: + - E2ELabelEncodeTrain: - PGProcessTrain: batch_size: 14 # same as loader: batch_size_per_card min_crop_size: 24 @@ -94,11 +96,12 @@ Eval: dataset: name: PGDataSet data_dir: ./train_data/total_text/test - label_file_list: [./train_data/total_text/test/] + label_file_list: [./train_data/total_text/test/test.txt] transforms: - DecodeImage: # load image img_mode: RGB channel_first: False + - E2ELabelEncodeTest: - E2EResizeForTest: max_side_len: 768 - NormalizeImage: @@ -108,7 +111,7 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'image', 'shape', 'img_id'] + keep_keys: [ 'image', 'shape', 'polys', 'texts', 'ignore_tags', 'img_id'] loader: shuffle: False drop_last: False diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml index 6a524e22..717c1681 100644 --- a/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml +++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: True + save_res_path: ./output/rec/predicts_chinese_common_v2.0.txt Optimizer: diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml index c96621c5..660465f3 100644 --- a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml +++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: True + save_res_path: ./output/rec/predicts_chinese_lite_v2.0.txt Optimizer: diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml index 5ae47c67..79e3ff88 100644 --- a/configs/rec/rec_icdar15_train.yml +++ b/configs/rec/rec_icdar15_train.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_ic15.txt Optimizer: name: Adam diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 900e98b6..9e0bd23e 100644 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_mv3_none_bilstm_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml index 6d86b90c..904afe11 100644 --- a/configs/rec/rec_mv3_none_none_ctc.yml +++ b/configs/rec/rec_mv3_none_none_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_mv3_none_none_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_mv3_tps_bilstm_att.yml b/configs/rec/rec_mv3_tps_bilstm_att.yml index 33aed74d..feaeb054 100644 --- a/configs/rec/rec_mv3_tps_bilstm_att.yml +++ b/configs/rec/rec_mv3_tps_bilstm_att.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_mv3_tps_bilstm_att.txt Optimizer: diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml index 026c6a9d..65ab23c4 100644 --- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_mv3_tps_bilstm_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml index 4052d426..331bb36e 100644 --- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_r34_vd_none_bilstm_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml index c3e1d9a3..695a4695 100644 --- a/configs/rec/rec_r34_vd_none_none_ctc.yml +++ b/configs/rec/rec_r34_vd_none_none_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_r34_vd_none_none_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_r34_vd_tps_bilstm_att.yml b/configs/rec/rec_r34_vd_tps_bilstm_att.yml index 87a14559..fdd3588c 100644 --- a/configs/rec/rec_r34_vd_tps_bilstm_att.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_att.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt Optimizer: diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml index 9c51962e..67108a6e 100644 --- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml @@ -19,6 +19,7 @@ Global: max_text_length: 25 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_r34_vd_tps_bilstm_ctc.txt Optimizer: name: Adam diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml index 34a997f3..fa7b1ae4 100644 --- a/configs/rec/rec_r50_fpn_srn.yml +++ b/configs/rec/rec_r50_fpn_srn.yml @@ -20,6 +20,7 @@ Global: num_heads: 8 infer_mode: False use_space_char: False + save_res_path: ./output/rec/predicts_srn.txt Optimizer: diff --git a/deploy/cpp_infer/include/postprocess_op.h b/deploy/cpp_infer/include/postprocess_op.h index 44ca3531..a600ea6d 100644 --- a/deploy/cpp_infer/include/postprocess_op.h +++ b/deploy/cpp_infer/include/postprocess_op.h @@ -51,6 +51,7 @@ public: float &ssid); float BoxScoreFast(std::vector> box_array, cv::Mat pred); + float PolygonScoreAcc(std::vector contour, cv::Mat pred); std::vector>> BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, diff --git a/deploy/cpp_infer/src/postprocess_op.cpp b/deploy/cpp_infer/src/postprocess_op.cpp index 8c44a54a..1b71c210 100644 --- a/deploy/cpp_infer/src/postprocess_op.cpp +++ b/deploy/cpp_infer/src/postprocess_op.cpp @@ -159,6 +159,39 @@ std::vector> PostProcessor::GetMiniBoxes(cv::RotatedRect box, return array; } +float PostProcessor::PolygonScoreAcc(std::vector contour, + cv::Mat pred){ + int width = pred.cols; + int height = pred.rows; + std::vector box_x; + std::vector box_y; + for(int i=0; i> box_array, cv::Mat pred) { auto array = box_array; @@ -235,6 +268,8 @@ PostProcessor::BoxesFromBitmap(const cv::Mat pred, const cv::Mat bitmap, float score; score = BoxScoreFast(array, pred); + /* compute using polygon*/ + // score = PolygonScoreAcc(contours[_i], pred); if (score < box_thresh) continue; diff --git a/deploy/cpp_infer/src/preprocess_op.cpp b/deploy/cpp_infer/src/preprocess_op.cpp old mode 100644 new mode 100755 index 87d8dbbd..fb7590e3 --- a/deploy/cpp_infer/src/preprocess_op.cpp +++ b/deploy/cpp_infer/src/preprocess_op.cpp @@ -77,19 +77,10 @@ void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img, int resize_h = int(float(h) * ratio); int resize_w = int(float(w) * ratio); - if (resize_h % 32 == 0) - resize_h = resize_h; - else if (resize_h / 32 < 1 + 1e-5) - resize_h = 32; - else - resize_h = (resize_h / 32) * 32; + + resize_h = max(int(round(float(resize_h) / 32) * 32), 32); + resize_w = max(int(round(float(resize_w) / 32) * 32), 32); - if (resize_w % 32 == 0) - resize_w = resize_w; - else if (resize_w / 32 < 1 + 1e-5) - resize_w = 32; - else - resize_w = (resize_w / 32) * 32; if (!use_tensorrt) { cv::resize(img, resize_img, cv::Size(resize_w, resize_h)); ratio_h = float(resize_h) / float(h); diff --git a/deploy/hubserving/ocr_cls/module.py b/deploy/hubserving/ocr_cls/module.py index 1b91580c..e159e0d3 100644 --- a/deploy/hubserving/ocr_cls/module.py +++ b/deploy/hubserving/ocr_cls/module.py @@ -6,6 +6,7 @@ from __future__ import print_function import os import sys sys.path.insert(0, ".") +import copy from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving @@ -14,6 +15,8 @@ import paddlehub as hub from tools.infer.utility import base64_to_cv2 from tools.infer.predict_cls import TextClassifier +from tools.infer.utility import parse_args +from deploy.hubserving.ocr_cls.params import read_params @moduleinfo( @@ -28,8 +31,7 @@ class OCRCls(hub.Module): """ initialize with the necessary elements """ - from ocr_cls.params import read_params - cfg = read_params() + cfg = self.merge_configs() cfg.use_gpu = use_gpu if use_gpu: @@ -48,6 +50,20 @@ class OCRCls(hub.Module): self.text_classifier = TextClassifier(cfg) + def merge_configs(self, ): + # deafult cfg + backup_argv = copy.deepcopy(sys.argv) + sys.argv = sys.argv[:1] + cfg = parse_args() + + update_cfg_map = vars(read_params()) + + for key in update_cfg_map: + cfg.__setattr__(key, update_cfg_map[key]) + + sys.argv = copy.deepcopy(backup_argv) + return cfg + def read_images(self, paths=[]): images = [] for img_path in paths: diff --git a/deploy/hubserving/ocr_det/module.py b/deploy/hubserving/ocr_det/module.py index 5f7bd6c4..c7d253f5 100644 --- a/deploy/hubserving/ocr_det/module.py +++ b/deploy/hubserving/ocr_det/module.py @@ -7,6 +7,8 @@ import os import sys sys.path.insert(0, ".") +import copy + from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving import cv2 @@ -15,6 +17,8 @@ import paddlehub as hub from tools.infer.utility import base64_to_cv2 from tools.infer.predict_det import TextDetector +from tools.infer.utility import parse_args +from deploy.hubserving.ocr_system.params import read_params @moduleinfo( @@ -29,8 +33,7 @@ class OCRDet(hub.Module): """ initialize with the necessary elements """ - from ocr_det.params import read_params - cfg = read_params() + cfg = self.merge_configs() cfg.use_gpu = use_gpu if use_gpu: @@ -49,6 +52,20 @@ class OCRDet(hub.Module): self.text_detector = TextDetector(cfg) + def merge_configs(self, ): + # deafult cfg + backup_argv = copy.deepcopy(sys.argv) + sys.argv = sys.argv[:1] + cfg = parse_args() + + update_cfg_map = vars(read_params()) + + for key in update_cfg_map: + cfg.__setattr__(key, update_cfg_map[key]) + + sys.argv = copy.deepcopy(backup_argv) + return cfg + def read_images(self, paths=[]): images = [] for img_path in paths: diff --git a/deploy/hubserving/ocr_det/params.py b/deploy/hubserving/ocr_det/params.py index 7be88e9b..bc75cc40 100755 --- a/deploy/hubserving/ocr_det/params.py +++ b/deploy/hubserving/ocr_det/params.py @@ -22,6 +22,7 @@ def read_params(): cfg.det_db_box_thresh = 0.5 cfg.det_db_unclip_ratio = 1.6 cfg.use_dilation = False + cfg.det_db_score_mode = "fast" # #EAST parmas # cfg.det_east_score_thresh = 0.8 diff --git a/deploy/hubserving/ocr_rec/module.py b/deploy/hubserving/ocr_rec/module.py index 41a42104..2bec3fcd 100644 --- a/deploy/hubserving/ocr_rec/module.py +++ b/deploy/hubserving/ocr_rec/module.py @@ -6,6 +6,7 @@ from __future__ import print_function import os import sys sys.path.insert(0, ".") +import copy from paddlehub.common.logger import logger from paddlehub.module.module import moduleinfo, runnable, serving @@ -14,6 +15,8 @@ import paddlehub as hub from tools.infer.utility import base64_to_cv2 from tools.infer.predict_rec import TextRecognizer +from tools.infer.utility import parse_args +from deploy.hubserving.ocr_rec.params import read_params @moduleinfo( @@ -28,8 +31,7 @@ class OCRRec(hub.Module): """ initialize with the necessary elements """ - from ocr_rec.params import read_params - cfg = read_params() + cfg = self.merge_configs() cfg.use_gpu = use_gpu if use_gpu: @@ -48,6 +50,20 @@ class OCRRec(hub.Module): self.text_recognizer = TextRecognizer(cfg) + def merge_configs(self, ): + # deafult cfg + backup_argv = copy.deepcopy(sys.argv) + sys.argv = sys.argv[:1] + cfg = parse_args() + + update_cfg_map = vars(read_params()) + + for key in update_cfg_map: + cfg.__setattr__(key, update_cfg_map[key]) + + sys.argv = copy.deepcopy(backup_argv) + return cfg + def read_images(self, paths=[]): images = [] for img_path in paths: diff --git a/deploy/hubserving/ocr_system/module.py b/deploy/hubserving/ocr_system/module.py index 7f361733..cbef8086 100644 --- a/deploy/hubserving/ocr_system/module.py +++ b/deploy/hubserving/ocr_system/module.py @@ -6,6 +6,7 @@ from __future__ import print_function import os import sys sys.path.insert(0, ".") +import copy import time @@ -17,6 +18,8 @@ import paddlehub as hub from tools.infer.utility import base64_to_cv2 from tools.infer.predict_system import TextSystem +from tools.infer.utility import parse_args +from deploy.hubserving.ocr_system.params import read_params @moduleinfo( @@ -31,8 +34,7 @@ class OCRSystem(hub.Module): """ initialize with the necessary elements """ - from ocr_system.params import read_params - cfg = read_params() + cfg = self.merge_configs() cfg.use_gpu = use_gpu if use_gpu: @@ -51,6 +53,20 @@ class OCRSystem(hub.Module): self.text_sys = TextSystem(cfg) + def merge_configs(self, ): + # deafult cfg + backup_argv = copy.deepcopy(sys.argv) + sys.argv = sys.argv[:1] + cfg = parse_args() + + update_cfg_map = vars(read_params()) + + for key in update_cfg_map: + cfg.__setattr__(key, update_cfg_map[key]) + + sys.argv = copy.deepcopy(backup_argv) + return cfg + def read_images(self, paths=[]): images = [] for img_path in paths: diff --git a/deploy/hubserving/ocr_system/params.py b/deploy/hubserving/ocr_system/params.py index bd56dc2e..bee53bfd 100755 --- a/deploy/hubserving/ocr_system/params.py +++ b/deploy/hubserving/ocr_system/params.py @@ -22,6 +22,7 @@ def read_params(): cfg.det_db_box_thresh = 0.5 cfg.det_db_unclip_ratio = 1.6 cfg.use_dilation = False + cfg.det_db_score_mode = "fast" #EAST parmas cfg.det_east_score_thresh = 0.8 diff --git a/doc/doc_ch/pgnet.md b/doc/doc_ch/pgnet.md index 26585386..e79c2388 100644 --- a/doc/doc_ch/pgnet.md +++ b/doc/doc_ch/pgnet.md @@ -83,19 +83,19 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im 本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。 ### 准备数据 -下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构: +下载解压[totaltext](https://paddleocr.bj.bcebos.com/dataset/total_text.tar) 数据集到PaddleOCR/train_data/目录,数据集组织结构: ``` /PaddleOCR/train_data/total_text/train/ |- rgb/ # total_text数据集的训练数据 - |- gt_0.png + |- img11.jpg | ... - |- total_text.txt # total_text数据集的训练标注 + |- train.txt # total_text数据集的训练标注 ``` total_text.txt标注文件格式如下,文件名和标注信息中间用"\t"分隔: ``` " 图像文件名 json.dumps编码的图像标注信息" -rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}] +rgb/img11.jpg [{"transcription": "ASRAMA", "points": [[214.0, 325.0], [235.0, 308.0], [259.0, 296.0], [286.0, 291.0], [313.0, 295.0], [338.0, 305.0], [362.0, 320.0], [349.0, 347.0], [330.0, 337.0], [310.0, 329.0], [290.0, 324.0], [269.0, 328.0], [249.0, 336.0], [231.0, 346.0]]}, {...}] ``` json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 `transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。** diff --git a/doc/doc_en/pgnet_en.md b/doc/doc_en/pgnet_en.md index 2ab1116c..9c46802b 100644 --- a/doc/doc_en/pgnet_en.md +++ b/doc/doc_en/pgnet_en.md @@ -76,19 +76,19 @@ The visualized end-to-end results are saved to the `./inference_results` folder This section takes the totaltext dataset as an example to introduce the training, evaluation and testing of the end-to-end model in PaddleOCR. ### Data Preparation -Download and unzip [totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) dataset to PaddleOCR/train_data/, dataset organization structure is as follow: +Download and unzip [totaltext](https://paddleocr.bj.bcebos.com/dataset/total_text.tar) dataset to PaddleOCR/train_data/, dataset organization structure is as follow: ``` /PaddleOCR/train_data/total_text/train/ |- rgb/ # total_text training data of dataset - |- gt_0.png + |- img11.png | ... - |- total_text.txt # total_text training annotation of dataset + |- train.txt # total_text training annotation of dataset ``` total_text.txt: the format of dimension file is as follows,the file name and annotation information are separated by "\t": ``` " Image file name Image annotation information encoded by json.dumps" -rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}] +rgb/img11.jpg [{"transcription": "ASRAMA", "points": [[214.0, 325.0], [235.0, 308.0], [259.0, 296.0], [286.0, 291.0], [313.0, 295.0], [338.0, 305.0], [362.0, 320.0], [349.0, 347.0], [330.0, 337.0], [310.0, 329.0], [290.0, 324.0], [269.0, 328.0], [249.0, 336.0], [231.0, 346.0]]}, {...}] ``` The image annotation after **json.dumps()** encoding is a list containing multiple dictionaries. diff --git a/paddleocr.py b/paddleocr.py index 016b00c1..c5da7248 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -193,6 +193,7 @@ def parse_args(mMain=True, add_help=True): parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--use_dilation", type=bool, default=False) + parser.add_argument("--det_db_score_mode", type=str, default="fast") # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) @@ -241,6 +242,7 @@ def parse_args(mMain=True, add_help=True): det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, use_dilation=False, + det_db_score_mode="fast", det_east_score_thresh=0.8, det_east_cover_thresh=0.1, det_east_nms_thresh=0.2, diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index 61cc7303..bba3209f 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -187,7 +187,51 @@ class CTCLabelEncode(BaseRecLabelEncode): return dict_character -class E2ELabelEncode(object): +class E2ELabelEncodeTest(BaseRecLabelEncode): + def __init__(self, + max_text_length, + character_dict_path=None, + character_type='EN', + use_space_char=False, + **kwargs): + super(E2ELabelEncodeTest, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def __call__(self, data): + import json + padnum = len(self.dict) + label = data['label'] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags = [], [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['transcription'] + boxes.append(box) + txts.append(txt) + if txt in ['*', '###']: + txt_tags.append(True) + else: + txt_tags.append(False) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool) + data['polys'] = boxes + data['ignore_tags'] = txt_tags + temp_texts = [] + for text in txts: + text = text.lower() + text = self.encode(text) + if text is None: + return None + text = text + [padnum] * (self.max_text_len - len(text) + ) # use 36 to pad + temp_texts.append(text) + data['texts'] = np.array(temp_texts) + return data + + +class E2ELabelEncodeTrain(object): def __init__(self, **kwargs): pass diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index b9819358..5adcd02c 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -72,6 +72,7 @@ class PGDataSet(Dataset): def __getitem__(self, idx): file_idx = self.data_idx_order_list[idx] data_line = self.data_lines[file_idx] + img_id = 0 try: data_line = data_line.decode('utf-8') substr = data_line.strip("\n").split(self.delimiter) @@ -79,9 +80,10 @@ class PGDataSet(Dataset): label = substr[1] img_path = os.path.join(self.data_dir, file_name) if self.mode.lower() == 'eval': - img_id = int(data_line.split(".")[0][7:]) - else: - img_id = 0 + try: + img_id = int(data_line.split(".")[0][7:]) + except: + img_id = 0 data = {'img_path': img_path, 'label': label, 'img_id': img_id} if not os.path.exists(img_path): raise Exception("{} does not exist!".format(img_path)) diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py index 525aa003..41b7ac2b 100644 --- a/ppocr/metrics/e2e_metric.py +++ b/ppocr/metrics/e2e_metric.py @@ -18,16 +18,18 @@ from __future__ import print_function __all__ = ['E2EMetric'] -from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results +from ppocr.utils.e2e_metric.Deteval import get_socre_A, get_socre_B, combine_results from ppocr.utils.e2e_utils.extract_textpoint_slow import get_dict class E2EMetric(object): def __init__(self, + mode, gt_mat_dir, character_dict_path, main_indicator='f_score_e2e', **kwargs): + self.mode = mode self.gt_mat_dir = gt_mat_dir self.label_list = get_dict(character_dict_path) self.max_index = len(self.label_list) @@ -35,13 +37,45 @@ class E2EMetric(object): self.reset() def __call__(self, preds, batch, **kwargs): - img_id = batch[2][0] - e2e_info_list = [{ - 'points': det_polyon, - 'texts': pred_str - } for det_polyon, pred_str in zip(preds['points'], preds['texts'])] - result = get_socre(self.gt_mat_dir, img_id, e2e_info_list) - self.results.append(result) + if self.mode == 'A': + gt_polyons_batch = batch[2] + temp_gt_strs_batch = batch[3][0] + ignore_tags_batch = batch[4] + gt_strs_batch = [] + + for temp_list in temp_gt_strs_batch: + t = "" + for index in temp_list: + if index < self.max_index: + t += self.label_list[index] + gt_strs_batch.append(t) + + for pred, gt_polyons, gt_strs, ignore_tags in zip( + [preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch): + # prepare gt + gt_info_list = [{ + 'points': gt_polyon, + 'text': gt_str, + 'ignore': ignore_tag + } for gt_polyon, gt_str, ignore_tag in + zip(gt_polyons, gt_strs, ignore_tags)] + # prepare det + e2e_info_list = [{ + 'points': det_polyon, + 'texts': pred_str + } for det_polyon, pred_str in + zip(pred['points'], pred['texts'])] + + result = get_socre_A(gt_info_list, e2e_info_list) + self.results.append(result) + else: + img_id = batch[5][0] + e2e_info_list = [{ + 'points': det_polyon, + 'texts': pred_str + } for det_polyon, pred_str in zip(preds['points'], preds['texts'])] + result = get_socre_B(self.gt_mat_dir, img_id, e2e_info_list) + self.results.append(result) def get_metric(self): metircs = combine_results(self.results) diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 91729e0a..769ddbe2 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -34,12 +34,18 @@ class DBPostProcess(object): max_candidates=1000, unclip_ratio=2.0, use_dilation=False, + score_mode="fast", **kwargs): self.thresh = thresh self.box_thresh = box_thresh self.max_candidates = max_candidates self.unclip_ratio = unclip_ratio self.min_size = 3 + self.score_mode = score_mode + assert score_mode in [ + "slow", "fast" + ], "Score mode must be in [slow, fast] but got: {}".format(score_mode) + self.dilation_kernel = None if not use_dilation else np.array( [[1, 1], [1, 1]]) @@ -69,7 +75,10 @@ class DBPostProcess(object): if sside < self.min_size: continue points = np.array(points) - score = self.box_score_fast(pred, points.reshape(-1, 2)) + if self.score_mode == "fast": + score = self.box_score_fast(pred, points.reshape(-1, 2)) + else: + score = self.box_score_slow(pred, contour) if self.box_thresh > score: continue @@ -120,6 +129,9 @@ class DBPostProcess(object): return box, min(bounding_box[1]) def box_score_fast(self, bitmap, _box): + ''' + box_score_fast: use bbox mean score as the mean score + ''' h, w = bitmap.shape[:2] box = _box.copy() xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) @@ -133,6 +145,27 @@ class DBPostProcess(object): cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + def box_score_slow(self, bitmap, contour): + ''' + box_score_slow: use polyon mean score as the mean score + ''' + h, w = bitmap.shape[:2] + contour = contour.copy() + contour = np.reshape(contour, (-1, 2)) + + xmin = np.clip(np.min(contour[:, 0]), 0, w - 1) + xmax = np.clip(np.max(contour[:, 0]), 0, w - 1) + ymin = np.clip(np.min(contour[:, 1]), 0, h - 1) + ymax = np.clip(np.max(contour[:, 1]), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + + contour[:, 0] = contour[:, 0] - xmin + contour[:, 1] = contour[:, 1] - ymin + + cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + def __call__(self, outs_dict, shape_list): pred = outs_dict['maps'] if isinstance(pred, paddle.Tensor): diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py index 2aa09304..45567a7d 100755 --- a/ppocr/utils/e2e_metric/Deteval.py +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -17,7 +17,144 @@ import scipy.io as io from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area -def get_socre(gt_dir, img_id, pred_dict): +def get_socre_A(gt_dir, pred_dict): + allInputs = 1 + + def input_reading_mod(pred_dict): + """This helper reads input from txt files""" + det = [] + n = len(pred_dict) + for i in range(n): + points = pred_dict[i]['points'] + text = pred_dict[i]['texts'] + point = ",".join(map(str, points.reshape(-1, ))) + det.append([point, text]) + return det + + def gt_reading_mod(gt_dict): + """This helper reads groundtruths from mat files""" + gt = [] + n = len(gt_dict) + for i in range(n): + points = gt_dict[i]['points'].tolist() + h = len(points) + text = gt_dict[i]['text'] + xx = [ + np.array( + ['x:'], dtype=' 1): + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + det_x = detection[0::2] + det_y = detection[1::2] + det_gt_iou = iod(det_x, det_y, gt_x, gt_y) + if det_gt_iou > threshold: + detections[det_id] = [] + + detections[:] = [item for item in detections if item != []] + return detections + + def sigma_calculation(det_x, det_y, gt_x, gt_y): + """ + sigma = inter_area / gt_area + """ + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(gt_x, gt_y)), 2) + + def tau_calculation(det_x, det_y, gt_x, gt_y): + if area(det_x, det_y) == 0.0: + return 0 + return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) / + area(det_x, det_y)), 2) + + ##############################Initialization################################### + # global_sigma = [] + # global_tau = [] + # global_pred_str = [] + # global_gt_str = [] + ############################################################################### + + for input_id in range(allInputs): + if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and ( + input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and ( + input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \ + and (input_id != 'Deteval_result_non_curved.txt'): + detections = input_reading_mod(pred_dict) + groundtruths = gt_reading_mod(gt_dir) + detections = detection_filtering( + detections, + groundtruths) # filters detections overlapping with DC area + dc_id = [] + for i in range(len(groundtruths)): + if groundtruths[i][5] == '#': + dc_id.append(i) + cnt = 0 + for a in dc_id: + num = a - cnt + del groundtruths[num] + cnt += 1 + + local_sigma_table = np.zeros((len(groundtruths), len(detections))) + local_tau_table = np.zeros((len(groundtruths), len(detections))) + local_pred_str = {} + local_gt_str = {} + + for gt_id, gt in enumerate(groundtruths): + if len(detections) > 0: + for det_id, detection in enumerate(detections): + detection_orig = detection + detection = [float(x) for x in detection[0].split(',')] + detection = list(map(int, detection)) + pred_seq_str = detection_orig[1].strip() + det_x = detection[0::2] + det_y = detection[1::2] + gt_x = list(map(int, np.squeeze(gt[1]))) + gt_y = list(map(int, np.squeeze(gt[3]))) + gt_seq_str = str(gt[4].tolist()[0]) + + local_sigma_table[gt_id, det_id] = sigma_calculation( + det_x, det_y, gt_x, gt_y) + local_tau_table[gt_id, det_id] = tau_calculation( + det_x, det_y, gt_x, gt_y) + local_pred_str[det_id] = pred_seq_str + local_gt_str[gt_id] = gt_seq_str + + global_sigma = local_sigma_table + global_tau = local_tau_table + global_pred_str = local_pred_str + global_gt_str = local_gt_str + + single_data = {} + single_data['sigma'] = global_sigma + single_data['global_tau'] = global_tau + single_data['global_pred_str'] = global_pred_str + single_data['global_gt_str'] = global_gt_str + return single_data + + +def get_socre_B(gt_dir, img_id, pred_dict): allInputs = 1 def input_reading_mod(pred_dict): diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index f5ea0504..59bb49f9 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -39,7 +39,10 @@ class TextDetector(object): self.args = args self.det_algorithm = args.det_algorithm pre_process_list = [{ - 'DetResizeForTest': None + 'DetResizeForTest': { + 'limit_side_len': args.det_limit_side_len, + 'limit_type': args.det_limit_type + } }, { 'NormalizeImage': { 'std': [0.229, 0.224, 0.225], @@ -62,6 +65,7 @@ class TextDetector(object): postprocess_params["max_candidates"] = 1000 postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = args.use_dilation + postprocess_params["score_mode"] = args.det_db_score_mode elif self.det_algorithm == "EAST": postprocess_params['name'] = 'EASTPostProcess' postprocess_params["score_thresh"] = args.det_east_score_thresh diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 9fa51d80..b5fe3ba9 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -48,6 +48,7 @@ def parse_args(): parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) parser.add_argument("--max_batch_size", type=int, default=10) parser.add_argument("--use_dilation", type=bool, default=False) + parser.add_argument("--det_db_score_mode", type=str, default="fast") # EAST parmas parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 075ec261..2563f5a8 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -73,35 +73,45 @@ def main(): global_config['infer_mode'] = True ops = create_operators(transforms, global_config) + save_res_path = config['Global'].get('save_res_path', + "./output/rec/predicts_rec.txt") + if not os.path.exists(os.path.dirname(save_res_path)): + os.makedirs(os.path.dirname(save_res_path)) + model.eval() - for file in get_image_file_list(config['Global']['infer_img']): - logger.info("infer_img: {}".format(file)) - with open(file, 'rb') as f: - img = f.read() - data = {'image': img} - batch = transform(data, ops) - if config['Architecture']['algorithm'] == "SRN": - encoder_word_pos_list = np.expand_dims(batch[1], axis=0) - gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) - gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) - gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) - others = [ - paddle.to_tensor(encoder_word_pos_list), - paddle.to_tensor(gsrm_word_pos_list), - paddle.to_tensor(gsrm_slf_attn_bias1_list), - paddle.to_tensor(gsrm_slf_attn_bias2_list) - ] + with open(save_res_path, "w") as fout: + for file in get_image_file_list(config['Global']['infer_img']): + logger.info("infer_img: {}".format(file)) + with open(file, 'rb') as f: + img = f.read() + data = {'image': img} + batch = transform(data, ops) + if config['Architecture']['algorithm'] == "SRN": + encoder_word_pos_list = np.expand_dims(batch[1], axis=0) + gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) + gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) + gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) - images = np.expand_dims(batch[0], axis=0) - images = paddle.to_tensor(images) - if config['Architecture']['algorithm'] == "SRN": - preds = model(images, others) - else: - preds = model(images) - post_result = post_process_class(preds) - for rec_reuslt in post_result: - logger.info('\t result: {}'.format(rec_reuslt)) + others = [ + paddle.to_tensor(encoder_word_pos_list), + paddle.to_tensor(gsrm_word_pos_list), + paddle.to_tensor(gsrm_slf_attn_bias1_list), + paddle.to_tensor(gsrm_slf_attn_bias2_list) + ] + + images = np.expand_dims(batch[0], axis=0) + images = paddle.to_tensor(images) + if config['Architecture']['algorithm'] == "SRN": + preds = model(images, others) + else: + preds = model(images) + post_result = post_process_class(preds) + for rec_reuslt in post_result: + logger.info('\t result: {}'.format(rec_reuslt)) + if len(rec_reuslt) >= 2: + fout.write(file + "\t" + rec_reuslt[0] + "\t" + str( + rec_reuslt[1]) + "\n") logger.info("success!")