polish code
This commit is contained in:
parent
59cc4efdc5
commit
c9e1077daa
|
@ -1,9 +1,9 @@
|
|||
Global:
|
||||
use_gpu: False
|
||||
use_gpu: True
|
||||
epoch_num: 400
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/b3_rare_r34_none_gru/
|
||||
save_model_dir: ./output/rec/seed
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 2000]
|
||||
|
@ -12,28 +12,32 @@ Global:
|
|||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: EN_symbol
|
||||
max_text_length: 25
|
||||
max_text_length: 100
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
save_res_path: ./output/rec/predicts_b3_rare_r34_none_gru.txt
|
||||
eval_filter: True
|
||||
save_res_path: ./output/rec/predicts_seed.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
name: Adadelta
|
||||
weight_deacy: 0.0
|
||||
momentum: 0.9
|
||||
lr:
|
||||
learning_rate: 0.0005
|
||||
name: Piecewise
|
||||
decay_epochs: [4,5,8]
|
||||
values: [1.0, 0.1, 0.01]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00000
|
||||
factor: 2.0e-05
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
model_type: seed
|
||||
algorithm: ASTER
|
||||
Transform:
|
||||
name: STN_ON
|
||||
|
@ -54,48 +58,49 @@ Loss:
|
|||
name: AsterLoss
|
||||
|
||||
PostProcess:
|
||||
name: AttnLabelDecode
|
||||
name: SEEDLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
is_filter: True
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/1.txt"]
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- Fasttext:
|
||||
path: "./cc.en.300.bin"
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- AttnLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- SEEDLabelEncode: # Class handling label
|
||||
- SEEDResize:
|
||||
image_shape: [3, 64, 256]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 2
|
||||
batch_size_per_card: 256
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
num_workers: 6
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/1.txt"]
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/evaluation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- AttnLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 100]
|
||||
- SEEDLabelEncode: # Class handling label
|
||||
- SEEDResize:
|
||||
image_shape: [3, 64, 256]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 2
|
||||
num_workers: 8
|
||||
drop_last: True
|
||||
batch_size_per_card: 256
|
||||
num_workers: 4
|
||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
|||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, SEEDResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .operators import *
|
||||
|
|
|
@ -276,9 +276,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.unknown = "UNKNOWN"
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str
|
||||
] + [self.unknown]
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
def __call__(self, data):
|
||||
|
@ -291,7 +289,6 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
data['length'] = np.array(len(text))
|
||||
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
||||
- len(text) - 2)
|
||||
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
|
@ -311,6 +308,39 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
return idx
|
||||
|
||||
|
||||
class SEEDLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SEEDLabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text)) + 1 # conclue eos
|
||||
text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
|
||||
)
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
|
||||
class SRNLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ import sys
|
|||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
import fasttext
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
|
@ -81,7 +82,7 @@ class NormalizeImage(object):
|
|||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data['image'] = (
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
return data
|
||||
|
||||
|
||||
|
@ -101,6 +102,17 @@ class ToCHWImage(object):
|
|||
return data
|
||||
|
||||
|
||||
class Fasttext(object):
|
||||
def __init__(self, path="None", **kwargs):
|
||||
self.fast_model = fasttext.load_model(path)
|
||||
|
||||
def __call__(self, data):
|
||||
label = data['label']
|
||||
fast_label = self.fast_model[label]
|
||||
data['fast_label'] = fast_label
|
||||
return data
|
||||
|
||||
|
||||
class KeepKeys(object):
|
||||
def __init__(self, keep_keys, **kwargs):
|
||||
self.keep_keys = keep_keys
|
||||
|
@ -183,7 +195,7 @@ class DetResizeForTest(object):
|
|||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'resize_long':
|
||||
ratio = float(limit_side_len) / max(h,w)
|
||||
ratio = float(limit_side_len) / max(h, w)
|
||||
else:
|
||||
raise Exception('not support limit type, image ')
|
||||
resize_h = int(h * ratio)
|
||||
|
|
|
@ -63,6 +63,18 @@ class RecResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class SEEDResize(object):
|
||||
def __init__(self, image_shape, infer_mode=False, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img = resize_no_padding_img(img, self.image_shape)
|
||||
data['image'] = norm_img
|
||||
return data
|
||||
|
||||
|
||||
class SRNRecResizeImg(object):
|
||||
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
@ -106,6 +118,17 @@ def resize_norm_img(img, image_shape):
|
|||
return padding_im
|
||||
|
||||
|
||||
def resize_no_padding_img(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
|
||||
def resize_norm_img_chinese(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
# todo: change to 0 and modified image shape
|
||||
|
|
|
@ -22,7 +22,6 @@ from .imaug import transform, create_operators
|
|||
|
||||
class SimpleDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
print("===== simpledataset ========")
|
||||
super(SimpleDataSet, self).__init__()
|
||||
self.logger = logger
|
||||
self.mode = mode.lower()
|
||||
|
|
|
@ -18,7 +18,26 @@ from __future__ import print_function
|
|||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import fasttext
|
||||
|
||||
|
||||
class CosineEmbeddingLoss(nn.Layer):
|
||||
def __init__(self, margin=0.):
|
||||
super(CosineEmbeddingLoss, self).__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(self, x1, x2, target):
|
||||
similarity = paddle.fluid.layers.reduce_sum(
|
||||
x1 * x2, dim=-1) / (paddle.norm(
|
||||
x1, axis=-1) * paddle.norm(
|
||||
x2, axis=-1))
|
||||
one_list = paddle.full_like(target, fill_value=1)
|
||||
out = paddle.fluid.layers.reduce_mean(
|
||||
paddle.where(
|
||||
paddle.equal(target, one_list), 1. - similarity,
|
||||
paddle.maximum(
|
||||
paddle.zeros_like(similarity), similarity - self.margin)))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class AsterLoss(nn.Layer):
|
||||
|
@ -35,28 +54,28 @@ class AsterLoss(nn.Layer):
|
|||
self.ignore_index = ignore_index
|
||||
self.sequence_normalize = sequence_normalize
|
||||
self.sample_normalize = sample_normalize
|
||||
self.loss_func = paddle.nn.CosineSimilarity()
|
||||
self.loss_sem = CosineEmbeddingLoss()
|
||||
self.is_cosin_loss = True
|
||||
self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
targets = batch[1].astype("int64")
|
||||
label_lengths = batch[2].astype('int64')
|
||||
# sem_target = batch[3].astype('float32')
|
||||
sem_target = batch[3].astype('float32')
|
||||
embedding_vectors = predicts['embedding_vectors']
|
||||
rec_pred = predicts['rec_pred']
|
||||
|
||||
# semantic loss
|
||||
# print(embedding_vectors)
|
||||
# print(embedding_vectors.shape)
|
||||
# targets = fasttext[targets]
|
||||
# sem_loss = 1 - self.loss_func(embedding_vectors, targets)
|
||||
if not self.is_cosin_loss:
|
||||
sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target))
|
||||
else:
|
||||
label_target = paddle.ones([embedding_vectors.shape[0]])
|
||||
sem_loss = paddle.sum(
|
||||
self.loss_sem(embedding_vectors, sem_target, label_target))
|
||||
|
||||
# rec loss
|
||||
batch_size, num_steps, num_classes = rec_pred.shape[0], rec_pred.shape[
|
||||
1], rec_pred.shape[2]
|
||||
assert len(targets.shape) == len(list(rec_pred.shape)) - 1, \
|
||||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||
batch_size, def_max_length = targets.shape[0], targets.shape[1]
|
||||
|
||||
mask = paddle.zeros([batch_size, num_steps])
|
||||
mask = paddle.zeros([batch_size, def_max_length])
|
||||
for i in range(batch_size):
|
||||
mask[i, :label_lengths[i]] = 1
|
||||
mask = paddle.cast(mask, "float32")
|
||||
|
@ -64,16 +83,16 @@ class AsterLoss(nn.Layer):
|
|||
assert max_length == rec_pred.shape[1]
|
||||
targets = targets[:, :max_length]
|
||||
mask = mask[:, :max_length]
|
||||
rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[-1]])
|
||||
rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]])
|
||||
input = nn.functional.log_softmax(rec_pred, axis=1)
|
||||
targets = paddle.reshape(targets, [-1, 1])
|
||||
mask = paddle.reshape(mask, [-1, 1])
|
||||
# print("input:", input)
|
||||
output = -paddle.gather(input, index=targets, axis=1) * mask
|
||||
output = -paddle.index_sample(input, index=targets) * mask
|
||||
output = paddle.sum(output)
|
||||
if self.sequence_normalize:
|
||||
output = output / paddle.sum(mask)
|
||||
if self.sample_normalize:
|
||||
output = output / batch_size
|
||||
loss = output
|
||||
return {'loss': loss} # , 'sem_loss':sem_loss}
|
||||
|
||||
loss = output + sem_loss * 0.1
|
||||
return {'loss': loss}
|
||||
|
|
|
@ -35,7 +35,5 @@ class AttentionLoss(nn.Layer):
|
|||
|
||||
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
|
||||
targets = paddle.reshape(targets, [-1])
|
||||
print("input:", paddle.argmax(inputs, axis=1))
|
||||
print("targets:", targets)
|
||||
|
||||
return {'loss': paddle.sum(self.loss_func(inputs, targets))}
|
||||
|
|
|
@ -13,13 +13,20 @@
|
|||
# limitations under the License.
|
||||
|
||||
import Levenshtein
|
||||
import string
|
||||
|
||||
|
||||
class RecMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.is_filter = is_filter
|
||||
self.reset()
|
||||
|
||||
def _normalize_text(self, text):
|
||||
text = ''.join(
|
||||
filter(lambda x: x in (string.digits + string.ascii_letters), text))
|
||||
return text.lower()
|
||||
|
||||
def __call__(self, pred_label, *args, **kwargs):
|
||||
preds, labels = pred_label
|
||||
correct_num = 0
|
||||
|
@ -28,6 +35,9 @@ class RecMetric(object):
|
|||
for (pred, pred_conf), (target, _) in zip(preds, labels):
|
||||
pred = pred.replace(" ", "")
|
||||
target = target.replace(" ", "")
|
||||
if self.is_filter:
|
||||
pred = self._normalize_text(pred)
|
||||
target = self._normalize_text(target)
|
||||
norm_edit_dis += Levenshtein.distance(pred, target) / max(
|
||||
len(pred), len(target), 1)
|
||||
if pred == target:
|
||||
|
|
|
@ -26,10 +26,8 @@ def build_backbone(config, model_type):
|
|||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
from .rec_resnet_aster import ResNet_ASTER
|
||||
support_dict = [
|
||||
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN",
|
||||
"ResNet_ASTER"
|
||||
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
@ -38,6 +36,9 @@ def build_backbone(config, model_type):
|
|||
from .table_resnet_vd import ResNet
|
||||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ["ResNet", "MobileNetV3"]
|
||||
elif model_type == "seed":
|
||||
from .rec_resnet_aster import ResNet_ASTER
|
||||
support_dict = ["ResNet_ASTER"]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -1,707 +0,0 @@
|
|||
# Copyright (c) 2015-present, Facebook, Inc.
|
||||
# All rights reserved.
|
||||
|
||||
# Modified from
|
||||
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
||||
# Copyright 2020 Ross Wightman, Apache-2.0 License
|
||||
|
||||
import paddle
|
||||
import itertools
|
||||
#import utils
|
||||
import math
|
||||
import warnings
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn.initializer import TruncatedNormal, Constant
|
||||
|
||||
#from timm.models.vision_transformer import trunc_normal_
|
||||
#from timm.models.registry import register_model
|
||||
|
||||
specification = {
|
||||
'LeViT_128S': {
|
||||
'C': '128_256_384',
|
||||
'D': 16,
|
||||
'N': '4_6_8',
|
||||
'X': '2_3_4',
|
||||
'drop_path': 0,
|
||||
'weights':
|
||||
'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
|
||||
},
|
||||
'LeViT_128': {
|
||||
'C': '128_256_384',
|
||||
'D': 16,
|
||||
'N': '4_8_12',
|
||||
'X': '4_4_4',
|
||||
'drop_path': 0,
|
||||
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
|
||||
},
|
||||
'LeViT_192': {
|
||||
'C': '192_288_384',
|
||||
'D': 32,
|
||||
'N': '3_5_6',
|
||||
'X': '4_4_4',
|
||||
'drop_path': 0,
|
||||
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
|
||||
},
|
||||
'LeViT_256': {
|
||||
'C': '256_384_512',
|
||||
'D': 32,
|
||||
'N': '4_6_8',
|
||||
'X': '4_4_4',
|
||||
'drop_path': 0,
|
||||
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
|
||||
},
|
||||
'LeViT_384': {
|
||||
'C': '384_512_768',
|
||||
'D': 32,
|
||||
'N': '6_9_12',
|
||||
'X': '4_4_4',
|
||||
'drop_path': 0.1,
|
||||
'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
|
||||
},
|
||||
}
|
||||
|
||||
__all__ = [specification.keys()]
|
||||
|
||||
trunc_normal_ = TruncatedNormal(std=.02)
|
||||
zeros_ = Constant(value=0.)
|
||||
ones_ = Constant(value=1.)
|
||||
|
||||
|
||||
#@register_model
|
||||
def LeViT_128S(class_dim=1000, distillation=True, pretrained=False, fuse=False):
|
||||
return model_factory(
|
||||
**specification['LeViT_128S'],
|
||||
class_dim=class_dim,
|
||||
distillation=distillation,
|
||||
pretrained=pretrained,
|
||||
fuse=fuse)
|
||||
|
||||
|
||||
#@register_model
|
||||
def LeViT_128(class_dim=1000, distillation=True, pretrained=False, fuse=False):
|
||||
return model_factory(
|
||||
**specification['LeViT_128'],
|
||||
class_dim=class_dim,
|
||||
distillation=distillation,
|
||||
pretrained=pretrained,
|
||||
fuse=fuse)
|
||||
|
||||
|
||||
#@register_model
|
||||
def LeViT_192(class_dim=1000, distillation=True, pretrained=False, fuse=False):
|
||||
return model_factory(
|
||||
**specification['LeViT_192'],
|
||||
class_dim=class_dim,
|
||||
distillation=distillation,
|
||||
pretrained=pretrained,
|
||||
fuse=fuse)
|
||||
|
||||
|
||||
#@register_model
|
||||
def LeViT_256(class_dim=1000, distillation=False, pretrained=False, fuse=False):
|
||||
return model_factory(
|
||||
**specification['LeViT_256'],
|
||||
class_dim=class_dim,
|
||||
distillation=distillation,
|
||||
pretrained=pretrained,
|
||||
fuse=fuse)
|
||||
|
||||
|
||||
#@register_model
|
||||
def LeViT_384(class_dim=1000, distillation=True, pretrained=False, fuse=False):
|
||||
return model_factory(
|
||||
**specification['LeViT_384'],
|
||||
class_dim=class_dim,
|
||||
distillation=distillation,
|
||||
pretrained=pretrained,
|
||||
fuse=fuse)
|
||||
|
||||
|
||||
FLOPS_COUNTER = 0
|
||||
|
||||
|
||||
class Conv2d_BN(paddle.nn.Sequential):
|
||||
def __init__(self,
|
||||
a,
|
||||
b,
|
||||
ks=1,
|
||||
stride=1,
|
||||
pad=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bn_weight_init=1,
|
||||
resolution=-10000):
|
||||
super().__init__()
|
||||
self.add_sublayer(
|
||||
'c',
|
||||
paddle.nn.Conv2D(
|
||||
a, b, ks, stride, pad, dilation, groups, bias_attr=False))
|
||||
bn = paddle.nn.BatchNorm2D(b)
|
||||
ones_(bn.weight)
|
||||
zeros_(bn.bias)
|
||||
self.add_sublayer('bn', bn)
|
||||
|
||||
global FLOPS_COUNTER
|
||||
output_points = (
|
||||
(resolution + 2 * pad - dilation * (ks - 1) - 1) // stride + 1)**2
|
||||
FLOPS_COUNTER += a * b * output_points * (ks**2)
|
||||
|
||||
@paddle.no_grad()
|
||||
def fuse(self):
|
||||
c, bn = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
w = c.weight * w[:, None, None, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / \
|
||||
(bn.running_var + bn.eps)**0.5
|
||||
m = paddle.nn.Conv2D(
|
||||
w.size(1),
|
||||
w.size(0),
|
||||
w.shape[2:],
|
||||
stride=self.c.stride,
|
||||
padding=self.c.padding,
|
||||
dilation=self.c.dilation,
|
||||
groups=self.c.groups)
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
class Linear_BN(paddle.nn.Sequential):
|
||||
def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
|
||||
super().__init__()
|
||||
self.add_sublayer('c', paddle.nn.Linear(a, b, bias_attr=False))
|
||||
bn = paddle.nn.BatchNorm1D(b)
|
||||
ones_(bn.weight)
|
||||
zeros_(bn.bias)
|
||||
self.add_sublayer('bn', bn)
|
||||
|
||||
global FLOPS_COUNTER
|
||||
output_points = resolution**2
|
||||
FLOPS_COUNTER += a * b * output_points
|
||||
|
||||
@paddle.no_grad()
|
||||
def fuse(self):
|
||||
l, bn = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
w = l.weight * w[:, None]
|
||||
b = bn.bias - bn.running_mean * bn.weight / \
|
||||
(bn.running_var + bn.eps)**0.5
|
||||
m = paddle.nn.Linear(w.size(1), w.size(0))
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
def forward(self, x):
|
||||
l, bn = self._sub_layers.values()
|
||||
x = l(x)
|
||||
return paddle.reshape(bn(x.flatten(0, 1)), x.shape)
|
||||
|
||||
|
||||
class BN_Linear(paddle.nn.Sequential):
|
||||
def __init__(self, a, b, bias=True, std=0.02):
|
||||
super().__init__()
|
||||
self.add_sublayer('bn', paddle.nn.BatchNorm1D(a))
|
||||
l = paddle.nn.Linear(a, b, bias_attr=bias)
|
||||
trunc_normal_(l.weight)
|
||||
if bias:
|
||||
zeros_(l.bias)
|
||||
self.add_sublayer('l', l)
|
||||
global FLOPS_COUNTER
|
||||
FLOPS_COUNTER += a * b
|
||||
|
||||
@paddle.no_grad()
|
||||
def fuse(self):
|
||||
bn, l = self._modules.values()
|
||||
w = bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
b = bn.bias - self.bn.running_mean * \
|
||||
self.bn.weight / (bn.running_var + bn.eps)**0.5
|
||||
w = l.weight * w[None, :]
|
||||
if l.bias is None:
|
||||
b = b @self.l.weight.T
|
||||
else:
|
||||
b = (l.weight @b[:, None]).view(-1) + self.l.bias
|
||||
m = paddle.nn.Linear(w.size(1), w.size(0))
|
||||
m.weight.data.copy_(w)
|
||||
m.bias.data.copy_(b)
|
||||
return m
|
||||
|
||||
|
||||
def b16(n, activation, resolution=224):
|
||||
return paddle.nn.Sequential(
|
||||
Conv2d_BN(
|
||||
3, n // 8, 3, 2, 1, resolution=resolution),
|
||||
activation(),
|
||||
Conv2d_BN(
|
||||
n // 8, n // 4, 3, 2, 1, resolution=resolution // 2),
|
||||
activation(),
|
||||
Conv2d_BN(
|
||||
n // 4, n // 2, 3, 2, 1, resolution=resolution // 4),
|
||||
activation(),
|
||||
Conv2d_BN(
|
||||
n // 2, n, 3, 2, 1, resolution=resolution // 8))
|
||||
|
||||
|
||||
class Residual(paddle.nn.Layer):
|
||||
def __init__(self, m, drop):
|
||||
super().__init__()
|
||||
self.m = m
|
||||
self.drop = drop
|
||||
|
||||
def forward(self, x):
|
||||
if self.training and self.drop > 0:
|
||||
return x + self.m(x) * paddle.rand(
|
||||
x.size(0), 1, 1,
|
||||
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
|
||||
else:
|
||||
return x + self.m(x)
|
||||
|
||||
|
||||
class Attention(paddle.nn.Layer):
|
||||
def __init__(self,
|
||||
dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=4,
|
||||
activation=None,
|
||||
resolution=14):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim**-0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
self.dh = int(attn_ratio * key_dim) * num_heads
|
||||
self.attn_ratio = attn_ratio
|
||||
self.h = self.dh + nh_kd * 2
|
||||
self.qkv = Linear_BN(dim, self.h, resolution=resolution)
|
||||
self.proj = paddle.nn.Sequential(
|
||||
activation(),
|
||||
Linear_BN(
|
||||
self.dh, dim, bn_weight_init=0, resolution=resolution))
|
||||
points = list(itertools.product(range(resolution), range(resolution)))
|
||||
N = len(points)
|
||||
attention_offsets = {}
|
||||
idxs = []
|
||||
for p1 in points:
|
||||
for p2 in points:
|
||||
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
||||
if offset not in attention_offsets:
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = self.create_parameter(
|
||||
shape=(num_heads, len(attention_offsets)),
|
||||
default_initializer=zeros_)
|
||||
tensor_idxs = paddle.to_tensor(idxs, dtype='int64')
|
||||
self.register_buffer('attention_bias_idxs',
|
||||
paddle.reshape(tensor_idxs, [N, N]))
|
||||
|
||||
global FLOPS_COUNTER
|
||||
#queries * keys
|
||||
FLOPS_COUNTER += num_heads * (resolution**4) * key_dim
|
||||
# softmax
|
||||
FLOPS_COUNTER += num_heads * (resolution**4)
|
||||
#attention * v
|
||||
FLOPS_COUNTER += num_heads * self.d * (resolution**4)
|
||||
|
||||
@paddle.no_grad()
|
||||
def train(self, mode=True):
|
||||
if mode:
|
||||
super().train()
|
||||
else:
|
||||
super().eval()
|
||||
if mode and hasattr(self, 'ab'):
|
||||
del self.ab
|
||||
else:
|
||||
gather_list = []
|
||||
attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
|
||||
for idx in self.attention_bias_idxs:
|
||||
gather = paddle.gather(attention_bias_t, idx)
|
||||
gather_list.append(gather)
|
||||
attention_biases = paddle.transpose(
|
||||
paddle.concat(gather_list), (1, 0)).reshape(
|
||||
(0, self.attention_bias_idxs.shape[0],
|
||||
self.attention_bias_idxs.shape[1]))
|
||||
self.ab = attention_biases
|
||||
#self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x): # x (B,N,C)
|
||||
self.training = True
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
qkv = paddle.reshape(qkv,
|
||||
[B, N, self.num_heads, self.h // self.num_heads])
|
||||
q, k, v = paddle.split(
|
||||
qkv, [self.key_dim, self.key_dim, self.d], axis=3)
|
||||
q = paddle.transpose(q, perm=[0, 2, 1, 3])
|
||||
k = paddle.transpose(k, perm=[0, 2, 1, 3])
|
||||
v = paddle.transpose(v, perm=[0, 2, 1, 3])
|
||||
k_transpose = paddle.transpose(k, perm=[0, 1, 3, 2])
|
||||
|
||||
if self.training:
|
||||
gather_list = []
|
||||
attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
|
||||
for idx in self.attention_bias_idxs:
|
||||
gather = paddle.gather(attention_bias_t, idx)
|
||||
gather_list.append(gather)
|
||||
attention_biases = paddle.transpose(
|
||||
paddle.concat(gather_list), (1, 0)).reshape(
|
||||
(0, self.attention_bias_idxs.shape[0],
|
||||
self.attention_bias_idxs.shape[1]))
|
||||
else:
|
||||
attention_biases = self.ab
|
||||
#np_ = paddle.to_tensor(self.attention_biases.numpy()[:, self.attention_bias_idxs.numpy()])
|
||||
#print(self.attention_bias_idxs.shape)
|
||||
#print(attention_biases.shape)
|
||||
#print(np_.shape)
|
||||
#print(np_.equal(attention_biases))
|
||||
#exit()
|
||||
|
||||
attn = ((q @k_transpose) * self.scale + attention_biases)
|
||||
attn = F.softmax(attn)
|
||||
x = paddle.transpose(attn @v, perm=[0, 2, 1, 3])
|
||||
x = paddle.reshape(x, [B, N, self.dh])
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Subsample(paddle.nn.Layer):
|
||||
def __init__(self, stride, resolution):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.resolution = resolution
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
x = paddle.reshape(x, [B, self.resolution, self.resolution,
|
||||
C])[:, ::self.stride, ::self.stride]
|
||||
x = paddle.reshape(x, [B, -1, C])
|
||||
return x
|
||||
|
||||
|
||||
class AttentionSubsample(paddle.nn.Layer):
|
||||
def __init__(self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
key_dim,
|
||||
num_heads=8,
|
||||
attn_ratio=2,
|
||||
activation=None,
|
||||
stride=2,
|
||||
resolution=14,
|
||||
resolution_=7):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim**-0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
self.dh = int(attn_ratio * key_dim) * self.num_heads
|
||||
self.attn_ratio = attn_ratio
|
||||
self.resolution_ = resolution_
|
||||
self.resolution_2 = resolution_**2
|
||||
self.training = True
|
||||
h = self.dh + nh_kd
|
||||
self.kv = Linear_BN(in_dim, h, resolution=resolution)
|
||||
|
||||
self.q = paddle.nn.Sequential(
|
||||
Subsample(stride, resolution),
|
||||
Linear_BN(
|
||||
in_dim, nh_kd, resolution=resolution_))
|
||||
self.proj = paddle.nn.Sequential(
|
||||
activation(), Linear_BN(
|
||||
self.dh, out_dim, resolution=resolution_))
|
||||
|
||||
self.stride = stride
|
||||
self.resolution = resolution
|
||||
points = list(itertools.product(range(resolution), range(resolution)))
|
||||
points_ = list(
|
||||
itertools.product(range(resolution_), range(resolution_)))
|
||||
|
||||
N = len(points)
|
||||
N_ = len(points_)
|
||||
attention_offsets = {}
|
||||
idxs = []
|
||||
i = 0
|
||||
j = 0
|
||||
for p1 in points_:
|
||||
i += 1
|
||||
for p2 in points:
|
||||
j += 1
|
||||
size = 1
|
||||
offset = (abs(p1[0] * stride - p2[0] + (size - 1) / 2),
|
||||
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
|
||||
if offset not in attention_offsets:
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = self.create_parameter(
|
||||
shape=(num_heads, len(attention_offsets)),
|
||||
default_initializer=zeros_)
|
||||
|
||||
tensor_idxs_ = paddle.to_tensor(idxs, dtype='int64')
|
||||
self.register_buffer('attention_bias_idxs',
|
||||
paddle.reshape(tensor_idxs_, [N_, N]))
|
||||
|
||||
global FLOPS_COUNTER
|
||||
#queries * keys
|
||||
FLOPS_COUNTER += num_heads * \
|
||||
(resolution**2) * (resolution_**2) * key_dim
|
||||
# softmax
|
||||
FLOPS_COUNTER += num_heads * (resolution**2) * (resolution_**2)
|
||||
#attention * v
|
||||
FLOPS_COUNTER += num_heads * \
|
||||
(resolution**2) * (resolution_**2) * self.d
|
||||
|
||||
@paddle.no_grad()
|
||||
def train(self, mode=True):
|
||||
if mode:
|
||||
super().train()
|
||||
else:
|
||||
super().eval()
|
||||
if mode and hasattr(self, 'ab'):
|
||||
del self.ab
|
||||
else:
|
||||
gather_list = []
|
||||
attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
|
||||
for idx in self.attention_bias_idxs:
|
||||
gather = paddle.gather(attention_bias_t, idx)
|
||||
gather_list.append(gather)
|
||||
attention_biases = paddle.transpose(
|
||||
paddle.concat(gather_list), (1, 0)).reshape(
|
||||
(0, self.attention_bias_idxs.shape[0],
|
||||
self.attention_bias_idxs.shape[1]))
|
||||
self.ab = attention_biases
|
||||
#self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x):
|
||||
self.training = True
|
||||
B, N, C = x.shape
|
||||
kv = self.kv(x)
|
||||
kv = paddle.reshape(kv, [B, N, self.num_heads, -1])
|
||||
k, v = paddle.split(kv, [self.key_dim, self.d], axis=3)
|
||||
k = paddle.transpose(k, perm=[0, 2, 1, 3]) # BHNC
|
||||
v = paddle.transpose(v, perm=[0, 2, 1, 3])
|
||||
q = paddle.reshape(
|
||||
self.q(x), [B, self.resolution_2, self.num_heads, self.key_dim])
|
||||
q = paddle.transpose(q, perm=[0, 2, 1, 3])
|
||||
|
||||
if self.training:
|
||||
gather_list = []
|
||||
attention_bias_t = paddle.transpose(self.attention_biases, (1, 0))
|
||||
for idx in self.attention_bias_idxs:
|
||||
gather = paddle.gather(attention_bias_t, idx)
|
||||
gather_list.append(gather)
|
||||
attention_biases = paddle.transpose(
|
||||
paddle.concat(gather_list), (1, 0)).reshape(
|
||||
(0, self.attention_bias_idxs.shape[0],
|
||||
self.attention_bias_idxs.shape[1]))
|
||||
else:
|
||||
attention_biases = self.ab
|
||||
|
||||
attn = (q @paddle.transpose(
|
||||
k, perm=[0, 1, 3, 2])) * self.scale + attention_biases
|
||||
attn = F.softmax(attn)
|
||||
|
||||
x = paddle.reshape(
|
||||
paddle.transpose(
|
||||
(attn @v), perm=[0, 2, 1, 3]), [B, -1, self.dh])
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class LeViT(paddle.nn.Layer):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
class_dim=1000,
|
||||
embed_dim=[192],
|
||||
key_dim=[64],
|
||||
depth=[12],
|
||||
num_heads=[3],
|
||||
attn_ratio=[2],
|
||||
mlp_ratio=[2],
|
||||
hybrid_backbone=None,
|
||||
down_ops=[],
|
||||
attention_activation=paddle.nn.Hardswish,
|
||||
mlp_activation=paddle.nn.Hardswish,
|
||||
distillation=True,
|
||||
drop_path=0):
|
||||
super().__init__()
|
||||
global FLOPS_COUNTER
|
||||
|
||||
self.class_dim = class_dim
|
||||
self.num_features = embed_dim[-1]
|
||||
self.embed_dim = embed_dim
|
||||
self.distillation = distillation
|
||||
|
||||
self.patch_embed = hybrid_backbone
|
||||
|
||||
self.blocks = []
|
||||
down_ops.append([''])
|
||||
resolution = img_size // patch_size
|
||||
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
|
||||
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio,
|
||||
down_ops)):
|
||||
for _ in range(dpth):
|
||||
self.blocks.append(
|
||||
Residual(
|
||||
Attention(
|
||||
ed,
|
||||
kd,
|
||||
nh,
|
||||
attn_ratio=ar,
|
||||
activation=attention_activation,
|
||||
resolution=resolution, ),
|
||||
drop_path))
|
||||
if mr > 0:
|
||||
h = int(ed * mr)
|
||||
self.blocks.append(
|
||||
Residual(
|
||||
paddle.nn.Sequential(
|
||||
Linear_BN(
|
||||
ed, h, resolution=resolution),
|
||||
mlp_activation(),
|
||||
Linear_BN(
|
||||
h,
|
||||
ed,
|
||||
bn_weight_init=0,
|
||||
resolution=resolution), ),
|
||||
drop_path))
|
||||
if do[0] == 'Subsample':
|
||||
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
|
||||
resolution_ = (resolution - 1) // do[5] + 1
|
||||
self.blocks.append(
|
||||
AttentionSubsample(
|
||||
*embed_dim[i:i + 2],
|
||||
key_dim=do[1],
|
||||
num_heads=do[2],
|
||||
attn_ratio=do[3],
|
||||
activation=attention_activation,
|
||||
stride=do[5],
|
||||
resolution=resolution,
|
||||
resolution_=resolution_))
|
||||
resolution = resolution_
|
||||
if do[4] > 0: # mlp_ratio
|
||||
h = int(embed_dim[i + 1] * do[4])
|
||||
self.blocks.append(
|
||||
Residual(
|
||||
paddle.nn.Sequential(
|
||||
Linear_BN(
|
||||
embed_dim[i + 1], h, resolution=resolution),
|
||||
mlp_activation(),
|
||||
Linear_BN(
|
||||
h,
|
||||
embed_dim[i + 1],
|
||||
bn_weight_init=0,
|
||||
resolution=resolution), ),
|
||||
drop_path))
|
||||
self.blocks = paddle.nn.Sequential(*self.blocks)
|
||||
|
||||
# Classifier head
|
||||
self.head = BN_Linear(
|
||||
embed_dim[-1], class_dim) if class_dim > 0 else paddle.nn.Identity()
|
||||
if distillation:
|
||||
self.head_dist = BN_Linear(
|
||||
embed_dim[-1],
|
||||
class_dim) if class_dim > 0 else paddle.nn.Identity()
|
||||
|
||||
self.FLOPS = FLOPS_COUNTER
|
||||
FLOPS_COUNTER = 0
|
||||
|
||||
def no_weight_decay(self):
|
||||
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = x.flatten(2)
|
||||
x = paddle.transpose(x, perm=[0, 2, 1])
|
||||
x = self.blocks(x)
|
||||
x = x.mean(1)
|
||||
if self.distillation:
|
||||
x = self.head(x), self.head_dist(x)
|
||||
if not self.training:
|
||||
x = (x[0] + x[1]) / 2
|
||||
else:
|
||||
x = self.head(x)
|
||||
return x
|
||||
|
||||
|
||||
def model_factory(C, D, X, N, drop_path, weights, class_dim, distillation,
|
||||
pretrained, fuse):
|
||||
embed_dim = [int(x) for x in C.split('_')]
|
||||
num_heads = [int(x) for x in N.split('_')]
|
||||
depth = [int(x) for x in X.split('_')]
|
||||
act = paddle.nn.Hardswish
|
||||
model = LeViT(
|
||||
patch_size=16,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
key_dim=[D] * 3,
|
||||
depth=depth,
|
||||
attn_ratio=[2, 2, 2],
|
||||
mlp_ratio=[2, 2, 2],
|
||||
down_ops=[
|
||||
#('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
|
||||
['Subsample', D, embed_dim[0] // D, 4, 2, 2],
|
||||
['Subsample', D, embed_dim[1] // D, 4, 2, 2],
|
||||
],
|
||||
attention_activation=act,
|
||||
mlp_activation=act,
|
||||
hybrid_backbone=b16(embed_dim[0], activation=act),
|
||||
class_dim=class_dim,
|
||||
drop_path=drop_path,
|
||||
distillation=distillation)
|
||||
# if pretrained:
|
||||
# checkpoint = torch.hub.load_state_dict_from_url(
|
||||
# weights, map_location='cpu')
|
||||
# model.load_state_dict(checkpoint['model'])
|
||||
if fuse:
|
||||
utils.replace_batchnorm(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
import torch
|
||||
checkpoint = torch.load('../LeViT/pretrained256.pth')
|
||||
torch_dict = checkpoint['net']
|
||||
paddle_dict = {}
|
||||
fc_names = ["c.weight", "l.weight", "qkv.weight", "fc1.weight", "fc2.weight", "downsample.reduction.weight", "head.weight", "attn.proj.weight"]
|
||||
rename_dict = {"running_mean": "_mean", "running_var": "_variance"}
|
||||
range_tuple = (0, 502)
|
||||
idx = 0
|
||||
for key in torch_dict:
|
||||
idx += 1
|
||||
weight = torch_dict[key].cpu().numpy()
|
||||
flag = [i in key for i in fc_names]
|
||||
if any(flag):
|
||||
if "emb" not in key:
|
||||
print("weight {} need to be trans".format(key))
|
||||
weight = weight.transpose()
|
||||
key = key.replace("running_mean", "_mean")
|
||||
key = key.replace("running_var", "_variance")
|
||||
paddle_dict[key]=weight
|
||||
'''
|
||||
import numpy as np
|
||||
net = globals()['LeViT_256'](fuse=False,
|
||||
pretrained=False,
|
||||
distillation=False)
|
||||
load_layer_state_dict = paddle.load(
|
||||
"./LeViT_256_official_nodistillation_paddle.pdparams")
|
||||
#net.set_state_dict(paddle_dict)
|
||||
net.set_state_dict(load_layer_state_dict)
|
||||
net.eval()
|
||||
#paddle.save(net.state_dict(), "./LeViT_256_official_paddle.pdparams")
|
||||
#model = paddle.jit.to_static(net,input_spec=[paddle.static.InputSpec(shape=[None, 3, 224, 224], dtype='float32')])
|
||||
#paddle.jit.save(model, "./LeViT_256_official_inference/inference")
|
||||
#exit()
|
||||
np.random.seed(123)
|
||||
img = np.random.rand(1, 3, 224, 224).astype('float32')
|
||||
img = paddle.to_tensor(img)
|
||||
outputs = net(img).numpy()
|
||||
print(outputs[0][:10])
|
||||
#print(outputs.shape)
|
|
@ -42,6 +42,5 @@ def build_head(config):
|
|||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
support_dict))
|
||||
print(config)
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
|
|
@ -43,13 +43,14 @@ class AsterHead(nn.Layer):
|
|||
self.time_step = time_step
|
||||
self.embeder = Embedding(self.time_step, in_channels)
|
||||
self.beam_width = beam_width
|
||||
self.eos = self.num_classes - 1
|
||||
|
||||
def forward(self, x, targets=None, embed=None):
|
||||
return_dict = {}
|
||||
embedding_vectors = self.embeder(x)
|
||||
rec_targets, rec_lengths = targets
|
||||
|
||||
if self.training:
|
||||
rec_targets, rec_lengths, _ = targets
|
||||
rec_pred = self.decoder([x, rec_targets, rec_lengths],
|
||||
embedding_vectors)
|
||||
return_dict['rec_pred'] = rec_pred
|
||||
|
@ -104,14 +105,12 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
# Decoder
|
||||
state = self.decoder.get_initial_state(embed)
|
||||
outputs = []
|
||||
|
||||
for i in range(max(lengths)):
|
||||
if i == 0:
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size], fill_value=self.num_classes)
|
||||
else:
|
||||
y_prev = targets[:, i - 1]
|
||||
|
||||
output, state = self.decoder(x, state, y_prev)
|
||||
outputs.append(output)
|
||||
outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
|
||||
|
@ -142,6 +141,170 @@ class AttentionRecognitionHead(nn.Layer):
|
|||
# return predicted_ids.squeeze(), predicted_scores.squeeze()
|
||||
return predicted_ids, predicted_scores
|
||||
|
||||
def beam_search(self, x, beam_width, eos, embed):
|
||||
def _inflate(tensor, times, dim):
|
||||
repeat_dims = [1] * tensor.dim()
|
||||
repeat_dims[dim] = times
|
||||
output = paddle.tile(tensor, repeat_dims)
|
||||
return output
|
||||
|
||||
# https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
|
||||
batch_size, l, d = x.shape
|
||||
# inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC
|
||||
x = paddle.tile(
|
||||
paddle.transpose(
|
||||
x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1])
|
||||
inflated_encoder_feats = paddle.reshape(
|
||||
paddle.transpose(
|
||||
x, perm=[1, 0, 2, 3]), [-1, l, d])
|
||||
|
||||
# Initialize the decoder
|
||||
state = self.decoder.get_initial_state(embed, tile_times=beam_width)
|
||||
|
||||
pos_index = paddle.reshape(
|
||||
paddle.arange(batch_size) * beam_width, shape=[-1, 1])
|
||||
|
||||
# Initialize the scores
|
||||
sequence_scores = paddle.full(
|
||||
shape=[batch_size * beam_width, 1], fill_value=-float('Inf'))
|
||||
index = [i * beam_width for i in range(0, batch_size)]
|
||||
sequence_scores[index] = 0.0
|
||||
|
||||
# Initialize the input vector
|
||||
y_prev = paddle.full(
|
||||
shape=[batch_size * beam_width], fill_value=self.num_classes)
|
||||
|
||||
# Store decisions for backtracking
|
||||
stored_scores = list()
|
||||
stored_predecessors = list()
|
||||
stored_emitted_symbols = list()
|
||||
|
||||
for i in range(self.max_len_labels):
|
||||
output, state = self.decoder(inflated_encoder_feats, state, y_prev)
|
||||
state = paddle.unsqueeze(state, axis=0)
|
||||
log_softmax_output = paddle.nn.functional.log_softmax(
|
||||
output, axis=1)
|
||||
|
||||
sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
|
||||
sequence_scores += log_softmax_output
|
||||
scores, candidates = paddle.topk(
|
||||
paddle.reshape(sequence_scores, [batch_size, -1]),
|
||||
beam_width,
|
||||
axis=1)
|
||||
|
||||
# Reshape input = (bk, 1) and sequence_scores = (bk, 1)
|
||||
y_prev = paddle.reshape(
|
||||
candidates % self.num_classes, shape=[batch_size * beam_width])
|
||||
sequence_scores = paddle.reshape(
|
||||
scores, shape=[batch_size * beam_width, 1])
|
||||
|
||||
# Update fields for next timestep
|
||||
pos_index = paddle.expand_as(pos_index, candidates)
|
||||
predecessors = paddle.cast(
|
||||
candidates / self.num_classes + pos_index, dtype='int64')
|
||||
predecessors = paddle.reshape(
|
||||
predecessors, shape=[batch_size * beam_width, 1])
|
||||
state = paddle.index_select(
|
||||
state, index=predecessors.squeeze(), axis=1)
|
||||
|
||||
# Update sequence socres and erase scores for <eos> symbol so that they aren't expanded
|
||||
stored_scores.append(sequence_scores.clone())
|
||||
y_prev = paddle.reshape(y_prev, shape=[-1, 1])
|
||||
eos_prev = paddle.full_like(y_prev, fill_value=eos)
|
||||
mask = eos_prev == y_prev
|
||||
mask = paddle.nonzero(mask)
|
||||
if mask.dim() > 0:
|
||||
sequence_scores = sequence_scores.numpy()
|
||||
mask = mask.numpy()
|
||||
sequence_scores[mask] = -float('inf')
|
||||
sequence_scores = paddle.to_tensor(sequence_scores)
|
||||
|
||||
# Cache results for backtracking
|
||||
stored_predecessors.append(predecessors)
|
||||
y_prev = paddle.squeeze(y_prev)
|
||||
stored_emitted_symbols.append(y_prev)
|
||||
|
||||
# Do backtracking to return the optimal values
|
||||
#====== backtrak ======#
|
||||
# Initialize return variables given different types
|
||||
p = list()
|
||||
l = [[self.max_len_labels] * beam_width for _ in range(batch_size)
|
||||
] # Placeholder for lengths of top-k sequences
|
||||
|
||||
# the last step output of the beams are not sorted
|
||||
# thus they are sorted here
|
||||
sorted_score, sorted_idx = paddle.topk(
|
||||
paddle.reshape(
|
||||
stored_scores[-1], shape=[batch_size, beam_width]),
|
||||
beam_width)
|
||||
|
||||
# initialize the sequence scores with the sorted last step beam scores
|
||||
s = sorted_score.clone()
|
||||
|
||||
batch_eos_found = [0] * batch_size # the number of EOS found
|
||||
# in the backward loop below for each batch
|
||||
t = self.max_len_labels - 1
|
||||
# initialize the back pointer with the sorted order of the last step beams.
|
||||
# add pos_index for indexing variable with b*k as the first dimension.
|
||||
t_predecessors = paddle.reshape(
|
||||
sorted_idx + pos_index.expand_as(sorted_idx),
|
||||
shape=[batch_size * beam_width])
|
||||
while t >= 0:
|
||||
# Re-order the variables with the back pointer
|
||||
current_symbol = paddle.index_select(
|
||||
stored_emitted_symbols[t], index=t_predecessors, axis=0)
|
||||
t_predecessors = paddle.index_select(
|
||||
stored_predecessors[t].squeeze(), index=t_predecessors, axis=0)
|
||||
eos_indices = stored_emitted_symbols[t] == eos
|
||||
eos_indices = paddle.nonzero(eos_indices)
|
||||
|
||||
if eos_indices.dim() > 0:
|
||||
for i in range(eos_indices.shape[0] - 1, -1, -1):
|
||||
# Indices of the EOS symbol for both variables
|
||||
# with b*k as the first dimension, and b, k for
|
||||
# the first two dimensions
|
||||
idx = eos_indices[i]
|
||||
b_idx = int(idx[0] / beam_width)
|
||||
# The indices of the replacing position
|
||||
# according to the replacement strategy noted above
|
||||
res_k_idx = beam_width - (batch_eos_found[b_idx] %
|
||||
beam_width) - 1
|
||||
batch_eos_found[b_idx] += 1
|
||||
res_idx = b_idx * beam_width + res_k_idx
|
||||
|
||||
# Replace the old information in return variables
|
||||
# with the new ended sequence information
|
||||
t_predecessors[res_idx] = stored_predecessors[t][idx[0]]
|
||||
current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]]
|
||||
s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0]
|
||||
l[b_idx][res_k_idx] = t + 1
|
||||
|
||||
# record the back tracked results
|
||||
p.append(current_symbol)
|
||||
t -= 1
|
||||
|
||||
# Sort and re-order again as the added ended sequences may change
|
||||
# the order (very unlikely)
|
||||
s, re_sorted_idx = s.topk(beam_width)
|
||||
for b_idx in range(batch_size):
|
||||
l[b_idx] = [
|
||||
l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]
|
||||
]
|
||||
|
||||
re_sorted_idx = paddle.reshape(
|
||||
re_sorted_idx + pos_index.expand_as(re_sorted_idx),
|
||||
[batch_size * beam_width])
|
||||
|
||||
# Reverse the sequences and re-order at the same time
|
||||
# It is reversed because the backtracking happens in reverse time order
|
||||
p = [
|
||||
paddle.reshape(
|
||||
paddle.index_select(step, re_sorted_idx, 0),
|
||||
shape=[batch_size, beam_width, -1]) for step in reversed(p)
|
||||
]
|
||||
p = paddle.concat(p, -1)[:, 0, :]
|
||||
return p, paddle.ones_like(p)
|
||||
|
||||
|
||||
class AttentionUnit(nn.Layer):
|
||||
def __init__(self, sDim, xDim, attDim):
|
||||
|
@ -151,21 +314,9 @@ class AttentionUnit(nn.Layer):
|
|||
self.xDim = xDim
|
||||
self.attDim = attDim
|
||||
|
||||
self.sEmbed = nn.Linear(
|
||||
sDim,
|
||||
attDim,
|
||||
weight_attr=paddle.nn.initializer.Normal(std=0.01),
|
||||
bias_attr=paddle.nn.initializer.Constant(0.0))
|
||||
self.xEmbed = nn.Linear(
|
||||
xDim,
|
||||
attDim,
|
||||
weight_attr=paddle.nn.initializer.Normal(std=0.01),
|
||||
bias_attr=paddle.nn.initializer.Constant(0.0))
|
||||
self.wEmbed = nn.Linear(
|
||||
attDim,
|
||||
1,
|
||||
weight_attr=paddle.nn.initializer.Normal(std=0.01),
|
||||
bias_attr=paddle.nn.initializer.Constant(0.0))
|
||||
self.sEmbed = nn.Linear(sDim, attDim)
|
||||
self.xEmbed = nn.Linear(xDim, attDim)
|
||||
self.wEmbed = nn.Linear(attDim, 1)
|
||||
|
||||
def forward(self, x, sPrev):
|
||||
batch_size, T, _ = x.shape # [b x T x xDim]
|
||||
|
@ -184,10 +335,8 @@ class AttentionUnit(nn.Layer):
|
|||
|
||||
vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
|
||||
vProj = paddle.reshape(vProj, [batch_size, T])
|
||||
|
||||
alpha = F.softmax(
|
||||
vProj, axis=1) # attention weights for each sample in the minibatch
|
||||
|
||||
return alpha
|
||||
|
||||
|
||||
|
@ -238,21 +387,4 @@ class DecoderUnit(nn.Layer):
|
|||
output, state = self.gru(concat_context, sPrev)
|
||||
output = paddle.squeeze(output, axis=1)
|
||||
output = self.fc(output)
|
||||
return output, state
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = AttentionRecognitionHead(
|
||||
num_classes=20,
|
||||
in_channels=30,
|
||||
sDim=512,
|
||||
attDim=512,
|
||||
max_len_labels=25,
|
||||
out_channels=38)
|
||||
|
||||
data = paddle.ones([16, 64, 3])
|
||||
targets = paddle.ones([16, 25])
|
||||
length = paddle.to_tensor(20)
|
||||
x = [data, targets, length]
|
||||
output = model(x)
|
||||
print(output.shape)
|
||||
return output, state
|
|
@ -44,13 +44,10 @@ class AttentionHead(nn.Layer):
|
|||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
output_hiddens = []
|
||||
|
||||
targets = targets[0]
|
||||
print(targets)
|
||||
if targets is not None:
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets[:, i], onehot_dim=self.num_classes)
|
||||
# print("char_onehots:", char_onehots)
|
||||
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
|
@ -107,8 +104,6 @@ class AttentionGRUCell(nn.Layer):
|
|||
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||
concat_context = paddle.concat([context, char_onehots], 1)
|
||||
# print("concat_context:", concat_context.shape)
|
||||
# print("prev_hidden:", prev_hidden.shape)
|
||||
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
|
||||
|
|
|
@ -106,16 +106,3 @@ class STN(nn.Layer):
|
|||
x = F.sigmoid(x)
|
||||
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
|
||||
return img_feat, x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
in_planes = 3
|
||||
num_ctrlpoints = 20
|
||||
np.random.seed(100)
|
||||
activation = 'none' # 'sigmoid'
|
||||
stn_head = STN(in_planes, num_ctrlpoints, activation)
|
||||
data = np.random.randn(10, 3, 32, 64).astype("float32")
|
||||
print("data:", np.sum(data))
|
||||
input = paddle.to_tensor(data)
|
||||
#input = paddle.randn([10, 3, 32, 64])
|
||||
control_points = stn_head(input)
|
||||
|
|
|
@ -326,5 +326,6 @@ class STN_ON(nn.Layer):
|
|||
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
||||
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
||||
x, _ = self.tps(image, ctrl_points)
|
||||
#print("x:", np.sum(x.numpy()))
|
||||
# print(x.shape)
|
||||
return x
|
||||
|
|
|
@ -136,7 +136,8 @@ class TPSSpatialTransformer(nn.Layer):
|
|||
assert source_control_points.ndimension() == 3
|
||||
assert source_control_points.shape[1] == self.num_control_points
|
||||
assert source_control_points.shape[2] == 2
|
||||
batch_size = source_control_points.shape[0]
|
||||
#batch_size = source_control_points.shape[0]
|
||||
batch_size = paddle.shape(source_control_points)[0]
|
||||
|
||||
self.padding_matrix = paddle.expand(
|
||||
self.padding_matrix, shape=[batch_size, 3, 2])
|
||||
|
@ -151,28 +152,6 @@ class TPSSpatialTransformer(nn.Layer):
|
|||
grid = paddle.clip(grid, 0,
|
||||
1) # the source_control_points may be out of [0, 1].
|
||||
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
||||
# grid = 2.0 * grid - 1.0
|
||||
grid = 2.0 * grid - 1.0
|
||||
output_maps = grid_sample(input, grid, canvas=None)
|
||||
return output_maps, source_coordinate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from stn import STN
|
||||
in_planes = 3
|
||||
num_ctrlpoints = 20
|
||||
np.random.seed(100)
|
||||
activation = 'none' # 'sigmoid'
|
||||
stn_head = STN(in_planes, num_ctrlpoints, activation)
|
||||
data = np.random.randn(10, 3, 32, 64).astype("float32")
|
||||
input = paddle.to_tensor(data)
|
||||
#input = paddle.randn([10, 3, 32, 64])
|
||||
control_points = stn_head(input)
|
||||
#print("control points:", control_points)
|
||||
#input = paddle.randn(shape=[10,3,32,100])
|
||||
tps = TPSSpatialTransformer(
|
||||
output_image_size=[32, 320],
|
||||
num_control_points=20,
|
||||
margins=[0.05, 0.05])
|
||||
out = tps(input, control_points[1])
|
||||
print("out 0 :", out[0].shape)
|
||||
print("out 1:", out[1].shape)
|
||||
|
|
|
@ -1,149 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def grid_sample(input, grid, canvas=None):
|
||||
output = F.grid_sample(input, grid)
|
||||
if canvas is None:
|
||||
return output
|
||||
else:
|
||||
input_mask = input.data.new(input.size()).fill_(1)
|
||||
output_mask = F.grid_sample(input_mask, grid)
|
||||
padded_output = output * output_mask + canvas * (1 - output_mask)
|
||||
return padded_output
|
||||
|
||||
|
||||
# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
|
||||
def compute_partial_repr(input_points, control_points):
|
||||
N = input_points.size(0)
|
||||
M = control_points.size(0)
|
||||
pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2)
|
||||
# original implementation, very slow
|
||||
# pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
|
||||
pairwise_diff_square = pairwise_diff * pairwise_diff
|
||||
pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
|
||||
1]
|
||||
repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist)
|
||||
# fix numerical error for 0 * log(0), substitute all nan with 0
|
||||
mask = repr_matrix != repr_matrix
|
||||
repr_matrix.masked_fill_(mask, 0)
|
||||
return repr_matrix
|
||||
|
||||
|
||||
# output_ctrl_pts are specified, according to our task.
|
||||
def build_output_control_points(num_control_points, margins):
|
||||
margin_x, margin_y = margins
|
||||
num_ctrl_pts_per_side = num_control_points // 2
|
||||
ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
|
||||
ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
|
||||
ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
|
||||
ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
|
||||
ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
|
||||
# ctrl_pts_top = ctrl_pts_top[1:-1,:]
|
||||
# ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:]
|
||||
output_ctrl_pts_arr = np.concatenate(
|
||||
[ctrl_pts_top, ctrl_pts_bottom], axis=0)
|
||||
output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr)
|
||||
return output_ctrl_pts
|
||||
|
||||
|
||||
# demo: ~/test/models/test_tps_transformation.py
|
||||
class TPSSpatialTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
output_image_size=None,
|
||||
num_control_points=None,
|
||||
margins=None):
|
||||
super(TPSSpatialTransformer, self).__init__()
|
||||
self.output_image_size = output_image_size
|
||||
self.num_control_points = num_control_points
|
||||
self.margins = margins
|
||||
|
||||
self.target_height, self.target_width = output_image_size
|
||||
target_control_points = build_output_control_points(num_control_points,
|
||||
margins)
|
||||
N = num_control_points
|
||||
# N = N - 4
|
||||
|
||||
# create padded kernel matrix
|
||||
forward_kernel = torch.zeros(N + 3, N + 3)
|
||||
target_control_partial_repr = compute_partial_repr(
|
||||
target_control_points, target_control_points)
|
||||
forward_kernel[:N, :N].copy_(target_control_partial_repr)
|
||||
forward_kernel[:N, -3].fill_(1)
|
||||
forward_kernel[-3, :N].fill_(1)
|
||||
forward_kernel[:N, -2:].copy_(target_control_points)
|
||||
forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1))
|
||||
# compute inverse matrix
|
||||
inverse_kernel = torch.inverse(forward_kernel)
|
||||
|
||||
# create target cordinate matrix
|
||||
HW = self.target_height * self.target_width
|
||||
target_coordinate = list(
|
||||
itertools.product(
|
||||
range(self.target_height), range(self.target_width)))
|
||||
target_coordinate = torch.Tensor(target_coordinate) # HW x 2
|
||||
Y, X = target_coordinate.split(1, dim=1)
|
||||
Y = Y / (self.target_height - 1)
|
||||
X = X / (self.target_width - 1)
|
||||
target_coordinate = torch.cat([X, Y],
|
||||
dim=1) # convert from (y, x) to (x, y)
|
||||
target_coordinate_partial_repr = compute_partial_repr(
|
||||
target_coordinate, target_control_points)
|
||||
target_coordinate_repr = torch.cat([
|
||||
target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate
|
||||
],
|
||||
dim=1)
|
||||
|
||||
# register precomputed matrices
|
||||
self.register_buffer('inverse_kernel', inverse_kernel)
|
||||
self.register_buffer('padding_matrix', torch.zeros(3, 2))
|
||||
self.register_buffer('target_coordinate_repr', target_coordinate_repr)
|
||||
self.register_buffer('target_control_points', target_control_points)
|
||||
|
||||
def forward(self, input, source_control_points):
|
||||
assert source_control_points.ndimension() == 3
|
||||
assert source_control_points.size(1) == self.num_control_points
|
||||
assert source_control_points.size(2) == 2
|
||||
batch_size = source_control_points.size(0)
|
||||
|
||||
Y = torch.cat([
|
||||
source_control_points, self.padding_matrix.expand(batch_size, 3, 2)
|
||||
], 1)
|
||||
mapping_matrix = torch.matmul(self.inverse_kernel, Y)
|
||||
source_coordinate = torch.matmul(self.target_coordinate_repr,
|
||||
mapping_matrix)
|
||||
|
||||
grid = source_coordinate.view(-1, self.target_height, self.target_width,
|
||||
2)
|
||||
grid = torch.clamp(grid, 0,
|
||||
1) # the source_control_points may be out of [0, 1].
|
||||
# the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
|
||||
grid = 2.0 * grid - 1.0
|
||||
output_maps = grid_sample(input, grid, canvas=None)
|
||||
return output_maps, source_coordinate
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from stn_torch import STNHead
|
||||
in_planes = 3
|
||||
num_ctrlpoints = 20
|
||||
torch.manual_seed(10)
|
||||
activation = 'none' # 'sigmoid'
|
||||
stn_head = STNHead(in_planes, num_ctrlpoints, activation)
|
||||
np.random.seed(100)
|
||||
data = np.random.randn(10, 3, 32, 64).astype("float32")
|
||||
input = torch.tensor(data)
|
||||
control_points = stn_head(input)
|
||||
tps = TPSSpatialTransformer(
|
||||
output_image_size=[32, 320],
|
||||
num_control_points=20,
|
||||
margins=[0.05, 0.05])
|
||||
out = tps(input, control_points[1])
|
||||
print("out 0 :", out[0].shape)
|
||||
print("out 1:", out[1].shape)
|
|
@ -127,3 +127,34 @@ class RMSProp(object):
|
|||
grad_clip=self.grad_clip,
|
||||
parameters=parameters)
|
||||
return opt
|
||||
|
||||
|
||||
class Adadelta(object):
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
epsilon=1e-08,
|
||||
rho=0.95,
|
||||
parameter_list=None,
|
||||
weight_decay=None,
|
||||
grad_clip=None,
|
||||
name=None,
|
||||
**kwargs):
|
||||
self.learning_rate = learning_rate
|
||||
self.epsilon = epsilon
|
||||
self.rho = rho
|
||||
self.parameter_list = parameter_list
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay = weight_decay
|
||||
self.grad_clip = grad_clip
|
||||
self.name = name
|
||||
|
||||
def __call__(self, parameters):
|
||||
opt = optim.Adadelta(
|
||||
learning_rate=self.learning_rate,
|
||||
epsilon=self.epsilon,
|
||||
rho=self.rho,
|
||||
weight_decay=self.weight_decay,
|
||||
grad_clip=self.grad_clip,
|
||||
name=self.name,
|
||||
parameters=parameters)
|
||||
return opt
|
||||
|
|
|
@ -25,7 +25,7 @@ from .db_postprocess import DBPostProcess
|
|||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
TableLabelDecode
|
||||
TableLabelDecode, SEEDLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
|
@ -34,7 +34,7 @@ def build_post_process(config, global_config=None):
|
|||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode'
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode', 'SEEDLabelDecode'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -170,10 +170,8 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.unkonwn = "UNKNOWN"
|
||||
dict_character = dict_character
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str
|
||||
] + [self.unkonwn]
|
||||
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
|
@ -214,7 +212,6 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
"""
|
||||
preds = preds["rec_pred"]
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
|
||||
|
@ -242,6 +239,88 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
return idx
|
||||
|
||||
|
||||
class SEEDLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SEEDLabelDecode, self).__init__(character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character
|
||||
dict_character = dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
end_idx = self.get_beg_end_flag_idx("eos")
|
||||
return [end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "sos":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "eos":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
|
||||
return idx
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
[end_idx] = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||
break
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text, np.mean(conf_list)))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
"""
|
||||
text = self.decode(text)
|
||||
if label is None:
|
||||
return text
|
||||
else:
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
"""
|
||||
preds_idx = preds["rec_pred"]
|
||||
if isinstance(preds_idx, paddle.Tensor):
|
||||
preds_idx = preds_idx.numpy()
|
||||
if "rec_pred_scores" in preds:
|
||||
preds_idx = preds["rec_pred"]
|
||||
preds_prob = preds["rec_pred_scores"]
|
||||
else:
|
||||
preds_idx = preds["rec_pred"].argmax(axis=2)
|
||||
preds_prob = preds["rec_pred"].max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
|
||||
class SRNLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -105,16 +105,13 @@ def load_dygraph_params(config, model, logger, optimizer):
|
|||
params = paddle.load(pm)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
# for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
for k1 in state_dict.keys():
|
||||
if k1 not in params:
|
||||
continue
|
||||
if list(state_dict[k1].shape) == list(params[k1].shape):
|
||||
new_state_dict[k1] = params[k1]
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k1} {params[k1].shape} !"
|
||||
)
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
|
|
@ -211,11 +211,10 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
# if use_srn or model_type == 'table' or algorithm == "ASTER":
|
||||
# preds = model(images, data=batch[1:])
|
||||
# else:
|
||||
# preds = model(images)
|
||||
preds = model(images, data=batch[1:])
|
||||
if use_srn or model_type == 'table' or model_type == "seed":
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
state_dict = model.state_dict()
|
||||
# for key in state_dict:
|
||||
# print(key)
|
||||
|
@ -415,6 +414,7 @@ def preprocess(is_train=False):
|
|||
yaml.dump(
|
||||
dict(config), f, default_flow_style=False, sort_keys=False)
|
||||
log_file = '{}/train.log'.format(save_model_dir)
|
||||
print("log has save in {}/train.log".format(save_model_dir))
|
||||
else:
|
||||
log_file = None
|
||||
logger = get_logger(name='root', log_file=log_file)
|
||||
|
|
Loading…
Reference in New Issue