ParakeetRebeccaRosario/parakeet/modules/attention.py

293 lines
10 KiB
Python
Raw Normal View History

2020-12-09 17:08:17 +08:00
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2020-10-10 15:51:54 +08:00
import math
import numpy as np
import paddle
from paddle import nn
from paddle.nn import functional as F
2020-12-09 17:08:17 +08:00
def scaled_dot_product_attention(q,
k,
v,
mask=None,
dropout=0.0,
training=True):
2020-12-18 10:54:50 +08:00
"""Scaled dot product attention with masking.
Assume that q, k, v all have the same leading dimensions (denoted as * in
descriptions below). Dropout is applied to attention weights before
weighted sum of values.
2020-10-10 15:51:54 +08:00
2020-12-18 10:54:50 +08:00
Parameters
-----------
2020-10-10 15:51:54 +08:00
2020-12-18 10:54:50 +08:00
q: Tensor [shape=(*, T_q, d)]
the query tensor.
2020-12-18 11:12:22 +08:00
2020-12-18 10:54:50 +08:00
k: Tensor [shape=(*, T_k, d)]
the key tensor.
2020-12-18 11:12:22 +08:00
2020-12-18 10:54:50 +08:00
v: Tensor [shape=(*, T_k, d_v)]
the value tensor.
2020-12-18 11:12:22 +08:00
2020-12-18 10:54:50 +08:00
mask: Tensor, [shape=(*, T_q, T_k) or broadcastable shape], optional
the mask tensor, zeros correspond to paddings. Defaults to None.
Returns
----------
out: Tensor [shape(*, T_q, d_v)]
the context vector.
2020-12-18 11:12:22 +08:00
2020-12-18 10:54:50 +08:00
attn_weights [Tensor shape(*, T_q, T_k)]
the attention weights.
2020-10-10 15:51:54 +08:00
"""
2020-12-09 17:08:17 +08:00
d = q.shape[-1] # we only support imperative execution
2020-10-10 15:51:54 +08:00
qk = paddle.matmul(q, k, transpose_y=True)
scaled_logit = paddle.scale(qk, 1.0 / math.sqrt(d))
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
if mask is not None:
2020-12-09 17:08:17 +08:00
scaled_logit += paddle.scale((1.0 - mask), -1e9) # hard coded here
2020-10-10 15:51:54 +08:00
attn_weights = F.softmax(scaled_logit, axis=-1)
attn_weights = F.dropout(attn_weights, dropout, training=training)
out = paddle.matmul(attn_weights, v)
return out, attn_weights
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
def drop_head(x, drop_n_heads, training):
"""
Drop n heads from multiple context vectors.
Args:
x (Tensor): shape(batch_size, num_heads, time_steps, channels), the input.
drop_n_heads (int): [description]
training ([type]): [description]
Returns:
[type]: [description]
"""
if not training or (drop_n_heads == 0):
return x
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
batch_size, num_heads, _, _ = x.shape
# drop all heads
if num_heads == drop_n_heads:
return paddle.zeros_like(x)
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
mask = np.ones([batch_size, num_heads])
mask[:, :drop_n_heads] = 0
for subarray in mask:
np.random.shuffle(subarray)
scale = float(num_heads) / (num_heads - drop_n_heads)
mask = scale * np.reshape(mask, [batch_size, num_heads, 1, 1])
out = x * paddle.to_tensor(mask)
return out
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
def _split_heads(x, num_heads):
batch_size, time_steps, _ = x.shape
x = paddle.reshape(x, [batch_size, time_steps, num_heads, -1])
x = paddle.transpose(x, [0, 2, 1, 3])
return x
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
def _concat_heads(x):
batch_size, _, time_steps, _ = x.shape
x = paddle.transpose(x, [0, 2, 1, 3])
x = paddle.reshape(x, [batch_size, time_steps, -1])
return x
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
# Standard implementations of Monohead Attention & Multihead Attention
class MonoheadAttention(nn.Layer):
def __init__(self, model_dim, dropout=0.0, k_dim=None, v_dim=None):
"""
Monohead Attention module.
Args:
model_dim (int): the feature size of query.
dropout (float, optional): dropout probability of scaled dot product
attention and final context vector. Defaults to 0.0.
k_dim (int, optional): feature size of the key of each scaled dot
product attention. If not provided, it is set to
model_dim / num_heads. Defaults to None.
v_dim (int, optional): feature size of the key of each scaled dot
product attention. If not provided, it is set to
model_dim / num_heads. Defaults to None.
"""
super(MonoheadAttention, self).__init__()
k_dim = k_dim or model_dim
v_dim = v_dim or model_dim
self.affine_q = nn.Linear(model_dim, k_dim)
self.affine_k = nn.Linear(model_dim, k_dim)
self.affine_v = nn.Linear(model_dim, v_dim)
self.affine_o = nn.Linear(v_dim, model_dim)
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
self.model_dim = model_dim
self.dropout = dropout
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
def forward(self, q, k, v, mask):
"""
Compute context vector and attention weights.
Args:
q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries.
k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys.
v (Tensor): shape(batch_size, time_steps_k, model_dim), the values.
mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or
broadcastable shape, dtype: float32 or float64, the mask.
Returns:
(out, attention_weights)
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
"""
2020-12-09 17:08:17 +08:00
q = self.affine_q(q) # (B, T, C)
2020-10-10 15:51:54 +08:00
k = self.affine_k(k)
v = self.affine_v(v)
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask, self.dropout, self.training)
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
out = self.affine_o(context_vectors)
return out, attention_weights
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
class MultiheadAttention(nn.Layer):
"""
Multihead scaled dot product attention.
"""
2020-12-09 17:08:17 +08:00
def __init__(self,
model_dim,
num_heads,
dropout=0.0,
k_dim=None,
v_dim=None):
2020-10-10 15:51:54 +08:00
"""
Multihead Attention module.
Args:
model_dim (int): the feature size of query.
num_heads (int): the number of attention heads.
dropout (float, optional): dropout probability of scaled dot product
attention and final context vector. Defaults to 0.0.
k_dim (int, optional): feature size of the key of each scaled dot
product attention. If not provided, it is set to
model_dim / num_heads. Defaults to None.
v_dim (int, optional): feature size of the key of each scaled dot
product attention. If not provided, it is set to
model_dim / num_heads. Defaults to None.
Raises:
ValueError: if model_dim is not divisible by num_heads
"""
super(MultiheadAttention, self).__init__()
2020-12-09 17:08:17 +08:00
if model_dim % num_heads != 0:
2020-10-10 15:51:54 +08:00
raise ValueError("model_dim must be divisible by num_heads")
depth = model_dim // num_heads
k_dim = k_dim or depth
v_dim = v_dim or depth
self.affine_q = nn.Linear(model_dim, num_heads * k_dim)
self.affine_k = nn.Linear(model_dim, num_heads * k_dim)
self.affine_v = nn.Linear(model_dim, num_heads * v_dim)
self.affine_o = nn.Linear(num_heads * v_dim, model_dim)
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
self.num_heads = num_heads
self.model_dim = model_dim
self.dropout = dropout
2020-12-09 17:08:17 +08:00
2020-10-10 15:51:54 +08:00
def forward(self, q, k, v, mask):
"""
Compute context vector and attention weights.
Args:
q (Tensor): shape(batch_size, time_steps_q, model_dim), the queries.
k (Tensor): shape(batch_size, time_steps_k, model_dim), the keys.
v (Tensor): shape(batch_size, time_steps_k, model_dim), the values.
mask (Tensor): shape(batch_size, times_steps_q, time_steps_k) or
broadcastable shape, dtype: float32 or float64, the mask.
Returns:
(out, attention_weights)
out (Tensor), shape(batch_size, time_steps_q, model_dim), the context vector.
attention_weights (Tensor): shape(batch_size, times_steps_q, time_steps_k), the attention weights.
"""
2020-12-09 17:08:17 +08:00
q = _split_heads(self.affine_q(q), self.num_heads) # (B, h, T, C)
2020-10-10 15:51:54 +08:00
k = _split_heads(self.affine_k(k), self.num_heads)
v = _split_heads(self.affine_v(v), self.num_heads)
2020-12-09 17:08:17 +08:00
mask = paddle.unsqueeze(mask, 1) # unsqueeze for the h dim
2020-10-10 15:51:54 +08:00
context_vectors, attention_weights = scaled_dot_product_attention(
q, k, v, mask, self.dropout, self.training)
# NOTE: there is more sophisticated implementation: Scheduled DropHead
2020-12-09 17:08:17 +08:00
context_vectors = _concat_heads(context_vectors) # (B, T, h*C)
2020-10-10 15:51:54 +08:00
out = self.affine_o(context_vectors)
return out, attention_weights
2020-12-09 17:08:17 +08:00
class LocationSensitiveAttention(nn.Layer):
def __init__(self,
d_query: int,
d_key: int,
d_attention: int,
location_filters: int,
location_kernel_size: int):
super().__init__()
self.query_layer = nn.Linear(d_query, d_attention, bias_attr=False)
self.key_layer = nn.Linear(d_key, d_attention, bias_attr=False)
self.value = nn.Linear(d_attention, 1, bias_attr=False)
#Location Layer
self.location_conv = nn.Conv1D(
2,
location_filters,
location_kernel_size,
1,
int((location_kernel_size - 1) / 2),
1,
bias_attr=False,
data_format='NLC')
self.location_layer = nn.Linear(
location_filters, d_attention, bias_attr=False)
def forward(self,
query,
processed_key,
value,
attention_weights_cat,
mask=None):
processed_query = self.query_layer(paddle.unsqueeze(query, axis=[1]))
processed_attention_weights = self.location_layer(
self.location_conv(attention_weights_cat))
alignment = self.value(
paddle.tanh(processed_attention_weights + processed_key +
processed_query))
if mask is not None:
alignment = alignment + (1.0 - mask) * -1e9
attention_weights = F.softmax(alignment, axis=1)
attention_context = paddle.matmul(
attention_weights, value, transpose_x=True)
attention_weights = paddle.squeeze(attention_weights, axis=[-1])
attention_context = paddle.squeeze(attention_context, axis=[1])
return attention_context, attention_weights