diff --git a/configs/rec/rec_chinese_reader.yml b/configs/rec/rec_icdar15_reader.yml similarity index 64% rename from configs/rec/rec_chinese_reader.yml rename to configs/rec/rec_icdar15_reader.yml index 95e1b500..f09a1ea7 100755 --- a/configs/rec/rec_chinese_reader.yml +++ b/configs/rec/rec_icdar15_reader.yml @@ -1,13 +1,13 @@ TrainReader: reader_function: ppocr.data.rec.dataset_traversal,SimpleReader num_workers: 8 - img_set_dir: . - label_file_path: ./train_data/hard_label.txt + img_set_dir: ./train_data + label_file_path: ./train_data/rec_gt_train.txt EvalReader: reader_function: ppocr.data.rec.dataset_traversal,SimpleReader - img_set_dir: . - label_file_path: ./train_data/label_val_all.txt + img_set_dir: ./train_data + label_file_path: ./train_data/rec_gt_test.txt TestReader: reader_function: ppocr.data.rec.dataset_traversal,SimpleReader diff --git a/configs/rec/rec_chinese_lite_train.yml b/configs/rec/rec_icdar15_train.yml similarity index 77% rename from configs/rec/rec_chinese_lite_train.yml rename to configs/rec/rec_icdar15_train.yml index e821a623..e16264fa 100755 --- a/configs/rec/rec_chinese_lite_train.yml +++ b/configs/rec/rec_icdar15_train.yml @@ -4,7 +4,7 @@ Global: epoch_num: 300 log_smooth_window: 20 print_batch_step: 10 - save_model_dir: output + save_model_dir: output_ic15 save_epoch_step: 3 eval_batch_step: 2000 train_batch_size_per_card: 256 @@ -12,11 +12,12 @@ Global: image_shape: [3, 32, 100] max_text_length: 25 character_type: ch - character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt + character_dict_path: ./ppocr/utils/ic15_dict.txt loss_type: ctc - reader_yml: ./configs/rec/rec_chinese_reader.yml - pretrain_weights: - + reader_yml: ./configs/rec/rec_icdar15_reader.yml + pretrain_weights: ./pretrain_models/CRNN/best_accuracy + checkpoints: + save_inference_dir: Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index b6a5fc10..839448e4 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -22,6 +22,7 @@ import string import lmdb from ppocr.utils.utility import initial_logger +from ppocr.utils.utility import get_image_file_list logger = initial_logger() from .img_tools import process_image, get_img_data @@ -143,8 +144,9 @@ class SimpleReader(object): self.num_workers = 1 else: self.num_workers = params['num_workers'] - self.img_set_dir = params['img_set_dir'] - self.label_file_path = params['label_file_path'] + if params['mode'] != 'test': + self.img_set_dir = params['img_set_dir'] + self.label_file_path = params['label_file_path'] self.char_ops = params['char_ops'] self.image_shape = params['image_shape'] self.loss_type = params['loss_type'] @@ -164,29 +166,34 @@ class SimpleReader(object): def sample_iter_reader(): if self.mode == 'test': - print("infer_img:", self.infer_img) - img = cv2.imread(self.infer_img) - norm_img = process_image(img, self.image_shape) - yield norm_img - with open(self.label_file_path, "rb") as fin: - label_infor_list = fin.readlines() - img_num = len(label_infor_list) - img_id_list = list(range(img_num)) - random.shuffle(img_id_list) - for img_id in range(process_id, img_num, self.num_workers): - label_infor = label_infor_list[img_id_list[img_id]] - substr = label_infor.decode('utf-8').strip("\n").split("\t") - img_path = self.img_set_dir + "/" + substr[0] - img = cv2.imread(img_path) - if img is None: - continue - label = substr[1] - outs = process_image(img, self.image_shape, label, - self.char_ops, self.loss_type, - self.max_text_length) - if outs is None: - continue - yield outs + image_file_list = get_image_file_list(self.infer_img) + for single_img in image_file_list: + img = cv2.imread(single_img) + if img.shape[-1]==1 or len(list(img.shape))==2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + norm_img = process_image(img, self.image_shape) + yield norm_img + else: + with open(self.label_file_path, "rb") as fin: + label_infor_list = fin.readlines() + img_num = len(label_infor_list) + img_id_list = list(range(img_num)) + random.shuffle(img_id_list) + for img_id in range(process_id, img_num, self.num_workers): + label_infor = label_infor_list[img_id_list[img_id]] + substr = label_infor.decode('utf-8').strip("\n").split("\t") + img_path = self.img_set_dir + "/" + substr[0] + img = cv2.imread(img_path) + if img is None: + logger.info("{} does not exist!".format(img_path)) + continue + label = substr[1] + outs = process_image(img, self.image_shape, label, + self.char_ops, self.loss_type, + self.max_text_length) + if outs is None: + continue + yield outs def batch_iter_reader(): batch_outs = [] @@ -198,4 +205,6 @@ class SimpleReader(object): if len(batch_outs) != 0: yield batch_outs - return batch_iter_reader + if self.mode != 'test': + return batch_iter_reader + return sample_iter_reader diff --git a/ppocr/utils/ic15_dict.txt b/ppocr/utils/ic15_dict.txt new file mode 100644 index 00000000..c1f9993d --- /dev/null +++ b/ppocr/utils/ic15_dict.txt @@ -0,0 +1,36 @@ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 diff --git a/set_env.sh b/set_env.sh deleted file mode 100644 index b32e58e3..00000000 --- a/set_env.sh +++ /dev/null @@ -1,12 +0,0 @@ -#. /paddle/set_env.sh↩ -export CUDA_VISIBLE_DEVICES="0,1,2,3"↩ -export PYTHONPATH=$PYTHONPATH:.↩ -export FLAGS_fraction_of_gpu_memory_to_use=1.0↩ -↩ -python_bin_dir="/opt/_internal/cpython-3.7.0/bin/"↩ -alias python=$python_bin_dir"python3.7"↩ -alias pip=$python_bin_dir"pip3.7"↩ -alias ipython=$python_bin_dir"ipython3"↩ -export LD_LIBRARY_PATH=/opt/_internal/cpython-3.7.0/lib:$LD_LIBRARY_PATH↩ -export PYTHONPATH=$PYTHONPATH:.↩ -ldconfig↩ diff --git a/tools/eval.py b/tools/eval.py index 3b176648..949f3a34 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -80,7 +80,7 @@ def main(): metrics = eval_det_run(exe, config, eval_info_dict, "test") else: reader_type = config['Global']['reader_yml'] - if "chinese" in reader_type: + if "benchmark" not in reader_type: eval_reader = reader_main(config=config, mode="eval") eval_info_dict = {'program': eval_program, \ 'reader': eval_reader, \ diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 71977391..de7799d0 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -21,7 +21,6 @@ import time import multiprocessing import numpy as np - def set_paddle_flags(**kwargs): for key, value in kwargs.items(): if os.environ.get(key, None) is None: @@ -47,7 +46,7 @@ from ppocr.data.reader_main import reader_main from ppocr.utils.save_load import init_model from ppocr.utils.character import CharacterOps from ppocr.utils.utility import create_module - +from ppocr.utils.utility import get_image_file_list logger = initial_logger() @@ -79,9 +78,15 @@ def main(): init_model(config, eval_prog, exe) - blobs = reader_main(config, 'test') - imgs = next(blobs()) - for img in imgs: + blobs = reader_main(config, 'test')() + infer_img = config['TestReader']['infer_img'] + infer_list = get_image_file_list(infer_img) + max_img_num = len(infer_list) + if len(infer_list) == 0: + logger.info("Can not find img in infer_img dir.") + for i in range(max_img_num): + print("infer_img:",infer_list[i]) + img = next(blobs) predict = exe.run(program=eval_prog, feed={"image": img}, fetch_list=fetch_varname_list, @@ -101,8 +106,8 @@ def main(): preds_text = preds_text.reshape(-1) preds_text = char_ops.decode(preds_text) - print(preds) - print(preds_text) + print("\t index:",preds) + print("\t word :",preds_text) # save for inference model target_var = []