fix Nan results and add test_reader func
This commit is contained in:
parent
f1f9206b32
commit
3892a8ca02
|
@ -122,6 +122,8 @@ class TextRecognizer(object):
|
||||||
blank = probs.shape[1]
|
blank = probs.shape[1]
|
||||||
valid_ind = np.where(ind != (blank - 1))[0]
|
valid_ind = np.where(ind != (blank - 1))[0]
|
||||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||||
|
if not valid_ind:
|
||||||
|
continue
|
||||||
# rec_res.append([preds_text, score])
|
# rec_res.append([preds_text, score])
|
||||||
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
|
rec_res[indices[beg_img_no + rno]] = [preds_text, score]
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -99,6 +99,8 @@ def main():
|
||||||
ind = np.argmax(probs, axis=1)
|
ind = np.argmax(probs, axis=1)
|
||||||
blank = probs.shape[1]
|
blank = probs.shape[1]
|
||||||
valid_ind = np.where(ind != (blank - 1))[0]
|
valid_ind = np.where(ind != (blank - 1))[0]
|
||||||
|
if not valid_ind:
|
||||||
|
continue
|
||||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||||
elif loss_type == "attention":
|
elif loss_type == "attention":
|
||||||
preds = np.array(predict[0])
|
preds = np.array(predict[0])
|
||||||
|
|
|
@ -36,7 +36,7 @@ set_paddle_flags(
|
||||||
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
|
||||||
)
|
)
|
||||||
|
|
||||||
import program
|
import tools.program as program
|
||||||
from paddle import fluid
|
from paddle import fluid
|
||||||
from ppocr.utils.utility import initial_logger
|
from ppocr.utils.utility import initial_logger
|
||||||
logger = initial_logger()
|
logger = initial_logger()
|
||||||
|
@ -106,6 +106,27 @@ def main():
|
||||||
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reader():
|
||||||
|
config = program.load_config(FLAGS.config)
|
||||||
|
program.merge_config(FLAGS.opt)
|
||||||
|
print(config)
|
||||||
|
train_reader = reader_main(config=config, mode="train")
|
||||||
|
import time
|
||||||
|
starttime = time.time()
|
||||||
|
count = 0
|
||||||
|
try:
|
||||||
|
for data in train_reader():
|
||||||
|
count += 1
|
||||||
|
if count % 1 == 0:
|
||||||
|
batch_time = time.time() - starttime
|
||||||
|
starttime = time.time()
|
||||||
|
print("reader:", count, len(data), batch_time)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
print("finish reader:", count)
|
||||||
|
print("success")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = program.ArgsParser()
|
parser = program.ArgsParser()
|
||||||
FLAGS = parser.parse_args()
|
FLAGS = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue