diff --git a/tools/program.py b/tools/program.py index ff8743f1..6197b91b 100755 --- a/tools/program.py +++ b/tools/program.py @@ -22,6 +22,7 @@ import yaml import os from ppocr.utils.utility import create_module from ppocr.utils.utility import initial_logger + logger = initial_logger() import paddle.fluid as fluid @@ -31,8 +32,7 @@ from eval_utils.eval_det_utils import eval_det_run from eval_utils.eval_rec_utils import eval_rec_run from ppocr.utils.save_load import save_model import numpy as np -from ppocr.utils.character import cal_predicts_accuracy - +from ppocr.utils.character import cal_predicts_accuracy, CharacterOps class ArgsParser(ArgumentParser): def __init__(self): @@ -374,3 +374,28 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict): save_path = save_model_dir + "/iter_epoch_%d" % (epoch) save_model(train_info_dict['train_program'], save_path) return + +def preProcess(): + FLAGS = ArgsParser().parse_args() + config = load_config(FLAGS.config) + merge_config(FLAGS.opt) + logger.info(config) + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + check_gpu(use_gpu) + + alg = config['Global']['algorithm'] + assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE'] + if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']: + config['Global']['char_ops'] = CharacterOps(config['Global']) + + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + startup_program = fluid.Program() + train_program = fluid.Program() + + isContain_det = False + if alg in ['EAST', 'DB']: + isContain_det = True + + return startup_program, train_program, place, config, isContain_det