fix bug
This commit is contained in:
parent
a91bbd7432
commit
185d1e1f92
|
@ -20,7 +20,7 @@ Architecture:
|
|||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
|
@ -37,7 +37,7 @@ Architecture:
|
|||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
|
@ -55,6 +55,9 @@ Architecture:
|
|||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
|
@ -73,7 +76,9 @@ Loss:
|
|||
loss_config_list:
|
||||
- DistillationDilaDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2", "Teacher"]
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
- ["Student2", "Teacher"]
|
||||
key: maps
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
|
@ -81,13 +86,16 @@ Loss:
|
|||
beta: 10
|
||||
ohem_ratio: 3
|
||||
- DistillationDMLLoss:
|
||||
model_name_pairs:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: ["thrink_maps"]
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
key: maps
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
|
@ -110,7 +118,7 @@ Optimizer:
|
|||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationCTDBPostProcessCLabelDecode
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2"]
|
||||
key: head_out
|
||||
thresh: 0.3
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .basic_loss import DMLLoss
|
||||
|
@ -22,6 +24,7 @@ from .det_db_loss import DBLoss
|
|||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||
|
||||
|
||||
|
||||
def _sum_loss(loss_dict):
|
||||
if "loss" in loss_dict.keys():
|
||||
return loss_dict
|
||||
|
@ -50,7 +53,7 @@ class DistillationDMLLoss(DMLLoss):
|
|||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
self.maps_name = self.maps_name
|
||||
self.maps_name = maps_name
|
||||
|
||||
def _check_maps_name(self, maps_name):
|
||||
if maps_name is None:
|
||||
|
@ -172,6 +175,7 @@ class DistillationDBLoss(DBLoss):
|
|||
class DistillationDilaDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
|
@ -182,6 +186,7 @@ class DistillationDilaDBLoss(DBLoss):
|
|||
super().__init__()
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
self.key = key
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
|
@ -219,7 +224,7 @@ class DistillationDilaDBLoss(DBLoss):
|
|||
loss_dict[k] = bce_loss + loss_binary_maps
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDistanceLoss(DistanceLoss):
|
||||
|
|
|
@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
|
|||
from ppocr.modeling.necks import build_neck
|
||||
from ppocr.modeling.heads import build_head
|
||||
from .base_model import BaseModel
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||
|
||||
__all__ = ['DistillationModel']
|
||||
|
||||
|
@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
|
|||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained is not None:
|
||||
init_model(model, path=pretrained)
|
||||
load_pretrained_params(model, pretrained)
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
|
|
|
@ -21,7 +21,7 @@ import copy
|
|||
|
||||
__all__ = ['build_post_process']
|
||||
|
||||
from .db_postprocess import DBPostProcess
|
||||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
|
@ -34,7 +34,7 @@ def build_post_process(config, global_config=None):
|
|||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode'
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -187,3 +187,44 @@ class DBPostProcess(object):
|
|||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
|
||||
|
||||
class DistillationDBPostProcess(DBPostProcess):
|
||||
def __init__(self,
|
||||
model_name=["student"],
|
||||
key=None,
|
||||
thresh=0.3,
|
||||
box_thresh=0.7,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=2.0,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
**kwargs):
|
||||
super(DistillationDBPostProcess, self).__init__(thresh,
|
||||
box_thresh,
|
||||
max_candidates,
|
||||
unclip_ratio,
|
||||
use_dilation,
|
||||
score_mode)
|
||||
if not isinstance(model_name, list):
|
||||
model_name = [model_name]
|
||||
self.model_name = model_name
|
||||
|
||||
self.key = key
|
||||
|
||||
def forward(self, predicts, shape_list):
|
||||
results = {}
|
||||
for name in self.model_name:
|
||||
pred = predicts[name]
|
||||
if self.key is not None:
|
||||
pred = pred[self.key]
|
||||
results[name] = super().__call__(pred, shape_list=label)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -116,6 +116,26 @@ def load_dygraph_params(config, model, logger, optimizer):
|
|||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
||||
def load_pretrained_params(model, path):
|
||||
if path is None:
|
||||
return False
|
||||
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
|
||||
print(f"The pretrained_model {path} does not exists!")
|
||||
return False
|
||||
|
||||
path = path if path.endswith('.pdparams') else path + '.pdparams'
|
||||
params = paddle.load(path)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
print(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
return True
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
|
|
|
@ -186,7 +186,10 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
try:
|
||||
model_type = config['Architecture']['model_type']
|
||||
except:
|
||||
model_type = None
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
|
|
Loading…
Reference in New Issue