del lable, polish data_reader
This commit is contained in:
parent
d27ae42a45
commit
d6778e2519
|
@ -170,10 +170,14 @@ class SimpleReader(object):
|
|||
image_file_list = [self.infer_img]
|
||||
elif os.path.isdir(self.infer_img):
|
||||
for single_file in os.listdir(self.infer_img):
|
||||
if single_file.endswith('png') or single_file.endswith('jpg'):
|
||||
image_file_list.append(os.path.join(self.infer_img, single_file))
|
||||
if single_file.split('.')[
|
||||
-1] not in ['bmp', 'jpg', 'jpeg', 'png', 'JPEG', 'JPG', 'PNG']:
|
||||
continue
|
||||
image_file_list.append(os.path.join(self.infer_img, single_file))
|
||||
for single_img in image_file_list:
|
||||
img = cv2.imread(single_img)
|
||||
if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
norm_img = process_image(img, self.image_shape)
|
||||
yield norm_img
|
||||
with open(self.label_file_path, "rb") as fin:
|
||||
|
|
|
@ -79,7 +79,11 @@ def main():
|
|||
init_model(config, eval_prog, exe)
|
||||
|
||||
blobs = reader_main(config, 'test')()
|
||||
infer_list = os.listdir(config['Global']['infer_img'])
|
||||
infer_img = config['Global']['infer_img']
|
||||
if os.path.isfile(infer_img):
|
||||
infer_list = [infer_img]
|
||||
elif os.path.isdir(infer_img):
|
||||
infer_list = os.listdir(config['Global']['infer_img'])
|
||||
max_img_num = len(infer_list)
|
||||
if len(infer_list) == 0:
|
||||
logger.info("Can not find img in infer_img dir.")
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue