update get_img_file_list

This commit is contained in:
tink2123 2020-05-12 20:51:28 +08:00
parent 250cd37a36
commit 2025ed1dfb
2 changed files with 23 additions and 22 deletions

View File

@ -22,7 +22,7 @@ import string
import lmdb import lmdb
from ppocr.utils.utility import initial_logger from ppocr.utils.utility import initial_logger
from tools.infer.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
logger = initial_logger() logger = initial_logger()
from .img_tools import process_image, get_img_data from .img_tools import process_image, get_img_data
@ -173,26 +173,27 @@ class SimpleReader(object):
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(img, self.image_shape)
yield norm_img yield norm_img
with open(self.label_file_path, "rb") as fin: else:
label_infor_list = fin.readlines() with open(self.label_file_path, "rb") as fin:
img_num = len(label_infor_list) label_infor_list = fin.readlines()
img_id_list = list(range(img_num)) img_num = len(label_infor_list)
random.shuffle(img_id_list) img_id_list = list(range(img_num))
for img_id in range(process_id, img_num, self.num_workers): random.shuffle(img_id_list)
label_infor = label_infor_list[img_id_list[img_id]] for img_id in range(process_id, img_num, self.num_workers):
substr = label_infor.decode('utf-8').strip("\n").split("\t") label_infor = label_infor_list[img_id_list[img_id]]
img_path = self.img_set_dir + "/" + substr[0] substr = label_infor.decode('utf-8').strip("\n").split("\t")
img = cv2.imread(img_path) img_path = self.img_set_dir + "/" + substr[0]
if img is None: img = cv2.imread(img_path)
logger.info("{} does not exist!".format(img_path)) if img is None:
continue logger.info("{} does not exist!".format(img_path))
label = substr[1] continue
outs = process_image(img, self.image_shape, label, label = substr[1]
self.char_ops, self.loss_type, outs = process_image(img, self.image_shape, label,
self.max_text_length) self.char_ops, self.loss_type,
if outs is None: self.max_text_length)
continue if outs is None:
yield outs continue
yield outs
def batch_iter_reader(): def batch_iter_reader():
batch_outs = [] batch_outs = []

View File

@ -46,7 +46,7 @@ from ppocr.data.reader_main import reader_main
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model
from ppocr.utils.character import CharacterOps from ppocr.utils.character import CharacterOps
from ppocr.utils.utility import create_module from ppocr.utils.utility import create_module
from tools.infer.utility import get_image_file_list from ppocr.utils.utility import get_image_file_list
logger = initial_logger() logger = initial_logger()