add rec_nrtr
This commit is contained in:
parent
b6f0a90366
commit
1623c17cdc
|
@ -3,22 +3,38 @@ Global:
|
|||
epoch_num: 21
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
<<<<<<< HEAD
|
||||
save_model_dir: ./output/rec/nrtr_final/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
=======
|
||||
save_model_dir: ./output/rec/piloptimnrtr/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: False
|
||||
>>>>>>> 9c67a7f... add rec_nrtr
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
<<<<<<< HEAD
|
||||
character_dict_path:
|
||||
character_type: EN_symbol
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: True
|
||||
=======
|
||||
character_dict_path: ppocr/utils/dict_99.txt
|
||||
character_type: dict_99
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
>>>>>>> 9c67a7f... add rec_nrtr
|
||||
save_res_path: ./output/rec/predicts_nrtr.txt
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)
|
||||
|
||||
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|
@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
||||
|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||
|
||||
|
||||
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
||||
|
|
|
@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
|||
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||
|
||||
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||
|
||||
|
|
|
@ -60,5 +60,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
||||
|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||
|
||||
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
||||
|
|
|
@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
|
|||
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||
|
||||
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||
|
||||
For training Chinese data, it is recommended to use
|
||||
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
|||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize
|
||||
from .randaugment import RandAugment
|
||||
from .operators import *
|
||||
from .label_ops import *
|
||||
|
|
|
@ -96,7 +96,7 @@ class BaseRecLabelEncode(object):
|
|||
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
|
||||
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
|
||||
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
|
||||
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
|
||||
]
|
||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||
support_character_type, character_type)
|
||||
|
|
|
@ -57,6 +57,38 @@ class DecodeImage(object):
|
|||
return data
|
||||
|
||||
|
||||
class NRTRDecodeImage(object):
|
||||
""" decode image """
|
||||
|
||||
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
|
||||
self.img_mode = img_mode
|
||||
self.channel_first = channel_first
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if six.PY2:
|
||||
assert type(img) is str and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
else:
|
||||
assert type(img) is bytes and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
|
||||
img = cv2.imdecode(img, 1)
|
||||
|
||||
if img is None:
|
||||
return None
|
||||
if self.img_mode == 'GRAY':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
class NormalizeImage(object):
|
||||
""" normalize image such as substract mean, divide std
|
||||
"""
|
||||
|
|
|
@ -16,7 +16,7 @@ import math
|
|||
import cv2
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from PIL import Image
|
||||
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
||||
|
||||
|
||||
|
@ -42,6 +42,34 @@ class ClsResizeImg(object):
|
|||
data['image'] = norm_img
|
||||
return data
|
||||
|
||||
class PILResize(object):
|
||||
def __init__(self, image_shape, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
image_pil = Image.fromarray(np.uint8(img))
|
||||
norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
|
||||
norm_img = np.array(norm_img)
|
||||
norm_img = np.expand_dims(norm_img, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||
return data
|
||||
|
||||
|
||||
class CVResize(object):
|
||||
def __init__(self, image_shape, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
#print(img)
|
||||
norm_img = cv2.resize(img,self.image_shape)
|
||||
norm_img = np.expand_dims(norm_img, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||
return data
|
||||
|
||||
|
||||
class RecResizeImg(object):
|
||||
def __init__(self,
|
||||
|
|
|
@ -25,7 +25,7 @@ from .det_sast_loss import SASTLoss
|
|||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_att_loss import AttentionLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
|
||||
from .rec_nrtr_loss import NRTRLoss
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
|
@ -42,8 +42,8 @@ from .combined_loss import CombinedLoss
|
|||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss'
|
||||
]
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('loss only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
def cal_performance(pred, tgt):
|
||||
|
||||
pred = pred.max(1)[1]
|
||||
tgt = tgt.contiguous().view(-1)
|
||||
non_pad_mask = tgt.ne(0)
|
||||
n_correct = pred.eq(tgt)
|
||||
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
|
||||
return n_correct
|
||||
|
||||
|
||||
class NRTRLoss(nn.Layer):
|
||||
def __init__(self,smoothing=True, **kwargs):
|
||||
super(NRTRLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0)
|
||||
self.smoothing = smoothing
|
||||
|
||||
def forward(self, pred, batch):
|
||||
pred = pred.reshape([-1, pred.shape[2]])
|
||||
max_len = batch[2].max()
|
||||
tgt = batch[1][:,1:2+max_len]
|
||||
tgt = tgt.reshape([-1] )
|
||||
if self.smoothing:
|
||||
eps = 0.1
|
||||
n_class = pred.shape[1]
|
||||
one_hot = F.one_hot(tgt, pred.shape[1])
|
||||
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||
log_prb = F.log_softmax(pred, axis=1)
|
||||
non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64'))
|
||||
loss = -(one_hot * log_prb).sum(axis=1)
|
||||
loss = loss.masked_select(non_pad_mask).mean()
|
||||
else:
|
||||
loss = self.loss_func(pred, tgt)
|
||||
return {'loss': loss}
|
|
@ -30,7 +30,7 @@ class RecMetric(object):
|
|||
target = target.replace(" ", "")
|
||||
norm_edit_dis += Levenshtein.distance(pred, target) / max(
|
||||
len(pred), len(target), 1)
|
||||
if pred == target:
|
||||
if pred.lower() == target.lower():
|
||||
correct_num += 1
|
||||
all_num += 1
|
||||
self.correct_num += correct_num
|
||||
|
@ -48,8 +48,8 @@ class RecMetric(object):
|
|||
'norm_edit_dis': 0,
|
||||
}
|
||||
"""
|
||||
acc = 1.0 * self.correct_num / self.all_num
|
||||
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
|
||||
acc = 1.0 * self.correct_num / (self.all_num)
|
||||
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num)
|
||||
self.reset()
|
||||
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
||||
|
||||
|
@ -57,3 +57,4 @@ class RecMetric(object):
|
|||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
self.norm_edit_dis = 0
|
||||
|
|
@ -14,7 +14,7 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from ppocr.modeling.transforms import build_transform
|
||||
from ppocr.modeling.backbones import build_backbone
|
||||
|
|
|
@ -25,7 +25,10 @@ def build_backbone(config, model_type):
|
|||
from .rec_mobilenet_v3 import MobileNetV3
|
||||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
|
||||
from .rec_nrtr_mtb import MTB
|
||||
from .rec_swin import SwinTransformer
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN','MTB','SwinTransformer']
|
||||
|
||||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
support_dict = ['ResNet']
|
||||
|
|
|
@ -0,0 +1,365 @@
|
|||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Linear
|
||||
from paddle.nn.initializer import XavierUniform as xavier_uniform_
|
||||
from paddle.nn.initializer import Constant as constant_
|
||||
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||
|
||||
zeros_ = constant_(value=0.)
|
||||
ones_ = constant_(value=1.)
|
||||
|
||||
class MultiheadAttention(nn.Layer):
|
||||
r"""Allows the model to jointly attend to information
|
||||
from different representation subspaces.
|
||||
See reference: Attention Is All You Need
|
||||
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model
|
||||
num_heads: parallel attention layers, or heads
|
||||
|
||||
Examples::
|
||||
|
||||
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = self.create_parameter(
|
||||
shape=(1, 1, embed_dim), default_initializer=zeros_)
|
||||
self.add_parameter("bias_k", self.bias_k)
|
||||
self.bias_v = self.create_parameter(
|
||||
shape=(1, 1, embed_dim), default_initializer=zeros_)
|
||||
self.add_parameter("bias_v", self.bias_v)
|
||||
else:
|
||||
self.bias_k = self.bias_v = None
|
||||
|
||||
self.add_zero_attn = add_zero_attn
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||
self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 2, kernel_size=(1, 1))
|
||||
self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 3, kernel_size=(1, 1))
|
||||
|
||||
def _reset_parameters(self):
|
||||
|
||||
|
||||
xavier_uniform_(self.out_proj.weight)
|
||||
if self.bias_k is not None:
|
||||
xavier_normal_(self.bias_k)
|
||||
if self.bias_v is not None:
|
||||
xavier_normal_(self.bias_v)
|
||||
|
||||
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
|
||||
need_weights=True, static_kv=False, attn_mask=None, qkv_ = [False,False,False]):
|
||||
"""
|
||||
Inputs of forward function
|
||||
query: [target length, batch size, embed dim]
|
||||
key: [sequence length, batch size, embed dim]
|
||||
value: [sequence length, batch size, embed dim]
|
||||
key_padding_mask: if True, mask padding based on batch size
|
||||
incremental_state: if provided, previous time steps are cashed
|
||||
need_weights: output attn_output_weights
|
||||
static_kv: key and value are static
|
||||
|
||||
Outputs of forward function
|
||||
attn_output: [target length, batch size, embed dim]
|
||||
attn_output_weights: [batch size, target length, sequence length]
|
||||
"""
|
||||
qkv_same = qkv_[0]
|
||||
kv_same = qkv_[1]
|
||||
|
||||
tgt_len, bsz, embed_dim = query.shape
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
||||
assert key.shape == value.shape
|
||||
|
||||
if qkv_same:
|
||||
# self-attention
|
||||
q, k, v = self._in_proj_qkv(query)
|
||||
elif kv_same:
|
||||
# encoder-decoder attention
|
||||
q = self._in_proj_q(query)
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = v = None
|
||||
else:
|
||||
k, v = self._in_proj_kv(key)
|
||||
else:
|
||||
q = self._in_proj_q(query)
|
||||
k = self._in_proj_k(key)
|
||||
v = self._in_proj_v(value)
|
||||
q *= self.scaling
|
||||
|
||||
if self.bias_k is not None:
|
||||
assert self.bias_v is not None
|
||||
self.bias_k = paddle.concat([self.bias_k for i in range(bsz)],axis=1)
|
||||
self.bias_v = paddle.concat([self.bias_v for i in range(bsz)],axis=1)
|
||||
k = paddle.concat([k, self.bias_k])
|
||||
v = paddle.concat([v, self.bias_v])
|
||||
if attn_mask is not None:
|
||||
attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = paddle.concat(
|
||||
[key_padding_mask,paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1)
|
||||
|
||||
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||
if k is not None:
|
||||
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||
if v is not None:
|
||||
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||
|
||||
|
||||
|
||||
src_len = k.shape[1]
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape[0] == bsz
|
||||
assert key_padding_mask.shape[1] == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
src_len += 1
|
||||
k = paddle.concat([k, paddle.zeros((k.shape[0], 1) + k.shape[2:],dtype=k.dtype)], axis=1)
|
||||
v = paddle.concat([v, paddle.zeros((v.shape[0], 1) + v.shape[2:],dtype=v.dtype)], axis=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = paddle.concat(
|
||||
[key_padding_mask, paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1)
|
||||
attn_output_weights = paddle.bmm(q, k.transpose([0,2,1]))
|
||||
assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_output_weights += attn_mask
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
||||
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
|
||||
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
|
||||
y = paddle.where(key==0.,key, y)
|
||||
attn_output_weights += y
|
||||
attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len])
|
||||
|
||||
attn_output_weights = F.softmax(
|
||||
attn_output_weights.astype('float32'), axis=-1,
|
||||
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype)
|
||||
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = paddle.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim])
|
||||
attn_output = self.out_proj(attn_output)
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
||||
attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads
|
||||
else:
|
||||
attn_output_weights = None
|
||||
|
||||
return attn_output, attn_output_weights
|
||||
|
||||
def _in_proj_qkv(self, query):
|
||||
query = query.transpose([1, 2, 0])
|
||||
query = paddle.unsqueeze(query, axis=2)
|
||||
res = self.conv3(query)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res.chunk(3, axis=-1)
|
||||
|
||||
def _in_proj_kv(self, key):
|
||||
key = key.transpose([1, 2, 0])
|
||||
key = paddle.unsqueeze(key, axis=2)
|
||||
res = self.conv2(key)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res.chunk(2, axis=-1)
|
||||
|
||||
def _in_proj_q(self, query):
|
||||
query = query.transpose([1, 2, 0])
|
||||
query = paddle.unsqueeze(query, axis=2)
|
||||
res = self.conv1(query)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res
|
||||
|
||||
def _in_proj_k(self, key):
|
||||
|
||||
key = key.transpose([1, 2, 0])
|
||||
key = paddle.unsqueeze(key, axis=2)
|
||||
res = self.conv1(key)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res
|
||||
|
||||
def _in_proj_v(self, value):
|
||||
|
||||
value = value.transpose([1,2,0])#(1, 2, 0)
|
||||
value = paddle.unsqueeze(value, axis=2)
|
||||
res = self.conv1(value)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res
|
||||
|
||||
|
||||
|
||||
class MultiheadAttentionOptim(nn.Layer):
|
||||
r"""Allows the model to jointly attend to information
|
||||
from different representation subspaces.
|
||||
See reference: Attention Is All You Need
|
||||
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model
|
||||
num_heads: parallel attention layers, or heads
|
||||
|
||||
Examples::
|
||||
|
||||
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
|
||||
super(MultiheadAttentionOptim, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||
self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||
self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||
|
||||
def _reset_parameters(self):
|
||||
|
||||
|
||||
xavier_uniform_(self.out_proj.weight)
|
||||
|
||||
|
||||
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
|
||||
need_weights=True, static_kv=False, attn_mask=None):
|
||||
"""
|
||||
Inputs of forward function
|
||||
query: [target length, batch size, embed dim]
|
||||
key: [sequence length, batch size, embed dim]
|
||||
value: [sequence length, batch size, embed dim]
|
||||
key_padding_mask: if True, mask padding based on batch size
|
||||
incremental_state: if provided, previous time steps are cashed
|
||||
need_weights: output attn_output_weights
|
||||
static_kv: key and value are static
|
||||
|
||||
Outputs of forward function
|
||||
attn_output: [target length, batch size, embed dim]
|
||||
attn_output_weights: [batch size, target length, sequence length]
|
||||
"""
|
||||
|
||||
|
||||
tgt_len, bsz, embed_dim = query.shape
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
||||
assert key.shape == value.shape
|
||||
|
||||
q = self._in_proj_q(query)
|
||||
k = self._in_proj_k(key)
|
||||
v = self._in_proj_v(value)
|
||||
q *= self.scaling
|
||||
|
||||
|
||||
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||
|
||||
|
||||
src_len = k.shape[1]
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape[0] == bsz
|
||||
assert key_padding_mask.shape[1] == src_len
|
||||
|
||||
|
||||
attn_output_weights = paddle.bmm(q, k.transpose([0,2,1]))
|
||||
assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_output_weights += attn_mask
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
||||
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
|
||||
|
||||
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
|
||||
|
||||
y = paddle.where(key==0.,key, y)
|
||||
|
||||
attn_output_weights += y
|
||||
attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len])
|
||||
|
||||
attn_output_weights = F.softmax(
|
||||
attn_output_weights.astype('float32'), axis=-1,
|
||||
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype)
|
||||
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = paddle.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim])
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
||||
attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads
|
||||
else:
|
||||
attn_output_weights = None
|
||||
|
||||
return attn_output, attn_output_weights
|
||||
|
||||
|
||||
def _in_proj_q(self, query):
|
||||
query = query.transpose([1, 2, 0])
|
||||
query = paddle.unsqueeze(query, axis=2)
|
||||
res = self.conv1(query)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res
|
||||
|
||||
def _in_proj_k(self, key):
|
||||
|
||||
key = key.transpose([1, 2, 0])
|
||||
key = paddle.unsqueeze(key, axis=2)
|
||||
res = self.conv2(key)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res
|
||||
|
||||
def _in_proj_v(self, value):
|
||||
|
||||
value = value.transpose([1,2,0])#(1, 2, 0)
|
||||
value = paddle.unsqueeze(value, axis=2)
|
||||
res = self.conv3(value)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res
|
|
@ -0,0 +1,28 @@
|
|||
from paddle import nn
|
||||
|
||||
class MTB(nn.Layer):
|
||||
def __init__(self, cnn_num, in_channels):
|
||||
super(MTB, self).__init__()
|
||||
self.block = nn.Sequential()
|
||||
self.out_channels = in_channels
|
||||
self.cnn_num = cnn_num
|
||||
if self.cnn_num == 2:
|
||||
for i in range(self.cnn_num):
|
||||
self.block.add_sublayer('conv_{}'.format(i), nn.Conv2D(
|
||||
in_channels = in_channels if i == 0 else 32*(2**(i-1)),
|
||||
out_channels = 32*(2**i),
|
||||
kernel_size = 3,
|
||||
stride = 2,
|
||||
padding=1))
|
||||
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
|
||||
self.block.add_sublayer('bn_{}'.format(i), nn.BatchNorm2D(32*(2**i)))
|
||||
|
||||
def forward(self, images):
|
||||
|
||||
x = self.block(images)
|
||||
if self.cnn_num == 2:
|
||||
# (b, w, h, c)
|
||||
x = x.transpose([0, 3, 2, 1])
|
||||
x_shape = x.shape
|
||||
x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
|
||||
return x
|
|
@ -7,7 +7,11 @@ from paddle.nn import LayerList
|
|||
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
||||
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
|
||||
import numpy as np
|
||||
<<<<<<< HEAD
|
||||
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim
|
||||
=======
|
||||
from ppocr.modeling.backbones.multiheadAttention import MultiheadAttentionOptim
|
||||
>>>>>>> 9c67a7f... add rec_nrtr
|
||||
from paddle.nn.initializer import Constant as constant_
|
||||
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ def build_neck(config):
|
|||
from .sast_fpn import SASTFPN
|
||||
from .rnn import SequenceEncoder
|
||||
from .pg_fpn import PGFPN
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN','TFEncoder']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||
|
|
|
@ -24,16 +24,15 @@ __all__ = ['build_post_process']
|
|||
from .db_postprocess import DBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode'
|
||||
'DistillationCTCLabelDecode', 'NRTRLabelDecode'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
|
|||
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
|
||||
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
|
||||
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
|
||||
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
|
||||
]
|
||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||
support_character_type, character_type)
|
||||
|
@ -256,8 +256,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
!
|
||||
"
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
)
|
||||
*
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
:
|
||||
;
|
||||
<
|
||||
=
|
||||
>
|
||||
?
|
||||
@
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
[
|
||||
\
|
||||
]
|
||||
^
|
||||
_
|
||||
`
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
{
|
||||
|
|
||||
}
|
||||
~
|
||||
|
|
@ -22,6 +22,7 @@ import sys
|
|||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
|
@ -30,6 +31,7 @@ from ppocr.utils.save_load import init_model
|
|||
from ppocr.utils.utility import print_dict
|
||||
import tools.program as program
|
||||
|
||||
|
||||
def main():
|
||||
global_config = config['Global']
|
||||
# build dataloader
|
||||
|
|
|
@ -186,7 +186,7 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
|
||||
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
|
||||
if 'start_epoch' in best_model_dict:
|
||||
start_epoch = best_model_dict['start_epoch']
|
||||
else:
|
||||
|
@ -211,6 +211,9 @@ def train(config,
|
|||
others = batch[-4:]
|
||||
preds = model(images, others)
|
||||
model_average = True
|
||||
elif use_nrtr:
|
||||
max_len = batch[2].max()
|
||||
preds = model(images, batch[1][:,:2+max_len])
|
||||
else:
|
||||
preds = model(images)
|
||||
loss = loss_class(preds, batch)
|
||||
|
@ -350,13 +353,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
|
|||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
|
||||
if use_srn:
|
||||
others = batch[-4:]
|
||||
preds = model(images, others)
|
||||
else:
|
||||
preds = model(images)
|
||||
|
||||
batch = [item.numpy() for item in batch]
|
||||
# Obtain usable results from post-processing methods
|
||||
post_result = post_process_class(preds, batch[1])
|
||||
|
@ -386,7 +387,7 @@ def preprocess(is_train=False):
|
|||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation'
|
||||
'CLS', 'PGNet', 'Distillation','NRTR'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
Loading…
Reference in New Issue