add new load dygraph func (#3088)
* add new load dygraph func * update load_pretrain_params * update load_dygrah_params * Update save_load.py * Update train.py * Update save_load.py * return {} when path is None * return {} when path is None
This commit is contained in:
parent
6f64faeab4
commit
be181cb3bd
|
@ -25,7 +25,7 @@ import paddle
|
|||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_params']
|
||||
|
||||
|
||||
def _mkdir_if_not_exist(path, logger):
|
||||
|
@ -89,6 +89,34 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
|
|||
return best_model_dict
|
||||
|
||||
|
||||
def load_dygraph_params(config, model, logger, optimizer):
|
||||
ckp = config['Global']['checkpoints']
|
||||
if ckp and os.path.exists(ckp):
|
||||
pre_best_model_dict = init_model(config, model, optimizer)
|
||||
return pre_best_model_dict
|
||||
else:
|
||||
pm = config['Global']['pretrained_model']
|
||||
if pm is None:
|
||||
return {}
|
||||
if not os.path.exists(pm) or not os.path.exists(pm + ".pdparams"):
|
||||
logger.info(f"The pretrained_model {pm} does not exists!")
|
||||
return {}
|
||||
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
|
||||
params = paddle.load(pm)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
model_path,
|
||||
|
|
|
@ -35,7 +35,7 @@ from ppocr.losses import build_loss
|
|||
from ppocr.optimizer import build_optimizer
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_dygraph_params
|
||||
import tools.program as program
|
||||
|
||||
dist.get_world_size()
|
||||
|
@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
|
|||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, optimizer)
|
||||
pre_best_model_dict = load_dygraph_params(config, model, logger, optimizer)
|
||||
|
||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||
if valid_dataloader is not None:
|
||||
|
|
Loading…
Reference in New Issue