Merge pull request #3232 from bingooo/dygraph

fix bug
This commit is contained in:
Double_V 2021-07-05 10:28:18 +08:00 committed by GitHub
commit 86d469d025
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -91,14 +91,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
def load_dygraph_params(config, model, logger, optimizer): def load_dygraph_params(config, model, logger, optimizer):
ckp = config['Global']['checkpoints'] ckp = config['Global']['checkpoints']
if ckp and os.path.exists(ckp): if ckp and os.path.exists(ckp + ".pdparams"):
pre_best_model_dict = init_model(config, model, optimizer) pre_best_model_dict = init_model(config, model, optimizer)
return pre_best_model_dict return pre_best_model_dict
else: else:
pm = config['Global']['pretrained_model'] pm = config['Global']['pretrained_model']
if pm is None: if pm is None:
return {} return {}
if not os.path.exists(pm) or not os.path.exists(pm + ".pdparams"): if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
logger.info(f"The pretrained_model {pm} does not exists!") logger.info(f"The pretrained_model {pm} does not exists!")
return {} return {}
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams' pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'