fix infer_rec for benchmark
This commit is contained in:
parent
6de43fbb47
commit
dd0112f52b
|
@ -9,4 +9,5 @@ EvalReader:
|
|||
|
||||
TestReader:
|
||||
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
|
||||
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
|
||||
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
|
||||
infer_img:
|
||||
|
|
|
@ -42,9 +42,11 @@ class LMDBReader(object):
|
|||
self.mode = params['mode']
|
||||
if params['mode'] == 'train':
|
||||
self.batch_size = params['train_batch_size_per_card']
|
||||
else:
|
||||
elif params['mode'] == "eval":
|
||||
self.batch_size = params['test_batch_size_per_card']
|
||||
|
||||
elif params['mode'] == "test":
|
||||
self.batch_size = 1
|
||||
self.infer_img = params["infer_img"]
|
||||
def load_hierarchical_lmdb_dataset(self):
|
||||
lmdb_sets = {}
|
||||
dataset_idx = 0
|
||||
|
@ -97,34 +99,42 @@ class LMDBReader(object):
|
|||
process_id = 0
|
||||
|
||||
def sample_iter_reader():
|
||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||
if process_id == 0:
|
||||
self.print_lmdb_sets_info(lmdb_sets)
|
||||
cur_index_sets = [1 + process_id] * len(lmdb_sets)
|
||||
while True:
|
||||
finish_read_num = 0
|
||||
for dataset_idx in range(len(lmdb_sets)):
|
||||
cur_index = cur_index_sets[dataset_idx]
|
||||
if cur_index > lmdb_sets[dataset_idx]['num_samples']:
|
||||
finish_read_num += 1
|
||||
else:
|
||||
sample_info = self.get_lmdb_sample_info(
|
||||
lmdb_sets[dataset_idx]['txn'], cur_index)
|
||||
cur_index_sets[dataset_idx] += self.num_workers
|
||||
if sample_info is None:
|
||||
continue
|
||||
img, label = sample_info
|
||||
outs = process_image(img, self.image_shape, label,
|
||||
self.char_ops, self.loss_type,
|
||||
self.max_text_length)
|
||||
if outs is None:
|
||||
continue
|
||||
yield outs
|
||||
|
||||
if finish_read_num == len(lmdb_sets):
|
||||
break
|
||||
self.close_lmdb_dataset(lmdb_sets)
|
||||
if self.mode == 'test':
|
||||
image_file_list = get_image_file_list(self.infer_img)
|
||||
for single_img in image_file_list:
|
||||
img = cv2.imread(single_img)
|
||||
if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
norm_img = process_image(img, self.image_shape)
|
||||
yield norm_img
|
||||
else:
|
||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||
if process_id == 0:
|
||||
self.print_lmdb_sets_info(lmdb_sets)
|
||||
cur_index_sets = [1 + process_id] * len(lmdb_sets)
|
||||
while True:
|
||||
finish_read_num = 0
|
||||
for dataset_idx in range(len(lmdb_sets)):
|
||||
cur_index = cur_index_sets[dataset_idx]
|
||||
if cur_index > lmdb_sets[dataset_idx]['num_samples']:
|
||||
finish_read_num += 1
|
||||
else:
|
||||
sample_info = self.get_lmdb_sample_info(
|
||||
lmdb_sets[dataset_idx]['txn'], cur_index)
|
||||
cur_index_sets[dataset_idx] += self.num_workers
|
||||
if sample_info is None:
|
||||
continue
|
||||
img, label = sample_info
|
||||
outs = process_image(img, self.image_shape, label,
|
||||
self.char_ops, self.loss_type,
|
||||
self.max_text_length)
|
||||
if outs is None:
|
||||
continue
|
||||
yield outs
|
||||
|
||||
if finish_read_num == len(lmdb_sets):
|
||||
break
|
||||
self.close_lmdb_dataset(lmdb_sets)
|
||||
def batch_iter_reader():
|
||||
batch_outs = []
|
||||
for outs in sample_iter_reader():
|
||||
|
@ -135,7 +145,9 @@ class LMDBReader(object):
|
|||
if len(batch_outs) != 0:
|
||||
yield batch_outs
|
||||
|
||||
return batch_iter_reader
|
||||
if self.mode != 'test':
|
||||
return batch_iter_reader
|
||||
return sample_iter_reader
|
||||
|
||||
|
||||
class SimpleReader(object):
|
||||
|
|
Loading…
Reference in New Issue