Merge pull request #1513 from WenmuZhou/dygraph_rc
Save configuration files and logs only during training
This commit is contained in:
commit
99175c6106
|
@ -34,7 +34,6 @@ def parse_args():
|
|||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--use_fp16", type=str2bool, default=False)
|
||||
parser.add_argument("--max_batch_size", type=int, default=10)
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||
|
||||
# params for text detector
|
||||
|
|
|
@ -332,7 +332,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
|||
return metirc
|
||||
|
||||
|
||||
def preprocess():
|
||||
def preprocess(is_train=False):
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
|
@ -350,15 +350,17 @@ def preprocess():
|
|||
device = paddle.set_device(device)
|
||||
|
||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||
|
||||
if is_train:
|
||||
# save_config
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
os.makedirs(save_model_dir, exist_ok=True)
|
||||
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
|
||||
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
logger = get_logger(
|
||||
name='root', log_file='{}/train.log'.format(save_model_dir))
|
||||
yaml.dump(
|
||||
dict(config), f, default_flow_style=False, sort_keys=False)
|
||||
log_file = '{}/train.log'.format(save_model_dir)
|
||||
else:
|
||||
log_file = None
|
||||
logger = get_logger(name='root', log_file=log_file)
|
||||
if config['Global']['use_visualdl']:
|
||||
from visualdl import LogWriter
|
||||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
||||
|
|
|
@ -110,6 +110,6 @@ def test_reader(config, device, logger):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
config, device, logger, vdl_writer = program.preprocess(is_train=True)
|
||||
main(config, device, logger, vdl_writer)
|
||||
# test_reader(config, device, logger)
|
||||
|
|
Loading…
Reference in New Issue