From f67e6e13876aa4a5d9fee4fd49292c08d8cd5fc0 Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Sat, 10 Apr 2021 14:44:32 +0800 Subject: [PATCH] fix eval score --- configs/e2e/e2e_r50_vd_pg.yml | 3 +- doc/doc_ch/pgnet.md | 8 +++-- ppocr/data/pgnet_dataset.py | 15 +++------ ppocr/metrics/e2e_metric.py | 47 ++++++----------------------- ppocr/postprocess/pg_postprocess.py | 1 + ppocr/utils/e2e_metric/Deteval.py | 33 ++++---------------- 6 files changed, 28 insertions(+), 79 deletions(-) diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index 0a232f7a..5a593ad8 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -61,6 +61,7 @@ PostProcess: score_thresh: 0.5 Metric: name: E2EMetric + gt_mat_dir: # the dir of gt_mat character_dict_path: ppocr/utils/ic15_dict.txt main_indicator: f_score_e2e @@ -106,7 +107,7 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ] + keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id'] loader: shuffle: False drop_last: False diff --git a/doc/doc_ch/pgnet.md b/doc/doc_ch/pgnet.md index 4d3b8208..d82bb796 100644 --- a/doc/doc_ch/pgnet.md +++ b/doc/doc_ch/pgnet.md @@ -2,7 +2,7 @@ - [一、简介](#简介) - [二、环境配置](#环境配置) - [三、快速使用](#快速使用) -- [四、模型训练、评估、推理](#快速训练) +- [四、模型训练、评估、推理](#模型训练、评估、推理) ## 一、简介 @@ -20,7 +20,9 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang ![](../pgnet_framework.png) 输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。 其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。 + 其检测识别效果图如下: + ![](../imgs_results/e2e_res_img293_pgnet.png) ![](../imgs_results/e2e_res_img295_pgnet.png) @@ -61,12 +63,12 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im 可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下: ![](../imgs_results/e2e_res_img623_pgnet.jpg) - + ## 四、模型训练、评估、推理 本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。 ### 准备数据 -下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md)数据集到PaddleOCR/train_data/目录,数据集组织结构: +下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构: ``` /PaddleOCR/train_data/total_text/train/ |- rgb/ # total_text数据集的训练数据 diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index ae063835..543dbe79 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -64,9 +64,6 @@ class PGDataSet(Dataset): for line in f.readlines(): poly_str, txt = line.strip().split('\t') poly = list(map(float, poly_str.split(','))) - if self.mode.lower() == "eval": - while len(poly) < 100: - poly.append(-1) text_polys.append( np.array( poly, dtype=np.float32).reshape(-1, 2)) @@ -139,23 +136,21 @@ class PGDataSet(Dataset): try: if self.data_format == 'icdar': im_path = os.path.join(data_path, 'rgb', data_line) - if self.mode.lower() == "eval": - poly_path = os.path.join(data_path, 'poly_gt', - data_line.split('.')[0] + '.txt') - else: - poly_path = os.path.join(data_path, 'poly', - data_line.split('.')[0] + '.txt') + poly_path = os.path.join(data_path, 'poly', + data_line.split('.')[0] + '.txt') text_polys, text_tags, text_strs = self.extract_polys(poly_path) else: image_dir = os.path.join(os.path.dirname(data_path), 'image') im_path, text_polys, text_tags, text_strs = self.extract_info_textnet( data_line, image_dir) + img_id = int(data_line.split(".")[0][3:]) data = { 'img_path': im_path, 'polys': text_polys, 'tags': text_tags, - 'strs': text_strs + 'strs': text_strs, + 'img_id': img_id } with open(data['img_path'], 'rb') as f: img = f.read() diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py index 684d7742..ef14ad48 100644 --- a/ppocr/metrics/e2e_metric.py +++ b/ppocr/metrics/e2e_metric.py @@ -24,53 +24,24 @@ from ppocr.utils.e2e_utils.extract_textpoint import get_dict class E2EMetric(object): def __init__(self, + gt_mat_dir, character_dict_path, main_indicator='f_score_e2e', **kwargs): + self.gt_mat_dir = gt_mat_dir self.label_list = get_dict(character_dict_path) self.max_index = len(self.label_list) self.main_indicator = main_indicator self.reset() def __call__(self, preds, batch, **kwargs): - temp_gt_polyons_batch = batch[2] - temp_gt_strs_batch = batch[3] - ignore_tags_batch = batch[4] - gt_polyons_batch = [] - gt_strs_batch = [] - - temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist() - for temp_list in temp_gt_polyons_batch: - t = [] - for index in temp_list: - if index[0] != -1 and index[1] != -1: - t.append(index) - gt_polyons_batch.append(t) - - temp_gt_strs_batch = temp_gt_strs_batch[0].tolist() - for temp_list in temp_gt_strs_batch: - t = "" - for index in temp_list: - if index < self.max_index: - t += self.label_list[index] - gt_strs_batch.append(t) - - for pred, gt_polyons, gt_strs, ignore_tags in zip( - [preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch): - # prepare gt - gt_info_list = [{ - 'points': gt_polyon, - 'text': gt_str, - 'ignore': ignore_tag - } for gt_polyon, gt_str, ignore_tag in - zip(gt_polyons, gt_strs, ignore_tags)] - # prepare det - e2e_info_list = [{ - 'points': det_polyon, - 'text': pred_str - } for det_polyon, pred_str in zip(pred['points'], pred['strs'])] - result = get_socre(gt_info_list, e2e_info_list) - self.results.append(result) + img_id = batch[5][0] + e2e_info_list = [{ + 'points': det_polyon, + 'text': pred_str + } for det_polyon, pred_str in zip(preds['points'], preds['strs'])] + result = get_socre(self.gt_mat_dir, img_id, e2e_info_list) + self.results.append(result) def get_metric(self): metircs = combine_results(self.results) diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py index d9c0048f..f9118d87 100644 --- a/ppocr/postprocess/pg_postprocess.py +++ b/ppocr/postprocess/pg_postprocess.py @@ -138,6 +138,7 @@ class PGPostProcess(object): continue keep_str_list.append(keep_str) + detected_poly = np.round(detected_poly).astype('int32') if self.valid_set == 'partvgg': middle_point = len(detected_poly) // 2 detected_poly = detected_poly[ diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py index 8033a9ff..e30a498e 100755 --- a/ppocr/utils/e2e_metric/Deteval.py +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -13,10 +13,11 @@ # limitations under the License. import numpy as np +import scipy.io as io from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area -def get_socre(gt_dict, pred_dict): +def get_socre(gt_dir, img_id, pred_dict): allInputs = 1 def input_reading_mod(pred_dict): @@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict): det.append([point, text]) return det - def gt_reading_mod(gt_dict): - """This helper reads groundtruths from mat files""" - gt = [] - n = len(gt_dict) - for i in range(n): - points = gt_dict[i]['points'] - h = len(points) - text = gt_dict[i]['text'] - xx = [ - np.array( - ['x:'], dtype='