diff --git a/tools/train.py b/tools/train.py index 29205483..c9d6e221 100755 --- a/tools/train.py +++ b/tools/train.py @@ -42,27 +42,10 @@ from ppocr.utils.utility import initial_logger logger = initial_logger() from ppocr.data.reader_main import reader_main from ppocr.utils.save_load import init_model -from ppocr.utils.character import CharacterOps from paddle.fluid.contrib.model_stat import summary def main(): - config = program.load_config(FLAGS.config) - program.merge_config(FLAGS.opt) - logger.info(config) - - # check if set use_gpu=True in paddlepaddle cpu version - use_gpu = config['Global']['use_gpu'] - program.check_gpu(use_gpu) - - alg = config['Global']['algorithm'] - assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE'] - if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']: - config['Global']['char_ops'] = CharacterOps(config['Global']) - - place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() - startup_program = fluid.Program() - train_program = fluid.Program() train_build_outputs = program.build( config, train_program, startup_program, mode='train') train_loader = train_build_outputs[0] @@ -109,7 +92,7 @@ def main(): 'fetch_name_list':eval_fetch_name_list,\ 'fetch_varname_list':eval_fetch_varname_list} - if alg in ['EAST', 'DB']: + if isContain_det: program.train_eval_det_run(config, exe, train_info_dict, eval_info_dict) else: program.train_eval_rec_run(config, exe, train_info_dict, eval_info_dict) @@ -136,7 +119,6 @@ def test_reader(): if __name__ == '__main__': - parser = program.ArgsParser() - FLAGS = parser.parse_args() + startup_program, train_program, place, config, isContain_det = program.preProcess() main() # test_reader()