Merge pull request #279 from LDOUBLEV/fixocr

fix Nan results and add test_reader func
This commit is contained in:
Double_V 2020-07-01 17:12:01 +08:00 committed by GitHub
commit f2c4d6697e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 1 deletions

View File

@ -122,6 +122,8 @@ class TextRecognizer(object):
blank = probs.shape[1] blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0] valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
if len(valid_ind) == 0:
continue
# rec_res.append([preds_text, score]) # rec_res.append([preds_text, score])
rec_res[indices[beg_img_no + rno]] = [preds_text, score] rec_res[indices[beg_img_no + rno]] = [preds_text, score]
else: else:

View File

@ -99,6 +99,8 @@ def main():
ind = np.argmax(probs, axis=1) ind = np.argmax(probs, axis=1)
blank = probs.shape[1] blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0] valid_ind = np.where(ind != (blank - 1))[0]
if len(valid_ind) == 0:
continue
score = np.mean(probs[valid_ind, ind[valid_ind]]) score = np.mean(probs[valid_ind, ind[valid_ind]])
elif loss_type == "attention": elif loss_type == "attention":
preds = np.array(predict[0]) preds = np.array(predict[0])

View File

@ -36,7 +36,7 @@ set_paddle_flags(
FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory FLAGS_eager_delete_tensor_gb=0, # enable GC to save memory
) )
import program import tools.program as program
from paddle import fluid from paddle import fluid
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
logger = initial_logger() logger = initial_logger()
@ -106,6 +106,26 @@ def main():
program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict)
def test_reader():
config = program.load_config(FLAGS.config)
program.merge_config(FLAGS.opt)
print(config)
train_reader = reader_main(config=config, mode="train")
import time
starttime = time.time()
count = 0
try:
for data in train_reader():
count += 1
if count % 1 == 0:
batch_time = time.time() - starttime
starttime = time.time()
print("reader:", count, len(data), batch_time)
except Exception as e:
logger.info(e)
logger.info("finish reader: {}, Success!".format(count))
if __name__ == '__main__': if __name__ == '__main__':
parser = program.ArgsParser() parser = program.ArgsParser()
FLAGS = parser.parse_args() FLAGS = parser.parse_args()