From f67e6e13876aa4a5d9fee4fd49292c08d8cd5fc0 Mon Sep 17 00:00:00 2001
From: Jethong <1147925384@qq.com>
Date: Sat, 10 Apr 2021 14:44:32 +0800
Subject: [PATCH] fix eval score
---
configs/e2e/e2e_r50_vd_pg.yml | 3 +-
doc/doc_ch/pgnet.md | 8 +++--
ppocr/data/pgnet_dataset.py | 15 +++------
ppocr/metrics/e2e_metric.py | 47 ++++++-----------------------
ppocr/postprocess/pg_postprocess.py | 1 +
ppocr/utils/e2e_metric/Deteval.py | 33 ++++----------------
6 files changed, 28 insertions(+), 79 deletions(-)
diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml
index 0a232f7a..5a593ad8 100644
--- a/configs/e2e/e2e_r50_vd_pg.yml
+++ b/configs/e2e/e2e_r50_vd_pg.yml
@@ -61,6 +61,7 @@ PostProcess:
score_thresh: 0.5
Metric:
name: E2EMetric
+ gt_mat_dir: # the dir of gt_mat
character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e
@@ -106,7 +107,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
- keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ]
+ keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id']
loader:
shuffle: False
drop_last: False
diff --git a/doc/doc_ch/pgnet.md b/doc/doc_ch/pgnet.md
index 4d3b8208..d82bb796 100644
--- a/doc/doc_ch/pgnet.md
+++ b/doc/doc_ch/pgnet.md
@@ -2,7 +2,7 @@
- [一、简介](#简介)
- [二、环境配置](#环境配置)
- [三、快速使用](#快速使用)
-- [四、模型训练、评估、推理](#快速训练)
+- [四、模型训练、评估、推理](#模型训练、评估、推理)
## 一、简介
@@ -20,7 +20,9 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang
![](../pgnet_framework.png)
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
+
其检测识别效果图如下:
+
![](../imgs_results/e2e_res_img293_pgnet.png)
![](../imgs_results/e2e_res_img295_pgnet.png)
@@ -61,12 +63,12 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pgnet.jpg)
-
+
## 四、模型训练、评估、推理
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
### 准备数据
-下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md)数据集到PaddleOCR/train_data/目录,数据集组织结构:
+下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构:
```
/PaddleOCR/train_data/total_text/train/
|- rgb/ # total_text数据集的训练数据
diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py
index ae063835..543dbe79 100644
--- a/ppocr/data/pgnet_dataset.py
+++ b/ppocr/data/pgnet_dataset.py
@@ -64,9 +64,6 @@ class PGDataSet(Dataset):
for line in f.readlines():
poly_str, txt = line.strip().split('\t')
poly = list(map(float, poly_str.split(',')))
- if self.mode.lower() == "eval":
- while len(poly) < 100:
- poly.append(-1)
text_polys.append(
np.array(
poly, dtype=np.float32).reshape(-1, 2))
@@ -139,23 +136,21 @@ class PGDataSet(Dataset):
try:
if self.data_format == 'icdar':
im_path = os.path.join(data_path, 'rgb', data_line)
- if self.mode.lower() == "eval":
- poly_path = os.path.join(data_path, 'poly_gt',
- data_line.split('.')[0] + '.txt')
- else:
- poly_path = os.path.join(data_path, 'poly',
- data_line.split('.')[0] + '.txt')
+ poly_path = os.path.join(data_path, 'poly',
+ data_line.split('.')[0] + '.txt')
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
else:
image_dir = os.path.join(os.path.dirname(data_path), 'image')
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
data_line, image_dir)
+ img_id = int(data_line.split(".")[0][3:])
data = {
'img_path': im_path,
'polys': text_polys,
'tags': text_tags,
- 'strs': text_strs
+ 'strs': text_strs,
+ 'img_id': img_id
}
with open(data['img_path'], 'rb') as f:
img = f.read()
diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py
index 684d7742..ef14ad48 100644
--- a/ppocr/metrics/e2e_metric.py
+++ b/ppocr/metrics/e2e_metric.py
@@ -24,53 +24,24 @@ from ppocr.utils.e2e_utils.extract_textpoint import get_dict
class E2EMetric(object):
def __init__(self,
+ gt_mat_dir,
character_dict_path,
main_indicator='f_score_e2e',
**kwargs):
+ self.gt_mat_dir = gt_mat_dir
self.label_list = get_dict(character_dict_path)
self.max_index = len(self.label_list)
self.main_indicator = main_indicator
self.reset()
def __call__(self, preds, batch, **kwargs):
- temp_gt_polyons_batch = batch[2]
- temp_gt_strs_batch = batch[3]
- ignore_tags_batch = batch[4]
- gt_polyons_batch = []
- gt_strs_batch = []
-
- temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
- for temp_list in temp_gt_polyons_batch:
- t = []
- for index in temp_list:
- if index[0] != -1 and index[1] != -1:
- t.append(index)
- gt_polyons_batch.append(t)
-
- temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
- for temp_list in temp_gt_strs_batch:
- t = ""
- for index in temp_list:
- if index < self.max_index:
- t += self.label_list[index]
- gt_strs_batch.append(t)
-
- for pred, gt_polyons, gt_strs, ignore_tags in zip(
- [preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
- # prepare gt
- gt_info_list = [{
- 'points': gt_polyon,
- 'text': gt_str,
- 'ignore': ignore_tag
- } for gt_polyon, gt_str, ignore_tag in
- zip(gt_polyons, gt_strs, ignore_tags)]
- # prepare det
- e2e_info_list = [{
- 'points': det_polyon,
- 'text': pred_str
- } for det_polyon, pred_str in zip(pred['points'], pred['strs'])]
- result = get_socre(gt_info_list, e2e_info_list)
- self.results.append(result)
+ img_id = batch[5][0]
+ e2e_info_list = [{
+ 'points': det_polyon,
+ 'text': pred_str
+ } for det_polyon, pred_str in zip(preds['points'], preds['strs'])]
+ result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
+ self.results.append(result)
def get_metric(self):
metircs = combine_results(self.results)
diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py
index d9c0048f..f9118d87 100644
--- a/ppocr/postprocess/pg_postprocess.py
+++ b/ppocr/postprocess/pg_postprocess.py
@@ -138,6 +138,7 @@ class PGPostProcess(object):
continue
keep_str_list.append(keep_str)
+ detected_poly = np.round(detected_poly).astype('int32')
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py
index 8033a9ff..e30a498e 100755
--- a/ppocr/utils/e2e_metric/Deteval.py
+++ b/ppocr/utils/e2e_metric/Deteval.py
@@ -13,10 +13,11 @@
# limitations under the License.
import numpy as np
+import scipy.io as io
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
-def get_socre(gt_dict, pred_dict):
+def get_socre(gt_dir, img_id, pred_dict):
allInputs = 1
def input_reading_mod(pred_dict):
@@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict):
det.append([point, text])
return det
- def gt_reading_mod(gt_dict):
- """This helper reads groundtruths from mat files"""
- gt = []
- n = len(gt_dict)
- for i in range(n):
- points = gt_dict[i]['points']
- h = len(points)
- text = gt_dict[i]['text']
- xx = [
- np.array(
- ['x:'], dtype='