commit
5c664bf4f5
|
@ -0,0 +1,102 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 21
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 10
|
||||||
|
save_model_dir: ./output/rec/nrtr/
|
||||||
|
save_epoch_step: 1
|
||||||
|
# evaluation is run every 2000 iterations
|
||||||
|
eval_batch_step: [0, 2000]
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words_en/word_10.png
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path:
|
||||||
|
character_type: EN_symbol
|
||||||
|
max_text_length: 25
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: True
|
||||||
|
save_res_path: ./output/rec/predicts_nrtr.txt
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.99
|
||||||
|
clip_norm: 5.0
|
||||||
|
lr:
|
||||||
|
name: Cosine
|
||||||
|
learning_rate: 0.0005
|
||||||
|
warmup_epoch: 2
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: NRTR
|
||||||
|
in_channels: 1
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: MTB
|
||||||
|
cnn_num: 2
|
||||||
|
Head:
|
||||||
|
name: Transformer
|
||||||
|
d_model: 512
|
||||||
|
num_encoder_layers: 6
|
||||||
|
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
|
||||||
|
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: NRTRLoss
|
||||||
|
smoothing: True
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: NRTRLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
|
transforms:
|
||||||
|
- NRTRDecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- NRTRLabelEncode: # Class handling label
|
||||||
|
- NRTRRecResizeImg:
|
||||||
|
image_shape: [100, 32]
|
||||||
|
resize_type: PIL # PIL or OpenCV
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 512
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/data_lmdb_release/evaluation/
|
||||||
|
transforms:
|
||||||
|
- NRTRDecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- NRTRLabelEncode: # Class handling label
|
||||||
|
- NRTRRecResizeImg:
|
||||||
|
image_shape: [100, 32]
|
||||||
|
resize_type: PIL # PIL or OpenCV
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 256
|
||||||
|
num_workers: 1
|
||||||
|
use_shared_memory: False
|
|
@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
||||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||||
|
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
|
||||||
|
|
||||||
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||||
|
|
||||||
|
@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
||||||
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
||||||
|
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||||
|
|
||||||
|
|
||||||
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
||||||
|
|
|
@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
||||||
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||||
|
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||||
|
|
||||||
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ PaddleOCR open-source text recognition algorithms list:
|
||||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||||
|
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
|
||||||
|
|
||||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||||
|
|
||||||
|
@ -60,5 +61,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
||||||
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
||||||
|
|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||||
|
|
||||||
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
||||||
|
|
|
@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
|
||||||
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||||
|
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||||
|
|
||||||
For training Chinese data, it is recommended to use
|
For training Chinese data, it is recommended to use
|
||||||
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
||||||
from .make_shrink_map import MakeShrinkMap
|
from .make_shrink_map import MakeShrinkMap
|
||||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||||
|
|
||||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg
|
||||||
from .randaugment import RandAugment
|
from .randaugment import RandAugment
|
||||||
from .copy_paste import CopyPaste
|
from .copy_paste import CopyPaste
|
||||||
from .operators import *
|
from .operators import *
|
||||||
|
|
|
@ -161,6 +161,34 @@ class BaseRecLabelEncode(object):
|
||||||
return text_list
|
return text_list
|
||||||
|
|
||||||
|
|
||||||
|
class NRTRLabelEncode(BaseRecLabelEncode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
max_text_length,
|
||||||
|
character_dict_path=None,
|
||||||
|
character_type='EN_symbol',
|
||||||
|
use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
super(NRTRLabelEncode,
|
||||||
|
self).__init__(max_text_length, character_dict_path,
|
||||||
|
character_type, use_space_char)
|
||||||
|
def __call__(self, data):
|
||||||
|
text = data['label']
|
||||||
|
text = self.encode(text)
|
||||||
|
if text is None:
|
||||||
|
return None
|
||||||
|
data['length'] = np.array(len(text))
|
||||||
|
text.insert(0, 2)
|
||||||
|
text.append(3)
|
||||||
|
text = text + [0] * (self.max_text_len - len(text))
|
||||||
|
data['label'] = np.array(text)
|
||||||
|
return data
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
|
||||||
|
return dict_character
|
||||||
|
|
||||||
class CTCLabelEncode(BaseRecLabelEncode):
|
class CTCLabelEncode(BaseRecLabelEncode):
|
||||||
""" Convert between text-label and text-index """
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
|
|
@ -57,6 +57,38 @@ class DecodeImage(object):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class NRTRDecodeImage(object):
|
||||||
|
""" decode image """
|
||||||
|
|
||||||
|
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
|
||||||
|
self.img_mode = img_mode
|
||||||
|
self.channel_first = channel_first
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
if six.PY2:
|
||||||
|
assert type(img) is str and len(
|
||||||
|
img) > 0, "invalid input 'img' in DecodeImage"
|
||||||
|
else:
|
||||||
|
assert type(img) is bytes and len(
|
||||||
|
img) > 0, "invalid input 'img' in DecodeImage"
|
||||||
|
img = np.frombuffer(img, dtype='uint8')
|
||||||
|
|
||||||
|
img = cv2.imdecode(img, 1)
|
||||||
|
|
||||||
|
if img is None:
|
||||||
|
return None
|
||||||
|
if self.img_mode == 'GRAY':
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
elif self.img_mode == 'RGB':
|
||||||
|
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
|
||||||
|
if self.channel_first:
|
||||||
|
img = img.transpose((2, 0, 1))
|
||||||
|
data['image'] = img
|
||||||
|
return data
|
||||||
|
|
||||||
class NormalizeImage(object):
|
class NormalizeImage(object):
|
||||||
""" normalize image such as substract mean, divide std
|
""" normalize image such as substract mean, divide std
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -16,7 +16,7 @@ import math
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
from PIL import Image
|
||||||
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,6 +43,25 @@ class ClsResizeImg(object):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class NRTRRecResizeImg(object):
|
||||||
|
def __init__(self, image_shape, resize_type, **kwargs):
|
||||||
|
self.image_shape = image_shape
|
||||||
|
self.resize_type = resize_type
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
if self.resize_type == 'PIL':
|
||||||
|
image_pil = Image.fromarray(np.uint8(img))
|
||||||
|
img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
|
||||||
|
img = np.array(img)
|
||||||
|
if self.resize_type == 'OpenCV':
|
||||||
|
img = cv2.resize(img, self.image_shape)
|
||||||
|
norm_img = np.expand_dims(img, -1)
|
||||||
|
norm_img = norm_img.transpose((2, 0, 1))
|
||||||
|
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class RecResizeImg(object):
|
class RecResizeImg(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
image_shape,
|
image_shape,
|
||||||
|
|
|
@ -25,7 +25,7 @@ from .det_sast_loss import SASTLoss
|
||||||
from .rec_ctc_loss import CTCLoss
|
from .rec_ctc_loss import CTCLoss
|
||||||
from .rec_att_loss import AttentionLoss
|
from .rec_att_loss import AttentionLoss
|
||||||
from .rec_srn_loss import SRNLoss
|
from .rec_srn_loss import SRNLoss
|
||||||
|
from .rec_nrtr_loss import NRTRLoss
|
||||||
# cls loss
|
# cls loss
|
||||||
from .cls_loss import ClsLoss
|
from .cls_loss import ClsLoss
|
||||||
|
|
||||||
|
@ -44,8 +44,9 @@ from .table_att_loss import TableAttentionLoss
|
||||||
def build_loss(config):
|
def build_loss(config):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
|
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception('loss only support {}'.format(
|
assert module_name in support_dict, Exception('loss only support {}'.format(
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class NRTRLoss(nn.Layer):
|
||||||
|
def __init__(self, smoothing=True, **kwargs):
|
||||||
|
super(NRTRLoss, self).__init__()
|
||||||
|
self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
|
||||||
|
self.smoothing = smoothing
|
||||||
|
|
||||||
|
def forward(self, pred, batch):
|
||||||
|
pred = pred.reshape([-1, pred.shape[2]])
|
||||||
|
max_len = batch[2].max()
|
||||||
|
tgt = batch[1][:, 1:2 + max_len]
|
||||||
|
tgt = tgt.reshape([-1])
|
||||||
|
if self.smoothing:
|
||||||
|
eps = 0.1
|
||||||
|
n_class = pred.shape[1]
|
||||||
|
one_hot = F.one_hot(tgt, pred.shape[1])
|
||||||
|
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||||
|
log_prb = F.log_softmax(pred, axis=1)
|
||||||
|
non_pad_mask = paddle.not_equal(
|
||||||
|
tgt, paddle.zeros(
|
||||||
|
tgt.shape, dtype='int64'))
|
||||||
|
loss = -(one_hot * log_prb).sum(axis=1)
|
||||||
|
loss = loss.masked_select(non_pad_mask).mean()
|
||||||
|
else:
|
||||||
|
loss = self.loss_func(pred, tgt)
|
||||||
|
return {'loss': loss}
|
|
@ -57,3 +57,4 @@ class RecMetric(object):
|
||||||
self.correct_num = 0
|
self.correct_num = 0
|
||||||
self.all_num = 0
|
self.all_num = 0
|
||||||
self.norm_edit_dis = 0
|
self.norm_edit_dis = 0
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from ppocr.modeling.transforms import build_transform
|
from ppocr.modeling.transforms import build_transform
|
||||||
from ppocr.modeling.backbones import build_backbone
|
from ppocr.modeling.backbones import build_backbone
|
||||||
|
|
|
@ -26,8 +26,9 @@ def build_backbone(config, model_type):
|
||||||
from .rec_resnet_vd import ResNet
|
from .rec_resnet_vd import ResNet
|
||||||
from .rec_resnet_fpn import ResNetFPN
|
from .rec_resnet_fpn import ResNetFPN
|
||||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||||
|
from .rec_nrtr_mtb import MTB
|
||||||
support_dict = [
|
support_dict = [
|
||||||
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
|
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB'
|
||||||
]
|
]
|
||||||
elif model_type == "e2e":
|
elif model_type == "e2e":
|
||||||
from .e2e_resnet_vd_pg import ResNet
|
from .e2e_resnet_vd_pg import ResNet
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
class MTB(nn.Layer):
|
||||||
|
def __init__(self, cnn_num, in_channels):
|
||||||
|
super(MTB, self).__init__()
|
||||||
|
self.block = nn.Sequential()
|
||||||
|
self.out_channels = in_channels
|
||||||
|
self.cnn_num = cnn_num
|
||||||
|
if self.cnn_num == 2:
|
||||||
|
for i in range(self.cnn_num):
|
||||||
|
self.block.add_sublayer(
|
||||||
|
'conv_{}'.format(i),
|
||||||
|
nn.Conv2D(
|
||||||
|
in_channels=in_channels
|
||||||
|
if i == 0 else 32 * (2**(i - 1)),
|
||||||
|
out_channels=32 * (2**i),
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1))
|
||||||
|
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
|
||||||
|
self.block.add_sublayer('bn_{}'.format(i),
|
||||||
|
nn.BatchNorm2D(32 * (2**i)))
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
x = self.block(images)
|
||||||
|
if self.cnn_num == 2:
|
||||||
|
# (b, w, h, c)
|
||||||
|
x = x.transpose([0, 3, 2, 1])
|
||||||
|
x_shape = x.shape
|
||||||
|
x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
|
||||||
|
return x
|
|
@ -26,12 +26,14 @@ def build_head(config):
|
||||||
from .rec_ctc_head import CTCHead
|
from .rec_ctc_head import CTCHead
|
||||||
from .rec_att_head import AttentionHead
|
from .rec_att_head import AttentionHead
|
||||||
from .rec_srn_head import SRNHead
|
from .rec_srn_head import SRNHead
|
||||||
|
from .rec_nrtr_head import Transformer
|
||||||
|
|
||||||
# cls head
|
# cls head
|
||||||
from .cls_head import ClsHead
|
from .cls_head import ClsHead
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||||
'SRNHead', 'PGHead', 'TableAttentionHead']
|
'SRNHead', 'PGHead', 'Transformer', 'TableAttentionHead'
|
||||||
|
]
|
||||||
|
|
||||||
#table head
|
#table head
|
||||||
from .table_att_head import TableAttentionHead
|
from .table_att_head import TableAttentionHead
|
||||||
|
|
|
@ -0,0 +1,178 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle.nn import Linear
|
||||||
|
from paddle.nn.initializer import XavierUniform as xavier_uniform_
|
||||||
|
from paddle.nn.initializer import Constant as constant_
|
||||||
|
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||||
|
|
||||||
|
zeros_ = constant_(value=0.)
|
||||||
|
ones_ = constant_(value=1.)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Layer):
|
||||||
|
"""Allows the model to jointly attend to information
|
||||||
|
from different representation subspaces.
|
||||||
|
See reference: Attention Is All You Need
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||||
|
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim: total dimension of the model
|
||||||
|
num_heads: parallel attention layers, or heads
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
dropout=0.,
|
||||||
|
bias=True,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False):
|
||||||
|
super(MultiheadAttention, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
|
||||||
|
self._reset_parameters()
|
||||||
|
self.conv1 = paddle.nn.Conv2D(
|
||||||
|
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||||
|
self.conv2 = paddle.nn.Conv2D(
|
||||||
|
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||||
|
self.conv3 = paddle.nn.Conv2D(
|
||||||
|
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
xavier_uniform_(self.out_proj.weight)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_padding_mask=None,
|
||||||
|
incremental_state=None,
|
||||||
|
need_weights=True,
|
||||||
|
static_kv=False,
|
||||||
|
attn_mask=None):
|
||||||
|
"""
|
||||||
|
Inputs of forward function
|
||||||
|
query: [target length, batch size, embed dim]
|
||||||
|
key: [sequence length, batch size, embed dim]
|
||||||
|
value: [sequence length, batch size, embed dim]
|
||||||
|
key_padding_mask: if True, mask padding based on batch size
|
||||||
|
incremental_state: if provided, previous time steps are cashed
|
||||||
|
need_weights: output attn_output_weights
|
||||||
|
static_kv: key and value are static
|
||||||
|
|
||||||
|
Outputs of forward function
|
||||||
|
attn_output: [target length, batch size, embed dim]
|
||||||
|
attn_output_weights: [batch size, target length, sequence length]
|
||||||
|
"""
|
||||||
|
tgt_len, bsz, embed_dim = query.shape
|
||||||
|
assert embed_dim == self.embed_dim
|
||||||
|
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
||||||
|
assert key.shape == value.shape
|
||||||
|
|
||||||
|
q = self._in_proj_q(query)
|
||||||
|
k = self._in_proj_k(key)
|
||||||
|
v = self._in_proj_v(value)
|
||||||
|
q *= self.scaling
|
||||||
|
|
||||||
|
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose(
|
||||||
|
[1, 0, 2])
|
||||||
|
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
|
||||||
|
[1, 0, 2])
|
||||||
|
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
|
||||||
|
[1, 0, 2])
|
||||||
|
|
||||||
|
src_len = k.shape[1]
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.shape[0] == bsz
|
||||||
|
assert key_padding_mask.shape[1] == src_len
|
||||||
|
|
||||||
|
attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1]))
|
||||||
|
assert list(attn_output_weights.
|
||||||
|
shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
attn_output_weights += attn_mask
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
attn_output_weights = attn_output_weights.reshape(
|
||||||
|
[bsz, self.num_heads, tgt_len, src_len])
|
||||||
|
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
|
||||||
|
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
|
||||||
|
y = paddle.where(key == 0., key, y)
|
||||||
|
attn_output_weights += y
|
||||||
|
attn_output_weights = attn_output_weights.reshape(
|
||||||
|
[bsz * self.num_heads, tgt_len, src_len])
|
||||||
|
|
||||||
|
attn_output_weights = F.softmax(
|
||||||
|
attn_output_weights.astype('float32'),
|
||||||
|
axis=-1,
|
||||||
|
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
|
||||||
|
else attn_output_weights.dtype)
|
||||||
|
attn_output_weights = F.dropout(
|
||||||
|
attn_output_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
attn_output = paddle.bmm(attn_output_weights, v)
|
||||||
|
assert list(attn_output.
|
||||||
|
shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||||
|
attn_output = attn_output.transpose([1, 0, 2]).reshape(
|
||||||
|
[tgt_len, bsz, embed_dim])
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
if need_weights:
|
||||||
|
# average attention weights over heads
|
||||||
|
attn_output_weights = attn_output_weights.reshape(
|
||||||
|
[bsz, self.num_heads, tgt_len, src_len])
|
||||||
|
attn_output_weights = attn_output_weights.sum(
|
||||||
|
axis=1) / self.num_heads
|
||||||
|
else:
|
||||||
|
attn_output_weights = None
|
||||||
|
return attn_output, attn_output_weights
|
||||||
|
|
||||||
|
def _in_proj_q(self, query):
|
||||||
|
query = query.transpose([1, 2, 0])
|
||||||
|
query = paddle.unsqueeze(query, axis=2)
|
||||||
|
res = self.conv1(query)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _in_proj_k(self, key):
|
||||||
|
key = key.transpose([1, 2, 0])
|
||||||
|
key = paddle.unsqueeze(key, axis=2)
|
||||||
|
res = self.conv2(key)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _in_proj_v(self, value):
|
||||||
|
value = value.transpose([1, 2, 0]) #(1, 2, 0)
|
||||||
|
value = paddle.unsqueeze(value, axis=2)
|
||||||
|
res = self.conv3(value)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res
|
|
@ -0,0 +1,844 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
import copy
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle.nn import LayerList
|
||||||
|
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
||||||
|
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
|
||||||
|
import numpy as np
|
||||||
|
from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
|
||||||
|
from paddle.nn.initializer import Constant as constant_
|
||||||
|
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||||
|
|
||||||
|
zeros_ = constant_(value=0.)
|
||||||
|
ones_ = constant_(value=1.)
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Layer):
|
||||||
|
"""A transformer model. User is able to modify the attributes as needed. The architechture
|
||||||
|
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
||||||
|
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
||||||
|
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
||||||
|
Processing Systems, pages 6000-6010.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
||||||
|
nhead: the number of heads in the multiheadattention models (default=8).
|
||||||
|
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
||||||
|
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
||||||
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
|
dropout: the dropout value (default=0.1).
|
||||||
|
custom_encoder: custom encoder (default=None).
|
||||||
|
custom_decoder: custom decoder (default=None).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
d_model=512,
|
||||||
|
nhead=8,
|
||||||
|
num_encoder_layers=6,
|
||||||
|
beam_size=0,
|
||||||
|
num_decoder_layers=6,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
attention_dropout_rate=0.0,
|
||||||
|
residual_dropout_rate=0.1,
|
||||||
|
custom_encoder=None,
|
||||||
|
custom_decoder=None,
|
||||||
|
in_channels=0,
|
||||||
|
out_channels=0,
|
||||||
|
dst_vocab_size=99,
|
||||||
|
scale_embedding=True):
|
||||||
|
super(Transformer, self).__init__()
|
||||||
|
self.embedding = Embeddings(
|
||||||
|
d_model=d_model,
|
||||||
|
vocab=dst_vocab_size,
|
||||||
|
padding_idx=0,
|
||||||
|
scale_embedding=scale_embedding)
|
||||||
|
self.positional_encoding = PositionalEncoding(
|
||||||
|
dropout=residual_dropout_rate,
|
||||||
|
dim=d_model, )
|
||||||
|
if custom_encoder is not None:
|
||||||
|
self.encoder = custom_encoder
|
||||||
|
else:
|
||||||
|
if num_encoder_layers > 0:
|
||||||
|
encoder_layer = TransformerEncoderLayer(
|
||||||
|
d_model, nhead, dim_feedforward, attention_dropout_rate,
|
||||||
|
residual_dropout_rate)
|
||||||
|
self.encoder = TransformerEncoder(encoder_layer,
|
||||||
|
num_encoder_layers)
|
||||||
|
else:
|
||||||
|
self.encoder = None
|
||||||
|
|
||||||
|
if custom_decoder is not None:
|
||||||
|
self.decoder = custom_decoder
|
||||||
|
else:
|
||||||
|
decoder_layer = TransformerDecoderLayer(
|
||||||
|
d_model, nhead, dim_feedforward, attention_dropout_rate,
|
||||||
|
residual_dropout_rate)
|
||||||
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
self.beam_size = beam_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.nhead = nhead
|
||||||
|
self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False)
|
||||||
|
w0 = np.random.normal(0.0, d_model**-0.5,
|
||||||
|
(d_model, dst_vocab_size)).astype(np.float32)
|
||||||
|
self.tgt_word_prj.weight.set_value(w0)
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
|
||||||
|
if isinstance(m, nn.Conv2D):
|
||||||
|
xavier_normal_(m.weight)
|
||||||
|
if m.bias is not None:
|
||||||
|
zeros_(m.bias)
|
||||||
|
|
||||||
|
def forward_train(self, src, tgt):
|
||||||
|
tgt = tgt[:, :-1]
|
||||||
|
|
||||||
|
tgt_key_padding_mask = self.generate_padding_mask(tgt)
|
||||||
|
tgt = self.embedding(tgt).transpose([1, 0, 2])
|
||||||
|
tgt = self.positional_encoding(tgt)
|
||||||
|
tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
|
||||||
|
|
||||||
|
if self.encoder is not None:
|
||||||
|
src = self.positional_encoding(src.transpose([1, 0, 2]))
|
||||||
|
memory = self.encoder(src)
|
||||||
|
else:
|
||||||
|
memory = src.squeeze(2).transpose([2, 0, 1])
|
||||||
|
output = self.decoder(
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask=tgt_mask,
|
||||||
|
memory_mask=None,
|
||||||
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||||
|
memory_key_padding_mask=None)
|
||||||
|
output = output.transpose([1, 0, 2])
|
||||||
|
logit = self.tgt_word_prj(output)
|
||||||
|
return logit
|
||||||
|
|
||||||
|
def forward(self, src, targets=None):
|
||||||
|
"""Take in and process masked source/target sequences.
|
||||||
|
Args:
|
||||||
|
src: the sequence to the encoder (required).
|
||||||
|
tgt: the sequence to the decoder (required).
|
||||||
|
Shape:
|
||||||
|
- src: :math:`(S, N, E)`.
|
||||||
|
- tgt: :math:`(T, N, E)`.
|
||||||
|
Examples:
|
||||||
|
>>> output = transformer_model(src, tgt)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.training:
|
||||||
|
max_len = targets[1].max()
|
||||||
|
tgt = targets[0][:, :2 + max_len]
|
||||||
|
return self.forward_train(src, tgt)
|
||||||
|
else:
|
||||||
|
if self.beam_size > 0:
|
||||||
|
return self.forward_beam(src)
|
||||||
|
else:
|
||||||
|
return self.forward_test(src)
|
||||||
|
|
||||||
|
def forward_test(self, src):
|
||||||
|
bs = src.shape[0]
|
||||||
|
if self.encoder is not None:
|
||||||
|
src = self.positional_encoding(src.transpose([1, 0, 2]))
|
||||||
|
memory = self.encoder(src)
|
||||||
|
else:
|
||||||
|
memory = src.squeeze(2).transpose([2, 0, 1])
|
||||||
|
dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
|
||||||
|
for len_dec_seq in range(1, 25):
|
||||||
|
src_enc = memory.clone()
|
||||||
|
tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
|
||||||
|
dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2])
|
||||||
|
dec_seq_embed = self.positional_encoding(dec_seq_embed)
|
||||||
|
tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[
|
||||||
|
0])
|
||||||
|
output = self.decoder(
|
||||||
|
dec_seq_embed,
|
||||||
|
src_enc,
|
||||||
|
tgt_mask=tgt_mask,
|
||||||
|
memory_mask=None,
|
||||||
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||||
|
memory_key_padding_mask=None)
|
||||||
|
dec_output = output.transpose([1, 0, 2])
|
||||||
|
|
||||||
|
dec_output = dec_output[:,
|
||||||
|
-1, :] # Pick the last step: (bh * bm) * d_h
|
||||||
|
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
|
||||||
|
word_prob = word_prob.reshape([1, bs, -1])
|
||||||
|
preds_idx = word_prob.argmax(axis=2)
|
||||||
|
|
||||||
|
if paddle.equal_all(
|
||||||
|
preds_idx[-1],
|
||||||
|
paddle.full(
|
||||||
|
preds_idx[-1].shape, 3, dtype='int64')):
|
||||||
|
break
|
||||||
|
|
||||||
|
preds_prob = word_prob.max(axis=2)
|
||||||
|
dec_seq = paddle.concat(
|
||||||
|
[dec_seq, preds_idx.reshape([-1, 1])], axis=1)
|
||||||
|
|
||||||
|
return dec_seq
|
||||||
|
|
||||||
|
def forward_beam(self, images):
|
||||||
|
''' Translation work in one batch '''
|
||||||
|
|
||||||
|
def get_inst_idx_to_tensor_position_map(inst_idx_list):
|
||||||
|
''' Indicate the position of an instance in a tensor. '''
|
||||||
|
return {
|
||||||
|
inst_idx: tensor_position
|
||||||
|
for tensor_position, inst_idx in enumerate(inst_idx_list)
|
||||||
|
}
|
||||||
|
|
||||||
|
def collect_active_part(beamed_tensor, curr_active_inst_idx,
|
||||||
|
n_prev_active_inst, n_bm):
|
||||||
|
''' Collect tensor parts associated to active instances. '''
|
||||||
|
|
||||||
|
_, *d_hs = beamed_tensor.shape
|
||||||
|
n_curr_active_inst = len(curr_active_inst_idx)
|
||||||
|
new_shape = (n_curr_active_inst * n_bm, *d_hs)
|
||||||
|
|
||||||
|
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
|
||||||
|
beamed_tensor = beamed_tensor.index_select(
|
||||||
|
paddle.to_tensor(curr_active_inst_idx), axis=0)
|
||||||
|
beamed_tensor = beamed_tensor.reshape([*new_shape])
|
||||||
|
|
||||||
|
return beamed_tensor
|
||||||
|
|
||||||
|
def collate_active_info(src_enc, inst_idx_to_position_map,
|
||||||
|
active_inst_idx_list):
|
||||||
|
# Sentences which are still active are collected,
|
||||||
|
# so the decoder will not run on completed sentences.
|
||||||
|
|
||||||
|
n_prev_active_inst = len(inst_idx_to_position_map)
|
||||||
|
active_inst_idx = [
|
||||||
|
inst_idx_to_position_map[k] for k in active_inst_idx_list
|
||||||
|
]
|
||||||
|
active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
|
||||||
|
active_src_enc = collect_active_part(
|
||||||
|
src_enc.transpose([1, 0, 2]), active_inst_idx,
|
||||||
|
n_prev_active_inst, n_bm).transpose([1, 0, 2])
|
||||||
|
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
|
||||||
|
active_inst_idx_list)
|
||||||
|
return active_src_enc, active_inst_idx_to_position_map
|
||||||
|
|
||||||
|
def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
|
||||||
|
inst_idx_to_position_map, n_bm,
|
||||||
|
memory_key_padding_mask):
|
||||||
|
''' Decode and update beam status, and then return active beam idx '''
|
||||||
|
|
||||||
|
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
|
||||||
|
dec_partial_seq = [
|
||||||
|
b.get_current_state() for b in inst_dec_beams if not b.done
|
||||||
|
]
|
||||||
|
dec_partial_seq = paddle.stack(dec_partial_seq)
|
||||||
|
|
||||||
|
dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
|
||||||
|
return dec_partial_seq
|
||||||
|
|
||||||
|
def prepare_beam_memory_key_padding_mask(
|
||||||
|
inst_dec_beams, memory_key_padding_mask, n_bm):
|
||||||
|
keep = []
|
||||||
|
for idx in (memory_key_padding_mask):
|
||||||
|
if not inst_dec_beams[idx].done:
|
||||||
|
keep.append(idx)
|
||||||
|
memory_key_padding_mask = memory_key_padding_mask[
|
||||||
|
paddle.to_tensor(keep)]
|
||||||
|
len_s = memory_key_padding_mask.shape[-1]
|
||||||
|
n_inst = memory_key_padding_mask.shape[0]
|
||||||
|
memory_key_padding_mask = paddle.concat(
|
||||||
|
[memory_key_padding_mask for i in range(n_bm)], axis=1)
|
||||||
|
memory_key_padding_mask = memory_key_padding_mask.reshape(
|
||||||
|
[n_inst * n_bm, len_s]) #repeat(1, n_bm)
|
||||||
|
return memory_key_padding_mask
|
||||||
|
|
||||||
|
def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
|
||||||
|
memory_key_padding_mask):
|
||||||
|
tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
|
||||||
|
dec_seq = self.embedding(dec_seq).transpose([1, 0, 2])
|
||||||
|
dec_seq = self.positional_encoding(dec_seq)
|
||||||
|
tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[
|
||||||
|
0])
|
||||||
|
dec_output = self.decoder(
|
||||||
|
dec_seq,
|
||||||
|
enc_output,
|
||||||
|
tgt_mask=tgt_mask,
|
||||||
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
).transpose([1, 0, 2])
|
||||||
|
dec_output = dec_output[:,
|
||||||
|
-1, :] # Pick the last step: (bh * bm) * d_h
|
||||||
|
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
|
||||||
|
word_prob = word_prob.reshape([n_active_inst, n_bm, -1])
|
||||||
|
return word_prob
|
||||||
|
|
||||||
|
def collect_active_inst_idx_list(inst_beams, word_prob,
|
||||||
|
inst_idx_to_position_map):
|
||||||
|
active_inst_idx_list = []
|
||||||
|
for inst_idx, inst_position in inst_idx_to_position_map.items():
|
||||||
|
is_inst_complete = inst_beams[inst_idx].advance(word_prob[
|
||||||
|
inst_position])
|
||||||
|
if not is_inst_complete:
|
||||||
|
active_inst_idx_list += [inst_idx]
|
||||||
|
|
||||||
|
return active_inst_idx_list
|
||||||
|
|
||||||
|
n_active_inst = len(inst_idx_to_position_map)
|
||||||
|
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
|
||||||
|
memory_key_padding_mask = None
|
||||||
|
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
|
||||||
|
memory_key_padding_mask)
|
||||||
|
# Update the beam with predicted word prob information and collect incomplete instances
|
||||||
|
active_inst_idx_list = collect_active_inst_idx_list(
|
||||||
|
inst_dec_beams, word_prob, inst_idx_to_position_map)
|
||||||
|
return active_inst_idx_list
|
||||||
|
|
||||||
|
def collect_hypothesis_and_scores(inst_dec_beams, n_best):
|
||||||
|
all_hyp, all_scores = [], []
|
||||||
|
for inst_idx in range(len(inst_dec_beams)):
|
||||||
|
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
|
||||||
|
all_scores += [scores[:n_best]]
|
||||||
|
hyps = [
|
||||||
|
inst_dec_beams[inst_idx].get_hypothesis(i)
|
||||||
|
for i in tail_idxs[:n_best]
|
||||||
|
]
|
||||||
|
all_hyp += [hyps]
|
||||||
|
return all_hyp, all_scores
|
||||||
|
|
||||||
|
with paddle.no_grad():
|
||||||
|
#-- Encode
|
||||||
|
|
||||||
|
if self.encoder is not None:
|
||||||
|
src = self.positional_encoding(images.transpose([1, 0, 2]))
|
||||||
|
src_enc = self.encoder(src).transpose([1, 0, 2])
|
||||||
|
else:
|
||||||
|
src_enc = images.squeeze(2).transpose([0, 2, 1])
|
||||||
|
|
||||||
|
#-- Repeat data for beam search
|
||||||
|
n_bm = self.beam_size
|
||||||
|
n_inst, len_s, d_h = src_enc.shape
|
||||||
|
src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
|
||||||
|
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
|
||||||
|
[1, 0, 2])
|
||||||
|
#-- Prepare beams
|
||||||
|
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
|
||||||
|
|
||||||
|
#-- Bookkeeping for active or not
|
||||||
|
active_inst_idx_list = list(range(n_inst))
|
||||||
|
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
|
||||||
|
active_inst_idx_list)
|
||||||
|
#-- Decode
|
||||||
|
for len_dec_seq in range(1, 25):
|
||||||
|
src_enc_copy = src_enc.clone()
|
||||||
|
active_inst_idx_list = beam_decode_step(
|
||||||
|
inst_dec_beams, len_dec_seq, src_enc_copy,
|
||||||
|
inst_idx_to_position_map, n_bm, None)
|
||||||
|
if not active_inst_idx_list:
|
||||||
|
break # all instances have finished their path to <EOS>
|
||||||
|
src_enc, inst_idx_to_position_map = collate_active_info(
|
||||||
|
src_enc_copy, inst_idx_to_position_map,
|
||||||
|
active_inst_idx_list)
|
||||||
|
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
|
||||||
|
1)
|
||||||
|
result_hyp = []
|
||||||
|
for bs_hyp in batch_hyp:
|
||||||
|
bs_hyp_pad = bs_hyp[0] + [3] * (25 - len(bs_hyp[0]))
|
||||||
|
result_hyp.append(bs_hyp_pad)
|
||||||
|
return paddle.to_tensor(np.array(result_hyp), dtype=paddle.int64)
|
||||||
|
|
||||||
|
def generate_square_subsequent_mask(self, sz):
|
||||||
|
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
||||||
|
Unmasked positions are filled with float(0.0).
|
||||||
|
"""
|
||||||
|
mask = paddle.zeros([sz, sz], dtype='float32')
|
||||||
|
mask_inf = paddle.triu(
|
||||||
|
paddle.full(
|
||||||
|
shape=[sz, sz], dtype='float32', fill_value='-inf'),
|
||||||
|
diagonal=1)
|
||||||
|
mask = mask + mask_inf
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def generate_padding_mask(self, x):
|
||||||
|
padding_mask = x.equal(paddle.to_tensor(0, dtype=x.dtype))
|
||||||
|
return padding_mask
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
"""Initiate parameters in the transformer model."""
|
||||||
|
|
||||||
|
for p in self.parameters():
|
||||||
|
if p.dim() > 1:
|
||||||
|
xavier_uniform_(p)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Layer):
|
||||||
|
"""TransformerEncoder is a stack of N encoder layers
|
||||||
|
Args:
|
||||||
|
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
||||||
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
|
norm: the layer normalization component (optional).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, encoder_layer, num_layers):
|
||||||
|
super(TransformerEncoder, self).__init__()
|
||||||
|
self.layers = _get_clones(encoder_layer, num_layers)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
def forward(self, src):
|
||||||
|
"""Pass the input through the endocder layers in turn.
|
||||||
|
Args:
|
||||||
|
src: the sequnce to the encoder (required).
|
||||||
|
mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
"""
|
||||||
|
output = src
|
||||||
|
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
output = self.layers[i](output,
|
||||||
|
src_mask=None,
|
||||||
|
src_key_padding_mask=None)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Layer):
|
||||||
|
"""TransformerDecoder is a stack of N decoder layers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
||||||
|
num_layers: the number of sub-decoder-layers in the decoder (required).
|
||||||
|
norm: the layer normalization component (optional).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, decoder_layer, num_layers):
|
||||||
|
super(TransformerDecoder, self).__init__()
|
||||||
|
self.layers = _get_clones(decoder_layer, num_layers)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask=None,
|
||||||
|
memory_mask=None,
|
||||||
|
tgt_key_padding_mask=None,
|
||||||
|
memory_key_padding_mask=None):
|
||||||
|
"""Pass the inputs (and mask) through the decoder layer in turn.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tgt: the sequence to the decoder (required).
|
||||||
|
memory: the sequnce from the last layer of the encoder (required).
|
||||||
|
tgt_mask: the mask for the tgt sequence (optional).
|
||||||
|
memory_mask: the mask for the memory sequence (optional).
|
||||||
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
||||||
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
||||||
|
"""
|
||||||
|
output = tgt
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
output = self.layers[i](
|
||||||
|
output,
|
||||||
|
memory,
|
||||||
|
tgt_mask=tgt_mask,
|
||||||
|
memory_mask=memory_mask,
|
||||||
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Layer):
|
||||||
|
"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
||||||
|
This standard encoder layer is based on the paper "Attention Is All You Need".
|
||||||
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
||||||
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
||||||
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
||||||
|
in a different way during application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model: the number of expected features in the input (required).
|
||||||
|
nhead: the number of heads in the multiheadattention models (required).
|
||||||
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
|
dropout: the dropout value (default=0.1).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
attention_dropout_rate=0.0,
|
||||||
|
residual_dropout_rate=0.1):
|
||||||
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
|
self.self_attn = MultiheadAttention(
|
||||||
|
d_model, nhead, dropout=attention_dropout_rate)
|
||||||
|
|
||||||
|
self.conv1 = Conv2D(
|
||||||
|
in_channels=d_model,
|
||||||
|
out_channels=dim_feedforward,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
self.conv2 = Conv2D(
|
||||||
|
in_channels=dim_feedforward,
|
||||||
|
out_channels=d_model,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm(d_model)
|
||||||
|
self.norm2 = LayerNorm(d_model)
|
||||||
|
self.dropout1 = Dropout(residual_dropout_rate)
|
||||||
|
self.dropout2 = Dropout(residual_dropout_rate)
|
||||||
|
|
||||||
|
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
||||||
|
"""Pass the input through the endocder layer.
|
||||||
|
Args:
|
||||||
|
src: the sequnce to the encoder layer (required).
|
||||||
|
src_mask: the mask for the src sequence (optional).
|
||||||
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
"""
|
||||||
|
src2 = self.self_attn(
|
||||||
|
src,
|
||||||
|
src,
|
||||||
|
src,
|
||||||
|
attn_mask=src_mask,
|
||||||
|
key_padding_mask=src_key_padding_mask)[0]
|
||||||
|
src = src + self.dropout1(src2)
|
||||||
|
src = self.norm1(src)
|
||||||
|
|
||||||
|
src = src.transpose([1, 2, 0])
|
||||||
|
src = paddle.unsqueeze(src, 2)
|
||||||
|
src2 = self.conv2(F.relu(self.conv1(src)))
|
||||||
|
src2 = paddle.squeeze(src2, 2)
|
||||||
|
src2 = src2.transpose([2, 0, 1])
|
||||||
|
src = paddle.squeeze(src, 2)
|
||||||
|
src = src.transpose([2, 0, 1])
|
||||||
|
|
||||||
|
src = src + self.dropout2(src2)
|
||||||
|
src = self.norm2(src)
|
||||||
|
return src
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderLayer(nn.Layer):
|
||||||
|
"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
||||||
|
This standard decoder layer is based on the paper "Attention Is All You Need".
|
||||||
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
||||||
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
||||||
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
||||||
|
in a different way during application.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
d_model: the number of expected features in the input (required).
|
||||||
|
nhead: the number of heads in the multiheadattention models (required).
|
||||||
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
|
dropout: the dropout value (default=0.1).
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
attention_dropout_rate=0.0,
|
||||||
|
residual_dropout_rate=0.1):
|
||||||
|
super(TransformerDecoderLayer, self).__init__()
|
||||||
|
self.self_attn = MultiheadAttention(
|
||||||
|
d_model, nhead, dropout=attention_dropout_rate)
|
||||||
|
self.multihead_attn = MultiheadAttention(
|
||||||
|
d_model, nhead, dropout=attention_dropout_rate)
|
||||||
|
|
||||||
|
self.conv1 = Conv2D(
|
||||||
|
in_channels=d_model,
|
||||||
|
out_channels=dim_feedforward,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
self.conv2 = Conv2D(
|
||||||
|
in_channels=dim_feedforward,
|
||||||
|
out_channels=d_model,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
|
||||||
|
self.norm1 = LayerNorm(d_model)
|
||||||
|
self.norm2 = LayerNorm(d_model)
|
||||||
|
self.norm3 = LayerNorm(d_model)
|
||||||
|
self.dropout1 = Dropout(residual_dropout_rate)
|
||||||
|
self.dropout2 = Dropout(residual_dropout_rate)
|
||||||
|
self.dropout3 = Dropout(residual_dropout_rate)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask=None,
|
||||||
|
memory_mask=None,
|
||||||
|
tgt_key_padding_mask=None,
|
||||||
|
memory_key_padding_mask=None):
|
||||||
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tgt: the sequence to the decoder layer (required).
|
||||||
|
memory: the sequnce from the last layer of the encoder (required).
|
||||||
|
tgt_mask: the mask for the tgt sequence (optional).
|
||||||
|
memory_mask: the mask for the memory sequence (optional).
|
||||||
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
||||||
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
||||||
|
|
||||||
|
"""
|
||||||
|
tgt2 = self.self_attn(
|
||||||
|
tgt,
|
||||||
|
tgt,
|
||||||
|
tgt,
|
||||||
|
attn_mask=tgt_mask,
|
||||||
|
key_padding_mask=tgt_key_padding_mask)[0]
|
||||||
|
tgt = tgt + self.dropout1(tgt2)
|
||||||
|
tgt = self.norm1(tgt)
|
||||||
|
tgt2 = self.multihead_attn(
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
memory,
|
||||||
|
attn_mask=memory_mask,
|
||||||
|
key_padding_mask=memory_key_padding_mask)[0]
|
||||||
|
tgt = tgt + self.dropout2(tgt2)
|
||||||
|
tgt = self.norm2(tgt)
|
||||||
|
|
||||||
|
# default
|
||||||
|
tgt = tgt.transpose([1, 2, 0])
|
||||||
|
tgt = paddle.unsqueeze(tgt, 2)
|
||||||
|
tgt2 = self.conv2(F.relu(self.conv1(tgt)))
|
||||||
|
tgt2 = paddle.squeeze(tgt2, 2)
|
||||||
|
tgt2 = tgt2.transpose([2, 0, 1])
|
||||||
|
tgt = paddle.squeeze(tgt, 2)
|
||||||
|
tgt = tgt.transpose([2, 0, 1])
|
||||||
|
|
||||||
|
tgt = tgt + self.dropout3(tgt2)
|
||||||
|
tgt = self.norm3(tgt)
|
||||||
|
return tgt
|
||||||
|
|
||||||
|
|
||||||
|
def _get_clones(module, N):
|
||||||
|
return LayerList([copy.deepcopy(module) for i in range(N)])
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding(nn.Layer):
|
||||||
|
"""Inject some information about the relative or absolute position of the tokens
|
||||||
|
in the sequence. The positional encodings have the same dimension as
|
||||||
|
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
||||||
|
functions of different frequencies.
|
||||||
|
.. math::
|
||||||
|
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
||||||
|
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
||||||
|
\text{where pos is the word position and i is the embed idx)
|
||||||
|
Args:
|
||||||
|
d_model: the embed dim (required).
|
||||||
|
dropout: the dropout value (default=0.1).
|
||||||
|
max_len: the max. length of the incoming sequence (default=5000).
|
||||||
|
Examples:
|
||||||
|
>>> pos_encoder = PositionalEncoding(d_model)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dropout, dim, max_len=5000):
|
||||||
|
super(PositionalEncoding, self).__init__()
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
|
pe = paddle.zeros([max_len, dim])
|
||||||
|
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
||||||
|
div_term = paddle.exp(
|
||||||
|
paddle.arange(0, dim, 2).astype('float32') *
|
||||||
|
(-math.log(10000.0) / dim))
|
||||||
|
pe[:, 0::2] = paddle.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = paddle.cos(position * div_term)
|
||||||
|
pe = pe.unsqueeze(0)
|
||||||
|
pe = pe.transpose([1, 0, 2])
|
||||||
|
self.register_buffer('pe', pe)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Inputs of forward function
|
||||||
|
Args:
|
||||||
|
x: the sequence fed to the positional encoder model (required).
|
||||||
|
Shape:
|
||||||
|
x: [sequence length, batch size, embed dim]
|
||||||
|
output: [sequence length, batch size, embed dim]
|
||||||
|
Examples:
|
||||||
|
>>> output = pos_encoder(x)
|
||||||
|
"""
|
||||||
|
x = x + self.pe[:x.shape[0], :]
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding_2d(nn.Layer):
|
||||||
|
"""Inject some information about the relative or absolute position of the tokens
|
||||||
|
in the sequence. The positional encodings have the same dimension as
|
||||||
|
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
||||||
|
functions of different frequencies.
|
||||||
|
.. math::
|
||||||
|
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
||||||
|
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
||||||
|
\text{where pos is the word position and i is the embed idx)
|
||||||
|
Args:
|
||||||
|
d_model: the embed dim (required).
|
||||||
|
dropout: the dropout value (default=0.1).
|
||||||
|
max_len: the max. length of the incoming sequence (default=5000).
|
||||||
|
Examples:
|
||||||
|
>>> pos_encoder = PositionalEncoding(d_model)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dropout, dim, max_len=5000):
|
||||||
|
super(PositionalEncoding_2d, self).__init__()
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
|
pe = paddle.zeros([max_len, dim])
|
||||||
|
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
||||||
|
div_term = paddle.exp(
|
||||||
|
paddle.arange(0, dim, 2).astype('float32') *
|
||||||
|
(-math.log(10000.0) / dim))
|
||||||
|
pe[:, 0::2] = paddle.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = paddle.cos(position * div_term)
|
||||||
|
pe = pe.unsqueeze(0).transpose([1, 0, 2])
|
||||||
|
self.register_buffer('pe', pe)
|
||||||
|
|
||||||
|
self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
|
||||||
|
self.linear1 = nn.Linear(dim, dim)
|
||||||
|
self.linear1.weight.data.fill_(1.)
|
||||||
|
self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
|
||||||
|
self.linear2 = nn.Linear(dim, dim)
|
||||||
|
self.linear2.weight.data.fill_(1.)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Inputs of forward function
|
||||||
|
Args:
|
||||||
|
x: the sequence fed to the positional encoder model (required).
|
||||||
|
Shape:
|
||||||
|
x: [sequence length, batch size, embed dim]
|
||||||
|
output: [sequence length, batch size, embed dim]
|
||||||
|
Examples:
|
||||||
|
>>> output = pos_encoder(x)
|
||||||
|
"""
|
||||||
|
w_pe = self.pe[:x.shape[-1], :]
|
||||||
|
w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
|
||||||
|
w_pe = w_pe * w1
|
||||||
|
w_pe = w_pe.transpose([1, 2, 0])
|
||||||
|
w_pe = w_pe.unsqueeze(2)
|
||||||
|
|
||||||
|
h_pe = self.pe[:x.shape[-2], :]
|
||||||
|
w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
|
||||||
|
h_pe = h_pe * w2
|
||||||
|
h_pe = h_pe.transpose([1, 2, 0])
|
||||||
|
h_pe = h_pe.unsqueeze(3)
|
||||||
|
|
||||||
|
x = x + w_pe + h_pe
|
||||||
|
x = x.reshape(
|
||||||
|
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose(
|
||||||
|
[2, 0, 1])
|
||||||
|
|
||||||
|
return self.dropout(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings(nn.Layer):
|
||||||
|
def __init__(self, d_model, vocab, padding_idx, scale_embedding):
|
||||||
|
super(Embeddings, self).__init__()
|
||||||
|
self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
|
||||||
|
w0 = np.random.normal(0.0, d_model**-0.5,
|
||||||
|
(vocab, d_model)).astype(np.float32)
|
||||||
|
self.embedding.weight.set_value(w0)
|
||||||
|
self.d_model = d_model
|
||||||
|
self.scale_embedding = scale_embedding
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.scale_embedding:
|
||||||
|
x = self.embedding(x)
|
||||||
|
return x * math.sqrt(self.d_model)
|
||||||
|
return self.embedding(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Beam():
|
||||||
|
''' Beam search '''
|
||||||
|
|
||||||
|
def __init__(self, size, device=False):
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self._done = False
|
||||||
|
# The score for each translation on the beam.
|
||||||
|
self.scores = paddle.zeros((size, ), dtype=paddle.float32)
|
||||||
|
self.all_scores = []
|
||||||
|
# The backpointers at each time-step.
|
||||||
|
self.prev_ks = []
|
||||||
|
# The outputs at each time-step.
|
||||||
|
self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
|
||||||
|
self.next_ys[0][0] = 2
|
||||||
|
|
||||||
|
def get_current_state(self):
|
||||||
|
"Get the outputs for the current timestep."
|
||||||
|
return self.get_tentative_hypothesis()
|
||||||
|
|
||||||
|
def get_current_origin(self):
|
||||||
|
"Get the backpointers for the current timestep."
|
||||||
|
return self.prev_ks[-1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def done(self):
|
||||||
|
return self._done
|
||||||
|
|
||||||
|
def advance(self, word_prob):
|
||||||
|
"Update beam status and check if finished or not."
|
||||||
|
num_words = word_prob.shape[1]
|
||||||
|
|
||||||
|
# Sum the previous scores.
|
||||||
|
if len(self.prev_ks) > 0:
|
||||||
|
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
|
||||||
|
else:
|
||||||
|
beam_lk = word_prob[0]
|
||||||
|
|
||||||
|
flat_beam_lk = beam_lk.reshape([-1])
|
||||||
|
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
|
||||||
|
True) # 1st sort
|
||||||
|
self.all_scores.append(self.scores)
|
||||||
|
self.scores = best_scores
|
||||||
|
# bestScoresId is flattened as a (beam x word) array,
|
||||||
|
# so we need to calculate which word and beam each score came from
|
||||||
|
prev_k = best_scores_id // num_words
|
||||||
|
self.prev_ks.append(prev_k)
|
||||||
|
self.next_ys.append(best_scores_id - prev_k * num_words)
|
||||||
|
# End condition is when top-of-beam is EOS.
|
||||||
|
if self.next_ys[-1][0] == 3:
|
||||||
|
self._done = True
|
||||||
|
self.all_scores.append(self.scores)
|
||||||
|
|
||||||
|
return self._done
|
||||||
|
|
||||||
|
def sort_scores(self):
|
||||||
|
"Sort the scores."
|
||||||
|
return self.scores, paddle.to_tensor(
|
||||||
|
[i for i in range(self.scores.shape[0])], dtype='int32')
|
||||||
|
|
||||||
|
def get_the_best_score_and_idx(self):
|
||||||
|
"Get the score of the best in the beam."
|
||||||
|
scores, ids = self.sort_scores()
|
||||||
|
return scores[1], ids[1]
|
||||||
|
|
||||||
|
def get_tentative_hypothesis(self):
|
||||||
|
"Get the decoded sequence for the current timestep."
|
||||||
|
if len(self.next_ys) == 1:
|
||||||
|
dec_seq = self.next_ys[0].unsqueeze(1)
|
||||||
|
else:
|
||||||
|
_, keys = self.sort_scores()
|
||||||
|
hyps = [self.get_hypothesis(k) for k in keys]
|
||||||
|
hyps = [[2] + h for h in hyps]
|
||||||
|
dec_seq = paddle.to_tensor(hyps, dtype='int64')
|
||||||
|
return dec_seq
|
||||||
|
|
||||||
|
def get_hypothesis(self, k):
|
||||||
|
""" Walk back to construct the full hypothesis. """
|
||||||
|
hyp = []
|
||||||
|
for j in range(len(self.prev_ks) - 1, -1, -1):
|
||||||
|
hyp.append(self.next_ys[j + 1][k])
|
||||||
|
k = self.prev_ks[j][k]
|
||||||
|
return list(map(lambda x: x.item(), hyp[::-1]))
|
|
@ -24,18 +24,16 @@ __all__ = ['build_post_process']
|
||||||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||||
from .east_postprocess import EASTPostProcess
|
from .east_postprocess import EASTPostProcess
|
||||||
from .sast_postprocess import SASTPostProcess
|
from .sast_postprocess import SASTPostProcess
|
||||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode, \
|
||||||
TableLabelDecode
|
TableLabelDecode
|
||||||
from .cls_postprocess import ClsPostProcess
|
from .cls_postprocess import ClsPostProcess
|
||||||
from .pg_postprocess import PGPostProcess
|
from .pg_postprocess import PGPostProcess
|
||||||
|
|
||||||
|
|
||||||
def build_post_process(config, global_config=None):
|
def build_post_process(config, global_config=None):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
'DistillationCTCLabelDecode', 'NRTRLabelDecode', 'TableLabelDecode', 'DistillationDBPostProcess'
|
||||||
'DistillationDBPostProcess'
|
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -156,6 +156,69 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class NRTRLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
character_dict_path=None,
|
||||||
|
character_type='EN_symbol',
|
||||||
|
use_space_char=True,
|
||||||
|
**kwargs):
|
||||||
|
super(NRTRLabelDecode, self).__init__(character_dict_path,
|
||||||
|
character_type, use_space_char)
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
if preds.dtype == paddle.int64:
|
||||||
|
if isinstance(preds, paddle.Tensor):
|
||||||
|
preds = preds.numpy()
|
||||||
|
if preds[0][0]==2:
|
||||||
|
preds_idx = preds[:,1:]
|
||||||
|
else:
|
||||||
|
preds_idx = preds
|
||||||
|
|
||||||
|
text = self.decode(preds_idx)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label[:,1:])
|
||||||
|
else:
|
||||||
|
if isinstance(preds, paddle.Tensor):
|
||||||
|
preds = preds.numpy()
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label[:,1:])
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = ['blank','<unk>','<s>','</s>'] + dict_character
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] == 3: # end
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text.lower(), np.mean(conf_list)))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AttnLabelDecode(BaseRecLabelDecode):
|
class AttnLabelDecode(BaseRecLabelDecode):
|
||||||
""" Convert between text-label and text-index """
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
@ -193,8 +256,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
||||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
batch_idx][idx]:
|
batch_idx][idx]:
|
||||||
continue
|
continue
|
||||||
char_list.append(self.character[int(text_index[batch_idx][
|
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
||||||
idx])])
|
|
||||||
if text_prob is not None:
|
if text_prob is not None:
|
||||||
conf_list.append(text_prob[batch_idx][idx])
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -186,6 +186,8 @@ def train(config,
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
|
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_type = config['Architecture']['model_type']
|
model_type = config['Architecture']['model_type']
|
||||||
except:
|
except:
|
||||||
|
@ -213,7 +215,7 @@ def train(config,
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
if use_srn:
|
if use_srn:
|
||||||
model_average = True
|
model_average = True
|
||||||
if use_srn or model_type == 'table':
|
if use_srn or model_type == 'table' or use_nrtr:
|
||||||
preds = model(images, data=batch[1:])
|
preds = model(images, data=batch[1:])
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
|
@ -398,7 +400,7 @@ def preprocess(is_train=False):
|
||||||
alg = config['Architecture']['algorithm']
|
alg = config['Architecture']['algorithm']
|
||||||
assert alg in [
|
assert alg in [
|
||||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||||
'CLS', 'PGNet', 'Distillation', 'TableAttn'
|
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
|
||||||
]
|
]
|
||||||
|
|
||||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||||
|
|
Loading…
Reference in New Issue