commit
86d469d025
|
@ -91,14 +91,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
|
||||||
|
|
||||||
def load_dygraph_params(config, model, logger, optimizer):
|
def load_dygraph_params(config, model, logger, optimizer):
|
||||||
ckp = config['Global']['checkpoints']
|
ckp = config['Global']['checkpoints']
|
||||||
if ckp and os.path.exists(ckp):
|
if ckp and os.path.exists(ckp + ".pdparams"):
|
||||||
pre_best_model_dict = init_model(config, model, optimizer)
|
pre_best_model_dict = init_model(config, model, optimizer)
|
||||||
return pre_best_model_dict
|
return pre_best_model_dict
|
||||||
else:
|
else:
|
||||||
pm = config['Global']['pretrained_model']
|
pm = config['Global']['pretrained_model']
|
||||||
if pm is None:
|
if pm is None:
|
||||||
return {}
|
return {}
|
||||||
if not os.path.exists(pm) or not os.path.exists(pm + ".pdparams"):
|
if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
|
||||||
logger.info(f"The pretrained_model {pm} does not exists!")
|
logger.info(f"The pretrained_model {pm} does not exists!")
|
||||||
return {}
|
return {}
|
||||||
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
|
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
|
||||||
|
|
Loading…
Reference in New Issue