Merge remote-tracking branch 'origin/dygraph' into dygraph
This commit is contained in:
commit
960f7fcec3
|
@ -0,0 +1,102 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 21
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec/nrtr/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
character_type: EN_symbol
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: True
|
||||
save_res_path: ./output/rec/predicts_nrtr.txt
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.99
|
||||
clip_norm: 5.0
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.0005
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: NRTR
|
||||
in_channels: 1
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MTB
|
||||
cnn_num: 2
|
||||
Head:
|
||||
name: Transformer
|
||||
d_model: 512
|
||||
num_encoder_layers: 6
|
||||
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
|
||||
|
||||
|
||||
Loss:
|
||||
name: NRTRLoss
|
||||
smoothing: True
|
||||
|
||||
PostProcess:
|
||||
name: NRTRLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- NRTRDecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- NRTRLabelEncode: # Class handling label
|
||||
- NRTRRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 512
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/evaluation/
|
||||
transforms:
|
||||
- NRTRDecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- NRTRLabelEncode: # Class handling label
|
||||
- NRTRRecResizeImg:
|
||||
image_shape: [100, 32]
|
||||
resize_type: PIL # PIL or OpenCV
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 1
|
||||
use_shared_memory: False
|
|
@ -75,7 +75,7 @@ def main(config, device, logger, vdl_writer):
|
|||
model = build_model(config['Architecture'])
|
||||
|
||||
flops = paddle.flops(model, [1, 3, 640, 640])
|
||||
logger.info(f"FLOPs before pruning: {flops}")
|
||||
logger.info("FLOPs before pruning: {}".format(flops))
|
||||
|
||||
from paddleslim.dygraph import FPGMFilterPruner
|
||||
model.train()
|
||||
|
@ -106,8 +106,8 @@ def main(config, device, logger, vdl_writer):
|
|||
|
||||
def eval_fn():
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
logger.info(f"metric['hmean']: {metric['hmean']}")
|
||||
eval_class, False)
|
||||
logger.info("metric['hmean']: {}".format(metric['hmean']))
|
||||
return metric['hmean']
|
||||
|
||||
params_sensitive = pruner.sensitive(
|
||||
|
@ -123,16 +123,17 @@ def main(config, device, logger, vdl_writer):
|
|||
# calculate pruned params's ratio
|
||||
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
|
||||
for key in params_sensitive.keys():
|
||||
logger.info(f"{key}, {params_sensitive[key]}")
|
||||
logger.info("{}, {}".format(key, params_sensitive[key]))
|
||||
|
||||
#params_sensitive = {}
|
||||
#for param in model.parameters():
|
||||
# if 'transpose' not in param.name and 'linear' not in param.name:
|
||||
# params_sensitive[param.name] = 0.1
|
||||
|
||||
plan = pruner.prune_vars(params_sensitive, [0])
|
||||
for param in model.parameters():
|
||||
if ("weights" in param.name and "conv" in param.name) or (
|
||||
"w_0" in param.name and "conv2d" in param.name):
|
||||
logger.info(f"{param.name}: {param.shape}")
|
||||
|
||||
flops = paddle.flops(model, [1, 3, 640, 640])
|
||||
logger.info(f"FLOPs after pruning: {flops}")
|
||||
logger.info("FLOPs after pruning: {}".format(flops))
|
||||
|
||||
# start train
|
||||
|
||||
|
|
|
@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
|
||||
|
||||
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|
@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
||||
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||
|
||||
|
||||
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
||||
|
|
|
@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
|||
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||
|
||||
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ PaddleOCR open-source text recognition algorithms list:
|
|||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
|
||||
|
||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
||||
|
@ -60,5 +61,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
||||
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||
|
||||
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
||||
|
|
|
@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
|
|||
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||
|
||||
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||
|
||||
For training Chinese data, it is recommended to use
|
||||
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
||||
|
|
BIN
doc/table/1.png
BIN
doc/table/1.png
Binary file not shown.
Before Width: | Height: | Size: 263 KiB After Width: | Height: | Size: 758 KiB |
Binary file not shown.
Before Width: | Height: | Size: 24 KiB After Width: | Height: | Size: 58 KiB |
|
@ -127,7 +127,7 @@ model_urls = {
|
|||
}
|
||||
|
||||
SUPPORT_DET_MODEL = ['DB']
|
||||
VERSION = '2.2'
|
||||
VERSION = '2.2.0.1'
|
||||
SUPPORT_REC_MODEL = ['CRNN']
|
||||
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
|||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .operators import *
|
||||
|
|
|
@ -161,6 +161,34 @@ class BaseRecLabelEncode(object):
|
|||
return text_list
|
||||
|
||||
|
||||
class NRTRLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='EN_symbol',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
|
||||
super(NRTRLabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text.insert(0, 2)
|
||||
text.append(3)
|
||||
text = text + [0] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
class CTCLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
|
|
@ -57,6 +57,38 @@ class DecodeImage(object):
|
|||
return data
|
||||
|
||||
|
||||
class NRTRDecodeImage(object):
|
||||
""" decode image """
|
||||
|
||||
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
|
||||
self.img_mode = img_mode
|
||||
self.channel_first = channel_first
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if six.PY2:
|
||||
assert type(img) is str and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
else:
|
||||
assert type(img) is bytes and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
|
||||
img = cv2.imdecode(img, 1)
|
||||
|
||||
if img is None:
|
||||
return None
|
||||
if self.img_mode == 'GRAY':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
class NormalizeImage(object):
|
||||
""" normalize image such as substract mean, divide std
|
||||
"""
|
||||
|
|
|
@ -16,7 +16,7 @@ import math
|
|||
import cv2
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
from PIL import Image
|
||||
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
||||
|
||||
|
||||
|
@ -43,6 +43,25 @@ class ClsResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class NRTRRecResizeImg(object):
|
||||
def __init__(self, image_shape, resize_type, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.resize_type = resize_type
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if self.resize_type == 'PIL':
|
||||
image_pil = Image.fromarray(np.uint8(img))
|
||||
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
|
||||
img = np.array(img)
|
||||
if self.resize_type == 'OpenCV':
|
||||
img = cv2.resize(img, self.image_shape)
|
||||
norm_img = np.expand_dims(img, -1)
|
||||
norm_img = norm_img.transpose((2, 0, 1))
|
||||
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||
return data
|
||||
|
||||
|
||||
class RecResizeImg(object):
|
||||
def __init__(self,
|
||||
image_shape,
|
||||
|
|
|
@ -25,7 +25,7 @@ from .det_sast_loss import SASTLoss
|
|||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_att_loss import AttentionLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
|
||||
from .rec_nrtr_loss import NRTRLoss
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
|
@ -44,8 +44,9 @@ from .table_att_loss import TableAttentionLoss
|
|||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('loss only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
class NRTRLoss(nn.Layer):
|
||||
def __init__(self, smoothing=True, **kwargs):
|
||||
super(NRTRLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
|
||||
self.smoothing = smoothing
|
||||
|
||||
def forward(self, pred, batch):
|
||||
pred = pred.reshape([-1, pred.shape[2]])
|
||||
max_len = batch[2].max()
|
||||
tgt = batch[1][:, 1:2 + max_len]
|
||||
tgt = tgt.reshape([-1])
|
||||
if self.smoothing:
|
||||
eps = 0.1
|
||||
n_class = pred.shape[1]
|
||||
one_hot = F.one_hot(tgt, pred.shape[1])
|
||||
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||
log_prb = F.log_softmax(pred, axis=1)
|
||||
non_pad_mask = paddle.not_equal(
|
||||
tgt, paddle.zeros(
|
||||
tgt.shape, dtype='int64'))
|
||||
loss = -(one_hot * log_prb).sum(axis=1)
|
||||
loss = loss.masked_select(non_pad_mask).mean()
|
||||
else:
|
||||
loss = self.loss_func(pred, tgt)
|
||||
return {'loss': loss}
|
|
@ -57,3 +57,4 @@ class RecMetric(object):
|
|||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
self.norm_edit_dis = 0
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
from ppocr.modeling.transforms import build_transform
|
||||
from ppocr.modeling.backbones import build_backbone
|
||||
|
|
|
@ -26,8 +26,9 @@ def build_backbone(config, model_type):
|
|||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
from .rec_nrtr_mtb import MTB
|
||||
support_dict = [
|
||||
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
|
||||
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB'
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from paddle import nn
|
||||
|
||||
|
||||
class MTB(nn.Layer):
|
||||
def __init__(self, cnn_num, in_channels):
|
||||
super(MTB, self).__init__()
|
||||
self.block = nn.Sequential()
|
||||
self.out_channels = in_channels
|
||||
self.cnn_num = cnn_num
|
||||
if self.cnn_num == 2:
|
||||
for i in range(self.cnn_num):
|
||||
self.block.add_sublayer(
|
||||
'conv_{}'.format(i),
|
||||
nn.Conv2D(
|
||||
in_channels=in_channels
|
||||
if i == 0 else 32 * (2**(i - 1)),
|
||||
out_channels=32 * (2**i),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1))
|
||||
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
|
||||
self.block.add_sublayer('bn_{}'.format(i),
|
||||
nn.BatchNorm2D(32 * (2**i)))
|
||||
|
||||
def forward(self, images):
|
||||
x = self.block(images)
|
||||
if self.cnn_num == 2:
|
||||
# (b, w, h, c)
|
||||
x = x.transpose([0, 3, 2, 1])
|
||||
x_shape = x.shape
|
||||
x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
|
||||
return x
|
|
@ -26,12 +26,14 @@ def build_head(config):
|
|||
from .rec_ctc_head import CTCHead
|
||||
from .rec_att_head import AttentionHead
|
||||
from .rec_srn_head import SRNHead
|
||||
from .rec_nrtr_head import Transformer
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead', 'TableAttentionHead']
|
||||
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead'
|
||||
]
|
||||
|
||||
#table head
|
||||
from .table_att_head import TableAttentionHead
|
||||
|
|
|
@ -0,0 +1,178 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Linear
|
||||
from paddle.nn.initializer import XavierUniform as xavier_uniform_
|
||||
from paddle.nn.initializer import Constant as constant_
|
||||
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||
|
||||
zeros_ = constant_(value=0.)
|
||||
ones_ = constant_(value=1.)
|
||||
|
||||
|
||||
class MultiheadAttention(nn.Layer):
|
||||
"""Allows the model to jointly attend to information
|
||||
from different representation subspaces.
|
||||
See reference: Attention Is All You Need
|
||||
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model
|
||||
num_heads: parallel attention layers, or heads
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
dropout=0.,
|
||||
bias=True,
|
||||
add_bias_kv=False,
|
||||
add_zero_attn=False):
|
||||
super(MultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
|
||||
self._reset_parameters()
|
||||
self.conv1 = paddle.nn.Conv2D(
|
||||
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||
self.conv2 = paddle.nn.Conv2D(
|
||||
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||
self.conv3 = paddle.nn.Conv2D(
|
||||
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||
|
||||
def _reset_parameters(self):
|
||||
xavier_uniform_(self.out_proj.weight)
|
||||
|
||||
def forward(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_padding_mask=None,
|
||||
incremental_state=None,
|
||||
need_weights=True,
|
||||
static_kv=False,
|
||||
attn_mask=None):
|
||||
"""
|
||||
Inputs of forward function
|
||||
query: [target length, batch size, embed dim]
|
||||
key: [sequence length, batch size, embed dim]
|
||||
value: [sequence length, batch size, embed dim]
|
||||
key_padding_mask: if True, mask padding based on batch size
|
||||
incremental_state: if provided, previous time steps are cashed
|
||||
need_weights: output attn_output_weights
|
||||
static_kv: key and value are static
|
||||
|
||||
Outputs of forward function
|
||||
attn_output: [target length, batch size, embed dim]
|
||||
attn_output_weights: [batch size, target length, sequence length]
|
||||
"""
|
||||
tgt_len, bsz, embed_dim = query.shape
|
||||
assert embed_dim == self.embed_dim
|
||||
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
||||
assert key.shape == value.shape
|
||||
|
||||
q = self._in_proj_q(query)
|
||||
k = self._in_proj_k(key)
|
||||
v = self._in_proj_v(value)
|
||||
q *= self.scaling
|
||||
|
||||
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose(
|
||||
[1, 0, 2])
|
||||
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
|
||||
[1, 0, 2])
|
||||
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
|
||||
[1, 0, 2])
|
||||
|
||||
src_len = k.shape[1]
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape[0] == bsz
|
||||
assert key_padding_mask.shape[1] == src_len
|
||||
|
||||
attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1]))
|
||||
assert list(attn_output_weights.
|
||||
shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_output_weights += attn_mask
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.reshape(
|
||||
[bsz, self.num_heads, tgt_len, src_len])
|
||||
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
|
||||
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
|
||||
y = paddle.where(key == 0., key, y)
|
||||
attn_output_weights += y
|
||||
attn_output_weights = attn_output_weights.reshape(
|
||||
[bsz * self.num_heads, tgt_len, src_len])
|
||||
|
||||
attn_output_weights = F.softmax(
|
||||
attn_output_weights.astype('float32'),
|
||||
axis=-1,
|
||||
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
|
||||
else attn_output_weights.dtype)
|
||||
attn_output_weights = F.dropout(
|
||||
attn_output_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = paddle.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.
|
||||
shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||
attn_output = attn_output.transpose([1, 0, 2]).reshape(
|
||||
[tgt_len, bsz, embed_dim])
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.reshape(
|
||||
[bsz, self.num_heads, tgt_len, src_len])
|
||||
attn_output_weights = attn_output_weights.sum(
|
||||
axis=1) / self.num_heads
|
||||
else:
|
||||
attn_output_weights = None
|
||||
return attn_output, attn_output_weights
|
||||
|
||||
def _in_proj_q(self, query):
|
||||
query = query.transpose([1, 2, 0])
|
||||
query = paddle.unsqueeze(query, axis=2)
|
||||
res = self.conv1(query)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res
|
||||
|
||||
def _in_proj_k(self, key):
|
||||
key = key.transpose([1, 2, 0])
|
||||
key = paddle.unsqueeze(key, axis=2)
|
||||
res = self.conv2(key)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res
|
||||
|
||||
def _in_proj_v(self, value):
|
||||
value = value.transpose([1, 2, 0]) #(1, 2, 0)
|
||||
value = paddle.unsqueeze(value, axis=2)
|
||||
res = self.conv3(value)
|
||||
res = paddle.squeeze(res, axis=2)
|
||||
res = res.transpose([2, 0, 1])
|
||||
return res
|
|
@ -0,0 +1,844 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import paddle
|
||||
import copy
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import LayerList
|
||||
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
||||
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
|
||||
import numpy as np
|
||||
from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
|
||||
from paddle.nn.initializer import Constant as constant_
|
||||
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||
|
||||
zeros_ = constant_(value=0.)
|
||||
ones_ = constant_(value=1.)
|
||||
|
||||
|
||||
class Transformer(nn.Layer):
|
||||
"""A transformer model. User is able to modify the attributes as needed. The architechture
|
||||
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
||||
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
||||
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
||||
Processing Systems, pages 6000-6010.
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
||||
nhead: the number of heads in the multiheadattention models (default=8).
|
||||
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
||||
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
custom_encoder: custom encoder (default=None).
|
||||
custom_decoder: custom decoder (default=None).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
beam_size=0,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=1024,
|
||||
attention_dropout_rate=0.0,
|
||||
residual_dropout_rate=0.1,
|
||||
custom_encoder=None,
|
||||
custom_decoder=None,
|
||||
in_channels=0,
|
||||
out_channels=0,
|
||||
dst_vocab_size=99,
|
||||
scale_embedding=True):
|
||||
super(Transformer, self).__init__()
|
||||
self.embedding = Embeddings(
|
||||
d_model=d_model,
|
||||
vocab=dst_vocab_size,
|
||||
padding_idx=0,
|
||||
scale_embedding=scale_embedding)
|
||||
self.positional_encoding = PositionalEncoding(
|
||||
dropout=residual_dropout_rate,
|
||||
dim=d_model, )
|
||||
if custom_encoder is not None:
|
||||
self.encoder = custom_encoder
|
||||
else:
|
||||
if num_encoder_layers > 0:
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model, nhead, dim_feedforward, attention_dropout_rate,
|
||||
residual_dropout_rate)
|
||||
self.encoder = TransformerEncoder(encoder_layer,
|
||||
num_encoder_layers)
|
||||
else:
|
||||
self.encoder = None
|
||||
|
||||
if custom_decoder is not None:
|
||||
self.decoder = custom_decoder
|
||||
else:
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model, nhead, dim_feedforward, attention_dropout_rate,
|
||||
residual_dropout_rate)
|
||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
|
||||
|
||||
self._reset_parameters()
|
||||
self.beam_size = beam_size
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False)
|
||||
w0 = np.random.normal(0.0, d_model**-0.5,
|
||||
(d_model, dst_vocab_size)).astype(np.float32)
|
||||
self.tgt_word_prj.weight.set_value(w0)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
|
||||
if isinstance(m, nn.Conv2D):
|
||||
xavier_normal_(m.weight)
|
||||
if m.bias is not None:
|
||||
zeros_(m.bias)
|
||||
|
||||
def forward_train(self, src, tgt):
|
||||
tgt = tgt[:, :-1]
|
||||
|
||||
tgt_key_padding_mask = self.generate_padding_mask(tgt)
|
||||
tgt = self.embedding(tgt).transpose([1, 0, 2])
|
||||
tgt = self.positional_encoding(tgt)
|
||||
tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
|
||||
|
||||
if self.encoder is not None:
|
||||
src = self.positional_encoding(src.transpose([1, 0, 2]))
|
||||
memory = self.encoder(src)
|
||||
else:
|
||||
memory = src.squeeze(2).transpose([2, 0, 1])
|
||||
output = self.decoder(
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=None,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=None)
|
||||
output = output.transpose([1, 0, 2])
|
||||
logit = self.tgt_word_prj(output)
|
||||
return logit
|
||||
|
||||
def forward(self, src, targets=None):
|
||||
"""Take in and process masked source/target sequences.
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
tgt: the sequence to the decoder (required).
|
||||
Shape:
|
||||
- src: :math:`(S, N, E)`.
|
||||
- tgt: :math:`(T, N, E)`.
|
||||
Examples:
|
||||
>>> output = transformer_model(src, tgt)
|
||||
"""
|
||||
|
||||
if self.training:
|
||||
max_len = targets[1].max()
|
||||
tgt = targets[0][:, :2 + max_len]
|
||||
return self.forward_train(src, tgt)
|
||||
else:
|
||||
if self.beam_size > 0:
|
||||
return self.forward_beam(src)
|
||||
else:
|
||||
return self.forward_test(src)
|
||||
|
||||
def forward_test(self, src):
|
||||
bs = src.shape[0]
|
||||
if self.encoder is not None:
|
||||
src = self.positional_encoding(src.transpose([1, 0, 2]))
|
||||
memory = self.encoder(src)
|
||||
else:
|
||||
memory = src.squeeze(2).transpose([2, 0, 1])
|
||||
dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
|
||||
for len_dec_seq in range(1, 25):
|
||||
src_enc = memory.clone()
|
||||
tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
|
||||
dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2])
|
||||
dec_seq_embed = self.positional_encoding(dec_seq_embed)
|
||||
tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[
|
||||
0])
|
||||
output = self.decoder(
|
||||
dec_seq_embed,
|
||||
src_enc,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=None,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=None)
|
||||
dec_output = output.transpose([1, 0, 2])
|
||||
|
||||
dec_output = dec_output[:,
|
||||
-1, :] # Pick the last step: (bh * bm) * d_h
|
||||
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
|
||||
word_prob = word_prob.reshape([1, bs, -1])
|
||||
preds_idx = word_prob.argmax(axis=2)
|
||||
|
||||
if paddle.equal_all(
|
||||
preds_idx[-1],
|
||||
paddle.full(
|
||||
preds_idx[-1].shape, 3, dtype='int64')):
|
||||
break
|
||||
|
||||
preds_prob = word_prob.max(axis=2)
|
||||
dec_seq = paddle.concat(
|
||||
[dec_seq, preds_idx.reshape([-1, 1])], axis=1)
|
||||
|
||||
return dec_seq
|
||||
|
||||
def forward_beam(self, images):
|
||||
''' Translation work in one batch '''
|
||||
|
||||
def get_inst_idx_to_tensor_position_map(inst_idx_list):
|
||||
''' Indicate the position of an instance in a tensor. '''
|
||||
return {
|
||||
inst_idx: tensor_position
|
||||
for tensor_position, inst_idx in enumerate(inst_idx_list)
|
||||
}
|
||||
|
||||
def collect_active_part(beamed_tensor, curr_active_inst_idx,
|
||||
n_prev_active_inst, n_bm):
|
||||
''' Collect tensor parts associated to active instances. '''
|
||||
|
||||
_, *d_hs = beamed_tensor.shape
|
||||
n_curr_active_inst = len(curr_active_inst_idx)
|
||||
new_shape = (n_curr_active_inst * n_bm, *d_hs)
|
||||
|
||||
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
|
||||
beamed_tensor = beamed_tensor.index_select(
|
||||
paddle.to_tensor(curr_active_inst_idx), axis=0)
|
||||
beamed_tensor = beamed_tensor.reshape([*new_shape])
|
||||
|
||||
return beamed_tensor
|
||||
|
||||
def collate_active_info(src_enc, inst_idx_to_position_map,
|
||||
active_inst_idx_list):
|
||||
# Sentences which are still active are collected,
|
||||
# so the decoder will not run on completed sentences.
|
||||
|
||||
n_prev_active_inst = len(inst_idx_to_position_map)
|
||||
active_inst_idx = [
|
||||
inst_idx_to_position_map[k] for k in active_inst_idx_list
|
||||
]
|
||||
active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
|
||||
active_src_enc = collect_active_part(
|
||||
src_enc.transpose([1, 0, 2]), active_inst_idx,
|
||||
n_prev_active_inst, n_bm).transpose([1, 0, 2])
|
||||
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
|
||||
active_inst_idx_list)
|
||||
return active_src_enc, active_inst_idx_to_position_map
|
||||
|
||||
def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
|
||||
inst_idx_to_position_map, n_bm,
|
||||
memory_key_padding_mask):
|
||||
''' Decode and update beam status, and then return active beam idx '''
|
||||
|
||||
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
|
||||
dec_partial_seq = [
|
||||
b.get_current_state() for b in inst_dec_beams if not b.done
|
||||
]
|
||||
dec_partial_seq = paddle.stack(dec_partial_seq)
|
||||
|
||||
dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
|
||||
return dec_partial_seq
|
||||
|
||||
def prepare_beam_memory_key_padding_mask(
|
||||
inst_dec_beams, memory_key_padding_mask, n_bm):
|
||||
keep = []
|
||||
for idx in (memory_key_padding_mask):
|
||||
if not inst_dec_beams[idx].done:
|
||||
keep.append(idx)
|
||||
memory_key_padding_mask = memory_key_padding_mask[
|
||||
paddle.to_tensor(keep)]
|
||||
len_s = memory_key_padding_mask.shape[-1]
|
||||
n_inst = memory_key_padding_mask.shape[0]
|
||||
memory_key_padding_mask = paddle.concat(
|
||||
[memory_key_padding_mask for i in range(n_bm)], axis=1)
|
||||
memory_key_padding_mask = memory_key_padding_mask.reshape(
|
||||
[n_inst * n_bm, len_s]) #repeat(1, n_bm)
|
||||
return memory_key_padding_mask
|
||||
|
||||
def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
|
||||
memory_key_padding_mask):
|
||||
tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
|
||||
dec_seq = self.embedding(dec_seq).transpose([1, 0, 2])
|
||||
dec_seq = self.positional_encoding(dec_seq)
|
||||
tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[
|
||||
0])
|
||||
dec_output = self.decoder(
|
||||
dec_seq,
|
||||
enc_output,
|
||||
tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
).transpose([1, 0, 2])
|
||||
dec_output = dec_output[:,
|
||||
-1, :] # Pick the last step: (bh * bm) * d_h
|
||||
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
|
||||
word_prob = word_prob.reshape([n_active_inst, n_bm, -1])
|
||||
return word_prob
|
||||
|
||||
def collect_active_inst_idx_list(inst_beams, word_prob,
|
||||
inst_idx_to_position_map):
|
||||
active_inst_idx_list = []
|
||||
for inst_idx, inst_position in inst_idx_to_position_map.items():
|
||||
is_inst_complete = inst_beams[inst_idx].advance(word_prob[
|
||||
inst_position])
|
||||
if not is_inst_complete:
|
||||
active_inst_idx_list += [inst_idx]
|
||||
|
||||
return active_inst_idx_list
|
||||
|
||||
n_active_inst = len(inst_idx_to_position_map)
|
||||
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
|
||||
memory_key_padding_mask = None
|
||||
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
|
||||
memory_key_padding_mask)
|
||||
# Update the beam with predicted word prob information and collect incomplete instances
|
||||
active_inst_idx_list = collect_active_inst_idx_list(
|
||||
inst_dec_beams, word_prob, inst_idx_to_position_map)
|
||||
return active_inst_idx_list
|
||||
|
||||
def collect_hypothesis_and_scores(inst_dec_beams, n_best):
|
||||
all_hyp, all_scores = [], []
|
||||
for inst_idx in range(len(inst_dec_beams)):
|
||||
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
|
||||
all_scores += [scores[:n_best]]
|
||||
hyps = [
|
||||
inst_dec_beams[inst_idx].get_hypothesis(i)
|
||||
for i in tail_idxs[:n_best]
|
||||
]
|
||||
all_hyp += [hyps]
|
||||
return all_hyp, all_scores
|
||||
|
||||
with paddle.no_grad():
|
||||
#-- Encode
|
||||
|
||||
if self.encoder is not None:
|
||||
src = self.positional_encoding(images.transpose([1, 0, 2]))
|
||||
src_enc = self.encoder(src).transpose([1, 0, 2])
|
||||
else:
|
||||
src_enc = images.squeeze(2).transpose([0, 2, 1])
|
||||
|
||||
#-- Repeat data for beam search
|
||||
n_bm = self.beam_size
|
||||
n_inst, len_s, d_h = src_enc.shape
|
||||
src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
|
||||
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
|
||||
[1, 0, 2])
|
||||
#-- Prepare beams
|
||||
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
|
||||
|
||||
#-- Bookkeeping for active or not
|
||||
active_inst_idx_list = list(range(n_inst))
|
||||
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
|
||||
active_inst_idx_list)
|
||||
#-- Decode
|
||||
for len_dec_seq in range(1, 25):
|
||||
src_enc_copy = src_enc.clone()
|
||||
active_inst_idx_list = beam_decode_step(
|
||||
inst_dec_beams, len_dec_seq, src_enc_copy,
|
||||
inst_idx_to_position_map, n_bm, None)
|
||||
if not active_inst_idx_list:
|
||||
break # all instances have finished their path to <EOS>
|
||||
src_enc, inst_idx_to_position_map = collate_active_info(
|
||||
src_enc_copy, inst_idx_to_position_map,
|
||||
active_inst_idx_list)
|
||||
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
|
||||
1)
|
||||
result_hyp = []
|
||||
for bs_hyp in batch_hyp:
|
||||
bs_hyp_pad = bs_hyp[0] + [3] * (25 - len(bs_hyp[0]))
|
||||
result_hyp.append(bs_hyp_pad)
|
||||
return paddle.to_tensor(np.array(result_hyp), dtype=paddle.int64)
|
||||
|
||||
def generate_square_subsequent_mask(self, sz):
|
||||
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
||||
Unmasked positions are filled with float(0.0).
|
||||
"""
|
||||
mask = paddle.zeros([sz, sz], dtype='float32')
|
||||
mask_inf = paddle.triu(
|
||||
paddle.full(
|
||||
shape=[sz, sz], dtype='float32', fill_value='-inf'),
|
||||
diagonal=1)
|
||||
mask = mask + mask_inf
|
||||
return mask
|
||||
|
||||
def generate_padding_mask(self, x):
|
||||
padding_mask = x.equal(paddle.to_tensor(0, dtype=x.dtype))
|
||||
return padding_mask
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""Initiate parameters in the transformer model."""
|
||||
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
xavier_uniform_(p)
|
||||
|
||||
|
||||
class TransformerEncoder(nn.Layer):
|
||||
"""TransformerEncoder is a stack of N encoder layers
|
||||
Args:
|
||||
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
norm: the layer normalization component (optional).
|
||||
"""
|
||||
|
||||
def __init__(self, encoder_layer, num_layers):
|
||||
super(TransformerEncoder, self).__init__()
|
||||
self.layers = _get_clones(encoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self, src):
|
||||
"""Pass the input through the endocder layers in turn.
|
||||
Args:
|
||||
src: the sequnce to the encoder (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
"""
|
||||
output = src
|
||||
|
||||
for i in range(self.num_layers):
|
||||
output = self.layers[i](output,
|
||||
src_mask=None,
|
||||
src_key_padding_mask=None)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Layer):
|
||||
"""TransformerDecoder is a stack of N decoder layers
|
||||
|
||||
Args:
|
||||
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
||||
num_layers: the number of sub-decoder-layers in the decoder (required).
|
||||
norm: the layer normalization component (optional).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, decoder_layer, num_layers):
|
||||
super(TransformerDecoder, self).__init__()
|
||||
self.layers = _get_clones(decoder_layer, num_layers)
|
||||
self.num_layers = num_layers
|
||||
|
||||
def forward(self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
tgt_key_padding_mask=None,
|
||||
memory_key_padding_mask=None):
|
||||
"""Pass the inputs (and mask) through the decoder layer in turn.
|
||||
|
||||
Args:
|
||||
tgt: the sequence to the decoder (required).
|
||||
memory: the sequnce from the last layer of the encoder (required).
|
||||
tgt_mask: the mask for the tgt sequence (optional).
|
||||
memory_mask: the mask for the memory sequence (optional).
|
||||
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
||||
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
||||
"""
|
||||
output = tgt
|
||||
for i in range(self.num_layers):
|
||||
output = self.layers[i](
|
||||
output,
|
||||
memory,
|
||||
tgt_mask=tgt_mask,
|
||||
memory_mask=memory_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Layer):
|
||||
"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
||||
This standard encoder layer is based on the paper "Attention Is All You Need".
|
||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
||||
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
||||
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
||||
in a different way during application.
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
attention_dropout_rate=0.0,
|
||||
residual_dropout_rate=0.1):
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
self.self_attn = MultiheadAttention(
|
||||
d_model, nhead, dropout=attention_dropout_rate)
|
||||
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=d_model,
|
||||
out_channels=dim_feedforward,
|
||||
kernel_size=(1, 1))
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=dim_feedforward,
|
||||
out_channels=d_model,
|
||||
kernel_size=(1, 1))
|
||||
|
||||
self.norm1 = LayerNorm(d_model)
|
||||
self.norm2 = LayerNorm(d_model)
|
||||
self.dropout1 = Dropout(residual_dropout_rate)
|
||||
self.dropout2 = Dropout(residual_dropout_rate)
|
||||
|
||||
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
||||
"""Pass the input through the endocder layer.
|
||||
Args:
|
||||
src: the sequnce to the encoder layer (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
"""
|
||||
src2 = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask)[0]
|
||||
src = src + self.dropout1(src2)
|
||||
src = self.norm1(src)
|
||||
|
||||
src = src.transpose([1, 2, 0])
|
||||
src = paddle.unsqueeze(src, 2)
|
||||
src2 = self.conv2(F.relu(self.conv1(src)))
|
||||
src2 = paddle.squeeze(src2, 2)
|
||||
src2 = src2.transpose([2, 0, 1])
|
||||
src = paddle.squeeze(src, 2)
|
||||
src = src.transpose([2, 0, 1])
|
||||
|
||||
src = src + self.dropout2(src2)
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Layer):
|
||||
"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
||||
This standard decoder layer is based on the paper "Attention Is All You Need".
|
||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
||||
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
||||
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
||||
in a different way during application.
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward=2048,
|
||||
attention_dropout_rate=0.0,
|
||||
residual_dropout_rate=0.1):
|
||||
super(TransformerDecoderLayer, self).__init__()
|
||||
self.self_attn = MultiheadAttention(
|
||||
d_model, nhead, dropout=attention_dropout_rate)
|
||||
self.multihead_attn = MultiheadAttention(
|
||||
d_model, nhead, dropout=attention_dropout_rate)
|
||||
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=d_model,
|
||||
out_channels=dim_feedforward,
|
||||
kernel_size=(1, 1))
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=dim_feedforward,
|
||||
out_channels=d_model,
|
||||
kernel_size=(1, 1))
|
||||
|
||||
self.norm1 = LayerNorm(d_model)
|
||||
self.norm2 = LayerNorm(d_model)
|
||||
self.norm3 = LayerNorm(d_model)
|
||||
self.dropout1 = Dropout(residual_dropout_rate)
|
||||
self.dropout2 = Dropout(residual_dropout_rate)
|
||||
self.dropout3 = Dropout(residual_dropout_rate)
|
||||
|
||||
def forward(self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
tgt_key_padding_mask=None,
|
||||
memory_key_padding_mask=None):
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
Args:
|
||||
tgt: the sequence to the decoder layer (required).
|
||||
memory: the sequnce from the last layer of the encoder (required).
|
||||
tgt_mask: the mask for the tgt sequence (optional).
|
||||
memory_mask: the mask for the memory sequence (optional).
|
||||
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
||||
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
||||
|
||||
"""
|
||||
tgt2 = self.self_attn(
|
||||
tgt,
|
||||
tgt,
|
||||
tgt,
|
||||
attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.multihead_attn(
|
||||
tgt,
|
||||
memory,
|
||||
memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt = self.norm2(tgt)
|
||||
|
||||
# default
|
||||
tgt = tgt.transpose([1, 2, 0])
|
||||
tgt = paddle.unsqueeze(tgt, 2)
|
||||
tgt2 = self.conv2(F.relu(self.conv1(tgt)))
|
||||
tgt2 = paddle.squeeze(tgt2, 2)
|
||||
tgt2 = tgt2.transpose([2, 0, 1])
|
||||
tgt = paddle.squeeze(tgt, 2)
|
||||
tgt = tgt.transpose([2, 0, 1])
|
||||
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return LayerList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Layer):
|
||||
"""Inject some information about the relative or absolute position of the tokens
|
||||
in the sequence. The positional encodings have the same dimension as
|
||||
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
||||
functions of different frequencies.
|
||||
.. math::
|
||||
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
||||
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
||||
\text{where pos is the word position and i is the embed idx)
|
||||
Args:
|
||||
d_model: the embed dim (required).
|
||||
dropout: the dropout value (default=0.1).
|
||||
max_len: the max. length of the incoming sequence (default=5000).
|
||||
Examples:
|
||||
>>> pos_encoder = PositionalEncoding(d_model)
|
||||
"""
|
||||
|
||||
def __init__(self, dropout, dim, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = paddle.zeros([max_len, dim])
|
||||
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
||||
div_term = paddle.exp(
|
||||
paddle.arange(0, dim, 2).astype('float32') *
|
||||
(-math.log(10000.0) / dim))
|
||||
pe[:, 0::2] = paddle.sin(position * div_term)
|
||||
pe[:, 1::2] = paddle.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
pe = pe.transpose([1, 0, 2])
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
"""Inputs of forward function
|
||||
Args:
|
||||
x: the sequence fed to the positional encoder model (required).
|
||||
Shape:
|
||||
x: [sequence length, batch size, embed dim]
|
||||
output: [sequence length, batch size, embed dim]
|
||||
Examples:
|
||||
>>> output = pos_encoder(x)
|
||||
"""
|
||||
x = x + self.pe[:x.shape[0], :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class PositionalEncoding_2d(nn.Layer):
|
||||
"""Inject some information about the relative or absolute position of the tokens
|
||||
in the sequence. The positional encodings have the same dimension as
|
||||
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
||||
functions of different frequencies.
|
||||
.. math::
|
||||
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
||||
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
||||
\text{where pos is the word position and i is the embed idx)
|
||||
Args:
|
||||
d_model: the embed dim (required).
|
||||
dropout: the dropout value (default=0.1).
|
||||
max_len: the max. length of the incoming sequence (default=5000).
|
||||
Examples:
|
||||
>>> pos_encoder = PositionalEncoding(d_model)
|
||||
"""
|
||||
|
||||
def __init__(self, dropout, dim, max_len=5000):
|
||||
super(PositionalEncoding_2d, self).__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
pe = paddle.zeros([max_len, dim])
|
||||
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
||||
div_term = paddle.exp(
|
||||
paddle.arange(0, dim, 2).astype('float32') *
|
||||
(-math.log(10000.0) / dim))
|
||||
pe[:, 0::2] = paddle.sin(position * div_term)
|
||||
pe[:, 1::2] = paddle.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0).transpose([1, 0, 2])
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
|
||||
self.linear1 = nn.Linear(dim, dim)
|
||||
self.linear1.weight.data.fill_(1.)
|
||||
self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
|
||||
self.linear2 = nn.Linear(dim, dim)
|
||||
self.linear2.weight.data.fill_(1.)
|
||||
|
||||
def forward(self, x):
|
||||
"""Inputs of forward function
|
||||
Args:
|
||||
x: the sequence fed to the positional encoder model (required).
|
||||
Shape:
|
||||
x: [sequence length, batch size, embed dim]
|
||||
output: [sequence length, batch size, embed dim]
|
||||
Examples:
|
||||
>>> output = pos_encoder(x)
|
||||
"""
|
||||
w_pe = self.pe[:x.shape[-1], :]
|
||||
w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
|
||||
w_pe = w_pe * w1
|
||||
w_pe = w_pe.transpose([1, 2, 0])
|
||||
w_pe = w_pe.unsqueeze(2)
|
||||
|
||||
h_pe = self.pe[:x.shape[-2], :]
|
||||
w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
|
||||
h_pe = h_pe * w2
|
||||
h_pe = h_pe.transpose([1, 2, 0])
|
||||
h_pe = h_pe.unsqueeze(3)
|
||||
|
||||
x = x + w_pe + h_pe
|
||||
x = x.reshape(
|
||||
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose(
|
||||
[2, 0, 1])
|
||||
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Embeddings(nn.Layer):
|
||||
def __init__(self, d_model, vocab, padding_idx, scale_embedding):
|
||||
super(Embeddings, self).__init__()
|
||||
self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
|
||||
w0 = np.random.normal(0.0, d_model**-0.5,
|
||||
(vocab, d_model)).astype(np.float32)
|
||||
self.embedding.weight.set_value(w0)
|
||||
self.d_model = d_model
|
||||
self.scale_embedding = scale_embedding
|
||||
|
||||
def forward(self, x):
|
||||
if self.scale_embedding:
|
||||
x = self.embedding(x)
|
||||
return x * math.sqrt(self.d_model)
|
||||
return self.embedding(x)
|
||||
|
||||
|
||||
class Beam():
|
||||
''' Beam search '''
|
||||
|
||||
def __init__(self, size, device=False):
|
||||
|
||||
self.size = size
|
||||
self._done = False
|
||||
# The score for each translation on the beam.
|
||||
self.scores = paddle.zeros((size, ), dtype=paddle.float32)
|
||||
self.all_scores = []
|
||||
# The backpointers at each time-step.
|
||||
self.prev_ks = []
|
||||
# The outputs at each time-step.
|
||||
self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
|
||||
self.next_ys[0][0] = 2
|
||||
|
||||
def get_current_state(self):
|
||||
"Get the outputs for the current timestep."
|
||||
return self.get_tentative_hypothesis()
|
||||
|
||||
def get_current_origin(self):
|
||||
"Get the backpointers for the current timestep."
|
||||
return self.prev_ks[-1]
|
||||
|
||||
@property
|
||||
def done(self):
|
||||
return self._done
|
||||
|
||||
def advance(self, word_prob):
|
||||
"Update beam status and check if finished or not."
|
||||
num_words = word_prob.shape[1]
|
||||
|
||||
# Sum the previous scores.
|
||||
if len(self.prev_ks) > 0:
|
||||
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
|
||||
else:
|
||||
beam_lk = word_prob[0]
|
||||
|
||||
flat_beam_lk = beam_lk.reshape([-1])
|
||||
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
|
||||
True) # 1st sort
|
||||
self.all_scores.append(self.scores)
|
||||
self.scores = best_scores
|
||||
# bestScoresId is flattened as a (beam x word) array,
|
||||
# so we need to calculate which word and beam each score came from
|
||||
prev_k = best_scores_id // num_words
|
||||
self.prev_ks.append(prev_k)
|
||||
self.next_ys.append(best_scores_id - prev_k * num_words)
|
||||
# End condition is when top-of-beam is EOS.
|
||||
if self.next_ys[-1][0] == 3:
|
||||
self._done = True
|
||||
self.all_scores.append(self.scores)
|
||||
|
||||
return self._done
|
||||
|
||||
def sort_scores(self):
|
||||
"Sort the scores."
|
||||
return self.scores, paddle.to_tensor(
|
||||
[i for i in range(self.scores.shape[0])], dtype='int32')
|
||||
|
||||
def get_the_best_score_and_idx(self):
|
||||
"Get the score of the best in the beam."
|
||||
scores, ids = self.sort_scores()
|
||||
return scores[1], ids[1]
|
||||
|
||||
def get_tentative_hypothesis(self):
|
||||
"Get the decoded sequence for the current timestep."
|
||||
if len(self.next_ys) == 1:
|
||||
dec_seq = self.next_ys[0].unsqueeze(1)
|
||||
else:
|
||||
_, keys = self.sort_scores()
|
||||
hyps = [self.get_hypothesis(k) for k in keys]
|
||||
hyps = [[2] + h for h in hyps]
|
||||
dec_seq = paddle.to_tensor(hyps, dtype='int64')
|
||||
return dec_seq
|
||||
|
||||
def get_hypothesis(self, k):
|
||||
""" Walk back to construct the full hypothesis. """
|
||||
hyp = []
|
||||
for j in range(len(self.prev_ks) - 1, -1, -1):
|
||||
hyp.append(self.next_ys[j + 1][k])
|
||||
k = self.prev_ks[j][k]
|
||||
return list(map(lambda x: x.item(), hyp[::-1]))
|
|
@ -24,18 +24,16 @@ __all__ = ['build_post_process']
|
|||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \
|
||||
TableLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess'
|
||||
'DistillationCTCLabelDecode', 'NRTRLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -156,6 +156,69 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
|
|||
return output
|
||||
|
||||
|
||||
class NRTRLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path=None,
|
||||
character_type='EN_symbol',
|
||||
use_space_char=True,
|
||||
**kwargs):
|
||||
super(NRTRLabelDecode, self).__init__(character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if preds.dtype == paddle.int64:
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
if preds[0][0]==2:
|
||||
preds_idx = preds[:,1:]
|
||||
else:
|
||||
preds_idx = preds
|
||||
|
||||
text = self.decode(preds_idx)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label[:,1:])
|
||||
else:
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label[:,1:])
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] == 3: # end
|
||||
break
|
||||
try:
|
||||
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
||||
except:
|
||||
continue
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text.lower(), np.mean(conf_list)))
|
||||
return result_list
|
||||
|
||||
|
||||
|
||||
class AttnLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -193,8 +256,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
|||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
|
|
|
@ -30,13 +30,13 @@ python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/
|
|||
# CPU
|
||||
python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
|
||||
|
||||
# For more,refer[Installation](https://www.paddlepaddle.org.cn/install/quick)。
|
||||
```
|
||||
For more,refer [Installation](https://www.paddlepaddle.org.cn/install/quick) .
|
||||
|
||||
- **(2) Install Layout-Parser**
|
||||
|
||||
```bash
|
||||
pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
### 2.2 Install PaddleOCR(including PP-OCR and PP-Structure)
|
||||
|
@ -180,10 +180,10 @@ OCR and table recognition model
|
|||
|
||||
|model name|description|model size|download|
|
||||
| --- | --- | --- | --- |
|
||||
|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|
||||
|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
|
||||
|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|
||||
|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|
||||
|
||||
If you need to use other models, you can download the model in [model_list](../doc/doc_en/models_list_en.md) or use your own trained model to configure it to the three fields of `det_model_dir`, `rec_model_dir`, `table_model_dir` .
|
||||
|
|
|
@ -30,13 +30,13 @@ python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/
|
|||
# CPU安装
|
||||
python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
|
||||
|
||||
# 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
||||
```
|
||||
更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
||||
|
||||
- **(2) 安装 Layout-Parser**
|
||||
|
||||
```bash
|
||||
pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
### 2.2 安装PaddleOCR(包含PP-OCR和PP-Structure)
|
||||
|
@ -179,10 +179,10 @@ OCR和表格识别模型
|
|||
|
||||
|模型名称|模型简介|推理模型大小|下载地址|
|
||||
| --- | --- | --- | --- |
|
||||
|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|
||||
|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
|
||||
|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|
||||
|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|
||||
|
||||
如需要使用其他模型,可以在 [model_list](../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
|
||||
|
|
|
@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
|
|||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
||||
cd ..
|
||||
# run
|
||||
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||
```
|
||||
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
|
|||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
||||
cd ..
|
||||
# 执行预测
|
||||
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||
python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
|
||||
```
|
||||
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
|
||||
|
||||
|
|
|
@ -7,4 +7,7 @@ tqdm
|
|||
numpy
|
||||
visualdl
|
||||
python-Levenshtein
|
||||
opencv-contrib-python==4.4.0.46
|
||||
opencv-contrib-python==4.4.0.46
|
||||
lxml
|
||||
premailer
|
||||
openpyxl
|
|
@ -4,7 +4,7 @@ python:python3.7
|
|||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:lite_train_infer=2|whole_train_infer=300
|
||||
Global.epoch_num:lite_train_infer=1|whole_train_infer=300
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_infer=2|whole_train_infer=4
|
||||
Global.pretrained_model:null
|
||||
|
@ -15,7 +15,7 @@ null:null
|
|||
trainer:norm_train|pact_train
|
||||
norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
pact_train:deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o
|
||||
fpgm_train:null
|
||||
fpgm_train:deploy/slim/prune/sensitivity_anal.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/det_mv3_db_v2.0_train/best_accuracy
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
|
@ -29,7 +29,7 @@ Global.save_inference_dir:./output/
|
|||
Global.pretrained_model:
|
||||
norm_export:tools/export_model.py -c configs/det/det_mv3_db.yml -o
|
||||
quant_export:deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o
|
||||
fpgm_export:deploy/slim/prune/export_prune_model.py
|
||||
fpgm_export:deploy/slim/prune/export_prune_model.py -c configs/det/det_mv3_db.yml -o
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
===========================train_params===========================
|
||||
model_name:ocr_server_det
|
||||
python:python3.7
|
||||
gpu_list:0|0,1
|
||||
Global.use_gpu:True|True
|
||||
Global.auto_cast:null
|
||||
Global.epoch_num:lite_train_infer=2|whole_train_infer=300
|
||||
Global.save_model_dir:./output/
|
||||
Train.loader.batch_size_per_card:lite_train_infer=2|whole_train_infer=4
|
||||
Global.pretrained_model:null
|
||||
train_model_name:latest
|
||||
train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
|
||||
null:null
|
||||
##
|
||||
trainer:norm_train|pact_train
|
||||
norm_train:tools/train.py -c configs/det/det_r50_vd_db.yml -o Global.pretrained_model=""
|
||||
pact_train:null
|
||||
fpgm_train:null
|
||||
distill_train:null
|
||||
null:null
|
||||
null:null
|
||||
##
|
||||
===========================eval_params===========================
|
||||
eval:tools/eval.py -c configs/det/det_mv3_db.yml -o
|
||||
null:null
|
||||
##
|
||||
===========================infer_params===========================
|
||||
Global.save_inference_dir:./output/
|
||||
Global.pretrained_model:
|
||||
norm_export:tools/export_model.py -c configs/det/det_r50_vd_db.yml -o
|
||||
quant_export:null
|
||||
fpgm_export:null
|
||||
distill_export:null
|
||||
export1:null
|
||||
export2:null
|
||||
##
|
||||
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
|
||||
infer_export:null
|
||||
infer_quant:False
|
||||
inference:tools/infer/predict_det.py
|
||||
--use_gpu:True|False
|
||||
--enable_mkldnn:True|False
|
||||
--cpu_threads:1|6
|
||||
--rec_batch_num:1
|
||||
--use_tensorrt:False|True
|
||||
--precision:fp32|fp16|int8
|
||||
--det_model_dir:
|
||||
--image_dir:./inference/ch_det_data_50/all-sum-510/
|
||||
--save_log_path:null
|
||||
--benchmark:True
|
||||
null:null
|
||||
|
|
@ -34,11 +34,14 @@ MODE=$2
|
|||
if [ ${MODE} = "lite_train_infer" ];then
|
||||
# pretrain lite train data
|
||||
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
|
||||
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar
|
||||
cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
|
||||
rm -rf ./train_data/icdar2015
|
||||
rm -rf ./train_data/ic15_data
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar
|
||||
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar # todo change to bcebos
|
||||
|
||||
wget -nc -P ./deploy/slim/prune https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/sen.pickle
|
||||
|
||||
cd ./train_data/ && tar xf icdar2015_lite.tar && tar xf ic15_data.tar
|
||||
ln -s ./icdar2015_lite ./icdar2015
|
||||
cd ../
|
||||
|
@ -65,6 +68,10 @@ elif [ ${MODE} = "infer" ] || [ ${MODE} = "cpp_infer" ];then
|
|||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
|
||||
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
|
||||
elif [ ${model_name} = "ocr_server_det" ]; then
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar
|
||||
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
|
||||
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
|
||||
else
|
||||
rm -rf ./train_data/ic15_data
|
||||
eval_model_name="ch_ppocr_mobile_v2.0_rec_infer"
|
||||
|
|
|
@ -88,8 +88,8 @@ class TextRecognizer(object):
|
|||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
assert imgC == img.shape[2]
|
||||
if self.character_type == "ch":
|
||||
imgW = int((32 * max_wh_ratio))
|
||||
max_wh_ratio = max(max_wh_ratio, imgW / imgH)
|
||||
imgW = int((32 * max_wh_ratio))
|
||||
h, w = img.shape[:2]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
|
|
|
@ -186,9 +186,11 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
try:
|
||||
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
|
||||
|
||||
try:
|
||||
model_type = config['Architecture']['model_type']
|
||||
except:
|
||||
except:
|
||||
model_type = None
|
||||
|
||||
if 'start_epoch' in best_model_dict:
|
||||
|
@ -213,7 +215,7 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
if use_srn or model_type == 'table':
|
||||
if use_srn or model_type == 'table' or use_nrtr:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
@ -398,7 +400,7 @@ def preprocess(is_train=False):
|
|||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'TableAttn'
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
Loading…
Reference in New Issue