diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml similarity index 94% rename from configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml rename to configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml index 016788ea..6b60ae08 100644 --- a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml +++ b/configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml @@ -8,9 +8,9 @@ Global: save_epoch_step: 3 eval_batch_step: [0, 2000] cal_metric_during_train: true - pretrained_model: null - checkpoints: null - save_inference_dir: null + pretrained_model: + checkpoints: + save_inference_dir: use_visualdl: false infer_img: doc/imgs_words/ch/word_1.jpg character_dict_path: ppocr/utils/ppocr_keys_v1.txt @@ -38,7 +38,7 @@ Architecture: algorithm: Distillation Models: Student: - pretrained: null + pretrained: freeze_params: false return_all_feats: true model_type: rec @@ -57,7 +57,7 @@ Architecture: name: CTCHead fc_decay: 0.00001 Teacher: - pretrained: null + pretrained: freeze_params: false return_all_feats: true model_type: rec @@ -118,8 +118,8 @@ Train: - DecodeImage: img_mode: BGR channel_first: false - - RecAug: null - - CTCLabelEncode: null + - RecAug: + - CTCLabelEncode: - RecResizeImg: image_shape: [3, 32, 320] - KeepKeys: @@ -143,7 +143,7 @@ Eval: - DecodeImage: img_mode: BGR channel_first: false - - CTCLabelEncode: null + - CTCLabelEncode: - RecResizeImg: image_shape: [3, 32, 320] - KeepKeys: diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py index 255ff32b..2e512331 100644 --- a/ppocr/modeling/architectures/distillation_model.py +++ b/ppocr/modeling/architectures/distillation_model.py @@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone from ppocr.modeling.necks import build_neck from ppocr.modeling.heads import build_head from .base_model import BaseModel -from ppocr.utils.save_load import load_dygraph_pretrain +from ppocr.utils.save_load import init_model __all__ = ['DistillationModel'] @@ -46,7 +46,7 @@ class DistillationModel(nn.Layer): pretrained = model_config.pop("pretrained") model = BaseModel(model_config) if pretrained is not None: - load_dygraph_pretrain(model, path=pretrained) + init_model(model, path=pretrained) if freeze_params: for param in model.parameters(): param.trainable = False diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 951132c3..23f5401b 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -23,6 +23,8 @@ import six import paddle +from ppocr.utils.logging import get_logger + __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] @@ -42,19 +44,11 @@ def _mkdir_if_not_exist(path, logger): raise OSError('Failed to mkdir {}'.format(path)) -def load_dygraph_pretrain(model, logger=None, path=None): - if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): - raise ValueError("Model pretrain path {} does not " - "exists.".format(path)) - param_state_dict = paddle.load(path + '.pdparams') - model.set_state_dict(param_state_dict) - return - - -def init_model(config, model, logger, optimizer=None, lr_scheduler=None): +def init_model(config, model, optimizer=None, lr_scheduler=None): """ load model from checkpoint or pretrained_model """ + logger = get_logger() global_config = config['Global'] checkpoints = global_config.get('checkpoints') pretrained_model = global_config.get('pretrained_model') @@ -77,13 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): best_model_dict = states_dict.get('best_model_dict', {}) if 'epoch' in states_dict: best_model_dict['start_epoch'] = states_dict['epoch'] + 1 - logger.info("resume from {}".format(checkpoints)) elif pretrained_model: if not isinstance(pretrained_model, list): pretrained_model = [pretrained_model] for pretrained in pretrained_model: - load_dygraph_pretrain(model, logger, path=pretrained) + if not (os.path.isdir(pretrained) or + os.path.exists(pretrained + '.pdparams')): + raise ValueError("Model pretrain path {} does not " + "exists.".format(pretrained)) + param_state_dict = paddle.load(pretrained + '.pdparams') + model.set_state_dict(param_state_dict) logger.info("load pretrained model from {}".format( pretrained_model)) else: diff --git a/tools/eval.py b/tools/eval.py index 9817fa75..66eb315f 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -49,7 +49,7 @@ def main(): model = build_model(config['Architecture']) use_srn = config['Architecture']['algorithm'] == "SRN" - best_model_dict = init_model(config, model, logger) + best_model_dict = init_model(config, model) if len(best_model_dict): logger.info('metric in ckpt ***************') for k, v in best_model_dict.items(): diff --git a/tools/export_model.py b/tools/export_model.py index 1d4538c8..625c8246 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -95,7 +95,7 @@ def main(): else: # base rec model config["Architecture"]["Head"]["out_channels"] = char_num model = build_model(config["Architecture"]) - init_model(config, model, logger) + init_model(config, model) model.eval() save_path = config["Global"]["save_inference_dir"] diff --git a/tools/infer_cls.py b/tools/infer_cls.py index 49696482..a588cab4 100755 --- a/tools/infer_cls.py +++ b/tools/infer_cls.py @@ -47,7 +47,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # create data ops transforms = [] diff --git a/tools/infer_det.py b/tools/infer_det.py index 913d617d..674f52ee 100755 --- a/tools/infer_det.py +++ b/tools/infer_det.py @@ -61,7 +61,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess']) diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index 9c079f60..1cd468b8 100755 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -68,7 +68,7 @@ def main(): # build model model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # build post process post_process_class = build_post_process(config['PostProcess'], diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 6894207d..09f5a0c7 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -58,7 +58,7 @@ def main(): model = build_model(config['Architecture']) - init_model(config, model, logger) + init_model(config, model) # create data ops transforms = [] diff --git a/tools/train.py b/tools/train.py index 555d3367..b024240b 100755 --- a/tools/train.py +++ b/tools/train.py @@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer): # build metric eval_class = build_metric(config['Metric']) # load pretrain model - pre_best_model_dict = init_model(config, model, logger, optimizer) + pre_best_model_dict = init_model(config, model, optimizer) logger.info('train dataloader has {} iters'.format(len(train_dataloader))) if valid_dataloader is not None: