delete blank lines and modify forward_train
This commit is contained in:
parent
a11e219970
commit
55b76dcaa5
|
@ -46,7 +46,7 @@ Architecture:
|
||||||
name: TransformerOptim
|
name: TransformerOptim
|
||||||
d_model: 512
|
d_model: 512
|
||||||
num_encoder_layers: 6
|
num_encoder_layers: 6
|
||||||
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
|
beam_size: 10 # When Beam size is greater than 0, it means to use beam search when evaluation.
|
||||||
|
|
||||||
|
|
||||||
Loss:
|
Loss:
|
||||||
|
|
|
@ -27,8 +27,9 @@ def build_backbone(config, model_type):
|
||||||
from .rec_resnet_fpn import ResNetFPN
|
from .rec_resnet_fpn import ResNetFPN
|
||||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||||
from .rec_nrtr_mtb import MTB
|
from .rec_nrtr_mtb import MTB
|
||||||
from .rec_swin import SwinTransformer
|
support_dict = [
|
||||||
support_dict = ['MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'SwinTransformer']
|
'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB'
|
||||||
|
]
|
||||||
elif model_type == "e2e":
|
elif model_type == "e2e":
|
||||||
from .e2e_resnet_vd_pg import ResNet
|
from .e2e_resnet_vd_pg import ResNet
|
||||||
support_dict = ["ResNet"]
|
support_dict = ["ResNet"]
|
||||||
|
|
|
@ -1,5 +1,20 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
class MTB(nn.Layer):
|
class MTB(nn.Layer):
|
||||||
def __init__(self, cnn_num, in_channels):
|
def __init__(self, cnn_num, in_channels):
|
||||||
super(MTB, self).__init__()
|
super(MTB, self).__init__()
|
||||||
|
@ -8,17 +23,20 @@ class MTB(nn.Layer):
|
||||||
self.cnn_num = cnn_num
|
self.cnn_num = cnn_num
|
||||||
if self.cnn_num == 2:
|
if self.cnn_num == 2:
|
||||||
for i in range(self.cnn_num):
|
for i in range(self.cnn_num):
|
||||||
self.block.add_sublayer('conv_{}'.format(i), nn.Conv2D(
|
self.block.add_sublayer(
|
||||||
in_channels = in_channels if i == 0 else 32*(2**(i-1)),
|
'conv_{}'.format(i),
|
||||||
out_channels = 32*(2**i),
|
nn.Conv2D(
|
||||||
kernel_size = 3,
|
in_channels=in_channels
|
||||||
stride = 2,
|
if i == 0 else 32 * (2**(i - 1)),
|
||||||
padding=1))
|
out_channels=32 * (2**i),
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1))
|
||||||
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
|
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
|
||||||
self.block.add_sublayer('bn_{}'.format(i), nn.BatchNorm2D(32*(2**i)))
|
self.block.add_sublayer('bn_{}'.format(i),
|
||||||
|
nn.BatchNorm2D(32 * (2**i)))
|
||||||
|
|
||||||
def forward(self, images):
|
def forward(self, images):
|
||||||
|
|
||||||
x = self.block(images)
|
x = self.block(images)
|
||||||
if self.cnn_num == 2:
|
if self.cnn_num == 2:
|
||||||
# (b, w, h, c)
|
# (b, w, h, c)
|
||||||
|
|
|
@ -27,14 +27,13 @@ def build_head(config):
|
||||||
from .rec_att_head import AttentionHead
|
from .rec_att_head import AttentionHead
|
||||||
from .rec_srn_head import SRNHead
|
from .rec_srn_head import SRNHead
|
||||||
from .rec_nrtr_optim_head import TransformerOptim
|
from .rec_nrtr_optim_head import TransformerOptim
|
||||||
|
|
||||||
# cls head
|
# cls head
|
||||||
from .cls_head import ClsHead
|
from .cls_head import ClsHead
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||||
|
'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead'
|
||||||
'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead']
|
]
|
||||||
|
|
||||||
|
|
||||||
#table head
|
#table head
|
||||||
from .table_att_head import TableAttentionHead
|
from .table_att_head import TableAttentionHead
|
||||||
|
|
|
@ -1,3 +1,17 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import paddle
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
import paddle.nn.functional as F
|
import paddle.nn.functional as F
|
||||||
|
@ -11,7 +25,7 @@ ones_ = constant_(value=1.)
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttentionOptim(nn.Layer):
|
class MultiheadAttentionOptim(nn.Layer):
|
||||||
r"""Allows the model to jointly attend to information
|
"""Allows the model to jointly attend to information
|
||||||
from different representation subspaces.
|
from different representation subspaces.
|
||||||
See reference: Attention Is All You Need
|
See reference: Attention Is All You Need
|
||||||
|
|
||||||
|
@ -23,37 +37,43 @@ class MultiheadAttentionOptim(nn.Layer):
|
||||||
embed_dim: total dimension of the model
|
embed_dim: total dimension of the model
|
||||||
num_heads: parallel attention layers, or heads
|
num_heads: parallel attention layers, or heads
|
||||||
|
|
||||||
Examples::
|
|
||||||
|
|
||||||
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
|
||||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
|
def __init__(self,
|
||||||
|
embed_dim,
|
||||||
|
num_heads,
|
||||||
|
dropout=0.,
|
||||||
|
bias=True,
|
||||||
|
add_bias_kv=False,
|
||||||
|
add_zero_attn=False):
|
||||||
super(MultiheadAttentionOptim, self).__init__()
|
super(MultiheadAttentionOptim, self).__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.head_dim = embed_dim // num_heads
|
self.head_dim = embed_dim // num_heads
|
||||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||||
self.scaling = self.head_dim ** -0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
|
|
||||||
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
|
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
|
||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
self.conv1 = paddle.nn.Conv2D(
|
||||||
self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||||
self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
self.conv2 = paddle.nn.Conv2D(
|
||||||
self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||||
|
self.conv3 = paddle.nn.Conv2D(
|
||||||
|
in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
|
|
||||||
|
|
||||||
xavier_uniform_(self.out_proj.weight)
|
xavier_uniform_(self.out_proj.weight)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
|
query,
|
||||||
need_weights=True, static_kv=False, attn_mask=None):
|
key,
|
||||||
|
value,
|
||||||
|
key_padding_mask=None,
|
||||||
|
incremental_state=None,
|
||||||
|
need_weights=True,
|
||||||
|
static_kv=False,
|
||||||
|
attn_mask=None):
|
||||||
"""
|
"""
|
||||||
Inputs of forward function
|
Inputs of forward function
|
||||||
query: [target length, batch size, embed dim]
|
query: [target length, batch size, embed dim]
|
||||||
|
@ -68,8 +88,6 @@ class MultiheadAttentionOptim(nn.Layer):
|
||||||
attn_output: [target length, batch size, embed dim]
|
attn_output: [target length, batch size, embed dim]
|
||||||
attn_output_weights: [batch size, target length, sequence length]
|
attn_output_weights: [batch size, target length, sequence length]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
tgt_len, bsz, embed_dim = query.shape
|
tgt_len, bsz, embed_dim = query.shape
|
||||||
assert embed_dim == self.embed_dim
|
assert embed_dim == self.embed_dim
|
||||||
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
||||||
|
@ -80,11 +98,12 @@ class MultiheadAttentionOptim(nn.Layer):
|
||||||
v = self._in_proj_v(value)
|
v = self._in_proj_v(value)
|
||||||
q *= self.scaling
|
q *= self.scaling
|
||||||
|
|
||||||
|
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose(
|
||||||
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
[1, 0, 2])
|
||||||
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
|
||||||
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
[1, 0, 2])
|
||||||
|
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose(
|
||||||
|
[1, 0, 2])
|
||||||
|
|
||||||
src_len = k.shape[1]
|
src_len = k.shape[1]
|
||||||
|
|
||||||
|
@ -92,44 +111,48 @@ class MultiheadAttentionOptim(nn.Layer):
|
||||||
assert key_padding_mask.shape[0] == bsz
|
assert key_padding_mask.shape[0] == bsz
|
||||||
assert key_padding_mask.shape[1] == src_len
|
assert key_padding_mask.shape[1] == src_len
|
||||||
|
|
||||||
|
attn_output_weights = paddle.bmm(q, k.transpose([0, 2, 1]))
|
||||||
attn_output_weights = paddle.bmm(q, k.transpose([0,2,1]))
|
assert list(attn_output_weights.
|
||||||
assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
|
shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
attn_mask = attn_mask.unsqueeze(0)
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
attn_output_weights += attn_mask
|
attn_output_weights += attn_mask
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
attn_output_weights = attn_output_weights.reshape(
|
||||||
|
[bsz, self.num_heads, tgt_len, src_len])
|
||||||
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
|
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
|
||||||
|
|
||||||
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
|
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
|
||||||
|
y = paddle.where(key == 0., key, y)
|
||||||
y = paddle.where(key==0.,key, y)
|
|
||||||
|
|
||||||
attn_output_weights += y
|
attn_output_weights += y
|
||||||
attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len])
|
attn_output_weights = attn_output_weights.reshape(
|
||||||
|
[bsz * self.num_heads, tgt_len, src_len])
|
||||||
|
|
||||||
attn_output_weights = F.softmax(
|
attn_output_weights = F.softmax(
|
||||||
attn_output_weights.astype('float32'), axis=-1,
|
attn_output_weights.astype('float32'),
|
||||||
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype)
|
axis=-1,
|
||||||
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
|
||||||
|
else attn_output_weights.dtype)
|
||||||
|
attn_output_weights = F.dropout(
|
||||||
|
attn_output_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
attn_output = paddle.bmm(attn_output_weights, v)
|
attn_output = paddle.bmm(attn_output_weights, v)
|
||||||
assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
assert list(attn_output.
|
||||||
attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim])
|
shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||||
|
attn_output = attn_output.transpose([1, 0, 2]).reshape(
|
||||||
|
[tgt_len, bsz, embed_dim])
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
if need_weights:
|
if need_weights:
|
||||||
# average attention weights over heads
|
# average attention weights over heads
|
||||||
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
attn_output_weights = attn_output_weights.reshape(
|
||||||
attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads
|
[bsz, self.num_heads, tgt_len, src_len])
|
||||||
|
attn_output_weights = attn_output_weights.sum(
|
||||||
|
axis=1) / self.num_heads
|
||||||
else:
|
else:
|
||||||
attn_output_weights = None
|
attn_output_weights = None
|
||||||
|
|
||||||
return attn_output, attn_output_weights
|
return attn_output, attn_output_weights
|
||||||
|
|
||||||
|
|
||||||
def _in_proj_q(self, query):
|
def _in_proj_q(self, query):
|
||||||
query = query.transpose([1, 2, 0])
|
query = query.transpose([1, 2, 0])
|
||||||
query = paddle.unsqueeze(query, axis=2)
|
query = paddle.unsqueeze(query, axis=2)
|
||||||
|
@ -139,7 +162,6 @@ class MultiheadAttentionOptim(nn.Layer):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _in_proj_k(self, key):
|
def _in_proj_k(self, key):
|
||||||
|
|
||||||
key = key.transpose([1, 2, 0])
|
key = key.transpose([1, 2, 0])
|
||||||
key = paddle.unsqueeze(key, axis=2)
|
key = paddle.unsqueeze(key, axis=2)
|
||||||
res = self.conv2(key)
|
res = self.conv2(key)
|
||||||
|
@ -148,8 +170,7 @@ class MultiheadAttentionOptim(nn.Layer):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def _in_proj_v(self, value):
|
def _in_proj_v(self, value):
|
||||||
|
value = value.transpose([1, 2, 0]) #(1, 2, 0)
|
||||||
value = value.transpose([1,2,0])#(1, 2, 0)
|
|
||||||
value = paddle.unsqueeze(value, axis=2)
|
value = paddle.unsqueeze(value, axis=2)
|
||||||
res = self.conv3(value)
|
res = self.conv3(value)
|
||||||
res = paddle.squeeze(res, axis=2)
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
|
|
@ -1,7 +1,21 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import paddle
|
import paddle
|
||||||
import copy
|
import copy
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
import paddle.nn.functional as F
|
import paddle.nn.functional as F
|
||||||
from paddle.nn import LayerList
|
from paddle.nn import LayerList
|
||||||
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
||||||
|
@ -14,8 +28,9 @@ from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||||
zeros_ = constant_(value=0.)
|
zeros_ = constant_(value=0.)
|
||||||
ones_ = constant_(value=1.)
|
ones_ = constant_(value=1.)
|
||||||
|
|
||||||
|
|
||||||
class TransformerOptim(nn.Layer):
|
class TransformerOptim(nn.Layer):
|
||||||
r"""A transformer model. User is able to modify the attributes as needed. The architechture
|
"""A transformer model. User is able to modify the attributes as needed. The architechture
|
||||||
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
||||||
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
||||||
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
||||||
|
@ -31,39 +46,50 @@ class TransformerOptim(nn.Layer):
|
||||||
custom_encoder: custom encoder (default=None).
|
custom_encoder: custom encoder (default=None).
|
||||||
custom_decoder: custom decoder (default=None).
|
custom_decoder: custom decoder (default=None).
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> transformer_model = nn.Transformer(src_vocab, tgt_vocab)
|
|
||||||
>>> transformer_model = nn.Transformer(src_vocab, tgt_vocab, nhead=16, num_encoder_layers=12)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, beam_size=0,
|
def __init__(self,
|
||||||
num_decoder_layers=6, dim_feedforward=1024, attention_dropout_rate=0.0, residual_dropout_rate=0.1,
|
d_model=512,
|
||||||
custom_encoder=None, custom_decoder=None,in_channels=0,out_channels=0,dst_vocab_size=99,scale_embedding=True):
|
nhead=8,
|
||||||
|
num_encoder_layers=6,
|
||||||
|
beam_size=0,
|
||||||
|
num_decoder_layers=6,
|
||||||
|
dim_feedforward=1024,
|
||||||
|
attention_dropout_rate=0.0,
|
||||||
|
residual_dropout_rate=0.1,
|
||||||
|
custom_encoder=None,
|
||||||
|
custom_decoder=None,
|
||||||
|
in_channels=0,
|
||||||
|
out_channels=0,
|
||||||
|
dst_vocab_size=99,
|
||||||
|
scale_embedding=True):
|
||||||
super(TransformerOptim, self).__init__()
|
super(TransformerOptim, self).__init__()
|
||||||
self.embedding = Embeddings(
|
self.embedding = Embeddings(
|
||||||
d_model=d_model,
|
d_model=d_model,
|
||||||
vocab=dst_vocab_size,
|
vocab=dst_vocab_size,
|
||||||
padding_idx=0,
|
padding_idx=0,
|
||||||
scale_embedding=scale_embedding
|
scale_embedding=scale_embedding)
|
||||||
)
|
|
||||||
self.positional_encoding = PositionalEncoding(
|
self.positional_encoding = PositionalEncoding(
|
||||||
dropout=residual_dropout_rate,
|
dropout=residual_dropout_rate,
|
||||||
dim=d_model,
|
dim=d_model, )
|
||||||
)
|
|
||||||
if custom_encoder is not None:
|
if custom_encoder is not None:
|
||||||
self.encoder = custom_encoder
|
self.encoder = custom_encoder
|
||||||
else:
|
else:
|
||||||
if num_encoder_layers > 0 :
|
if num_encoder_layers > 0:
|
||||||
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, attention_dropout_rate, residual_dropout_rate)
|
encoder_layer = TransformerEncoderLayer(
|
||||||
|
d_model, nhead, dim_feedforward, attention_dropout_rate,
|
||||||
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers)
|
residual_dropout_rate)
|
||||||
|
self.encoder = TransformerEncoder(encoder_layer,
|
||||||
|
num_encoder_layers)
|
||||||
else:
|
else:
|
||||||
self.encoder = None
|
self.encoder = None
|
||||||
|
|
||||||
if custom_decoder is not None:
|
if custom_decoder is not None:
|
||||||
self.decoder = custom_decoder
|
self.decoder = custom_decoder
|
||||||
else:
|
else:
|
||||||
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, attention_dropout_rate, residual_dropout_rate)
|
decoder_layer = TransformerDecoderLayer(
|
||||||
|
d_model, nhead, dim_feedforward, attention_dropout_rate,
|
||||||
|
residual_dropout_rate)
|
||||||
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
|
||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
@ -71,201 +97,205 @@ class TransformerOptim(nn.Layer):
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.nhead = nhead
|
self.nhead = nhead
|
||||||
self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False)
|
self.tgt_word_prj = nn.Linear(d_model, dst_vocab_size, bias_attr=False)
|
||||||
w0 = np.random.normal(0.0, d_model**-0.5,(d_model, dst_vocab_size)).astype(np.float32)
|
w0 = np.random.normal(0.0, d_model**-0.5,
|
||||||
|
(d_model, dst_vocab_size)).astype(np.float32)
|
||||||
self.tgt_word_prj.weight.set_value(w0)
|
self.tgt_word_prj.weight.set_value(w0)
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
|
||||||
def _init_weights(self, m):
|
def _init_weights(self, m):
|
||||||
|
|
||||||
if isinstance(m, nn.Conv2D):
|
if isinstance(m, nn.Conv2D):
|
||||||
xavier_normal_(m.weight)
|
xavier_normal_(m.weight)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
zeros_(m.bias)
|
zeros_(m.bias)
|
||||||
|
|
||||||
def forward_train(self,src,tgt):
|
def forward_train(self, src, tgt):
|
||||||
tgt = tgt[:, :-1]
|
tgt = tgt[:, :-1]
|
||||||
|
|
||||||
|
tgt_key_padding_mask = self.generate_padding_mask(tgt)
|
||||||
|
tgt = self.embedding(tgt).transpose([1, 0, 2])
|
||||||
tgt_key_padding_mask = self.generate_padding_mask(tgt)
|
tgt = self.positional_encoding(tgt)
|
||||||
tgt = self.embedding(tgt).transpose([1, 0, 2])
|
tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
|
||||||
tgt = self.positional_encoding(tgt)
|
|
||||||
tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
|
|
||||||
|
|
||||||
if self.encoder is not None :
|
if self.encoder is not None:
|
||||||
src = self.positional_encoding(src.transpose([1, 0, 2]))
|
src = self.positional_encoding(src.transpose([1, 0, 2]))
|
||||||
memory = self.encoder(src)
|
memory = self.encoder(src)
|
||||||
else:
|
else:
|
||||||
memory = src.squeeze(2).transpose([2, 0, 1])
|
memory = src.squeeze(2).transpose([2, 0, 1])
|
||||||
output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=None,
|
output = self.decoder(
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
tgt,
|
||||||
memory_key_padding_mask=None)
|
memory,
|
||||||
output = output.transpose([1, 0, 2])
|
tgt_mask=tgt_mask,
|
||||||
logit = self.tgt_word_prj(output)
|
memory_mask=None,
|
||||||
return logit
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||||
|
memory_key_padding_mask=None)
|
||||||
def forward(self, src, tgt=None):
|
output = output.transpose([1, 0, 2])
|
||||||
r"""Take in and process masked source/target sequences.
|
logit = self.tgt_word_prj(output)
|
||||||
|
return logit
|
||||||
|
|
||||||
|
def forward(self, src, targets=None):
|
||||||
|
"""Take in and process masked source/target sequences.
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder (required).
|
src: the sequence to the encoder (required).
|
||||||
tgt: the sequence to the decoder (required).
|
tgt: the sequence to the decoder (required).
|
||||||
src_mask: the additive mask for the src sequence (optional).
|
|
||||||
tgt_mask: the additive mask for the tgt sequence (optional).
|
|
||||||
memory_mask: the additive mask for the encoder output (optional).
|
|
||||||
src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
|
|
||||||
tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
|
|
||||||
memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
|
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
- src: :math:`(S, N, E)`.
|
- src: :math:`(S, N, E)`.
|
||||||
- tgt: :math:`(T, N, E)`.
|
- tgt: :math:`(T, N, E)`.
|
||||||
- src_mask: :math:`(S, S)`.
|
|
||||||
- tgt_mask: :math:`(T, T)`.
|
|
||||||
- memory_mask: :math:`(T, S)`.
|
|
||||||
- src_key_padding_mask: :math:`(N, S)`.
|
|
||||||
- tgt_key_padding_mask: :math:`(N, T)`.
|
|
||||||
- memory_key_padding_mask: :math:`(N, S)`.
|
|
||||||
|
|
||||||
Note: [src/tgt/memory]_mask should be filled with
|
|
||||||
float('-inf') for the masked positions and float(0.0) else. These masks
|
|
||||||
ensure that predictions for position i depend only on the unmasked positions
|
|
||||||
j and are applied identically for each sequence in a batch.
|
|
||||||
[src/tgt/memory]_key_padding_mask should be a ByteTensor where True values are positions
|
|
||||||
that should be masked with float('-inf') and False values will be unchanged.
|
|
||||||
This mask ensures that no information will be taken from position i if
|
|
||||||
it is masked, and has a separate mask for each sequence in a batch.
|
|
||||||
|
|
||||||
- output: :math:`(T, N, E)`.
|
|
||||||
|
|
||||||
Note: Due to the multi-head attention architecture in the transformer model,
|
|
||||||
the output sequence length of a transformer is same as the input sequence
|
|
||||||
(i.e. target) length of the decode.
|
|
||||||
|
|
||||||
where S is the source sequence length, T is the target sequence length, N is the
|
|
||||||
batch size, E is the feature number
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
|
>>> output = transformer_model(src, tgt)
|
||||||
"""
|
"""
|
||||||
if tgt is not None:
|
|
||||||
|
if self.training:
|
||||||
|
max_len = targets[1].max()
|
||||||
|
tgt = targets[0][:, :2 + max_len]
|
||||||
return self.forward_train(src, tgt)
|
return self.forward_train(src, tgt)
|
||||||
else:
|
else:
|
||||||
if self.beam_size > 0 :
|
if self.beam_size > 0:
|
||||||
return self.forward_beam(src)
|
return self.forward_beam(src)
|
||||||
else:
|
else:
|
||||||
return self.forward_test(src)
|
return self.forward_test(src)
|
||||||
|
|
||||||
def forward_test(self, src):
|
def forward_test(self, src):
|
||||||
bs = src.shape[0]
|
bs = src.shape[0]
|
||||||
if self.encoder is not None :
|
if self.encoder is not None:
|
||||||
src = self.positional_encoding(src.transpose([1, 0, 2]))
|
src = self.positional_encoding(src.transpose([1, 0, 2]))
|
||||||
memory = self.encoder(src)
|
memory = self.encoder(src)
|
||||||
else:
|
else:
|
||||||
memory = src.squeeze(2).transpose([2, 0, 1])
|
memory = src.squeeze(2).transpose([2, 0, 1])
|
||||||
dec_seq = paddle.full((bs,1), 2, dtype=paddle.int64)
|
dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
|
||||||
for len_dec_seq in range(1, 25):
|
for len_dec_seq in range(1, 25):
|
||||||
src_enc = memory.clone()
|
src_enc = memory.clone()
|
||||||
tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
|
tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
|
||||||
dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2])
|
dec_seq_embed = self.embedding(dec_seq).transpose([1, 0, 2])
|
||||||
dec_seq_embed = self.positional_encoding(dec_seq_embed)
|
dec_seq_embed = self.positional_encoding(dec_seq_embed)
|
||||||
tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[0])
|
tgt_mask = self.generate_square_subsequent_mask(dec_seq_embed.shape[
|
||||||
output = self.decoder(dec_seq_embed, src_enc, tgt_mask=tgt_mask, memory_mask=None,
|
0])
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
output = self.decoder(
|
||||||
memory_key_padding_mask=None)
|
dec_seq_embed,
|
||||||
|
src_enc,
|
||||||
|
tgt_mask=tgt_mask,
|
||||||
|
memory_mask=None,
|
||||||
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||||
|
memory_key_padding_mask=None)
|
||||||
dec_output = output.transpose([1, 0, 2])
|
dec_output = output.transpose([1, 0, 2])
|
||||||
|
|
||||||
dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h
|
dec_output = dec_output[:,
|
||||||
|
-1, :] # Pick the last step: (bh * bm) * d_h
|
||||||
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
|
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
|
||||||
word_prob = word_prob.reshape([1, bs, -1])
|
word_prob = word_prob.reshape([1, bs, -1])
|
||||||
preds_idx = word_prob.argmax(axis=2)
|
preds_idx = word_prob.argmax(axis=2)
|
||||||
|
|
||||||
if paddle.equal_all(preds_idx[-1],paddle.full(preds_idx[-1].shape,3,dtype='int64')):
|
if paddle.equal_all(
|
||||||
|
preds_idx[-1],
|
||||||
|
paddle.full(
|
||||||
|
preds_idx[-1].shape, 3, dtype='int64')):
|
||||||
break
|
break
|
||||||
|
|
||||||
preds_prob = word_prob.max(axis=2)
|
preds_prob = word_prob.max(axis=2)
|
||||||
dec_seq = paddle.concat([dec_seq,preds_idx.reshape([-1,1])],axis=1)
|
dec_seq = paddle.concat(
|
||||||
|
[dec_seq, preds_idx.reshape([-1, 1])], axis=1)
|
||||||
|
|
||||||
return dec_seq
|
return dec_seq
|
||||||
|
|
||||||
def forward_beam(self,images):
|
def forward_beam(self, images):
|
||||||
|
|
||||||
''' Translation work in one batch '''
|
''' Translation work in one batch '''
|
||||||
|
|
||||||
def get_inst_idx_to_tensor_position_map(inst_idx_list):
|
def get_inst_idx_to_tensor_position_map(inst_idx_list):
|
||||||
''' Indicate the position of an instance in a tensor. '''
|
''' Indicate the position of an instance in a tensor. '''
|
||||||
return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
|
return {
|
||||||
|
inst_idx: tensor_position
|
||||||
|
for tensor_position, inst_idx in enumerate(inst_idx_list)
|
||||||
|
}
|
||||||
|
|
||||||
def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
|
def collect_active_part(beamed_tensor, curr_active_inst_idx,
|
||||||
|
n_prev_active_inst, n_bm):
|
||||||
''' Collect tensor parts associated to active instances. '''
|
''' Collect tensor parts associated to active instances. '''
|
||||||
|
|
||||||
_, *d_hs = beamed_tensor.shape
|
_, *d_hs = beamed_tensor.shape
|
||||||
n_curr_active_inst = len(curr_active_inst_idx)
|
n_curr_active_inst = len(curr_active_inst_idx)
|
||||||
new_shape = (n_curr_active_inst * n_bm, *d_hs)
|
new_shape = (n_curr_active_inst * n_bm, *d_hs)
|
||||||
|
|
||||||
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])#contiguous()
|
beamed_tensor = beamed_tensor.reshape(
|
||||||
beamed_tensor = beamed_tensor.index_select(paddle.to_tensor(curr_active_inst_idx),axis=0)
|
[n_prev_active_inst, -1]) #contiguous()
|
||||||
|
beamed_tensor = beamed_tensor.index_select(
|
||||||
|
paddle.to_tensor(curr_active_inst_idx), axis=0)
|
||||||
beamed_tensor = beamed_tensor.reshape([*new_shape])
|
beamed_tensor = beamed_tensor.reshape([*new_shape])
|
||||||
|
|
||||||
return beamed_tensor
|
return beamed_tensor
|
||||||
|
|
||||||
|
def collate_active_info(src_enc, inst_idx_to_position_map,
|
||||||
def collate_active_info(
|
active_inst_idx_list):
|
||||||
src_enc, inst_idx_to_position_map, active_inst_idx_list):
|
|
||||||
# Sentences which are still active are collected,
|
# Sentences which are still active are collected,
|
||||||
# so the decoder will not run on completed sentences.
|
# so the decoder will not run on completed sentences.
|
||||||
|
|
||||||
n_prev_active_inst = len(inst_idx_to_position_map)
|
n_prev_active_inst = len(inst_idx_to_position_map)
|
||||||
active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
|
active_inst_idx = [
|
||||||
|
inst_idx_to_position_map[k] for k in active_inst_idx_list
|
||||||
|
]
|
||||||
active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
|
active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
|
||||||
active_src_enc = collect_active_part(src_enc.transpose([1, 0, 2]), active_inst_idx, n_prev_active_inst, n_bm).transpose([1, 0, 2])
|
active_src_enc = collect_active_part(
|
||||||
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
|
src_enc.transpose([1, 0, 2]), active_inst_idx,
|
||||||
|
n_prev_active_inst, n_bm).transpose([1, 0, 2])
|
||||||
|
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
|
||||||
|
active_inst_idx_list)
|
||||||
return active_src_enc, active_inst_idx_to_position_map
|
return active_src_enc, active_inst_idx_to_position_map
|
||||||
|
|
||||||
def beam_decode_step(
|
def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
|
||||||
inst_dec_beams, len_dec_seq, enc_output, inst_idx_to_position_map, n_bm, memory_key_padding_mask):
|
inst_idx_to_position_map, n_bm,
|
||||||
|
memory_key_padding_mask):
|
||||||
''' Decode and update beam status, and then return active beam idx '''
|
''' Decode and update beam status, and then return active beam idx '''
|
||||||
|
|
||||||
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
|
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
|
||||||
dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
|
dec_partial_seq = [
|
||||||
|
b.get_current_state() for b in inst_dec_beams if not b.done
|
||||||
|
]
|
||||||
dec_partial_seq = paddle.stack(dec_partial_seq)
|
dec_partial_seq = paddle.stack(dec_partial_seq)
|
||||||
|
|
||||||
dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
|
dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
|
||||||
return dec_partial_seq
|
return dec_partial_seq
|
||||||
|
|
||||||
def prepare_beam_memory_key_padding_mask(inst_dec_beams, memory_key_padding_mask, n_bm):
|
def prepare_beam_memory_key_padding_mask(
|
||||||
|
inst_dec_beams, memory_key_padding_mask, n_bm):
|
||||||
keep = []
|
keep = []
|
||||||
for idx in (memory_key_padding_mask):
|
for idx in (memory_key_padding_mask):
|
||||||
if not inst_dec_beams[idx].done:
|
if not inst_dec_beams[idx].done:
|
||||||
keep.append(idx)
|
keep.append(idx)
|
||||||
memory_key_padding_mask = memory_key_padding_mask[paddle.to_tensor(keep)]
|
memory_key_padding_mask = memory_key_padding_mask[
|
||||||
|
paddle.to_tensor(keep)]
|
||||||
len_s = memory_key_padding_mask.shape[-1]
|
len_s = memory_key_padding_mask.shape[-1]
|
||||||
n_inst = memory_key_padding_mask.shape[0]
|
n_inst = memory_key_padding_mask.shape[0]
|
||||||
memory_key_padding_mask = paddle.concat([memory_key_padding_mask for i in range(n_bm)],axis=1)
|
memory_key_padding_mask = paddle.concat(
|
||||||
memory_key_padding_mask = memory_key_padding_mask.reshape([n_inst * n_bm, len_s])#repeat(1, n_bm)
|
[memory_key_padding_mask for i in range(n_bm)], axis=1)
|
||||||
|
memory_key_padding_mask = memory_key_padding_mask.reshape(
|
||||||
|
[n_inst * n_bm, len_s]) #repeat(1, n_bm)
|
||||||
return memory_key_padding_mask
|
return memory_key_padding_mask
|
||||||
|
|
||||||
def predict_word(dec_seq, enc_output, n_active_inst, n_bm, memory_key_padding_mask):
|
def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
|
||||||
|
memory_key_padding_mask):
|
||||||
tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
|
tgt_key_padding_mask = self.generate_padding_mask(dec_seq)
|
||||||
dec_seq = self.embedding(dec_seq).transpose([1, 0, 2])
|
dec_seq = self.embedding(dec_seq).transpose([1, 0, 2])
|
||||||
dec_seq = self.positional_encoding(dec_seq)
|
dec_seq = self.positional_encoding(dec_seq)
|
||||||
tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[0])
|
tgt_mask = self.generate_square_subsequent_mask(dec_seq.shape[
|
||||||
|
0])
|
||||||
dec_output = self.decoder(
|
dec_output = self.decoder(
|
||||||
dec_seq, enc_output,
|
dec_seq,
|
||||||
|
enc_output,
|
||||||
tgt_mask=tgt_mask,
|
tgt_mask=tgt_mask,
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
).transpose([1, 0, 2])
|
).transpose([1, 0, 2])
|
||||||
dec_output = dec_output[:, -1, :] # Pick the last step: (bh * bm) * d_h
|
dec_output = dec_output[:,
|
||||||
|
-1, :] # Pick the last step: (bh * bm) * d_h
|
||||||
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
|
word_prob = F.log_softmax(self.tgt_word_prj(dec_output), axis=1)
|
||||||
word_prob = word_prob.reshape([n_active_inst, n_bm, -1])
|
word_prob = word_prob.reshape([n_active_inst, n_bm, -1])
|
||||||
return word_prob
|
return word_prob
|
||||||
|
|
||||||
def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
|
def collect_active_inst_idx_list(inst_beams, word_prob,
|
||||||
|
inst_idx_to_position_map):
|
||||||
active_inst_idx_list = []
|
active_inst_idx_list = []
|
||||||
for inst_idx, inst_position in inst_idx_to_position_map.items():
|
for inst_idx, inst_position in inst_idx_to_position_map.items():
|
||||||
is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
|
is_inst_complete = inst_beams[inst_idx].advance(word_prob[
|
||||||
|
inst_position])
|
||||||
if not is_inst_complete:
|
if not is_inst_complete:
|
||||||
active_inst_idx_list += [inst_idx]
|
active_inst_idx_list += [inst_idx]
|
||||||
|
|
||||||
|
@ -274,7 +304,8 @@ class TransformerOptim(nn.Layer):
|
||||||
n_active_inst = len(inst_idx_to_position_map)
|
n_active_inst = len(inst_idx_to_position_map)
|
||||||
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
|
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
|
||||||
memory_key_padding_mask = None
|
memory_key_padding_mask = None
|
||||||
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm, memory_key_padding_mask)
|
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
|
||||||
|
memory_key_padding_mask)
|
||||||
# Update the beam with predicted word prob information and collect incomplete instances
|
# Update the beam with predicted word prob information and collect incomplete instances
|
||||||
active_inst_idx_list = collect_active_inst_idx_list(
|
active_inst_idx_list = collect_active_inst_idx_list(
|
||||||
inst_dec_beams, word_prob, inst_idx_to_position_map)
|
inst_dec_beams, word_prob, inst_idx_to_position_map)
|
||||||
|
@ -285,14 +316,17 @@ class TransformerOptim(nn.Layer):
|
||||||
for inst_idx in range(len(inst_dec_beams)):
|
for inst_idx in range(len(inst_dec_beams)):
|
||||||
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
|
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
|
||||||
all_scores += [scores[:n_best]]
|
all_scores += [scores[:n_best]]
|
||||||
hyps = [inst_dec_beams[inst_idx].get_hypothesis(i) for i in tail_idxs[:n_best]]
|
hyps = [
|
||||||
|
inst_dec_beams[inst_idx].get_hypothesis(i)
|
||||||
|
for i in tail_idxs[:n_best]
|
||||||
|
]
|
||||||
all_hyp += [hyps]
|
all_hyp += [hyps]
|
||||||
return all_hyp, all_scores
|
return all_hyp, all_scores
|
||||||
|
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
#-- Encode
|
#-- Encode
|
||||||
|
|
||||||
if self.encoder is not None :
|
if self.encoder is not None:
|
||||||
src = self.positional_encoding(images.transpose([1, 0, 2]))
|
src = self.positional_encoding(images.transpose([1, 0, 2]))
|
||||||
src_enc = self.encoder(src).transpose([1, 0, 2])
|
src_enc = self.encoder(src).transpose([1, 0, 2])
|
||||||
else:
|
else:
|
||||||
|
@ -301,45 +335,53 @@ class TransformerOptim(nn.Layer):
|
||||||
#-- Repeat data for beam search
|
#-- Repeat data for beam search
|
||||||
n_bm = self.beam_size
|
n_bm = self.beam_size
|
||||||
n_inst, len_s, d_h = src_enc.shape
|
n_inst, len_s, d_h = src_enc.shape
|
||||||
src_enc = paddle.concat([src_enc for i in range(n_bm)],axis=1)
|
src_enc = paddle.concat([src_enc for i in range(n_bm)], axis=1)
|
||||||
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose([1, 0, 2])#repeat(1, n_bm, 1)
|
src_enc = src_enc.reshape([n_inst * n_bm, len_s, d_h]).transpose(
|
||||||
|
[1, 0, 2]) #repeat(1, n_bm, 1)
|
||||||
#-- Prepare beams
|
#-- Prepare beams
|
||||||
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
|
inst_dec_beams = [Beam(n_bm) for _ in range(n_inst)]
|
||||||
|
|
||||||
#-- Bookkeeping for active or not
|
#-- Bookkeeping for active or not
|
||||||
active_inst_idx_list = list(range(n_inst))
|
active_inst_idx_list = list(range(n_inst))
|
||||||
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
|
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
|
||||||
|
active_inst_idx_list)
|
||||||
#-- Decode
|
#-- Decode
|
||||||
for len_dec_seq in range(1, 25):
|
for len_dec_seq in range(1, 25):
|
||||||
src_enc_copy = src_enc.clone()
|
src_enc_copy = src_enc.clone()
|
||||||
active_inst_idx_list = beam_decode_step(
|
active_inst_idx_list = beam_decode_step(
|
||||||
inst_dec_beams, len_dec_seq, src_enc_copy, inst_idx_to_position_map, n_bm, None)
|
inst_dec_beams, len_dec_seq, src_enc_copy,
|
||||||
|
inst_idx_to_position_map, n_bm, None)
|
||||||
if not active_inst_idx_list:
|
if not active_inst_idx_list:
|
||||||
break # all instances have finished their path to <EOS>
|
break # all instances have finished their path to <EOS>
|
||||||
src_enc, inst_idx_to_position_map = collate_active_info(
|
src_enc, inst_idx_to_position_map = collate_active_info(
|
||||||
src_enc_copy, inst_idx_to_position_map, active_inst_idx_list)
|
src_enc_copy, inst_idx_to_position_map,
|
||||||
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams, 1)
|
active_inst_idx_list)
|
||||||
|
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
|
||||||
|
1)
|
||||||
result_hyp = []
|
result_hyp = []
|
||||||
for bs_hyp in batch_hyp:
|
for bs_hyp in batch_hyp:
|
||||||
bs_hyp_pad =bs_hyp[0]+[3]*(25-len(bs_hyp[0]))
|
bs_hyp_pad = bs_hyp[0] + [3] * (25 - len(bs_hyp[0]))
|
||||||
result_hyp.append(bs_hyp_pad)
|
result_hyp.append(bs_hyp_pad)
|
||||||
return paddle.to_tensor(np.array(result_hyp),dtype=paddle.int64)
|
return paddle.to_tensor(np.array(result_hyp), dtype=paddle.int64)
|
||||||
|
|
||||||
def generate_square_subsequent_mask(self, sz):
|
def generate_square_subsequent_mask(self, sz):
|
||||||
r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
||||||
Unmasked positions are filled with float(0.0).
|
Unmasked positions are filled with float(0.0).
|
||||||
"""
|
"""
|
||||||
mask = paddle.zeros([sz, sz],dtype='float32')
|
mask = paddle.zeros([sz, sz], dtype='float32')
|
||||||
mask_inf = paddle.triu(paddle.full(shape=[sz,sz], dtype='float32', fill_value='-inf'),diagonal=1)
|
mask_inf = paddle.triu(
|
||||||
mask = mask+mask_inf
|
paddle.full(
|
||||||
|
shape=[sz, sz], dtype='float32', fill_value='-inf'),
|
||||||
|
diagonal=1)
|
||||||
|
mask = mask + mask_inf
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def generate_padding_mask(self, x):
|
def generate_padding_mask(self, x):
|
||||||
padding_mask = x.equal(paddle.to_tensor(0,dtype=x.dtype))
|
padding_mask = x.equal(paddle.to_tensor(0, dtype=x.dtype))
|
||||||
return padding_mask
|
return padding_mask
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
r"""Initiate parameters in the transformer model."""
|
"""Initiate parameters in the transformer model."""
|
||||||
|
|
||||||
for p in self.parameters():
|
for p in self.parameters():
|
||||||
if p.dim() > 1:
|
if p.dim() > 1:
|
||||||
|
@ -347,16 +389,11 @@ class TransformerOptim(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoder(nn.Layer):
|
class TransformerEncoder(nn.Layer):
|
||||||
r"""TransformerEncoder is a stack of N encoder layers
|
"""TransformerEncoder is a stack of N encoder layers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
||||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
norm: the layer normalization component (optional).
|
norm: the layer normalization component (optional).
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
|
|
||||||
>>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder_layer, num_layers):
|
def __init__(self, encoder_layer, num_layers):
|
||||||
|
@ -364,50 +401,46 @@ class TransformerEncoder(nn.Layer):
|
||||||
self.layers = _get_clones(encoder_layer, num_layers)
|
self.layers = _get_clones(encoder_layer, num_layers)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
|
||||||
def forward(self, src):
|
def forward(self, src):
|
||||||
r"""Pass the input through the endocder layers in turn.
|
"""Pass the input through the endocder layers in turn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequnce to the encoder (required).
|
src: the sequnce to the encoder (required).
|
||||||
mask: the mask for the src sequence (optional).
|
mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
Shape:
|
|
||||||
see the docs in Transformer class.
|
|
||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
output = self.layers[i](output, src_mask=None,
|
output = self.layers[i](output,
|
||||||
|
src_mask=None,
|
||||||
src_key_padding_mask=None)
|
src_key_padding_mask=None)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoder(nn.Layer):
|
class TransformerDecoder(nn.Layer):
|
||||||
r"""TransformerDecoder is a stack of N decoder layers
|
"""TransformerDecoder is a stack of N decoder layers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
||||||
num_layers: the number of sub-decoder-layers in the decoder (required).
|
num_layers: the number of sub-decoder-layers in the decoder (required).
|
||||||
norm: the layer normalization component (optional).
|
norm: the layer normalization component (optional).
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> decoder_layer = nn.TransformerDecoderLayer(d_model, nhead)
|
|
||||||
>>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, decoder_layer, num_layers):
|
def __init__(self, decoder_layer, num_layers):
|
||||||
super(TransformerDecoder, self).__init__()
|
super(TransformerDecoder, self).__init__()
|
||||||
self.layers = _get_clones(decoder_layer, num_layers)
|
self.layers = _get_clones(decoder_layer, num_layers)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
|
||||||
def forward(self, tgt, memory, tgt_mask=None,
|
def forward(self,
|
||||||
memory_mask=None, tgt_key_padding_mask=None,
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask=None,
|
||||||
|
memory_mask=None,
|
||||||
|
tgt_key_padding_mask=None,
|
||||||
memory_key_padding_mask=None):
|
memory_key_padding_mask=None):
|
||||||
r"""Pass the inputs (and mask) through the decoder layer in turn.
|
"""Pass the inputs (and mask) through the decoder layer in turn.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tgt: the sequence to the decoder (required).
|
tgt: the sequence to the decoder (required).
|
||||||
|
@ -416,21 +449,22 @@ class TransformerDecoder(nn.Layer):
|
||||||
memory_mask: the mask for the memory sequence (optional).
|
memory_mask: the mask for the memory sequence (optional).
|
||||||
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
||||||
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
||||||
|
|
||||||
Shape:
|
|
||||||
see the docs in Transformer class.
|
|
||||||
"""
|
"""
|
||||||
output = tgt
|
output = tgt
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
output = self.layers[i](output, memory, tgt_mask=tgt_mask,
|
output = self.layers[i](
|
||||||
memory_mask=memory_mask,
|
output,
|
||||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask)
|
tgt_mask=tgt_mask,
|
||||||
|
memory_mask=memory_mask,
|
||||||
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class TransformerEncoderLayer(nn.Layer):
|
class TransformerEncoderLayer(nn.Layer):
|
||||||
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
||||||
This standard encoder layer is based on the paper "Attention Is All You Need".
|
This standard encoder layer is based on the paper "Attention Is All You Need".
|
||||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
||||||
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
||||||
|
@ -443,16 +477,26 @@ class TransformerEncoderLayer(nn.Layer):
|
||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
dropout: the dropout value (default=0.1).
|
dropout: the dropout value (default=0.1).
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> encoder_layer = nn.TransformerEncoderLayer(d_model, nhead)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model, nhead, dim_feedforward=2048, attention_dropout_rate=0.0, residual_dropout_rate=0.1):
|
def __init__(self,
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
attention_dropout_rate=0.0,
|
||||||
|
residual_dropout_rate=0.1):
|
||||||
super(TransformerEncoderLayer, self).__init__()
|
super(TransformerEncoderLayer, self).__init__()
|
||||||
self.self_attn = MultiheadAttentionOptim(d_model, nhead, dropout=attention_dropout_rate)
|
self.self_attn = MultiheadAttentionOptim(
|
||||||
|
d_model, nhead, dropout=attention_dropout_rate)
|
||||||
|
|
||||||
self.conv1 = Conv2D(in_channels=d_model, out_channels=dim_feedforward, kernel_size=(1, 1))
|
self.conv1 = Conv2D(
|
||||||
self.conv2 = Conv2D(in_channels=dim_feedforward, out_channels=d_model, kernel_size=(1, 1))
|
in_channels=d_model,
|
||||||
|
out_channels=dim_feedforward,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
self.conv2 = Conv2D(
|
||||||
|
in_channels=dim_feedforward,
|
||||||
|
out_channels=d_model,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
|
||||||
self.norm1 = LayerNorm(d_model)
|
self.norm1 = LayerNorm(d_model)
|
||||||
self.norm2 = LayerNorm(d_model)
|
self.norm2 = LayerNorm(d_model)
|
||||||
|
@ -460,18 +504,18 @@ class TransformerEncoderLayer(nn.Layer):
|
||||||
self.dropout2 = Dropout(residual_dropout_rate)
|
self.dropout2 = Dropout(residual_dropout_rate)
|
||||||
|
|
||||||
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
||||||
r"""Pass the input through the endocder layer.
|
"""Pass the input through the endocder layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequnce to the encoder layer (required).
|
src: the sequnce to the encoder layer (required).
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
|
||||||
Shape:
|
|
||||||
see the docs in Transformer class.
|
|
||||||
"""
|
"""
|
||||||
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
|
src2 = self.self_attn(
|
||||||
key_padding_mask=src_key_padding_mask)[0]
|
src,
|
||||||
|
src,
|
||||||
|
src,
|
||||||
|
attn_mask=src_mask,
|
||||||
|
key_padding_mask=src_key_padding_mask)[0]
|
||||||
src = src + self.dropout1(src2)
|
src = src + self.dropout1(src2)
|
||||||
src = self.norm1(src)
|
src = self.norm1(src)
|
||||||
|
|
||||||
|
@ -487,8 +531,9 @@ class TransformerEncoderLayer(nn.Layer):
|
||||||
src = self.norm2(src)
|
src = self.norm2(src)
|
||||||
return src
|
return src
|
||||||
|
|
||||||
|
|
||||||
class TransformerDecoderLayer(nn.Layer):
|
class TransformerDecoderLayer(nn.Layer):
|
||||||
r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
||||||
This standard decoder layer is based on the paper "Attention Is All You Need".
|
This standard decoder layer is based on the paper "Attention Is All You Need".
|
||||||
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
||||||
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
||||||
|
@ -501,17 +546,28 @@ class TransformerDecoderLayer(nn.Layer):
|
||||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||||
dropout: the dropout value (default=0.1).
|
dropout: the dropout value (default=0.1).
|
||||||
|
|
||||||
Examples::
|
|
||||||
>>> decoder_layer = nn.TransformerDecoderLayer(d_model, nhead)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, d_model, nhead, dim_feedforward=2048, attention_dropout_rate=0.0, residual_dropout_rate=0.1):
|
def __init__(self,
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
attention_dropout_rate=0.0,
|
||||||
|
residual_dropout_rate=0.1):
|
||||||
super(TransformerDecoderLayer, self).__init__()
|
super(TransformerDecoderLayer, self).__init__()
|
||||||
self.self_attn = MultiheadAttentionOptim(d_model, nhead, dropout=attention_dropout_rate)
|
self.self_attn = MultiheadAttentionOptim(
|
||||||
self.multihead_attn = MultiheadAttentionOptim(d_model, nhead, dropout=attention_dropout_rate)
|
d_model, nhead, dropout=attention_dropout_rate)
|
||||||
|
self.multihead_attn = MultiheadAttentionOptim(
|
||||||
|
d_model, nhead, dropout=attention_dropout_rate)
|
||||||
|
|
||||||
self.conv1 = Conv2D(in_channels=d_model, out_channels=dim_feedforward, kernel_size=(1, 1))
|
self.conv1 = Conv2D(
|
||||||
self.conv2 = Conv2D(in_channels=dim_feedforward, out_channels=d_model, kernel_size=(1, 1))
|
in_channels=d_model,
|
||||||
|
out_channels=dim_feedforward,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
self.conv2 = Conv2D(
|
||||||
|
in_channels=dim_feedforward,
|
||||||
|
out_channels=d_model,
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
|
||||||
self.norm1 = LayerNorm(d_model)
|
self.norm1 = LayerNorm(d_model)
|
||||||
self.norm2 = LayerNorm(d_model)
|
self.norm2 = LayerNorm(d_model)
|
||||||
|
@ -520,9 +576,14 @@ class TransformerDecoderLayer(nn.Layer):
|
||||||
self.dropout2 = Dropout(residual_dropout_rate)
|
self.dropout2 = Dropout(residual_dropout_rate)
|
||||||
self.dropout3 = Dropout(residual_dropout_rate)
|
self.dropout3 = Dropout(residual_dropout_rate)
|
||||||
|
|
||||||
def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
|
def forward(self,
|
||||||
tgt_key_padding_mask=None, memory_key_padding_mask=None):
|
tgt,
|
||||||
r"""Pass the inputs (and mask) through the decoder layer.
|
memory,
|
||||||
|
tgt_mask=None,
|
||||||
|
memory_mask=None,
|
||||||
|
tgt_key_padding_mask=None,
|
||||||
|
memory_key_padding_mask=None):
|
||||||
|
"""Pass the inputs (and mask) through the decoder layer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tgt: the sequence to the decoder layer (required).
|
tgt: the sequence to the decoder layer (required).
|
||||||
|
@ -532,15 +593,21 @@ class TransformerDecoderLayer(nn.Layer):
|
||||||
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
||||||
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
||||||
|
|
||||||
Shape:
|
|
||||||
see the docs in Transformer class.
|
|
||||||
"""
|
"""
|
||||||
tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
|
tgt2 = self.self_attn(
|
||||||
key_padding_mask=tgt_key_padding_mask)[0]
|
tgt,
|
||||||
|
tgt,
|
||||||
|
tgt,
|
||||||
|
attn_mask=tgt_mask,
|
||||||
|
key_padding_mask=tgt_key_padding_mask)[0]
|
||||||
tgt = tgt + self.dropout1(tgt2)
|
tgt = tgt + self.dropout1(tgt2)
|
||||||
tgt = self.norm1(tgt)
|
tgt = self.norm1(tgt)
|
||||||
tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
|
tgt2 = self.multihead_attn(
|
||||||
key_padding_mask=memory_key_padding_mask)[0]
|
tgt,
|
||||||
|
memory,
|
||||||
|
memory,
|
||||||
|
attn_mask=memory_mask,
|
||||||
|
key_padding_mask=memory_key_padding_mask)[0]
|
||||||
tgt = tgt + self.dropout2(tgt2)
|
tgt = tgt + self.dropout2(tgt2)
|
||||||
tgt = self.norm2(tgt)
|
tgt = self.norm2(tgt)
|
||||||
|
|
||||||
|
@ -562,9 +629,8 @@ def _get_clones(module, N):
|
||||||
return LayerList([copy.deepcopy(module) for i in range(N)])
|
return LayerList([copy.deepcopy(module) for i in range(N)])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(nn.Layer):
|
class PositionalEncoding(nn.Layer):
|
||||||
r"""Inject some information about the relative or absolute position of the tokens
|
"""Inject some information about the relative or absolute position of the tokens
|
||||||
in the sequence. The positional encodings have the same dimension as
|
in the sequence. The positional encodings have the same dimension as
|
||||||
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
||||||
functions of different frequencies.
|
functions of different frequencies.
|
||||||
|
@ -586,7 +652,9 @@ class PositionalEncoding(nn.Layer):
|
||||||
|
|
||||||
pe = paddle.zeros([max_len, dim])
|
pe = paddle.zeros([max_len, dim])
|
||||||
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
||||||
div_term = paddle.exp(paddle.arange(0, dim, 2).astype('float32') * (-math.log(10000.0) / dim))
|
div_term = paddle.exp(
|
||||||
|
paddle.arange(0, dim, 2).astype('float32') *
|
||||||
|
(-math.log(10000.0) / dim))
|
||||||
pe[:, 0::2] = paddle.sin(position * div_term)
|
pe[:, 0::2] = paddle.sin(position * div_term)
|
||||||
pe[:, 1::2] = paddle.cos(position * div_term)
|
pe[:, 1::2] = paddle.cos(position * div_term)
|
||||||
pe = pe.unsqueeze(0)
|
pe = pe.unsqueeze(0)
|
||||||
|
@ -594,7 +662,7 @@ class PositionalEncoding(nn.Layer):
|
||||||
self.register_buffer('pe', pe)
|
self.register_buffer('pe', pe)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
r"""Inputs of forward function
|
"""Inputs of forward function
|
||||||
Args:
|
Args:
|
||||||
x: the sequence fed to the positional encoder model (required).
|
x: the sequence fed to the positional encoder model (required).
|
||||||
Shape:
|
Shape:
|
||||||
|
@ -608,7 +676,7 @@ class PositionalEncoding(nn.Layer):
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding_2d(nn.Layer):
|
class PositionalEncoding_2d(nn.Layer):
|
||||||
r"""Inject some information about the relative or absolute position of the tokens
|
"""Inject some information about the relative or absolute position of the tokens
|
||||||
in the sequence. The positional encodings have the same dimension as
|
in the sequence. The positional encodings have the same dimension as
|
||||||
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
||||||
functions of different frequencies.
|
functions of different frequencies.
|
||||||
|
@ -630,7 +698,9 @@ class PositionalEncoding_2d(nn.Layer):
|
||||||
|
|
||||||
pe = paddle.zeros([max_len, dim])
|
pe = paddle.zeros([max_len, dim])
|
||||||
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
||||||
div_term = paddle.exp(paddle.arange(0, dim, 2).astype('float32') * (-math.log(10000.0) / dim))
|
div_term = paddle.exp(
|
||||||
|
paddle.arange(0, dim, 2).astype('float32') *
|
||||||
|
(-math.log(10000.0) / dim))
|
||||||
pe[:, 0::2] = paddle.sin(position * div_term)
|
pe[:, 0::2] = paddle.sin(position * div_term)
|
||||||
pe[:, 1::2] = paddle.cos(position * div_term)
|
pe[:, 1::2] = paddle.cos(position * div_term)
|
||||||
pe = pe.unsqueeze(0).transpose([1, 0, 2])
|
pe = pe.unsqueeze(0).transpose([1, 0, 2])
|
||||||
|
@ -644,7 +714,7 @@ class PositionalEncoding_2d(nn.Layer):
|
||||||
self.linear2.weight.data.fill_(1.)
|
self.linear2.weight.data.fill_(1.)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
r"""Inputs of forward function
|
"""Inputs of forward function
|
||||||
Args:
|
Args:
|
||||||
x: the sequence fed to the positional encoder model (required).
|
x: the sequence fed to the positional encoder model (required).
|
||||||
Shape:
|
Shape:
|
||||||
|
@ -666,7 +736,9 @@ class PositionalEncoding_2d(nn.Layer):
|
||||||
h_pe = h_pe.unsqueeze(3)
|
h_pe = h_pe.unsqueeze(3)
|
||||||
|
|
||||||
x = x + w_pe + h_pe
|
x = x + w_pe + h_pe
|
||||||
x = x.reshape([x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose([2,0,1])
|
x = x.reshape(
|
||||||
|
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]).transpose(
|
||||||
|
[2, 0, 1])
|
||||||
|
|
||||||
return self.dropout(x)
|
return self.dropout(x)
|
||||||
|
|
||||||
|
@ -675,8 +747,9 @@ class Embeddings(nn.Layer):
|
||||||
def __init__(self, d_model, vocab, padding_idx, scale_embedding):
|
def __init__(self, d_model, vocab, padding_idx, scale_embedding):
|
||||||
super(Embeddings, self).__init__()
|
super(Embeddings, self).__init__()
|
||||||
self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
|
self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
|
||||||
w0 = np.random.normal(0.0, d_model**-0.5,(vocab, d_model)).astype(np.float32)
|
w0 = np.random.normal(0.0, d_model**-0.5,
|
||||||
self.embedding.weight.set_value(w0)
|
(vocab, d_model)).astype(np.float32)
|
||||||
|
self.embedding.weight.set_value(w0)
|
||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.scale_embedding = scale_embedding
|
self.scale_embedding = scale_embedding
|
||||||
|
|
||||||
|
@ -687,9 +760,6 @@ class Embeddings(nn.Layer):
|
||||||
return self.embedding(x)
|
return self.embedding(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Beam():
|
class Beam():
|
||||||
''' Beam search '''
|
''' Beam search '''
|
||||||
|
|
||||||
|
@ -698,12 +768,12 @@ class Beam():
|
||||||
self.size = size
|
self.size = size
|
||||||
self._done = False
|
self._done = False
|
||||||
# The score for each translation on the beam.
|
# The score for each translation on the beam.
|
||||||
self.scores = paddle.zeros((size,), dtype=paddle.float32)
|
self.scores = paddle.zeros((size, ), dtype=paddle.float32)
|
||||||
self.all_scores = []
|
self.all_scores = []
|
||||||
# The backpointers at each time-step.
|
# The backpointers at each time-step.
|
||||||
self.prev_ks = []
|
self.prev_ks = []
|
||||||
# The outputs at each time-step.
|
# The outputs at each time-step.
|
||||||
self.next_ys = [paddle.full((size,), 0, dtype=paddle.int64)]
|
self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
|
||||||
self.next_ys[0][0] = 2
|
self.next_ys[0][0] = 2
|
||||||
|
|
||||||
def get_current_state(self):
|
def get_current_state(self):
|
||||||
|
@ -729,28 +799,26 @@ class Beam():
|
||||||
beam_lk = word_prob[0]
|
beam_lk = word_prob[0]
|
||||||
|
|
||||||
flat_beam_lk = beam_lk.reshape([-1])
|
flat_beam_lk = beam_lk.reshape([-1])
|
||||||
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True, True) # 1st sort
|
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
|
||||||
|
True) # 1st sort
|
||||||
self.all_scores.append(self.scores)
|
self.all_scores.append(self.scores)
|
||||||
self.scores = best_scores
|
self.scores = best_scores
|
||||||
|
|
||||||
# bestScoresId is flattened as a (beam x word) array,
|
# bestScoresId is flattened as a (beam x word) array,
|
||||||
# so we need to calculate which word and beam each score came from
|
# so we need to calculate which word and beam each score came from
|
||||||
prev_k = best_scores_id // num_words
|
prev_k = best_scores_id // num_words
|
||||||
self.prev_ks.append(prev_k)
|
self.prev_ks.append(prev_k)
|
||||||
|
self.next_ys.append(best_scores_id - prev_k * num_words)
|
||||||
self.next_ys.append(best_scores_id - prev_k * num_words)
|
|
||||||
|
|
||||||
# End condition is when top-of-beam is EOS.
|
# End condition is when top-of-beam is EOS.
|
||||||
if self.next_ys[-1][0] == 3 :
|
if self.next_ys[-1][0] == 3:
|
||||||
self._done = True
|
self._done = True
|
||||||
self.all_scores.append(self.scores)
|
self.all_scores.append(self.scores)
|
||||||
|
|
||||||
|
|
||||||
return self._done
|
return self._done
|
||||||
|
|
||||||
def sort_scores(self):
|
def sort_scores(self):
|
||||||
"Sort the scores."
|
"Sort the scores."
|
||||||
return self.scores, paddle.to_tensor([i for i in range(self.scores.shape[0])],dtype='int32')
|
return self.scores, paddle.to_tensor(
|
||||||
|
[i for i in range(self.scores.shape[0])], dtype='int32')
|
||||||
|
|
||||||
def get_the_best_score_and_idx(self):
|
def get_the_best_score_and_idx(self):
|
||||||
"Get the score of the best in the beam."
|
"Get the score of the best in the beam."
|
||||||
|
@ -759,7 +827,6 @@ class Beam():
|
||||||
|
|
||||||
def get_tentative_hypothesis(self):
|
def get_tentative_hypothesis(self):
|
||||||
"Get the decoded sequence for the current timestep."
|
"Get the decoded sequence for the current timestep."
|
||||||
|
|
||||||
if len(self.next_ys) == 1:
|
if len(self.next_ys) == 1:
|
||||||
dec_seq = self.next_ys[0].unsqueeze(1)
|
dec_seq = self.next_ys[0].unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
|
@ -767,13 +834,12 @@ class Beam():
|
||||||
hyps = [self.get_hypothesis(k) for k in keys]
|
hyps = [self.get_hypothesis(k) for k in keys]
|
||||||
hyps = [[2] + h for h in hyps]
|
hyps = [[2] + h for h in hyps]
|
||||||
dec_seq = paddle.to_tensor(hyps, dtype='int64')
|
dec_seq = paddle.to_tensor(hyps, dtype='int64')
|
||||||
|
|
||||||
return dec_seq
|
return dec_seq
|
||||||
|
|
||||||
def get_hypothesis(self, k):
|
def get_hypothesis(self, k):
|
||||||
""" Walk back to construct the full hypothesis. """
|
""" Walk back to construct the full hypothesis. """
|
||||||
hyp = []
|
hyp = []
|
||||||
for j in range(len(self.prev_ks) - 1, -1, -1):
|
for j in range(len(self.prev_ks) - 1, -1, -1):
|
||||||
hyp.append(self.next_ys[j+1][k])
|
hyp.append(self.next_ys[j + 1][k])
|
||||||
k = self.prev_ks[j][k]
|
k = self.prev_ks[j][k]
|
||||||
return list(map(lambda x: x.item(), hyp[::-1]))
|
return list(map(lambda x: x.item(), hyp[::-1]))
|
||||||
|
|
|
@ -189,9 +189,9 @@ def train(config,
|
||||||
|
|
||||||
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
|
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_type = config['Architecture']['model_type']
|
model_type = config['Architecture']['model_type']
|
||||||
except:
|
except:
|
||||||
model_type = None
|
model_type = None
|
||||||
|
|
||||||
if 'start_epoch' in best_model_dict:
|
if 'start_epoch' in best_model_dict:
|
||||||
|
@ -216,11 +216,8 @@ def train(config,
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
if use_srn:
|
if use_srn:
|
||||||
model_average = True
|
model_average = True
|
||||||
if use_srn or model_type == 'table':
|
if use_srn or model_type == 'table' or use_nrtr:
|
||||||
preds = model(images, data=batch[1:])
|
preds = model(images, data=batch[1:])
|
||||||
elif use_nrtr:
|
|
||||||
max_len = batch[2].max()
|
|
||||||
preds = model(images, batch[1][:,:2+max_len])
|
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
|
@ -405,9 +402,7 @@ def preprocess(is_train=False):
|
||||||
alg = config['Architecture']['algorithm']
|
alg = config['Architecture']['algorithm']
|
||||||
assert alg in [
|
assert alg in [
|
||||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||||
|
|
||||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
|
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn'
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||||
|
|
Loading…
Reference in New Issue