add srn for dygraph
This commit is contained in:
parent
de3e2e7cd3
commit
c1fd46641e
|
@ -1,5 +1,5 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
use_gpu: True
|
||||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
|
@ -59,7 +59,7 @@ Metric:
|
|||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
@ -78,7 +78,7 @@ Train:
|
|||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
|
|
@ -58,7 +58,7 @@ Metric:
|
|||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
@ -77,7 +77,7 @@ Train:
|
|||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
|
|
@ -63,7 +63,7 @@ Metric:
|
|||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
@ -82,7 +82,7 @@ Train:
|
|||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
|
|
@ -58,7 +58,7 @@ Metric:
|
|||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
@ -77,7 +77,7 @@ Train:
|
|||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
|
|
@ -56,7 +56,7 @@ Metric:
|
|||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
@ -75,7 +75,7 @@ Train:
|
|||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
|
|
@ -62,7 +62,7 @@ Metric:
|
|||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
@ -81,7 +81,7 @@ Train:
|
|||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDateSet
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: ./output/rec/srn
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 5000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: en
|
||||
max_text_length: 25
|
||||
num_heads: 8
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.0001
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SRN
|
||||
in_channels: 1
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNetFPN
|
||||
Head:
|
||||
name: SRNHead
|
||||
max_text_length: 25
|
||||
num_heads: 8
|
||||
num_encoder_TUs: 2
|
||||
num_decoder_TUs: 4
|
||||
hidden_dims: 512
|
||||
|
||||
Loss:
|
||||
name: SRNLoss
|
||||
|
||||
PostProcess:
|
||||
name: SRNLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/srn_train_data_duiqi
|
||||
#label_file_list: ["./train_data/ic15_data/1.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SRNLabelEncode: # Class handling label
|
||||
- SRNRecResizeImg:
|
||||
image_shape: [1, 64, 256]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image',
|
||||
'label',
|
||||
'length',
|
||||
'encoder_word_pos',
|
||||
'gsrm_word_pos',
|
||||
'gsrm_slf_attn_bias1',
|
||||
'gsrm_slf_attn_bias2'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
batch_size_per_card: 64
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/evaluation
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SRNLabelEncode: # Class handling label
|
||||
- SRNRecResizeImg:
|
||||
image_shape: [1, 64, 256]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image',
|
||||
'label',
|
||||
'length',
|
||||
'encoder_word_pos',
|
||||
'gsrm_word_pos',
|
||||
'gsrm_slf_attn_bias1',
|
||||
'gsrm_slf_attn_bias2']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 32
|
||||
num_workers: 4
|
|
@ -33,7 +33,7 @@ import paddle.distributed as dist
|
|||
|
||||
from ppocr.data.imaug import transform, create_operators
|
||||
from ppocr.data.simple_dataset import SimpleDataSet
|
||||
from ppocr.data.lmdb_dataset import LMDBDateSet
|
||||
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||
|
||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||
|
||||
|
@ -54,7 +54,7 @@ signal.signal(signal.SIGTERM, term_mp)
|
|||
def build_dataloader(config, mode, device, logger):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
support_dict = ['SimpleDataSet', 'LMDBDateSet']
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet']
|
||||
module_name = config[mode]['dataset']['name']
|
||||
assert module_name in support_dict, Exception(
|
||||
'DataSet only support {}'.format(support_dict))
|
||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
|||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .operators import *
|
||||
from .label_ops import *
|
||||
|
|
|
@ -98,6 +98,8 @@ class BaseRecLabelEncode(object):
|
|||
support_character_type, character_type)
|
||||
|
||||
self.max_text_len = max_text_length
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
if character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
|
@ -213,3 +215,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
||||
|
||||
class SRNLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length=25,
|
||||
character_dict_path=None,
|
||||
character_type='en',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SRNLabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||
return dict_character
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
char_num = len(self.character_str)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) > self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text = text + [char_num] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
|
|
@ -12,20 +12,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -77,6 +63,26 @@ class RecResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class SRNRecResizeImg(object):
|
||||
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.num_heads = num_heads
|
||||
self.max_text_length = max_text_length
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img = resize_norm_img_srn(img, self.image_shape)
|
||||
data['image'] = norm_img
|
||||
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
|
||||
|
||||
data['encoder_word_pos'] = encoder_word_pos
|
||||
data['gsrm_word_pos'] = gsrm_word_pos
|
||||
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
|
||||
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
|
||||
return data
|
||||
|
||||
|
||||
def resize_norm_img(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
h = img.shape[0]
|
||||
|
@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
|
|||
def resize_norm_img_chinese(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
# todo: change to 0 and modified image shape
|
||||
max_wh_ratio = 0
|
||||
max_wh_ratio = imgW * 1.0 / imgH
|
||||
h, w = img.shape[0], img.shape[1]
|
||||
ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, ratio)
|
||||
|
@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
|
|||
return padding_im
|
||||
|
||||
|
||||
def resize_norm_img_srn(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
img_black = np.zeros((imgH, imgW))
|
||||
im_hei = img.shape[0]
|
||||
im_wid = img.shape[1]
|
||||
|
||||
if im_wid <= im_hei * 1:
|
||||
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||
elif im_wid <= im_hei * 2:
|
||||
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||
elif im_wid <= im_hei * 3:
|
||||
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||
else:
|
||||
img_new = cv2.resize(img, (imgW, imgH))
|
||||
|
||||
img_np = np.asarray(img_new)
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||
img_black[:, 0:img_np.shape[1]] = img_np
|
||||
img_black = img_black[:, :, np.newaxis]
|
||||
|
||||
row, col, c = img_black.shape
|
||||
c = 1
|
||||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
|
||||
def srn_other_inputs(image_shape, num_heads, max_text_length):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||
(feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||
(max_text_length, 1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||
[1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
|
||||
[num_heads, 1, 1]) * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||
[1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
|
||||
[num_heads, 1, 1]) * [-1e9]
|
||||
|
||||
return [
|
||||
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
|
||||
|
||||
def flag():
|
||||
"""
|
||||
flag
|
||||
|
|
|
@ -20,9 +20,9 @@ import cv2
|
|||
from .imaug import transform, create_operators
|
||||
|
||||
|
||||
class LMDBDateSet(Dataset):
|
||||
class LMDBDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger):
|
||||
super(LMDBDateSet, self).__init__()
|
||||
super(LMDBDataSet, self).__init__()
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
|
|
|
@ -23,11 +23,14 @@ def build_loss(config):
|
|||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class SRNLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(SRNLoss, self).__init__()
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
predict = predicts['predict']
|
||||
word_predict = predicts['word_out']
|
||||
gsrm_predict = predicts['gsrm_out']
|
||||
label = batch[1]
|
||||
|
||||
casted_label = paddle.cast(x=label, dtype='int64')
|
||||
casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
|
||||
|
||||
cost_word = self.loss_func(word_predict, label=casted_label)
|
||||
cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
|
||||
cost_vsfd = self.loss_func(predict, label=casted_label)
|
||||
|
||||
cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
|
||||
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
|
||||
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
|
||||
|
||||
sum_cost = cost_word + cost_vsfd * 2.0 + cost_gsrm * 0.15
|
||||
|
||||
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
|
|
@ -26,6 +26,7 @@ def build_metric(config):
|
|||
from .det_metric import DetMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .rec_metric import RecMetric
|
||||
|
||||
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric']
|
||||
|
||||
|
|
|
@ -31,8 +31,6 @@ class RecMetric(object):
|
|||
if pred == target:
|
||||
correct_num += 1
|
||||
all_num += 1
|
||||
# if all_num < 10 and kwargs.get('show_str', False):
|
||||
# print('{} -> {}'.format(pred, target))
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
self.norm_edit_dis += norm_edit_dis
|
||||
|
@ -48,7 +46,7 @@ class RecMetric(object):
|
|||
'norm_edit_dis': 0,
|
||||
}
|
||||
"""
|
||||
acc = self.correct_num / self.all_num
|
||||
acc = 1.0 * self.correct_num / self.all_num
|
||||
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
|
||||
self.reset()
|
||||
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
||||
|
|
|
@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
|
|||
config["Head"]['in_channels'] = in_channels
|
||||
self.head = build_head(config["Head"])
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, data=None):
|
||||
if self.use_transform:
|
||||
x = self.transform(x)
|
||||
x = self.backbone(x)
|
||||
if self.use_neck:
|
||||
x = self.neck(x)
|
||||
x = self.head(x)
|
||||
if data is None:
|
||||
x = self.head(x)
|
||||
else:
|
||||
x = self.head(x, data)
|
||||
return x
|
||||
|
|
|
@ -24,7 +24,8 @@ def build_backbone(config, model_type):
|
|||
elif model_type == 'rec' or model_type == 'cls':
|
||||
from .rec_mobilenet_v3 import MobileNetV3
|
||||
from .rec_resnet_vd import ResNet
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN']
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -0,0 +1,307 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import paddle.fluid as fluid
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["ResNetFPN"]
|
||||
|
||||
|
||||
class ResNetFPN(nn.Layer):
|
||||
def __init__(self, in_channels=1, layers=50, **kwargs):
|
||||
super(ResNetFPN, self).__init__()
|
||||
supported_layers = {
|
||||
18: {
|
||||
'depth': [2, 2, 2, 2],
|
||||
'block_class': BasicBlock
|
||||
},
|
||||
34: {
|
||||
'depth': [3, 4, 6, 3],
|
||||
'block_class': BasicBlock
|
||||
},
|
||||
50: {
|
||||
'depth': [3, 4, 6, 3],
|
||||
'block_class': BottleneckBlock
|
||||
},
|
||||
101: {
|
||||
'depth': [3, 4, 23, 3],
|
||||
'block_class': BottleneckBlock
|
||||
},
|
||||
152: {
|
||||
'depth': [3, 8, 36, 3],
|
||||
'block_class': BottleneckBlock
|
||||
}
|
||||
}
|
||||
stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
self.depth = supported_layers[layers]['depth']
|
||||
self.F = []
|
||||
self.conv = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
act="relu",
|
||||
name="conv1")
|
||||
self.block_list = []
|
||||
in_ch = 64
|
||||
if layers >= 50:
|
||||
for block in range(len(self.depth)):
|
||||
for i in range(self.depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
block_list = self.add_sublayer(
|
||||
"bottleneckBlock_{}_{}".format(block, i),
|
||||
BottleneckBlock(
|
||||
in_channels=in_ch,
|
||||
out_channels=num_filters[block],
|
||||
stride=stride_list[block] if i == 0 else 1,
|
||||
name=conv_name))
|
||||
in_ch = num_filters[block] * 4
|
||||
self.block_list.append(block_list)
|
||||
self.F.append(block_list)
|
||||
else:
|
||||
for block in range(len(self.depth)):
|
||||
for i in range(self.depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
if i == 0 and block != 0:
|
||||
stride = (2, 1)
|
||||
else:
|
||||
stride = (1, 1)
|
||||
basic_block = self.add_sublayer(
|
||||
conv_name,
|
||||
BasicBlock(
|
||||
in_channels=in_ch,
|
||||
out_channels=num_filters[block],
|
||||
stride=stride_list[block] if i == 0 else 1,
|
||||
is_first=block == i == 0,
|
||||
name=conv_name))
|
||||
in_ch = basic_block.out_channels
|
||||
self.block_list.append(basic_block)
|
||||
out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
|
||||
self.base_block = []
|
||||
self.conv_trans = []
|
||||
self.bn_block = []
|
||||
for i in [-2, -3]:
|
||||
in_channels = out_ch_list[i + 1] + out_ch_list[i]
|
||||
|
||||
self.base_block.append(
|
||||
self.add_sublayer(
|
||||
"F_{}_base_block_0".format(i),
|
||||
nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_ch_list[i],
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(trainable=True),
|
||||
bias_attr=ParamAttr(trainable=True))))
|
||||
self.base_block.append(
|
||||
self.add_sublayer(
|
||||
"F_{}_base_block_1".format(i),
|
||||
nn.Conv2D(
|
||||
in_channels=out_ch_list[i],
|
||||
out_channels=out_ch_list[i],
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(trainable=True),
|
||||
bias_attr=ParamAttr(trainable=True))))
|
||||
self.base_block.append(
|
||||
self.add_sublayer(
|
||||
"F_{}_base_block_2".format(i),
|
||||
nn.BatchNorm(
|
||||
num_channels=out_ch_list[i],
|
||||
act="relu",
|
||||
param_attr=ParamAttr(trainable=True),
|
||||
bias_attr=ParamAttr(trainable=True))))
|
||||
self.base_block.append(
|
||||
self.add_sublayer(
|
||||
"F_{}_base_block_3".format(i),
|
||||
nn.Conv2D(
|
||||
in_channels=out_ch_list[i],
|
||||
out_channels=512,
|
||||
kernel_size=1,
|
||||
bias_attr=ParamAttr(trainable=True),
|
||||
weight_attr=ParamAttr(trainable=True))))
|
||||
self.out_channels = 512
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv(x)
|
||||
fpn_list = []
|
||||
F = []
|
||||
for i in range(len(self.depth)):
|
||||
fpn_list.append(np.sum(self.depth[:i + 1]))
|
||||
|
||||
for i, block in enumerate(self.block_list):
|
||||
x = block(x)
|
||||
for number in fpn_list:
|
||||
if i + 1 == number:
|
||||
F.append(x)
|
||||
base = F[-1]
|
||||
|
||||
j = 0
|
||||
for i, block in enumerate(self.base_block):
|
||||
if i % 3 == 0 and i < 6:
|
||||
j = j + 1
|
||||
b, c, w, h = F[-j - 1].shape
|
||||
if [w, h] == list(base.shape[2:]):
|
||||
base = base
|
||||
else:
|
||||
base = self.conv_trans[j - 1](base)
|
||||
base = self.bn_block[j - 1](base)
|
||||
base = paddle.concat([base, F[-j - 1]], axis=1)
|
||||
base = block(base)
|
||||
return base
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=2 if stride == (1, 1) else kernel_size,
|
||||
dilation=2 if stride == (1, 1) else 1,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'),
|
||||
bias_attr=False, )
|
||||
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=name + '.output.1.w_0'),
|
||||
bias_attr=ParamAttr(name=name + '.output.1.b_0'),
|
||||
moving_mean_name=bn_name + "_mean",
|
||||
moving_variance_name=bn_name + "_variance")
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class ShortCut(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, stride, name, is_first=False):
|
||||
super(ShortCut, self).__init__()
|
||||
self.use_conv = True
|
||||
|
||||
if in_channels != out_channels or stride != 1 or is_first == True:
|
||||
if stride == (1, 1):
|
||||
self.conv = ConvBNLayer(
|
||||
in_channels, out_channels, 1, 1, name=name)
|
||||
else: # stride==(2,2)
|
||||
self.conv = ConvBNLayer(
|
||||
in_channels, out_channels, 1, stride, name=name)
|
||||
else:
|
||||
self.use_conv = False
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, stride, name):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
self.short = ShortCut(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * 4,
|
||||
stride=stride,
|
||||
is_first=False,
|
||||
name=name + "_branch1")
|
||||
self.out_channels = out_channels * 4
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv0(x)
|
||||
y = self.conv1(y)
|
||||
y = self.conv2(y)
|
||||
y = y + self.short(x)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, stride, name, is_first):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act='relu',
|
||||
stride=stride,
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
self.short = ShortCut(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=stride,
|
||||
is_first=is_first,
|
||||
name=name + "_branch1")
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x):
|
||||
y = self.conv0(x)
|
||||
y = self.conv1(y)
|
||||
y = y + self.short(x)
|
||||
return F.relu(y)
|
|
@ -23,10 +23,13 @@ def build_head(config):
|
|||
|
||||
# rec head
|
||||
from .rec_ctc_head import CTCHead
|
||||
from .rec_srn_head import SRNHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead']
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead'
|
||||
]
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,279 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import paddle.fluid as fluid
|
||||
import numpy as np
|
||||
from .self_attention import WrapEncoderForFeature
|
||||
from .self_attention import WrapEncoder
|
||||
from paddle.static import Program
|
||||
from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
|
||||
import paddle.fluid.framework as framework
|
||||
|
||||
from collections import OrderedDict
|
||||
gradient_clip = 10
|
||||
|
||||
|
||||
class PVAM(nn.Layer):
|
||||
def __init__(self, in_channels, char_num, max_text_length, num_heads,
|
||||
num_encoder_tus, hidden_dims):
|
||||
super(PVAM, self).__init__()
|
||||
self.char_num = char_num
|
||||
self.max_length = max_text_length
|
||||
self.num_heads = num_heads
|
||||
self.num_encoder_TUs = num_encoder_tus
|
||||
self.hidden_dims = hidden_dims
|
||||
# Transformer encoder
|
||||
t = 256
|
||||
c = 512
|
||||
self.wrap_encoder_for_feature = WrapEncoderForFeature(
|
||||
src_vocab_size=1,
|
||||
max_length=t,
|
||||
n_layer=self.num_encoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
d_key=int(self.hidden_dims / self.num_heads),
|
||||
d_value=int(self.hidden_dims / self.num_heads),
|
||||
d_model=self.hidden_dims,
|
||||
d_inner_hid=self.hidden_dims,
|
||||
prepostprocess_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
relu_dropout=0.1,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True)
|
||||
|
||||
# PVAM
|
||||
self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1)
|
||||
self.fc0 = paddle.nn.Linear(
|
||||
in_features=in_channels,
|
||||
out_features=in_channels, )
|
||||
self.emb = paddle.nn.Embedding(
|
||||
num_embeddings=self.max_length, embedding_dim=in_channels)
|
||||
self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2)
|
||||
self.fc1 = paddle.nn.Linear(
|
||||
in_features=in_channels, out_features=1, bias_attr=False)
|
||||
|
||||
def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
|
||||
b, c, h, w = inputs.shape
|
||||
conv_features = paddle.reshape(inputs, shape=[-1, c, h * w])
|
||||
conv_features = paddle.transpose(conv_features, perm=[0, 2, 1])
|
||||
# transformer encoder
|
||||
b, t, c = conv_features.shape
|
||||
|
||||
enc_inputs = [conv_features, encoder_word_pos, None]
|
||||
word_features = self.wrap_encoder_for_feature(enc_inputs)
|
||||
|
||||
# pvam
|
||||
b, t, c = word_features.shape
|
||||
word_features = self.fc0(word_features)
|
||||
word_features_ = paddle.reshape(word_features, [-1, 1, t, c])
|
||||
word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1])
|
||||
word_pos_feature = self.emb(gsrm_word_pos)
|
||||
word_pos_feature_ = paddle.reshape(word_pos_feature,
|
||||
[-1, self.max_length, 1, c])
|
||||
word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1])
|
||||
y = word_pos_feature_ + word_features_
|
||||
y = F.tanh(y)
|
||||
attention_weight = self.fc1(y)
|
||||
attention_weight = paddle.reshape(
|
||||
attention_weight, shape=[-1, self.max_length, t])
|
||||
attention_weight = F.softmax(attention_weight, axis=-1)
|
||||
pvam_features = paddle.matmul(attention_weight,
|
||||
word_features) #[b, max_length, c]
|
||||
return pvam_features
|
||||
|
||||
|
||||
class GSRM(nn.Layer):
|
||||
def __init__(self, in_channels, char_num, max_text_length, num_heads,
|
||||
num_encoder_tus, num_decoder_tus, hidden_dims):
|
||||
super(GSRM, self).__init__()
|
||||
self.char_num = char_num
|
||||
self.max_length = max_text_length
|
||||
self.num_heads = num_heads
|
||||
self.num_encoder_TUs = num_encoder_tus
|
||||
self.num_decoder_TUs = num_decoder_tus
|
||||
self.hidden_dims = hidden_dims
|
||||
|
||||
self.fc0 = paddle.nn.Linear(
|
||||
in_features=in_channels, out_features=self.char_num)
|
||||
self.wrap_encoder0 = WrapEncoder(
|
||||
src_vocab_size=self.char_num + 1,
|
||||
max_length=self.max_length,
|
||||
n_layer=self.num_decoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
d_key=int(self.hidden_dims / self.num_heads),
|
||||
d_value=int(self.hidden_dims / self.num_heads),
|
||||
d_model=self.hidden_dims,
|
||||
d_inner_hid=self.hidden_dims,
|
||||
prepostprocess_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
relu_dropout=0.1,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True)
|
||||
|
||||
self.wrap_encoder1 = WrapEncoder(
|
||||
src_vocab_size=self.char_num + 1,
|
||||
max_length=self.max_length,
|
||||
n_layer=self.num_decoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
d_key=int(self.hidden_dims / self.num_heads),
|
||||
d_value=int(self.hidden_dims / self.num_heads),
|
||||
d_model=self.hidden_dims,
|
||||
d_inner_hid=self.hidden_dims,
|
||||
prepostprocess_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
relu_dropout=0.1,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True)
|
||||
|
||||
self.mul = lambda x: paddle.matmul(x=x,
|
||||
y=self.wrap_encoder0.prepare_decoder.emb0.weight,
|
||||
transpose_y=True)
|
||||
|
||||
def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2):
|
||||
# ===== GSRM Visual-to-semantic embedding block =====
|
||||
b, t, c = inputs.shape
|
||||
pvam_features = paddle.reshape(inputs, [-1, c])
|
||||
word_out = self.fc0(pvam_features)
|
||||
word_ids = paddle.argmax(F.softmax(word_out), axis=1)
|
||||
word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1])
|
||||
|
||||
#===== GSRM Semantic reasoning block =====
|
||||
"""
|
||||
This module is achieved through bi-transformers,
|
||||
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
|
||||
"""
|
||||
pad_idx = self.char_num
|
||||
|
||||
word1 = paddle.cast(word_ids, "float32")
|
||||
word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC")
|
||||
word1 = paddle.cast(word1, "int64")
|
||||
word1 = word1[:, :-1, :]
|
||||
word2 = word_ids
|
||||
|
||||
enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
|
||||
enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
|
||||
|
||||
gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
|
||||
gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
|
||||
|
||||
gsrm_feature2 = F.pad(gsrm_feature2, [0, 1],
|
||||
value=0.,
|
||||
data_format="NLC")
|
||||
gsrm_feature2 = gsrm_feature2[:, 1:, ]
|
||||
gsrm_features = gsrm_feature1 + gsrm_feature2
|
||||
|
||||
gsrm_out = self.mul(gsrm_features)
|
||||
|
||||
b, t, c = gsrm_out.shape
|
||||
gsrm_out = paddle.reshape(gsrm_out, [-1, c])
|
||||
|
||||
return gsrm_features, word_out, gsrm_out
|
||||
|
||||
|
||||
class VSFD(nn.Layer):
|
||||
def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
|
||||
super(VSFD, self).__init__()
|
||||
self.char_num = char_num
|
||||
self.fc0 = paddle.nn.Linear(
|
||||
in_features=in_channels * 2, out_features=pvam_ch)
|
||||
self.fc1 = paddle.nn.Linear(
|
||||
in_features=pvam_ch, out_features=self.char_num)
|
||||
|
||||
def forward(self, pvam_feature, gsrm_feature):
|
||||
b, t, c1 = pvam_feature.shape
|
||||
b, t, c2 = gsrm_feature.shape
|
||||
combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2)
|
||||
img_comb_feature_ = paddle.reshape(
|
||||
combine_feature_, shape=[-1, c1 + c2])
|
||||
img_comb_feature_map = self.fc0(img_comb_feature_)
|
||||
img_comb_feature_map = F.sigmoid(img_comb_feature_map)
|
||||
img_comb_feature_map = paddle.reshape(
|
||||
img_comb_feature_map, shape=[-1, t, c1])
|
||||
combine_feature = img_comb_feature_map * pvam_feature + (
|
||||
1.0 - img_comb_feature_map) * gsrm_feature
|
||||
img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1])
|
||||
|
||||
out = self.fc1(img_comb_feature)
|
||||
return out
|
||||
|
||||
|
||||
class SRNHead(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, max_text_length, num_heads,
|
||||
num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
|
||||
super(SRNHead, self).__init__()
|
||||
self.char_num = out_channels
|
||||
self.max_length = max_text_length
|
||||
self.num_heads = num_heads
|
||||
self.num_encoder_TUs = num_encoder_TUs
|
||||
self.num_decoder_TUs = num_decoder_TUs
|
||||
self.hidden_dims = hidden_dims
|
||||
|
||||
self.pvam = PVAM(
|
||||
in_channels=in_channels,
|
||||
char_num=self.char_num,
|
||||
max_text_length=self.max_length,
|
||||
num_heads=self.num_heads,
|
||||
num_encoder_tus=self.num_encoder_TUs,
|
||||
hidden_dims=self.hidden_dims)
|
||||
|
||||
self.gsrm = GSRM(
|
||||
in_channels=in_channels,
|
||||
char_num=self.char_num,
|
||||
max_text_length=self.max_length,
|
||||
num_heads=self.num_heads,
|
||||
num_encoder_tus=self.num_encoder_TUs,
|
||||
num_decoder_tus=self.num_decoder_TUs,
|
||||
hidden_dims=self.hidden_dims)
|
||||
self.vsfd = VSFD(in_channels=in_channels)
|
||||
|
||||
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
|
||||
|
||||
def forward(self, inputs, others):
|
||||
encoder_word_pos = others[0]
|
||||
gsrm_word_pos = others[1]
|
||||
gsrm_slf_attn_bias1 = others[2]
|
||||
gsrm_slf_attn_bias2 = others[3]
|
||||
|
||||
pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
|
||||
|
||||
gsrm_feature, word_out, gsrm_out = self.gsrm(
|
||||
pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2)
|
||||
|
||||
final_out = self.vsfd(pvam_feature, gsrm_feature)
|
||||
if not self.training:
|
||||
final_out = F.softmax(final_out, axis=1)
|
||||
|
||||
_, decoded_out = paddle.topk(final_out, k=1)
|
||||
|
||||
predicts = OrderedDict([
|
||||
('predict', final_out),
|
||||
('pvam_feature', pvam_feature),
|
||||
('decoded_out', decoded_out),
|
||||
('word_out', word_out),
|
||||
('gsrm_out', gsrm_out),
|
||||
])
|
||||
|
||||
return predicts
|
|
@ -0,0 +1,408 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr, nn
|
||||
from paddle import nn, ParamAttr
|
||||
from paddle.nn import functional as F
|
||||
import paddle.fluid as fluid
|
||||
import numpy as np
|
||||
gradient_clip = 10
|
||||
|
||||
|
||||
class WrapEncoderForFeature(nn.Layer):
|
||||
def __init__(self,
|
||||
src_vocab_size,
|
||||
max_length,
|
||||
n_layer,
|
||||
n_head,
|
||||
d_key,
|
||||
d_value,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
prepostprocess_dropout,
|
||||
attention_dropout,
|
||||
relu_dropout,
|
||||
preprocess_cmd,
|
||||
postprocess_cmd,
|
||||
weight_sharing,
|
||||
bos_idx=0):
|
||||
super(WrapEncoderForFeature, self).__init__()
|
||||
|
||||
self.prepare_encoder = PrepareEncoder(
|
||||
src_vocab_size,
|
||||
d_model,
|
||||
max_length,
|
||||
prepostprocess_dropout,
|
||||
bos_idx=bos_idx,
|
||||
word_emb_param_name="src_word_emb_table")
|
||||
self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
|
||||
d_inner_hid, prepostprocess_dropout,
|
||||
attention_dropout, relu_dropout, preprocess_cmd,
|
||||
postprocess_cmd)
|
||||
|
||||
def forward(self, enc_inputs):
|
||||
conv_features, src_pos, src_slf_attn_bias = enc_inputs
|
||||
enc_input = self.prepare_encoder(conv_features, src_pos)
|
||||
enc_output = self.encoder(enc_input, src_slf_attn_bias)
|
||||
return enc_output
|
||||
|
||||
|
||||
class WrapEncoder(nn.Layer):
|
||||
"""
|
||||
embedder + encoder
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
src_vocab_size,
|
||||
max_length,
|
||||
n_layer,
|
||||
n_head,
|
||||
d_key,
|
||||
d_value,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
prepostprocess_dropout,
|
||||
attention_dropout,
|
||||
relu_dropout,
|
||||
preprocess_cmd,
|
||||
postprocess_cmd,
|
||||
weight_sharing,
|
||||
bos_idx=0):
|
||||
super(WrapEncoder, self).__init__()
|
||||
|
||||
self.prepare_decoder = PrepareDecoder(
|
||||
src_vocab_size,
|
||||
d_model,
|
||||
max_length,
|
||||
prepostprocess_dropout,
|
||||
bos_idx=bos_idx)
|
||||
self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
|
||||
d_inner_hid, prepostprocess_dropout,
|
||||
attention_dropout, relu_dropout, preprocess_cmd,
|
||||
postprocess_cmd)
|
||||
|
||||
def forward(self, enc_inputs):
|
||||
src_word, src_pos, src_slf_attn_bias = enc_inputs
|
||||
enc_input = self.prepare_decoder(src_word, src_pos)
|
||||
enc_output = self.encoder(enc_input, src_slf_attn_bias)
|
||||
return enc_output
|
||||
|
||||
|
||||
class Encoder(nn.Layer):
|
||||
"""
|
||||
encoder
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_layer,
|
||||
n_head,
|
||||
d_key,
|
||||
d_value,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
prepostprocess_dropout,
|
||||
attention_dropout,
|
||||
relu_dropout,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da"):
|
||||
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
self.encoder_layers = list()
|
||||
for i in range(n_layer):
|
||||
self.encoder_layers.append(
|
||||
self.add_sublayer(
|
||||
"layer_%d" % i,
|
||||
EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
|
||||
prepostprocess_dropout, attention_dropout,
|
||||
relu_dropout, preprocess_cmd,
|
||||
postprocess_cmd)))
|
||||
self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
|
||||
prepostprocess_dropout)
|
||||
|
||||
def forward(self, enc_input, attn_bias):
|
||||
for encoder_layer in self.encoder_layers:
|
||||
enc_output = encoder_layer(enc_input, attn_bias)
|
||||
enc_input = enc_output
|
||||
enc_output = self.processer(enc_output)
|
||||
return enc_output
|
||||
|
||||
|
||||
class EncoderLayer(nn.Layer):
|
||||
"""
|
||||
EncoderLayer
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_head,
|
||||
d_key,
|
||||
d_value,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
prepostprocess_dropout,
|
||||
attention_dropout,
|
||||
relu_dropout,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da"):
|
||||
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
|
||||
prepostprocess_dropout)
|
||||
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
|
||||
attention_dropout)
|
||||
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
|
||||
prepostprocess_dropout)
|
||||
|
||||
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
|
||||
prepostprocess_dropout)
|
||||
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
|
||||
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
|
||||
prepostprocess_dropout)
|
||||
|
||||
def forward(self, enc_input, attn_bias):
|
||||
attn_output = self.self_attn(
|
||||
self.preprocesser1(enc_input), None, None, attn_bias)
|
||||
attn_output = self.postprocesser1(attn_output, enc_input)
|
||||
ffn_output = self.ffn(self.preprocesser2(attn_output))
|
||||
ffn_output = self.postprocesser2(ffn_output, attn_output)
|
||||
return ffn_output
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Layer):
|
||||
"""
|
||||
Multi-Head Attention
|
||||
"""
|
||||
|
||||
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.n_head = n_head
|
||||
self.d_key = d_key
|
||||
self.d_value = d_value
|
||||
self.d_model = d_model
|
||||
self.dropout_rate = dropout_rate
|
||||
self.q_fc = paddle.nn.Linear(
|
||||
in_features=d_model, out_features=d_key * n_head, bias_attr=False)
|
||||
self.k_fc = paddle.nn.Linear(
|
||||
in_features=d_model, out_features=d_key * n_head, bias_attr=False)
|
||||
self.v_fc = paddle.nn.Linear(
|
||||
in_features=d_model, out_features=d_value * n_head, bias_attr=False)
|
||||
self.proj_fc = paddle.nn.Linear(
|
||||
in_features=d_value * n_head, out_features=d_model, bias_attr=False)
|
||||
|
||||
def _prepare_qkv(self, queries, keys, values, cache=None):
|
||||
if keys is None: # self-attention
|
||||
keys, values = queries, queries
|
||||
static_kv = False
|
||||
else: # cross-attention
|
||||
static_kv = True
|
||||
|
||||
q = self.q_fc(queries)
|
||||
q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
|
||||
q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
|
||||
|
||||
if cache is not None and static_kv and "static_k" in cache:
|
||||
# for encoder-decoder attention in inference and has cached
|
||||
k = cache["static_k"]
|
||||
v = cache["static_v"]
|
||||
else:
|
||||
k = self.k_fc(keys)
|
||||
v = self.v_fc(values)
|
||||
k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
|
||||
k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
|
||||
v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
|
||||
v = paddle.transpose(x=v, perm=[0, 2, 1, 3])
|
||||
|
||||
if cache is not None:
|
||||
if static_kv and not "static_k" in cache:
|
||||
# for encoder-decoder attention in inference and has not cached
|
||||
cache["static_k"], cache["static_v"] = k, v
|
||||
elif not static_kv:
|
||||
# for decoder self-attention in inference
|
||||
cache_k, cache_v = cache["k"], cache["v"]
|
||||
k = paddle.concat([cache_k, k], axis=2)
|
||||
v = paddle.concat([cache_v, v], axis=2)
|
||||
cache["k"], cache["v"] = k, v
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward(self, queries, keys, values, attn_bias, cache=None):
|
||||
# compute q ,k ,v
|
||||
keys = queries if keys is None else keys
|
||||
values = keys if values is None else values
|
||||
q, k, v = self._prepare_qkv(queries, keys, values, cache)
|
||||
|
||||
# scale dot product attention
|
||||
product = paddle.matmul(x=q, y=k, transpose_y=True)
|
||||
product = product * self.d_model**-0.5
|
||||
if attn_bias is not None:
|
||||
product += attn_bias
|
||||
weights = F.softmax(product)
|
||||
if self.dropout_rate:
|
||||
weights = F.dropout(
|
||||
weights, p=self.dropout_rate, mode="downscale_in_infer")
|
||||
out = paddle.matmul(weights, v)
|
||||
|
||||
# combine heads
|
||||
out = paddle.transpose(out, perm=[0, 2, 1, 3])
|
||||
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
|
||||
|
||||
# project to output
|
||||
out = self.proj_fc(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class PrePostProcessLayer(nn.Layer):
|
||||
"""
|
||||
PrePostProcessLayer
|
||||
"""
|
||||
|
||||
def __init__(self, process_cmd, d_model, dropout_rate):
|
||||
super(PrePostProcessLayer, self).__init__()
|
||||
self.process_cmd = process_cmd
|
||||
self.functors = []
|
||||
for cmd in self.process_cmd:
|
||||
if cmd == "a": # add residual connection
|
||||
self.functors.append(lambda x, y: x + y if y is not None else x)
|
||||
elif cmd == "n": # add layer normalization
|
||||
self.functors.append(
|
||||
self.add_sublayer(
|
||||
"layer_norm_%d" % len(
|
||||
self.sublayers(include_sublayers=False)),
|
||||
paddle.nn.LayerNorm(
|
||||
normalized_shape=d_model,
|
||||
weight_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(1.)),
|
||||
bias_attr=fluid.ParamAttr(
|
||||
initializer=fluid.initializer.Constant(0.)))))
|
||||
elif cmd == "d": # add dropout
|
||||
self.functors.append(lambda x: F.dropout(
|
||||
x, p=dropout_rate, mode="downscale_in_infer")
|
||||
if dropout_rate else x)
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
for i, cmd in enumerate(self.process_cmd):
|
||||
if cmd == "a":
|
||||
x = self.functors[i](x, residual)
|
||||
else:
|
||||
x = self.functors[i](x)
|
||||
return x
|
||||
|
||||
|
||||
class PrepareEncoder(nn.Layer):
|
||||
def __init__(self,
|
||||
src_vocab_size,
|
||||
src_emb_dim,
|
||||
src_max_len,
|
||||
dropout_rate=0,
|
||||
bos_idx=0,
|
||||
word_emb_param_name=None,
|
||||
pos_enc_param_name=None):
|
||||
super(PrepareEncoder, self).__init__()
|
||||
self.src_emb_dim = src_emb_dim
|
||||
self.src_max_len = src_max_len
|
||||
self.emb = paddle.nn.Embedding(
|
||||
num_embeddings=self.src_max_len,
|
||||
embedding_dim=self.src_emb_dim,
|
||||
sparse=True)
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, src_word, src_pos):
|
||||
src_word_emb = src_word
|
||||
src_word_emb = fluid.layers.cast(src_word_emb, 'float32')
|
||||
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
|
||||
src_pos = paddle.squeeze(src_pos, axis=-1)
|
||||
src_pos_enc = self.emb(src_pos)
|
||||
src_pos_enc.stop_gradient = True
|
||||
enc_input = src_word_emb + src_pos_enc
|
||||
if self.dropout_rate:
|
||||
out = F.dropout(
|
||||
x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
|
||||
else:
|
||||
out = enc_input
|
||||
return out
|
||||
|
||||
|
||||
class PrepareDecoder(nn.Layer):
|
||||
def __init__(self,
|
||||
src_vocab_size,
|
||||
src_emb_dim,
|
||||
src_max_len,
|
||||
dropout_rate=0,
|
||||
bos_idx=0,
|
||||
word_emb_param_name=None,
|
||||
pos_enc_param_name=None):
|
||||
super(PrepareDecoder, self).__init__()
|
||||
self.src_emb_dim = src_emb_dim
|
||||
"""
|
||||
self.emb0 = Embedding(num_embeddings=src_vocab_size,
|
||||
embedding_dim=src_emb_dim)
|
||||
"""
|
||||
self.emb0 = paddle.nn.Embedding(
|
||||
num_embeddings=src_vocab_size,
|
||||
embedding_dim=self.src_emb_dim,
|
||||
weight_attr=paddle.ParamAttr(
|
||||
name=word_emb_param_name,
|
||||
initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
|
||||
self.emb1 = paddle.nn.Embedding(
|
||||
num_embeddings=src_max_len,
|
||||
embedding_dim=self.src_emb_dim,
|
||||
weight_attr=paddle.ParamAttr(name=pos_enc_param_name))
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, src_word, src_pos):
|
||||
src_word = fluid.layers.cast(src_word, 'int64')
|
||||
src_word = paddle.squeeze(src_word, axis=-1)
|
||||
src_word_emb = self.emb0(src_word)
|
||||
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
|
||||
src_pos = paddle.squeeze(src_pos, axis=-1)
|
||||
src_pos_enc = self.emb1(src_pos)
|
||||
src_pos_enc.stop_gradient = True
|
||||
enc_input = src_word_emb + src_pos_enc
|
||||
if self.dropout_rate:
|
||||
out = F.dropout(
|
||||
x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
|
||||
else:
|
||||
out = enc_input
|
||||
return out
|
||||
|
||||
|
||||
class FFN(nn.Layer):
|
||||
"""
|
||||
Feed-Forward Network
|
||||
"""
|
||||
|
||||
def __init__(self, d_inner_hid, d_model, dropout_rate):
|
||||
super(FFN, self).__init__()
|
||||
self.dropout_rate = dropout_rate
|
||||
self.fc1 = paddle.nn.Linear(
|
||||
in_features=d_model, out_features=d_inner_hid)
|
||||
self.fc2 = paddle.nn.Linear(
|
||||
in_features=d_inner_hid, out_features=d_model)
|
||||
|
||||
def forward(self, x):
|
||||
hidden = self.fc1(x)
|
||||
hidden = F.relu(hidden)
|
||||
if self.dropout_rate:
|
||||
hidden = F.dropout(
|
||||
hidden, p=self.dropout_rate, mode="downscale_in_infer")
|
||||
out = self.fc2(hidden)
|
||||
return out
|
|
@ -26,11 +26,12 @@ def build_post_process(config, global_config=None):
|
|||
from .db_postprocess import DBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -29,6 +29,9 @@ class BaseRecLabelDecode(object):
|
|||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||
support_character_type, character_type)
|
||||
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
|
||||
if character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
|
@ -104,7 +107,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob)
|
||||
|
@ -153,3 +155,83 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
||||
|
||||
class SRNLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path=None,
|
||||
character_type='en',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SRNLabelDecode, self).__init__(character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
pred = preds['predict']
|
||||
char_num = len(self.character_str) + 2
|
||||
if isinstance(pred, paddle.Tensor):
|
||||
pred = pred.numpy()
|
||||
pred = np.reshape(pred, [-1, char_num])
|
||||
|
||||
preds_idx = np.argmax(pred, axis=1)
|
||||
preds_prob = np.max(pred, axis=1)
|
||||
|
||||
preds_idx = np.reshape(preds_idx, [-1, 25])
|
||||
|
||||
preds_prob = np.reshape(preds_prob, [-1, 25])
|
||||
|
||||
text = self.decode(preds_idx, preds_prob)
|
||||
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=True):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] in ignored_tokens:
|
||||
continue
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text, np.mean(conf_list)))
|
||||
return result_list
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||
return dict_character
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
|
|
@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
|
|||
from tools.program import load_config, merge_config, ArgsParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", help="configuration file to use")
|
||||
parser.add_argument(
|
||||
"-o", "--output_path", type=str, default='./output/infer/')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
|
@ -51,14 +59,33 @@ def main():
|
|||
model.eval()
|
||||
|
||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||
infer_shape = [3, 32, 100] if config['Architecture'][
|
||||
'model_type'] != "det" else [3, 640, 640]
|
||||
model = to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
other_shape = [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
])
|
||||
shape=[None, 1, 64, 256], dtype='float32'), [
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 256, 1],
|
||||
dtype="int64"), paddle.static.InputSpec(
|
||||
shape=[None, 25, 1],
|
||||
dtype="int64"), paddle.static.InputSpec(
|
||||
shape=[None, 8, 25, 25], dtype="int64"),
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 8, 25, 25], dtype="int64")
|
||||
]
|
||||
]
|
||||
model = to_static(model, input_spec=other_shape)
|
||||
|
||||
else:
|
||||
infer_shape = [3, 32, 100] if config['Architecture'][
|
||||
'model_type'] != "det" else [3, 640, 640]
|
||||
model = to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
])
|
||||
|
||||
paddle.jit.save(model, save_path)
|
||||
logger.info('inference model is saved to {}'.format(save_path))
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import numpy as np
|
|||
import math
|
||||
import time
|
||||
import traceback
|
||||
import paddle
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.postprocess import build_post_process
|
||||
|
@ -46,6 +47,13 @@ class TextRecognizer(object):
|
|||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
if self.rec_algorithm == "SRN":
|
||||
postprocess_params = {
|
||||
'name': 'SRNLabelDecode',
|
||||
"character_type": args.rec_char_type,
|
||||
"character_dict_path": args.rec_char_dict_path,
|
||||
"use_space_char": args.use_space_char
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor, self.output_tensors = \
|
||||
utility.create_predictor(args, 'rec', logger)
|
||||
|
@ -70,6 +78,78 @@ class TextRecognizer(object):
|
|||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def resize_norm_img_srn(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
img_black = np.zeros((imgH, imgW))
|
||||
im_hei = img.shape[0]
|
||||
im_wid = img.shape[1]
|
||||
|
||||
if im_wid <= im_hei * 1:
|
||||
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||
elif im_wid <= im_hei * 2:
|
||||
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||
elif im_wid <= im_hei * 3:
|
||||
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||
else:
|
||||
img_new = cv2.resize(img, (imgW, imgH))
|
||||
|
||||
img_np = np.asarray(img_new)
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||
img_black[:, 0:img_np.shape[1]] = img_np
|
||||
img_black = img_black[:, :, np.newaxis]
|
||||
|
||||
row, col, c = img_black.shape
|
||||
c = 1
|
||||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||
(feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||
(max_text_length, 1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(
|
||||
gsrm_slf_attn_bias1,
|
||||
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(
|
||||
gsrm_slf_attn_bias2,
|
||||
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||
|
||||
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||
|
||||
return [
|
||||
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
|
||||
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
||||
norm_img = self.resize_norm_img_srn(img, image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
|
||||
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||
self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||
|
||||
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
||||
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
||||
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
||||
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
||||
|
||||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2)
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
|
@ -93,21 +173,64 @@ class TextRecognizer(object):
|
|||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
if self.rec_algorithm != "SRN":
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
else:
|
||||
norm_img = self.process_image_srn(
|
||||
img_list[indices[ino]], self.rec_image_shape, 8, 25)
|
||||
encoder_word_pos_list = []
|
||||
gsrm_word_pos_list = []
|
||||
gsrm_slf_attn_bias1_list = []
|
||||
gsrm_slf_attn_bias2_list = []
|
||||
encoder_word_pos_list.append(norm_img[1])
|
||||
gsrm_word_pos_list.append(norm_img[2])
|
||||
gsrm_slf_attn_bias1_list.append(norm_img[3])
|
||||
gsrm_slf_attn_bias2_list.append(norm_img[4])
|
||||
norm_img_batch.append(norm_img[0])
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
starttime = time.time()
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
preds = outputs[0]
|
||||
|
||||
if self.rec_algorithm == "SRN":
|
||||
starttime = time.time()
|
||||
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
|
||||
gsrm_slf_attn_bias1_list = np.concatenate(
|
||||
gsrm_slf_attn_bias1_list)
|
||||
gsrm_slf_attn_bias2_list = np.concatenate(
|
||||
gsrm_slf_attn_bias2_list)
|
||||
|
||||
inputs = [
|
||||
norm_img_batch,
|
||||
encoder_word_pos_list,
|
||||
gsrm_word_pos_list,
|
||||
gsrm_slf_attn_bias1_list,
|
||||
gsrm_slf_attn_bias2_list,
|
||||
]
|
||||
input_names = self.predictor.get_input_names()
|
||||
for i in range(len(input_names)):
|
||||
input_tensor = self.predictor.get_input_handle(input_names[
|
||||
i])
|
||||
input_tensor.copy_from_cpu(inputs[i])
|
||||
self.predictor.run()
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
preds = {"predict": outputs[2]}
|
||||
else:
|
||||
starttime = time.time()
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.run()
|
||||
|
||||
outputs = []
|
||||
for output_tensor in self.output_tensors:
|
||||
output = output_tensor.copy_to_cpu()
|
||||
outputs.append(output)
|
||||
preds = outputs[0]
|
||||
|
||||
rec_result = self.postprocess_op(preds)
|
||||
for rno in range(len(rec_result)):
|
||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||
|
|
|
@ -62,7 +62,13 @@ def main():
|
|||
elif op_name in ['RecResizeImg']:
|
||||
op[op_name]['infer_mode'] = True
|
||||
elif op_name == 'KeepKeys':
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
op[op_name]['keep_keys'] = [
|
||||
'image', 'encoder_word_pos', 'gsrm_word_pos',
|
||||
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
|
||||
]
|
||||
else:
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
transforms.append(op)
|
||||
global_config['infer_mode'] = True
|
||||
ops = create_operators(transforms, global_config)
|
||||
|
@ -74,10 +80,25 @@ def main():
|
|||
img = f.read()
|
||||
data = {'image': img}
|
||||
batch = transform(data, ops)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
|
||||
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
|
||||
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
|
||||
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
|
||||
|
||||
others = [
|
||||
paddle.to_tensor(encoder_word_pos_list),
|
||||
paddle.to_tensor(gsrm_word_pos_list),
|
||||
paddle.to_tensor(gsrm_slf_attn_bias1_list),
|
||||
paddle.to_tensor(gsrm_slf_attn_bias2_list)
|
||||
]
|
||||
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
preds = model(images)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
preds = model(images, others)
|
||||
else:
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds)
|
||||
for rec_reuslt in post_result:
|
||||
logger.info('\t result: {}'.format(rec_reuslt))
|
||||
|
|
|
@ -179,9 +179,9 @@ def train(config,
|
|||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
else:
|
||||
start_epoch = 1
|
||||
start_epoch = 0
|
||||
|
||||
for epoch in range(start_epoch, epoch_num + 1):
|
||||
for epoch in range(start_epoch, epoch_num):
|
||||
if epoch > 0:
|
||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||
train_batch_cost = 0.0
|
||||
|
@ -194,7 +194,11 @@ def train(config,
|
|||
break
|
||||
lr = optimizer.get_lr()
|
||||
images = batch[0]
|
||||
preds = model(images)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
others = batch[-4:]
|
||||
preds = model(images, others)
|
||||
else:
|
||||
preds = model(images)
|
||||
loss = loss_class(preds, batch)
|
||||
avg_loss = loss['loss']
|
||||
avg_loss.backward()
|
||||
|
@ -212,6 +216,7 @@ def train(config,
|
|||
stats['lr'] = lr
|
||||
train_stats.update(stats)
|
||||
|
||||
#cal_metric_during_train = False
|
||||
if cal_metric_during_train: # onlt rec and cls need
|
||||
batch = [item.numpy() for item in batch]
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
|
@ -312,8 +317,9 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
|||
if idx >= len(valid_dataloader):
|
||||
break
|
||||
images = batch[0]
|
||||
others = batch[-4:]
|
||||
start = time.time()
|
||||
preds = model(images)
|
||||
preds = model(images, others)
|
||||
|
||||
batch = [item.numpy() for item in batch]
|
||||
# Obtain usable results from post-processing methods
|
||||
|
|
Loading…
Reference in New Issue