add metric mode

This commit is contained in:
Jethong 2021-04-21 14:15:51 +08:00
parent e2b84da866
commit 2f978f638b
3 changed files with 25 additions and 7 deletions

View File

@ -60,6 +60,7 @@ PostProcess:
name: PGPostProcess name: PGPostProcess
score_thresh: 0.5 score_thresh: 0.5
mode: fast # fast or slow two ways mode: fast # fast or slow two ways
Metric: Metric:
name: E2EMetric name: E2EMetric
mode: A # A or B mode: A # A or B

View File

@ -199,14 +199,32 @@ class E2ELabelEncode_test(BaseRecLabelEncode):
character_type, use_space_char) character_type, use_space_char)
def __call__(self, data): def __call__(self, data):
texts = data['texts'] import json
padnum = len(self.dict)
label = data['label']
label = json.loads(label)
nBox = len(label)
boxes, txts, txt_tags = [], [], []
for bno in range(0, nBox):
box = label[bno]['points']
txt = label[bno]['transcription']
boxes.append(box)
txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['ignore_tags'] = txt_tags
temp_texts = [] temp_texts = []
for text in texts: for text in txts:
text = text.lower() text = text.lower()
text = self.encode(text) text = self.encode(text)
if text is None: if text is None:
return None return None
text = text + [36] * (self.max_text_len - len(text) text = text + [padnum] * (self.max_text_len - len(text)
) # use 36 to pad ) # use 36 to pad
temp_texts.append(text) temp_texts.append(text)
data['texts'] = np.array(temp_texts) data['texts'] = np.array(temp_texts)

View File

@ -39,7 +39,7 @@ class E2EMetric(object):
def __call__(self, preds, batch, **kwargs): def __call__(self, preds, batch, **kwargs):
if self.mode == 'A': if self.mode == 'A':
gt_polyons_batch = batch[2] gt_polyons_batch = batch[2]
temp_gt_strs_batch = batch[3] temp_gt_strs_batch = batch[3][0]
ignore_tags_batch = batch[4] ignore_tags_batch = batch[4]
gt_strs_batch = [] gt_strs_batch = []
@ -51,8 +51,7 @@ class E2EMetric(object):
gt_strs_batch.append(t) gt_strs_batch.append(t)
for pred, gt_polyons, gt_strs, ignore_tags in zip( for pred, gt_polyons, gt_strs, ignore_tags in zip(
[preds], [gt_polyons_batch], [gt_strs_batch], [preds], gt_polyons_batch, [gt_strs_batch], ignore_tags_batch):
ignore_tags_batch):
# prepare gt # prepare gt
gt_info_list = [{ gt_info_list = [{
'points': gt_polyon, 'points': gt_polyon,