Merge pull request #83 from tink2123/fix_infer
fix infer_rec for benchmark
This commit is contained in:
commit
96ead92e45
|
@ -42,9 +42,11 @@ class LMDBReader(object):
|
|||
self.mode = params['mode']
|
||||
if params['mode'] == 'train':
|
||||
self.batch_size = params['train_batch_size_per_card']
|
||||
else:
|
||||
elif params['mode'] == "eval":
|
||||
self.batch_size = params['test_batch_size_per_card']
|
||||
|
||||
elif params['mode'] == "test":
|
||||
self.batch_size = 1
|
||||
self.infer_img = params["infer_img"]
|
||||
def load_hierarchical_lmdb_dataset(self):
|
||||
lmdb_sets = {}
|
||||
dataset_idx = 0
|
||||
|
@ -97,6 +99,15 @@ class LMDBReader(object):
|
|||
process_id = 0
|
||||
|
||||
def sample_iter_reader():
|
||||
if self.mode == 'test':
|
||||
image_file_list = get_image_file_list(self.infer_img)
|
||||
for single_img in image_file_list:
|
||||
img = cv2.imread(single_img)
|
||||
if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
norm_img = process_image(img, self.image_shape)
|
||||
yield norm_img
|
||||
else:
|
||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||
if process_id == 0:
|
||||
self.print_lmdb_sets_info(lmdb_sets)
|
||||
|
@ -124,7 +135,6 @@ class LMDBReader(object):
|
|||
if finish_read_num == len(lmdb_sets):
|
||||
break
|
||||
self.close_lmdb_dataset(lmdb_sets)
|
||||
|
||||
def batch_iter_reader():
|
||||
batch_outs = []
|
||||
for outs in sample_iter_reader():
|
||||
|
@ -135,7 +145,9 @@ class LMDBReader(object):
|
|||
if len(batch_outs) != 0:
|
||||
yield batch_outs
|
||||
|
||||
if self.mode != 'test':
|
||||
return batch_iter_reader
|
||||
return sample_iter_reader
|
||||
|
||||
|
||||
class SimpleReader(object):
|
||||
|
|
Loading…
Reference in New Issue