modify transformeroptim, resize
This commit is contained in:
parent
73058cc082
commit
2bf8ad9b7d
|
@ -43,7 +43,7 @@ Architecture:
|
|||
name: MTB
|
||||
cnn_num: 2
|
||||
Head:
|
||||
name: TransformerOptim
|
||||
name: Transformer
|
||||
d_model: 512
|
||||
num_encoder_layers: 6
|
||||
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
|
||||
|
@ -69,8 +69,9 @@ Train:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- NRTRLabelEncode: # Class handling label
|
||||
- PILResize:
|
||||
- NRTRRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
|
@ -88,8 +89,9 @@ Eval:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- NRTRLabelEncode: # Class handling label
|
||||
- PILResize:
|
||||
- NRTRRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
|||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .operators import *
|
||||
|
|
|
@ -42,30 +42,21 @@ class ClsResizeImg(object):
|
|||
data['image'] = norm_img
|
||||
return data
|
||||
|
||||
class PILResize(object):
|
||||
def __init__(self, image_shape, **kwargs):
|
||||
|
||||
class NRTRRecResizeImg(object):
|
||||
def __init__(self, image_shape, resize_type, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.resize_type = resize_type
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
image_pil = Image.fromarray(np.uint8(img))
|
||||
norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
|
||||
norm_img = np.array(norm_img)
|
||||
norm_img = np.expand_dims(norm_img, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||
return data
|
||||
|
||||
|
||||
class CVResize(object):
|
||||
def __init__(self, image_shape, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
#print(img)
|
||||
norm_img = cv2.resize(img,self.image_shape)
|
||||
norm_img = np.expand_dims(norm_img, -1)
|
||||
if self.resize_type == 'PIL':
|
||||
image_pil = Image.fromarray(np.uint8(img))
|
||||
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
|
||||
img = np.array(img)
|
||||
if self.resize_type == 'OpenCV':
|
||||
img = cv2.resize(img, self.image_shape)
|
||||
norm_img = np.expand_dims(img, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||
return data
|
||||
|
|
|
@ -3,34 +3,26 @@ from paddle import nn
|
|||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
def cal_performance(pred, tgt):
|
||||
|
||||
pred = pred.max(1)[1]
|
||||
tgt = tgt.contiguous().view(-1)
|
||||
non_pad_mask = tgt.ne(0)
|
||||
n_correct = pred.eq(tgt)
|
||||
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
|
||||
return n_correct
|
||||
|
||||
|
||||
class NRTRLoss(nn.Layer):
|
||||
def __init__(self,smoothing=True, **kwargs):
|
||||
def __init__(self, smoothing=True, **kwargs):
|
||||
super(NRTRLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0)
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
|
||||
self.smoothing = smoothing
|
||||
|
||||
def forward(self, pred, batch):
|
||||
pred = pred.reshape([-1, pred.shape[2]])
|
||||
max_len = batch[2].max()
|
||||
tgt = batch[1][:,1:2+max_len]
|
||||
tgt = tgt.reshape([-1] )
|
||||
tgt = batch[1][:, 1:2 + max_len]
|
||||
tgt = tgt.reshape([-1])
|
||||
if self.smoothing:
|
||||
eps = 0.1
|
||||
n_class = pred.shape[1]
|
||||
one_hot = F.one_hot(tgt, pred.shape[1])
|
||||
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||
log_prb = F.log_softmax(pred, axis=1)
|
||||
non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64'))
|
||||
non_pad_mask = paddle.not_equal(
|
||||
tgt, paddle.zeros(
|
||||
tgt.shape, dtype='int64'))
|
||||
loss = -(one_hot * log_prb).sum(axis=1)
|
||||
loss = loss.masked_select(non_pad_mask).mean()
|
||||
else:
|
||||
|
|
|
@ -26,13 +26,13 @@ def build_head(config):
|
|||
from .rec_ctc_head import CTCHead
|
||||
from .rec_att_head import AttentionHead
|
||||
from .rec_srn_head import SRNHead
|
||||
from .rec_nrtr_optim_head import TransformerOptim
|
||||
from .rec_nrtr_head import Transformer
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead'
|
||||
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -24,7 +24,7 @@ zeros_ = constant_(value=0.)
|
|||
ones_ = constant_(value=1.)
|
||||
|
||||
|
||||
class MultiheadAttentionOptim(nn.Layer):
|
||||
class MultiheadAttention(nn.Layer):
|
||||
"""Allows the model to jointly attend to information
|
||||
from different representation subspaces.
|
||||
See reference: Attention Is All You Need
|
||||
|
@ -46,7 +46,7 @@ class MultiheadAttentionOptim(nn.Layer):
|
|||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False):
|
||||
super(MultiheadAttentionOptim, self).__init__()
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
|
|
|
@ -21,7 +21,7 @@ from paddle.nn import LayerList
|
|||
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
||||
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
|
||||
import numpy as np
|
||||
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim
|
||||
from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
|
||||
from paddle.nn.initializer import Constant as constant_
|
||||
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||
|
||||
|
@ -29,7 +29,7 @@ zeros_ = constant_(value=0.)
|
|||
ones_ = constant_(value=1.)
|
||||
|
||||
|
||||
class TransformerOptim(nn.Layer):
|
||||
class Transformer(nn.Layer):
|
||||
"""A transformer model. User is able to modify the attributes as needed. The architechture
|
||||
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
||||
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
||||
|
@ -63,7 +63,7 @@ class TransformerOptim(nn.Layer):
|
|||
out_channels=0,
|
||||
dst_vocab_size=99,
|
||||
scale_embedding=True):
|
||||
super(TransformerOptim, self).__init__()
|
||||
super(Transformer, self).__init__()
|
||||
self.embedding = Embeddings(
|
||||
d_model=d_model,
|
||||
vocab=dst_vocab_size,
|
||||
|
@ -215,8 +215,7 @@ class TransformerOptim(nn.Layer):
|
|||
n_curr_active_inst = len(curr_active_inst_idx)
|
||||
new_shape = (n_curr_active_inst * n_bm, *d_hs)
|
||||
|
||||
beamed_tensor = beamed_tensor.reshape(
|
||||
[n_prev_active_inst, -1])
|
||||
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
|
||||
beamed_tensor = beamed_tensor.index_select(
|
||||
paddle.to_tensor(curr_active_inst_idx), axis=0)
|
||||
beamed_tensor = beamed_tensor.reshape([*new_shape])
|
||||
|
@ -486,7 +485,7 @@ class TransformerEncoderLayer(nn.Layer):
|
|||
attention_dropout_rate=0.0,
|
||||
residual_dropout_rate=0.1):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
self.self_attn = MultiheadAttentionOptim(
|
||||
self.self_attn = MultiheadAttention(
|
||||
d_model, nhead, dropout=attention_dropout_rate)
|
||||
|
||||
self.conv1 = Conv2D(
|
||||
|
@ -555,9 +554,9 @@ class TransformerDecoderLayer(nn.Layer):
|
|||
attention_dropout_rate=0.0,
|
||||
residual_dropout_rate=0.1):
|
||||
super(TransformerDecoderLayer, self).__init__()
|
||||
self.self_attn = MultiheadAttentionOptim(
|
||||
self.self_attn = MultiheadAttention(
|
||||
d_model, nhead, dropout=attention_dropout_rate)
|
||||
self.multihead_attn = MultiheadAttentionOptim(
|
||||
self.multihead_attn = MultiheadAttention(
|
||||
d_model, nhead, dropout=attention_dropout_rate)
|
||||
|
||||
self.conv1 = Conv2D(
|
Loading…
Reference in New Issue