diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 6a379853..a55f671e 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -26,34 +26,27 @@ import time import paddle.fluid as fluid import tools.infer.utility as utility -from ppocr.utils.utility import initial_logger -logger = initial_logger() +from ppocr.postprocess import build_post_process +from ppocr.utils.logging import get_logger from ppocr.utils.utility import get_image_file_list, check_and_read_gif -from ppocr.utils.character import CharacterOps class TextRecognizer(object): def __init__(self, args): - self.predictor, self.input_tensor, self.output_tensors =\ - utility.create_predictor(args, mode="rec") self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")] self.character_type = args.rec_char_type self.rec_batch_num = args.rec_batch_num self.rec_algorithm = args.rec_algorithm self.use_zero_copy_run = args.use_zero_copy_run - char_ops_params = { + postprocess_params = { + 'name': 'CTCLabelDecode', "character_type": args.rec_char_type, "character_dict_path": args.rec_char_dict_path, - "use_space_char": args.use_space_char, - "max_text_length": args.max_text_length + "use_space_char": args.use_space_char } - if self.rec_algorithm != "RARE": - char_ops_params['loss_type'] = 'ctc' - self.loss_type = 'ctc' - else: - char_ops_params['loss_type'] = 'attention' - self.loss_type = 'attention' - self.char_ops = CharacterOps(char_ops_params) + self.postprocess_op = build_post_process(postprocess_params) + self.predictor, self.input_tensor, self.output_tensors = \ + utility.create_predictor(args, 'rec', logger) def resize_norm_img(self, img, max_wh_ratio): imgC, imgH, imgW = self.rec_image_shape @@ -112,48 +105,14 @@ class TextRecognizer(object): else: norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) self.predictor.run([norm_img_batch]) - - if self.loss_type == "ctc": - rec_idx_batch = self.output_tensors[0].copy_to_cpu() - rec_idx_lod = self.output_tensors[0].lod()[0] - predict_batch = self.output_tensors[1].copy_to_cpu() - predict_lod = self.output_tensors[1].lod()[0] - elapse = time.time() - starttime - predict_time += elapse - for rno in range(len(rec_idx_lod) - 1): - beg = rec_idx_lod[rno] - end = rec_idx_lod[rno + 1] - rec_idx_tmp = rec_idx_batch[beg:end, 0] - preds_text = self.char_ops.decode(rec_idx_tmp) - beg = predict_lod[rno] - end = predict_lod[rno + 1] - probs = predict_batch[beg:end, :] - ind = np.argmax(probs, axis=1) - blank = probs.shape[1] - valid_ind = np.where(ind != (blank - 1))[0] - if len(valid_ind) == 0: - continue - score = np.mean(probs[valid_ind, ind[valid_ind]]) - # rec_res.append([preds_text, score]) - rec_res[indices[beg_img_no + rno]] = [preds_text, score] - else: - rec_idx_batch = self.output_tensors[0].copy_to_cpu() - predict_batch = self.output_tensors[1].copy_to_cpu() - elapse = time.time() - starttime - predict_time += elapse - for rno in range(len(rec_idx_batch)): - end_pos = np.where(rec_idx_batch[rno, :] == 1)[0] - if len(end_pos) <= 1: - preds = rec_idx_batch[rno, 1:] - score = np.mean(predict_batch[rno, 1:]) - else: - preds = rec_idx_batch[rno, 1:end_pos[1]] - score = np.mean(predict_batch[rno, 1:end_pos[1]]) - preds_text = self.char_ops.decode(preds) - # rec_res.append([preds_text, score]) - rec_res[indices[beg_img_no + rno]] = [preds_text, score] - - return rec_res, predict_time + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = outputs[0] + rec_res = self.postprocess_op(preds) + elapse = time.time() - starttime + return rec_res, elapse def main(args): @@ -183,9 +142,10 @@ def main(args): exit() for ino in range(len(img_list)): print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) - print("Total predict time for %d images:%.3f" % + print("Total predict time for %d images, cost: %.3f" % (len(img_list), predict_time)) if __name__ == "__main__": + logger = get_logger() main(utility.parse_args()) diff --git a/tools/program.py b/tools/program.py index 8bae0fd5..c2b9306c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -323,6 +323,20 @@ def eval(model, valid_dataloader, post_process_class, eval_class): return metirc +def save_inference_mode(model, config, logger): + model.eval() + save_path = '{}/infer/{}'.format(config['Global']['save_model_dir'], + config['Architecture']['model_type']) + if config['Architecture']['model_type'] == 'rec': + input_shape = [None, 3, 32, None] + jit_model = paddle.jit.to_static( + model, input_spec=[paddle.static.InputSpec(input_shape)]) + paddle.jit.save(jit_model, save_path) + logger.info('inference model save to {}'.format(save_path)) + + model.train() + + def preprocess(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) @@ -334,7 +348,7 @@ def preprocess(): alg = config['Architecture']['algorithm'] assert alg in [ - 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN' + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'CLS' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' diff --git a/tools/train.py b/tools/train.py index c1622379..1cf644e6 100755 --- a/tools/train.py +++ b/tools/train.py @@ -89,6 +89,7 @@ def main(config, device, logger, vdl_writer): program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer) + program.save_inference_mode(model, config, logger) def test_reader(config, device, logger): @@ -102,8 +103,8 @@ def test_reader(config, device, logger): if count % 1 == 0: batch_time = time.time() - starttime starttime = time.time() - logger.info("reader: {}, {}, {}".format(count, - len(data), batch_time)) + logger.info("reader: {}, {}, {}".format( + count, len(data[0]), batch_time)) except Exception as e: logger.info(e) logger.info("finish reader: {}, Success!".format(count)) @@ -112,4 +113,4 @@ def test_reader(config, device, logger): if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess() main(config, device, logger, vdl_writer) -# test_reader(config, device, logger) + # test_reader(config, device, logger)