fix bug
This commit is contained in:
parent
a91bbd7432
commit
185d1e1f92
|
@ -20,7 +20,7 @@ Architecture:
|
||||||
algorithm: Distillation
|
algorithm: Distillation
|
||||||
Models:
|
Models:
|
||||||
Student:
|
Student:
|
||||||
pretrained:
|
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
freeze_params: false
|
freeze_params: false
|
||||||
return_all_feats: false
|
return_all_feats: false
|
||||||
model_type: det
|
model_type: det
|
||||||
|
@ -37,7 +37,7 @@ Architecture:
|
||||||
name: DBHead
|
name: DBHead
|
||||||
k: 50
|
k: 50
|
||||||
Student2:
|
Student2:
|
||||||
pretrained:
|
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||||
freeze_params: false
|
freeze_params: false
|
||||||
return_all_feats: false
|
return_all_feats: false
|
||||||
model_type: det
|
model_type: det
|
||||||
|
@ -55,6 +55,9 @@ Architecture:
|
||||||
name: DBHead
|
name: DBHead
|
||||||
k: 50
|
k: 50
|
||||||
Teacher:
|
Teacher:
|
||||||
|
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||||
|
freeze_params: true
|
||||||
|
return_all_feats: false
|
||||||
model_type: det
|
model_type: det
|
||||||
algorithm: DB
|
algorithm: DB
|
||||||
Transform:
|
Transform:
|
||||||
|
@ -73,7 +76,9 @@ Loss:
|
||||||
loss_config_list:
|
loss_config_list:
|
||||||
- DistillationDilaDBLoss:
|
- DistillationDilaDBLoss:
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
model_name_list: ["Student", "Student2", "Teacher"]
|
model_name_pairs:
|
||||||
|
- ["Student", "Teacher"]
|
||||||
|
- ["Student2", "Teacher"]
|
||||||
key: maps
|
key: maps
|
||||||
balance_loss: true
|
balance_loss: true
|
||||||
main_loss_type: DiceLoss
|
main_loss_type: DiceLoss
|
||||||
|
@ -81,13 +86,16 @@ Loss:
|
||||||
beta: 10
|
beta: 10
|
||||||
ohem_ratio: 3
|
ohem_ratio: 3
|
||||||
- DistillationDMLLoss:
|
- DistillationDMLLoss:
|
||||||
|
model_name_pairs:
|
||||||
|
- ["Student", "Student2"]
|
||||||
maps_name: ["thrink_maps"]
|
maps_name: ["thrink_maps"]
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
act: "softmax"
|
act: "softmax"
|
||||||
model_name_pairs: ["Student", "Student2"]
|
model_name_pairs: ["Student", "Student2"]
|
||||||
key: maps
|
key: maps
|
||||||
- DistillationDBLoss:
|
- DistillationDBLoss:
|
||||||
model_name_list: ["Student", "Teacher"]
|
weight: 1.0
|
||||||
|
model_name_list: ["Student", "Student2"]
|
||||||
key: maps
|
key: maps
|
||||||
name: DBLoss
|
name: DBLoss
|
||||||
balance_loss: true
|
balance_loss: true
|
||||||
|
@ -110,7 +118,7 @@ Optimizer:
|
||||||
factor: 0
|
factor: 0
|
||||||
|
|
||||||
PostProcess:
|
PostProcess:
|
||||||
name: DistillationCTDBPostProcessCLabelDecode
|
name: DistillationDBPostProcess
|
||||||
model_name: ["Student", "Student2"]
|
model_name: ["Student", "Student2"]
|
||||||
key: head_out
|
key: head_out
|
||||||
thresh: 0.3
|
thresh: 0.3
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
import paddle.nn as nn
|
import paddle.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
from .rec_ctc_loss import CTCLoss
|
from .rec_ctc_loss import CTCLoss
|
||||||
from .basic_loss import DMLLoss
|
from .basic_loss import DMLLoss
|
||||||
|
@ -22,6 +24,7 @@ from .det_db_loss import DBLoss
|
||||||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _sum_loss(loss_dict):
|
def _sum_loss(loss_dict):
|
||||||
if "loss" in loss_dict.keys():
|
if "loss" in loss_dict.keys():
|
||||||
return loss_dict
|
return loss_dict
|
||||||
|
@ -50,7 +53,7 @@ class DistillationDMLLoss(DMLLoss):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_name_pairs = model_name_pairs
|
self.model_name_pairs = model_name_pairs
|
||||||
self.name = name
|
self.name = name
|
||||||
self.maps_name = self.maps_name
|
self.maps_name = maps_name
|
||||||
|
|
||||||
def _check_maps_name(self, maps_name):
|
def _check_maps_name(self, maps_name):
|
||||||
if maps_name is None:
|
if maps_name is None:
|
||||||
|
@ -172,6 +175,7 @@ class DistillationDBLoss(DBLoss):
|
||||||
class DistillationDilaDBLoss(DBLoss):
|
class DistillationDilaDBLoss(DBLoss):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_name_pairs=[],
|
model_name_pairs=[],
|
||||||
|
key=None,
|
||||||
balance_loss=True,
|
balance_loss=True,
|
||||||
main_loss_type='DiceLoss',
|
main_loss_type='DiceLoss',
|
||||||
alpha=5,
|
alpha=5,
|
||||||
|
@ -182,6 +186,7 @@ class DistillationDilaDBLoss(DBLoss):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_name_pairs = model_name_pairs
|
self.model_name_pairs = model_name_pairs
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.key = key
|
||||||
|
|
||||||
def forward(self, predicts, batch):
|
def forward(self, predicts, batch):
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
|
@ -219,7 +224,7 @@ class DistillationDilaDBLoss(DBLoss):
|
||||||
loss_dict[k] = bce_loss + loss_binary_maps
|
loss_dict[k] = bce_loss + loss_binary_maps
|
||||||
|
|
||||||
loss_dict = _sum_loss(loss_dict)
|
loss_dict = _sum_loss(loss_dict)
|
||||||
return loss
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
class DistillationDistanceLoss(DistanceLoss):
|
class DistillationDistanceLoss(DistanceLoss):
|
||||||
|
|
|
@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
|
||||||
from ppocr.modeling.necks import build_neck
|
from ppocr.modeling.necks import build_neck
|
||||||
from ppocr.modeling.heads import build_head
|
from ppocr.modeling.heads import build_head
|
||||||
from .base_model import BaseModel
|
from .base_model import BaseModel
|
||||||
from ppocr.utils.save_load import init_model
|
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||||
|
|
||||||
__all__ = ['DistillationModel']
|
__all__ = ['DistillationModel']
|
||||||
|
|
||||||
|
@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
|
||||||
pretrained = model_config.pop("pretrained")
|
pretrained = model_config.pop("pretrained")
|
||||||
model = BaseModel(model_config)
|
model = BaseModel(model_config)
|
||||||
if pretrained is not None:
|
if pretrained is not None:
|
||||||
init_model(model, path=pretrained)
|
load_pretrained_params(model, pretrained)
|
||||||
if freeze_params:
|
if freeze_params:
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
param.trainable = False
|
param.trainable = False
|
||||||
|
|
|
@ -21,7 +21,7 @@ import copy
|
||||||
|
|
||||||
__all__ = ['build_post_process']
|
__all__ = ['build_post_process']
|
||||||
|
|
||||||
from .db_postprocess import DBPostProcess
|
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||||
from .east_postprocess import EASTPostProcess
|
from .east_postprocess import EASTPostProcess
|
||||||
from .sast_postprocess import SASTPostProcess
|
from .sast_postprocess import SASTPostProcess
|
||||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||||
|
@ -34,7 +34,7 @@ def build_post_process(config, global_config=None):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||||
'DistillationCTCLabelDecode', 'TableLabelDecode'
|
'DistillationCTCLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -187,3 +187,44 @@ class DBPostProcess(object):
|
||||||
|
|
||||||
boxes_batch.append({'points': boxes})
|
boxes_batch.append({'points': boxes})
|
||||||
return boxes_batch
|
return boxes_batch
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDBPostProcess(DBPostProcess):
|
||||||
|
def __init__(self,
|
||||||
|
model_name=["student"],
|
||||||
|
key=None,
|
||||||
|
thresh=0.3,
|
||||||
|
box_thresh=0.7,
|
||||||
|
max_candidates=1000,
|
||||||
|
unclip_ratio=2.0,
|
||||||
|
use_dilation=False,
|
||||||
|
score_mode="fast",
|
||||||
|
**kwargs):
|
||||||
|
super(DistillationDBPostProcess, self).__init__(thresh,
|
||||||
|
box_thresh,
|
||||||
|
max_candidates,
|
||||||
|
unclip_ratio,
|
||||||
|
use_dilation,
|
||||||
|
score_mode)
|
||||||
|
if not isinstance(model_name, list):
|
||||||
|
model_name = [model_name]
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def forward(self, predicts, shape_list):
|
||||||
|
results = {}
|
||||||
|
for name in self.model_name:
|
||||||
|
pred = predicts[name]
|
||||||
|
if self.key is not None:
|
||||||
|
pred = pred[self.key]
|
||||||
|
results[name] = super().__call__(pred, shape_list=label)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -116,6 +116,26 @@ def load_dygraph_params(config, model, logger, optimizer):
|
||||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
def load_pretrained_params(model, path):
|
||||||
|
if path is None:
|
||||||
|
return False
|
||||||
|
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
|
||||||
|
print(f"The pretrained_model {path} does not exists!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
path = path if path.endswith('.pdparams') else path + '.pdparams'
|
||||||
|
params = paddle.load(path)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
new_state_dict = {}
|
||||||
|
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||||
|
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||||
|
new_state_dict[k1] = params[k2]
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||||
|
)
|
||||||
|
model.set_state_dict(new_state_dict)
|
||||||
|
return True
|
||||||
|
|
||||||
def save_model(model,
|
def save_model(model,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
|
|
@ -186,7 +186,10 @@ def train(config,
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
|
try:
|
||||||
model_type = config['Architecture']['model_type']
|
model_type = config['Architecture']['model_type']
|
||||||
|
except:
|
||||||
|
model_type = None
|
||||||
|
|
||||||
if 'start_epoch' in best_model_dict:
|
if 'start_epoch' in best_model_dict:
|
||||||
start_epoch = best_model_dict['start_epoch']
|
start_epoch = best_model_dict['start_epoch']
|
||||||
|
|
Loading…
Reference in New Issue