diff --git a/configs/rec/rec_benchmark_reader.yml b/configs/rec/rec_benchmark_reader.yml index 3d1e3e0b..524f2f68 100755 --- a/configs/rec/rec_benchmark_reader.yml +++ b/configs/rec/rec_benchmark_reader.yml @@ -10,4 +10,3 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,LMDBReader lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ - infer_img: ./infer_img diff --git a/configs/rec/rec_chinese_lite_train.yml b/configs/rec/rec_chinese_lite_train.yml index ec1b7a69..cbc43e06 100755 --- a/configs/rec/rec_chinese_lite_train.yml +++ b/configs/rec/rec_chinese_lite_train.yml @@ -18,6 +18,8 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_chinese_reader.yml b/configs/rec/rec_chinese_reader.yml index f09a1ea7..a44efd99 100755 --- a/configs/rec/rec_chinese_reader.yml +++ b/configs/rec/rec_chinese_reader.yml @@ -11,4 +11,3 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,SimpleReader - infer_img: ./infer_img diff --git a/configs/rec/rec_icdar15_reader.yml b/configs/rec/rec_icdar15_reader.yml index 12facda1..322d5f25 100755 --- a/configs/rec/rec_icdar15_reader.yml +++ b/configs/rec/rec_icdar15_reader.yml @@ -11,4 +11,3 @@ EvalReader: TestReader: reader_function: ppocr.data.rec.dataset_traversal,SimpleReader - infer_img: ./infer_img diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml index 6596fc33..dacf3243 100755 --- a/configs/rec/rec_icdar15_train.yml +++ b/configs/rec/rec_icdar15_train.yml @@ -17,6 +17,8 @@ Global: pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 11a09ee9..951c83cc 100755 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -1,6 +1,6 @@ Global: algorithm: CRNN - use_gpu: true + use_gpu: false epoch_num: 72 log_smooth_window: 20 print_batch_step: 10 @@ -14,7 +14,7 @@ Global: character_type: en loss_type: ctc reader_yml: ./configs/rec/rec_benchmark_reader.yml - pretrain_weights: + pretrain_weights: ./output/rec_CRNN/rec_mv3_none_bilstm_ctc/best_accuracy checkpoints: save_inference_dir: diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml index bbbb6d1f..ceec09ce 100755 --- a/configs/rec/rec_mv3_none_none_ctc.yml +++ b/configs/rec/rec_mv3_none_none_ctc.yml @@ -17,6 +17,7 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_mv3_tps_bilstm_attn.yml b/configs/rec/rec_mv3_tps_bilstm_attn.yml index 03a2e901..d2fb512f 100755 --- a/configs/rec/rec_mv3_tps_bilstm_attn.yml +++ b/configs/rec/rec_mv3_tps_bilstm_attn.yml @@ -17,7 +17,9 @@ Global: pretrain_weights: checkpoints: save_inference_dir: - + infer_img: + + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml index 47247b72..bc5780bd 100755 --- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml @@ -17,6 +17,7 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: Architecture: diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml index 10181936..b71e8fea 100755 --- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml @@ -17,7 +17,9 @@ Global: pretrain_weights: checkpoints: save_inference_dir: - + infer_img: + + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml index ff4c5763..d9c9458d 100755 --- a/configs/rec/rec_r34_vd_none_none_ctc.yml +++ b/configs/rec/rec_r34_vd_none_none_ctc.yml @@ -17,6 +17,7 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_r34_vd_tps_bilstm_attn.yml b/configs/rec/rec_r34_vd_tps_bilstm_attn.yml index 4d96e9e7..405082bd 100755 --- a/configs/rec/rec_r34_vd_tps_bilstm_attn.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_attn.yml @@ -17,6 +17,8 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml index 844721a2..517322c3 100755 --- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml @@ -17,6 +17,8 @@ Global: pretrain_weights: checkpoints: save_inference_dir: + infer_img: + Architecture: function: ppocr.modeling.architectures.rec_model,RecModel diff --git a/doc/recognition.md b/doc/recognition.md index a5a8119c..ea38c0f3 100644 --- a/doc/recognition.md +++ b/doc/recognition.md @@ -184,7 +184,7 @@ python3 tools/eval.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkp ``` # 预测英文结果 -python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg +python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png ``` 预测图片: diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index f60b9fe3..9d0d2e96 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -43,11 +43,10 @@ class LMDBReader(object): self.mode = params['mode'] if params['mode'] == 'train': self.batch_size = params['train_batch_size_per_card'] - elif params['mode'] == "eval": + else: self.batch_size = params['test_batch_size_per_card'] - elif params['mode'] == "test": - self.batch_size = 1 - self.infer_img = params["infer_img"] + self.infer_img = params['infer_img'] + def load_hierarchical_lmdb_dataset(self): lmdb_sets = {} dataset_idx = 0 @@ -100,11 +99,11 @@ class LMDBReader(object): process_id = 0 def sample_iter_reader(): - if self.mode == 'test': + if self.infer_img is not None: image_file_list = get_image_file_list(self.infer_img) for single_img in image_file_list: img = cv2.imread(single_img) - if img.shape[-1]==1 or len(list(img.shape))==2: + if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) norm_img = process_image(img, self.image_shape) yield norm_img @@ -136,6 +135,7 @@ class LMDBReader(object): if finish_read_num == len(lmdb_sets): break self.close_lmdb_dataset(lmdb_sets) + def batch_iter_reader(): batch_outs = [] for outs in sample_iter_reader(): @@ -146,7 +146,7 @@ class LMDBReader(object): if len(batch_outs) != 0: yield batch_outs - if self.mode != 'test': + if self.infer_img is None: return batch_iter_reader return sample_iter_reader @@ -165,24 +165,22 @@ class SimpleReader(object): self.loss_type = params['loss_type'] self.max_text_length = params['max_text_length'] self.mode = params['mode'] + self.infer_img = params['infer_img'] if params['mode'] == 'train': self.batch_size = params['train_batch_size_per_card'] - elif params['mode'] == 'eval': - self.batch_size = params['test_batch_size_per_card'] else: - self.batch_size = 1 - self.infer_img = params['infer_img'] + self.batch_size = params['test_batch_size_per_card'] def __call__(self, process_id): if self.mode != 'train': process_id = 0 def sample_iter_reader(): - if self.mode == 'test': + if self.infer_img is not None: image_file_list = get_image_file_list(self.infer_img) for single_img in image_file_list: img = cv2.imread(single_img) - if img.shape[-1]==1 or len(list(img.shape))==2: + if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) norm_img = process_image(img, self.image_shape) yield norm_img @@ -192,7 +190,7 @@ class SimpleReader(object): img_num = len(label_infor_list) img_id_list = list(range(img_num)) random.shuffle(img_id_list) - if sys.platform=="win32": + if sys.platform == "win32": print("multiprocess is not fully compatible with Windows." "num_workers will be 1.") self.num_workers = 1 @@ -204,7 +202,7 @@ class SimpleReader(object): if img is None: logger.info("{} does not exist!".format(img_path)) continue - if img.shape[-1]==1 or len(list(img.shape))==2: + if img.shape[-1] == 1 or len(list(img.shape)) == 2: img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) label = substr[1] @@ -225,6 +223,6 @@ class SimpleReader(object): if len(batch_outs) != 0: yield batch_outs - if self.mode != 'test': + if self.infer_img is None: return batch_iter_reader return sample_iter_reader diff --git a/tools/eval_utils/eval_rec_utils.py b/tools/eval_utils/eval_rec_utils.py index 3ceaa159..2d7d7e1d 100644 --- a/tools/eval_utils/eval_rec_utils.py +++ b/tools/eval_utils/eval_rec_utils.py @@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): total_sample_num = 0 total_acc_num = 0 total_batch_num = 0 - if mode == "test": + if mode == "eval": is_remove_duplicate = False else: is_remove_duplicate = True @@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict): total_correct_number = 0 eval_data_acc_info = {} for eval_data in eval_data_list: - config['EvalReader']['lmdb_sets_dir'] = \ + config['TestReader']['lmdb_sets_dir'] = \ eval_data_dir + "/" + eval_data - eval_reader = reader_main(config=config, mode="eval") + eval_reader = reader_main(config=config, mode="test") eval_info_dict['reader'] = eval_reader - metrics = eval_rec_run(exe, config, eval_info_dict, "eval") + metrics = eval_rec_run(exe, config, eval_info_dict, "test") total_evaluation_data_number += metrics['total_sample_num'] total_correct_number += metrics['total_acc_num'] eval_data_acc_info[eval_data] = metrics diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 25bae1ca..67e61451 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -21,6 +21,7 @@ import time import multiprocessing import numpy as np + def set_paddle_flags(**kwargs): for key, value in kwargs.items(): if os.environ.get(key, None) is None: @@ -78,13 +79,13 @@ def main(): init_model(config, eval_prog, exe) blobs = reader_main(config, 'test')() - infer_img = config['TestReader']['infer_img'] + infer_img = config['Global']['infer_img'] infer_list = get_image_file_list(infer_img) max_img_num = len(infer_list) if len(infer_list) == 0: logger.info("Can not find img in infer_img dir.") for i in range(max_img_num): - print("infer_img:",infer_list[i]) + print("infer_img:", infer_list[i]) img = next(blobs) predict = exe.run(program=eval_prog, feed={"image": img}, @@ -105,8 +106,8 @@ def main(): preds_text = preds_text.reshape(-1) preds_text = char_ops.decode(preds_text) - print("\t index:",preds) - print("\t word :",preds_text) + print("\t index:", preds) + print("\t word :", preds_text) # save for inference model target_var = []