ParakeetRebeccaRosario/parakeet/modules/multihead_attention.py

196 lines
7.0 KiB
Python
Raw Normal View History

2020-02-26 21:03:51 +08:00
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2020-01-03 16:25:17 +08:00
import math
import numpy as np
2020-02-11 16:57:30 +08:00
import paddle.fluid as fluid
2020-01-03 16:25:17 +08:00
import paddle.fluid.dygraph as dg
import paddle.fluid.layers as layers
2020-02-11 16:57:30 +08:00
2020-02-26 21:03:51 +08:00
2020-02-11 16:57:30 +08:00
class Linear(dg.Layer):
2020-02-26 21:03:51 +08:00
def __init__(self,
in_features,
out_features,
is_bias=True,
dtype="float32"):
2020-02-11 16:57:30 +08:00
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.dtype = dtype
2020-02-26 21:03:51 +08:00
self.weight = fluid.ParamAttr(
initializer=fluid.initializer.XavierInitializer())
self.bias = is_bias
2020-02-11 16:57:30 +08:00
if is_bias is not False:
k = math.sqrt(1 / in_features)
2020-02-26 21:03:51 +08:00
self.bias = fluid.ParamAttr(initializer=fluid.initializer.Uniform(
low=-k, high=k))
self.linear = dg.Linear(
in_features,
out_features,
param_attr=self.weight,
bias_attr=self.bias, )
2020-02-11 16:57:30 +08:00
def forward(self, x):
x = self.linear(x)
return x
2020-01-03 16:25:17 +08:00
2020-02-26 21:03:51 +08:00
2020-01-03 16:25:17 +08:00
class ScaledDotProductAttention(dg.Layer):
def __init__(self, d_key):
super(ScaledDotProductAttention, self).__init__()
self.d_key = d_key
2020-02-26 21:03:51 +08:00
2020-01-03 16:25:17 +08:00
# please attention this mask is diff from pytorch
2020-02-26 21:03:51 +08:00
def forward(self,
key,
value,
query,
mask=None,
query_mask=None,
dropout=0.1):
"""
Scaled Dot Product Attention.
Args:
key (Variable): The input key of scaled dot product attention.
Shape: (B, T, C), dtype: float32.
value (Variable): The input value of scaled dot product attention.
Shape: (B, T, C), dtype: float32.
query (Variable): The input query of scaled dot product attention.
Shape: (B, T, C), dtype: float32.
mask (Variable, optional): The mask of key. Defaults to None.
Shape(B, T_q, T_k), dtype: float32.
query_mask (Variable, optional): The mask of query. Defaults to None.
Shape(B, T_q, T_q), dtype: float32.
dropout (float32, optional): The probability of dropout. Defaults to 0.1.
Returns:
result (Variable), Shape(B, T, C), the result of mutihead attention.
attention (Variable), Shape(n_head * B, T, C), the attention of key.
"""
2020-01-03 16:25:17 +08:00
# Compute attention score
2020-02-26 21:03:51 +08:00
attention = layers.matmul(
query, key, transpose_y=True, alpha=self.d_key
**-0.5) #transpose the last dim in y
2020-01-03 16:25:17 +08:00
# Mask key to ignore padding
if mask is not None:
attention = attention + mask
attention = layers.softmax(attention)
attention = layers.dropout(
attention, dropout, dropout_implementation='upscale_in_train')
2020-02-26 21:03:51 +08:00
2020-01-03 16:25:17 +08:00
# Mask query to ignore padding
if query_mask is not None:
attention = attention * query_mask
2020-02-26 21:03:51 +08:00
2020-01-03 16:25:17 +08:00
result = layers.matmul(attention, value)
return result, attention
2020-02-26 21:03:51 +08:00
2020-01-03 16:25:17 +08:00
class MultiheadAttention(dg.Layer):
2020-02-26 21:03:51 +08:00
def __init__(self,
num_hidden,
d_k,
d_q,
num_head=4,
is_bias=False,
dropout=0.1,
is_concat=True):
2020-01-03 16:25:17 +08:00
super(MultiheadAttention, self).__init__()
self.num_hidden = num_hidden
self.num_head = num_head
self.d_k = d_k
self.d_q = d_q
self.dropout = dropout
2020-02-06 17:11:28 +08:00
self.is_concat = is_concat
2020-01-03 16:25:17 +08:00
2020-02-06 17:11:28 +08:00
self.key = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
self.value = Linear(num_hidden, num_head * d_k, is_bias=is_bias)
self.query = Linear(num_hidden, num_head * d_q, is_bias=is_bias)
2020-01-03 16:25:17 +08:00
self.scal_attn = ScaledDotProductAttention(d_k)
2020-02-06 17:11:28 +08:00
if self.is_concat:
self.fc = Linear(num_head * d_q * 2, num_hidden)
else:
self.fc = Linear(num_head * d_q, num_hidden)
2020-01-03 16:25:17 +08:00
self.layer_norm = dg.LayerNorm(num_hidden)
def forward(self, key, value, query_input, mask=None, query_mask=None):
"""
Multihead Attention.
Args:
key (Variable): The input key of attention.
Shape: (B, T, C), dtype: float32.
value (Variable): The input value of attention.
Shape: (B, T, C), dtype: float32.
query_input (Variable): The input query of attention.
Shape: (B, T, C), dtype: float32.
mask (Variable, optional): The mask of key. Defaults to None.
Shape: (B, T_query, T_key), dtype: float32.
query_mask (Variable, optional): The mask of query. Defaults to None.
Shape: (B, T_query, T_key), dtype: float32.
Returns:
result (Variable), the result of mutihead attention. Shape: (B, T, C).
attention (Variable), the attention of key and query. Shape: (num_head * B, T, C)
"""
2020-01-03 16:25:17 +08:00
batch_size = key.shape[0]
seq_len_key = key.shape[1]
seq_len_query = query_input.shape[1]
# Make multihead attention
2020-02-26 21:03:51 +08:00
key = layers.reshape(
self.key(key), [batch_size, seq_len_key, self.num_head, self.d_k])
value = layers.reshape(
self.value(value),
[batch_size, seq_len_key, self.num_head, self.d_k])
query = layers.reshape(
self.query(query_input),
[batch_size, seq_len_query, self.num_head, self.d_q])
key = layers.reshape(
layers.transpose(key, [2, 0, 1, 3]), [-1, seq_len_key, self.d_k])
value = layers.reshape(
layers.transpose(value, [2, 0, 1, 3]),
[-1, seq_len_key, self.d_k])
query = layers.reshape(
layers.transpose(query, [2, 0, 1, 3]),
[-1, seq_len_query, self.d_q])
result, attention = self.scal_attn(
key, value, query, mask=mask, query_mask=query_mask)
2020-01-03 16:25:17 +08:00
# concat all multihead result
2020-02-26 21:03:51 +08:00
result = layers.reshape(
result, [self.num_head, batch_size, seq_len_query, self.d_q])
result = layers.reshape(
layers.transpose(result, [1, 2, 0, 3]),
[batch_size, seq_len_query, -1])
2020-02-06 17:11:28 +08:00
if self.is_concat:
2020-02-26 21:03:51 +08:00
result = layers.concat([query_input, result], axis=-1)
result = layers.dropout(
self.fc(result),
self.dropout,
dropout_implementation='upscale_in_train')
2020-01-03 16:25:17 +08:00
result = result + query_input
2020-02-26 21:03:51 +08:00
2020-01-03 16:25:17 +08:00
result = self.layer_norm(result)
2020-02-26 21:03:51 +08:00
return result, attention