add for SEED
This commit is contained in:
parent
38801c7f5e
commit
59cc4efdc5
|
@ -0,0 +1,101 @@
|
|||
Global:
|
||||
use_gpu: False
|
||||
epoch_num: 400
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/b3_rare_r34_none_gru/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: EN_symbol
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
learning_rate: 0.0005
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00000
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: ASTER
|
||||
Transform:
|
||||
name: STN_ON
|
||||
tps_inputsize: [32, 64]
|
||||
tps_outputsize: [32, 100]
|
||||
num_control_points: 20
|
||||
tps_margins: [0.05,0.05]
|
||||
stn_activation: none
|
||||
Backbone:
|
||||
name: ResNet_ASTER
|
||||
Head:
|
||||
name: AsterHead # AttentionHead
|
||||
sDim: 512
|
||||
attDim: 512
|
||||
max_len_labels: 100
|
||||
|
||||
Loss:
|
||||
name: AsterLoss
|
||||
|
||||
PostProcess:
|
||||
name: AttnLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/1.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- AttnLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 2
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/1.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- AttnLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 2
|
||||
num_workers: 8
|
|
@ -104,6 +104,7 @@ class BaseRecLabelEncode(object):
|
|||
self.max_text_len = max_text_length
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.unknown = "UNKNOWN"
|
||||
if character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
|
@ -275,7 +276,9 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||
self.unknown = "UNKNOWN"
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str
|
||||
] + [self.unknown]
|
||||
return dict_character
|
||||
|
||||
def __call__(self, data):
|
||||
|
@ -288,6 +291,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
data['length'] = np.array(len(text))
|
||||
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
||||
- len(text) - 2)
|
||||
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
|
@ -352,19 +356,22 @@ class SRNLabelEncode(BaseRecLabelEncode):
|
|||
% beg_or_end
|
||||
return idx
|
||||
|
||||
|
||||
class TableLabelEncode(object):
|
||||
""" Convert between text-label and text-index """
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
max_elem_length,
|
||||
max_cell_num,
|
||||
character_dict_path,
|
||||
span_weight = 1.0,
|
||||
**kwargs):
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
max_elem_length,
|
||||
max_cell_num,
|
||||
character_dict_path,
|
||||
span_weight=1.0,
|
||||
**kwargs):
|
||||
self.max_text_length = max_text_length
|
||||
self.max_elem_length = max_elem_length
|
||||
self.max_cell_num = max_cell_num
|
||||
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
||||
list_character, list_elem = self.load_char_elem_dict(
|
||||
character_dict_path)
|
||||
list_character = self.add_special_char(list_character)
|
||||
list_elem = self.add_special_char(list_elem)
|
||||
self.dict_character = {}
|
||||
|
@ -374,7 +381,7 @@ class TableLabelEncode(object):
|
|||
for i, elem in enumerate(list_elem):
|
||||
self.dict_elem[elem] = i
|
||||
self.span_weight = span_weight
|
||||
|
||||
|
||||
def load_char_elem_dict(self, character_dict_path):
|
||||
list_character = []
|
||||
list_elem = []
|
||||
|
@ -383,27 +390,27 @@ class TableLabelEncode(object):
|
|||
substr = lines[0].decode('utf-8').strip("\n").split("\t")
|
||||
character_num = int(substr[0])
|
||||
elem_num = int(substr[1])
|
||||
for cno in range(1, 1+character_num):
|
||||
for cno in range(1, 1 + character_num):
|
||||
character = lines[cno].decode('utf-8').strip("\n")
|
||||
list_character.append(character)
|
||||
for eno in range(1+character_num, 1+character_num+elem_num):
|
||||
for eno in range(1 + character_num, 1 + character_num + elem_num):
|
||||
elem = lines[eno].decode('utf-8').strip("\n")
|
||||
list_elem.append(elem)
|
||||
return list_character, list_elem
|
||||
|
||||
|
||||
def add_special_char(self, list_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||
return list_character
|
||||
|
||||
|
||||
def get_span_idx_list(self):
|
||||
span_idx_list = []
|
||||
for elem in self.dict_elem:
|
||||
if 'span' in elem:
|
||||
span_idx_list.append(self.dict_elem[elem])
|
||||
return span_idx_list
|
||||
|
||||
|
||||
def __call__(self, data):
|
||||
cells = data['cells']
|
||||
structure = data['structure']['tokens']
|
||||
|
@ -412,18 +419,22 @@ class TableLabelEncode(object):
|
|||
return None
|
||||
elem_num = len(structure)
|
||||
structure = [0] + structure + [len(self.dict_elem) - 1]
|
||||
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
|
||||
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
|
||||
)
|
||||
structure = np.array(structure)
|
||||
data['structure'] = structure
|
||||
elem_char_idx1 = self.dict_elem['<td>']
|
||||
elem_char_idx2 = self.dict_elem['<td']
|
||||
span_idx_list = self.get_span_idx_list()
|
||||
td_idx_list = np.logical_or(structure == elem_char_idx1, structure == elem_char_idx2)
|
||||
td_idx_list = np.logical_or(structure == elem_char_idx1,
|
||||
structure == elem_char_idx2)
|
||||
td_idx_list = np.where(td_idx_list)[0]
|
||||
|
||||
structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
|
||||
|
||||
structure_mask = np.ones(
|
||||
(self.max_elem_length + 2, 1), dtype=np.float32)
|
||||
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
|
||||
bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
|
||||
bbox_list_mask = np.zeros(
|
||||
(self.max_elem_length + 2, 1), dtype=np.float32)
|
||||
img_height, img_width, img_ch = data['image'].shape
|
||||
if len(span_idx_list) > 0:
|
||||
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
|
||||
|
@ -450,9 +461,11 @@ class TableLabelEncode(object):
|
|||
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
||||
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
||||
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
||||
data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
|
||||
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||
self.max_elem_length, self.max_cell_num, elem_num])
|
||||
data['sp_tokens'] = np.array([
|
||||
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
|
||||
elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||
self.max_elem_length, self.max_cell_num, elem_num
|
||||
])
|
||||
return data
|
||||
|
||||
def encode(self, text, char_or_elem):
|
||||
|
@ -504,9 +517,8 @@ class TableLabelEncode(object):
|
|||
idx = np.array(self.dict_elem[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
|
||||
% beg_or_end
|
||||
% beg_or_end
|
||||
else:
|
||||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
|
@ -22,6 +22,7 @@ from .imaug import transform, create_operators
|
|||
|
||||
class SimpleDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
print("===== simpledataset ========")
|
||||
super(SimpleDataSet, self).__init__()
|
||||
self.logger = logger
|
||||
self.mode = mode.lower()
|
||||
|
|
|
@ -41,10 +41,13 @@ from .combined_loss import CombinedLoss
|
|||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
|
||||
from .rec_aster_loss import AsterLoss
|
||||
|
||||
|
||||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss', 'AsterLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import fasttext
|
||||
|
||||
|
||||
class AsterLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
weight=None,
|
||||
size_average=True,
|
||||
ignore_index=-100,
|
||||
sequence_normalize=False,
|
||||
sample_normalize=True,
|
||||
**kwargs):
|
||||
super(AsterLoss, self).__init__()
|
||||
self.weight = weight
|
||||
self.size_average = size_average
|
||||
self.ignore_index = ignore_index
|
||||
self.sequence_normalize = sequence_normalize
|
||||
self.sample_normalize = sample_normalize
|
||||
self.loss_func = paddle.nn.CosineSimilarity()
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
targets = batch[1].astype("int64")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
# sem_target = batch[3].astype('float32')
|
||||
embedding_vectors = predicts['embedding_vectors']
|
||||
rec_pred = predicts['rec_pred']
|
||||
|
||||
# semantic loss
|
||||
# print(embedding_vectors)
|
||||
# print(embedding_vectors.shape)
|
||||
# targets = fasttext[targets]
|
||||
# sem_loss = 1 - self.loss_func(embedding_vectors, targets)
|
||||
|
||||
# rec loss
|
||||
batch_size, num_steps, num_classes = rec_pred.shape[0], rec_pred.shape[
|
||||
1], rec_pred.shape[2]
|
||||
assert len(targets.shape) == len(list(rec_pred.shape)) - 1, \
|
||||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||
|
||||
mask = paddle.zeros([batch_size, num_steps])
|
||||
for i in range(batch_size):
|
||||
mask[i, :label_lengths[i]] = 1
|
||||
mask = paddle.cast(mask, "float32")
|
||||
max_length = max(label_lengths)
|
||||
assert max_length == rec_pred.shape[1]
|
||||
targets = targets[:, :max_length]
|
||||
mask = mask[:, :max_length]
|
||||
rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[-1]])
|
||||
input = nn.functional.log_softmax(rec_pred, axis=1)
|
||||
targets = paddle.reshape(targets, [-1, 1])
|
||||
mask = paddle.reshape(mask, [-1, 1])
|
||||
# print("input:", input)
|
||||
output = -paddle.gather(input, index=targets, axis=1) * mask
|
||||
output = paddle.sum(output)
|
||||
if self.sequence_normalize:
|
||||
output = output / paddle.sum(mask)
|
||||
if self.sample_normalize:
|
||||
output = output / batch_size
|
||||
loss = output
|
||||
return {'loss': loss} # , 'sem_loss':sem_loss}
|
|
@ -35,5 +35,7 @@ class AttentionLoss(nn.Layer):
|
|||
|
||||
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
|
||||
targets = paddle.reshape(targets, [-1])
|
||||
print("input:", paddle.argmax(inputs, axis=1))
|
||||
print("targets:", targets)
|
||||
|
||||
return {'loss': paddle.sum(self.loss_func(inputs, targets))}
|
||||
|
|
|
@ -26,8 +26,10 @@ def build_backbone(config, model_type):
|
|||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
from .rec_resnet_aster import ResNet_ASTER
|
||||
support_dict = [
|
||||
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
|
||||
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN",
|
||||
"ResNet_ASTER"
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,707 @@
|
|||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
|
||||
# Modified from
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# Copyright 2020 Ross Wightman, Apache-2.0 License
|
||||
|
||||
import paddle
|
||||
import itertools
|
||||
#import utils
|
||||
import math
|
||||
import warnings
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn.initializer import TruncatedNormal, Constant
|
||||
|
||||
#from timm.models.vision_transformer import trunc_normal_
|
||||
#from timm.models.registry import register_model
|
||||
|
||||
specification = {
|
||||
'LeViT_128S': {
|
||||
'C': '128_256_384',
|
||||
'D': 16,
|
||||
'N': '4_6_8',
|
||||
'X': '2_3_4',
|
||||
'drop_path': 0,
|
||||
'weights':
|
||||
'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
|
||||
},
|
||||
'LeViT_128': {
|
||||
'C': '128_256_384',
|
||||
'D': 16,
|
||||
'N': '4_8_12',
|
||||
'X': '4_4_4',
|
||||
'drop_path': 0,
|
||||
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
|
||||
},
|
||||
'LeViT_192': {
|
||||
'C': '192_288_384',
|
||||
'D': 32,
|
||||
'N': '3_5_6',
|
||||
'X': '4_4_4',
|
||||
'drop_path': 0,
|
||||
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
|
||||
},
|
||||
'LeViT_256': {
|
||||
'C': '256_384_512',
|
||||
'D': 32,
|
||||
'N': '4_6_8',
|
||||
'X': '4_4_4',
|
||||
'drop_path': 0,
|
||||
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
|
||||
},
|
||||
'LeViT_384': {
|
||||
'C': '384_512_768',
|
||||
'D': 32,
|
||||
'N': '6_9_12',
|
||||
'X': '4_4_4',
|
||||
'drop_path': 0.1,
|
||||
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
|
||||
},
|
||||
}
|
||||
|
||||
__all__ = [specification.keys()]
|
||||
|
||||
trunc_normal_ = TruncatedNormal(std=.02)
|
||||
zeros_ = Constant(value=0.)
|
||||
ones_ = Constant(value=1.)
|
||||
|
||||
|
||||
#@register_model
|
||||
def LeViT_128S(class_dim=1000, distillation=True, pretrained=False, fuse=False):
|
||||
return model_factory(
|
||||
**specification['LeViT_128S'],
|
||||
class_dim=class_dim,
|
||||
distillation=distillation,
|
||||
pretrained=pretrained,
|
||||
fuse=fuse)
|
||||
|
||||
|
||||
#@register_model
|
||||
def LeViT_128(class_dim=1000, distillation=True, pretrained=False, fuse=False):
|
||||
return model_factory(
|
||||
**specification['LeViT_128'],
|
||||
class_dim=class_dim,
|
||||
distillation=distillation,
|
||||
pretrained=pretrained,
|
||||
fuse=fuse)
|
||||
|
||||
|
||||
#@register_model
|
||||
def LeViT_192(class_dim=1000, distillation=True, pretrained=False, fuse=False):
|
||||
return model_factory(
|
||||
**specification['LeViT_192'],
|
||||
class_dim=class_dim,
|
||||
distillation=distillation,
|
||||
pretrained=pretrained,
|
||||
fuse=fuse)
|
||||
|
||||
|
||||
#@register_model
|
||||
def LeViT_256(class_dim=1000, distillation=False, pretrained=False, fuse=False):
|
||||
return model_factory(
|
||||
**specification['LeViT_256'],
|
||||
class_dim=class_dim,
|
||||
distillation=distillation,
|
||||
pretrained=pretrained,
|
||||
fuse=fuse)
|
||||
|
||||
|
||||
#@register_model
|
||||
def LeViT_384(class_dim=1000, distillation=True, pretrained=False, fuse=False):
|
||||
return model_factory(
|
||||
**specification['LeViT_384'],
|
||||
class_dim=class_dim,
|
||||
distillation=distillation,
|
||||
pretrained=pretrained,
|
||||
fuse=fuse)
|
||||
|
||||
|
||||
FLOPS_COUNTER = 0
|
||||
|
||||
|
||||
class Conv2d_BN(paddle.nn.Sequential):
|
||||
def __init__(self,
|
||||
a,
|
||||
b,
|
||||
ks=1,
|
||||
stride=1,
|
||||
pad=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bn_weight_init=1,
|
||||
resolution=-10000):
|
||||
super().__init__()
|
||||
self.add_sublayer(
|
||||
'c',
|
||||
paddle.nn.Conv2D(
|
||||
a, b, ks, stride, pad, dilation, groups, bias_attr=False))
|
||||
bn = paddle.nn.BatchNorm2D(b)
|
||||
ones_(bn.weight)
|
||||
zeros_(bn.bias)
|
||||
self.add_sublayer('bn', bn)
|
||||
|
||||
global FLOPS_COUNTER
|
||||
output_points = (
|
||||
(resolution + 2 * pad - dilation * (ks - 1) - 1) // stride + 1)**2
|
||||
FLOPS_COUNTER += a * b * output_points * (ks**2)
|
||||
|
||||
@paddle.no_grad()
|
||||
def fuse(self):
|
||||
c, bn = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
w = c.weight * w[:, None, None, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / \
|
||||
(bn.running_var + bn.eps)**0.5
|
||||
m = paddle.nn.Conv2D(
|
||||
w.size(1),
|
||||
w.size(0),
|
||||
w.shape[2:],
|
||||
stride=self.c.stride,
|
||||
padding=self.c.padding,
|
||||
dilation=self.c.dilation,
|
||||
groups=self.c.groups)
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
class Linear_BN(paddle.nn.Sequential):
|
||||
def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
|
||||
super().__init__()
|
||||
self.add_sublayer('c', paddle.nn.Linear(a, b, bias_attr=False))
|
||||
bn = paddle.nn.BatchNorm1D(b)
|
||||
ones_(bn.weight)
|
||||
zeros_(bn.bias)
|
||||
self.add_sublayer('bn', bn)
|
||||
|
||||
global FLOPS_COUNTER
|
||||
output_points = resolution**2
|
||||
FLOPS_COUNTER += a * b * output_points
|
||||
|
||||
@paddle.no_grad()
|
||||
def fuse(self):
|
||||
l, bn = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
w = l.weight * w[:, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / \
|
||||
(bn.running_var + bn.eps)**0.5
|
||||
m = paddle.nn.Linear(w.size(1), w.size(0))
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
def forward(self, x):
|
||||
l, bn = self._sub_layers.values()
|
||||
x = l(x)
|
||||
return paddle.reshape(bn(x.flatten(0, 1)), x.shape)
|
||||
|
||||
|
||||
class BN_Linear(paddle.nn.Sequential):
|
||||
def __init__(self, a, b, bias=True, std=0.02):
|
||||
super().__init__()
|
||||
self.add_sublayer('bn', paddle.nn.BatchNorm1D(a))
|
||||
l = paddle.nn.Linear(a, b, bias_attr=bias)
|
||||
trunc_normal_(l.weight)
|
||||
if bias:
|
||||
zeros_(l.bias)
|
||||
self.add_sublayer('l', l)
|
||||
global FLOPS_COUNTER
|
||||
FLOPS_COUNTER += a * b
|
||||
|
||||
@paddle.no_grad()
|
||||
def fuse(self):
|
||||
bn, l = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
b = bn.bias - self.bn.running_mean * \
|
||||
self.bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
w = l.weight * w[None, :]
|
||||
if l.bias is None:
|
||||
b = b @self.l.weight.T
|
||||
else:
|
||||
b = (l.weight @b[:, None]).view(-1) + self.l.bias
|
||||
m = paddle.nn.Linear(w.size(1), w.size(0))
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
def b16(n, activation, resolution=224):
|
||||
return paddle.nn.Sequential(
|
||||
Conv2d_BN(
|
||||
3, n // 8, 3, 2, 1, resolution=resolution),
|
||||
activation(),
|
||||
Conv2d_BN(
|
||||
n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
|
||||
activation(),
|
||||
Conv2d_BN(
|
||||
n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
|
||||
activation(),
|
||||
Conv2d_BN(
|
||||
n // 2, n, 3, 2, 1, resolution=resolution // 8))
|
||||
|
||||
|
||||
class Residual(paddle.nn.Layer):
|
||||
def __init__(self, m, drop):
|
||||
super().__init__()
|
||||
self.m = m
|
||||
self.drop = drop
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.drop > 0:
|
||||
return x + self.m(x) * paddle.rand(
|
||||
x.size(0), 1, 1,
|
||||
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
|
||||
else:
|
||||
return x + self.m(x)
|
||||
|
||||
|
||||
class Attention(paddle.nn.Layer):
|
||||
def __init__(self,
|
||||
dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
activation=None,
|
||||
resolution=14):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim**-0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
self.dh = int(attn_ratio * key_dim) * num_heads
|
||||
self.attn_ratio = attn_ratio
|
||||
self.h = self.dh + nh_kd * 2
|
||||
self.qkv = Linear_BN(dim, self.h, resolution=resolution)
|
||||
self.proj = paddle.nn.Sequential(
|
||||
activation(),
|
||||
Linear_BN(
|
||||
self.dh, dim, bn_weight_init=0, resolution=resolution))
|
||||
points = list(itertools.product(range(resolution), range(resolution)))
|
||||
N = len(points)
|
||||
attention_offsets = {}
|
||||
idxs = []
|
||||
for p1 in points:
|
||||
for p2 in points:
|
||||
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
||||
if offset not in attention_offsets:
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = self.create_parameter(
|
||||
shape=(num_heads, len(attention_offsets)),
|
||||
default_initializer=zeros_)
|
||||
tensor_idxs = paddle.to_tensor(idxs, dtype='int64')
|
||||
self.register_buffer('attention_bias_idxs',
|
||||
paddle.reshape(tensor_idxs, [N, N]))
|
||||
|
||||
global FLOPS_COUNTER
|
||||
#queries * keys
|
||||
FLOPS_COUNTER += num_heads * (resolution**4) * key_dim
|
||||
# softmax
|
||||
FLOPS_COUNTER += num_heads * (resolution**4)
|
||||
#attention * v
|
||||
FLOPS_COUNTER += num_heads * self.d * (resolution**4)
|
||||
|
||||
@paddle.no_grad()
|
||||
def train(self, mode=True):
|
||||
if mode:
|
||||
super().train()
|
||||
else:
|
||||
super().eval()
|
||||
if mode and hasattr(self, 'ab'):
|
||||
del self.ab
|
||||
else:
|
||||
gather_list = []
|
||||
attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
|
||||
for idx in self.attention_bias_idxs:
|
||||
gather = paddle.gather(attention_bias_t, idx)
|
||||
gather_list.append(gather)
|
||||
attention_biases = paddle.transpose(
|
||||
paddle.concat(gather_list), (1, 0)).reshape(
|
||||
(0, self.attention_bias_idxs.shape[0],
|
||||
self.attention_bias_idxs.shape[1]))
|
||||
self.ab = attention_biases
|
||||
#self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x): # x (B,N,C)
|
||||
self.training = True
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
qkv = paddle.reshape(qkv,
|
||||
[B, N, self.num_heads, self.h // self.num_heads])
|
||||
q, k, v = paddle.split(
|
||||
qkv, [self.key_dim, self.key_dim, self.d], axis=3)
|
||||
q = paddle.transpose(q, perm=[0, 2, 1, 3])
|
||||
k = paddle.transpose(k, perm=[0, 2, 1, 3])
|
||||
v = paddle.transpose(v, perm=[0, 2, 1, 3])
|
||||
k_transpose = paddle.transpose(k, perm=[0, 1, 3, 2])
|
||||
|
||||
if self.training:
|
||||
gather_list = []
|
||||
attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
|
||||
for idx in self.attention_bias_idxs:
|
||||
gather = paddle.gather(attention_bias_t, idx)
|
||||
gather_list.append(gather)
|
||||
attention_biases = paddle.transpose(
|
||||
paddle.concat(gather_list), (1, 0)).reshape(
|
||||
(0, self.attention_bias_idxs.shape[0],
|
||||
self.attention_bias_idxs.shape[1]))
|
||||
else:
|
||||
attention_biases = self.ab
|
||||
#np_ = paddle.to_tensor(self.attention_biases.numpy()[:, self.attention_bias_idxs.numpy()])
|
||||
#print(self.attention_bias_idxs.shape)
|
||||
#print(attention_biases.shape)
|
||||
#print(np_.shape)
|
||||
#print(np_.equal(attention_biases))
|
||||
#exit()
|
||||
|
||||
attn = ((q @k_transpose) * self.scale + attention_biases)
|
||||
attn = F.softmax(attn)
|
||||
x = paddle.transpose(attn @v, perm=[0, 2, 1, 3])
|
||||
x = paddle.reshape(x, [B, N, self.dh])
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Subsample(paddle.nn.Layer):
|
||||
def __init__(self, stride, resolution):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.resolution = resolution
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
x = paddle.reshape(x, [B, self.resolution, self.resolution,
|
||||
C])[:, ::self.stride, ::self.stride]
|
||||
x = paddle.reshape(x, [B, -1, C])
|
||||
return x
|
||||
|
||||
|
||||
class AttentionSubsample(paddle.nn.Layer):
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=2,
|
||||
activation=None,
|
||||
stride=2,
|
||||
resolution=14,
|
||||
resolution_=7):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim**-0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
self.dh = int(attn_ratio * key_dim) * self.num_heads
|
||||
self.attn_ratio = attn_ratio
|
||||
self.resolution_ = resolution_
|
||||
self.resolution_2 = resolution_**2
|
||||
self.training = True
|
||||
h = self.dh + nh_kd
|
||||
self.kv = Linear_BN(in_dim, h, resolution=resolution)
|
||||
|
||||
self.q = paddle.nn.Sequential(
|
||||
Subsample(stride, resolution),
|
||||
Linear_BN(
|
||||
in_dim, nh_kd, resolution=resolution_))
|
||||
self.proj = paddle.nn.Sequential(
|
||||
activation(), Linear_BN(
|
||||
self.dh, out_dim, resolution=resolution_))
|
||||
|
||||
self.stride = stride
|
||||
self.resolution = resolution
|
||||
points = list(itertools.product(range(resolution), range(resolution)))
|
||||
points_ = list(
|
||||
itertools.product(range(resolution_), range(resolution_)))
|
||||
|
||||
N = len(points)
|
||||
N_ = len(points_)
|
||||
attention_offsets = {}
|
||||
idxs = []
|
||||
i = 0
|
||||
j = 0
|
||||
for p1 in points_:
|
||||
i += 1
|
||||
for p2 in points:
|
||||
j += 1
|
||||
size = 1
|
||||
offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2),
|
||||
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
|
||||
if offset not in attention_offsets:
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = self.create_parameter(
|
||||
shape=(num_heads, len(attention_offsets)),
|
||||
default_initializer=zeros_)
|
||||
|
||||
tensor_idxs_ = paddle.to_tensor(idxs, dtype='int64')
|
||||
self.register_buffer('attention_bias_idxs',
|
||||
paddle.reshape(tensor_idxs_, [N_, N]))
|
||||
|
||||
global FLOPS_COUNTER
|
||||
#queries * keys
|
||||
FLOPS_COUNTER += num_heads * \
|
||||
(resolution**2) * (resolution_**2) * key_dim
|
||||
# softmax
|
||||
FLOPS_COUNTER += num_heads * (resolution**2) * (resolution_**2)
|
||||
#attention * v
|
||||
FLOPS_COUNTER += num_heads * \
|
||||
(resolution**2) * (resolution_**2) * self.d
|
||||
|
||||
@paddle.no_grad()
|
||||
def train(self, mode=True):
|
||||
if mode:
|
||||
super().train()
|
||||
else:
|
||||
super().eval()
|
||||
if mode and hasattr(self, 'ab'):
|
||||
del self.ab
|
||||
else:
|
||||
gather_list = []
|
||||
attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
|
||||
for idx in self.attention_bias_idxs:
|
||||
gather = paddle.gather(attention_bias_t, idx)
|
||||
gather_list.append(gather)
|
||||
attention_biases = paddle.transpose(
|
||||
paddle.concat(gather_list), (1, 0)).reshape(
|
||||
(0, self.attention_bias_idxs.shape[0],
|
||||
self.attention_bias_idxs.shape[1]))
|
||||
self.ab = attention_biases
|
||||
#self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x):
|
||||
self.training = True
|
||||
B, N, C = x.shape
|
||||
kv = self.kv(x)
|
||||
kv = paddle.reshape(kv, [B, N, self.num_heads, -1])
|
||||
k, v = paddle.split(kv, [self.key_dim, self.d], axis=3)
|
||||
k = paddle.transpose(k, perm=[0, 2, 1, 3]) # BHNC
|
||||
v = paddle.transpose(v, perm=[0, 2, 1, 3])
|
||||
q = paddle.reshape(
|
||||
self.q(x), [B, self.resolution_2, self.num_heads, self.key_dim])
|
||||
q = paddle.transpose(q, perm=[0, 2, 1, 3])
|
||||
|
||||
if self.training:
|
||||
gather_list = []
|
||||
attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
|
||||
for idx in self.attention_bias_idxs:
|
||||
gather = paddle.gather(attention_bias_t, idx)
|
||||
gather_list.append(gather)
|
||||
attention_biases = paddle.transpose(
|
||||
paddle.concat(gather_list), (1, 0)).reshape(
|
||||
(0, self.attention_bias_idxs.shape[0],
|
||||
self.attention_bias_idxs.shape[1]))
|
||||
else:
|
||||
attention_biases = self.ab
|
||||
|
||||
attn = (q @paddle.transpose(
|
||||
k, perm=[0, 1, 3, 2])) * self.scale + attention_biases
|
||||
attn = F.softmax(attn)
|
||||
|
||||
x = paddle.reshape(
|
||||
paddle.transpose(
|
||||
(attn @v), perm=[0, 2, 1, 3]), [B, -1, self.dh])
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class LeViT(paddle.nn.Layer):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
class_dim=1000,
|
||||
embed_dim=[192],
|
||||
key_dim=[64],
|
||||
depth=[12],
|
||||
num_heads=[3],
|
||||
attn_ratio=[2],
|
||||
mlp_ratio=[2],
|
||||
hybrid_backbone=None,
|
||||
down_ops=[],
|
||||
attention_activation=paddle.nn.Hardswish,
|
||||
mlp_activation=paddle.nn.Hardswish,
|
||||
distillation=True,
|
||||
drop_path=0):
|
||||
super().__init__()
|
||||
global FLOPS_COUNTER
|
||||
|
||||
self.class_dim = class_dim
|
||||
self.num_features = embed_dim[-1]
|
||||
self.embed_dim = embed_dim
|
||||
self.distillation = distillation
|
||||
|
||||
self.patch_embed = hybrid_backbone
|
||||
|
||||
self.blocks = []
|
||||
down_ops.append([''])
|
||||
resolution = img_size // patch_size
|
||||
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
|
||||
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio,
|
||||
down_ops)):
|
||||
for _ in range(dpth):
|
||||
self.blocks.append(
|
||||
Residual(
|
||||
Attention(
|
||||
ed,
|
||||
kd,
|
||||
nh,
|
||||
attn_ratio=ar,
|
||||
activation=attention_activation,
|
||||
resolution=resolution, ),
|
||||
drop_path))
|
||||
if mr > 0:
|
||||
h = int(ed * mr)
|
||||
self.blocks.append(
|
||||
Residual(
|
||||
paddle.nn.Sequential(
|
||||
Linear_BN(
|
||||
ed, h, resolution=resolution),
|
||||
mlp_activation(),
|
||||
Linear_BN(
|
||||
h,
|
||||
ed,
|
||||
bn_weight_init=0,
|
||||
resolution=resolution), ),
|
||||
drop_path))
|
||||
if do[0] == 'Subsample':
|
||||
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
|
||||
resolution_ = (resolution - 1) // do[5] + 1
|
||||
self.blocks.append(
|
||||
AttentionSubsample(
|
||||
*embed_dim[i:i + 2],
|
||||
key_dim=do[1],
|
||||
num_heads=do[2],
|
||||
attn_ratio=do[3],
|
||||
activation=attention_activation,
|
||||
stride=do[5],
|
||||
resolution=resolution,
|
||||
resolution_=resolution_))
|
||||
resolution = resolution_
|
||||
if do[4] > 0: # mlp_ratio
|
||||
h = int(embed_dim[i + 1] * do[4])
|
||||
self.blocks.append(
|
||||
Residual(
|
||||
paddle.nn.Sequential(
|
||||
Linear_BN(
|
||||
embed_dim[i + 1], h, resolution=resolution),
|
||||
mlp_activation(),
|
||||
Linear_BN(
|
||||
h,
|
||||
embed_dim[i + 1],
|
||||
bn_weight_init=0,
|
||||
resolution=resolution), ),
|
||||
drop_path))
|
||||
self.blocks = paddle.nn.Sequential(*self.blocks)
|
||||
|
||||
# Classifier head
|
||||
self.head = BN_Linear(
|
||||
embed_dim[-1], class_dim) if class_dim > 0 else paddle.nn.Identity()
|
||||
if distillation:
|
||||
self.head_dist = BN_Linear(
|
||||
embed_dim[-1],
|
||||
class_dim) if class_dim > 0 else paddle.nn.Identity()
|
||||
|
||||
self.FLOPS = FLOPS_COUNTER
|
||||
FLOPS_COUNTER = 0
|
||||
|
||||
def no_weight_decay(self):
|
||||
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = x.flatten(2)
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
x = self.blocks(x)
|
||||
x = x.mean(1)
|
||||
if self.distillation:
|
||||
x = self.head(x), self.head_dist(x)
|
||||
if not self.training:
|
||||
x = (x[0] + x[1]) / 2
|
||||
else:
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def model_factory(C, D, X, N, drop_path, weights, class_dim, distillation,
|
||||
pretrained, fuse):
|
||||
embed_dim = [int(x) for x in C.split('_')]
|
||||
num_heads = [int(x) for x in N.split('_')]
|
||||
depth = [int(x) for x in X.split('_')]
|
||||
act = paddle.nn.Hardswish
|
||||
model = LeViT(
|
||||
patch_size=16,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
key_dim=[D] * 3,
|
||||
depth=depth,
|
||||
attn_ratio=[2, 2, 2],
|
||||
mlp_ratio=[2, 2, 2],
|
||||
down_ops=[
|
||||
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
|
||||
['Subsample', D, embed_dim[0] // D, 4, 2, 2],
|
||||
['Subsample', D, embed_dim[1] // D, 4, 2, 2],
|
||||
],
|
||||
attention_activation=act,
|
||||
mlp_activation=act,
|
||||
hybrid_backbone=b16(embed_dim[0], activation=act),
|
||||
class_dim=class_dim,
|
||||
drop_path=drop_path,
|
||||
distillation=distillation)
|
||||
# if pretrained:
|
||||
# checkpoint = torch.hub.load_state_dict_from_url(
|
||||
# weights, map_location='cpu')
|
||||
# model.load_state_dict(checkpoint['model'])
|
||||
if fuse:
|
||||
utils.replace_batchnorm(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
import torch
|
||||
checkpoint = torch.load('../LeViT/pretrained256.pth')
|
||||
torch_dict = checkpoint['net']
|
||||
paddle_dict = {}
|
||||
fc_names = ["c.weight", "l.weight", "qkv.weight", "fc1.weight", "fc2.weight", "downsample.reduction.weight", "head.weight", "attn.proj.weight"]
|
||||
rename_dict = {"running_mean": "_mean", "running_var": "_variance"}
|
||||
range_tuple = (0, 502)
|
||||
idx = 0
|
||||
for key in torch_dict:
|
||||
idx += 1
|
||||
weight = torch_dict[key].cpu().numpy()
|
||||
flag = [i in key for i in fc_names]
|
||||
if any(flag):
|
||||
if "emb" not in key:
|
||||
print("weight {} need to be trans".format(key))
|
||||
weight = weight.transpose()
|
||||
key = key.replace("running_mean", "_mean")
|
||||
key = key.replace("running_var", "_variance")
|
||||
paddle_dict[key]=weight
|
||||
'''
|
||||
import numpy as np
|
||||
net = globals()['LeViT_256'](fuse=False,
|
||||
pretrained=False,
|
||||
distillation=False)
|
||||
load_layer_state_dict = paddle.load(
|
||||
"./LeViT_256_official_nodistillation_paddle.pdparams")
|
||||
#net.set_state_dict(paddle_dict)
|
||||
net.set_state_dict(load_layer_state_dict)
|
||||
net.eval()
|
||||
#paddle.save(net.state_dict(), "./LeViT_256_official_paddle.pdparams")
|
||||
#model = paddle.jit.to_static(net,input_spec=[paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype='float32')])
|
||||
#paddle.jit.save(model, "./LeViT_256_official_inference/inference")
|
||||
#exit()
|
||||
np.random.seed(123)
|
||||
img = np.random.rand(1, 3, 224, 224).astype('float32')
|
||||
img = paddle.to_tensor(img)
|
||||
outputs = net(img).numpy()
|
||||
print(outputs[0][:10])
|
||||
#print(outputs.shape)
|
|
@ -0,0 +1,147 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
import sys
|
||||
import math
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2D(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias_attr=False)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2D(
|
||||
in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
|
||||
|
||||
|
||||
def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
|
||||
# [n_position]
|
||||
positions = paddle.arange(0, n_position)
|
||||
# [feat_dim]
|
||||
dim_range = paddle.arange(0, feat_dim)
|
||||
dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
|
||||
# [n_position, feat_dim]
|
||||
angles = paddle.unsqueeze(
|
||||
positions, axis=1) / paddle.unsqueeze(
|
||||
dim_range, axis=0)
|
||||
angles = paddle.cast(angles, "float32")
|
||||
angles[:, 0::2] = paddle.sin(angles[:, 0::2])
|
||||
angles[:, 1::2] = paddle.cos(angles[:, 1::2])
|
||||
return angles
|
||||
|
||||
|
||||
class AsterBlock(nn.Layer):
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(AsterBlock, self).__init__()
|
||||
self.conv1 = conv1x1(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2D(planes)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2D(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class ResNet_ASTER(nn.Layer):
|
||||
"""For aster or crnn"""
|
||||
|
||||
def __init__(self, with_lstm=True, n_group=1, in_channels=3):
|
||||
super(ResNet_ASTER, self).__init__()
|
||||
self.with_lstm = with_lstm
|
||||
self.n_group = n_group
|
||||
|
||||
self.layer0 = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
in_channels,
|
||||
32,
|
||||
kernel_size=(3, 3),
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm2D(32),
|
||||
nn.ReLU())
|
||||
|
||||
self.inplanes = 32
|
||||
self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
|
||||
self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
|
||||
self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
|
||||
self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
|
||||
self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
|
||||
|
||||
if with_lstm:
|
||||
self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2)
|
||||
self.out_channels = 2 * 256
|
||||
else:
|
||||
self.out_channels = 512
|
||||
|
||||
def _make_layer(self, planes, blocks, stride):
|
||||
downsample = None
|
||||
if stride != [1, 1] or self.inplanes != planes:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
|
||||
|
||||
layers = []
|
||||
layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes
|
||||
for _ in range(1, blocks):
|
||||
layers.append(AsterBlock(self.inplanes, planes))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.layer0(x)
|
||||
x1 = self.layer1(x0)
|
||||
x2 = self.layer2(x1)
|
||||
x3 = self.layer3(x2)
|
||||
x4 = self.layer4(x3)
|
||||
x5 = self.layer5(x4)
|
||||
|
||||
cnn_feat = x5.squeeze(2) # [N, c, w]
|
||||
cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1])
|
||||
if self.with_lstm:
|
||||
rnn_feat, _ = self.rnn(cnn_feat)
|
||||
return rnn_feat
|
||||
else:
|
||||
return cnn_feat
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x = paddle.randn([3, 3, 32, 100])
|
||||
net = ResNet_ASTER()
|
||||
encoder_feat = net(x)
|
||||
print(encoder_feat.shape)
|
|
@ -26,12 +26,15 @@ def build_head(config):
|
|||
from .rec_ctc_head import CTCHead
|
||||
from .rec_att_head import AttentionHead
|
||||
from .rec_srn_head import SRNHead
|
||||
from .rec_aster_head import AttentionRecognitionHead, AsterHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead', 'TableAttentionHead']
|
||||
'SRNHead', 'PGHead', 'TableAttentionHead', 'AttentionRecognitionHead',
|
||||
'AsterHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
from .table_att_head import TableAttentionHead
|
||||
|
@ -39,5 +42,6 @@ def build_head(config):
|
|||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
support_dict))
|
||||
print(config)
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
|
|
@ -0,0 +1,258 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
class AsterHead(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
sDim,
|
||||
attDim,
|
||||
max_len_labels,
|
||||
time_step=25,
|
||||
beam_width=5,
|
||||
**kwargs):
|
||||
super(AsterHead, self).__init__()
|
||||
self.num_classes = out_channels
|
||||
self.in_planes = in_channels
|
||||
self.sDim = sDim
|
||||
self.attDim = attDim
|
||||
self.max_len_labels = max_len_labels
|
||||
self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim,
|
||||
attDim, max_len_labels)
|
||||
self.time_step = time_step
|
||||
self.embeder = Embedding(self.time_step, in_channels)
|
||||
self.beam_width = beam_width
|
||||
|
||||
def forward(self, x, targets=None, embed=None):
|
||||
return_dict = {}
|
||||
embedding_vectors = self.embeder(x)
|
||||
rec_targets, rec_lengths = targets
|
||||
|
||||
if self.training:
|
||||
rec_pred = self.decoder([x, rec_targets, rec_lengths],
|
||||
embedding_vectors)
|
||||
return_dict['rec_pred'] = rec_pred
|
||||
return_dict['embedding_vectors'] = embedding_vectors
|
||||
else:
|
||||
rec_pred, rec_pred_scores = self.decoder.beam_search(
|
||||
x, self.beam_width, self.eos, embedding_vectors)
|
||||
return_dict['rec_pred'] = rec_pred
|
||||
return_dict['rec_pred_scores'] = rec_pred_scores
|
||||
return_dict['embedding_vectors'] = embedding_vectors
|
||||
|
||||
return return_dict
|
||||
|
||||
|
||||
class Embedding(nn.Layer):
|
||||
def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
|
||||
super(Embedding, self).__init__()
|
||||
self.in_timestep = in_timestep
|
||||
self.in_planes = in_planes
|
||||
self.embed_dim = embed_dim
|
||||
self.mid_dim = mid_dim
|
||||
self.eEmbed = nn.Linear(
|
||||
in_timestep * in_planes,
|
||||
self.embed_dim) # Embed encoder output to a word-embedding like
|
||||
|
||||
def forward(self, x):
|
||||
x = paddle.reshape(x, [paddle.shape(x)[0], -1])
|
||||
x = self.eEmbed(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionRecognitionHead(nn.Layer):
|
||||
"""
|
||||
input: [b x 16 x 64 x in_planes]
|
||||
output: probability sequence: [b x T x num_classes]
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
|
||||
super(AttentionRecognitionHead, self).__init__()
|
||||
self.num_classes = out_channels # this is the output classes. So it includes the <EOS>.
|
||||
self.in_planes = in_channels
|
||||
self.sDim = sDim
|
||||
self.attDim = attDim
|
||||
self.max_len_labels = max_len_labels
|
||||
|
||||
self.decoder = DecoderUnit(
|
||||
sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim)
|
||||
|
||||
def forward(self, x, embed):
|
||||
x, targets, lengths = x
|
||||
batch_size = paddle.shape(x)[0]
|
||||
# Decoder
|
||||
state = self.decoder.get_initial_state(embed)
|
||||
outputs = []
|
||||
|
||||
for i in range(max(lengths)):
|
||||
if i == 0:
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size], fill_value=self.num_classes)
|
||||
else:
|
||||
y_prev = targets[:, i - 1]
|
||||
|
||||
output, state = self.decoder(x, state, y_prev)
|
||||
outputs.append(output)
|
||||
outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
|
||||
return outputs
|
||||
|
||||
# inference stage.
|
||||
def sample(self, x):
|
||||
x, _, _ = x
|
||||
batch_size = x.size(0)
|
||||
# Decoder
|
||||
state = paddle.zeros([1, batch_size, self.sDim])
|
||||
|
||||
predicted_ids, predicted_scores = [], []
|
||||
for i in range(self.max_len_labels):
|
||||
if i == 0:
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size], fill_value=self.num_classes)
|
||||
else:
|
||||
y_prev = predicted
|
||||
|
||||
output, state = self.decoder(x, state, y_prev)
|
||||
output = F.softmax(output, axis=1)
|
||||
score, predicted = output.max(1)
|
||||
predicted_ids.append(predicted.unsqueeze(1))
|
||||
predicted_scores.append(score.unsqueeze(1))
|
||||
predicted_ids = paddle.concat([predicted_ids, 1])
|
||||
predicted_scores = paddle.concat([predicted_scores, 1])
|
||||
# return predicted_ids.squeeze(), predicted_scores.squeeze()
|
||||
return predicted_ids, predicted_scores
|
||||
|
||||
|
||||
class AttentionUnit(nn.Layer):
|
||||
def __init__(self, sDim, xDim, attDim):
|
||||
super(AttentionUnit, self).__init__()
|
||||
|
||||
self.sDim = sDim
|
||||
self.xDim = xDim
|
||||
self.attDim = attDim
|
||||
|
||||
self.sEmbed = nn.Linear(
|
||||
sDim,
|
||||
attDim,
|
||||
weight_attr=paddle.nn.initializer.Normal(std=0.01),
|
||||
bias_attr=paddle.nn.initializer.Constant(0.0))
|
||||
self.xEmbed = nn.Linear(
|
||||
xDim,
|
||||
attDim,
|
||||
weight_attr=paddle.nn.initializer.Normal(std=0.01),
|
||||
bias_attr=paddle.nn.initializer.Constant(0.0))
|
||||
self.wEmbed = nn.Linear(
|
||||
attDim,
|
||||
1,
|
||||
weight_attr=paddle.nn.initializer.Normal(std=0.01),
|
||||
bias_attr=paddle.nn.initializer.Constant(0.0))
|
||||
|
||||
def forward(self, x, sPrev):
|
||||
batch_size, T, _ = x.shape # [b x T x xDim]
|
||||
x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim]
|
||||
xProj = self.xEmbed(x) # [(b x T) x attDim]
|
||||
xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim]
|
||||
|
||||
sPrev = sPrev.squeeze(0)
|
||||
sProj = self.sEmbed(sPrev) # [b x attDim]
|
||||
sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
|
||||
sProj = paddle.expand(sProj,
|
||||
[batch_size, T, self.attDim]) # [b x T x attDim]
|
||||
|
||||
sumTanh = paddle.tanh(sProj + xProj)
|
||||
sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
|
||||
|
||||
vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
|
||||
vProj = paddle.reshape(vProj, [batch_size, T])
|
||||
|
||||
alpha = F.softmax(
|
||||
vProj, axis=1) # attention weights for each sample in the minibatch
|
||||
|
||||
return alpha
|
||||
|
||||
|
||||
class DecoderUnit(nn.Layer):
|
||||
def __init__(self, sDim, xDim, yDim, attDim):
|
||||
super(DecoderUnit, self).__init__()
|
||||
self.sDim = sDim
|
||||
self.xDim = xDim
|
||||
self.yDim = yDim
|
||||
self.attDim = attDim
|
||||
self.emdDim = attDim
|
||||
|
||||
self.attention_unit = AttentionUnit(sDim, xDim, attDim)
|
||||
self.tgt_embedding = nn.Embedding(
|
||||
yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(
|
||||
std=0.01)) # the last is used for <BOS>
|
||||
self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
|
||||
self.fc = nn.Linear(
|
||||
sDim,
|
||||
yDim,
|
||||
weight_attr=nn.initializer.Normal(std=0.01),
|
||||
bias_attr=nn.initializer.Constant(value=0))
|
||||
self.embed_fc = nn.Linear(300, self.sDim)
|
||||
|
||||
def get_initial_state(self, embed, tile_times=1):
|
||||
assert embed.shape[1] == 300
|
||||
state = self.embed_fc(embed) # N * sDim
|
||||
if tile_times != 1:
|
||||
state = state.unsqueeze(1)
|
||||
trans_state = paddle.transpose(state, perm=[1, 0, 2])
|
||||
state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1])
|
||||
trans_state = paddle.transpose(state, perm=[1, 0, 2])
|
||||
state = paddle.reshape(trans_state, shape=[-1, self.sDim])
|
||||
state = state.unsqueeze(0) # 1 * N * sDim
|
||||
return state
|
||||
|
||||
def forward(self, x, sPrev, yPrev):
|
||||
# x: feature sequence from the image decoder.
|
||||
batch_size, T, _ = x.shape
|
||||
alpha = self.attention_unit(x, sPrev)
|
||||
context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1)
|
||||
yPrev = paddle.cast(yPrev, dtype="int64")
|
||||
yProj = self.tgt_embedding(yPrev)
|
||||
|
||||
concat_context = paddle.concat([yProj, context], 1)
|
||||
concat_context = paddle.squeeze(concat_context, 1)
|
||||
sPrev = paddle.squeeze(sPrev, 0)
|
||||
output, state = self.gru(concat_context, sPrev)
|
||||
output = paddle.squeeze(output, axis=1)
|
||||
output = self.fc(output)
|
||||
return output, state
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = AttentionRecognitionHead(
|
||||
num_classes=20,
|
||||
in_channels=30,
|
||||
sDim=512,
|
||||
attDim=512,
|
||||
max_len_labels=25,
|
||||
out_channels=38)
|
||||
|
||||
data = paddle.ones([16, 64, 3])
|
||||
targets = paddle.ones([16, 25])
|
||||
length = paddle.to_tensor(20)
|
||||
x = [data, targets, length]
|
||||
output = model(x)
|
||||
print(output.shape)
|
|
@ -44,10 +44,13 @@ class AttentionHead(nn.Layer):
|
|||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
output_hiddens = []
|
||||
|
||||
targets = targets[0]
|
||||
print(targets)
|
||||
if targets is not None:
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets[:, i], onehot_dim=self.num_classes)
|
||||
# print("char_onehots:", char_onehots)
|
||||
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
|
@ -104,6 +107,8 @@ class AttentionGRUCell(nn.Layer):
|
|||
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||
concat_context = paddle.concat([context, char_onehots], 1)
|
||||
# print("concat_context:", concat_context.shape)
|
||||
# print("prev_hidden:", prev_hidden.shape)
|
||||
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
|
||||
|
|
|
@ -17,8 +17,9 @@ __all__ = ['build_transform']
|
|||
|
||||
def build_transform(config):
|
||||
from .tps import TPS
|
||||
from .tps import STN_ON
|
||||
|
||||
support_dict = ['TPS']
|
||||
support_dict = ['TPS', 'STN_ON']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception(
|
||||
|
|
|
@ -0,0 +1,121 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
def conv3x3_block(in_channels, out_channels, stride=1):
|
||||
n = 3 * 3 * out_channels
|
||||
w = math.sqrt(2. / n)
|
||||
conv_layer = nn.Conv2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
weight_attr=nn.initializer.Normal(
|
||||
mean=0.0, std=w),
|
||||
bias_attr=nn.initializer.Constant(0))
|
||||
block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU())
|
||||
return block
|
||||
|
||||
|
||||
class STN(nn.Layer):
|
||||
def __init__(self, in_channels, num_ctrlpoints, activation='none'):
|
||||
super(STN, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.num_ctrlpoints = num_ctrlpoints
|
||||
self.activation = activation
|
||||
self.stn_convnet = nn.Sequential(
|
||||
conv3x3_block(in_channels, 32), #32x64
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(32, 64), #16x32
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(64, 128), # 8*16
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(128, 256), # 4*8
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(256, 256), # 2*4,
|
||||
nn.MaxPool2D(
|
||||
kernel_size=2, stride=2),
|
||||
conv3x3_block(256, 256)) # 1*2
|
||||
self.stn_fc1 = nn.Sequential(
|
||||
nn.Linear(
|
||||
2 * 256,
|
||||
512,
|
||||
weight_attr=nn.initializer.Normal(0, 0.001),
|
||||
bias_attr=nn.initializer.Constant(0)),
|
||||
nn.BatchNorm1D(512),
|
||||
nn.ReLU())
|
||||
fc2_bias = self.init_stn()
|
||||
self.stn_fc2 = nn.Linear(
|
||||
512,
|
||||
num_ctrlpoints * 2,
|
||||
weight_attr=nn.initializer.Constant(0.0),
|
||||
bias_attr=nn.initializer.Assign(fc2_bias))
|
||||
|
||||
def init_stn(self):
|
||||
margin = 0.01
|
||||
sampling_num_per_side = int(self.num_ctrlpoints / 2)
|
||||
ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
|
||||
ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
|
||||
ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
ctrl_points = np.concatenate(
|
||||
[ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
|
||||
if self.activation == 'none':
|
||||
pass
|
||||
elif self.activation == 'sigmoid':
|
||||
ctrl_points = -np.log(1. / ctrl_points - 1.)
|
||||
ctrl_points = paddle.to_tensor(ctrl_points)
|
||||
fc2_bias = paddle.reshape(
|
||||
ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]])
|
||||
return fc2_bias
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stn_convnet(x)
|
||||
batch_size, _, h, w = x.shape
|
||||
x = paddle.reshape(x, shape=(batch_size, -1))
|
||||
img_feat = self.stn_fc1(x)
|
||||
x = self.stn_fc2(0.1 * img_feat)
|
||||
if self.activation == 'sigmoid':
|
||||
x = F.sigmoid(x)
|
||||
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
|
||||
return img_feat, x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
in_planes = 3
|
||||
num_ctrlpoints = 20
|
||||
np.random.seed(100)
|
||||
activation = 'none' # 'sigmoid'
|
||||
stn_head = STN(in_planes, num_ctrlpoints, activation)
|
||||
data = np.random.randn(10, 3, 32, 64).astype("float32")
|
||||
print("data:", np.sum(data))
|
||||
input = paddle.to_tensor(data)
|
||||
#input = paddle.randn([10, 3, 32, 64])
|
||||
control_points = stn_head(input)
|
|
@ -22,6 +22,9 @@ from paddle import nn, ParamAttr
|
|||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
from .tps_spatial_transformer import TPSSpatialTransformer
|
||||
from .stn import STN
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
|
@ -231,7 +234,8 @@ class GridGenerator(nn.Layer):
|
|||
""" Return inv_delta_C which is needed to calculate T """
|
||||
F = self.F
|
||||
hat_eye = paddle.eye(F, dtype='float64') # F x F
|
||||
hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
||||
hat_C = paddle.norm(
|
||||
C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
||||
hat_C = (hat_C**2) * paddle.log(hat_C)
|
||||
delta_C = paddle.concat( # F+3 x F+3
|
||||
[
|
||||
|
@ -301,3 +305,26 @@ class TPS(nn.Layer):
|
|||
[-1, image.shape[2], image.shape[3], 2])
|
||||
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
|
||||
return batch_I_r
|
||||
|
||||
|
||||
class STN_ON(nn.Layer):
|
||||
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
|
||||
num_control_points, tps_margins, stn_activation):
|
||||
super(STN_ON, self).__init__()
|
||||
self.tps = TPSSpatialTransformer(
|
||||
output_image_size=tuple(tps_outputsize),
|
||||
num_control_points=num_control_points,
|
||||
margins=tuple(tps_margins))
|
||||
self.stn_head = STN(in_channels=in_channels,
|
||||
num_ctrlpoints=num_control_points,
|
||||
activation=stn_activation)
|
||||
self.tps_inputsize = tps_inputsize
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, image):
|
||||
stn_input = paddle.nn.functional.interpolate(
|
||||
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
||||
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
||||
x, _ = self.tps(image, ctrl_points)
|
||||
# print(x.shape)
|
||||
return x
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
|
||||
def grid_sample(input, grid, canvas=None):
|
||||
input.stop_gradient = False
|
||||
output = F.grid_sample(input, grid)
|
||||
if canvas is None:
|
||||
return output
|
||||
else:
|
||||
input_mask = paddle.ones(shape=input.shape)
|
||||
output_mask = F.grid_sample(input_mask, grid)
|
||||
padded_output = output * output_mask + canvas * (1 - output_mask)
|
||||
return padded_output
|
||||
|
||||
|
||||
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
||||
def compute_partial_repr(input_points, control_points):
|
||||
N = input_points.shape[0]
|
||||
M = control_points.shape[0]
|
||||
pairwise_diff = paddle.reshape(
|
||||
input_points, shape=[N, 1, 2]) - paddle.reshape(
|
||||
control_points, shape=[1, M, 2])
|
||||
# original implementation, very slow
|
||||
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
||||
pairwise_diff_square = pairwise_diff * pairwise_diff
|
||||
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
|
||||
1]
|
||||
repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist)
|
||||
# fix numerical error for 0 * log(0), substitute all nan with 0
|
||||
mask = repr_matrix != repr_matrix
|
||||
repr_matrix[mask] = 0
|
||||
return repr_matrix
|
||||
|
||||
|
||||
# output_ctrl_pts are specified, according to our task.
|
||||
def build_output_control_points(num_control_points, margins):
|
||||
margin_x, margin_y = margins
|
||||
num_ctrl_pts_per_side = num_control_points // 2
|
||||
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
||||
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
||||
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
|
||||
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
|
||||
output_ctrl_pts_arr = np.concatenate(
|
||||
[ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
||||
output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr)
|
||||
return output_ctrl_pts
|
||||
|
||||
|
||||
class TPSSpatialTransformer(nn.Layer):
|
||||
def __init__(self,
|
||||
output_image_size=None,
|
||||
num_control_points=None,
|
||||
margins=None):
|
||||
super(TPSSpatialTransformer, self).__init__()
|
||||
self.output_image_size = output_image_size
|
||||
self.num_control_points = num_control_points
|
||||
self.margins = margins
|
||||
|
||||
self.target_height, self.target_width = output_image_size
|
||||
target_control_points = build_output_control_points(num_control_points,
|
||||
margins)
|
||||
N = num_control_points
|
||||
# N = N - 4
|
||||
|
||||
# create padded kernel matrix
|
||||
forward_kernel = paddle.zeros(shape=[N + 3, N + 3])
|
||||
target_control_partial_repr = compute_partial_repr(
|
||||
target_control_points, target_control_points)
|
||||
target_control_partial_repr = paddle.cast(target_control_partial_repr,
|
||||
forward_kernel.dtype)
|
||||
forward_kernel[:N, :N] = target_control_partial_repr
|
||||
forward_kernel[:N, -3] = 1
|
||||
forward_kernel[-3, :N] = 1
|
||||
target_control_points = paddle.cast(target_control_points,
|
||||
forward_kernel.dtype)
|
||||
forward_kernel[:N, -2:] = target_control_points
|
||||
forward_kernel[-2:, :N] = paddle.transpose(
|
||||
target_control_points, perm=[1, 0])
|
||||
# compute inverse matrix
|
||||
inverse_kernel = paddle.inverse(forward_kernel)
|
||||
|
||||
# create target cordinate matrix
|
||||
HW = self.target_height * self.target_width
|
||||
target_coordinate = list(
|
||||
itertools.product(
|
||||
range(self.target_height), range(self.target_width)))
|
||||
target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2
|
||||
Y, X = paddle.split(
|
||||
target_coordinate, target_coordinate.shape[1], axis=1)
|
||||
#Y, X = target_coordinate.split(1, dim = 1)
|
||||
Y = Y / (self.target_height - 1)
|
||||
X = X / (self.target_width - 1)
|
||||
target_coordinate = paddle.concat(
|
||||
[X, Y], axis=1) # convert from (y, x) to (x, y)
|
||||
target_coordinate_partial_repr = compute_partial_repr(
|
||||
target_coordinate, target_control_points)
|
||||
target_coordinate_repr = paddle.concat(
|
||||
[
|
||||
target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]),
|
||||
target_coordinate
|
||||
],
|
||||
axis=1)
|
||||
|
||||
# register precomputed matrices
|
||||
self.inverse_kernel = inverse_kernel
|
||||
self.padding_matrix = paddle.zeros(shape=[3, 2])
|
||||
self.target_coordinate_repr = target_coordinate_repr
|
||||
self.target_control_points = target_control_points
|
||||
|
||||
def forward(self, input, source_control_points):
|
||||
assert source_control_points.ndimension() == 3
|
||||
assert source_control_points.shape[1] == self.num_control_points
|
||||
assert source_control_points.shape[2] == 2
|
||||
batch_size = source_control_points.shape[0]
|
||||
|
||||
self.padding_matrix = paddle.expand(
|
||||
self.padding_matrix, shape=[batch_size, 3, 2])
|
||||
Y = paddle.concat([source_control_points, self.padding_matrix], 1)
|
||||
mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
|
||||
source_coordinate = paddle.matmul(self.target_coordinate_repr,
|
||||
mapping_matrix)
|
||||
|
||||
grid = paddle.reshape(
|
||||
source_coordinate,
|
||||
shape=[-1, self.target_height, self.target_width, 2])
|
||||
grid = paddle.clip(grid, 0,
|
||||
1) # the source_control_points may be out of [0, 1].
|
||||
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
||||
# grid = 2.0 * grid - 1.0
|
||||
output_maps = grid_sample(input, grid, canvas=None)
|
||||
return output_maps, source_coordinate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from stn import STN
|
||||
in_planes = 3
|
||||
num_ctrlpoints = 20
|
||||
np.random.seed(100)
|
||||
activation = 'none' # 'sigmoid'
|
||||
stn_head = STN(in_planes, num_ctrlpoints, activation)
|
||||
data = np.random.randn(10, 3, 32, 64).astype("float32")
|
||||
input = paddle.to_tensor(data)
|
||||
#input = paddle.randn([10, 3, 32, 64])
|
||||
control_points = stn_head(input)
|
||||
#print("control points:", control_points)
|
||||
#input = paddle.randn(shape=[10,3,32,100])
|
||||
tps = TPSSpatialTransformer(
|
||||
output_image_size=[32, 320],
|
||||
num_control_points=20,
|
||||
margins=[0.05, 0.05])
|
||||
out = tps(input, control_points[1])
|
||||
print("out 0 :", out[0].shape)
|
||||
print("out 1:", out[1].shape)
|
|
@ -0,0 +1,149 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def grid_sample(input, grid, canvas=None):
|
||||
output = F.grid_sample(input, grid)
|
||||
if canvas is None:
|
||||
return output
|
||||
else:
|
||||
input_mask = input.data.new(input.size()).fill_(1)
|
||||
output_mask = F.grid_sample(input_mask, grid)
|
||||
padded_output = output * output_mask + canvas * (1 - output_mask)
|
||||
return padded_output
|
||||
|
||||
|
||||
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
||||
def compute_partial_repr(input_points, control_points):
|
||||
N = input_points.size(0)
|
||||
M = control_points.size(0)
|
||||
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
|
||||
# original implementation, very slow
|
||||
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
||||
pairwise_diff_square = pairwise_diff * pairwise_diff
|
||||
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
|
||||
1]
|
||||
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
|
||||
# fix numerical error for 0 * log(0), substitute all nan with 0
|
||||
mask = repr_matrix != repr_matrix
|
||||
repr_matrix.masked_fill_(mask, 0)
|
||||
return repr_matrix
|
||||
|
||||
|
||||
# output_ctrl_pts are specified, according to our task.
|
||||
def build_output_control_points(num_control_points, margins):
|
||||
margin_x, margin_y = margins
|
||||
num_ctrl_pts_per_side = num_control_points // 2
|
||||
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
||||
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
||||
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
|
||||
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
|
||||
output_ctrl_pts_arr = np.concatenate(
|
||||
[ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
||||
output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
|
||||
return output_ctrl_pts
|
||||
|
||||
|
||||
# demo: ~/test/models/test_tps_transformation.py
|
||||
class TPSSpatialTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
output_image_size=None,
|
||||
num_control_points=None,
|
||||
margins=None):
|
||||
super(TPSSpatialTransformer, self).__init__()
|
||||
self.output_image_size = output_image_size
|
||||
self.num_control_points = num_control_points
|
||||
self.margins = margins
|
||||
|
||||
self.target_height, self.target_width = output_image_size
|
||||
target_control_points = build_output_control_points(num_control_points,
|
||||
margins)
|
||||
N = num_control_points
|
||||
# N = N - 4
|
||||
|
||||
# create padded kernel matrix
|
||||
forward_kernel = torch.zeros(N + 3, N + 3)
|
||||
target_control_partial_repr = compute_partial_repr(
|
||||
target_control_points, target_control_points)
|
||||
forward_kernel[:N, :N].copy_(target_control_partial_repr)
|
||||
forward_kernel[:N, -3].fill_(1)
|
||||
forward_kernel[-3, :N].fill_(1)
|
||||
forward_kernel[:N, -2:].copy_(target_control_points)
|
||||
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
|
||||
# compute inverse matrix
|
||||
inverse_kernel = torch.inverse(forward_kernel)
|
||||
|
||||
# create target cordinate matrix
|
||||
HW = self.target_height * self.target_width
|
||||
target_coordinate = list(
|
||||
itertools.product(
|
||||
range(self.target_height), range(self.target_width)))
|
||||
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
|
||||
Y, X = target_coordinate.split(1, dim=1)
|
||||
Y = Y / (self.target_height - 1)
|
||||
X = X / (self.target_width - 1)
|
||||
target_coordinate = torch.cat([X, Y],
|
||||
dim=1) # convert from (y, x) to (x, y)
|
||||
target_coordinate_partial_repr = compute_partial_repr(
|
||||
target_coordinate, target_control_points)
|
||||
target_coordinate_repr = torch.cat([
|
||||
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
|
||||
],
|
||||
dim=1)
|
||||
|
||||
# register precomputed matrices
|
||||
self.register_buffer('inverse_kernel', inverse_kernel)
|
||||
self.register_buffer('padding_matrix', torch.zeros(3, 2))
|
||||
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
|
||||
self.register_buffer('target_control_points', target_control_points)
|
||||
|
||||
def forward(self, input, source_control_points):
|
||||
assert source_control_points.ndimension() == 3
|
||||
assert source_control_points.size(1) == self.num_control_points
|
||||
assert source_control_points.size(2) == 2
|
||||
batch_size = source_control_points.size(0)
|
||||
|
||||
Y = torch.cat([
|
||||
source_control_points, self.padding_matrix.expand(batch_size, 3, 2)
|
||||
], 1)
|
||||
mapping_matrix = torch.matmul(self.inverse_kernel, Y)
|
||||
source_coordinate = torch.matmul(self.target_coordinate_repr,
|
||||
mapping_matrix)
|
||||
|
||||
grid = source_coordinate.view(-1, self.target_height, self.target_width,
|
||||
2)
|
||||
grid = torch.clamp(grid, 0,
|
||||
1) # the source_control_points may be out of [0, 1].
|
||||
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
||||
grid = 2.0 * grid - 1.0
|
||||
output_maps = grid_sample(input, grid, canvas=None)
|
||||
return output_maps, source_coordinate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from stn_torch import STNHead
|
||||
in_planes = 3
|
||||
num_ctrlpoints = 20
|
||||
torch.manual_seed(10)
|
||||
activation = 'none' # 'sigmoid'
|
||||
stn_head = STNHead(in_planes, num_ctrlpoints, activation)
|
||||
np.random.seed(100)
|
||||
data = np.random.randn(10, 3, 32, 64).astype("float32")
|
||||
input = torch.tensor(data)
|
||||
control_points = stn_head(input)
|
||||
tps = TPSSpatialTransformer(
|
||||
output_image_size=[32, 320],
|
||||
num_control_points=20,
|
||||
margins=[0.05, 0.05])
|
||||
out = tps(input, control_points[1])
|
||||
print("out 0 :", out[0].shape)
|
||||
print("out 1:", out[1].shape)
|
|
@ -170,8 +170,10 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.unkonwn = "UNKNOWN"
|
||||
dict_character = dict_character
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str
|
||||
] + [self.unkonwn]
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
|
@ -212,6 +214,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
"""
|
||||
preds = preds["rec_pred"]
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
|
||||
|
@ -324,10 +327,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|||
class TableLabelDecode(object):
|
||||
""" """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
**kwargs):
|
||||
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
||||
def __init__(self, character_dict_path, **kwargs):
|
||||
list_character, list_elem = self.load_char_elem_dict(
|
||||
character_dict_path)
|
||||
list_character = self.add_special_char(list_character)
|
||||
list_elem = self.add_special_char(list_elem)
|
||||
self.dict_character = {}
|
||||
|
@ -366,14 +368,14 @@ class TableLabelDecode(object):
|
|||
def __call__(self, preds):
|
||||
structure_probs = preds['structure_probs']
|
||||
loc_preds = preds['loc_preds']
|
||||
if isinstance(structure_probs,paddle.Tensor):
|
||||
if isinstance(structure_probs, paddle.Tensor):
|
||||
structure_probs = structure_probs.numpy()
|
||||
if isinstance(loc_preds,paddle.Tensor):
|
||||
if isinstance(loc_preds, paddle.Tensor):
|
||||
loc_preds = loc_preds.numpy()
|
||||
structure_idx = structure_probs.argmax(axis=2)
|
||||
structure_probs = structure_probs.max(axis=2)
|
||||
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
|
||||
structure_probs, 'elem')
|
||||
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
|
||||
structure_idx, structure_probs, 'elem')
|
||||
res_html_code_list = []
|
||||
res_loc_list = []
|
||||
batch_num = len(structure_str)
|
||||
|
@ -388,8 +390,13 @@ class TableLabelDecode(object):
|
|||
res_loc = np.array(res_loc)
|
||||
res_html_code_list.append(res_html_code)
|
||||
res_loc_list.append(res_loc)
|
||||
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
|
||||
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
|
||||
return {
|
||||
'res_html_code': res_html_code_list,
|
||||
'res_loc': res_loc_list,
|
||||
'res_score_list': result_score_list,
|
||||
'res_elem_idx_list': result_elem_idx_list,
|
||||
'structure_str_list': structure_str
|
||||
}
|
||||
|
||||
def decode(self, text_index, structure_probs, char_or_elem):
|
||||
"""convert text-label into text-index.
|
||||
|
|
|
@ -105,13 +105,16 @@ def load_dygraph_params(config, model, logger, optimizer):
|
|||
params = paddle.load(pm)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
# for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
for k1 in state_dict.keys():
|
||||
if k1 not in params:
|
||||
continue
|
||||
if list(state_dict[k1].shape) == list(params[k1].shape):
|
||||
new_state_dict[k1] = params[k1]
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k1} {params[k1].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
|
|
@ -187,6 +187,7 @@ def train(config,
|
|||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
algorithm = config['Architecture']['algorithm']
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
|
@ -210,10 +211,14 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
if use_srn or model_type == 'table':
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
# if use_srn or model_type == 'table' or algorithm == "ASTER":
|
||||
# preds = model(images, data=batch[1:])
|
||||
# else:
|
||||
# preds = model(images)
|
||||
preds = model(images, data=batch[1:])
|
||||
state_dict = model.state_dict()
|
||||
# for key in state_dict:
|
||||
# print(key)
|
||||
loss = loss_class(preds, batch)
|
||||
avg_loss = loss['loss']
|
||||
avg_loss.backward()
|
||||
|
@ -395,7 +400,7 @@ def preprocess(is_train=False):
|
|||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'TableAttn'
|
||||
'CLS', 'PGNet', 'Distillation', 'TableAttn', 'ASTER'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
|
@ -72,6 +72,8 @@ def main(config, device, logger, vdl_writer):
|
|||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
character = getattr(post_process_class, 'character')
|
||||
print("getattr character:", character)
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
|
|
Loading…
Reference in New Issue