Update Attention.py
This commit is contained in:
parent
8e7c15d914
commit
db19659742
|
@ -11,7 +11,7 @@ class DotAttention(nn.Module):
|
|||
super(DotAttention, self).__init__()
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, Q, K, V, mask_out=None,head_mask=None):
|
||||
def forward(self, Q, K, V, mask_out=None, head_mask=None):
|
||||
"""
|
||||
一般输入信息 X 时,假设 K = V = X
|
||||
|
||||
|
@ -70,7 +70,7 @@ class MultiHeadAttention(nn.Module):
|
|||
self.attention = DotAttention(dropout=dropout)
|
||||
self.out = nn.Linear(self.all_head_dim, embed_dim)
|
||||
|
||||
def forward(self, Q, K, V, key_padding_mask=None,attention_mask=None, head_mask=None):
|
||||
def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None):
|
||||
"""
|
||||
:param Q: [B, L, Hs]
|
||||
:param K: [B, S, Hs]
|
||||
|
@ -81,7 +81,7 @@ class MultiHeadAttention(nn.Module):
|
|||
"""
|
||||
B, L, Hs = Q.shape
|
||||
S = V.size(1)
|
||||
N,H = self.num_heads, self.head_dim
|
||||
N, H = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q_in(Q).view(B, L, N, H).transpose(1, 2) # [B, N, L, H]
|
||||
k = self.k_in(K).view(B, S, N, H).transpose(1, 2) # [B, N, S, H]
|
||||
|
@ -96,7 +96,7 @@ class MultiHeadAttention(nn.Module):
|
|||
if attention_mask.dim() == 1:
|
||||
attention_mask = attention_mask.unsqueeze(0)
|
||||
elif attention_mask.dim() == 2:
|
||||
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(B,-1,-1,-1)
|
||||
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1)
|
||||
else:
|
||||
raise ValueError(f'attention_mask dim must be 1 or 2, can not be {attention_mask.dim()}')
|
||||
|
||||
|
@ -109,7 +109,7 @@ class MultiHeadAttention(nn.Module):
|
|||
head_mask = head_mask.eq(0)
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
attention_out, attention_weight = self.attention(q, k, v, mask_out=mask_out,head_mask=head_mask)
|
||||
attention_out, attention_weight = self.attention(q, k, v, mask_out=mask_out, head_mask=head_mask)
|
||||
|
||||
attention_out = attention_out.transpose(1, 2).reshape(B, L, N * H) # [B, N, L, H] -> [B, L, N * H]
|
||||
|
||||
|
@ -127,12 +127,12 @@ if __name__ == '__main__':
|
|||
|
||||
q = torch.randn(4, 6, 20) # [B, L, H]
|
||||
k = v = torch.randn(4, 5, 20) # [B, S, H]
|
||||
key_padding_mask = seq_len_to_mask([5,4,3,2], max_len=5)
|
||||
attention_mask = torch.tensor([1,0,0,1,0]) # 为1 的地方 mask 掉
|
||||
head_mask = torch.tensor([0,1]) # 为1 的地方 mask 掉
|
||||
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
|
||||
attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉
|
||||
head_mask = torch.tensor([0, 1]) # 为1 的地方 mask 掉
|
||||
|
||||
m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0,output_attentions=True)
|
||||
ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask,head_mask=head_mask)
|
||||
m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0, output_attentions=True)
|
||||
ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask)
|
||||
print(ao.shape, aw.shape) # [B, L, H] [B, N, L, S]
|
||||
print(ao)
|
||||
print(aw.unbind(1))
|
||||
|
|
Loading…
Reference in New Issue