rc版本适配
This commit is contained in:
parent
44840726ff
commit
4d775dc98f
|
@ -68,11 +68,11 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
|
|||
param_state_dict[key] = pre_state_dict[weight_name]
|
||||
else:
|
||||
param_state_dict[key] = model_dict[key]
|
||||
model.set_dict(param_state_dict)
|
||||
model.set_state_dict(param_state_dict)
|
||||
return
|
||||
|
||||
param_state_dict, optim_state_dict = paddle.load(path)
|
||||
model.set_dict(param_state_dict)
|
||||
param_state_dict = paddle.load(path + '.pdparams')
|
||||
model.set_state_dict(param_state_dict)
|
||||
return
|
||||
|
||||
|
||||
|
@ -91,7 +91,7 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
|||
"Given dir {}.pdopt not exist.".format(checkpoints)
|
||||
para_dict = paddle.load(checkpoints + '.pdparams')
|
||||
opti_dict = paddle.load(checkpoints + '.pdopt')
|
||||
model.set_dict(para_dict)
|
||||
model.set_state_dict(para_dict)
|
||||
if optimizer is not None:
|
||||
optimizer.set_state_dict(opti_dict)
|
||||
|
||||
|
|
Loading…
Reference in New Issue