commit
9a44e279fe
|
@ -0,0 +1,99 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 5
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 20
|
||||
save_model_dir: ./sar_rec
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict90.txt
|
||||
character_type: EN_symbol
|
||||
max_text_length: 30
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
rm_symbol: True
|
||||
save_res_path: ./output/rec/predicts_sar.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs: [3, 4]
|
||||
values: [0.001, 0.0001, 0.00001]
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: SAR
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet31
|
||||
Head:
|
||||
name: SARHead
|
||||
|
||||
Loss:
|
||||
name: SARLoss
|
||||
|
||||
PostProcess:
|
||||
name: SARLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
label_file_list: ['./train_data/train_list.txt']
|
||||
data_dir: ./train_data/
|
||||
ratio_list: 1.0
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SARLabelEncode: # Class handling label
|
||||
- SARRecResizeImg:
|
||||
image_shape: [3, 48, 48, 160] # h:48 w:[48,160]
|
||||
width_downsample_ratio: 0.25
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 64
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
use_shared_memory: False
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./eval_data/evaluation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SARLabelEncode: # Class handling label
|
||||
- SARRecResizeImg:
|
||||
image_shape: [3, 48, 48, 160]
|
||||
width_downsample_ratio: 0.25
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'valid_ratio'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 64
|
||||
num_workers: 4
|
||||
use_shared_memory: False
|
||||
|
|
@ -45,6 +45,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
|
||||
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))
|
||||
|
||||
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|
@ -60,6 +61,6 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
||||
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||
|
||||
|SAR|Resnet31| 87.2% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|
||||
|
||||
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
||||
|
|
|
@ -86,7 +86,10 @@ train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单
|
|||
|
||||
若您本地没有数据集,可以在官网下载 [ICDAR2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,下载 benchmark 所需的lmdb格式数据集。
|
||||
|
||||
如果希望复现SAR的论文指标,需要下载[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg), 提取码:627x。此外,真实数据集icdar2013, icdar2015, cocotext, IIIT5也作为训练数据的一部分。具体数据细节可以参考论文SAR。
|
||||
|
||||
如果你使用的是icdar2015的公开数据集,PaddleOCR 提供了一份用于训练 ICDAR2015 数据集的标签文件,通过以下方式下载:
|
||||
|
||||
```
|
||||
# 训练集标签
|
||||
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_train.txt
|
||||
|
@ -230,6 +233,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
|||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
|
||||
|
||||
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ PaddleOCR open-source text recognition algorithms list:
|
|||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
|
||||
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))
|
||||
|
||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
||||
|
@ -62,5 +63,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
||||
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||
|SAR|Resnet31| 87.2% | rec_r31_sar | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|
||||
|
||||
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
||||
|
|
|
@ -92,6 +92,8 @@ Similar to the training set, the test set also needs to be provided a folder con
|
|||
If you do not have a dataset locally, you can download it on the official website [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads).
|
||||
Also refer to [DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,download the lmdb format dataset required for benchmark
|
||||
|
||||
If you want to reproduce the paper SAR, you need to download extra dataset [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg), extraction code: 627x. Besides, icdar2013, icdar2015, cocotext, IIIT5k datasets are also used to train. For specific details, please refer to the paper SAR.
|
||||
|
||||
PaddleOCR provides label files for training the icdar2015 dataset, which can be downloaded in the following ways:
|
||||
|
||||
```
|
||||
|
@ -236,6 +238,8 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
|
|||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
|
||||
|
||||
|
||||
For training Chinese data, it is recommended to use
|
||||
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
|||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .operators import *
|
||||
|
|
|
@ -549,3 +549,49 @@ class TableLabelEncode(object):
|
|||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
||||
|
||||
class SARLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SARLabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
beg_end_str = "<BOS/EOS>"
|
||||
unknown_str = "<UKN>"
|
||||
padding_str = "<PAD>"
|
||||
dict_character = dict_character + [unknown_str]
|
||||
self.unknown_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [beg_end_str]
|
||||
self.start_idx = len(dict_character) - 1
|
||||
self.end_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [padding_str]
|
||||
self.padding_idx = len(dict_character) - 1
|
||||
|
||||
return dict_character
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len - 1:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
target = [self.start_idx] + text + [self.end_idx]
|
||||
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
|
||||
|
||||
padded_text[:len(target)] = target
|
||||
data['label'] = np.array(padded_text)
|
||||
return data
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [self.padding_idx]
|
||||
|
|
|
@ -102,6 +102,56 @@ class SRNRecResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class SARRecResizeImg(object):
|
||||
def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.width_downsample_ratio = width_downsample_ratio
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(img, self.image_shape, self.width_downsample_ratio)
|
||||
data['image'] = norm_img
|
||||
data['resized_shape'] = resize_shape
|
||||
data['pad_shape'] = pad_shape
|
||||
data['valid_ratio'] = valid_ratio
|
||||
return data
|
||||
|
||||
|
||||
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
valid_ratio = 1.0
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
width_divisor = int(1 / width_downsample_ratio)
|
||||
# resize
|
||||
ratio = w / float(h)
|
||||
resize_w = math.ceil(imgH * ratio)
|
||||
if resize_w % width_divisor != 0:
|
||||
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||
if imgW_min is not None:
|
||||
resize_w = max(imgW_min, resize_w)
|
||||
if imgW_max is not None:
|
||||
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||
resize_w = min(imgW_max, resize_w)
|
||||
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
# norm
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resize_shape = resized_image.shape
|
||||
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||
padding_im[:, :, 0:resize_w] = resized_image
|
||||
pad_shape = padding_im.shape
|
||||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
|
||||
def resize_norm_img(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
h = img.shape[0]
|
||||
|
|
|
@ -26,6 +26,7 @@ from .rec_ctc_loss import CTCLoss
|
|||
from .rec_att_loss import AttentionLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
from .rec_nrtr_loss import NRTRLoss
|
||||
from .rec_sar_loss import SARLoss
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
|
@ -44,7 +45,7 @@ from .table_att_loss import TableAttentionLoss
|
|||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss'
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss', 'SARLoss'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class SARLoss(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super(SARLoss, self).__init__()
|
||||
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="mean", ignore_index=96)
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets
|
||||
label = batch[1].astype("int64")[:, 1:] # ignore first index of target in loss calculation
|
||||
batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
|
||||
1], predict.shape[2]
|
||||
assert len(label.shape) == len(list(predict.shape)) - 1, \
|
||||
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||
|
||||
inputs = paddle.reshape(predict, [-1, num_classes])
|
||||
targets = paddle.reshape(label, [-1])
|
||||
loss = self.loss_func(inputs, targets)
|
||||
return {'loss': loss}
|
|
@ -27,8 +27,9 @@ def build_backbone(config, model_type):
|
|||
from .rec_resnet_fpn import ResNetFPN
|
||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
from .rec_nrtr_mtb import MTB
|
||||
from .rec_resnet_31 import ResNet31
|
||||
support_dict = [
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB'
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', "ResNet31"
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
__all__ = ["ResNet31"]
|
||||
|
||||
|
||||
def conv3x3(in_channel, out_channel, stride=1):
|
||||
return nn.Conv2D(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias_attr=False
|
||||
)
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
expansion = 1
|
||||
def __init__(self, in_channels, channels, stride=1, downsample=False):
|
||||
super().__init__()
|
||||
self.conv1 = conv3x3(in_channels, channels, stride)
|
||||
self.bn1 = nn.BatchNorm2D(channels)
|
||||
self.relu = nn.ReLU()
|
||||
self.conv2 = conv3x3(channels, channels)
|
||||
self.bn2 = nn.BatchNorm2D(channels)
|
||||
self.downsample = downsample
|
||||
if downsample:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2D(in_channels, channels * self.expansion, 1, stride, bias_attr=False),
|
||||
nn.BatchNorm2D(channels * self.expansion),
|
||||
)
|
||||
else:
|
||||
self.downsample = nn.Sequential()
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet31(nn.Layer):
|
||||
'''
|
||||
Args:
|
||||
in_channels (int): Number of channels of input image tensor.
|
||||
layers (list[int]): List of BasicBlock number for each stage.
|
||||
channels (list[int]): List of out_channels of Conv2d layer.
|
||||
out_indices (None | Sequence[int]): Indices of output stages.
|
||||
last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
|
||||
'''
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
layers=[1, 2, 5, 3],
|
||||
channels=[64, 128, 256, 256, 512, 512, 512],
|
||||
out_indices=None,
|
||||
last_stage_pool=False):
|
||||
super(ResNet31, self).__init__()
|
||||
assert isinstance(in_channels, int)
|
||||
assert isinstance(last_stage_pool, bool)
|
||||
|
||||
self.out_indices = out_indices
|
||||
self.last_stage_pool = last_stage_pool
|
||||
|
||||
# conv 1 (Conv Conv)
|
||||
self.conv1_1 = nn.Conv2D(in_channels, channels[0], kernel_size=3, stride=1, padding=1)
|
||||
self.bn1_1 = nn.BatchNorm2D(channels[0])
|
||||
self.relu1_1 = nn.ReLU()
|
||||
|
||||
self.conv1_2 = nn.Conv2D(channels[0], channels[1], kernel_size=3, stride=1, padding=1)
|
||||
self.bn1_2 = nn.BatchNorm2D(channels[1])
|
||||
self.relu1_2 = nn.ReLU()
|
||||
|
||||
# conv 2 (Max-pooling, Residual block, Conv)
|
||||
self.pool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self.block2 = self._make_layer(channels[1], channels[2], layers[0])
|
||||
self.conv2 = nn.Conv2D(channels[2], channels[2], kernel_size=3, stride=1, padding=1)
|
||||
self.bn2 = nn.BatchNorm2D(channels[2])
|
||||
self.relu2 = nn.ReLU()
|
||||
|
||||
# conv 3 (Max-pooling, Residual block, Conv)
|
||||
self.pool3 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self.block3 = self._make_layer(channels[2], channels[3], layers[1])
|
||||
self.conv3 = nn.Conv2D(channels[3], channels[3], kernel_size=3, stride=1, padding=1)
|
||||
self.bn3 = nn.BatchNorm2D(channels[3])
|
||||
self.relu3 = nn.ReLU()
|
||||
|
||||
# conv 4 (Max-pooling, Residual block, Conv)
|
||||
self.pool4 = nn.MaxPool2D(kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
|
||||
self.block4 = self._make_layer(channels[3], channels[4], layers[2])
|
||||
self.conv4 = nn.Conv2D(channels[4], channels[4], kernel_size=3, stride=1, padding=1)
|
||||
self.bn4 = nn.BatchNorm2D(channels[4])
|
||||
self.relu4 = nn.ReLU()
|
||||
|
||||
# conv 5 ((Max-pooling), Residual block, Conv)
|
||||
self.pool5 = None
|
||||
if self.last_stage_pool:
|
||||
self.pool5 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self.block5 = self._make_layer(channels[4], channels[5], layers[3])
|
||||
self.conv5 = nn.Conv2D(channels[5], channels[5], kernel_size=3, stride=1, padding=1)
|
||||
self.bn5 = nn.BatchNorm2D(channels[5])
|
||||
self.relu5 = nn.ReLU()
|
||||
|
||||
self.out_channels = channels[-1]
|
||||
|
||||
def _make_layer(self, input_channels, output_channels, blocks):
|
||||
layers = []
|
||||
for _ in range(blocks):
|
||||
downsample = None
|
||||
if input_channels != output_channels:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2D(
|
||||
input_channels,
|
||||
output_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
bias_attr=False),
|
||||
nn.BatchNorm2D(output_channels),
|
||||
)
|
||||
|
||||
layers.append(BasicBlock(input_channels, output_channels, downsample=downsample))
|
||||
input_channels = output_channels
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1_1(x)
|
||||
x = self.bn1_1(x)
|
||||
x = self.relu1_1(x)
|
||||
|
||||
x = self.conv1_2(x)
|
||||
x = self.bn1_2(x)
|
||||
x = self.relu1_2(x)
|
||||
|
||||
outs = []
|
||||
for i in range(4):
|
||||
layer_index = i + 2
|
||||
pool_layer = getattr(self, f'pool{layer_index}')
|
||||
block_layer = getattr(self, f'block{layer_index}')
|
||||
conv_layer = getattr(self, f'conv{layer_index}')
|
||||
bn_layer = getattr(self, f'bn{layer_index}')
|
||||
relu_layer = getattr(self, f'relu{layer_index}')
|
||||
|
||||
if pool_layer is not None:
|
||||
x = pool_layer(x)
|
||||
x = block_layer(x)
|
||||
x = conv_layer(x)
|
||||
x = bn_layer(x)
|
||||
x= relu_layer(x)
|
||||
|
||||
outs.append(x)
|
||||
|
||||
if self.out_indices is not None:
|
||||
return tuple([outs[i] for i in self.out_indices])
|
||||
|
||||
return x
|
|
@ -27,12 +27,13 @@ def build_head(config):
|
|||
from .rec_att_head import AttentionHead
|
||||
from .rec_srn_head import SRNHead
|
||||
from .rec_nrtr_head import Transformer
|
||||
from .rec_sar_head import SARHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead'
|
||||
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead', 'SARHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
|
|
|
@ -0,0 +1,383 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class SAREncoder(nn.Layer):
|
||||
"""
|
||||
Args:
|
||||
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
||||
enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
|
||||
enc_gru (bool): If True, use GRU, else LSTM in encoder.
|
||||
d_model (int): Dim of channels from backbone.
|
||||
d_enc (int): Dim of encoder RNN layer.
|
||||
mask (bool): If True, mask padding in RNN sequence.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
enc_bi_rnn=False,
|
||||
enc_drop_rnn=0.1,
|
||||
enc_gru=False,
|
||||
d_model=512,
|
||||
d_enc=512,
|
||||
mask=True,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
assert isinstance(enc_bi_rnn, bool)
|
||||
assert isinstance(enc_drop_rnn, (int, float))
|
||||
assert 0 <= enc_drop_rnn < 1.0
|
||||
assert isinstance(enc_gru, bool)
|
||||
assert isinstance(d_model, int)
|
||||
assert isinstance(d_enc, int)
|
||||
assert isinstance(mask, bool)
|
||||
|
||||
self.enc_bi_rnn = enc_bi_rnn
|
||||
self.enc_drop_rnn = enc_drop_rnn
|
||||
self.mask = mask
|
||||
|
||||
# LSTM Encoder
|
||||
if enc_bi_rnn:
|
||||
direction = 'bidirectional'
|
||||
else:
|
||||
direction = 'forward'
|
||||
kwargs = dict(
|
||||
input_size=d_model,
|
||||
hidden_size=d_enc,
|
||||
num_layers=2,
|
||||
time_major=False,
|
||||
dropout=enc_drop_rnn,
|
||||
direction=direction)
|
||||
if enc_gru:
|
||||
self.rnn_encoder = nn.GRU(**kwargs)
|
||||
else:
|
||||
self.rnn_encoder = nn.LSTM(**kwargs)
|
||||
|
||||
# global feature transformation
|
||||
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
||||
self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
|
||||
|
||||
def forward(self, feat, img_metas=None):
|
||||
if img_metas is not None:
|
||||
assert len(img_metas[0]) == feat.shape[0]
|
||||
|
||||
valid_ratios = None
|
||||
if img_metas is not None and self.mask:
|
||||
valid_ratios = img_metas[-1]
|
||||
|
||||
h_feat = feat.shape[2] # bsz c h w
|
||||
feat_v = F.max_pool2d(
|
||||
feat, kernel_size=(h_feat, 1), stride=1, padding=0)
|
||||
feat_v = feat_v.squeeze(2) # bsz * C * W
|
||||
feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
|
||||
holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
|
||||
|
||||
if valid_ratios is not None:
|
||||
valid_hf = []
|
||||
T = holistic_feat.shape[1]
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_step = min(T, math.ceil(T * valid_ratio)) - 1
|
||||
valid_hf.append(holistic_feat[i, valid_step, :])
|
||||
valid_hf = paddle.stack(valid_hf, axis=0)
|
||||
else:
|
||||
valid_hf = holistic_feat[:, -1, :] # bsz * C
|
||||
holistic_feat = self.linear(valid_hf) # bsz * C
|
||||
|
||||
return holistic_feat
|
||||
|
||||
|
||||
class BaseDecoder(nn.Layer):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
def forward_train(self, feat, out_enc, targets, img_metas):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self,
|
||||
feat,
|
||||
out_enc,
|
||||
label=None,
|
||||
img_metas=None,
|
||||
train_mode=True):
|
||||
self.train_mode = train_mode
|
||||
|
||||
if train_mode:
|
||||
return self.forward_train(feat, out_enc, label, img_metas)
|
||||
return self.forward_test(feat, out_enc, img_metas)
|
||||
|
||||
|
||||
class ParallelSARDecoder(BaseDecoder):
|
||||
"""
|
||||
Args:
|
||||
out_channels (int): Output class number.
|
||||
enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
|
||||
dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
|
||||
dec_drop_rnn (float): Dropout of RNN layer in decoder.
|
||||
dec_gru (bool): If True, use GRU, else LSTM in decoder.
|
||||
d_model (int): Dim of channels from backbone.
|
||||
d_enc (int): Dim of encoder RNN layer.
|
||||
d_k (int): Dim of channels of attention module.
|
||||
pred_dropout (float): Dropout probability of prediction layer.
|
||||
max_seq_len (int): Maximum sequence length for decoding.
|
||||
mask (bool): If True, mask padding in feature map.
|
||||
start_idx (int): Index of start token.
|
||||
padding_idx (int): Index of padding token.
|
||||
pred_concat (bool): If True, concat glimpse feature from
|
||||
attention with holistic feature and hidden state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_channels, # 90 + unknown + start + padding
|
||||
enc_bi_rnn=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_drop_rnn=0.0,
|
||||
dec_gru=False,
|
||||
d_model=512,
|
||||
d_enc=512,
|
||||
d_k=64,
|
||||
pred_dropout=0.1,
|
||||
max_text_length=30,
|
||||
mask=True,
|
||||
pred_concat=True,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = out_channels
|
||||
self.enc_bi_rnn = enc_bi_rnn
|
||||
self.d_k = d_k
|
||||
self.start_idx = out_channels - 2
|
||||
self.padding_idx = out_channels - 1
|
||||
self.max_seq_len = max_text_length
|
||||
self.mask = mask
|
||||
self.pred_concat = pred_concat
|
||||
|
||||
encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
|
||||
decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
|
||||
|
||||
# 2D attention layer
|
||||
self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
|
||||
self.conv3x3_1 = nn.Conv2D(
|
||||
d_model, d_k, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1x1_2 = nn.Linear(d_k, 1)
|
||||
|
||||
# Decoder RNN layer
|
||||
if dec_bi_rnn:
|
||||
direction = 'bidirectional'
|
||||
else:
|
||||
direction = 'forward'
|
||||
|
||||
kwargs = dict(
|
||||
input_size=encoder_rnn_out_size,
|
||||
hidden_size=encoder_rnn_out_size,
|
||||
num_layers=2,
|
||||
time_major=False,
|
||||
dropout=dec_drop_rnn,
|
||||
direction=direction)
|
||||
if dec_gru:
|
||||
self.rnn_decoder = nn.GRU(**kwargs)
|
||||
else:
|
||||
self.rnn_decoder = nn.LSTM(**kwargs)
|
||||
|
||||
# Decoder input embedding
|
||||
self.embedding = nn.Embedding(
|
||||
self.num_classes,
|
||||
encoder_rnn_out_size,
|
||||
padding_idx=self.padding_idx)
|
||||
|
||||
# Prediction layer
|
||||
self.pred_dropout = nn.Dropout(pred_dropout)
|
||||
pred_num_classes = self.num_classes - 1
|
||||
if pred_concat:
|
||||
fc_in_channel = decoder_rnn_out_size + d_model + d_enc
|
||||
else:
|
||||
fc_in_channel = d_model
|
||||
self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
|
||||
|
||||
def _2d_attention(self,
|
||||
decoder_input,
|
||||
feat,
|
||||
holistic_feat,
|
||||
valid_ratios=None):
|
||||
|
||||
y = self.rnn_decoder(decoder_input)[0]
|
||||
# y: bsz * (seq_len + 1) * hidden_size
|
||||
|
||||
attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
|
||||
bsz, seq_len, attn_size = attn_query.shape
|
||||
attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
|
||||
# (bsz, seq_len + 1, attn_size, 1, 1)
|
||||
|
||||
attn_key = self.conv3x3_1(feat)
|
||||
# bsz * attn_size * h * w
|
||||
attn_key = attn_key.unsqueeze(1)
|
||||
# bsz * 1 * attn_size * h * w
|
||||
|
||||
attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
|
||||
|
||||
# bsz * (seq_len + 1) * attn_size * h * w
|
||||
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
|
||||
# bsz * (seq_len + 1) * h * w * attn_size
|
||||
attn_weight = self.conv1x1_2(attn_weight)
|
||||
# bsz * (seq_len + 1) * h * w * 1
|
||||
bsz, T, h, w, c = attn_weight.shape
|
||||
assert c == 1
|
||||
|
||||
if valid_ratios is not None:
|
||||
# cal mask of attention weight
|
||||
for i, valid_ratio in enumerate(valid_ratios):
|
||||
valid_width = min(w, math.ceil(w * valid_ratio))
|
||||
attn_weight[i, :, :, valid_width:, :] = float('-inf')
|
||||
|
||||
attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
|
||||
attn_weight = F.softmax(attn_weight, axis=-1)
|
||||
|
||||
attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
|
||||
attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
|
||||
# attn_weight: bsz * T * c * h * w
|
||||
# feat: bsz * c * h * w
|
||||
attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
|
||||
(3, 4),
|
||||
keepdim=False)
|
||||
# bsz * (seq_len + 1) * C
|
||||
|
||||
# Linear transformation
|
||||
if self.pred_concat:
|
||||
hf_c = holistic_feat.shape[-1]
|
||||
holistic_feat = paddle.expand(
|
||||
holistic_feat, shape=[bsz, seq_len, hf_c])
|
||||
y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
|
||||
else:
|
||||
y = self.prediction(attn_feat)
|
||||
# bsz * (seq_len + 1) * num_classes
|
||||
if self.train_mode:
|
||||
y = self.pred_dropout(y)
|
||||
|
||||
return y
|
||||
|
||||
def forward_train(self, feat, out_enc, label, img_metas):
|
||||
'''
|
||||
img_metas: [label, valid_ratio]
|
||||
'''
|
||||
if img_metas is not None:
|
||||
assert len(img_metas[0]) == feat.shape[0]
|
||||
|
||||
valid_ratios = None
|
||||
if img_metas is not None and self.mask:
|
||||
valid_ratios = img_metas[-1]
|
||||
|
||||
label = label.cuda()
|
||||
lab_embedding = self.embedding(label)
|
||||
# bsz * seq_len * emb_dim
|
||||
out_enc = out_enc.unsqueeze(1)
|
||||
# bsz * 1 * emb_dim
|
||||
in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
|
||||
# bsz * (seq_len + 1) * C
|
||||
out_dec = self._2d_attention(
|
||||
in_dec, feat, out_enc, valid_ratios=valid_ratios)
|
||||
# bsz * (seq_len + 1) * num_classes
|
||||
|
||||
return out_dec[:, 1:, :] # bsz * seq_len * num_classes
|
||||
|
||||
def forward_test(self, feat, out_enc, img_metas):
|
||||
if img_metas is not None:
|
||||
assert len(img_metas[0]) == feat.shape[0]
|
||||
|
||||
valid_ratios = None
|
||||
if img_metas is not None and self.mask:
|
||||
valid_ratios = img_metas[-1]
|
||||
|
||||
seq_len = self.max_seq_len
|
||||
bsz = feat.shape[0]
|
||||
start_token = paddle.full(
|
||||
(bsz, ), fill_value=self.start_idx, dtype='int64')
|
||||
# bsz
|
||||
start_token = self.embedding(start_token)
|
||||
# bsz * emb_dim
|
||||
emb_dim = start_token.shape[1]
|
||||
start_token = start_token.unsqueeze(1)
|
||||
start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
|
||||
# bsz * seq_len * emb_dim
|
||||
out_enc = out_enc.unsqueeze(1)
|
||||
# bsz * 1 * emb_dim
|
||||
decoder_input = paddle.concat((out_enc, start_token), axis=1)
|
||||
# bsz * (seq_len + 1) * emb_dim
|
||||
|
||||
outputs = []
|
||||
for i in range(1, seq_len + 1):
|
||||
decoder_output = self._2d_attention(
|
||||
decoder_input, feat, out_enc, valid_ratios=valid_ratios)
|
||||
char_output = decoder_output[:, i, :] # bsz * num_classes
|
||||
char_output = F.softmax(char_output, -1)
|
||||
outputs.append(char_output)
|
||||
max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
|
||||
char_embedding = self.embedding(max_idx) # bsz * emb_dim
|
||||
if i < seq_len:
|
||||
decoder_input[:, i + 1, :] = char_embedding
|
||||
|
||||
outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class SARHead(nn.Layer):
|
||||
def __init__(self,
|
||||
out_channels,
|
||||
enc_bi_rnn=False,
|
||||
enc_drop_rnn=0.1,
|
||||
enc_gru=False,
|
||||
dec_bi_rnn=False,
|
||||
dec_drop_rnn=0.0,
|
||||
dec_gru=False,
|
||||
d_k=512,
|
||||
pred_dropout=0.1,
|
||||
max_text_length=30,
|
||||
pred_concat=True,
|
||||
**kwargs):
|
||||
super(SARHead, self).__init__()
|
||||
|
||||
# encoder module
|
||||
self.encoder = SAREncoder(
|
||||
enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
|
||||
|
||||
# decoder module
|
||||
self.decoder = ParallelSARDecoder(
|
||||
out_channels=out_channels,
|
||||
enc_bi_rnn=enc_bi_rnn,
|
||||
dec_bi_rnn=dec_bi_rnn,
|
||||
dec_drop_rnn=dec_drop_rnn,
|
||||
dec_gru=dec_gru,
|
||||
d_k=d_k,
|
||||
pred_dropout=pred_dropout,
|
||||
max_text_length=max_text_length,
|
||||
pred_concat=pred_concat)
|
||||
|
||||
def forward(self, feat, targets=None):
|
||||
'''
|
||||
img_metas: [label, valid_ratio]
|
||||
'''
|
||||
holistic_feat = self.encoder(feat, targets) # bsz c
|
||||
|
||||
if self.training:
|
||||
label = targets[0] # label
|
||||
label = paddle.to_tensor(label, dtype='int64')
|
||||
final_out = self.decoder(
|
||||
feat, holistic_feat, label, img_metas=targets)
|
||||
if not self.training:
|
||||
final_out = self.decoder(
|
||||
feat,
|
||||
holistic_feat,
|
||||
label=None,
|
||||
img_metas=targets,
|
||||
train_mode=False)
|
||||
# (bsz, seq_len, num_classes)
|
||||
|
||||
return final_out
|
|
@ -25,7 +25,7 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
|||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \
|
||||
TableLabelDecode
|
||||
TableLabelDecode, SARLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
|
@ -33,7 +33,8 @@ def build_post_process(config, global_config=None):
|
|||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'NRTRLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -15,6 +15,7 @@ import numpy as np
|
|||
import string
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
import re
|
||||
|
||||
|
||||
class BaseRecLabelDecode(object):
|
||||
|
@ -165,21 +166,21 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|||
use_space_char=True,
|
||||
**kwargs):
|
||||
super(NRTRLabelDecode, self).__init__(character_dict_path,
|
||||
character_type, use_space_char)
|
||||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if preds.dtype == paddle.int64:
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
if preds[0][0]==2:
|
||||
preds_idx = preds[:,1:]
|
||||
if preds[0][0] == 2:
|
||||
preds_idx = preds[:, 1:]
|
||||
else:
|
||||
preds_idx = preds
|
||||
|
||||
text = self.decode(preds_idx)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label[:,1:])
|
||||
label = self.decode(label[:, 1:])
|
||||
else:
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
|
@ -188,13 +189,13 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label[:,1:])
|
||||
label = self.decode(label[:, 1:])
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
|
||||
dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
|
@ -203,10 +204,11 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] == 3: # end
|
||||
if text_index[batch_idx][idx] == 3: # end
|
||||
break
|
||||
try:
|
||||
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
except:
|
||||
continue
|
||||
if text_prob is not None:
|
||||
|
@ -218,7 +220,6 @@ class NRTRLabelDecode(BaseRecLabelDecode):
|
|||
return result_list
|
||||
|
||||
|
||||
|
||||
class AttnLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -256,7 +257,8 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
|
@ -386,10 +388,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|||
class TableLabelDecode(object):
|
||||
""" """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
**kwargs):
|
||||
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
||||
def __init__(self, character_dict_path, **kwargs):
|
||||
list_character, list_elem = self.load_char_elem_dict(
|
||||
character_dict_path)
|
||||
list_character = self.add_special_char(list_character)
|
||||
list_elem = self.add_special_char(list_elem)
|
||||
self.dict_character = {}
|
||||
|
@ -408,7 +409,8 @@ class TableLabelDecode(object):
|
|||
list_elem = []
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
|
||||
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
|
||||
"\t")
|
||||
character_num = int(substr[0])
|
||||
elem_num = int(substr[1])
|
||||
for cno in range(1, 1 + character_num):
|
||||
|
@ -428,14 +430,14 @@ class TableLabelDecode(object):
|
|||
def __call__(self, preds):
|
||||
structure_probs = preds['structure_probs']
|
||||
loc_preds = preds['loc_preds']
|
||||
if isinstance(structure_probs,paddle.Tensor):
|
||||
if isinstance(structure_probs, paddle.Tensor):
|
||||
structure_probs = structure_probs.numpy()
|
||||
if isinstance(loc_preds,paddle.Tensor):
|
||||
if isinstance(loc_preds, paddle.Tensor):
|
||||
loc_preds = loc_preds.numpy()
|
||||
structure_idx = structure_probs.argmax(axis=2)
|
||||
structure_probs = structure_probs.max(axis=2)
|
||||
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
|
||||
structure_probs, 'elem')
|
||||
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
|
||||
structure_idx, structure_probs, 'elem')
|
||||
res_html_code_list = []
|
||||
res_loc_list = []
|
||||
batch_num = len(structure_str)
|
||||
|
@ -450,8 +452,13 @@ class TableLabelDecode(object):
|
|||
res_loc = np.array(res_loc)
|
||||
res_html_code_list.append(res_html_code)
|
||||
res_loc_list.append(res_loc)
|
||||
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
|
||||
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
|
||||
return {
|
||||
'res_html_code': res_html_code_list,
|
||||
'res_loc': res_loc_list,
|
||||
'res_score_list': result_score_list,
|
||||
'res_elem_idx_list': result_elem_idx_list,
|
||||
'structure_str_list': structure_str
|
||||
}
|
||||
|
||||
def decode(self, text_index, structure_probs, char_or_elem):
|
||||
"""convert text-label into text-index.
|
||||
|
@ -516,3 +523,82 @@ class TableLabelDecode(object):
|
|||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
||||
|
||||
class SARLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SARLabelDecode, self).__init__(character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
self.rm_symbol = kwargs.get('rm_symbol', False)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
beg_end_str = "<BOS/EOS>"
|
||||
unknown_str = "<UKN>"
|
||||
padding_str = "<PAD>"
|
||||
dict_character = dict_character + [unknown_str]
|
||||
self.unknown_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [beg_end_str]
|
||||
self.start_idx = len(dict_character) - 1
|
||||
self.end_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [padding_str]
|
||||
self.padding_idx = len(dict_character) - 1
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] in ignored_tokens:
|
||||
continue
|
||||
if int(text_index[batch_idx][idx]) == int(self.end_idx):
|
||||
if text_prob is None and idx == 0:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
if self.rm_symbol:
|
||||
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
|
||||
text = text.lower()
|
||||
text = comp.sub('', text)
|
||||
result_list.append((text, np.mean(conf_list)))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [self.padding_idx]
|
||||
|
|
|
@ -55,6 +55,7 @@ def main():
|
|||
|
||||
model = build_model(config['Architecture'])
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
use_sar = config['Architecture']['algorithm'] == "SAR"
|
||||
if "model_type" in config['Architecture'].keys():
|
||||
model_type = config['Architecture']['model_type']
|
||||
else:
|
||||
|
@ -71,7 +72,7 @@ def main():
|
|||
|
||||
# start eval
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, use_srn)
|
||||
eval_class, model_type, use_srn, use_sar)
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
|
|
@ -74,6 +74,10 @@ def main():
|
|||
'image', 'encoder_word_pos', 'gsrm_word_pos',
|
||||
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
|
||||
]
|
||||
elif config['Architecture']['algorithm'] == "SAR":
|
||||
op[op_name]['keep_keys'] = [
|
||||
'image', 'valid_ratio'
|
||||
]
|
||||
else:
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
transforms.append(op)
|
||||
|
@ -106,11 +110,16 @@ def main():
|
|||
paddle.to_tensor(gsrm_slf_attn_bias1_list),
|
||||
paddle.to_tensor(gsrm_slf_attn_bias2_list)
|
||||
]
|
||||
if config['Architecture']['algorithm'] == "SAR":
|
||||
valid_ratio = np.expand_dims(batch[-1], axis=0)
|
||||
img_metas = [paddle.to_tensor(valid_ratio)]
|
||||
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
preds = model(images, others)
|
||||
elif config['Architecture']['algorithm'] == "SAR":
|
||||
preds = model(images, img_metas)
|
||||
else:
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds)
|
||||
|
|
|
@ -187,7 +187,7 @@ def train(config,
|
|||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
|
||||
|
||||
use_sar = config['Architecture']['algorithm'] == 'SAR'
|
||||
try:
|
||||
model_type = config['Architecture']['model_type']
|
||||
except:
|
||||
|
@ -215,7 +215,7 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
if use_srn or model_type == 'table' or use_nrtr:
|
||||
if use_srn or model_type == 'table' or use_nrtr or use_sar:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
@ -279,7 +279,8 @@ def train(config,
|
|||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
use_srn=use_srn)
|
||||
use_srn=use_srn,
|
||||
use_sar=use_sar)
|
||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||
logger.info(cur_metric_str)
|
||||
|
@ -351,7 +352,8 @@ def eval(model,
|
|||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
use_srn=False):
|
||||
use_srn=False,
|
||||
use_sar=False):
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
total_frame = 0.0
|
||||
|
@ -364,7 +366,7 @@ def eval(model,
|
|||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
if use_srn or model_type == 'table':
|
||||
if use_srn or model_type == 'table' or use_sar:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
@ -400,7 +402,7 @@ def preprocess(is_train=False):
|
|||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
Loading…
Reference in New Issue