rc版本适配

This commit is contained in:
WenmuZhou 2020-11-09 18:20:03 +08:00
parent 44840726ff
commit 4d775dc98f
1 changed files with 4 additions and 4 deletions

View File

@ -68,11 +68,11 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_dict(param_state_dict)
model.set_state_dict(param_state_dict)
return
param_state_dict, optim_state_dict = paddle.load(path)
model.set_dict(param_state_dict)
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
return
@ -91,7 +91,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
"Given dir {}.pdopt not exist.".format(checkpoints)
para_dict = paddle.load(checkpoints + '.pdparams')
opti_dict = paddle.load(checkpoints + '.pdopt')
model.set_dict(para_dict)
model.set_state_dict(para_dict)
if optimizer is not None:
optimizer.set_state_dict(opti_dict)