add distillation function
This commit is contained in:
parent
551a6827f0
commit
ed02b91d26
|
@ -13,9 +13,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def build_loss(config):
|
|
||||||
# det loss
|
# det loss
|
||||||
from .det_db_loss import DBLoss
|
from .det_db_loss import DBLoss
|
||||||
from .det_east_loss import EASTLoss
|
from .det_east_loss import EASTLoss
|
||||||
|
@ -31,10 +31,19 @@ def build_loss(config):
|
||||||
|
|
||||||
# e2e loss
|
# e2e loss
|
||||||
from .e2e_pg_loss import PGLoss
|
from .e2e_pg_loss import PGLoss
|
||||||
|
|
||||||
|
# basic loss function
|
||||||
|
from .basic_loss import DistanceLoss
|
||||||
|
|
||||||
|
# combined loss function
|
||||||
|
from .combined_loss import CombinedLoss
|
||||||
|
|
||||||
|
|
||||||
|
def build_loss(config):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||||
'SRNLoss', 'PGLoss']
|
'SRNLoss', 'PGLoss', 'CombinedLoss'
|
||||||
|
]
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception('loss only support {}'.format(
|
assert module_name in support_dict, Exception('loss only support {}'.format(
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
from paddle.nn import L1Loss
|
||||||
|
from paddle.nn import MSELoss as L2Loss
|
||||||
|
from paddle.nn import SmoothL1Loss
|
||||||
|
|
||||||
|
|
||||||
|
class CELoss(nn.Layer):
|
||||||
|
def __init__(self, name="loss_ce", epsilon=None):
|
||||||
|
super().__init__()
|
||||||
|
self.name = name
|
||||||
|
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
|
||||||
|
epsilon = None
|
||||||
|
self.epsilon = epsilon
|
||||||
|
|
||||||
|
def _labelsmoothing(self, target, class_num):
|
||||||
|
if target.shape[-1] != class_num:
|
||||||
|
one_hot_target = F.one_hot(target, class_num)
|
||||||
|
else:
|
||||||
|
one_hot_target = target
|
||||||
|
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
|
||||||
|
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
|
||||||
|
return soft_target
|
||||||
|
|
||||||
|
def forward(self, x, label):
|
||||||
|
loss_dict = {}
|
||||||
|
if self.epsilon is not None:
|
||||||
|
class_num = x.shape[-1]
|
||||||
|
label = self._labelsmoothing(label, class_num)
|
||||||
|
x = -F.log_softmax(x, axis=-1)
|
||||||
|
loss = paddle.sum(x * label, axis=-1)
|
||||||
|
else:
|
||||||
|
if label.shape[-1] == x.shape[-1]:
|
||||||
|
label = F.softmax(label, axis=-1)
|
||||||
|
soft_label = True
|
||||||
|
else:
|
||||||
|
soft_label = False
|
||||||
|
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
|
||||||
|
|
||||||
|
loss_dict[self.name] = paddle.mean(loss)
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DMLLoss(nn.Layer):
|
||||||
|
"""
|
||||||
|
DMLLoss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name="loss_dml"):
|
||||||
|
super().__init__()
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def forward(self, out1, out2):
|
||||||
|
loss_dict = {}
|
||||||
|
soft_out1 = F.softmax(out1, axis=-1)
|
||||||
|
log_soft_out1 = paddle.log(soft_out1)
|
||||||
|
soft_out2 = F.softmax(out2, axis=-1)
|
||||||
|
log_soft_out2 = paddle.log(soft_out2)
|
||||||
|
loss = (F.kl_div(
|
||||||
|
log_soft_out1, soft_out2, reduction='batchmean') + F.kl_div(
|
||||||
|
log_soft_out2, soft_out1, reduction='batchmean')) / 2.0
|
||||||
|
loss_dict[self.name] = loss
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistanceLoss(nn.Layer):
|
||||||
|
"""
|
||||||
|
DistanceLoss:
|
||||||
|
mode: loss mode
|
||||||
|
name: loss key in the output dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mode="l2", name="loss_dist", **kargs):
|
||||||
|
assert mode in ["l1", "l2", "smooth_l1"]
|
||||||
|
if mode == "l1":
|
||||||
|
self.loss_func = nn.L1Loss(**kargs)
|
||||||
|
elif mode == "l1":
|
||||||
|
self.loss_func = nn.MSELoss(**kargs)
|
||||||
|
elif mode == "smooth_l1":
|
||||||
|
self.loss_func = nn.SmoothL1Loss(**kargs)
|
||||||
|
|
||||||
|
self.name = "{}_{}".format(name, mode)
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
return {self.name: self.loss_func(x, y)}
|
|
@ -24,7 +24,7 @@ class ClsLoss(nn.Layer):
|
||||||
super(ClsLoss, self).__init__()
|
super(ClsLoss, self).__init__()
|
||||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
|
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
|
||||||
|
|
||||||
def __call__(self, predicts, batch):
|
def forward(self, predicts, batch):
|
||||||
label = batch[1]
|
label = batch[1]
|
||||||
loss = self.loss_func(input=predicts, label=label)
|
loss = self.loss_func(input=predicts, label=label)
|
||||||
return {'loss': loss}
|
return {'loss': loss}
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
from .distillation_loss import DistillationCTCLoss
|
||||||
|
from .distillation_loss import DistillationDMLLoss
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedLoss(nn.Layer):
|
||||||
|
"""
|
||||||
|
CombinedLoss:
|
||||||
|
a combionation of loss function
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, loss_config_list=None):
|
||||||
|
super().__init__()
|
||||||
|
self.loss_func = []
|
||||||
|
self.loss_weight = []
|
||||||
|
assert isinstance(loss_config_list, list), (
|
||||||
|
'operator config should be a list')
|
||||||
|
for config in loss_config_list:
|
||||||
|
assert isinstance(config,
|
||||||
|
dict) and len(config) == 1, "yaml format error"
|
||||||
|
name = list(config)[0]
|
||||||
|
param = config[name]
|
||||||
|
assert "weight" in param, "weight must be in param, but param just contains {}".format(
|
||||||
|
param.keys())
|
||||||
|
self.loss_weight.append(param.pop("weight"))
|
||||||
|
self.loss_func.append(eval(name)(**param))
|
||||||
|
|
||||||
|
def forward(self, input, batch, **kargs):
|
||||||
|
loss_dict = {}
|
||||||
|
for idx, loss_func in enumerate(self.loss_func):
|
||||||
|
loss = loss_func(input, batch, **kargs)
|
||||||
|
if isinstance(loss, paddle.Tensor):
|
||||||
|
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||||
|
weight = self.loss_weight[idx]
|
||||||
|
loss = {
|
||||||
|
"{}_{}".format(key, idx): loss[key] * weight
|
||||||
|
for key in loss
|
||||||
|
}
|
||||||
|
loss_dict.update(loss)
|
||||||
|
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
|
||||||
|
return loss_dict
|
|
@ -0,0 +1,76 @@
|
||||||
|
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
from .rec_ctc_loss import CTCLoss
|
||||||
|
from .basic_loss import DMLLoss
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationDMLLoss(DMLLoss):
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_name_list1=[],
|
||||||
|
model_name_list2=[],
|
||||||
|
key=None,
|
||||||
|
name="loss_dml"):
|
||||||
|
super().__init__(name=name)
|
||||||
|
if not isinstance(model_name_list1, (list, )):
|
||||||
|
model_name_list1 = [model_name_list1]
|
||||||
|
if not isinstance(model_name_list2, (list, )):
|
||||||
|
model_name_list2 = [model_name_list2]
|
||||||
|
|
||||||
|
assert len(model_name_list1) == len(model_name_list2)
|
||||||
|
self.model_name_list1 = model_name_list1
|
||||||
|
self.model_name_list2 = model_name_list2
|
||||||
|
self.key = key
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
for idx in range(len(self.model_name_list1)):
|
||||||
|
out1 = predicts[self.model_name_list1[idx]]
|
||||||
|
out2 = predicts[self.model_name_list2[idx]]
|
||||||
|
if self.key is not None:
|
||||||
|
out1 = out1[self.key]
|
||||||
|
out2 = out2[self.key]
|
||||||
|
loss = super().forward(out1, out2)
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
assert len(loss) == 1
|
||||||
|
loss = list(loss.values())[0]
|
||||||
|
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||||
|
return loss_dict
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationCTCLoss(CTCLoss):
|
||||||
|
def __init__(self, model_name_list=[], key=None, name="loss_ctc"):
|
||||||
|
super().__init__()
|
||||||
|
self.model_name_list = model_name_list
|
||||||
|
self.key = key
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
loss_dict = dict()
|
||||||
|
for model_name in self.model_name_list:
|
||||||
|
out = predicts[model_name]
|
||||||
|
if self.key is not None:
|
||||||
|
out = out[self.key]
|
||||||
|
loss = super().forward(out, batch)
|
||||||
|
if isinstance(loss, dict):
|
||||||
|
assert len(loss) == 1
|
||||||
|
loss = list(loss.values())[0]
|
||||||
|
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||||
|
return loss_dict
|
|
@ -25,7 +25,7 @@ class CTCLoss(nn.Layer):
|
||||||
super(CTCLoss, self).__init__()
|
super(CTCLoss, self).__init__()
|
||||||
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
|
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
|
||||||
|
|
||||||
def __call__(self, predicts, batch):
|
def forward(self, predicts, batch):
|
||||||
predicts = predicts.transpose((1, 0, 2))
|
predicts = predicts.transpose((1, 0, 2))
|
||||||
N, B, _ = predicts.shape
|
N, B, _ = predicts.shape
|
||||||
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
||||||
|
|
|
@ -13,12 +13,20 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from .distillation_model import DistillationModel
|
||||||
|
|
||||||
__all__ = ['build_model']
|
__all__ = ['build_model']
|
||||||
|
|
||||||
def build_model(config):
|
|
||||||
from .base_model import BaseModel
|
|
||||||
|
|
||||||
|
def build_model(config):
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_class = BaseModel(config)
|
if not "name" in config:
|
||||||
return module_class
|
arch = BaseModel(config)
|
||||||
|
else:
|
||||||
|
name = config.pop("name")
|
||||||
|
mod = importlib.import_module(__name__)
|
||||||
|
arch = getattr(mod, name)(config)
|
||||||
|
return arch
|
||||||
|
|
|
@ -32,7 +32,6 @@ class BaseModel(nn.Layer):
|
||||||
config (dict): the super parameters for module.
|
config (dict): the super parameters for module.
|
||||||
"""
|
"""
|
||||||
super(BaseModel, self).__init__()
|
super(BaseModel, self).__init__()
|
||||||
|
|
||||||
in_channels = config.get('in_channels', 3)
|
in_channels = config.get('in_channels', 3)
|
||||||
model_type = config['model_type']
|
model_type = config['model_type']
|
||||||
# build transfrom,
|
# build transfrom,
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from paddle import nn
|
||||||
|
from ppocr.modeling.transforms import build_transform
|
||||||
|
from ppocr.modeling.backbones import build_backbone
|
||||||
|
from ppocr.modeling.necks import build_neck
|
||||||
|
from ppocr.modeling.heads import build_head
|
||||||
|
from .base_model import BaseModel
|
||||||
|
from ppocr.utils.save_load import load_dygraph_pretrain
|
||||||
|
|
||||||
|
__all__ = ['DistillationModel']
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationModel(nn.Layer):
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
the module for OCR distillation.
|
||||||
|
args:
|
||||||
|
config (dict): the super parameters for module.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
freeze_params = config["freeze_params"]
|
||||||
|
pretrained = config["pretrained"]
|
||||||
|
if not isinstance(freeze_params, list):
|
||||||
|
freeze_params = [freeze_params]
|
||||||
|
assert len(config["Models"]) == len(freeze_params)
|
||||||
|
|
||||||
|
if not isinstance(pretrained, list):
|
||||||
|
pretrained = [pretrained] * len(config["Models"])
|
||||||
|
assert len(config["Models"]) == len(pretrained)
|
||||||
|
|
||||||
|
self.model_dict = dict()
|
||||||
|
index = 0
|
||||||
|
for key in config["Models"]:
|
||||||
|
model_config = config["Models"][key]
|
||||||
|
model = BaseModel(model_config)
|
||||||
|
if pretrained[index] is not None:
|
||||||
|
load_dygraph_pretrain(model, path=pretrained[index])
|
||||||
|
if freeze_params[index]:
|
||||||
|
for param in model.parameters():
|
||||||
|
param.trainable = False
|
||||||
|
self.model_dict[key] = self.add_sublayer(key, model)
|
||||||
|
index += 1
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
result_dict = dict()
|
||||||
|
for key in self.model_dict:
|
||||||
|
result_dict[key] = self.model_dict[key](x)
|
||||||
|
return result_dict
|
|
@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=1,
|
padding=1,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hardswish',
|
act='hardswish')
|
||||||
name='conv1')
|
|
||||||
|
|
||||||
self.stages = []
|
self.stages = []
|
||||||
self.out_channels = []
|
self.out_channels = []
|
||||||
|
@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer):
|
||||||
kernel_size=k,
|
kernel_size=k,
|
||||||
stride=s,
|
stride=s,
|
||||||
use_se=se,
|
use_se=se,
|
||||||
act=nl,
|
act=nl))
|
||||||
name="conv" + str(i + 2)))
|
|
||||||
inplanes = make_divisible(scale * c)
|
inplanes = make_divisible(scale * c)
|
||||||
i += 1
|
i += 1
|
||||||
block_list.append(
|
block_list.append(
|
||||||
|
@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=0,
|
padding=0,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hardswish',
|
act='hardswish'))
|
||||||
name='conv_last'))
|
|
||||||
self.stages.append(nn.Sequential(*block_list))
|
self.stages.append(nn.Sequential(*block_list))
|
||||||
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
||||||
for i, stage in enumerate(self.stages):
|
for i, stage in enumerate(self.stages):
|
||||||
|
@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer):
|
||||||
padding,
|
padding,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act=None,
|
act=None):
|
||||||
name=None):
|
|
||||||
super(ConvBNLayer, self).__init__()
|
super(ConvBNLayer, self).__init__()
|
||||||
self.if_act = if_act
|
self.if_act = if_act
|
||||||
self.act = act
|
self.act = act
|
||||||
|
@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer):
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
weight_attr=ParamAttr(name=name + '_weights'),
|
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
|
|
||||||
self.bn = nn.BatchNorm(
|
self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
|
||||||
num_channels=out_channels,
|
|
||||||
act=None,
|
|
||||||
param_attr=ParamAttr(name=name + "_bn_scale"),
|
|
||||||
bias_attr=ParamAttr(name=name + "_bn_offset"),
|
|
||||||
moving_mean_name=name + "_bn_mean",
|
|
||||||
moving_variance_name=name + "_bn_variance")
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
|
@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer):
|
||||||
kernel_size,
|
kernel_size,
|
||||||
stride,
|
stride,
|
||||||
use_se,
|
use_se,
|
||||||
act=None,
|
act=None):
|
||||||
name=''):
|
|
||||||
super(ResidualUnit, self).__init__()
|
super(ResidualUnit, self).__init__()
|
||||||
self.if_shortcut = stride == 1 and in_channels == out_channels
|
self.if_shortcut = stride == 1 and in_channels == out_channels
|
||||||
self.if_se = use_se
|
self.if_se = use_se
|
||||||
|
@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer):
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act=act,
|
act=act)
|
||||||
name=name + "_expand")
|
|
||||||
self.bottleneck_conv = ConvBNLayer(
|
self.bottleneck_conv = ConvBNLayer(
|
||||||
in_channels=mid_channels,
|
in_channels=mid_channels,
|
||||||
out_channels=mid_channels,
|
out_channels=mid_channels,
|
||||||
|
@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer):
|
||||||
padding=int((kernel_size - 1) // 2),
|
padding=int((kernel_size - 1) // 2),
|
||||||
groups=mid_channels,
|
groups=mid_channels,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act=act,
|
act=act)
|
||||||
name=name + "_depthwise")
|
|
||||||
if self.if_se:
|
if self.if_se:
|
||||||
self.mid_se = SEModule(mid_channels, name=name + "_se")
|
self.mid_se = SEModule(mid_channels)
|
||||||
self.linear_conv = ConvBNLayer(
|
self.linear_conv = ConvBNLayer(
|
||||||
in_channels=mid_channels,
|
in_channels=mid_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
|
@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer):
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0,
|
||||||
if_act=False,
|
if_act=False,
|
||||||
act=None,
|
act=None)
|
||||||
name=name + "_linear")
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
x = self.expand_conv(inputs)
|
x = self.expand_conv(inputs)
|
||||||
|
@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class SEModule(nn.Layer):
|
class SEModule(nn.Layer):
|
||||||
def __init__(self, in_channels, reduction=4, name=""):
|
def __init__(self, in_channels, reduction=4):
|
||||||
super(SEModule, self).__init__()
|
super(SEModule, self).__init__()
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||||
self.conv1 = nn.Conv2D(
|
self.conv1 = nn.Conv2D(
|
||||||
|
@ -266,17 +251,13 @@ class SEModule(nn.Layer):
|
||||||
out_channels=in_channels // reduction,
|
out_channels=in_channels // reduction,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0)
|
||||||
weight_attr=ParamAttr(name=name + "_1_weights"),
|
|
||||||
bias_attr=ParamAttr(name=name + "_1_offset"))
|
|
||||||
self.conv2 = nn.Conv2D(
|
self.conv2 = nn.Conv2D(
|
||||||
in_channels=in_channels // reduction,
|
in_channels=in_channels // reduction,
|
||||||
out_channels=in_channels,
|
out_channels=in_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=0,
|
padding=0)
|
||||||
weight_attr=ParamAttr(name + "_2_weights"),
|
|
||||||
bias_attr=ParamAttr(name=name + "_2_offset"))
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
outputs = self.avg_pool(inputs)
|
outputs = self.avg_pool(inputs)
|
||||||
|
|
|
@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=1,
|
padding=1,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hardswish',
|
act='hardswish')
|
||||||
name='conv1')
|
|
||||||
i = 0
|
i = 0
|
||||||
block_list = []
|
block_list = []
|
||||||
inplanes = make_divisible(inplanes * scale)
|
inplanes = make_divisible(inplanes * scale)
|
||||||
|
@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer):
|
||||||
kernel_size=k,
|
kernel_size=k,
|
||||||
stride=s,
|
stride=s,
|
||||||
use_se=se,
|
use_se=se,
|
||||||
act=nl,
|
act=nl))
|
||||||
name='conv' + str(i + 2)))
|
|
||||||
inplanes = make_divisible(scale * c)
|
inplanes = make_divisible(scale * c)
|
||||||
i += 1
|
i += 1
|
||||||
self.blocks = nn.Sequential(*block_list)
|
self.blocks = nn.Sequential(*block_list)
|
||||||
|
@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=0,
|
padding=0,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hardswish',
|
act='hardswish')
|
||||||
name='conv_last')
|
|
||||||
|
|
||||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||||
self.out_channels = make_divisible(scale * cls_ch_squeeze)
|
self.out_channels = make_divisible(scale * cls_ch_squeeze)
|
||||||
|
|
|
@ -23,14 +23,12 @@ from paddle import ParamAttr, nn
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
def get_para_bias_attr(l2_decay, k, name):
|
def get_para_bias_attr(l2_decay, k):
|
||||||
regularizer = paddle.regularizer.L2Decay(l2_decay)
|
regularizer = paddle.regularizer.L2Decay(l2_decay)
|
||||||
stdv = 1.0 / math.sqrt(k * 1.0)
|
stdv = 1.0 / math.sqrt(k * 1.0)
|
||||||
initializer = nn.initializer.Uniform(-stdv, stdv)
|
initializer = nn.initializer.Uniform(-stdv, stdv)
|
||||||
weight_attr = ParamAttr(
|
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
||||||
regularizer=regularizer, initializer=initializer, name=name + "_w_attr")
|
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
||||||
bias_attr = ParamAttr(
|
|
||||||
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
|
|
||||||
return [weight_attr, bias_attr]
|
return [weight_attr, bias_attr]
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,13 +36,12 @@ class CTCHead(nn.Layer):
|
||||||
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
|
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
|
||||||
super(CTCHead, self).__init__()
|
super(CTCHead, self).__init__()
|
||||||
weight_attr, bias_attr = get_para_bias_attr(
|
weight_attr, bias_attr = get_para_bias_attr(
|
||||||
l2_decay=fc_decay, k=in_channels, name='ctc_fc')
|
l2_decay=fc_decay, k=in_channels)
|
||||||
self.fc = nn.Linear(
|
self.fc = nn.Linear(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
weight_attr=weight_attr,
|
weight_attr=weight_attr,
|
||||||
bias_attr=bias_attr,
|
bias_attr=bias_attr)
|
||||||
name='ctc_fc')
|
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
def forward(self, x, labels=None):
|
def forward(self, x, labels=None):
|
||||||
|
|
|
@ -21,18 +21,19 @@ import copy
|
||||||
|
|
||||||
__all__ = ['build_post_process']
|
__all__ = ['build_post_process']
|
||||||
|
|
||||||
|
|
||||||
def build_post_process(config, global_config=None):
|
|
||||||
from .db_postprocess import DBPostProcess
|
from .db_postprocess import DBPostProcess
|
||||||
from .east_postprocess import EASTPostProcess
|
from .east_postprocess import EASTPostProcess
|
||||||
from .sast_postprocess import SASTPostProcess
|
from .sast_postprocess import SASTPostProcess
|
||||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
|
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode
|
||||||
from .cls_postprocess import ClsPostProcess
|
from .cls_postprocess import ClsPostProcess
|
||||||
from .pg_postprocess import PGPostProcess
|
from .pg_postprocess import PGPostProcess
|
||||||
|
|
||||||
|
|
||||||
|
def build_post_process(config, global_config=None):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
|
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||||
|
'DistillationCTCLabelDecode'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -125,6 +125,31 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
|
|
||||||
|
class DistillationCTCLabelDecode(CTCLabelDecode):
|
||||||
|
"""
|
||||||
|
Convert
|
||||||
|
Convert between text-label and text-index
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
character_dict_path=None,
|
||||||
|
character_type='ch',
|
||||||
|
use_space_char=False,
|
||||||
|
model_name="student",
|
||||||
|
key_out=None,
|
||||||
|
**kwargs):
|
||||||
|
super(DistillationCTCLabelDecode, self).__init__(
|
||||||
|
character_dict_path, character_type, use_space_char)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.key_out = key_out
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
pred = preds[self.model_name]
|
||||||
|
if self.key_out is not None:
|
||||||
|
pred = pred[self.key_out]
|
||||||
|
return super().__call__(pred, label=label, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class AttnLabelDecode(BaseRecLabelDecode):
|
class AttnLabelDecode(BaseRecLabelDecode):
|
||||||
""" Convert between text-label and text-index """
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,10 @@ def _mkdir_if_not_exist(path, logger):
|
||||||
raise OSError('Failed to mkdir {}'.format(path))
|
raise OSError('Failed to mkdir {}'.format(path))
|
||||||
|
|
||||||
|
|
||||||
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
|
def load_dygraph_pretrain(model,
|
||||||
|
logger=None,
|
||||||
|
path=None,
|
||||||
|
load_static_weights=False):
|
||||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||||
raise ValueError("Model pretrain path {} does not "
|
raise ValueError("Model pretrain path {} does not "
|
||||||
"exists.".format(path))
|
"exists.".format(path))
|
||||||
|
|
|
@ -386,7 +386,7 @@ def preprocess(is_train=False):
|
||||||
alg = config['Architecture']['algorithm']
|
alg = config['Architecture']['algorithm']
|
||||||
assert alg in [
|
assert alg in [
|
||||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||||
'CLS', 'PGNet'
|
'CLS', 'PGNet', 'Distillation'
|
||||||
]
|
]
|
||||||
|
|
||||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||||
|
|
|
@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer):
|
||||||
# for rec algorithm
|
# for rec algorithm
|
||||||
if hasattr(post_process_class, 'character'):
|
if hasattr(post_process_class, 'character'):
|
||||||
char_num = len(getattr(post_process_class, 'character'))
|
char_num = len(getattr(post_process_class, 'character'))
|
||||||
|
if config['Architecture']["algorithm"] in ["Distillation",
|
||||||
|
]: # distillation model
|
||||||
|
for key in config['Architecture']["Models"]:
|
||||||
|
config['Architecture']["Models"][key]["Head"][
|
||||||
|
'out_channels'] = char_num
|
||||||
|
else: # base rec model
|
||||||
config['Architecture']["Head"]['out_channels'] = char_num
|
config['Architecture']["Head"]['out_channels'] = char_num
|
||||||
|
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
if config['Global']['distributed']:
|
if config['Global']['distributed']:
|
||||||
model = paddle.DataParallel(model)
|
model = paddle.DataParallel(model)
|
||||||
|
|
Loading…
Reference in New Issue