AlpinDale преди 7 месеца
родител
ревизия
93cffaf446
променени са 3 файла, в които са добавени 81 реда и са изтрити 55 реда
  1. 75 54
      aphrodite/attention/backends/flash_attn.py
  2. 5 0
      aphrodite/attention/selector.py
  3. 1 1
      requirements-cuda.txt

+ 75 - 54
aphrodite/attention/backends/flash_attn.py

@@ -1,20 +1,16 @@
-"""Attention layer with Flash and PagedAttention.
-
-NOTE: At the moment, this file includes a lot of duplicated code from
-XFormers backend. The duplicated code will be removed once we use flash-attn or
-flashinfer for all the attention operations.
-"""
+"""Attention layer with FlashAttention."""
 from dataclasses import dataclass
 from typing import List, Optional, Tuple, Type
 
 import torch
-from vllm_flash_attn import flash_attn_varlen_func
+from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
 
+from aphrodite._C import cache_ops
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata)
-from aphrodite.attention.ops.paged_attn import (PagedAttention,
-                                                PagedAttentionMetadata)
+
+_SUPPORTED_HEAD_SIZES = [32, 64, 96, 128, 160, 192, 224, 256]
 
 
 class FlashAttentionBackend(AttentionBackend):
@@ -38,8 +34,9 @@ class FlashAttentionBackend(AttentionBackend):
         num_kv_heads: int,
         head_size: int,
     ) -> Tuple[int, ...]:
-        return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
-                                                 num_kv_heads, head_size)
+        if block_size % 16 != 0:
+            raise ValueError("Block size must be a multiple of 16.")
+        return (2, num_blocks, block_size, num_kv_heads, head_size)
 
     @staticmethod
     def swap_blocks(
@@ -47,18 +44,26 @@ class FlashAttentionBackend(AttentionBackend):
         dst_kv_cache: torch.Tensor,
         src_to_dst: torch.Tensor,
     ) -> None:
-        PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
+        src_key_cache = src_kv_cache[0]
+        dst_key_cache = dst_kv_cache[0]
+        cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
+
+        src_value_cache = src_kv_cache[1]
+        dst_value_cache = dst_kv_cache[1]
+        cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
 
     @staticmethod
     def copy_blocks(
         kv_caches: List[torch.Tensor],
         src_to_dists: torch.Tensor,
     ) -> None:
-        PagedAttention.copy_blocks(kv_caches, src_to_dists)
+        key_caches = [kv_cache[0] for kv_cache in kv_caches]
+        value_caches = [kv_cache[1] for kv_cache in kv_caches]
+        cache_ops.copy_blocks(key_caches, value_caches, src_to_dists)
 
 
 @dataclass
-class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
+class FlashAttentionMetadata(AttentionMetadata):
     """Metadata for FlashAttentionBackend.
 
     NOTE: Any python object stored here is not updated when it is
@@ -100,6 +105,14 @@ class FlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
     # so far).
     context_lens_tensor: Optional[torch.Tensor]
 
+    # (batch_size, max_blocks_per_seq).
+    # Block addresses per sequence. (Seq id -> list of physical block)
+    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
+    # in the kv cache. Each block can contain up to block_size tokens.
+    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
+    # captured.
+    block_tables: Optional[torch.Tensor]
+
     # Whether or not if cuda graph is enabled.
     # Cuda-graph is currently enabled for decoding only.
     # TODO: Move `use_cuda_graph` out since it's unrelated to attention.
@@ -220,11 +233,15 @@ class FlashAttentionImpl(AttentionImpl):
         assert self.num_heads % self.num_kv_heads == 0
         self.num_queries_per_kv = self.num_heads // self.num_kv_heads
 
-        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
-        if head_size not in suppored_head_sizes:
+        if sliding_window is not None:
+            # NOTE: flash-attn's sliding window does not work with
+            # paged KV cache.
+            raise ValueError(
+                "Sliding window is not supported in FlashAttention.")
+        if head_size not in _SUPPORTED_HEAD_SIZES:
             raise ValueError(
-                f"Head size {head_size} is not supported by PagedAttention. "
-                f"Supported head sizes are: {suppored_head_sizes}.")
+                f"Head size {head_size} is not supported by FlashAttention. "
+                f"Supported head sizes are: {_SUPPORTED_HEAD_SIZES}.")
 
     def forward(
         self,
@@ -235,17 +252,20 @@ class FlashAttentionImpl(AttentionImpl):
         attn_metadata: FlashAttentionMetadata,
         kv_scale: float = 1.0,
     ) -> torch.Tensor:
-        """Forward pass with FlashAttention and PagedAttention.
+        """Forward pass with FlashAttention.
 
         Args:
             query: shape = [num_tokens, num_heads * head_size]
             key: shape = [num_tokens, num_kv_heads * head_size]
             value: shape = [num_tokens, num_kv_heads * head_size]
-            kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
+            kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size]
             attn_metadata: Metadata for attention.
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
+        # NOTE: FlashAttention does not support FP8 KV cache.
+        assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention."
+
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.
         query = query.view(-1, self.num_heads, self.head_size)
@@ -253,16 +273,20 @@ class FlashAttentionImpl(AttentionImpl):
         value = value.view(-1, self.num_kv_heads, self.head_size)
 
         if kv_cache is not None:
-            key_cache, value_cache = PagedAttention.split_kv_cache(
-                kv_cache, self.num_kv_heads, self.head_size)
+            key_cache = kv_cache[0]
+            value_cache = kv_cache[1]
 
             # Reshape the input keys and values and store them in the cache.
             # If kv_cache is not provided, the new key and value tensors are
             # not cached. This happens during the initial memory profiling run.
-            PagedAttention.write_to_paged_cache(key, value, key_cache,
-                                                value_cache,
-                                                attn_metadata.slot_mapping,
-                                                self.kv_cache_dtype, kv_scale)
+            cache_ops.reshape_and_cache_flash(
+                key,
+                value,
+                key_cache,
+                value_cache,
+                attn_metadata.slot_mapping.flatten(),
+                self.kv_cache_dtype,
+            )
 
         num_prefill_tokens = attn_metadata.num_prefill_tokens
         num_decode_tokens = attn_metadata.num_decode_tokens
@@ -282,7 +306,8 @@ class FlashAttentionImpl(AttentionImpl):
 
         if prefill_meta := attn_metadata.prefill_metadata:
             # Prompt run.
-            if kv_cache is None or prefill_meta.block_tables.numel() == 0:
+            if (kv_cache is None or prefill_meta.block_tables is None
+                    or prefill_meta.block_tables.numel() == 0):
                 # normal attention
                 # When block_tables are not filled, it means q and k are the
                 # prompt, and they have the same length.
@@ -303,38 +328,34 @@ class FlashAttentionImpl(AttentionImpl):
                 output[:num_prefill_tokens] = out
             else:
                 # prefix-enabled attention
-                # TODO: this triton kernel has regression issue (broke) to
-                # deal with different data types between KV and FP8 KV cache,
-                # to be addressed separately.
-                output[:num_prefill_tokens] = PagedAttention.forward_prefix(
-                    query,
-                    key,
-                    value,
-                    key_cache,
-                    value_cache,
-                    prefill_meta.block_tables,
-                    prefill_meta.query_start_loc,
-                    prefill_meta.seq_lens_tensor,
-                    prefill_meta.context_lens_tensor,
-                    prefill_meta.max_query_len,
-                    self.alibi_slopes,
-                    self.sliding_window[0],
+                assert prefill_meta.seq_lens is not None
+                max_seq_len = max(prefill_meta.seq_lens)
+                output[:num_prefill_tokens] = flash_attn_varlen_func(
+                    q=query,
+                    k=key_cache,
+                    v=value_cache,
+                    cu_seqlens_q=prefill_meta.query_start_loc,
+                    max_seqlen_q=prefill_meta.max_query_len,
+                    cu_seqlens_k=prefill_meta.seq_start_loc,
+                    max_seqlen_k=max_seq_len,
+                    softmax_scale=self.scale,
+                    causal=True,
+                    alibi_slopes=self.alibi_slopes,
+                    block_table=prefill_meta.block_tables,
                 )
+
         if decode_meta := attn_metadata.decode_metadata:
             # Decoding run.
-            output[num_prefill_tokens:] = PagedAttention.forward_decode(
-                decode_query,
+            output[num_prefill_tokens:] = flash_attn_with_kvcache(
+                decode_query.unsqueeze(1),
                 key_cache,
                 value_cache,
-                decode_meta.block_tables,
-                decode_meta.seq_lens_tensor,
-                decode_meta.max_decode_seq_len,
-                self.kv_cache_dtype,
-                self.num_kv_heads,
-                self.scale,
-                self.alibi_slopes,
-                kv_scale,
-            )
+                block_table=decode_meta.block_tables,
+                cache_seqlens=decode_meta.seq_lens_tensor,
+                softmax_scale=self.scale,
+                causal=True,
+                alibi_slopes=self.alibi_slopes,
+            ).squeeze(1)
 
         # Reshape the output tensor.
         return output.view(num_tokens, hidden_size)

+ 5 - 0
aphrodite/attention/selector.py

@@ -92,6 +92,11 @@ def _which_attn_to_use(
                     "torch.float16 or torch.bfloat16.")
         return _Backend.XFORMERS
 
+    if block_size % 16 != 0:
+        logger.info("Cannot use FlashAttention-2 backend for block size not "
+                    "divisible by 16.")
+        return _Backend.XFORMERS
+
     if kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
         logger.info("Cannot use FlashAttention-2 backend for FP8 KV cache.")
         return _Backend.XFORMERS

+ 1 - 1
requirements-cuda.txt

@@ -6,4 +6,4 @@ nvidia-ml-py == 12.555.43
 torch == 2.3.0
 xformers == 0.0.26.post1  # Requires torch 2.3.0
 triton >= 2.2.0
-vllm-flash-attn == 2.5.8.post1 # Requires PyTorch 2.3.0
+vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0