Browse Source

chore: use pytorch sdpa backend to do naive attention for rocm

AlpinDale 7 months ago
parent
commit
71a26f0998
1 changed files with 30 additions and 34 deletions
  1. 30 34
      aphrodite/attention/backends/rocm_flash_attn.py

+ 30 - 34
aphrodite/attention/backends/rocm_flash_attn.py

@@ -203,7 +203,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         blocksparse_params: Optional[Dict[str, Any]] = None,
     ) -> None:
         assert blocksparse_params is None, ValueError(
-            "ROCm FlashAttention does not support block-sparse attention.")
+            "ROCmFlashAttention does not support blocksparse attention.")
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
@@ -248,7 +248,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                     self.use_naive_attn = True
 
             if self.use_naive_attn:
-                self.attn_func = _naive_attention
+                self.attn_func = _sdpa_attention
                 logger.debug("Using naive attention in ROCmBackend")
 
     def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -343,11 +343,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                         # Interleave for MQA workaround.
                         key = self.repeat_kv(key, self.num_queries_per_kv)
                         value = self.repeat_kv(value, self.num_queries_per_kv)
+                    query = query.movedim(0, query.dim() - 2)
+                    key = key.movedim(0, key.dim() - 2)
+                    value = value.movedim(0, value.dim() - 2)
+                    # sdpa math backend attention
                     out = self.attn_func(
                         query,
                         key,
                         value,
                         prefill_meta.seq_lens,
+                        num_tokens,
+                        self.num_heads,
+                        self.head_size,
                         self.scale,
                     )
                 else:
@@ -403,45 +410,34 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         return output.view(num_tokens, hidden_size)
 
 
-def _naive_attention(
+def _sdpa_attention(
     query: torch.Tensor,
     key: torch.Tensor,
     value: torch.Tensor,
     seq_lens: List[int],
+    num_tokens: int,
+    num_heads: int,
+    head_size: int,
     scale: float,
 ) -> torch.Tensor:
-    output = torch.empty_like(query)
     start = 0
-    for _, seq_len in enumerate(seq_lens):
+    output = torch.empty((num_tokens, num_heads, head_size),
+                         dtype=query.dtype,
+                         device=query.device)
+
+    for seq_len in seq_lens:
         end = start + seq_len
-        out = _naive_masked_attention(
-            query[start:end],
-            key[start:end],
-            value[start:end],
-            scale,
-        )
-        # TODO: Unnecessary copy. Optimize.
-        output[start:end].copy_(out)
-        start += seq_len
+        with torch.backends.cuda.sdp_kernel(enable_math=True,
+                                            enable_flash=False,
+                                            enable_mem_efficient=False):
+            sub_out = torch.nn.functional.scaled_dot_product_attention(
+                query[:, start:end, :],
+                key[:, start:end, :],
+                value[:, start:end, :],
+                dropout_p=0.0,
+                is_causal=True,
+                scale=scale).movedim(query.dim() - 2, 0)
+            output[start:end, :, :] = sub_out
+            start = end
 
     return output
-
-
-def _naive_masked_attention(
-    query: torch.Tensor,
-    key: torch.Tensor,
-    value: torch.Tensor,
-    scale: float,
-) -> torch.Tensor:
-    seq_len, head_size, head_dim = query.shape
-    attn_mask = torch.triu(torch.ones(seq_len,
-                                      seq_len,
-                                      dtype=query.dtype,
-                                      device=query.device),
-                           diagonal=1)
-    attn_mask = attn_mask * torch.finfo(query.dtype).min
-    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
-    attn_weights = attn_weights + attn_mask.float()
-    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
-    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
-    return out