rm load_dyg_pretrain
This commit is contained in:
parent
bd1820b784
commit
48d8537959
|
@ -8,9 +8,9 @@ Global:
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: [0, 2000]
|
eval_batch_step: [0, 2000]
|
||||||
cal_metric_during_train: true
|
cal_metric_during_train: true
|
||||||
pretrained_model: null
|
pretrained_model:
|
||||||
checkpoints: null
|
checkpoints:
|
||||||
save_inference_dir: null
|
save_inference_dir:
|
||||||
use_visualdl: false
|
use_visualdl: false
|
||||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||||
|
@ -38,7 +38,7 @@ Architecture:
|
||||||
algorithm: Distillation
|
algorithm: Distillation
|
||||||
Models:
|
Models:
|
||||||
Student:
|
Student:
|
||||||
pretrained: null
|
pretrained:
|
||||||
freeze_params: false
|
freeze_params: false
|
||||||
return_all_feats: true
|
return_all_feats: true
|
||||||
model_type: rec
|
model_type: rec
|
||||||
|
@ -57,7 +57,7 @@ Architecture:
|
||||||
name: CTCHead
|
name: CTCHead
|
||||||
fc_decay: 0.00001
|
fc_decay: 0.00001
|
||||||
Teacher:
|
Teacher:
|
||||||
pretrained: null
|
pretrained:
|
||||||
freeze_params: false
|
freeze_params: false
|
||||||
return_all_feats: true
|
return_all_feats: true
|
||||||
model_type: rec
|
model_type: rec
|
||||||
|
@ -118,8 +118,8 @@ Train:
|
||||||
- DecodeImage:
|
- DecodeImage:
|
||||||
img_mode: BGR
|
img_mode: BGR
|
||||||
channel_first: false
|
channel_first: false
|
||||||
- RecAug: null
|
- RecAug:
|
||||||
- CTCLabelEncode: null
|
- CTCLabelEncode:
|
||||||
- RecResizeImg:
|
- RecResizeImg:
|
||||||
image_shape: [3, 32, 320]
|
image_shape: [3, 32, 320]
|
||||||
- KeepKeys:
|
- KeepKeys:
|
||||||
|
@ -143,7 +143,7 @@ Eval:
|
||||||
- DecodeImage:
|
- DecodeImage:
|
||||||
img_mode: BGR
|
img_mode: BGR
|
||||||
channel_first: false
|
channel_first: false
|
||||||
- CTCLabelEncode: null
|
- CTCLabelEncode:
|
||||||
- RecResizeImg:
|
- RecResizeImg:
|
||||||
image_shape: [3, 32, 320]
|
image_shape: [3, 32, 320]
|
||||||
- KeepKeys:
|
- KeepKeys:
|
|
@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
|
||||||
from ppocr.modeling.necks import build_neck
|
from ppocr.modeling.necks import build_neck
|
||||||
from ppocr.modeling.heads import build_head
|
from ppocr.modeling.heads import build_head
|
||||||
from .base_model import BaseModel
|
from .base_model import BaseModel
|
||||||
from ppocr.utils.save_load import load_dygraph_pretrain
|
from ppocr.utils.save_load import init_model
|
||||||
|
|
||||||
__all__ = ['DistillationModel']
|
__all__ = ['DistillationModel']
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
|
||||||
pretrained = model_config.pop("pretrained")
|
pretrained = model_config.pop("pretrained")
|
||||||
model = BaseModel(model_config)
|
model = BaseModel(model_config)
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
load_dygraph_pretrain(model, path=pretrained)
|
init_model(model, path=pretrained)
|
||||||
if freeze_params:
|
if freeze_params:
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
param.trainable = False
|
param.trainable = False
|
||||||
|
|
|
@ -23,6 +23,8 @@ import six
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
|
|
||||||
|
from ppocr.utils.logging import get_logger
|
||||||
|
|
||||||
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
|
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,19 +44,11 @@ def _mkdir_if_not_exist(path, logger):
|
||||||
raise OSError('Failed to mkdir {}'.format(path))
|
raise OSError('Failed to mkdir {}'.format(path))
|
||||||
|
|
||||||
|
|
||||||
def load_dygraph_pretrain(model, logger=None, path=None):
|
def init_model(config, model, optimizer=None, lr_scheduler=None):
|
||||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
|
||||||
raise ValueError("Model pretrain path {} does not "
|
|
||||||
"exists.".format(path))
|
|
||||||
param_state_dict = paddle.load(path + '.pdparams')
|
|
||||||
model.set_state_dict(param_state_dict)
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
|
||||||
"""
|
"""
|
||||||
load model from checkpoint or pretrained_model
|
load model from checkpoint or pretrained_model
|
||||||
"""
|
"""
|
||||||
|
logger = get_logger()
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
checkpoints = global_config.get('checkpoints')
|
checkpoints = global_config.get('checkpoints')
|
||||||
pretrained_model = global_config.get('pretrained_model')
|
pretrained_model = global_config.get('pretrained_model')
|
||||||
|
@ -77,13 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
||||||
best_model_dict = states_dict.get('best_model_dict', {})
|
best_model_dict = states_dict.get('best_model_dict', {})
|
||||||
if 'epoch' in states_dict:
|
if 'epoch' in states_dict:
|
||||||
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
||||||
|
|
||||||
logger.info("resume from {}".format(checkpoints))
|
logger.info("resume from {}".format(checkpoints))
|
||||||
elif pretrained_model:
|
elif pretrained_model:
|
||||||
if not isinstance(pretrained_model, list):
|
if not isinstance(pretrained_model, list):
|
||||||
pretrained_model = [pretrained_model]
|
pretrained_model = [pretrained_model]
|
||||||
for pretrained in pretrained_model:
|
for pretrained in pretrained_model:
|
||||||
load_dygraph_pretrain(model, logger, path=pretrained)
|
if not (os.path.isdir(pretrained) or
|
||||||
|
os.path.exists(pretrained + '.pdparams')):
|
||||||
|
raise ValueError("Model pretrain path {} does not "
|
||||||
|
"exists.".format(pretrained))
|
||||||
|
param_state_dict = paddle.load(pretrained + '.pdparams')
|
||||||
|
model.set_state_dict(param_state_dict)
|
||||||
logger.info("load pretrained model from {}".format(
|
logger.info("load pretrained model from {}".format(
|
||||||
pretrained_model))
|
pretrained_model))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -49,7 +49,7 @@ def main():
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
|
|
||||||
best_model_dict = init_model(config, model, logger)
|
best_model_dict = init_model(config, model)
|
||||||
if len(best_model_dict):
|
if len(best_model_dict):
|
||||||
logger.info('metric in ckpt ***************')
|
logger.info('metric in ckpt ***************')
|
||||||
for k, v in best_model_dict.items():
|
for k, v in best_model_dict.items():
|
||||||
|
|
|
@ -95,7 +95,7 @@ def main():
|
||||||
else: # base rec model
|
else: # base rec model
|
||||||
config["Architecture"]["Head"]["out_channels"] = char_num
|
config["Architecture"]["Head"]["out_channels"] = char_num
|
||||||
model = build_model(config["Architecture"])
|
model = build_model(config["Architecture"])
|
||||||
init_model(config, model, logger)
|
init_model(config, model)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
save_path = config["Global"]["save_inference_dir"]
|
save_path = config["Global"]["save_inference_dir"]
|
||||||
|
|
|
@ -47,7 +47,7 @@ def main():
|
||||||
# build model
|
# build model
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
init_model(config, model, logger)
|
init_model(config, model)
|
||||||
|
|
||||||
# create data ops
|
# create data ops
|
||||||
transforms = []
|
transforms = []
|
||||||
|
|
|
@ -61,7 +61,7 @@ def main():
|
||||||
# build model
|
# build model
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
init_model(config, model, logger)
|
init_model(config, model)
|
||||||
|
|
||||||
# build post process
|
# build post process
|
||||||
post_process_class = build_post_process(config['PostProcess'])
|
post_process_class = build_post_process(config['PostProcess'])
|
||||||
|
|
|
@ -68,7 +68,7 @@ def main():
|
||||||
# build model
|
# build model
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
init_model(config, model, logger)
|
init_model(config, model)
|
||||||
|
|
||||||
# build post process
|
# build post process
|
||||||
post_process_class = build_post_process(config['PostProcess'],
|
post_process_class = build_post_process(config['PostProcess'],
|
||||||
|
|
|
@ -58,7 +58,7 @@ def main():
|
||||||
|
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
init_model(config, model, logger)
|
init_model(config, model)
|
||||||
|
|
||||||
# create data ops
|
# create data ops
|
||||||
transforms = []
|
transforms = []
|
||||||
|
|
|
@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
|
||||||
# build metric
|
# build metric
|
||||||
eval_class = build_metric(config['Metric'])
|
eval_class = build_metric(config['Metric'])
|
||||||
# load pretrain model
|
# load pretrain model
|
||||||
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
pre_best_model_dict = init_model(config, model, optimizer)
|
||||||
|
|
||||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||||
if valid_dataloader is not None:
|
if valid_dataloader is not None:
|
||||||
|
|
Loading…
Reference in New Issue