Update Attention.py

This commit is contained in:
leo 2019-12-03 22:43:09 +08:00
parent 8e7c15d914
commit db19659742
1 changed files with 10 additions and 10 deletions

View File

@ -11,7 +11,7 @@ class DotAttention(nn.Module):
super(DotAttention, self).__init__()
self.dropout = dropout
def forward(self, Q, K, V, mask_out=None,head_mask=None):
def forward(self, Q, K, V, mask_out=None, head_mask=None):
"""
一般输入信息 X 假设 K = V = X
@ -70,7 +70,7 @@ class MultiHeadAttention(nn.Module):
self.attention = DotAttention(dropout=dropout)
self.out = nn.Linear(self.all_head_dim, embed_dim)
def forward(self, Q, K, V, key_padding_mask=None,attention_mask=None, head_mask=None):
def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None):
"""
:param Q: [B, L, Hs]
:param K: [B, S, Hs]
@ -81,7 +81,7 @@ class MultiHeadAttention(nn.Module):
"""
B, L, Hs = Q.shape
S = V.size(1)
N,H = self.num_heads, self.head_dim
N, H = self.num_heads, self.head_dim
q = self.q_in(Q).view(B, L, N, H).transpose(1, 2) # [B, N, L, H]
k = self.k_in(K).view(B, S, N, H).transpose(1, 2) # [B, N, S, H]
@ -96,7 +96,7 @@ class MultiHeadAttention(nn.Module):
if attention_mask.dim() == 1:
attention_mask = attention_mask.unsqueeze(0)
elif attention_mask.dim() == 2:
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(B,-1,-1,-1)
attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1)
else:
raise ValueError(f'attention_mask dim must be 1 or 2, can not be {attention_mask.dim()}')
@ -109,7 +109,7 @@ class MultiHeadAttention(nn.Module):
head_mask = head_mask.eq(0)
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
attention_out, attention_weight = self.attention(q, k, v, mask_out=mask_out,head_mask=head_mask)
attention_out, attention_weight = self.attention(q, k, v, mask_out=mask_out, head_mask=head_mask)
attention_out = attention_out.transpose(1, 2).reshape(B, L, N * H) # [B, N, L, H] -> [B, L, N * H]
@ -127,12 +127,12 @@ if __name__ == '__main__':
q = torch.randn(4, 6, 20) # [B, L, H]
k = v = torch.randn(4, 5, 20) # [B, S, H]
key_padding_mask = seq_len_to_mask([5,4,3,2], max_len=5)
attention_mask = torch.tensor([1,0,0,1,0]) # 为1 的地方 mask 掉
head_mask = torch.tensor([0,1]) # 为1 的地方 mask 掉
key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5)
attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉
head_mask = torch.tensor([0, 1]) # 为1 的地方 mask 掉
m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0,output_attentions=True)
ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask,head_mask=head_mask)
m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0, output_attentions=True)
ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask)
print(ao.shape, aw.shape) # [B, L, H] [B, N, L, S]
print(ao)
print(aw.unbind(1))