fix eval score
This commit is contained in:
parent
2a15d7e505
commit
f67e6e1387
|
@ -61,6 +61,7 @@ PostProcess:
|
|||
score_thresh: 0.5
|
||||
Metric:
|
||||
name: E2EMetric
|
||||
gt_mat_dir: # the dir of gt_mat
|
||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||
main_indicator: f_score_e2e
|
||||
|
||||
|
@ -106,7 +107,7 @@ Eval:
|
|||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ]
|
||||
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
- [一、简介](#简介)
|
||||
- [二、环境配置](#环境配置)
|
||||
- [三、快速使用](#快速使用)
|
||||
- [四、模型训练、评估、推理](#快速训练)
|
||||
- [四、模型训练、评估、推理](#模型训练、评估、推理)
|
||||
|
||||
<a name="简介"></a>
|
||||
## 一、简介
|
||||
|
@ -20,7 +20,9 @@ PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.Wang
|
|||
![](../pgnet_framework.png)
|
||||
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
|
||||
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
|
||||
|
||||
其检测识别效果图如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img293_pgnet.png)
|
||||
![](../imgs_results/e2e_res_img295_pgnet.png)
|
||||
|
||||
|
@ -61,12 +63,12 @@ python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/im
|
|||
可视化文本检测结果默认保存到./inference_results文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
||||
|
||||
<a name="快速训练"></a>
|
||||
<a name="模型训练、评估、推理"></a>
|
||||
## 四、模型训练、评估、推理
|
||||
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
|
||||
|
||||
### 准备数据
|
||||
下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md)数据集到PaddleOCR/train_data/目录,数据集组织结构:
|
||||
下载解压[totaltext](https://github.com/cs-chan/Total-Text-Dataset/blob/master/Dataset/README.md) 数据集到PaddleOCR/train_data/目录,数据集组织结构:
|
||||
```
|
||||
/PaddleOCR/train_data/total_text/train/
|
||||
|- rgb/ # total_text数据集的训练数据
|
||||
|
|
|
@ -64,9 +64,6 @@ class PGDataSet(Dataset):
|
|||
for line in f.readlines():
|
||||
poly_str, txt = line.strip().split('\t')
|
||||
poly = list(map(float, poly_str.split(',')))
|
||||
if self.mode.lower() == "eval":
|
||||
while len(poly) < 100:
|
||||
poly.append(-1)
|
||||
text_polys.append(
|
||||
np.array(
|
||||
poly, dtype=np.float32).reshape(-1, 2))
|
||||
|
@ -139,23 +136,21 @@ class PGDataSet(Dataset):
|
|||
try:
|
||||
if self.data_format == 'icdar':
|
||||
im_path = os.path.join(data_path, 'rgb', data_line)
|
||||
if self.mode.lower() == "eval":
|
||||
poly_path = os.path.join(data_path, 'poly_gt',
|
||||
data_line.split('.')[0] + '.txt')
|
||||
else:
|
||||
poly_path = os.path.join(data_path, 'poly',
|
||||
data_line.split('.')[0] + '.txt')
|
||||
poly_path = os.path.join(data_path, 'poly',
|
||||
data_line.split('.')[0] + '.txt')
|
||||
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
|
||||
else:
|
||||
image_dir = os.path.join(os.path.dirname(data_path), 'image')
|
||||
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
|
||||
data_line, image_dir)
|
||||
img_id = int(data_line.split(".")[0][3:])
|
||||
|
||||
data = {
|
||||
'img_path': im_path,
|
||||
'polys': text_polys,
|
||||
'tags': text_tags,
|
||||
'strs': text_strs
|
||||
'strs': text_strs,
|
||||
'img_id': img_id
|
||||
}
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
|
|
|
@ -24,53 +24,24 @@ from ppocr.utils.e2e_utils.extract_textpoint import get_dict
|
|||
|
||||
class E2EMetric(object):
|
||||
def __init__(self,
|
||||
gt_mat_dir,
|
||||
character_dict_path,
|
||||
main_indicator='f_score_e2e',
|
||||
**kwargs):
|
||||
self.gt_mat_dir = gt_mat_dir
|
||||
self.label_list = get_dict(character_dict_path)
|
||||
self.max_index = len(self.label_list)
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
temp_gt_polyons_batch = batch[2]
|
||||
temp_gt_strs_batch = batch[3]
|
||||
ignore_tags_batch = batch[4]
|
||||
gt_polyons_batch = []
|
||||
gt_strs_batch = []
|
||||
|
||||
temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
|
||||
for temp_list in temp_gt_polyons_batch:
|
||||
t = []
|
||||
for index in temp_list:
|
||||
if index[0] != -1 and index[1] != -1:
|
||||
t.append(index)
|
||||
gt_polyons_batch.append(t)
|
||||
|
||||
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
|
||||
for temp_list in temp_gt_strs_batch:
|
||||
t = ""
|
||||
for index in temp_list:
|
||||
if index < self.max_index:
|
||||
t += self.label_list[index]
|
||||
gt_strs_batch.append(t)
|
||||
|
||||
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
||||
[preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
|
||||
# prepare gt
|
||||
gt_info_list = [{
|
||||
'points': gt_polyon,
|
||||
'text': gt_str,
|
||||
'ignore': ignore_tag
|
||||
} for gt_polyon, gt_str, ignore_tag in
|
||||
zip(gt_polyons, gt_strs, ignore_tags)]
|
||||
# prepare det
|
||||
e2e_info_list = [{
|
||||
'points': det_polyon,
|
||||
'text': pred_str
|
||||
} for det_polyon, pred_str in zip(pred['points'], pred['strs'])]
|
||||
result = get_socre(gt_info_list, e2e_info_list)
|
||||
self.results.append(result)
|
||||
img_id = batch[5][0]
|
||||
e2e_info_list = [{
|
||||
'points': det_polyon,
|
||||
'text': pred_str
|
||||
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])]
|
||||
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
def get_metric(self):
|
||||
metircs = combine_results(self.results)
|
||||
|
|
|
@ -138,6 +138,7 @@ class PGPostProcess(object):
|
|||
continue
|
||||
|
||||
keep_str_list.append(keep_str)
|
||||
detected_poly = np.round(detected_poly).astype('int32')
|
||||
if self.valid_set == 'partvgg':
|
||||
middle_point = len(detected_poly) // 2
|
||||
detected_poly = detected_poly[
|
||||
|
|
|
@ -13,10 +13,11 @@
|
|||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import scipy.io as io
|
||||
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
||||
|
||||
|
||||
def get_socre(gt_dict, pred_dict):
|
||||
def get_socre(gt_dir, img_id, pred_dict):
|
||||
allInputs = 1
|
||||
|
||||
def input_reading_mod(pred_dict):
|
||||
|
@ -30,31 +31,9 @@ def get_socre(gt_dict, pred_dict):
|
|||
det.append([point, text])
|
||||
return det
|
||||
|
||||
def gt_reading_mod(gt_dict):
|
||||
"""This helper reads groundtruths from mat files"""
|
||||
gt = []
|
||||
n = len(gt_dict)
|
||||
for i in range(n):
|
||||
points = gt_dict[i]['points']
|
||||
h = len(points)
|
||||
text = gt_dict[i]['text']
|
||||
xx = [
|
||||
np.array(
|
||||
['x:'], dtype='<U2'), 0, np.array(
|
||||
['y:'], dtype='<U2'), 0, np.array(
|
||||
['#'], dtype='<U1'), np.array(
|
||||
['#'], dtype='<U1')
|
||||
]
|
||||
t_x, t_y = [], []
|
||||
for j in range(h):
|
||||
t_x.append(points[j][0])
|
||||
t_y.append(points[j][1])
|
||||
xx[1] = np.array([t_x], dtype='int16')
|
||||
xx[3] = np.array([t_y], dtype='int16')
|
||||
if text != "" and "#" not in text:
|
||||
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
|
||||
xx[5] = np.array(['c'], dtype='<U1')
|
||||
gt.append(xx)
|
||||
def gt_reading_mod(gt_dir, gt_id):
|
||||
gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
|
||||
gt = gt['polygt']
|
||||
return gt
|
||||
|
||||
def detection_filtering(detections, groundtruths, threshold=0.5):
|
||||
|
@ -101,7 +80,7 @@ def get_socre(gt_dict, pred_dict):
|
|||
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
|
||||
and (input_id != 'Deteval_result_non_curved.txt'):
|
||||
detections = input_reading_mod(pred_dict)
|
||||
groundtruths = gt_reading_mod(gt_dict)
|
||||
groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
|
||||
detections = detection_filtering(
|
||||
detections,
|
||||
groundtruths) # filters detections overlapping with DC area
|
||||
|
|
Loading…
Reference in New Issue