2021-08-19 17:31:02 +08:00
|
|
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2021-08-16 19:33:15 +08:00
|
|
|
import math
|
|
|
|
import paddle
|
|
|
|
import copy
|
2021-08-19 17:31:02 +08:00
|
|
|
from paddle import nn
|
2021-08-16 19:33:15 +08:00
|
|
|
import paddle.nn.functional as F
|
|
|
|
from paddle.nn import LayerList
|
|
|
|
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
|
|
|
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
|
|
|
|
import numpy as np
|
2021-08-24 15:46:43 +08:00
|
|
|
from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
|
2021-08-16 19:33:15 +08:00
|
|
|
from paddle.nn.initializer import Constant as constant_
|
|
|
|
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
|
|
|
|
|
|
|
zeros_ = constant_(value=0.)
|
|
|
|
ones_ = constant_(value=1.)
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
|
2021-08-24 15:46:43 +08:00
|
|
|
class Transformer(nn.Layer):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""A transformer model. User is able to modify the attributes as needed. The architechture
|
2021-08-16 19:33:15 +08:00
|
|
|
is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
|
|
|
|
Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
|
|
|
|
Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
|
|
|
|
Processing Systems, pages 6000-6010.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
d_model: the number of expected features in the encoder/decoder inputs (default=512).
|
|
|
|
nhead: the number of heads in the multiheadattention models (default=8).
|
|
|
|
num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
|
|
|
|
num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
|
|
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
|
|
dropout: the dropout value (default=0.1).
|
|
|
|
custom_encoder: custom encoder (default=None).
|
|
|
|
custom_decoder: custom decoder (default=None).
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def __init__(self,
|
|
|
|
d_model=512,
|
|
|
|
nhead=8,
|
|
|
|
num_encoder_layers=6,
|
|
|
|
beam_size=0,
|
|
|
|
num_decoder_layers=6,
|
|
|
|
dim_feedforward=1024,
|
|
|
|
attention_dropout_rate=0.0,
|
|
|
|
residual_dropout_rate=0.1,
|
|
|
|
custom_encoder=None,
|
|
|
|
custom_decoder=None,
|
|
|
|
in_channels=0,
|
|
|
|
out_channels=0,
|
|
|
|
scale_embedding=True):
|
2021-08-24 15:46:43 +08:00
|
|
|
super(Transformer, self).__init__()
|
2021-09-13 21:10:10 +08:00
|
|
|
self.out_channels = out_channels + 1
|
2021-08-16 19:33:15 +08:00
|
|
|
self.embedding = Embeddings(
|
|
|
|
d_model=d_model,
|
2021-09-13 21:10:10 +08:00
|
|
|
vocab=self.out_channels,
|
2021-08-16 19:33:15 +08:00
|
|
|
padding_idx=0,
|
2021-08-19 17:31:02 +08:00
|
|
|
scale_embedding=scale_embedding)
|
2021-08-16 19:33:15 +08:00
|
|
|
self.positional_encoding = PositionalEncoding(
|
|
|
|
dropout=residual_dropout_rate,
|
2021-08-19 17:31:02 +08:00
|
|
|
dim=d_model, )
|
2021-08-16 19:33:15 +08:00
|
|
|
if custom_encoder is not None:
|
|
|
|
self.encoder = custom_encoder
|
|
|
|
else:
|
2021-08-19 17:31:02 +08:00
|
|
|
if num_encoder_layers > 0:
|
|
|
|
encoder_layer = TransformerEncoderLayer(
|
|
|
|
d_model, nhead, dim_feedforward, attention_dropout_rate,
|
|
|
|
residual_dropout_rate)
|
|
|
|
self.encoder = TransformerEncoder(encoder_layer,
|
|
|
|
num_encoder_layers)
|
2021-08-16 19:33:15 +08:00
|
|
|
else:
|
|
|
|
self.encoder = None
|
|
|
|
|
|
|
|
if custom_decoder is not None:
|
|
|
|
self.decoder = custom_decoder
|
|
|
|
else:
|
2021-08-19 17:31:02 +08:00
|
|
|
decoder_layer = TransformerDecoderLayer(
|
|
|
|
d_model, nhead, dim_feedforward, attention_dropout_rate,
|
|
|
|
residual_dropout_rate)
|
2021-08-16 19:33:15 +08:00
|
|
|
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
|
|
|
|
|
|
|
|
self._reset_parameters()
|
|
|
|
self.beam_size = beam_size
|
|
|
|
self.d_model = d_model
|
|
|
|
self.nhead = nhead
|
2021-09-13 21:10:10 +08:00
|
|
|
self.tgt_word_prj = nn.Linear(
|
|
|
|
d_model, self.out_channels, bias_attr=False)
|
2021-08-19 17:31:02 +08:00
|
|
|
w0 = np.random.normal(0.0, d_model**-0.5,
|
2021-09-13 21:10:10 +08:00
|
|
|
(d_model, self.out_channels)).astype(np.float32)
|
2021-08-16 19:33:15 +08:00
|
|
|
self.tgt_word_prj.weight.set_value(w0)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
2021-08-19 17:31:02 +08:00
|
|
|
|
2021-08-16 19:33:15 +08:00
|
|
|
if isinstance(m, nn.Conv2D):
|
|
|
|
xavier_normal_(m.weight)
|
|
|
|
if m.bias is not None:
|
|
|
|
zeros_(m.bias)
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def forward_train(self, src, tgt):
|
|
|
|
tgt = tgt[:, :-1]
|
2021-08-16 19:33:15 +08:00
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
tgt_key_padding_mask = self.generate_padding_mask(tgt)
|
|
|
|
tgt = self.embedding(tgt).transpose([1, 0, 2])
|
|
|
|
tgt = self.positional_encoding(tgt)
|
|
|
|
tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
|
2021-08-16 19:33:15 +08:00
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
if self.encoder is not None:
|
|
|
|
src = self.positional_encoding(src.transpose([1, 0, 2]))
|
|
|
|
memory = self.encoder(src)
|
|
|
|
else:
|
|
|
|
memory = src.squeeze(2).transpose([2, 0, 1])
|
|
|
|
output = self.decoder(
|
|
|
|
tgt,
|
|
|
|
memory,
|
|
|
|
tgt_mask=tgt_mask,
|
|
|
|
memory_mask=None,
|
|
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
|
|
memory_key_padding_mask=None)
|
|
|
|
output = output.transpose([1, 0, 2])
|
|
|
|
logit = self.tgt_word_prj(output)
|
|
|
|
return logit
|
|
|
|
|
|
|
|
def forward(self, src, targets=None):
|
|
|
|
"""Take in and process masked source/target sequences.
|
2021-08-16 19:33:15 +08:00
|
|
|
Args:
|
|
|
|
src: the sequence to the encoder (required).
|
|
|
|
tgt: the sequence to the decoder (required).
|
|
|
|
Shape:
|
|
|
|
- src: :math:`(S, N, E)`.
|
|
|
|
- tgt: :math:`(T, N, E)`.
|
|
|
|
Examples:
|
2021-08-19 17:31:02 +08:00
|
|
|
>>> output = transformer_model(src, tgt)
|
2021-08-16 19:33:15 +08:00
|
|
|
"""
|
2021-08-19 17:31:02 +08:00
|
|
|
|
|
|
|
if self.training:
|
|
|
|
max_len = targets[1].max()
|
|
|
|
tgt = targets[0][:, :2 + max_len]
|
2021-08-16 19:33:15 +08:00
|
|
|
return self.forward_train(src, tgt)
|
|
|
|
else:
|
2021-08-19 17:31:02 +08:00
|
|
|
if self.beam_size > 0:
|
2021-08-16 19:33:15 +08:00
|
|
|
return self.forward_beam(src)
|
|
|
|
else:
|
|
|
|
return self.forward_test(src)
|
|
|
|
|
|
|
|
def forward_test(self, src):
|
2021-09-13 21:10:10 +08:00
|
|
|
bs = paddle.shape(src)[0]
|
2021-08-19 17:31:02 +08:00
|
|
|
if self.encoder is not None:
|
2021-09-13 21:10:10 +08:00
|
|
|
src = self.positional_encoding(paddle.transpose(src, [1, 0, 2]))
|
2021-08-16 19:33:15 +08:00
|
|
|
memory = self.encoder(src)
|
|
|
|
else:
|
2021-09-13 21:10:10 +08:00
|
|
|
memory = paddle.transpose(paddle.squeeze(src, 2), [2, 0, 1])
|
2021-08-19 17:31:02 +08:00
|
|
|
dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
|
2021-09-13 21:10:10 +08:00
|
|
|
dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
|
2021-08-16 19:33:15 +08:00
|
|
|
for len_dec_seq in range(1, 25):
|
2021-09-13 21:10:10 +08:00
|
|
|
dec_seq_embed = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
|
2021-08-16 19:33:15 +08:00
|
|
|
dec_seq_embed = self.positional_encoding(dec_seq_embed)
|
2021-09-13 21:10:10 +08:00
|
|
|
tgt_mask = self.generate_square_subsequent_mask(
|
|
|
|
paddle.shape(dec_seq_embed)[0])
|
2021-08-19 17:31:02 +08:00
|
|
|
output = self.decoder(
|
|
|
|
dec_seq_embed,
|
2021-09-13 21:10:10 +08:00
|
|
|
memory,
|
2021-08-19 17:31:02 +08:00
|
|
|
tgt_mask=tgt_mask,
|
|
|
|
memory_mask=None,
|
2021-09-13 21:10:10 +08:00
|
|
|
tgt_key_padding_mask=None,
|
2021-08-19 17:31:02 +08:00
|
|
|
memory_key_padding_mask=None)
|
2021-09-13 21:10:10 +08:00
|
|
|
dec_output = paddle.transpose(output, [1, 0, 2])
|
|
|
|
dec_output = dec_output[:, -1, :]
|
|
|
|
word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
|
|
|
|
preds_idx = paddle.argmax(word_prob, axis=1)
|
2021-08-19 17:31:02 +08:00
|
|
|
if paddle.equal_all(
|
2021-09-13 21:10:10 +08:00
|
|
|
preds_idx,
|
2021-08-19 17:31:02 +08:00
|
|
|
paddle.full(
|
2021-09-13 21:10:10 +08:00
|
|
|
paddle.shape(preds_idx), 3, dtype='int64')):
|
2021-08-16 19:33:15 +08:00
|
|
|
break
|
2021-09-13 21:10:10 +08:00
|
|
|
preds_prob = paddle.max(word_prob, axis=1)
|
2021-08-19 17:31:02 +08:00
|
|
|
dec_seq = paddle.concat(
|
2021-09-13 21:10:10 +08:00
|
|
|
[dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
|
|
|
|
dec_prob = paddle.concat(
|
|
|
|
[dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1)
|
|
|
|
return [dec_seq, dec_prob]
|
2021-08-16 19:33:15 +08:00
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def forward_beam(self, images):
|
2021-08-16 19:33:15 +08:00
|
|
|
''' Translation work in one batch '''
|
|
|
|
|
|
|
|
def get_inst_idx_to_tensor_position_map(inst_idx_list):
|
|
|
|
''' Indicate the position of an instance in a tensor. '''
|
2021-08-19 17:31:02 +08:00
|
|
|
return {
|
|
|
|
inst_idx: tensor_position
|
|
|
|
for tensor_position, inst_idx in enumerate(inst_idx_list)
|
|
|
|
}
|
2021-08-16 19:33:15 +08:00
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def collect_active_part(beamed_tensor, curr_active_inst_idx,
|
|
|
|
n_prev_active_inst, n_bm):
|
2021-08-16 19:33:15 +08:00
|
|
|
''' Collect tensor parts associated to active instances. '''
|
|
|
|
|
2021-09-13 21:10:10 +08:00
|
|
|
beamed_tensor_shape = paddle.shape(beamed_tensor)
|
2021-08-16 19:33:15 +08:00
|
|
|
n_curr_active_inst = len(curr_active_inst_idx)
|
2021-09-13 21:10:10 +08:00
|
|
|
new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
|
|
|
|
beamed_tensor_shape[2])
|
2021-08-16 19:33:15 +08:00
|
|
|
|
2021-08-24 15:46:43 +08:00
|
|
|
beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
|
2021-08-19 17:31:02 +08:00
|
|
|
beamed_tensor = beamed_tensor.index_select(
|
2021-09-13 21:10:10 +08:00
|
|
|
curr_active_inst_idx, axis=0)
|
|
|
|
beamed_tensor = beamed_tensor.reshape(new_shape)
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
return beamed_tensor
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def collate_active_info(src_enc, inst_idx_to_position_map,
|
|
|
|
active_inst_idx_list):
|
2021-08-16 19:33:15 +08:00
|
|
|
# Sentences which are still active are collected,
|
|
|
|
# so the decoder will not run on completed sentences.
|
2021-08-19 17:31:02 +08:00
|
|
|
|
2021-08-16 19:33:15 +08:00
|
|
|
n_prev_active_inst = len(inst_idx_to_position_map)
|
2021-08-19 17:31:02 +08:00
|
|
|
active_inst_idx = [
|
|
|
|
inst_idx_to_position_map[k] for k in active_inst_idx_list
|
|
|
|
]
|
2021-08-16 19:33:15 +08:00
|
|
|
active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
|
2021-08-19 17:31:02 +08:00
|
|
|
active_src_enc = collect_active_part(
|
|
|
|
src_enc.transpose([1, 0, 2]), active_inst_idx,
|
|
|
|
n_prev_active_inst, n_bm).transpose([1, 0, 2])
|
|
|
|
active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
|
|
|
|
active_inst_idx_list)
|
2021-08-16 19:33:15 +08:00
|
|
|
return active_src_enc, active_inst_idx_to_position_map
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
|
|
|
|
inst_idx_to_position_map, n_bm,
|
|
|
|
memory_key_padding_mask):
|
2021-08-16 19:33:15 +08:00
|
|
|
''' Decode and update beam status, and then return active beam idx '''
|
|
|
|
|
|
|
|
def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
|
2021-08-19 17:31:02 +08:00
|
|
|
dec_partial_seq = [
|
|
|
|
b.get_current_state() for b in inst_dec_beams if not b.done
|
|
|
|
]
|
2021-08-16 19:33:15 +08:00
|
|
|
dec_partial_seq = paddle.stack(dec_partial_seq)
|
|
|
|
dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
|
|
|
|
return dec_partial_seq
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
|
|
|
|
memory_key_padding_mask):
|
2021-09-13 21:10:10 +08:00
|
|
|
dec_seq = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
|
2021-08-16 19:33:15 +08:00
|
|
|
dec_seq = self.positional_encoding(dec_seq)
|
2021-09-13 21:10:10 +08:00
|
|
|
tgt_mask = self.generate_square_subsequent_mask(
|
|
|
|
paddle.shape(dec_seq)[0])
|
2021-08-16 19:33:15 +08:00
|
|
|
dec_output = self.decoder(
|
2021-08-19 17:31:02 +08:00
|
|
|
dec_seq,
|
|
|
|
enc_output,
|
2021-08-16 19:33:15 +08:00
|
|
|
tgt_mask=tgt_mask,
|
2021-09-13 21:10:10 +08:00
|
|
|
tgt_key_padding_mask=None,
|
|
|
|
memory_key_padding_mask=memory_key_padding_mask, )
|
|
|
|
dec_output = paddle.transpose(dec_output, [1, 0, 2])
|
2021-08-19 17:31:02 +08:00
|
|
|
dec_output = dec_output[:,
|
|
|
|
-1, :] # Pick the last step: (bh * bm) * d_h
|
2021-09-13 21:10:10 +08:00
|
|
|
word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
|
|
|
|
word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1])
|
2021-08-16 19:33:15 +08:00
|
|
|
return word_prob
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def collect_active_inst_idx_list(inst_beams, word_prob,
|
|
|
|
inst_idx_to_position_map):
|
2021-08-16 19:33:15 +08:00
|
|
|
active_inst_idx_list = []
|
|
|
|
for inst_idx, inst_position in inst_idx_to_position_map.items():
|
2021-08-19 17:31:02 +08:00
|
|
|
is_inst_complete = inst_beams[inst_idx].advance(word_prob[
|
|
|
|
inst_position])
|
2021-08-16 19:33:15 +08:00
|
|
|
if not is_inst_complete:
|
|
|
|
active_inst_idx_list += [inst_idx]
|
|
|
|
|
|
|
|
return active_inst_idx_list
|
|
|
|
|
|
|
|
n_active_inst = len(inst_idx_to_position_map)
|
|
|
|
dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
|
2021-08-19 17:31:02 +08:00
|
|
|
word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
|
2021-09-13 21:10:10 +08:00
|
|
|
None)
|
2021-08-16 19:33:15 +08:00
|
|
|
# Update the beam with predicted word prob information and collect incomplete instances
|
|
|
|
active_inst_idx_list = collect_active_inst_idx_list(
|
|
|
|
inst_dec_beams, word_prob, inst_idx_to_position_map)
|
|
|
|
return active_inst_idx_list
|
|
|
|
|
|
|
|
def collect_hypothesis_and_scores(inst_dec_beams, n_best):
|
|
|
|
all_hyp, all_scores = [], []
|
|
|
|
for inst_idx in range(len(inst_dec_beams)):
|
|
|
|
scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
|
|
|
|
all_scores += [scores[:n_best]]
|
2021-08-19 17:31:02 +08:00
|
|
|
hyps = [
|
|
|
|
inst_dec_beams[inst_idx].get_hypothesis(i)
|
|
|
|
for i in tail_idxs[:n_best]
|
|
|
|
]
|
2021-08-16 19:33:15 +08:00
|
|
|
all_hyp += [hyps]
|
|
|
|
return all_hyp, all_scores
|
|
|
|
|
|
|
|
with paddle.no_grad():
|
|
|
|
#-- Encode
|
2021-08-19 17:31:02 +08:00
|
|
|
if self.encoder is not None:
|
2021-08-16 19:33:15 +08:00
|
|
|
src = self.positional_encoding(images.transpose([1, 0, 2]))
|
2021-09-13 21:10:10 +08:00
|
|
|
src_enc = self.encoder(src)
|
2021-08-16 19:33:15 +08:00
|
|
|
else:
|
|
|
|
src_enc = images.squeeze(2).transpose([0, 2, 1])
|
|
|
|
|
|
|
|
n_bm = self.beam_size
|
2021-09-13 21:10:10 +08:00
|
|
|
src_shape = paddle.shape(src_enc)
|
|
|
|
inst_dec_beams = [Beam(n_bm) for _ in range(1)]
|
|
|
|
active_inst_idx_list = list(range(1))
|
|
|
|
# Repeat data for beam search
|
|
|
|
src_enc = paddle.tile(src_enc, [1, n_bm, 1])
|
2021-08-19 17:31:02 +08:00
|
|
|
inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
|
|
|
|
active_inst_idx_list)
|
2021-09-13 21:10:10 +08:00
|
|
|
# Decode
|
2021-08-16 19:33:15 +08:00
|
|
|
for len_dec_seq in range(1, 25):
|
|
|
|
src_enc_copy = src_enc.clone()
|
|
|
|
active_inst_idx_list = beam_decode_step(
|
2021-08-19 17:31:02 +08:00
|
|
|
inst_dec_beams, len_dec_seq, src_enc_copy,
|
|
|
|
inst_idx_to_position_map, n_bm, None)
|
2021-08-16 19:33:15 +08:00
|
|
|
if not active_inst_idx_list:
|
|
|
|
break # all instances have finished their path to <EOS>
|
|
|
|
src_enc, inst_idx_to_position_map = collate_active_info(
|
2021-08-19 17:31:02 +08:00
|
|
|
src_enc_copy, inst_idx_to_position_map,
|
|
|
|
active_inst_idx_list)
|
|
|
|
batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
|
|
|
|
1)
|
2021-08-16 19:33:15 +08:00
|
|
|
result_hyp = []
|
2021-09-13 21:10:10 +08:00
|
|
|
hyp_scores = []
|
|
|
|
for bs_hyp, score in zip(batch_hyp, batch_scores):
|
|
|
|
l = len(bs_hyp[0])
|
|
|
|
bs_hyp_pad = bs_hyp[0] + [3] * (25 - l)
|
2021-08-16 19:33:15 +08:00
|
|
|
result_hyp.append(bs_hyp_pad)
|
2021-09-13 21:10:10 +08:00
|
|
|
score = float(score) / l
|
|
|
|
hyp_score = [score for _ in range(25)]
|
|
|
|
hyp_scores.append(hyp_score)
|
|
|
|
return [
|
|
|
|
paddle.to_tensor(
|
|
|
|
np.array(result_hyp), dtype=paddle.int64),
|
|
|
|
paddle.to_tensor(hyp_scores)
|
|
|
|
]
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
def generate_square_subsequent_mask(self, sz):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
|
2021-08-16 19:33:15 +08:00
|
|
|
Unmasked positions are filled with float(0.0).
|
|
|
|
"""
|
2021-08-19 17:31:02 +08:00
|
|
|
mask = paddle.zeros([sz, sz], dtype='float32')
|
|
|
|
mask_inf = paddle.triu(
|
|
|
|
paddle.full(
|
|
|
|
shape=[sz, sz], dtype='float32', fill_value='-inf'),
|
|
|
|
diagonal=1)
|
|
|
|
mask = mask + mask_inf
|
2021-08-16 19:33:15 +08:00
|
|
|
return mask
|
|
|
|
|
|
|
|
def generate_padding_mask(self, x):
|
2021-09-13 21:10:10 +08:00
|
|
|
padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype))
|
2021-08-16 19:33:15 +08:00
|
|
|
return padding_mask
|
|
|
|
|
|
|
|
def _reset_parameters(self):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""Initiate parameters in the transformer model."""
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
for p in self.parameters():
|
|
|
|
if p.dim() > 1:
|
|
|
|
xavier_uniform_(p)
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerEncoder(nn.Layer):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""TransformerEncoder is a stack of N encoder layers
|
2021-08-16 19:33:15 +08:00
|
|
|
Args:
|
|
|
|
encoder_layer: an instance of the TransformerEncoderLayer() class (required).
|
|
|
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
|
|
|
norm: the layer normalization component (optional).
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, encoder_layer, num_layers):
|
|
|
|
super(TransformerEncoder, self).__init__()
|
|
|
|
self.layers = _get_clones(encoder_layer, num_layers)
|
|
|
|
self.num_layers = num_layers
|
|
|
|
|
|
|
|
def forward(self, src):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""Pass the input through the endocder layers in turn.
|
2021-08-16 19:33:15 +08:00
|
|
|
Args:
|
|
|
|
src: the sequnce to the encoder (required).
|
|
|
|
mask: the mask for the src sequence (optional).
|
|
|
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
|
|
"""
|
|
|
|
output = src
|
|
|
|
|
|
|
|
for i in range(self.num_layers):
|
2021-08-19 17:31:02 +08:00
|
|
|
output = self.layers[i](output,
|
|
|
|
src_mask=None,
|
2021-08-16 19:33:15 +08:00
|
|
|
src_key_padding_mask=None)
|
|
|
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerDecoder(nn.Layer):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""TransformerDecoder is a stack of N decoder layers
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
decoder_layer: an instance of the TransformerDecoderLayer() class (required).
|
|
|
|
num_layers: the number of sub-decoder-layers in the decoder (required).
|
|
|
|
norm: the layer normalization component (optional).
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, decoder_layer, num_layers):
|
|
|
|
super(TransformerDecoder, self).__init__()
|
|
|
|
self.layers = _get_clones(decoder_layer, num_layers)
|
|
|
|
self.num_layers = num_layers
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def forward(self,
|
|
|
|
tgt,
|
|
|
|
memory,
|
|
|
|
tgt_mask=None,
|
|
|
|
memory_mask=None,
|
|
|
|
tgt_key_padding_mask=None,
|
2021-08-16 19:33:15 +08:00
|
|
|
memory_key_padding_mask=None):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""Pass the inputs (and mask) through the decoder layer in turn.
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
tgt: the sequence to the decoder (required).
|
|
|
|
memory: the sequnce from the last layer of the encoder (required).
|
|
|
|
tgt_mask: the mask for the tgt sequence (optional).
|
|
|
|
memory_mask: the mask for the memory sequence (optional).
|
|
|
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
|
|
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
|
|
|
"""
|
|
|
|
output = tgt
|
|
|
|
for i in range(self.num_layers):
|
2021-08-19 17:31:02 +08:00
|
|
|
output = self.layers[i](
|
|
|
|
output,
|
|
|
|
memory,
|
|
|
|
tgt_mask=tgt_mask,
|
|
|
|
memory_mask=memory_mask,
|
|
|
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
|
|
|
memory_key_padding_mask=memory_key_padding_mask)
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
return output
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
|
2021-08-16 19:33:15 +08:00
|
|
|
class TransformerEncoderLayer(nn.Layer):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""TransformerEncoderLayer is made up of self-attn and feedforward network.
|
2021-08-16 19:33:15 +08:00
|
|
|
This standard encoder layer is based on the paper "Attention Is All You Need".
|
|
|
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
|
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
|
|
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
|
|
|
in a different way during application.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
d_model: the number of expected features in the input (required).
|
|
|
|
nhead: the number of heads in the multiheadattention models (required).
|
|
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
|
|
dropout: the dropout value (default=0.1).
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def __init__(self,
|
|
|
|
d_model,
|
|
|
|
nhead,
|
|
|
|
dim_feedforward=2048,
|
|
|
|
attention_dropout_rate=0.0,
|
|
|
|
residual_dropout_rate=0.1):
|
2021-08-16 19:33:15 +08:00
|
|
|
super(TransformerEncoderLayer, self).__init__()
|
2021-08-24 15:46:43 +08:00
|
|
|
self.self_attn = MultiheadAttention(
|
2021-08-19 17:31:02 +08:00
|
|
|
d_model, nhead, dropout=attention_dropout_rate)
|
|
|
|
|
|
|
|
self.conv1 = Conv2D(
|
|
|
|
in_channels=d_model,
|
|
|
|
out_channels=dim_feedforward,
|
|
|
|
kernel_size=(1, 1))
|
|
|
|
self.conv2 = Conv2D(
|
|
|
|
in_channels=dim_feedforward,
|
|
|
|
out_channels=d_model,
|
|
|
|
kernel_size=(1, 1))
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
self.norm1 = LayerNorm(d_model)
|
|
|
|
self.norm2 = LayerNorm(d_model)
|
|
|
|
self.dropout1 = Dropout(residual_dropout_rate)
|
|
|
|
self.dropout2 = Dropout(residual_dropout_rate)
|
|
|
|
|
|
|
|
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""Pass the input through the endocder layer.
|
2021-08-16 19:33:15 +08:00
|
|
|
Args:
|
|
|
|
src: the sequnce to the encoder layer (required).
|
|
|
|
src_mask: the mask for the src sequence (optional).
|
|
|
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
|
|
|
"""
|
2021-08-19 17:31:02 +08:00
|
|
|
src2 = self.self_attn(
|
|
|
|
src,
|
|
|
|
src,
|
|
|
|
src,
|
|
|
|
attn_mask=src_mask,
|
2021-09-13 21:10:10 +08:00
|
|
|
key_padding_mask=src_key_padding_mask)
|
2021-08-16 19:33:15 +08:00
|
|
|
src = src + self.dropout1(src2)
|
|
|
|
src = self.norm1(src)
|
|
|
|
|
2021-09-13 21:10:10 +08:00
|
|
|
src = paddle.transpose(src, [1, 2, 0])
|
2021-08-16 19:33:15 +08:00
|
|
|
src = paddle.unsqueeze(src, 2)
|
|
|
|
src2 = self.conv2(F.relu(self.conv1(src)))
|
|
|
|
src2 = paddle.squeeze(src2, 2)
|
2021-09-13 21:10:10 +08:00
|
|
|
src2 = paddle.transpose(src2, [2, 0, 1])
|
2021-08-16 19:33:15 +08:00
|
|
|
src = paddle.squeeze(src, 2)
|
2021-09-13 21:10:10 +08:00
|
|
|
src = paddle.transpose(src, [2, 0, 1])
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
src = src + self.dropout2(src2)
|
|
|
|
src = self.norm2(src)
|
|
|
|
return src
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
|
2021-08-16 19:33:15 +08:00
|
|
|
class TransformerDecoderLayer(nn.Layer):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
|
2021-08-16 19:33:15 +08:00
|
|
|
This standard decoder layer is based on the paper "Attention Is All You Need".
|
|
|
|
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
|
|
|
|
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
|
|
|
|
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
|
|
|
|
in a different way during application.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
d_model: the number of expected features in the input (required).
|
|
|
|
nhead: the number of heads in the multiheadattention models (required).
|
|
|
|
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
|
|
|
dropout: the dropout value (default=0.1).
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def __init__(self,
|
|
|
|
d_model,
|
|
|
|
nhead,
|
|
|
|
dim_feedforward=2048,
|
|
|
|
attention_dropout_rate=0.0,
|
|
|
|
residual_dropout_rate=0.1):
|
2021-08-16 19:33:15 +08:00
|
|
|
super(TransformerDecoderLayer, self).__init__()
|
2021-08-24 15:46:43 +08:00
|
|
|
self.self_attn = MultiheadAttention(
|
2021-08-19 17:31:02 +08:00
|
|
|
d_model, nhead, dropout=attention_dropout_rate)
|
2021-08-24 15:46:43 +08:00
|
|
|
self.multihead_attn = MultiheadAttention(
|
2021-08-19 17:31:02 +08:00
|
|
|
d_model, nhead, dropout=attention_dropout_rate)
|
|
|
|
|
|
|
|
self.conv1 = Conv2D(
|
|
|
|
in_channels=d_model,
|
|
|
|
out_channels=dim_feedforward,
|
|
|
|
kernel_size=(1, 1))
|
|
|
|
self.conv2 = Conv2D(
|
|
|
|
in_channels=dim_feedforward,
|
|
|
|
out_channels=d_model,
|
|
|
|
kernel_size=(1, 1))
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
self.norm1 = LayerNorm(d_model)
|
|
|
|
self.norm2 = LayerNorm(d_model)
|
|
|
|
self.norm3 = LayerNorm(d_model)
|
|
|
|
self.dropout1 = Dropout(residual_dropout_rate)
|
|
|
|
self.dropout2 = Dropout(residual_dropout_rate)
|
|
|
|
self.dropout3 = Dropout(residual_dropout_rate)
|
|
|
|
|
2021-08-19 17:31:02 +08:00
|
|
|
def forward(self,
|
|
|
|
tgt,
|
|
|
|
memory,
|
|
|
|
tgt_mask=None,
|
|
|
|
memory_mask=None,
|
|
|
|
tgt_key_padding_mask=None,
|
|
|
|
memory_key_padding_mask=None):
|
|
|
|
"""Pass the inputs (and mask) through the decoder layer.
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
tgt: the sequence to the decoder layer (required).
|
|
|
|
memory: the sequnce from the last layer of the encoder (required).
|
|
|
|
tgt_mask: the mask for the tgt sequence (optional).
|
|
|
|
memory_mask: the mask for the memory sequence (optional).
|
|
|
|
tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
|
|
|
|
memory_key_padding_mask: the mask for the memory keys per batch (optional).
|
|
|
|
|
|
|
|
"""
|
2021-08-19 17:31:02 +08:00
|
|
|
tgt2 = self.self_attn(
|
|
|
|
tgt,
|
|
|
|
tgt,
|
|
|
|
tgt,
|
|
|
|
attn_mask=tgt_mask,
|
2021-09-13 21:10:10 +08:00
|
|
|
key_padding_mask=tgt_key_padding_mask)
|
2021-08-16 19:33:15 +08:00
|
|
|
tgt = tgt + self.dropout1(tgt2)
|
|
|
|
tgt = self.norm1(tgt)
|
2021-08-19 17:31:02 +08:00
|
|
|
tgt2 = self.multihead_attn(
|
|
|
|
tgt,
|
|
|
|
memory,
|
|
|
|
memory,
|
|
|
|
attn_mask=memory_mask,
|
2021-09-13 21:10:10 +08:00
|
|
|
key_padding_mask=memory_key_padding_mask)
|
2021-08-16 19:33:15 +08:00
|
|
|
tgt = tgt + self.dropout2(tgt2)
|
|
|
|
tgt = self.norm2(tgt)
|
|
|
|
|
|
|
|
# default
|
2021-09-13 21:10:10 +08:00
|
|
|
tgt = paddle.transpose(tgt, [1, 2, 0])
|
2021-08-16 19:33:15 +08:00
|
|
|
tgt = paddle.unsqueeze(tgt, 2)
|
|
|
|
tgt2 = self.conv2(F.relu(self.conv1(tgt)))
|
|
|
|
tgt2 = paddle.squeeze(tgt2, 2)
|
2021-09-13 21:10:10 +08:00
|
|
|
tgt2 = paddle.transpose(tgt2, [2, 0, 1])
|
2021-08-16 19:33:15 +08:00
|
|
|
tgt = paddle.squeeze(tgt, 2)
|
2021-09-13 21:10:10 +08:00
|
|
|
tgt = paddle.transpose(tgt, [2, 0, 1])
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
tgt = tgt + self.dropout3(tgt2)
|
|
|
|
tgt = self.norm3(tgt)
|
|
|
|
return tgt
|
|
|
|
|
|
|
|
|
|
|
|
def _get_clones(module, N):
|
|
|
|
return LayerList([copy.deepcopy(module) for i in range(N)])
|
|
|
|
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Layer):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""Inject some information about the relative or absolute position of the tokens
|
2021-08-16 19:33:15 +08:00
|
|
|
in the sequence. The positional encodings have the same dimension as
|
|
|
|
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
|
|
|
functions of different frequencies.
|
|
|
|
.. math::
|
|
|
|
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
|
|
|
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
|
|
|
\text{where pos is the word position and i is the embed idx)
|
|
|
|
Args:
|
|
|
|
d_model: the embed dim (required).
|
|
|
|
dropout: the dropout value (default=0.1).
|
|
|
|
max_len: the max. length of the incoming sequence (default=5000).
|
|
|
|
Examples:
|
|
|
|
>>> pos_encoder = PositionalEncoding(d_model)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, dropout, dim, max_len=5000):
|
|
|
|
super(PositionalEncoding, self).__init__()
|
|
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
|
|
|
|
pe = paddle.zeros([max_len, dim])
|
|
|
|
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
2021-08-19 17:31:02 +08:00
|
|
|
div_term = paddle.exp(
|
|
|
|
paddle.arange(0, dim, 2).astype('float32') *
|
|
|
|
(-math.log(10000.0) / dim))
|
2021-08-16 19:33:15 +08:00
|
|
|
pe[:, 0::2] = paddle.sin(position * div_term)
|
|
|
|
pe[:, 1::2] = paddle.cos(position * div_term)
|
2021-09-13 21:10:10 +08:00
|
|
|
pe = paddle.unsqueeze(pe, 0)
|
|
|
|
pe = paddle.transpose(pe, [1, 0, 2])
|
2021-08-16 19:33:15 +08:00
|
|
|
self.register_buffer('pe', pe)
|
|
|
|
|
|
|
|
def forward(self, x):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""Inputs of forward function
|
2021-08-16 19:33:15 +08:00
|
|
|
Args:
|
|
|
|
x: the sequence fed to the positional encoder model (required).
|
|
|
|
Shape:
|
|
|
|
x: [sequence length, batch size, embed dim]
|
|
|
|
output: [sequence length, batch size, embed dim]
|
|
|
|
Examples:
|
|
|
|
>>> output = pos_encoder(x)
|
|
|
|
"""
|
2021-09-13 21:10:10 +08:00
|
|
|
x = x + self.pe[:paddle.shape(x)[0], :]
|
2021-08-16 19:33:15 +08:00
|
|
|
return self.dropout(x)
|
|
|
|
|
|
|
|
|
|
|
|
class PositionalEncoding_2d(nn.Layer):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""Inject some information about the relative or absolute position of the tokens
|
2021-08-16 19:33:15 +08:00
|
|
|
in the sequence. The positional encodings have the same dimension as
|
|
|
|
the embeddings, so that the two can be summed. Here, we use sine and cosine
|
|
|
|
functions of different frequencies.
|
|
|
|
.. math::
|
|
|
|
\text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
|
|
|
|
\text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
|
|
|
|
\text{where pos is the word position and i is the embed idx)
|
|
|
|
Args:
|
|
|
|
d_model: the embed dim (required).
|
|
|
|
dropout: the dropout value (default=0.1).
|
|
|
|
max_len: the max. length of the incoming sequence (default=5000).
|
|
|
|
Examples:
|
|
|
|
>>> pos_encoder = PositionalEncoding(d_model)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, dropout, dim, max_len=5000):
|
|
|
|
super(PositionalEncoding_2d, self).__init__()
|
|
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
|
|
|
|
|
|
pe = paddle.zeros([max_len, dim])
|
|
|
|
position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
|
2021-08-19 17:31:02 +08:00
|
|
|
div_term = paddle.exp(
|
|
|
|
paddle.arange(0, dim, 2).astype('float32') *
|
|
|
|
(-math.log(10000.0) / dim))
|
2021-08-16 19:33:15 +08:00
|
|
|
pe[:, 0::2] = paddle.sin(position * div_term)
|
|
|
|
pe[:, 1::2] = paddle.cos(position * div_term)
|
2021-09-13 21:10:10 +08:00
|
|
|
pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2])
|
2021-08-16 19:33:15 +08:00
|
|
|
self.register_buffer('pe', pe)
|
|
|
|
|
|
|
|
self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
|
|
|
|
self.linear1 = nn.Linear(dim, dim)
|
|
|
|
self.linear1.weight.data.fill_(1.)
|
|
|
|
self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
|
|
|
|
self.linear2 = nn.Linear(dim, dim)
|
|
|
|
self.linear2.weight.data.fill_(1.)
|
|
|
|
|
|
|
|
def forward(self, x):
|
2021-08-19 17:31:02 +08:00
|
|
|
"""Inputs of forward function
|
2021-08-16 19:33:15 +08:00
|
|
|
Args:
|
|
|
|
x: the sequence fed to the positional encoder model (required).
|
|
|
|
Shape:
|
|
|
|
x: [sequence length, batch size, embed dim]
|
|
|
|
output: [sequence length, batch size, embed dim]
|
|
|
|
Examples:
|
|
|
|
>>> output = pos_encoder(x)
|
|
|
|
"""
|
2021-09-13 21:10:10 +08:00
|
|
|
w_pe = self.pe[:paddle.shape(x)[-1], :]
|
2021-08-16 19:33:15 +08:00
|
|
|
w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
|
|
|
|
w_pe = w_pe * w1
|
2021-09-13 21:10:10 +08:00
|
|
|
w_pe = paddle.transpose(w_pe, [1, 2, 0])
|
|
|
|
w_pe = paddle.unsqueeze(w_pe, 2)
|
2021-08-16 19:33:15 +08:00
|
|
|
|
2021-09-13 21:10:10 +08:00
|
|
|
h_pe = self.pe[:paddle.shape(x).shape[-2], :]
|
2021-08-16 19:33:15 +08:00
|
|
|
w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
|
|
|
|
h_pe = h_pe * w2
|
2021-09-13 21:10:10 +08:00
|
|
|
h_pe = paddle.transpose(h_pe, [1, 2, 0])
|
|
|
|
h_pe = paddle.unsqueeze(h_pe, 3)
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
x = x + w_pe + h_pe
|
2021-09-13 21:10:10 +08:00
|
|
|
x = paddle.transpose(
|
|
|
|
paddle.reshape(x,
|
|
|
|
[x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
|
|
|
|
[2, 0, 1])
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
return self.dropout(x)
|
|
|
|
|
|
|
|
|
|
|
|
class Embeddings(nn.Layer):
|
|
|
|
def __init__(self, d_model, vocab, padding_idx, scale_embedding):
|
|
|
|
super(Embeddings, self).__init__()
|
|
|
|
self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
|
2021-08-19 17:31:02 +08:00
|
|
|
w0 = np.random.normal(0.0, d_model**-0.5,
|
|
|
|
(vocab, d_model)).astype(np.float32)
|
|
|
|
self.embedding.weight.set_value(w0)
|
2021-08-16 19:33:15 +08:00
|
|
|
self.d_model = d_model
|
|
|
|
self.scale_embedding = scale_embedding
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
if self.scale_embedding:
|
|
|
|
x = self.embedding(x)
|
|
|
|
return x * math.sqrt(self.d_model)
|
|
|
|
return self.embedding(x)
|
|
|
|
|
|
|
|
|
|
|
|
class Beam():
|
|
|
|
''' Beam search '''
|
|
|
|
|
|
|
|
def __init__(self, size, device=False):
|
|
|
|
|
|
|
|
self.size = size
|
|
|
|
self._done = False
|
|
|
|
# The score for each translation on the beam.
|
2021-08-19 17:31:02 +08:00
|
|
|
self.scores = paddle.zeros((size, ), dtype=paddle.float32)
|
2021-08-16 19:33:15 +08:00
|
|
|
self.all_scores = []
|
|
|
|
# The backpointers at each time-step.
|
|
|
|
self.prev_ks = []
|
|
|
|
# The outputs at each time-step.
|
2021-08-19 17:31:02 +08:00
|
|
|
self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
|
2021-08-16 19:33:15 +08:00
|
|
|
self.next_ys[0][0] = 2
|
|
|
|
|
|
|
|
def get_current_state(self):
|
|
|
|
"Get the outputs for the current timestep."
|
|
|
|
return self.get_tentative_hypothesis()
|
|
|
|
|
|
|
|
def get_current_origin(self):
|
|
|
|
"Get the backpointers for the current timestep."
|
|
|
|
return self.prev_ks[-1]
|
|
|
|
|
|
|
|
@property
|
|
|
|
def done(self):
|
|
|
|
return self._done
|
|
|
|
|
|
|
|
def advance(self, word_prob):
|
|
|
|
"Update beam status and check if finished or not."
|
|
|
|
num_words = word_prob.shape[1]
|
|
|
|
|
|
|
|
# Sum the previous scores.
|
|
|
|
if len(self.prev_ks) > 0:
|
|
|
|
beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
|
|
|
|
else:
|
|
|
|
beam_lk = word_prob[0]
|
|
|
|
|
|
|
|
flat_beam_lk = beam_lk.reshape([-1])
|
2021-08-19 17:31:02 +08:00
|
|
|
best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
|
|
|
|
True) # 1st sort
|
2021-08-16 19:33:15 +08:00
|
|
|
self.all_scores.append(self.scores)
|
|
|
|
self.scores = best_scores
|
|
|
|
# bestScoresId is flattened as a (beam x word) array,
|
|
|
|
# so we need to calculate which word and beam each score came from
|
|
|
|
prev_k = best_scores_id // num_words
|
|
|
|
self.prev_ks.append(prev_k)
|
2021-08-19 17:31:02 +08:00
|
|
|
self.next_ys.append(best_scores_id - prev_k * num_words)
|
2021-08-16 19:33:15 +08:00
|
|
|
# End condition is when top-of-beam is EOS.
|
2021-08-19 17:31:02 +08:00
|
|
|
if self.next_ys[-1][0] == 3:
|
2021-08-16 19:33:15 +08:00
|
|
|
self._done = True
|
|
|
|
self.all_scores.append(self.scores)
|
|
|
|
|
|
|
|
return self._done
|
|
|
|
|
|
|
|
def sort_scores(self):
|
|
|
|
"Sort the scores."
|
2021-08-19 17:31:02 +08:00
|
|
|
return self.scores, paddle.to_tensor(
|
2021-09-13 21:10:10 +08:00
|
|
|
[i for i in range(int(self.scores.shape[0]))], dtype='int32')
|
2021-08-16 19:33:15 +08:00
|
|
|
|
|
|
|
def get_the_best_score_and_idx(self):
|
|
|
|
"Get the score of the best in the beam."
|
|
|
|
scores, ids = self.sort_scores()
|
|
|
|
return scores[1], ids[1]
|
|
|
|
|
|
|
|
def get_tentative_hypothesis(self):
|
|
|
|
"Get the decoded sequence for the current timestep."
|
|
|
|
if len(self.next_ys) == 1:
|
|
|
|
dec_seq = self.next_ys[0].unsqueeze(1)
|
|
|
|
else:
|
|
|
|
_, keys = self.sort_scores()
|
|
|
|
hyps = [self.get_hypothesis(k) for k in keys]
|
|
|
|
hyps = [[2] + h for h in hyps]
|
|
|
|
dec_seq = paddle.to_tensor(hyps, dtype='int64')
|
|
|
|
return dec_seq
|
|
|
|
|
|
|
|
def get_hypothesis(self, k):
|
|
|
|
""" Walk back to construct the full hypothesis. """
|
|
|
|
hyp = []
|
|
|
|
for j in range(len(self.prev_ks) - 1, -1, -1):
|
2021-08-19 17:31:02 +08:00
|
|
|
hyp.append(self.next_ys[j + 1][k])
|
2021-08-16 19:33:15 +08:00
|
|
|
k = self.prev_ks[j][k]
|
|
|
|
return list(map(lambda x: x.item(), hyp[::-1]))
|