add metric mode
This commit is contained in:
parent
e2b84da866
commit
2f978f638b
|
@ -60,6 +60,7 @@ PostProcess:
|
||||||
name: PGPostProcess
|
name: PGPostProcess
|
||||||
score_thresh: 0.5
|
score_thresh: 0.5
|
||||||
mode: fast # fast or slow two ways
|
mode: fast # fast or slow two ways
|
||||||
|
|
||||||
Metric:
|
Metric:
|
||||||
name: E2EMetric
|
name: E2EMetric
|
||||||
mode: A # A or B
|
mode: A # A or B
|
||||||
|
|
|
@ -199,14 +199,32 @@ class E2ELabelEncode_test(BaseRecLabelEncode):
|
||||||
character_type, use_space_char)
|
character_type, use_space_char)
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
texts = data['texts']
|
import json
|
||||||
|
padnum = len(self.dict)
|
||||||
|
label = data['label']
|
||||||
|
label = json.loads(label)
|
||||||
|
nBox = len(label)
|
||||||
|
boxes, txts, txt_tags = [], [], []
|
||||||
|
for bno in range(0, nBox):
|
||||||
|
box = label[bno]['points']
|
||||||
|
txt = label[bno]['transcription']
|
||||||
|
boxes.append(box)
|
||||||
|
txts.append(txt)
|
||||||
|
if txt in ['*', '###']:
|
||||||
|
txt_tags.append(True)
|
||||||
|
else:
|
||||||
|
txt_tags.append(False)
|
||||||
|
boxes = np.array(boxes, dtype=np.float32)
|
||||||
|
txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||||
|
data['polys'] = boxes
|
||||||
|
data['ignore_tags'] = txt_tags
|
||||||
temp_texts = []
|
temp_texts = []
|
||||||
for text in texts:
|
for text in txts:
|
||||||
text = text.lower()
|
text = text.lower()
|
||||||
text = self.encode(text)
|
text = self.encode(text)
|
||||||
if text is None:
|
if text is None:
|
||||||
return None
|
return None
|
||||||
text = text + [36] * (self.max_text_len - len(text)
|
text = text + [padnum] * (self.max_text_len - len(text)
|
||||||
) # use 36 to pad
|
) # use 36 to pad
|
||||||
temp_texts.append(text)
|
temp_texts.append(text)
|
||||||
data['texts'] = np.array(temp_texts)
|
data['texts'] = np.array(temp_texts)
|
||||||
|
|
|
@ -39,7 +39,7 @@ class E2EMetric(object):
|
||||||
def __call__(self, preds, batch, **kwargs):
|
def __call__(self, preds, batch, **kwargs):
|
||||||
if self.mode == 'A':
|
if self.mode == 'A':
|
||||||
gt_polyons_batch = batch[2]
|
gt_polyons_batch = batch[2]
|
||||||
temp_gt_strs_batch = batch[3]
|
temp_gt_strs_batch = batch[3][0]
|
||||||
ignore_tags_batch = batch[4]
|
ignore_tags_batch = batch[4]
|
||||||
gt_strs_batch = []
|
gt_strs_batch = []
|
||||||
|
|
||||||
|
@ -51,8 +51,7 @@ class E2EMetric(object):
|
||||||
gt_strs_batch.append(t)
|
gt_strs_batch.append(t)
|
||||||
|
|
||||||
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
||||||
[preds], [gt_polyons_batch], [gt_strs_batch],
|
[preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
|
||||||
ignore_tags_batch):
|
|
||||||
# prepare gt
|
# prepare gt
|
||||||
gt_info_list = [{
|
gt_info_list = [{
|
||||||
'points': gt_polyon,
|
'points': gt_polyon,
|
||||||
|
|
Loading…
Reference in New Issue