fix infer_rec for benchmark
This commit is contained in:
parent
6de43fbb47
commit
dd0112f52b
|
@ -9,4 +9,5 @@ EvalReader:
|
||||||
|
|
||||||
TestReader:
|
TestReader:
|
||||||
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
|
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
|
||||||
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
|
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
|
||||||
|
infer_img:
|
||||||
|
|
|
@ -42,9 +42,11 @@ class LMDBReader(object):
|
||||||
self.mode = params['mode']
|
self.mode = params['mode']
|
||||||
if params['mode'] == 'train':
|
if params['mode'] == 'train':
|
||||||
self.batch_size = params['train_batch_size_per_card']
|
self.batch_size = params['train_batch_size_per_card']
|
||||||
else:
|
elif params['mode'] == "eval":
|
||||||
self.batch_size = params['test_batch_size_per_card']
|
self.batch_size = params['test_batch_size_per_card']
|
||||||
|
elif params['mode'] == "test":
|
||||||
|
self.batch_size = 1
|
||||||
|
self.infer_img = params["infer_img"]
|
||||||
def load_hierarchical_lmdb_dataset(self):
|
def load_hierarchical_lmdb_dataset(self):
|
||||||
lmdb_sets = {}
|
lmdb_sets = {}
|
||||||
dataset_idx = 0
|
dataset_idx = 0
|
||||||
|
@ -97,34 +99,42 @@ class LMDBReader(object):
|
||||||
process_id = 0
|
process_id = 0
|
||||||
|
|
||||||
def sample_iter_reader():
|
def sample_iter_reader():
|
||||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
if self.mode == 'test':
|
||||||
if process_id == 0:
|
image_file_list = get_image_file_list(self.infer_img)
|
||||||
self.print_lmdb_sets_info(lmdb_sets)
|
for single_img in image_file_list:
|
||||||
cur_index_sets = [1 + process_id] * len(lmdb_sets)
|
img = cv2.imread(single_img)
|
||||||
while True:
|
if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||||
finish_read_num = 0
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
for dataset_idx in range(len(lmdb_sets)):
|
norm_img = process_image(img, self.image_shape)
|
||||||
cur_index = cur_index_sets[dataset_idx]
|
yield norm_img
|
||||||
if cur_index > lmdb_sets[dataset_idx]['num_samples']:
|
else:
|
||||||
finish_read_num += 1
|
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||||
else:
|
if process_id == 0:
|
||||||
sample_info = self.get_lmdb_sample_info(
|
self.print_lmdb_sets_info(lmdb_sets)
|
||||||
lmdb_sets[dataset_idx]['txn'], cur_index)
|
cur_index_sets = [1 + process_id] * len(lmdb_sets)
|
||||||
cur_index_sets[dataset_idx] += self.num_workers
|
while True:
|
||||||
if sample_info is None:
|
finish_read_num = 0
|
||||||
continue
|
for dataset_idx in range(len(lmdb_sets)):
|
||||||
img, label = sample_info
|
cur_index = cur_index_sets[dataset_idx]
|
||||||
outs = process_image(img, self.image_shape, label,
|
if cur_index > lmdb_sets[dataset_idx]['num_samples']:
|
||||||
self.char_ops, self.loss_type,
|
finish_read_num += 1
|
||||||
self.max_text_length)
|
else:
|
||||||
if outs is None:
|
sample_info = self.get_lmdb_sample_info(
|
||||||
continue
|
lmdb_sets[dataset_idx]['txn'], cur_index)
|
||||||
yield outs
|
cur_index_sets[dataset_idx] += self.num_workers
|
||||||
|
if sample_info is None:
|
||||||
if finish_read_num == len(lmdb_sets):
|
continue
|
||||||
break
|
img, label = sample_info
|
||||||
self.close_lmdb_dataset(lmdb_sets)
|
outs = process_image(img, self.image_shape, label,
|
||||||
|
self.char_ops, self.loss_type,
|
||||||
|
self.max_text_length)
|
||||||
|
if outs is None:
|
||||||
|
continue
|
||||||
|
yield outs
|
||||||
|
|
||||||
|
if finish_read_num == len(lmdb_sets):
|
||||||
|
break
|
||||||
|
self.close_lmdb_dataset(lmdb_sets)
|
||||||
def batch_iter_reader():
|
def batch_iter_reader():
|
||||||
batch_outs = []
|
batch_outs = []
|
||||||
for outs in sample_iter_reader():
|
for outs in sample_iter_reader():
|
||||||
|
@ -135,7 +145,9 @@ class LMDBReader(object):
|
||||||
if len(batch_outs) != 0:
|
if len(batch_outs) != 0:
|
||||||
yield batch_outs
|
yield batch_outs
|
||||||
|
|
||||||
return batch_iter_reader
|
if self.mode != 'test':
|
||||||
|
return batch_iter_reader
|
||||||
|
return sample_iter_reader
|
||||||
|
|
||||||
|
|
||||||
class SimpleReader(object):
|
class SimpleReader(object):
|
||||||
|
|
Loading…
Reference in New Issue