Merge pull request #3002 from littletomatodonkey/dyg/add_distillation
add distillation
This commit is contained in:
commit
85aeae712a
|
@ -0,0 +1,158 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 800
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: true
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
character_type: ch
|
||||
max_text_length: 25
|
||||
infer_mode: false
|
||||
use_space_char: false
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.0005
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 1.0e-05
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: small
|
||||
small_stride: [1, 2, 2, 2]
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 48
|
||||
Head:
|
||||
name: CTCHead
|
||||
fc_decay: 0.00001
|
||||
Teacher:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: small
|
||||
small_stride: [1, 2, 2, 2]
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 48
|
||||
Head:
|
||||
name: CTCHead
|
||||
fc_decay: 0.00001
|
||||
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationCTCLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
key: head_out
|
||||
- DistillationDMLLoss:
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: head_out
|
||||
- DistillationDistanceLoss:
|
||||
weight: 1.0
|
||||
mode: "l2"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: backbone_out
|
||||
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode
|
||||
model_name: ["Student", "Teacher"]
|
||||
key: head_out
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: RecMetric
|
||||
main_indicator: acc
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
- CTCLabelEncode:
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: 128
|
||||
drop_last: true
|
||||
num_sections: 1
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- CTCLabelEncode:
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 128
|
||||
num_workers: 8
|
|
@ -13,28 +13,37 @@
|
|||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
# det loss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_east_loss import EASTLoss
|
||||
from .det_sast_loss import SASTLoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_att_loss import AttentionLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
# e2e loss
|
||||
from .e2e_pg_loss import PGLoss
|
||||
|
||||
# basic loss function
|
||||
from .basic_loss import DistanceLoss
|
||||
|
||||
# combined loss function
|
||||
from .combined_loss import CombinedLoss
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
# det loss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_east_loss import EASTLoss
|
||||
from .det_sast_loss import SASTLoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_att_loss import AttentionLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
# e2e loss
|
||||
from .e2e_pg_loss import PGLoss
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss']
|
||||
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('loss only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from paddle.nn import L1Loss
|
||||
from paddle.nn import MSELoss as L2Loss
|
||||
from paddle.nn import SmoothL1Loss
|
||||
|
||||
|
||||
class CELoss(nn.Layer):
|
||||
def __init__(self, epsilon=None):
|
||||
super().__init__()
|
||||
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
|
||||
epsilon = None
|
||||
self.epsilon = epsilon
|
||||
|
||||
def _labelsmoothing(self, target, class_num):
|
||||
if target.shape[-1] != class_num:
|
||||
one_hot_target = F.one_hot(target, class_num)
|
||||
else:
|
||||
one_hot_target = target
|
||||
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
|
||||
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
|
||||
return soft_target
|
||||
|
||||
def forward(self, x, label):
|
||||
loss_dict = {}
|
||||
if self.epsilon is not None:
|
||||
class_num = x.shape[-1]
|
||||
label = self._labelsmoothing(label, class_num)
|
||||
x = -F.log_softmax(x, axis=-1)
|
||||
loss = paddle.sum(x * label, axis=-1)
|
||||
else:
|
||||
if label.shape[-1] == x.shape[-1]:
|
||||
label = F.softmax(label, axis=-1)
|
||||
soft_label = True
|
||||
else:
|
||||
soft_label = False
|
||||
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
|
||||
return loss
|
||||
|
||||
|
||||
class DMLLoss(nn.Layer):
|
||||
"""
|
||||
DMLLoss
|
||||
"""
|
||||
|
||||
def __init__(self, act=None):
|
||||
super().__init__()
|
||||
if act is not None:
|
||||
assert act in ["softmax", "sigmoid"]
|
||||
if act == "softmax":
|
||||
self.act = nn.Softmax(axis=-1)
|
||||
elif act == "sigmoid":
|
||||
self.act = nn.Sigmoid()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
def forward(self, out1, out2):
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1)
|
||||
out2 = self.act(out2)
|
||||
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, out1, reduction='batchmean')) / 2.0
|
||||
return loss
|
||||
|
||||
|
||||
class DistanceLoss(nn.Layer):
|
||||
"""
|
||||
DistanceLoss:
|
||||
mode: loss mode
|
||||
"""
|
||||
|
||||
def __init__(self, mode="l2", **kargs):
|
||||
super().__init__()
|
||||
assert mode in ["l1", "l2", "smooth_l1"]
|
||||
if mode == "l1":
|
||||
self.loss_func = nn.L1Loss(**kargs)
|
||||
elif mode == "l2":
|
||||
self.loss_func = nn.MSELoss(**kargs)
|
||||
elif mode == "smooth_l1":
|
||||
self.loss_func = nn.SmoothL1Loss(**kargs)
|
||||
|
||||
def forward(self, x, y):
|
||||
return self.loss_func(x, y)
|
|
@ -24,7 +24,7 @@ class ClsLoss(nn.Layer):
|
|||
super(ClsLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
def forward(self, predicts, batch):
|
||||
label = batch[1]
|
||||
loss = self.loss_func(input=predicts, label=label)
|
||||
return {'loss': loss}
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from .distillation_loss import DistillationCTCLoss
|
||||
from .distillation_loss import DistillationDMLLoss
|
||||
from .distillation_loss import DistillationDistanceLoss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Layer):
|
||||
"""
|
||||
CombinedLoss:
|
||||
a combionation of loss function
|
||||
"""
|
||||
|
||||
def __init__(self, loss_config_list=None):
|
||||
super().__init__()
|
||||
self.loss_func = []
|
||||
self.loss_weight = []
|
||||
assert isinstance(loss_config_list, list), (
|
||||
'operator config should be a list')
|
||||
for config in loss_config_list:
|
||||
assert isinstance(config,
|
||||
dict) and len(config) == 1, "yaml format error"
|
||||
name = list(config)[0]
|
||||
param = config[name]
|
||||
assert "weight" in param, "weight must be in param, but param just contains {}".format(
|
||||
param.keys())
|
||||
self.loss_weight.append(param.pop("weight"))
|
||||
self.loss_func.append(eval(name)(**param))
|
||||
|
||||
def forward(self, input, batch, **kargs):
|
||||
loss_dict = {}
|
||||
for idx, loss_func in enumerate(self.loss_func):
|
||||
loss = loss_func(input, batch, **kargs)
|
||||
if isinstance(loss, paddle.Tensor):
|
||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||
weight = self.loss_weight[idx]
|
||||
loss = {
|
||||
"{}_{}".format(key, idx): loss[key] * weight
|
||||
for key in loss
|
||||
}
|
||||
loss_dict.update(loss)
|
||||
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
|
||||
return loss_dict
|
|
@ -0,0 +1,108 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .basic_loss import DMLLoss
|
||||
from .basic_loss import DistanceLoss
|
||||
|
||||
|
||||
class DistillationDMLLoss(DMLLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, model_name_pairs=[], act=None, key=None,
|
||||
name="loss_dml"):
|
||||
super().__init__(act=act)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationCTCLoss(CTCLoss):
|
||||
def __init__(self, model_name_list=[], key=None, name="loss_ctc"):
|
||||
super().__init__()
|
||||
self.model_name_list = model_name_list
|
||||
self.key = key
|
||||
self.name = name
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
loss = super().forward(out, batch)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, model_name,
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDistanceLoss(DistanceLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mode="l2",
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
name="loss_distance",
|
||||
**kargs):
|
||||
super().__init__(mode=mode, **kargs)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name + "_l2"
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
|
||||
key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
|
||||
idx)] = loss
|
||||
return loss_dict
|
|
@ -25,7 +25,7 @@ class CTCLoss(nn.Layer):
|
|||
super(CTCLoss, self).__init__()
|
||||
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
def forward(self, predicts, batch):
|
||||
predicts = predicts.transpose((1, 0, 2))
|
||||
N, B, _ = predicts.shape
|
||||
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
||||
|
|
|
@ -19,20 +19,23 @@ from __future__ import unicode_literals
|
|||
|
||||
import copy
|
||||
|
||||
__all__ = ['build_metric']
|
||||
__all__ = ["build_metric"]
|
||||
|
||||
from .det_metric import DetMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .distillation_metric import DistillationMetric
|
||||
|
||||
|
||||
def build_metric(config):
|
||||
from .det_metric import DetMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
|
||||
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric']
|
||||
support_dict = [
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
module_name = config.pop("name")
|
||||
assert module_name in support_dict, Exception(
|
||||
'metric only support {}'.format(support_dict))
|
||||
"metric only support {}".format(support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import copy
|
||||
|
||||
from .rec_metric import RecMetric
|
||||
from .det_metric import DetMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .cls_metric import ClsMetric
|
||||
|
||||
|
||||
class DistillationMetric(object):
|
||||
def __init__(self,
|
||||
key=None,
|
||||
base_metric_name="RecMetric",
|
||||
main_indicator='acc',
|
||||
**kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.key = key
|
||||
self.main_indicator = main_indicator
|
||||
self.base_metric_name = base_metric_name
|
||||
self.kwargs = kwargs
|
||||
self.metrics = None
|
||||
|
||||
def _init_metrcis(self, preds):
|
||||
self.metrics = dict()
|
||||
mod = importlib.import_module(__name__)
|
||||
for key in preds:
|
||||
self.metrics[key] = getattr(mod, self.base_metric_name)(
|
||||
main_indicator=self.main_indicator, **self.kwargs)
|
||||
self.metrics[key].reset()
|
||||
|
||||
def __call__(self, preds, *args, **kwargs):
|
||||
assert isinstance(preds, dict)
|
||||
if self.metrics is None:
|
||||
self._init_metrcis(preds)
|
||||
output = dict()
|
||||
for key in preds:
|
||||
metric = self.metrics[key].__call__(preds[key], *args, **kwargs)
|
||||
for sub_key in metric:
|
||||
output["{}_{}".format(key, sub_key)] = metric[sub_key]
|
||||
return output
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0,
|
||||
'norm_edit_dis': 0,
|
||||
}
|
||||
"""
|
||||
output = dict()
|
||||
for key in self.metrics:
|
||||
metric = self.metrics[key].get_metric()
|
||||
# main indicator
|
||||
if key == self.key:
|
||||
output.update(metric)
|
||||
else:
|
||||
for sub_key in metric:
|
||||
output["{}_{}".format(key, sub_key)] = metric[sub_key]
|
||||
return output
|
||||
|
||||
def reset(self):
|
||||
for key in self.metrics:
|
||||
self.metrics[key].reset()
|
|
@ -13,12 +13,20 @@
|
|||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import importlib
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .distillation_model import DistillationModel
|
||||
|
||||
__all__ = ['build_model']
|
||||
|
||||
|
||||
def build_model(config):
|
||||
from .base_model import BaseModel
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_class = BaseModel(config)
|
||||
return module_class
|
||||
if not "name" in config:
|
||||
arch = BaseModel(config)
|
||||
else:
|
||||
name = config.pop("name")
|
||||
mod = importlib.import_module(__name__)
|
||||
arch = getattr(mod, name)(config)
|
||||
return arch
|
||||
|
|
|
@ -32,7 +32,6 @@ class BaseModel(nn.Layer):
|
|||
config (dict): the super parameters for module.
|
||||
"""
|
||||
super(BaseModel, self).__init__()
|
||||
|
||||
in_channels = config.get('in_channels', 3)
|
||||
model_type = config['model_type']
|
||||
# build transfrom,
|
||||
|
@ -68,14 +67,23 @@ class BaseModel(nn.Layer):
|
|||
config["Head"]['in_channels'] = in_channels
|
||||
self.head = build_head(config["Head"])
|
||||
|
||||
self.return_all_feats = config.get("return_all_feats", False)
|
||||
|
||||
def forward(self, x, data=None):
|
||||
y = dict()
|
||||
if self.use_transform:
|
||||
x = self.transform(x)
|
||||
x = self.backbone(x)
|
||||
y["backbone_out"] = x
|
||||
if self.use_neck:
|
||||
x = self.neck(x)
|
||||
y["neck_out"] = x
|
||||
if data is None:
|
||||
x = self.head(x)
|
||||
else:
|
||||
x = self.head(x, data)
|
||||
return x
|
||||
y["head_out"] = x
|
||||
if self.return_all_feats:
|
||||
return y
|
||||
else:
|
||||
return x
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
from ppocr.modeling.transforms import build_transform
|
||||
from ppocr.modeling.backbones import build_backbone
|
||||
from ppocr.modeling.necks import build_neck
|
||||
from ppocr.modeling.heads import build_head
|
||||
from .base_model import BaseModel
|
||||
from ppocr.utils.save_load import init_model
|
||||
|
||||
__all__ = ['DistillationModel']
|
||||
|
||||
|
||||
class DistillationModel(nn.Layer):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
the module for OCR distillation.
|
||||
args:
|
||||
config (dict): the super parameters for module.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model_list = []
|
||||
self.model_name_list = []
|
||||
for key in config["Models"]:
|
||||
model_config = config["Models"][key]
|
||||
freeze_params = False
|
||||
pretrained = None
|
||||
if "freeze_params" in model_config:
|
||||
freeze_params = model_config.pop("freeze_params")
|
||||
if "pretrained" in model_config:
|
||||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained is not None:
|
||||
init_model(model, path=pretrained)
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
self.model_list.append(self.add_sublayer(key, model))
|
||||
self.model_name_list.append(key)
|
||||
|
||||
def forward(self, x):
|
||||
result_dict = dict()
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
result_dict[model_name] = self.model_list[idx](x)
|
||||
return result_dict
|
|
@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer):
|
|||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv1')
|
||||
act='hardswish')
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = []
|
||||
|
@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer):
|
|||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=nl,
|
||||
name="conv" + str(i + 2)))
|
||||
act=nl))
|
||||
inplanes = make_divisible(scale * c)
|
||||
i += 1
|
||||
block_list.append(
|
||||
|
@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer):
|
|||
padding=0,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv_last'))
|
||||
act='hardswish'))
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
||||
for i, stage in enumerate(self.stages):
|
||||
|
@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer):
|
|||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
act=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
|
@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer):
|
|||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_bn_scale"),
|
||||
bias_attr=ParamAttr(name=name + "_bn_offset"),
|
||||
moving_mean_name=name + "_bn_mean",
|
||||
moving_variance_name=name + "_bn_variance")
|
||||
self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
|
@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer):
|
|||
kernel_size,
|
||||
stride,
|
||||
use_se,
|
||||
act=None,
|
||||
name=''):
|
||||
act=None):
|
||||
super(ResidualUnit, self).__init__()
|
||||
self.if_shortcut = stride == 1 and in_channels == out_channels
|
||||
self.if_se = use_se
|
||||
|
@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer):
|
|||
stride=1,
|
||||
padding=0,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_expand")
|
||||
act=act)
|
||||
self.bottleneck_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
|
@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer):
|
|||
padding=int((kernel_size - 1) // 2),
|
||||
groups=mid_channels,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_depthwise")
|
||||
act=act)
|
||||
if self.if_se:
|
||||
self.mid_se = SEModule(mid_channels, name=name + "_se")
|
||||
self.mid_se = SEModule(mid_channels)
|
||||
self.linear_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
|
@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer):
|
|||
stride=1,
|
||||
padding=0,
|
||||
if_act=False,
|
||||
act=None,
|
||||
name=name + "_linear")
|
||||
act=None)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.expand_conv(inputs)
|
||||
|
@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer):
|
|||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, in_channels, reduction=4, name=""):
|
||||
def __init__(self, in_channels, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||
self.conv1 = nn.Conv2D(
|
||||
|
@ -266,17 +251,13 @@ class SEModule(nn.Layer):
|
|||
out_channels=in_channels // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(name=name + "_1_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_1_offset"))
|
||||
padding=0)
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=in_channels // reduction,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(name + "_2_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_2_offset"))
|
||||
padding=0)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
|
|
|
@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer):
|
|||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv1')
|
||||
act='hardswish')
|
||||
i = 0
|
||||
block_list = []
|
||||
inplanes = make_divisible(inplanes * scale)
|
||||
|
@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer):
|
|||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=nl,
|
||||
name='conv' + str(i + 2)))
|
||||
act=nl))
|
||||
inplanes = make_divisible(scale * c)
|
||||
i += 1
|
||||
self.blocks = nn.Sequential(*block_list)
|
||||
|
@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer):
|
|||
padding=0,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv_last')
|
||||
act='hardswish')
|
||||
|
||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = make_divisible(scale * cls_ch_squeeze)
|
||||
|
|
|
@ -23,10 +23,10 @@ import paddle.nn.functional as F
|
|||
from paddle import ParamAttr
|
||||
|
||||
|
||||
def get_bias_attr(k, name):
|
||||
def get_bias_attr(k):
|
||||
stdv = 1.0 / math.sqrt(k * 1.0)
|
||||
initializer = paddle.nn.initializer.Uniform(-stdv, stdv)
|
||||
bias_attr = ParamAttr(initializer=initializer, name=name + "_b_attr")
|
||||
bias_attr = ParamAttr(initializer=initializer)
|
||||
return bias_attr
|
||||
|
||||
|
||||
|
@ -38,18 +38,14 @@ class Head(nn.Layer):
|
|||
out_channels=in_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(name=name_list[0] + '.w_0'),
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=False)
|
||||
self.conv_bn1 = nn.BatchNorm(
|
||||
num_channels=in_channels // 4,
|
||||
param_attr=ParamAttr(
|
||||
name=name_list[1] + '.w_0',
|
||||
initializer=paddle.nn.initializer.Constant(value=1.0)),
|
||||
bias_attr=ParamAttr(
|
||||
name=name_list[1] + '.b_0',
|
||||
initializer=paddle.nn.initializer.Constant(value=1e-4)),
|
||||
moving_mean_name=name_list[1] + '.w_1',
|
||||
moving_variance_name=name_list[1] + '.w_2',
|
||||
act='relu')
|
||||
self.conv2 = nn.Conv2DTranspose(
|
||||
in_channels=in_channels // 4,
|
||||
|
@ -57,19 +53,14 @@ class Head(nn.Layer):
|
|||
kernel_size=2,
|
||||
stride=2,
|
||||
weight_attr=ParamAttr(
|
||||
name=name_list[2] + '.w_0',
|
||||
initializer=paddle.nn.initializer.KaimingUniform()),
|
||||
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2"))
|
||||
bias_attr=get_bias_attr(in_channels // 4))
|
||||
self.conv_bn2 = nn.BatchNorm(
|
||||
num_channels=in_channels // 4,
|
||||
param_attr=ParamAttr(
|
||||
name=name_list[3] + '.w_0',
|
||||
initializer=paddle.nn.initializer.Constant(value=1.0)),
|
||||
bias_attr=ParamAttr(
|
||||
name=name_list[3] + '.b_0',
|
||||
initializer=paddle.nn.initializer.Constant(value=1e-4)),
|
||||
moving_mean_name=name_list[3] + '.w_1',
|
||||
moving_variance_name=name_list[3] + '.w_2',
|
||||
act="relu")
|
||||
self.conv3 = nn.Conv2DTranspose(
|
||||
in_channels=in_channels // 4,
|
||||
|
@ -77,10 +68,8 @@ class Head(nn.Layer):
|
|||
kernel_size=2,
|
||||
stride=2,
|
||||
weight_attr=ParamAttr(
|
||||
name=name_list[4] + '.w_0',
|
||||
initializer=paddle.nn.initializer.KaimingUniform()),
|
||||
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"),
|
||||
)
|
||||
bias_attr=get_bias_attr(in_channels // 4), )
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
|
|
@ -23,14 +23,12 @@ from paddle import ParamAttr, nn
|
|||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
def get_para_bias_attr(l2_decay, k, name):
|
||||
def get_para_bias_attr(l2_decay, k):
|
||||
regularizer = paddle.regularizer.L2Decay(l2_decay)
|
||||
stdv = 1.0 / math.sqrt(k * 1.0)
|
||||
initializer = nn.initializer.Uniform(-stdv, stdv)
|
||||
weight_attr = ParamAttr(
|
||||
regularizer=regularizer, initializer=initializer, name=name + "_w_attr")
|
||||
bias_attr = ParamAttr(
|
||||
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
|
||||
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
||||
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
||||
return [weight_attr, bias_attr]
|
||||
|
||||
|
||||
|
@ -38,13 +36,12 @@ class CTCHead(nn.Layer):
|
|||
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
|
||||
super(CTCHead, self).__init__()
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels, name='ctc_fc')
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr,
|
||||
name='ctc_fc')
|
||||
bias_attr=bias_attr)
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
|
|
|
@ -32,61 +32,53 @@ class DBFPN(nn.Layer):
|
|||
in_channels=in_channels[0],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_51.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in3_conv = nn.Conv2D(
|
||||
in_channels=in_channels[1],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_50.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in4_conv = nn.Conv2D(
|
||||
in_channels=in_channels[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_49.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in5_conv = nn.Conv2D(
|
||||
in_channels=in_channels[3],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_48.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p5_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_52.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p4_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_53.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p3_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_54.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p2_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_55.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -21,18 +21,19 @@ import copy
|
|||
|
||||
__all__ = ['build_post_process']
|
||||
|
||||
from .db_postprocess import DBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
from .db_postprocess import DBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|||
return dict_character
|
||||
|
||||
|
||||
class DistillationCTCLabelDecode(CTCLabelDecode):
|
||||
"""
|
||||
Convert
|
||||
Convert between text-label and text-index
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
model_name=["student"],
|
||||
key=None,
|
||||
**kwargs):
|
||||
super(DistillationCTCLabelDecode, self).__init__(
|
||||
character_dict_path, character_type, use_space_char)
|
||||
if not isinstance(model_name, list):
|
||||
model_name = [model_name]
|
||||
self.model_name = model_name
|
||||
|
||||
self.key = key
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
output = dict()
|
||||
for name in self.model_name:
|
||||
pred = preds[name]
|
||||
if self.key is not None:
|
||||
pred = pred[self.key]
|
||||
output[name] = super().__call__(pred, label=label, *args, **kwargs)
|
||||
return output
|
||||
|
||||
|
||||
class AttnLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -23,6 +23,8 @@ import six
|
|||
|
||||
import paddle
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
|
||||
|
||||
|
||||
|
@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger):
|
|||
raise OSError('Failed to mkdir {}'.format(path))
|
||||
|
||||
|
||||
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
|
||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(path))
|
||||
if load_static_weights:
|
||||
pre_state_dict = paddle.static.load_program_state(path)
|
||||
param_state_dict = {}
|
||||
model_dict = model.state_dict()
|
||||
for key in model_dict.keys():
|
||||
weight_name = model_dict[key].name
|
||||
weight_name = weight_name.replace('binarize', '').replace(
|
||||
'thresh', '') # for DB
|
||||
if weight_name in pre_state_dict.keys():
|
||||
# logger.info('Load weight: {}, shape: {}'.format(
|
||||
# weight_name, pre_state_dict[weight_name].shape))
|
||||
if 'encoder_rnn' in key:
|
||||
# delete axis which is 1
|
||||
pre_state_dict[weight_name] = pre_state_dict[
|
||||
weight_name].squeeze()
|
||||
# change axis
|
||||
if len(pre_state_dict[weight_name].shape) > 1:
|
||||
pre_state_dict[weight_name] = pre_state_dict[
|
||||
weight_name].transpose((1, 0))
|
||||
param_state_dict[key] = pre_state_dict[weight_name]
|
||||
else:
|
||||
param_state_dict[key] = model_dict[key]
|
||||
model.set_state_dict(param_state_dict)
|
||||
return
|
||||
|
||||
param_state_dict = paddle.load(path + '.pdparams')
|
||||
model.set_state_dict(param_state_dict)
|
||||
return
|
||||
|
||||
|
||||
def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
||||
def init_model(config, model, optimizer=None, lr_scheduler=None):
|
||||
"""
|
||||
load model from checkpoint or pretrained_model
|
||||
"""
|
||||
logger = get_logger()
|
||||
global_config = config['Global']
|
||||
checkpoints = global_config.get('checkpoints')
|
||||
pretrained_model = global_config.get('pretrained_model')
|
||||
|
@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
|||
best_model_dict = states_dict.get('best_model_dict', {})
|
||||
if 'epoch' in states_dict:
|
||||
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
||||
|
||||
logger.info("resume from {}".format(checkpoints))
|
||||
elif pretrained_model:
|
||||
load_static_weights = global_config.get('load_static_weights', False)
|
||||
if not isinstance(pretrained_model, list):
|
||||
pretrained_model = [pretrained_model]
|
||||
if not isinstance(load_static_weights, list):
|
||||
load_static_weights = [load_static_weights] * len(pretrained_model)
|
||||
for idx, pretrained in enumerate(pretrained_model):
|
||||
load_static = load_static_weights[idx]
|
||||
load_dygraph_pretrain(
|
||||
model, logger, path=pretrained, load_static_weights=load_static)
|
||||
for pretrained in pretrained_model:
|
||||
if not (os.path.isdir(pretrained) or
|
||||
os.path.exists(pretrained + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(pretrained))
|
||||
param_state_dict = paddle.load(pretrained + '.pdparams')
|
||||
model.set_state_dict(param_state_dict)
|
||||
logger.info("load pretrained model from {}".format(
|
||||
pretrained_model))
|
||||
else:
|
||||
|
|
|
@ -49,7 +49,7 @@ def main():
|
|||
model = build_model(config['Architecture'])
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
|
||||
best_model_dict = init_model(config, model, logger)
|
||||
best_model_dict = init_model(config, model)
|
||||
if len(best_model_dict):
|
||||
logger.info('metric in ckpt ***************')
|
||||
for k, v in best_model_dict.items():
|
||||
|
|
|
@ -17,7 +17,7 @@ import sys
|
|||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
|
||||
|
||||
import argparse
|
||||
|
||||
|
@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger
|
|||
from tools.program import load_config, merge_config, ArgsParser
|
||||
|
||||
|
||||
def main():
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
logger = get_logger()
|
||||
# build post process
|
||||
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
config['Global'])
|
||||
|
||||
# build model
|
||||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
model = build_model(config['Architecture'])
|
||||
init_model(config, model, logger)
|
||||
model.eval()
|
||||
|
||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
max_text_length = config['Architecture']['Head']['max_text_length']
|
||||
def export_single_model(model, arch_config, save_path, logger):
|
||||
if arch_config["algorithm"] == "SRN":
|
||||
max_text_length = arch_config["Head"]["max_text_length"]
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 1, 64, 256], dtype='float32'), [
|
||||
shape=[None, 1, 64, 256], dtype="float32"), [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 256, 1],
|
||||
dtype="int64"), paddle.static.InputSpec(
|
||||
|
@ -71,24 +51,66 @@ def main():
|
|||
model = to_static(model, input_spec=other_shape)
|
||||
else:
|
||||
infer_shape = [3, -1, -1]
|
||||
if config['Architecture']['model_type'] == "rec":
|
||||
if arch_config["model_type"] == "rec":
|
||||
infer_shape = [3, 32, -1] # for rec model, H must be 32
|
||||
if 'Transform' in config['Architecture'] and config['Architecture'][
|
||||
'Transform'] is not None and config['Architecture'][
|
||||
'Transform']['name'] == 'TPS':
|
||||
if "Transform" in arch_config and arch_config[
|
||||
"Transform"] is not None and arch_config["Transform"][
|
||||
"name"] == "TPS":
|
||||
logger.info(
|
||||
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
|
||||
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
|
||||
)
|
||||
infer_shape[-1] = 100
|
||||
|
||||
model = to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
shape=[None] + infer_shape, dtype="float32")
|
||||
])
|
||||
|
||||
paddle.jit.save(model, save_path)
|
||||
logger.info('inference model is saved to {}'.format(save_path))
|
||||
logger.info("inference model is saved to {}".format(save_path))
|
||||
return
|
||||
|
||||
|
||||
def main():
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
logger = get_logger()
|
||||
# build post process
|
||||
|
||||
post_process_class = build_post_process(config["PostProcess"],
|
||||
config["Global"])
|
||||
|
||||
# build model
|
||||
# for rec algorithm
|
||||
if hasattr(post_process_class, "character"):
|
||||
char_num = len(getattr(post_process_class, "character"))
|
||||
if config["Architecture"]["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config["Architecture"]["Models"]:
|
||||
config["Architecture"]["Models"][key]["Head"][
|
||||
"out_channels"] = char_num
|
||||
else: # base rec model
|
||||
config["Architecture"]["Head"]["out_channels"] = char_num
|
||||
model = build_model(config["Architecture"])
|
||||
init_model(config, model)
|
||||
model.eval()
|
||||
|
||||
save_path = config["Global"]["save_inference_dir"]
|
||||
|
||||
arch_config = config["Architecture"]
|
||||
|
||||
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
||||
archs = list(arch_config["Models"].values())
|
||||
for idx, name in enumerate(model.model_name_list):
|
||||
sub_model_save_path = os.path.join(save_path, name, "inference")
|
||||
export_single_model(model.model_list[idx], archs[idx],
|
||||
sub_model_save_path, logger)
|
||||
else:
|
||||
save_path = os.path.join(save_path, "inference")
|
||||
export_single_model(model, arch_config, save_path, logger)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -47,7 +47,7 @@ def main():
|
|||
# build model
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model, logger)
|
||||
init_model(config, model)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
|
|
|
@ -61,7 +61,7 @@ def main():
|
|||
# build model
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model, logger)
|
||||
init_model(config, model)
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'])
|
||||
|
|
|
@ -68,7 +68,7 @@ def main():
|
|||
# build model
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model, logger)
|
||||
init_model(config, model)
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
|
|
|
@ -20,6 +20,7 @@ import numpy as np
|
|||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
|
@ -46,12 +47,18 @@ def main():
|
|||
|
||||
# build model
|
||||
if hasattr(post_process_class, 'character'):
|
||||
config['Architecture']["Head"]['out_channels'] = len(
|
||||
getattr(post_process_class, 'character'))
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
init_model(config, model, logger)
|
||||
init_model(config, model)
|
||||
|
||||
# create data ops
|
||||
transforms = []
|
||||
|
@ -107,11 +114,23 @@ def main():
|
|||
else:
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds)
|
||||
for rec_reuslt in post_result:
|
||||
logger.info('\t result: {}'.format(rec_reuslt))
|
||||
if len(rec_reuslt) >= 2:
|
||||
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str(
|
||||
rec_reuslt[1]) + "\n")
|
||||
info = None
|
||||
if isinstance(post_result, dict):
|
||||
rec_info = dict()
|
||||
for key in post_result:
|
||||
if len(post_result[key][0]) >= 2:
|
||||
rec_info[key] = {
|
||||
"label": post_result[key][0][0],
|
||||
"score": post_result[key][0][1],
|
||||
}
|
||||
info = json.dumps(rec_info)
|
||||
else:
|
||||
if len(post_result[0]) >= 2:
|
||||
info = post_result[0][0] + "\t" + str(post_result[0][1])
|
||||
|
||||
if info is not None:
|
||||
logger.info("\t result: {}".format(info))
|
||||
fout.write(file + "\t" + info)
|
||||
logger.info("success!")
|
||||
|
||||
|
||||
|
|
|
@ -386,7 +386,7 @@ def preprocess(is_train=False):
|
|||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet'
|
||||
'CLS', 'PGNet', 'Distillation'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
|
@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer):
|
|||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
if config['Global']['distributed']:
|
||||
model = paddle.DataParallel(model)
|
||||
|
@ -90,7 +97,7 @@ def main(config, device, logger, vdl_writer):
|
|||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
||||
pre_best_model_dict = init_model(config, model, optimizer)
|
||||
|
||||
logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
|
||||
if valid_dataloader is not None:
|
||||
|
|
Loading…
Reference in New Issue