fix infer rec
This commit is contained in:
parent
115955f722
commit
bd1820b784
|
@ -19,6 +19,7 @@ Global:
|
|||
infer_mode: false
|
||||
use_space_char: false
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
|
@ -98,7 +99,7 @@ Loss:
|
|||
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode
|
||||
model_name: ["Student"]
|
||||
model_name: ["Student", "Teacher"]
|
||||
key: head_out
|
||||
|
||||
Metric:
|
||||
|
|
|
@ -20,6 +20,7 @@ import numpy as np
|
|||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
|
@ -113,11 +114,23 @@ def main():
|
|||
else:
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds)
|
||||
for rec_reuslt in post_result:
|
||||
logger.info('\t result: {}'.format(rec_reuslt))
|
||||
if len(rec_reuslt) >= 2:
|
||||
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str(
|
||||
rec_reuslt[1]) + "\n")
|
||||
info = None
|
||||
if isinstance(post_result, dict):
|
||||
rec_info = dict()
|
||||
for key in post_result:
|
||||
if len(post_result[key][0]) >= 2:
|
||||
rec_info[key] = {
|
||||
"label": post_result[key][0][0],
|
||||
"score": post_result[key][0][1],
|
||||
}
|
||||
info = json.dumps(rec_info)
|
||||
else:
|
||||
if len(post_result[0]) >= 2:
|
||||
info = post_result[0][0] + "\t" + str(post_result[0][1])
|
||||
|
||||
if info is not None:
|
||||
logger.info("\t result: {}".format(info))
|
||||
fout.write(file + "\t" + info)
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue