Merge pull request #1513 from WenmuZhou/dygraph_rc
Save configuration files and logs only during training
This commit is contained in:
commit
99175c6106
|
@ -34,7 +34,6 @@ def parse_args():
|
||||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||||
parser.add_argument("--use_fp16", type=str2bool, default=False)
|
parser.add_argument("--use_fp16", type=str2bool, default=False)
|
||||||
parser.add_argument("--max_batch_size", type=int, default=10)
|
|
||||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||||
|
|
||||||
# params for text detector
|
# params for text detector
|
||||||
|
|
|
@ -332,7 +332,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
||||||
return metirc
|
return metirc
|
||||||
|
|
||||||
|
|
||||||
def preprocess():
|
def preprocess(is_train=False):
|
||||||
FLAGS = ArgsParser().parse_args()
|
FLAGS = ArgsParser().parse_args()
|
||||||
config = load_config(FLAGS.config)
|
config = load_config(FLAGS.config)
|
||||||
merge_config(FLAGS.opt)
|
merge_config(FLAGS.opt)
|
||||||
|
@ -350,15 +350,17 @@ def preprocess():
|
||||||
device = paddle.set_device(device)
|
device = paddle.set_device(device)
|
||||||
|
|
||||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||||
|
if is_train:
|
||||||
# save_config
|
# save_config
|
||||||
save_model_dir = config['Global']['save_model_dir']
|
save_model_dir = config['Global']['save_model_dir']
|
||||||
os.makedirs(save_model_dir, exist_ok=True)
|
os.makedirs(save_model_dir, exist_ok=True)
|
||||||
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
|
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
|
||||||
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
|
yaml.dump(
|
||||||
|
dict(config), f, default_flow_style=False, sort_keys=False)
|
||||||
logger = get_logger(
|
log_file = '{}/train.log'.format(save_model_dir)
|
||||||
name='root', log_file='{}/train.log'.format(save_model_dir))
|
else:
|
||||||
|
log_file = None
|
||||||
|
logger = get_logger(name='root', log_file=log_file)
|
||||||
if config['Global']['use_visualdl']:
|
if config['Global']['use_visualdl']:
|
||||||
from visualdl import LogWriter
|
from visualdl import LogWriter
|
||||||
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
vdl_writer_path = '{}/vdl/'.format(save_model_dir)
|
||||||
|
|
|
@ -110,6 +110,6 @@ def test_reader(config, device, logger):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
config, device, logger, vdl_writer = program.preprocess()
|
config, device, logger, vdl_writer = program.preprocess(is_train=True)
|
||||||
main(config, device, logger, vdl_writer)
|
main(config, device, logger, vdl_writer)
|
||||||
# test_reader(config, device, logger)
|
# test_reader(config, device, logger)
|
||||||
|
|
Loading…
Reference in New Issue