modify transformeroptim, resize

This commit is contained in:
Topdu 2021-08-24 07:46:43 +00:00
parent 73058cc082
commit 2bf8ad9b7d
7 changed files with 35 additions and 51 deletions

View File

@ -43,7 +43,7 @@ Architecture:
name: MTB
cnn_num: 2
Head:
name: TransformerOptim
name: Transformer
d_model: 512
num_encoder_layers: 6
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
@ -69,8 +69,9 @@ Train:
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- PILResize:
- NRTRRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:
@ -88,8 +89,9 @@ Eval:
img_mode: BGR
channel_first: False
- NRTRLabelEncode: # Class handling label
- PILResize:
- NRTRRecResizeImg:
image_shape: [100, 32]
resize_type: PIL # PIL or OpenCV
- KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader:

View File

@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
from .randaugment import RandAugment
from .copy_paste import CopyPaste
from .operators import *

View File

@ -42,30 +42,21 @@ class ClsResizeImg(object):
data['image'] = norm_img
return data
class PILResize(object):
def __init__(self, image_shape, **kwargs):
class NRTRRecResizeImg(object):
def __init__(self, image_shape, resize_type, **kwargs):
self.image_shape = image_shape
self.resize_type = resize_type
def __call__(self, data):
img = data['image']
image_pil = Image.fromarray(np.uint8(img))
norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
norm_img = np.array(norm_img)
norm_img = np.expand_dims(norm_img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data
class CVResize(object):
def __init__(self, image_shape, **kwargs):
self.image_shape = image_shape
def __call__(self, data):
img = data['image']
#print(img)
norm_img = cv2.resize(img,self.image_shape)
norm_img = np.expand_dims(norm_img, -1)
if self.resize_type == 'PIL':
image_pil = Image.fromarray(np.uint8(img))
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
img = np.array(img)
if self.resize_type == 'OpenCV':
img = cv2.resize(img, self.image_shape)
norm_img = np.expand_dims(img, -1)
norm_img = norm_img.transpose((2, 0, 1))
data['image'] = norm_img.astype(np.float32) / 128. - 1.
return data

View File

@ -3,34 +3,26 @@ from paddle import nn
import paddle.nn.functional as F
def cal_performance(pred, tgt):
pred = pred.max(1)[1]
tgt = tgt.contiguous().view(-1)
non_pad_mask = tgt.ne(0)
n_correct = pred.eq(tgt)
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
return n_correct
class NRTRLoss(nn.Layer):
def __init__(self,smoothing=True, **kwargs):
def __init__(self, smoothing=True, **kwargs):
super(NRTRLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0)
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
self.smoothing = smoothing
def forward(self, pred, batch):
pred = pred.reshape([-1, pred.shape[2]])
max_len = batch[2].max()
tgt = batch[1][:,1:2+max_len]
tgt = tgt.reshape([-1] )
tgt = batch[1][:, 1:2 + max_len]
tgt = tgt.reshape([-1])
if self.smoothing:
eps = 0.1
n_class = pred.shape[1]
one_hot = F.one_hot(tgt, pred.shape[1])
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, axis=1)
non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64'))
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype='int64'))
loss = -(one_hot * log_prb).sum(axis=1)
loss = loss.masked_select(non_pad_mask).mean()
else:

View File

@ -26,13 +26,13 @@ def build_head(config):
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead
from .rec_nrtr_optim_head import TransformerOptim
from .rec_nrtr_head import Transformer
# cls head
from .cls_head import ClsHead
support_dict = [
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead'
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead'
]
#table head

View File

@ -24,7 +24,7 @@ zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)
class MultiheadAttentionOptim(nn.Layer):
class MultiheadAttention(nn.Layer):
"""Allows the model to jointly attend to information
from different representation subspaces.
See reference: Attention Is All You Need
@ -46,7 +46,7 @@ class MultiheadAttentionOptim(nn.Layer):
bias=True,
add_bias_kv=False,
add_zero_attn=False):
super(MultiheadAttentionOptim, self).__init__()
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout

View File

@ -21,7 +21,7 @@ from paddle.nn import LayerList
from paddle.nn.initializer import XavierNormal as xavier_uniform_
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
import numpy as np
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim
from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_
@ -29,7 +29,7 @@ zeros_ = constant_(value=0.)
ones_ = constant_(value=1.)
class TransformerOptim(nn.Layer):
class Transformer(nn.Layer):
"""A transformer model. User is able to modify the attributes as needed. The architechture
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
@ -63,7 +63,7 @@ class TransformerOptim(nn.Layer):
out_channels=0,
dst_vocab_size=99,
scale_embedding=True):
super(TransformerOptim, self).__init__()
super(Transformer, self).__init__()
self.embedding = Embeddings(
d_model=d_model,
vocab=dst_vocab_size,
@ -215,8 +215,7 @@ class TransformerOptim(nn.Layer):
n_curr_active_inst = len(curr_active_inst_idx)
new_shape = (n_curr_active_inst * n_bm, *d_hs)
beamed_tensor = beamed_tensor.reshape(
[n_prev_active_inst, -1])
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
beamed_tensor = beamed_tensor.index_select(
paddle.to_tensor(curr_active_inst_idx), axis=0)
beamed_tensor = beamed_tensor.reshape([*new_shape])
@ -486,7 +485,7 @@ class TransformerEncoderLayer(nn.Layer):
attention_dropout_rate=0.0,
residual_dropout_rate=0.1):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = MultiheadAttentionOptim(
self.self_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate)
self.conv1 = Conv2D(
@ -555,9 +554,9 @@ class TransformerDecoderLayer(nn.Layer):
attention_dropout_rate=0.0,
residual_dropout_rate=0.1):
super(TransformerDecoderLayer, self).__init__()
self.self_attn = MultiheadAttentionOptim(
self.self_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate)
self.multihead_attn = MultiheadAttentionOptim(
self.multihead_attn = MultiheadAttention(
d_model, nhead, dropout=attention_dropout_rate)
self.conv1 = Conv2D(