commit
7fd8d6a205
|
@ -80,9 +80,9 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
||||||
"""
|
"""
|
||||||
load model from checkpoint or pretrained_model
|
load model from checkpoint or pretrained_model
|
||||||
"""
|
"""
|
||||||
gloabl_config = config['Global']
|
global_config = config['Global']
|
||||||
checkpoints = gloabl_config.get('checkpoints')
|
checkpoints = global_config.get('checkpoints')
|
||||||
pretrained_model = gloabl_config.get('pretrained_model')
|
pretrained_model = global_config.get('pretrained_model')
|
||||||
best_model_dict = {}
|
best_model_dict = {}
|
||||||
if checkpoints:
|
if checkpoints:
|
||||||
assert os.path.exists(checkpoints + ".pdparams"), \
|
assert os.path.exists(checkpoints + ".pdparams"), \
|
||||||
|
@ -105,7 +105,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
||||||
|
|
||||||
logger.info("resume from {}".format(checkpoints))
|
logger.info("resume from {}".format(checkpoints))
|
||||||
elif pretrained_model:
|
elif pretrained_model:
|
||||||
load_static_weights = gloabl_config.get('load_static_weights', False)
|
load_static_weights = global_config.get('load_static_weights', False)
|
||||||
if not isinstance(pretrained_model, list):
|
if not isinstance(pretrained_model, list):
|
||||||
pretrained_model = [pretrained_model]
|
pretrained_model = [pretrained_model]
|
||||||
if not isinstance(load_static_weights, list):
|
if not isinstance(load_static_weights, list):
|
||||||
|
|
Loading…
Reference in New Issue