2020-01-03 16:25:17 +08:00
|
|
|
import math
|
|
|
|
import numpy as np
|
2020-02-11 16:57:30 +08:00
|
|
|
import paddle.fluid as fluid
|
2020-01-03 16:25:17 +08:00
|
|
|
import paddle.fluid.dygraph as dg
|
|
|
|
import paddle.fluid.layers as layers
|
2020-02-11 16:57:30 +08:00
|
|
|
|
|
|
|
class Linear(dg.Layer):
|
|
|
|
def __init__(self, in_features, out_features, is_bias=True, dtype="float32"):
|
|
|
|
super(Linear, self).__init__()
|
|
|
|
self.in_features = in_features
|
|
|
|
self.out_features = out_features
|
|
|
|
self.dtype = dtype
|
|
|
|
self.weight = fluid.ParamAttr(initializer = fluid.initializer.XavierInitializer())
|
|
|
|
self.bias = is_bias
|
|
|
|
|
|
|
|
if is_bias is not False:
|
|
|
|
k = math.sqrt(1 / in_features)
|
|
|
|
self.bias = fluid.ParamAttr(initializer = fluid.initializer.Uniform(low=-k, high=k))
|
|
|
|
|
|
|
|
self.linear = dg.Linear(in_features, out_features, param_attr = self.weight,
|
|
|
|
bias_attr = self.bias,)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.linear(x)
|
|
|
|
return x
|
2020-01-03 16:25:17 +08:00
|
|
|
|
|
|
|
class ScaledDotProductAttention(dg.Layer):
|
|
|
|
def __init__(self, d_key):
|
|
|
|
super(ScaledDotProductAttention, self).__init__()
|
|
|
|
|
|
|
|
self.d_key = d_key
|
|
|
|
|
|
|
|
# please attention this mask is diff from pytorch
|
2020-01-08 11:55:06 +08:00
|
|
|
def forward(self, key, value, query, mask=None, query_mask=None, dropout=0.1):
|
|
|
|
"""
|
|
|
|
Scaled Dot Product Attention.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
|
|
|
|
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
|
|
|
|
query (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
|
|
|
|
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
|
|
|
|
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
|
|
|
|
dropout (Constant): dtype: float32. The probability of dropout.
|
|
|
|
Returns:
|
|
|
|
result (Variable), Shape(B, T, C), the result of mutihead attention.
|
|
|
|
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
|
|
|
"""
|
2020-01-03 16:25:17 +08:00
|
|
|
# Compute attention score
|
|
|
|
attention = layers.matmul(query, key, transpose_y=True) #transpose the last dim in y
|
|
|
|
attention = attention / math.sqrt(self.d_key)
|
|
|
|
|
|
|
|
# Mask key to ignore padding
|
|
|
|
if mask is not None:
|
2020-01-08 11:55:06 +08:00
|
|
|
attention = attention * mask
|
|
|
|
mask = (mask == 0).astype(np.float32) * (-2 ** 32 + 1)
|
2020-01-03 16:25:17 +08:00
|
|
|
attention = attention + mask
|
2020-01-13 20:37:49 +08:00
|
|
|
|
2020-01-03 16:25:17 +08:00
|
|
|
attention = layers.softmax(attention)
|
2020-01-08 11:55:06 +08:00
|
|
|
attention = layers.dropout(attention, dropout)
|
2020-01-13 20:37:49 +08:00
|
|
|
|
2020-01-03 16:25:17 +08:00
|
|
|
# Mask query to ignore padding
|
|
|
|
if query_mask is not None:
|
|
|
|
attention = attention * query_mask
|
|
|
|
|
|
|
|
result = layers.matmul(attention, value)
|
|
|
|
return result, attention
|
|
|
|
|
|
|
|
class MultiheadAttention(dg.Layer):
|
2020-02-06 17:11:28 +08:00
|
|
|
def __init__(self, num_hidden, d_k, d_q, num_head=4, is_bias=False, dropout=0.1, is_concat=True):
|
2020-01-03 16:25:17 +08:00
|
|
|
super(MultiheadAttention, self).__init__()
|
|
|
|
self.num_hidden = num_hidden
|
|
|
|
self.num_head = num_head
|
|
|
|
self.d_k = d_k
|
|
|
|
self.d_q = d_q
|
|
|
|
self.dropout = dropout
|
2020-02-06 17:11:28 +08:00
|
|
|
self.is_concat = is_concat
|
2020-01-03 16:25:17 +08:00
|
|
|
|
2020-02-06 17:11:28 +08:00
|
|
|
self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
|
|
|
|
self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
|
|
|
|
self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias)
|
2020-01-03 16:25:17 +08:00
|
|
|
|
|
|
|
self.scal_attn = ScaledDotProductAttention(d_k)
|
|
|
|
|
2020-02-06 17:11:28 +08:00
|
|
|
if self.is_concat:
|
|
|
|
self.fc = Linear(num_head * d_q * 2, num_hidden)
|
|
|
|
else:
|
|
|
|
self.fc = Linear(num_head * d_q, num_hidden)
|
2020-01-03 16:25:17 +08:00
|
|
|
|
|
|
|
self.layer_norm = dg.LayerNorm(num_hidden)
|
|
|
|
|
|
|
|
def forward(self, key, value, query_input, mask=None, query_mask=None):
|
2020-01-08 11:55:06 +08:00
|
|
|
"""
|
|
|
|
Multihead Attention.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
key (Variable): Shape(B, T, C), dtype: float32. The input key of attention.
|
|
|
|
value (Variable): Shape(B, T, C), dtype: float32. The input value of attention.
|
|
|
|
query_input (Variable): Shape(B, T, C), dtype: float32. The input query of attention.
|
|
|
|
mask (Variable): Shape(B, len_q, len_k), dtype: float32. The mask of key.
|
|
|
|
query_mask (Variable): Shape(B, len_q, 1), dtype: float32. The mask of query.
|
|
|
|
Returns:
|
|
|
|
result (Variable), Shape(B, T, C), the result of mutihead attention.
|
|
|
|
attention (Variable), Shape(n_head * B, T, C), the attention of key.
|
|
|
|
"""
|
2020-01-03 16:25:17 +08:00
|
|
|
batch_size = key.shape[0]
|
|
|
|
seq_len_key = key.shape[1]
|
|
|
|
seq_len_query = query_input.shape[1]
|
|
|
|
|
|
|
|
# repeat masks h times
|
|
|
|
if query_mask is not None:
|
|
|
|
query_mask = layers.expand(query_mask, [self.num_head, 1, seq_len_key])
|
|
|
|
if mask is not None:
|
|
|
|
mask = layers.expand(mask, (self.num_head, 1, 1))
|
|
|
|
|
2020-01-08 11:55:06 +08:00
|
|
|
|
2020-01-03 16:25:17 +08:00
|
|
|
# Make multihead attention
|
|
|
|
# key & value.shape = (batch_size, seq_len, feature)(feature = num_head * num_hidden_per_attn)
|
|
|
|
key = layers.reshape(self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
|
|
|
|
value = layers.reshape(self.value(value), [batch_size, seq_len_key, self.num_head, self.d_k])
|
|
|
|
query = layers.reshape(self.query(query_input), [batch_size, seq_len_query, self.num_head, self.d_q])
|
|
|
|
|
|
|
|
key = layers.reshape(layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
|
|
|
|
value = layers.reshape(layers.transpose(value, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
|
|
|
|
query = layers.reshape(layers.transpose(query, [2, 0, 1, 3]), [-1, seq_len_query, self.d_q])
|
2020-01-08 11:55:06 +08:00
|
|
|
|
2020-01-03 16:25:17 +08:00
|
|
|
result, attention = self.scal_attn(key, value, query, mask=mask, query_mask=query_mask)
|
|
|
|
|
|
|
|
# concat all multihead result
|
|
|
|
result = layers.reshape(result, [self.num_head, batch_size, seq_len_query, self.d_q])
|
|
|
|
result = layers.reshape(layers.transpose(result, [1,2,0,3]),[batch_size, seq_len_query, -1])
|
2020-02-06 17:11:28 +08:00
|
|
|
if self.is_concat:
|
|
|
|
result = layers.concat([query_input,result], axis=-1)
|
2020-01-03 16:25:17 +08:00
|
|
|
result = layers.dropout(self.fc(result), self.dropout)
|
|
|
|
result = result + query_input
|
|
|
|
|
|
|
|
result = self.layer_norm(result)
|
|
|
|
return result, attention
|