fix eval score
This commit is contained in:
parent
2a15d7e505
commit
f67e6e1387
|
@ -61,6 +61,7 @@ PostProcess:
|
||||||
score_thresh: 0.5
|
score_thresh: 0.5
|
||||||
Metric:
|
Metric:
|
||||||
name: E2EMetric
|
name: E2EMetric
|
||||||
|
gt_mat_dir: # the dir of gt_mat
|
||||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||||
main_indicator: f_score_e2e
|
main_indicator: f_score_e2e
|
||||||
|
|
||||||
|
@ -106,7 +107,7 @@ Eval:
|
||||||
order: 'hwc'
|
order: 'hwc'
|
||||||
- ToCHWImage:
|
- ToCHWImage:
|
||||||
- KeepKeys:
|
- KeepKeys:
|
||||||
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ]
|
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id']
|
||||||
loader:
|
loader:
|
||||||
shuffle: False
|
shuffle: False
|
||||||
drop_last: False
|
drop_last: False
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
- [一、简介](#简介)
|
- [一、简介](#简介)
|
||||||
- [二、环境配置](#环境配置)
|
- [二、环境配置](#环境配置)
|
||||||
- [三、快速使用](#快速使用)
|
- [三、快速使用](#快速使用)
|
||||||
- [四、模型训练、评估、推理](#快速训练)
|
- [四、模型训练、评估、推理](#模型训练、评估、推理)
|
||||||
|
|
||||||
<a name="简介"></a>
|
<a name="简介"></a>
|
||||||
## 一、简介
|
## 一、简介
|
||||||
|
@ -20,7 +20,9 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang
|
||||||
![](../pgnet_framework.png)
|
![](../pgnet_framework.png)
|
||||||
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
|
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
|
||||||
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
|
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
|
||||||
|
|
||||||
其检测识别效果图如下:
|
其检测识别效果图如下:
|
||||||
|
|
||||||
![](../imgs_results/e2e_res_img293_pgnet.png)
|
![](../imgs_results/e2e_res_img293_pgnet.png)
|
||||||
![](../imgs_results/e2e_res_img295_pgnet.png)
|
![](../imgs_results/e2e_res_img295_pgnet.png)
|
||||||
|
|
||||||
|
@ -61,12 +63,12 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
|
||||||
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
||||||
|
|
||||||
<a name="快速训练"></a>
|
<a name="模型训练、评估、推理"></a>
|
||||||
## 四、模型训练、评估、推理
|
## 四、模型训练、评估、推理
|
||||||
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
|
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
|
||||||
|
|
||||||
### 准备数据
|
### 准备数据
|
||||||
下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md)数据集到PaddleOCR/train_data/目录,数据集组织结构:
|
下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构:
|
||||||
```
|
```
|
||||||
/PaddleOCR/train_data/total_text/train/
|
/PaddleOCR/train_data/total_text/train/
|
||||||
|- rgb/ # total_text数据集的训练数据
|
|- rgb/ # total_text数据集的训练数据
|
||||||
|
|
|
@ -64,9 +64,6 @@ class PGDataSet(Dataset):
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
poly_str, txt = line.strip().split('\t')
|
poly_str, txt = line.strip().split('\t')
|
||||||
poly = list(map(float, poly_str.split(',')))
|
poly = list(map(float, poly_str.split(',')))
|
||||||
if self.mode.lower() == "eval":
|
|
||||||
while len(poly) < 100:
|
|
||||||
poly.append(-1)
|
|
||||||
text_polys.append(
|
text_polys.append(
|
||||||
np.array(
|
np.array(
|
||||||
poly, dtype=np.float32).reshape(-1, 2))
|
poly, dtype=np.float32).reshape(-1, 2))
|
||||||
|
@ -139,23 +136,21 @@ class PGDataSet(Dataset):
|
||||||
try:
|
try:
|
||||||
if self.data_format == 'icdar':
|
if self.data_format == 'icdar':
|
||||||
im_path = os.path.join(data_path, 'rgb', data_line)
|
im_path = os.path.join(data_path, 'rgb', data_line)
|
||||||
if self.mode.lower() == "eval":
|
poly_path = os.path.join(data_path, 'poly',
|
||||||
poly_path = os.path.join(data_path, 'poly_gt',
|
data_line.split('.')[0] + '.txt')
|
||||||
data_line.split('.')[0] + '.txt')
|
|
||||||
else:
|
|
||||||
poly_path = os.path.join(data_path, 'poly',
|
|
||||||
data_line.split('.')[0] + '.txt')
|
|
||||||
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
|
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
|
||||||
else:
|
else:
|
||||||
image_dir = os.path.join(os.path.dirname(data_path), 'image')
|
image_dir = os.path.join(os.path.dirname(data_path), 'image')
|
||||||
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
|
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
|
||||||
data_line, image_dir)
|
data_line, image_dir)
|
||||||
|
img_id = int(data_line.split(".")[0][3:])
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
'img_path': im_path,
|
'img_path': im_path,
|
||||||
'polys': text_polys,
|
'polys': text_polys,
|
||||||
'tags': text_tags,
|
'tags': text_tags,
|
||||||
'strs': text_strs
|
'strs': text_strs,
|
||||||
|
'img_id': img_id
|
||||||
}
|
}
|
||||||
with open(data['img_path'], 'rb') as f:
|
with open(data['img_path'], 'rb') as f:
|
||||||
img = f.read()
|
img = f.read()
|
||||||
|
|
|
@ -24,53 +24,24 @@ from ppocr.utils.e2e_utils.extract_textpoint import get_dict
|
||||||
|
|
||||||
class E2EMetric(object):
|
class E2EMetric(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
gt_mat_dir,
|
||||||
character_dict_path,
|
character_dict_path,
|
||||||
main_indicator='f_score_e2e',
|
main_indicator='f_score_e2e',
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
self.gt_mat_dir = gt_mat_dir
|
||||||
self.label_list = get_dict(character_dict_path)
|
self.label_list = get_dict(character_dict_path)
|
||||||
self.max_index = len(self.label_list)
|
self.max_index = len(self.label_list)
|
||||||
self.main_indicator = main_indicator
|
self.main_indicator = main_indicator
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __call__(self, preds, batch, **kwargs):
|
def __call__(self, preds, batch, **kwargs):
|
||||||
temp_gt_polyons_batch = batch[2]
|
img_id = batch[5][0]
|
||||||
temp_gt_strs_batch = batch[3]
|
e2e_info_list = [{
|
||||||
ignore_tags_batch = batch[4]
|
'points': det_polyon,
|
||||||
gt_polyons_batch = []
|
'text': pred_str
|
||||||
gt_strs_batch = []
|
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])]
|
||||||
|
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
|
||||||
temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
|
self.results.append(result)
|
||||||
for temp_list in temp_gt_polyons_batch:
|
|
||||||
t = []
|
|
||||||
for index in temp_list:
|
|
||||||
if index[0] != -1 and index[1] != -1:
|
|
||||||
t.append(index)
|
|
||||||
gt_polyons_batch.append(t)
|
|
||||||
|
|
||||||
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
|
|
||||||
for temp_list in temp_gt_strs_batch:
|
|
||||||
t = ""
|
|
||||||
for index in temp_list:
|
|
||||||
if index < self.max_index:
|
|
||||||
t += self.label_list[index]
|
|
||||||
gt_strs_batch.append(t)
|
|
||||||
|
|
||||||
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
|
||||||
[preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
|
|
||||||
# prepare gt
|
|
||||||
gt_info_list = [{
|
|
||||||
'points': gt_polyon,
|
|
||||||
'text': gt_str,
|
|
||||||
'ignore': ignore_tag
|
|
||||||
} for gt_polyon, gt_str, ignore_tag in
|
|
||||||
zip(gt_polyons, gt_strs, ignore_tags)]
|
|
||||||
# prepare det
|
|
||||||
e2e_info_list = [{
|
|
||||||
'points': det_polyon,
|
|
||||||
'text': pred_str
|
|
||||||
} for det_polyon, pred_str in zip(pred['points'], pred['strs'])]
|
|
||||||
result = get_socre(gt_info_list, e2e_info_list)
|
|
||||||
self.results.append(result)
|
|
||||||
|
|
||||||
def get_metric(self):
|
def get_metric(self):
|
||||||
metircs = combine_results(self.results)
|
metircs = combine_results(self.results)
|
||||||
|
|
|
@ -138,6 +138,7 @@ class PGPostProcess(object):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
keep_str_list.append(keep_str)
|
keep_str_list.append(keep_str)
|
||||||
|
detected_poly = np.round(detected_poly).astype('int32')
|
||||||
if self.valid_set == 'partvgg':
|
if self.valid_set == 'partvgg':
|
||||||
middle_point = len(detected_poly) // 2
|
middle_point = len(detected_poly) // 2
|
||||||
detected_poly = detected_poly[
|
detected_poly = detected_poly[
|
||||||
|
|
|
@ -13,10 +13,11 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import scipy.io as io
|
||||||
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
||||||
|
|
||||||
|
|
||||||
def get_socre(gt_dict, pred_dict):
|
def get_socre(gt_dir, img_id, pred_dict):
|
||||||
allInputs = 1
|
allInputs = 1
|
||||||
|
|
||||||
def input_reading_mod(pred_dict):
|
def input_reading_mod(pred_dict):
|
||||||
|
@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict):
|
||||||
det.append([point, text])
|
det.append([point, text])
|
||||||
return det
|
return det
|
||||||
|
|
||||||
def gt_reading_mod(gt_dict):
|
def gt_reading_mod(gt_dir, gt_id):
|
||||||
"""This helper reads groundtruths from mat files"""
|
gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
|
||||||
gt = []
|
gt = gt['polygt']
|
||||||
n = len(gt_dict)
|
|
||||||
for i in range(n):
|
|
||||||
points = gt_dict[i]['points']
|
|
||||||
h = len(points)
|
|
||||||
text = gt_dict[i]['text']
|
|
||||||
xx = [
|
|
||||||
np.array(
|
|
||||||
['x:'], dtype='<U2'), 0, np.array(
|
|
||||||
['y:'], dtype='<U2'), 0, np.array(
|
|
||||||
['#'], dtype='<U1'), np.array(
|
|
||||||
['#'], dtype='<U1')
|
|
||||||
]
|
|
||||||
t_x, t_y = [], []
|
|
||||||
for j in range(h):
|
|
||||||
t_x.append(points[j][0])
|
|
||||||
t_y.append(points[j][1])
|
|
||||||
xx[1] = np.array([t_x], dtype='int16')
|
|
||||||
xx[3] = np.array([t_y], dtype='int16')
|
|
||||||
if text != "" and "#" not in text:
|
|
||||||
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
|
|
||||||
xx[5] = np.array(['c'], dtype='<U1')
|
|
||||||
gt.append(xx)
|
|
||||||
return gt
|
return gt
|
||||||
|
|
||||||
def detection_filtering(detections, groundtruths, threshold=0.5):
|
def detection_filtering(detections, groundtruths, threshold=0.5):
|
||||||
|
@ -101,7 +80,7 @@ def get_socre(gt_dict, pred_dict):
|
||||||
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
|
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
|
||||||
and (input_id != 'Deteval_result_non_curved.txt'):
|
and (input_id != 'Deteval_result_non_curved.txt'):
|
||||||
detections = input_reading_mod(pred_dict)
|
detections = input_reading_mod(pred_dict)
|
||||||
groundtruths = gt_reading_mod(gt_dict)
|
groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
|
||||||
detections = detection_filtering(
|
detections = detection_filtering(
|
||||||
detections,
|
detections,
|
||||||
groundtruths) # filters detections overlapping with DC area
|
groundtruths) # filters detections overlapping with DC area
|
||||||
|
|
Loading…
Reference in New Issue