From ae124590150341ca11a833f5f74f175c7fb7799a Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Fri, 18 Dec 2020 18:51:19 +0800 Subject: [PATCH] Save configuration files and logs only during training --- tools/program.py | 22 ++++++++++++---------- tools/train.py | 2 +- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tools/program.py b/tools/program.py index 787a59d4..4331f9d4 100755 --- a/tools/program.py +++ b/tools/program.py @@ -332,7 +332,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class): return metirc -def preprocess(): +def preprocess(is_train=False): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) merge_config(FLAGS.opt) @@ -350,15 +350,17 @@ def preprocess(): device = paddle.set_device(device) config['Global']['distributed'] = dist.get_world_size() != 1 - - # save_config - save_model_dir = config['Global']['save_model_dir'] - os.makedirs(save_model_dir, exist_ok=True) - with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: - yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False) - - logger = get_logger( - name='root', log_file='{}/train.log'.format(save_model_dir)) + if is_train: + # save_config + save_model_dir = config['Global']['save_model_dir'] + os.makedirs(save_model_dir, exist_ok=True) + with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: + yaml.dump( + dict(config), f, default_flow_style=False, sort_keys=False) + log_file = '{}/train.log'.format(save_model_dir) + else: + log_file = None + logger = get_logger(name='root', log_file=log_file) if config['Global']['use_visualdl']: from visualdl import LogWriter vdl_writer_path = '{}/vdl/'.format(save_model_dir) diff --git a/tools/train.py b/tools/train.py index 6e44c598..383f8d83 100755 --- a/tools/train.py +++ b/tools/train.py @@ -110,6 +110,6 @@ def test_reader(config, device, logger): if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess() + config, device, logger, vdl_writer = program.preprocess(is_train=True) main(config, device, logger, vdl_writer) # test_reader(config, device, logger)