fix infer_rec for benchmark

This commit is contained in:
tink2123 2020-05-21 11:11:36 +08:00
parent 6de43fbb47
commit dd0112f52b
2 changed files with 44 additions and 31 deletions

View File

@ -10,3 +10,4 @@ EvalReader:
TestReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
infer_img:

View File

@ -42,9 +42,11 @@ class LMDBReader(object):
self.mode = params['mode']
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
else:
elif params['mode'] == "eval":
self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test":
self.batch_size = 1
self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {}
dataset_idx = 0
@ -97,6 +99,15 @@ class LMDBReader(object):
process_id = 0
def sample_iter_reader():
if self.mode == 'test':
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape)
yield norm_img
else:
lmdb_sets = self.load_hierarchical_lmdb_dataset()
if process_id == 0:
self.print_lmdb_sets_info(lmdb_sets)
@ -124,7 +135,6 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets):
break
self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader():
batch_outs = []
for outs in sample_iter_reader():
@ -135,7 +145,9 @@ class LMDBReader(object):
if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'test':
return batch_iter_reader
return sample_iter_reader
class SimpleReader(object):