From ed02b91d26048537261f062c75efa027c64690c7 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Wed, 2 Jun 2021 08:31:57 +0000 Subject: [PATCH] add distillation function --- ppocr/losses/__init__.py | 43 +++++--- ppocr/losses/basic_loss.py | 101 ++++++++++++++++++ ppocr/losses/cls_loss.py | 2 +- ppocr/losses/combined_loss.py | 57 ++++++++++ ppocr/losses/distillation_loss.py | 76 +++++++++++++ ppocr/losses/rec_ctc_loss.py | 2 +- ppocr/modeling/architectures/__init__.py | 16 ++- ppocr/modeling/architectures/base_model.py | 1 - .../architectures/distillation_model.py | 65 +++++++++++ ppocr/modeling/backbones/det_mobilenet_v3.py | 45 +++----- ppocr/modeling/backbones/rec_mobilenet_v3.py | 9 +- ppocr/modeling/heads/rec_ctc_head.py | 13 +-- ppocr/postprocess/__init__.py | 17 +-- ppocr/postprocess/rec_postprocess.py | 25 +++++ ppocr/utils/save_load.py | 5 +- tools/program.py | 2 +- tools/train.py | 9 +- 17 files changed, 407 insertions(+), 81 deletions(-) create mode 100644 ppocr/losses/basic_loss.py create mode 100644 ppocr/losses/combined_loss.py create mode 100644 ppocr/losses/distillation_loss.py create mode 100644 ppocr/modeling/architectures/distillation_model.py diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 223ae6b1..bf10d298 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -13,28 +13,37 @@ # limitations under the License. import copy +import paddle +import paddle.nn as nn + +# det loss +from .det_db_loss import DBLoss +from .det_east_loss import EASTLoss +from .det_sast_loss import SASTLoss + +# rec loss +from .rec_ctc_loss import CTCLoss +from .rec_att_loss import AttentionLoss +from .rec_srn_loss import SRNLoss + +# cls loss +from .cls_loss import ClsLoss + +# e2e loss +from .e2e_pg_loss import PGLoss + +# basic loss function +from .basic_loss import DistanceLoss + +# combined loss function +from .combined_loss import CombinedLoss def build_loss(config): - # det loss - from .det_db_loss import DBLoss - from .det_east_loss import EASTLoss - from .det_sast_loss import SASTLoss - - # rec loss - from .rec_ctc_loss import CTCLoss - from .rec_att_loss import AttentionLoss - from .rec_srn_loss import SRNLoss - - # cls loss - from .cls_loss import ClsLoss - - # e2e loss - from .e2e_pg_loss import PGLoss support_dict = [ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', - 'SRNLoss', 'PGLoss'] - + 'SRNLoss', 'PGLoss', 'CombinedLoss' + ] config = copy.deepcopy(config) module_name = config.pop('name') assert module_name in support_dict, Exception('loss only support {}'.format( diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py new file mode 100644 index 00000000..3321827b --- /dev/null +++ b/ppocr/losses/basic_loss.py @@ -0,0 +1,101 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.nn import L1Loss +from paddle.nn import MSELoss as L2Loss +from paddle.nn import SmoothL1Loss + + +class CELoss(nn.Layer): + def __init__(self, name="loss_ce", epsilon=None): + super().__init__() + self.name = name + if epsilon is not None and (epsilon <= 0 or epsilon >= 1): + epsilon = None + self.epsilon = epsilon + + def _labelsmoothing(self, target, class_num): + if target.shape[-1] != class_num: + one_hot_target = F.one_hot(target, class_num) + else: + one_hot_target = target + soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon) + soft_target = paddle.reshape(soft_target, shape=[-1, class_num]) + return soft_target + + def forward(self, x, label): + loss_dict = {} + if self.epsilon is not None: + class_num = x.shape[-1] + label = self._labelsmoothing(label, class_num) + x = -F.log_softmax(x, axis=-1) + loss = paddle.sum(x * label, axis=-1) + else: + if label.shape[-1] == x.shape[-1]: + label = F.softmax(label, axis=-1) + soft_label = True + else: + soft_label = False + loss = F.cross_entropy(x, label=label, soft_label=soft_label) + + loss_dict[self.name] = paddle.mean(loss) + return loss_dict + + +class DMLLoss(nn.Layer): + """ + DMLLoss + """ + + def __init__(self, name="loss_dml"): + super().__init__() + self.name = name + + def forward(self, out1, out2): + loss_dict = {} + soft_out1 = F.softmax(out1, axis=-1) + log_soft_out1 = paddle.log(soft_out1) + soft_out2 = F.softmax(out2, axis=-1) + log_soft_out2 = paddle.log(soft_out2) + loss = (F.kl_div( + log_soft_out1, soft_out2, reduction='batchmean') + F.kl_div( + log_soft_out2, soft_out1, reduction='batchmean')) / 2.0 + loss_dict[self.name] = loss + return loss_dict + + +class DistanceLoss(nn.Layer): + """ + DistanceLoss: + mode: loss mode + name: loss key in the output dict + """ + + def __init__(self, mode="l2", name="loss_dist", **kargs): + assert mode in ["l1", "l2", "smooth_l1"] + if mode == "l1": + self.loss_func = nn.L1Loss(**kargs) + elif mode == "l1": + self.loss_func = nn.MSELoss(**kargs) + elif mode == "smooth_l1": + self.loss_func = nn.SmoothL1Loss(**kargs) + + self.name = "{}_{}".format(name, mode) + + def forward(self, x, y): + return {self.name: self.loss_func(x, y)} diff --git a/ppocr/losses/cls_loss.py b/ppocr/losses/cls_loss.py index 41c7db02..ecca5d2e 100755 --- a/ppocr/losses/cls_loss.py +++ b/ppocr/losses/cls_loss.py @@ -24,7 +24,7 @@ class ClsLoss(nn.Layer): super(ClsLoss, self).__init__() self.loss_func = nn.CrossEntropyLoss(reduction='mean') - def __call__(self, predicts, batch): + def forward(self, predicts, batch): label = batch[1] loss = self.loss_func(input=predicts, label=label) return {'loss': loss} diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py new file mode 100644 index 00000000..49012e30 --- /dev/null +++ b/ppocr/losses/combined_loss.py @@ -0,0 +1,57 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn + +from .distillation_loss import DistillationCTCLoss +from .distillation_loss import DistillationDMLLoss + + +class CombinedLoss(nn.Layer): + """ + CombinedLoss: + a combionation of loss function + """ + + def __init__(self, loss_config_list=None): + super().__init__() + self.loss_func = [] + self.loss_weight = [] + assert isinstance(loss_config_list, list), ( + 'operator config should be a list') + for config in loss_config_list: + assert isinstance(config, + dict) and len(config) == 1, "yaml format error" + name = list(config)[0] + param = config[name] + assert "weight" in param, "weight must be in param, but param just contains {}".format( + param.keys()) + self.loss_weight.append(param.pop("weight")) + self.loss_func.append(eval(name)(**param)) + + def forward(self, input, batch, **kargs): + loss_dict = {} + for idx, loss_func in enumerate(self.loss_func): + loss = loss_func(input, batch, **kargs) + if isinstance(loss, paddle.Tensor): + loss = {"loss_{}_{}".format(str(loss), idx): loss} + weight = self.loss_weight[idx] + loss = { + "{}_{}".format(key, idx): loss[key] * weight + for key in loss + } + loss_dict.update(loss) + loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) + return loss_dict diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py new file mode 100644 index 00000000..cc6d7d5a --- /dev/null +++ b/ppocr/losses/distillation_loss.py @@ -0,0 +1,76 @@ +#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +import paddle +import paddle.nn as nn + +from .rec_ctc_loss import CTCLoss +from .basic_loss import DMLLoss + + +class DistillationDMLLoss(DMLLoss): + """ + """ + + def __init__(self, + model_name_list1=[], + model_name_list2=[], + key=None, + name="loss_dml"): + super().__init__(name=name) + if not isinstance(model_name_list1, (list, )): + model_name_list1 = [model_name_list1] + if not isinstance(model_name_list2, (list, )): + model_name_list2 = [model_name_list2] + + assert len(model_name_list1) == len(model_name_list2) + self.model_name_list1 = model_name_list1 + self.model_name_list2 = model_name_list2 + self.key = key + + def forward(self, predicts, batch): + loss_dict = dict() + for idx in range(len(self.model_name_list1)): + out1 = predicts[self.model_name_list1[idx]] + out2 = predicts[self.model_name_list2[idx]] + if self.key is not None: + out1 = out1[self.key] + out2 = out2[self.key] + loss = super().forward(out1, out2) + if isinstance(loss, dict): + assert len(loss) == 1 + loss = list(loss.values())[0] + loss_dict["{}_{}".format(self.name, idx)] = loss + return loss_dict + + +class DistillationCTCLoss(CTCLoss): + def __init__(self, model_name_list=[], key=None, name="loss_ctc"): + super().__init__() + self.model_name_list = model_name_list + self.key = key + self.name = name + + def forward(self, predicts, batch): + loss_dict = dict() + for model_name in self.model_name_list: + out = predicts[model_name] + if self.key is not None: + out = out[self.key] + loss = super().forward(out, batch) + if isinstance(loss, dict): + assert len(loss) == 1 + loss = list(loss.values())[0] + loss_dict["{}_{}".format(self.name, model_name)] = loss + return loss_dict diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py index 425de587..6c0b56ff 100755 --- a/ppocr/losses/rec_ctc_loss.py +++ b/ppocr/losses/rec_ctc_loss.py @@ -25,7 +25,7 @@ class CTCLoss(nn.Layer): super(CTCLoss, self).__init__() self.loss_func = nn.CTCLoss(blank=0, reduction='none') - def __call__(self, predicts, batch): + def forward(self, predicts, batch): predicts = predicts.transpose((1, 0, 2)) N, B, _ = predicts.shape preds_lengths = paddle.to_tensor([N] * B, dtype='int64') diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py index 86eaf7c9..e9a01cf0 100755 --- a/ppocr/modeling/architectures/__init__.py +++ b/ppocr/modeling/architectures/__init__.py @@ -13,12 +13,20 @@ # limitations under the License. import copy +import importlib + +from .base_model import BaseModel +from .distillation_model import DistillationModel __all__ = ['build_model'] + def build_model(config): - from .base_model import BaseModel - config = copy.deepcopy(config) - module_class = BaseModel(config) - return module_class \ No newline at end of file + if not "name" in config: + arch = BaseModel(config) + else: + name = config.pop("name") + mod = importlib.import_module(__name__) + arch = getattr(mod, name)(config) + return arch diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 09b6e034..5a41e507 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -32,7 +32,6 @@ class BaseModel(nn.Layer): config (dict): the super parameters for module. """ super(BaseModel, self).__init__() - in_channels = config.get('in_channels', 3) model_type = config['model_type'] # build transfrom, diff --git a/ppocr/modeling/architectures/distillation_model.py b/ppocr/modeling/architectures/distillation_model.py new file mode 100644 index 00000000..cc3f2405 --- /dev/null +++ b/ppocr/modeling/architectures/distillation_model.py @@ -0,0 +1,65 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn +from ppocr.modeling.transforms import build_transform +from ppocr.modeling.backbones import build_backbone +from ppocr.modeling.necks import build_neck +from ppocr.modeling.heads import build_head +from .base_model import BaseModel +from ppocr.utils.save_load import load_dygraph_pretrain + +__all__ = ['DistillationModel'] + + +class DistillationModel(nn.Layer): + def __init__(self, config): + """ + the module for OCR distillation. + args: + config (dict): the super parameters for module. + """ + super().__init__() + + freeze_params = config["freeze_params"] + pretrained = config["pretrained"] + if not isinstance(freeze_params, list): + freeze_params = [freeze_params] + assert len(config["Models"]) == len(freeze_params) + + if not isinstance(pretrained, list): + pretrained = [pretrained] * len(config["Models"]) + assert len(config["Models"]) == len(pretrained) + + self.model_dict = dict() + index = 0 + for key in config["Models"]: + model_config = config["Models"][key] + model = BaseModel(model_config) + if pretrained[index] is not None: + load_dygraph_pretrain(model, path=pretrained[index]) + if freeze_params[index]: + for param in model.parameters(): + param.trainable = False + self.model_dict[key] = self.add_sublayer(key, model) + index += 1 + + def forward(self, x): + result_dict = dict() + for key in self.model_dict: + result_dict[key] = self.model_dict[key](x) + return result_dict diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py index bb451bbe..05113ea8 100755 --- a/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer): padding=1, groups=1, if_act=True, - act='hardswish', - name='conv1') + act='hardswish') self.stages = [] self.out_channels = [] @@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer): kernel_size=k, stride=s, use_se=se, - act=nl, - name="conv" + str(i + 2))) + act=nl)) inplanes = make_divisible(scale * c) i += 1 block_list.append( @@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer): padding=0, groups=1, if_act=True, - act='hardswish', - name='conv_last')) + act='hardswish')) self.stages.append(nn.Sequential(*block_list)) self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) for i, stage in enumerate(self.stages): @@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer): padding, groups=1, if_act=True, - act=None, - name=None): + act=None): super(ConvBNLayer, self).__init__() self.if_act = if_act self.act = act @@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer): stride=stride, padding=padding, groups=groups, - weight_attr=ParamAttr(name=name + '_weights'), bias_attr=False) - self.bn = nn.BatchNorm( - num_channels=out_channels, - act=None, - param_attr=ParamAttr(name=name + "_bn_scale"), - bias_attr=ParamAttr(name=name + "_bn_offset"), - moving_mean_name=name + "_bn_mean", - moving_variance_name=name + "_bn_variance") + self.bn = nn.BatchNorm(num_channels=out_channels, act=None) def forward(self, x): x = self.conv(x) @@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer): kernel_size, stride, use_se, - act=None, - name=''): + act=None): super(ResidualUnit, self).__init__() self.if_shortcut = stride == 1 and in_channels == out_channels self.if_se = use_se @@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer): stride=1, padding=0, if_act=True, - act=act, - name=name + "_expand") + act=act) self.bottleneck_conv = ConvBNLayer( in_channels=mid_channels, out_channels=mid_channels, @@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer): padding=int((kernel_size - 1) // 2), groups=mid_channels, if_act=True, - act=act, - name=name + "_depthwise") + act=act) if self.if_se: - self.mid_se = SEModule(mid_channels, name=name + "_se") + self.mid_se = SEModule(mid_channels) self.linear_conv = ConvBNLayer( in_channels=mid_channels, out_channels=out_channels, @@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer): stride=1, padding=0, if_act=False, - act=None, - name=name + "_linear") + act=None) def forward(self, inputs): x = self.expand_conv(inputs) @@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer): class SEModule(nn.Layer): - def __init__(self, in_channels, reduction=4, name=""): + def __init__(self, in_channels, reduction=4): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2D(1) self.conv1 = nn.Conv2D( @@ -266,17 +251,13 @@ class SEModule(nn.Layer): out_channels=in_channels // reduction, kernel_size=1, stride=1, - padding=0, - weight_attr=ParamAttr(name=name + "_1_weights"), - bias_attr=ParamAttr(name=name + "_1_offset")) + padding=0) self.conv2 = nn.Conv2D( in_channels=in_channels // reduction, out_channels=in_channels, kernel_size=1, stride=1, - padding=0, - weight_attr=ParamAttr(name + "_2_weights"), - bias_attr=ParamAttr(name=name + "_2_offset")) + padding=0) def forward(self, inputs): outputs = self.avg_pool(inputs) diff --git a/ppocr/modeling/backbones/rec_mobilenet_v3.py b/ppocr/modeling/backbones/rec_mobilenet_v3.py index 1ff17159..c5dcfdd5 100644 --- a/ppocr/modeling/backbones/rec_mobilenet_v3.py +++ b/ppocr/modeling/backbones/rec_mobilenet_v3.py @@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer): padding=1, groups=1, if_act=True, - act='hardswish', - name='conv1') + act='hardswish') i = 0 block_list = [] inplanes = make_divisible(inplanes * scale) @@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer): kernel_size=k, stride=s, use_se=se, - act=nl, - name='conv' + str(i + 2))) + act=nl)) inplanes = make_divisible(scale * c) i += 1 self.blocks = nn.Sequential(*block_list) @@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer): padding=0, groups=1, if_act=True, - act='hardswish', - name='conv_last') + act='hardswish') self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.out_channels = make_divisible(scale * cls_ch_squeeze) diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py index 69d4ef50..481f93e4 100755 --- a/ppocr/modeling/heads/rec_ctc_head.py +++ b/ppocr/modeling/heads/rec_ctc_head.py @@ -23,14 +23,12 @@ from paddle import ParamAttr, nn from paddle.nn import functional as F -def get_para_bias_attr(l2_decay, k, name): +def get_para_bias_attr(l2_decay, k): regularizer = paddle.regularizer.L2Decay(l2_decay) stdv = 1.0 / math.sqrt(k * 1.0) initializer = nn.initializer.Uniform(-stdv, stdv) - weight_attr = ParamAttr( - regularizer=regularizer, initializer=initializer, name=name + "_w_attr") - bias_attr = ParamAttr( - regularizer=regularizer, initializer=initializer, name=name + "_b_attr") + weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer) + bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer) return [weight_attr, bias_attr] @@ -38,13 +36,12 @@ class CTCHead(nn.Layer): def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs): super(CTCHead, self).__init__() weight_attr, bias_attr = get_para_bias_attr( - l2_decay=fc_decay, k=in_channels, name='ctc_fc') + l2_decay=fc_decay, k=in_channels) self.fc = nn.Linear( in_channels, out_channels, weight_attr=weight_attr, - bias_attr=bias_attr, - name='ctc_fc') + bias_attr=bias_attr) self.out_channels = out_channels def forward(self, x, labels=None): diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 042654a1..cd2b7ea7 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -21,18 +21,19 @@ import copy __all__ = ['build_post_process'] +from .db_postprocess import DBPostProcess +from .east_postprocess import EASTPostProcess +from .sast_postprocess import SASTPostProcess +from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode +from .cls_postprocess import ClsPostProcess +from .pg_postprocess import PGPostProcess + def build_post_process(config, global_config=None): - from .db_postprocess import DBPostProcess - from .east_postprocess import EASTPostProcess - from .sast_postprocess import SASTPostProcess - from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode - from .cls_postprocess import ClsPostProcess - from .pg_postprocess import PGPostProcess - support_dict = [ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', - 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess' + 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess', + 'DistillationCTCLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index d353391c..5cc7abe7 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -125,6 +125,31 @@ class CTCLabelDecode(BaseRecLabelDecode): return dict_character +class DistillationCTCLabelDecode(CTCLabelDecode): + """ + Convert + Convert between text-label and text-index + """ + + def __init__(self, + character_dict_path=None, + character_type='ch', + use_space_char=False, + model_name="student", + key_out=None, + **kwargs): + super(DistillationCTCLabelDecode, self).__init__( + character_dict_path, character_type, use_space_char) + self.model_name = model_name + self.key_out = key_out + + def __call__(self, preds, label=None, *args, **kwargs): + pred = preds[self.model_name] + if self.key_out is not None: + pred = pred[self.key_out] + return super().__call__(pred, label=label, *args, **kwargs) + + class AttnLabelDecode(BaseRecLabelDecode): """ Convert between text-label and text-index """ diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 3d1c5c35..c730d1ab 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -42,7 +42,10 @@ def _mkdir_if_not_exist(path, logger): raise OSError('Failed to mkdir {}'.format(path)) -def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): +def load_dygraph_pretrain(model, + logger=None, + path=None, + load_static_weights=False): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): raise ValueError("Model pretrain path {} does not " "exists.".format(path)) diff --git a/tools/program.py b/tools/program.py index 7e54a2f8..7641bed7 100755 --- a/tools/program.py +++ b/tools/program.py @@ -386,7 +386,7 @@ def preprocess(is_train=False): alg = config['Architecture']['algorithm'] assert alg in [ 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet' + 'CLS', 'PGNet', 'Distillation' ] device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' diff --git a/tools/train.py b/tools/train.py index 47358ca4..555d3367 100755 --- a/tools/train.py +++ b/tools/train.py @@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer): # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num + if config['Architecture']["algorithm"] in ["Distillation", + ]: # distillation model + for key in config['Architecture']["Models"]: + config['Architecture']["Models"][key]["Head"][ + 'out_channels'] = char_num + else: # base rec model + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) if config['Global']['distributed']: model = paddle.DataParallel(model)