diff --git a/module/Attention.py b/module/Attention.py index 1493d09..851be48 100644 --- a/module/Attention.py +++ b/module/Attention.py @@ -11,7 +11,7 @@ class DotAttention(nn.Module): super(DotAttention, self).__init__() self.dropout = dropout - def forward(self, Q, K, V, mask_out=None,head_mask=None): + def forward(self, Q, K, V, mask_out=None, head_mask=None): """ 一般输入信息 X 时,假设 K = V = X @@ -70,7 +70,7 @@ class MultiHeadAttention(nn.Module): self.attention = DotAttention(dropout=dropout) self.out = nn.Linear(self.all_head_dim, embed_dim) - def forward(self, Q, K, V, key_padding_mask=None,attention_mask=None, head_mask=None): + def forward(self, Q, K, V, key_padding_mask=None, attention_mask=None, head_mask=None): """ :param Q: [B, L, Hs] :param K: [B, S, Hs] @@ -81,7 +81,7 @@ class MultiHeadAttention(nn.Module): """ B, L, Hs = Q.shape S = V.size(1) - N,H = self.num_heads, self.head_dim + N, H = self.num_heads, self.head_dim q = self.q_in(Q).view(B, L, N, H).transpose(1, 2) # [B, N, L, H] k = self.k_in(K).view(B, S, N, H).transpose(1, 2) # [B, N, S, H] @@ -96,7 +96,7 @@ class MultiHeadAttention(nn.Module): if attention_mask.dim() == 1: attention_mask = attention_mask.unsqueeze(0) elif attention_mask.dim() == 2: - attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(B,-1,-1,-1) + attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).expand(B, -1, -1, -1) else: raise ValueError(f'attention_mask dim must be 1 or 2, can not be {attention_mask.dim()}') @@ -109,7 +109,7 @@ class MultiHeadAttention(nn.Module): head_mask = head_mask.eq(0) head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - attention_out, attention_weight = self.attention(q, k, v, mask_out=mask_out,head_mask=head_mask) + attention_out, attention_weight = self.attention(q, k, v, mask_out=mask_out, head_mask=head_mask) attention_out = attention_out.transpose(1, 2).reshape(B, L, N * H) # [B, N, L, H] -> [B, L, N * H] @@ -127,12 +127,12 @@ if __name__ == '__main__': q = torch.randn(4, 6, 20) # [B, L, H] k = v = torch.randn(4, 5, 20) # [B, S, H] - key_padding_mask = seq_len_to_mask([5,4,3,2], max_len=5) - attention_mask = torch.tensor([1,0,0,1,0]) # 为1 的地方 mask 掉 - head_mask = torch.tensor([0,1]) # 为1 的地方 mask 掉 + key_padding_mask = seq_len_to_mask([5, 4, 3, 2], max_len=5) + attention_mask = torch.tensor([1, 0, 0, 1, 0]) # 为1 的地方 mask 掉 + head_mask = torch.tensor([0, 1]) # 为1 的地方 mask 掉 - m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0,output_attentions=True) - ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask,head_mask=head_mask) + m = MultiHeadAttention(embed_dim=20, num_heads=2, dropout=0.0, output_attentions=True) + ao, aw = m(q, k, v, key_padding_mask=key_padding_mask, attention_mask=attention_mask, head_mask=head_mask) print(ao.shape, aw.shape) # [B, L, H] [B, N, L, S] print(ao) print(aw.unbind(1))