Bläddra i källkod

attention: add `AttentionState` abstraction (#863)

AlpinDale 3 månader sedan
förälder
incheckning
1405051912

+ 1 - 1
Dockerfile

@@ -118,7 +118,7 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/aphrodite-workspace
     python3 -m pip install dist/*.whl --verbose
 
 RUN --mount=type=cache,target=/root/.cache/pip \
-    python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu124torch2.4-cp310-cp310-linux_x86_64.whl
+    python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp310-cp310-linux_x86_64.whl
 #################### Aphrodite installation IMAGE ####################
 
 

+ 2 - 0
aphrodite/attention/__init__.py

@@ -1,6 +1,7 @@
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionMetadata,
                                                    AttentionMetadataBuilder,
+                                                   AttentionState,
                                                    AttentionType)
 from aphrodite.attention.layer import Attention
 from aphrodite.attention.selector import get_attn_backend
@@ -11,5 +12,6 @@ __all__ = [
     "AttentionType",
     "AttentionMetadataBuilder",
     "Attention",
+    "AttentionState",
     "get_attn_backend",
 ]

+ 48 - 1
aphrodite/attention/backends/abstract.py

@@ -1,4 +1,5 @@
 from abc import ABC, abstractmethod
+from contextlib import contextmanager
 from dataclasses import dataclass, fields
 from enum import Enum, auto
 from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
@@ -8,7 +9,7 @@ import torch
 
 if TYPE_CHECKING:
     from aphrodite.task_handler.model_runner_base import (
-        ModelRunnerInputBuilderBase)
+        ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase)
 
 
 class AttentionType(Enum):
@@ -35,6 +36,10 @@ class AttentionBackend(ABC):
     def get_metadata_cls() -> Type["AttentionMetadata"]:
         raise NotImplementedError
 
+    @staticmethod
+    def get_state_cls() -> Type["AttentionState"]:
+        raise NotImplementedError
+
     @classmethod
     def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
         return cls.get_metadata_cls()(*args, **kwargs)
@@ -127,6 +132,48 @@ class AttentionMetadata:
 T = TypeVar("T", bound=AttentionMetadata)
 
 
+class AttentionState(ABC, Generic[T]):
+    """Holds attention backend specific objects reused during the
+    lifetime of the model runner.
+    """
+
+    @abstractmethod
+    def __init__(self, runner: "ModelRunnerBase"):
+        ...
+
+    @abstractmethod
+    @contextmanager
+    def graph_capture(self, max_batch_size: int):
+        """Context manager used when capturing a CUDA graph."""
+        yield
+
+    @abstractmethod
+    def graph_clone(self, batch_size: int) -> "AttentionState[T]":
+        """Clone attention state to save in CUDA graph metadata."""
+        ...
+
+    @abstractmethod
+    def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T:
+        """Get attention metadata for CUDA graph capture of batch_size."""
+        ...
+
+    @abstractmethod
+    def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]:
+        """Get attention-specific input buffers for CUDA graph capture."""
+        ...
+
+    @abstractmethod
+    def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any],
+                                    attn_metadata: T) -> None:
+        """In-place modify input buffers dict for CUDA graph replay."""
+        ...
+
+    @abstractmethod
+    def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
+        """Prepare state for forward pass."""
+        ...
+
+
 class AttentionMetadataBuilder(ABC, Generic[T]):
     """Abstract class for attention metadata builders."""
 

+ 6 - 1
aphrodite/attention/backends/blocksparse_attn.py

@@ -7,7 +7,8 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata,
                                                    AttentionType)
-from aphrodite.attention.backends.utils import CommonMetadataBuilder
+from aphrodite.attention.backends.utils import (CommonAttentionState,
+                                                CommonMetadataBuilder)
 from aphrodite.attention.ops.blocksparse_attention.interface import (
     LocalStridedBlockSparseAttn, get_head_sliding_step)
 from aphrodite.attention.ops.paged_attn import PagedAttention
@@ -100,6 +101,10 @@ class BlocksparseFlashAttentionBackend(AttentionBackend):
     def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
         return BlocksparseFlashAttentionMetadataBuilder
 
+    @staticmethod
+    def get_state_cls() -> Type["CommonAttentionState"]:
+        return CommonAttentionState
+
     @staticmethod
     def get_kv_cache_shape(
         num_blocks: int,

+ 5 - 0
aphrodite/attention/backends/flash_attn.py

@@ -11,6 +11,7 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionMetadataBuilder,
                                                    AttentionType)
 from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
+                                                CommonAttentionState,
                                                 compute_slot_mapping,
                                                 compute_slot_mapping_start_idx,
                                                 is_block_tables_empty)
@@ -145,6 +146,10 @@ class FlashAttentionBackend(AttentionBackend):
     def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
         return FlashAttentionMetadataBuilder
 
+    @staticmethod
+    def get_state_cls() -> Type["CommonAttentionState"]:
+        return CommonAttentionState
+
     @staticmethod
     def get_kv_cache_shape(
         num_blocks: int,

+ 160 - 0
aphrodite/attention/backends/flashinfer.py

@@ -1,14 +1,19 @@
+from contextlib import contextmanager
 from dataclasses import dataclass
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type
 
 try:
     from flashinfer import BatchDecodeWithPagedKVCacheWrapper
+    from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
     from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
 
     import aphrodite.attention.backends.flash_attn  # noqa
+    FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
 except ImportError:
     BatchDecodeWithPagedKVCacheWrapper = None
+    CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
     BatchPrefillWithPagedKVCacheWrapper = None
+    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
 
 import torch
 
@@ -17,6 +22,7 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata,
                                                    AttentionMetadataBuilder,
+                                                   AttentionState,
                                                    AttentionType)
 from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
                                                 compute_slot_mapping,
@@ -48,6 +54,10 @@ class FlashInferBackend(AttentionBackend):
     def get_builder_cls() -> Type["FlashInferMetadataBuilder"]:
         return FlashInferMetadataBuilder
 
+    @staticmethod
+    def get_state_cls() -> Type["FlashInferState"]:
+        return FlashInferState
+
     @staticmethod
     def get_kv_cache_shape(
         num_blocks: int,
@@ -77,6 +87,156 @@ class FlashInferBackend(AttentionBackend):
         return [64, 128, 256]
 
 
+class FlashInferState(AttentionState):
+    def __init__(self, runner):
+        self.runner = runner
+        self._is_graph_capturing = False
+        self._workspace_buffer = None
+        self._decode_wrapper = None
+        self._prefill_wrapper = None
+
+    def _get_workspace_buffer(self):
+        if self._workspace_buffer is None:
+            self._workspace_buffer = torch.empty(
+                FLASHINFER_WORKSPACE_BUFFER_SIZE,
+                dtype=torch.uint8,
+                device=self.runner.device)
+        return self._workspace_buffer
+
+    def _get_prefill_wrapper(self):
+        if self._prefill_wrapper is None:
+            self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
+                self._get_workspace_buffer(), "NHD")
+        return self._prefill_wrapper
+
+    def _get_decode_wrapper(self):
+        if self._decode_wrapper is None:
+            num_qo_heads = (self.runner.model_config.get_num_attention_heads(
+                self.runner.parallel_config))
+            num_kv_heads = self.runner.model_config.get_num_kv_heads(
+                self.runner.parallel_config)
+            use_tensor_cores = num_qo_heads // num_kv_heads >= 4
+            self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
+                self._get_workspace_buffer(),
+                "NHD",
+                use_tensor_cores=use_tensor_cores)
+        return self._decode_wrapper
+
+    @contextmanager
+    def graph_capture(self, max_batch_size: int):
+        self._is_graph_capturing = True
+        self._graph_decode_wrapper = None
+        self._graph_slot_mapping = torch.full((max_batch_size, ),
+                                              PAD_SLOT_ID,
+                                              dtype=torch.long,
+                                              device=self.runner.device)
+        self._graph_seq_lens = torch.ones(max_batch_size,
+                                          dtype=torch.int32,
+                                          device=self.runner.device)
+        self._graph_block_tables = torch.from_numpy(
+            self.runner.graph_block_tables).to(device=self.runner.device)
+        self._graph_decode_workspace_buffer = self._get_workspace_buffer()
+        self._graph_indices_buffer = torch.empty(
+            max_batch_size * self.runner.cache_config.num_gpu_blocks,
+            dtype=torch.int32,
+            device=self.runner.device)
+        self._graph_indptr_buffer = torch.empty(max_batch_size + 1,
+                                                dtype=torch.int32,
+                                                device=self.runner.device)
+        self._graph_last_page_len_buffer = torch.empty(
+            max_batch_size, dtype=torch.int32, device=self.runner.device)
+        yield
+        self._is_graph_capturing = False
+        del self._graph_slot_mapping
+        del self._graph_seq_lens
+        del self._graph_block_tables
+        del self._graph_decode_workspace_buffer
+        del self._graph_indices_buffer
+        del self._graph_indptr_buffer
+        del self._graph_last_page_len_buffer
+        del self._graph_decode_wrapper
+
+    def graph_clone(self, batch_size: int):
+        assert self._is_graph_capturing
+        state = self.__class__(self.runner)
+        state._workspace_buffer = self._graph_decode_workspace_buffer
+        state._decode_wrapper = self._graph_decode_wrapper
+        state._prefill_wrapper = self._get_prefill_wrapper()
+        return state
+
+    def graph_capture_get_metadata_for_batch(self, batch_size: int):
+        assert self._is_graph_capturing
+        _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
+        _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
+        num_qo_heads = (self.runner.model_config.get_num_attention_heads(
+            self.runner.parallel_config))
+        num_kv_heads = self.runner.model_config.get_num_kv_heads(
+            self.runner.parallel_config)
+        use_tensor_cores = num_qo_heads // num_kv_heads >= 4
+        self._graph_decode_wrapper = \
+            CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
+            self._graph_decode_workspace_buffer, _indptr_buffer,
+            self._graph_indices_buffer, _last_page_len_buffer, "NHD",
+            use_tensor_cores)
+        kv_cache_dtype = get_kv_cache_torch_dtype(
+            self.runner.kv_cache_dtype, self.runner.model_config.dtype)
+        paged_kv_indptr_tensor_host = torch.arange(0,
+                                                   batch_size + 1,
+                                                   dtype=torch.int32)
+        paged_kv_indices_tensor_host = torch.arange(0,
+                                                    batch_size,
+                                                    dtype=torch.int32)
+        paged_kv_last_page_len_tensor_host = torch.full((batch_size, ),
+                                                        self.runner.block_size,
+                                                        dtype=torch.int32)
+        query_start_loc_host = torch.arange(0,
+                                            batch_size + 1,
+                                            dtype=torch.int32)
+        attn_metadata = self.runner.attn_backend.make_metadata(
+            num_prefills=0,
+            slot_mapping=self._graph_slot_mapping[:batch_size],
+            num_prefill_tokens=0,
+            num_decode_tokens=batch_size,
+            max_prefill_seq_len=0,
+            block_tables=self._graph_block_tables,
+            paged_kv_indptr=paged_kv_indptr_tensor_host,
+            paged_kv_indices=paged_kv_indices_tensor_host,
+            paged_kv_last_page_len=paged_kv_last_page_len_tensor_host,
+            num_qo_heads=num_qo_heads,
+            num_kv_heads=num_kv_heads,
+            head_dim=self.runner.model_config.get_head_size(),
+            page_size=self.runner.block_size,
+            seq_start_loc=None,
+            query_start_loc=query_start_loc_host,
+            device=self.runner.device,
+            data_type=kv_cache_dtype,
+            use_cuda_graph=True,
+            decode_wrapper=self._graph_decode_wrapper,
+            prefill_wrapper=None)
+        attn_metadata.begin_forward()
+        return attn_metadata
+
+    def get_graph_input_buffers(self, attn_metadata):
+        return {
+            "slot_mapping": attn_metadata.slot_mapping,
+        }
+
+    def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
+        return
+
+    def begin_forward(self, model_input):
+        assert not self._is_graph_capturing
+        state = self
+        if model_input.attn_metadata.use_cuda_graph:
+            batch_size = model_input.input_tokens.shape[0]
+            state = (self.runner.graph_runners[model_input.virtual_engine]
+                     [batch_size].attn_state)
+        model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
+        )
+        model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
+        model_input.attn_metadata.begin_forward()
+
+
 @dataclass
 class FlashInferMetadata(AttentionMetadata):
     # Maximum sequence length among prefill batch. 0 if there are decoding

+ 6 - 1
aphrodite/attention/backends/ipex_attn.py

@@ -10,7 +10,8 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata,
                                                    AttentionType)
-from aphrodite.attention.backends.utils import CommonMetadataBuilder
+from aphrodite.attention.backends.utils import (CommonAttentionState,
+                                                CommonMetadataBuilder)
 from aphrodite.attention.ops.paged_attn import (PagedAttention,
                                                 PagedAttentionMetadata)
 
@@ -35,6 +36,10 @@ class IpexAttnBackend(AttentionBackend):
     def get_builder_cls() -> Type["IpexAttnMetadataBuilder"]:
         return IpexAttnMetadataBuilder
 
+    @staticmethod
+    def get_state_cls() -> Type["CommonAttentionState"]:
+        return CommonAttentionState
+
     @staticmethod
     def get_kv_cache_shape(
         num_blocks: int,

+ 6 - 1
aphrodite/attention/backends/openvino.py

@@ -1,11 +1,12 @@
 from dataclasses import dataclass
-from typing import List, Tuple
+from typing import List, Tuple, Type
 
 import openvino as ov
 import torch
 
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionMetadata)
+from aphrodite.attention.backends.utils import CommonAttentionState
 
 
 class OpenVINOAttentionBackend(AttentionBackend):
@@ -24,6 +25,10 @@ class OpenVINOAttentionBackend(AttentionBackend):
     def make_metadata(*args, **kwargs) -> "AttentionMetadata":
         raise NotImplementedError
 
+    @staticmethod
+    def get_state_cls() -> Type["CommonAttentionState"]:
+        return CommonAttentionState
+
     @staticmethod
     def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata":
         return OpenVINOAttentionMetadata(*args, **kwargs)

+ 5 - 0
aphrodite/attention/backends/pallas.py

@@ -8,6 +8,7 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata,
                                                    AttentionType)
+from aphrodite.attention.backends.utils import CommonAttentionState
 
 
 class PallasAttentionBackend(AttentionBackend):
@@ -20,6 +21,10 @@ class PallasAttentionBackend(AttentionBackend):
     def get_metadata_cls() -> Type["PallasMetadata"]:
         return PallasMetadata
 
+    @staticmethod
+    def get_state_cls() -> Type["CommonAttentionState"]:
+        return CommonAttentionState
+
     @staticmethod
     def get_kv_cache_shape(
         num_blocks: int,

+ 5 - 0
aphrodite/attention/backends/placeholder_attn.py

@@ -7,6 +7,7 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata,
                                                    AttentionMetadataBuilder)
+from aphrodite.attention.backends.utils import CommonAttentionState
 
 if TYPE_CHECKING:
     from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
@@ -34,6 +35,10 @@ class PlaceholderAttentionBackend(AttentionBackend):
     def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
         return PlaceholderAttentionMetadata
 
+    @staticmethod
+    def get_state_cls() -> Type["CommonAttentionState"]:
+        return CommonAttentionState
+
     @staticmethod
     def get_kv_cache_shape(
         num_blocks: int,

+ 6 - 1
aphrodite/attention/backends/rocm_flash_attn.py

@@ -10,7 +10,8 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata,
                                                    AttentionType)
-from aphrodite.attention.backends.utils import CommonMetadataBuilder
+from aphrodite.attention.backends.utils import (CommonAttentionState,
+                                                CommonMetadataBuilder)
 from aphrodite.attention.ops.paged_attn import (PagedAttention,
                                                 PagedAttentionMetadata)
 
@@ -29,6 +30,10 @@ class ROCmFlashAttentionBackend(AttentionBackend):
     def get_metadata_cls() -> Type["AttentionMetadata"]:
         return ROCmFlashAttentionMetadata
 
+    @staticmethod
+    def get_state_cls() -> Type["CommonAttentionState"]:
+        return CommonAttentionState
+
     @staticmethod
     def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]:
         return ROCmFlashAttentionMetadataBuilder

+ 7 - 2
aphrodite/attention/backends/torch_sdpa.py

@@ -10,7 +10,8 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata,
                                                    AttentionType)
-from aphrodite.attention.backends.utils import CommonMetadataBuilder
+from aphrodite.attention.backends.utils import (CommonAttentionState,
+                                                CommonMetadataBuilder)
 from aphrodite.attention.ops.paged_attn import PagedAttentionMetadata
 from aphrodite.common.utils import is_cpu
 
@@ -36,7 +37,11 @@ class TorchSDPABackend(AttentionBackend):
     @staticmethod
     def get_metadata_cls() -> Type["AttentionMetadata"]:
         return TorchSDPAMetadata
-    
+
+    @staticmethod
+    def get_state_cls() -> Type["CommonAttentionState"]:
+        return CommonAttentionState
+
     @staticmethod
     def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
         return TorchSDPAMetadataBuilder

+ 71 - 2
aphrodite/attention/backends/utils.py

@@ -1,11 +1,16 @@
-from typing import TYPE_CHECKING, Dict, List, Type, TypeVar, Union
+from contextlib import contextmanager
+from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union
 
 import numpy as np
 import torch
 
-from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder
+from aphrodite.attention import (AttentionMetadata, AttentionMetadataBuilder,
+                                 AttentionState)
 from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
 
+if TYPE_CHECKING:
+    from aphrodite.task_handler.model_runner_base import ModelRunnerBase
+
 # Error string(s) for encoder/decoder
 # unsupported attention scenarios
 STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported "
@@ -266,4 +271,68 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
             context_lens_tensor=context_lens_tensor,
             block_tables=block_tables,
             use_cuda_graph=use_captured_graph,
+                )
+
+
+class CommonAttentionState(AttentionState):
+    def __init__(self, runner: "ModelRunnerBase"):
+        self.runner = runner
+        self._is_graph_capturing = False
+
+    @contextmanager
+    def graph_capture(self, max_batch_size: int):
+        self._is_graph_capturing = True
+        self._graph_slot_mapping = torch.full((max_batch_size, ),
+                                              PAD_SLOT_ID,
+                                              dtype=torch.long,
+                                              device=self.runner.device)
+        self._graph_seq_lens = torch.ones(max_batch_size,
+                                          dtype=torch.int32,
+                                          device=self.runner.device)
+        self._graph_block_tables = torch.from_numpy(
+            self.runner.graph_block_tables).to(device=self.runner.device)
+        yield
+        self._is_graph_capturing = False
+        del self._graph_slot_mapping
+        del self._graph_seq_lens
+        del self._graph_block_tables
+
+    def graph_clone(self, batch_size: int) -> "CommonAttentionState":
+        assert self._is_graph_capturing
+        return self.__class__(self.runner)
+
+    def graph_capture_get_metadata_for_batch(self, batch_size: int):
+        assert self._is_graph_capturing
+        attn_metadata = self.runner.attn_backend.make_metadata(
+            num_prefills=0,
+            num_prefill_tokens=0,
+            num_decode_tokens=batch_size,
+            slot_mapping=self._graph_slot_mapping[:batch_size],
+            seq_lens=None,
+            seq_lens_tensor=self._graph_seq_lens[:batch_size],
+            max_query_len=None,
+            max_prefill_seq_len=0,
+            max_decode_seq_len=self.runner.max_seq_len_to_capture,
+            query_start_loc=None,
+            seq_start_loc=None,
+            context_lens_tensor=None,
+            block_tables=self._graph_block_tables[:batch_size],
+            use_cuda_graph=True,
         )
+        return attn_metadata
+
+    def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]:
+        return {
+            "slot_mapping": attn_metadata.slot_mapping,
+            "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
+            "block_tables": attn_metadata.decode_metadata.block_tables,
+        }
+
+    def prepare_graph_input_buffers(self, input_buffers,
+                                    attn_metadata) -> None:
+        input_buffers["seq_lens_tensor"].copy_(
+            attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
+        input_buffers["block_tables"].copy_(
+            attn_metadata.decode_metadata.block_tables, non_blocking=True)
+    def begin_forward(self, model_input) -> None:
+        return

+ 6 - 1
aphrodite/attention/backends/xformers.py

@@ -13,7 +13,8 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionImpl,
                                                    AttentionMetadata,
                                                    AttentionType)
-from aphrodite.attention.backends.utils import CommonMetadataBuilder
+from aphrodite.attention.backends.utils import (CommonAttentionState,
+                                                CommonMetadataBuilder)
 from aphrodite.attention.ops.paged_attn import (PagedAttention,
                                                 PagedAttentionMetadata)
 
@@ -36,6 +37,10 @@ class XFormersBackend(AttentionBackend):
     def get_builder_cls() -> Type["XFormersMetadataBuilder"]:
         return XFormersMetadataBuilder
 
+    @staticmethod
+    def get_state_cls() -> Type["CommonAttentionState"]:
+        return CommonAttentionState
+
     @staticmethod
     def get_kv_cache_shape(
         num_blocks: int,

+ 1 - 46
aphrodite/spec_decode/draft_model_runner.py

@@ -12,17 +12,6 @@ except ModuleNotFoundError:
     from aphrodite.attention.backends.rocm_flash_attn import (
         ROCmFlashAttentionMetadata as FlashAttentionMetadata)
 
-try:
-    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
-    from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
-    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
-    FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
-except ImportError:
-    BatchDecodeWithPagedKVCacheWrapper = None
-    CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
-    BatchPrefillWithPagedKVCacheWrapper = None
-    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
-
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
@@ -88,11 +77,6 @@ class TP1DraftModelRunner(ModelRunner):
             **kwargs,
         )
 
-        self.flashinfer_decode_workspace_buffer = None
-        self.flashinfer_decode_wrapper = None
-        self.flashinfer_prefill_workspace_buffer = None
-        self.flashinfer_prefill_wrapper = None
-
     def _update_sampling_metadata(self, sampling_metadata, num_seqs,
                                   num_queries):
 
@@ -268,36 +252,7 @@ class TP1DraftModelRunner(ModelRunner):
                     model_input.prompt_adapter_requests,
                     model_input.prompt_adapter_mapping)
 
-            if self.attn_backend.get_name() == "flashinfer":
-                assert model_input.attn_metadata is not None
-                assert model_input.input_tokens is not None
-                if self.flashinfer_decode_workspace_buffer is None:
-                    self.flashinfer_decode_workspace_buffer = torch.empty(
-                        FLASHINFER_WORKSPACE_BUFFER_SIZE,
-                        dtype=torch.uint8,
-                        device=self.device)
-                    self.flashinfer_decode_wrapper = \
-                        BatchDecodeWithPagedKVCacheWrapper(
-                        self.flashinfer_decode_workspace_buffer, "NHD")
-                    self.flashinfer_prefill_workspace_buffer = torch.empty(
-                        FLASHINFER_WORKSPACE_BUFFER_SIZE,
-                        dtype=torch.uint8,
-                        device=self.device)
-                    self.flashinfer_prefill_wrapper = \
-                        BatchPrefillWithPagedKVCacheWrapper(
-                        self.flashinfer_prefill_workspace_buffer, "NHD")
-
-                model_input.attn_metadata.prefill_wrapper = \
-                    self.flashinfer_prefill_wrapper
-                if model_input.attn_metadata.use_cuda_graph:
-                    batch_size = model_input.input_tokens.shape[0]
-                    model_input.attn_metadata.decode_wrapper = \
-                        self.graph_runners[model_input.
-                        virtual_engine][batch_size].flashinfer_decode_wrapper
-                else:
-                    model_input.attn_metadata.decode_wrapper = \
-                        self.flashinfer_decode_wrapper
-                model_input.attn_metadata.begin_forward()
+            self.attn_state.begin_forward(model_input)
 
         # Detect exec mode
         assert model_input.attn_metadata is not None

+ 3 - 2
aphrodite/task_handler/enc_dec_model_runner.py

@@ -7,6 +7,7 @@ from loguru import logger
 
 from aphrodite.attention.backends.abstract import (AttentionBackend,
                                                    AttentionMetadata)
+from aphrodite.attention.backends.utils import PAD_SLOT_ID
 from aphrodite.attention.selector import (_Backend,
                                           get_env_variable_attn_backend,
                                           get_global_forced_attn_backend,
@@ -23,7 +24,7 @@ from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
 from aphrodite.modeling import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
 from aphrodite.task_handler.model_runner import (
-    _PAD_SLOT_ID, GPUModelRunnerBase, ModelInputForGPUBuilder,
+    GPUModelRunnerBase, ModelInputForGPUBuilder,
     ModelInputForGPUWithSamplingMetadata)
 from aphrodite.task_handler.model_runner_base import (
     _add_attn_metadata_broadcastable_dict,
@@ -387,7 +388,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
                     # initialized yet. In this case, we just use a dummy
                     # slot mapping.
                     # In embeddings, the block tables are {seq_id: None}.
-                    cross_slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
+                    cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len)
                 else:
                     for i in range(0, seq_len):
                         block_number = seq_group_metadata.cross_block_table[

+ 30 - 188
aphrodite/task_handler/model_runner.py

@@ -15,18 +15,9 @@ import torch.distributed
 import torch.nn as nn
 from loguru import logger
 
-try:
-    from flashinfer import BatchDecodeWithPagedKVCacheWrapper
-    from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper
-    from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
-    FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
-except ImportError:
-    BatchDecodeWithPagedKVCacheWrapper = None
-    CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None
-    BatchPrefillWithPagedKVCacheWrapper = None
-    FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
-
 from aphrodite.attention import AttentionMetadata, get_attn_backend
+from aphrodite.attention.backends.abstract import AttentionState
+from aphrodite.attention.backends.utils import CommonAttentionState
 from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
@@ -34,8 +25,7 @@ from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sequence import (IntermediateTensors, SamplerOutput,
                                        SequenceGroupMetadata)
 from aphrodite.common.utils import (CudaMemoryProfiler, PyObjectCache,
-                                    async_tensor_h2d, flatten_2d_lists,
-                                    get_kv_cache_torch_dtype, is_hip,
+                                    async_tensor_h2d, flatten_2d_lists, is_hip,
                                     is_pin_memory_available)
 from aphrodite.distributed import get_pp_group
 from aphrodite.distributed.parallel_state import (
@@ -67,7 +57,6 @@ from aphrodite.task_handler.model_runner_base import (
 if TYPE_CHECKING:
     from aphrodite.attention.backends.abstract import AttentionBackend
 
-_PAD_SLOT_ID = -1
 LORA_WARMUP_RANK = 8
 _BATCH_SIZE_ALIGNMENT = 8
 # Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
@@ -852,6 +841,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
             self.block_size,
             self.model_config.is_attention_free(),
         )
+        if self.attn_backend:
+            self.attn_state = self.attn_backend.get_state_cls()(
+                weakref.proxy(self))
+        else:
+            self.attn_state = CommonAttentionState(weakref.proxy(self))
 
         # Multi-modal data support
         self.input_registry = input_registry
@@ -866,11 +860,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
         self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None
 
-        self.flashinfer_decode_workspace_buffer = None
-        self.flashinfer_decode_wrapper = None
-        self.flashinfer_prefill_workspace_buffer = None
-        self.flashinfer_prefill_wrapper = None
-
         set_cpu_offload_max_bytes(
             int(self.cache_config.cpu_offload_gb * 1024**3))
 
@@ -1227,10 +1216,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
         input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
         input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
-        slot_mapping = torch.empty(max_batch_size, dtype=torch.long).cuda()
-        slot_mapping.fill_(_PAD_SLOT_ID)
-        seq_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
-        block_tables = torch.from_numpy(self.graph_block_tables).cuda()
         intermediate_inputs = None
         if not get_pp_group().is_first_rank:
             intermediate_inputs = self.model.make_empty_intermediate_tensors(
@@ -1250,102 +1235,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
             bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
         ]
 
-        if self.attn_backend.get_name() == "flashinfer":
-            # For flashinfer, different batch sizes will share the
-            # same workspace buffer.
-            decode_workspace_buffer = \
-            torch.empty(FLASHINFER_WORKSPACE_BUFFER_SIZE,
-                                                dtype=torch.uint8,
-                                              device=self.device)
-            indices_buffer = torch.empty(max_batch_size *
-                                         self.cache_config.num_gpu_blocks,
-                                         dtype=torch.int32,
-                                         device=self.device)
-            indptr_buffer = torch.empty(max_batch_size + 1,
-                                        dtype=torch.int32,
-                                        device=self.device)
-            last_page_len_buffer = torch.empty(max_batch_size,
-                                               dtype=torch.int32,
-                                               device=self.device)
-
-        with graph_capture() as graph_capture_context:
+        with self.attn_state.graph_capture(
+                max_batch_size), graph_capture() as graph_capture_context:
+
             # NOTE: Capturing the largest batch size first may help reduce the
             # memory usage of CUDA graph.
             for virtual_engine in range(
                     self.parallel_config.pipeline_parallel_size):
                 for batch_size in reversed(batch_size_capture_list):
-                    if self.attn_backend.get_name() == "flashinfer":
-                        _indptr_buffer = indptr_buffer[:batch_size + 1]
-                        _last_page_len_buffer = last_page_len_buffer[:
-                                                                     batch_size]
-
-                        num_qo_heads = (
-                            self.model_config.get_num_attention_heads(
-                                self.parallel_config, self.tp_rank))
-                        num_kv_heads = self.model_config.get_num_kv_heads(
-                            self.parallel_config, self.tp_rank)
-                        if num_qo_heads // num_kv_heads >= 4:
-                            use_tensor_cores = True
-                        else:
-                            use_tensor_cores = False
-                        decode_wrapper = \
-                            CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
-                            decode_workspace_buffer, _indptr_buffer,
-                            indices_buffer, _last_page_len_buffer, "NHD",
-                            use_tensor_cores)
-                        kv_cache_dtype = get_kv_cache_torch_dtype(
-                            self.kv_cache_dtype, self.model_config.dtype)
-
-                        paged_kv_indptr_tensor_host = torch.arange(
-                            0, batch_size + 1, dtype=torch.int32)
-                        paged_kv_indices_tensor_host = torch.arange(
-                            0, batch_size, dtype=torch.int32)
-                        paged_kv_last_page_len_tensor_host = torch.full(
-                            (batch_size, ), self.block_size, dtype=torch.int32)
-                        query_start_loc_host = torch.arange(0,
-                                                            batch_size + 1,
-                                                            dtype=torch.int32)
-
-                        attn_metadata = self.attn_backend.make_metadata(
-                            num_prefills=0,
-                            slot_mapping=slot_mapping[:batch_size],
-                            num_prefill_tokens=0,
-                            num_decode_tokens=batch_size,
-                            max_prefill_seq_len=0,
-                            block_tables=block_tables,
-                            paged_kv_indptr=paged_kv_indptr_tensor_host,
-                            paged_kv_indices=paged_kv_indices_tensor_host,
-                            paged_kv_last_page_len=
-                            paged_kv_last_page_len_tensor_host,
-                            num_qo_heads=num_qo_heads,
-                            num_kv_heads=num_kv_heads,
-                            head_dim=self.model_config.get_head_size(),
-                            page_size=self.block_size,
-                            seq_start_loc=None,
-                            query_start_loc=query_start_loc_host,
-                            device=self.device,
-                            data_type=kv_cache_dtype,
-                            use_cuda_graph=True,
-                            decode_wrapper=decode_wrapper,
-                            prefill_wrapper=None)
-                        attn_metadata.begin_forward()
-                    else:
-                        attn_metadata = self.attn_backend.make_metadata(
-                            num_prefills=0,
-                            num_prefill_tokens=0,
-                            num_decode_tokens=batch_size,
-                            slot_mapping=slot_mapping[:batch_size],
-                            seq_lens=None,
-                            seq_lens_tensor=seq_lens[:batch_size],
-                            max_query_len=None,
-                            max_prefill_seq_len=0,
-                            max_decode_seq_len=self.max_seq_len_to_capture,
-                            query_start_loc=None,
-                            seq_start_loc=None,
-                            context_lens_tensor=None,
-                            block_tables=block_tables[:batch_size],
-                            use_cuda_graph=True,
-                        )
+                    attn_metadata = (
+                        self.attn_state.graph_capture_get_metadata_for_batch(
+                            batch_size))
+
 
                     if self.lora_config:
                         lora_mapping = LoRAMapping(
@@ -1363,17 +1264,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
                             set(), prompt_adapter_mapping)
 
                     graph_runner = CUDAGraphRunner(
-                        self.model, self.attn_backend.get_name())
-
-                    if self.attn_backend.get_name() == "flashinfer":
-                        graph_runner.flashinfer_indptr_buffer = _indptr_buffer
-                        graph_runner.flashinfer_indices_buffer = indices_buffer
-                        graph_runner.flashinfer_last_page_len_buffer = \
-                            _last_page_len_buffer
-                        graph_runner.flashinfer_decode_workspace_buffer = \
-                                decode_workspace_buffer
-                        graph_runner.flashinfer_decode_wrapper = \
-                            decode_wrapper
+                        self.model, self.attn_backend.get_name(),
+                        self.attn_state.graph_clone(batch_size))
 
                     capture_inputs = {
                         "input_ids":
@@ -1501,36 +1393,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
                 model_input.prompt_adapter_requests,
                 model_input.prompt_adapter_mapping)
 
-        if self.attn_backend.get_name() == "flashinfer":
-            assert model_input.attn_metadata is not None
-            assert model_input.input_tokens is not None
-            if self.flashinfer_decode_workspace_buffer is None:
-                self.flashinfer_decode_workspace_buffer = torch.empty(
-                    FLASHINFER_WORKSPACE_BUFFER_SIZE,
-                    dtype=torch.uint8,
-                    device=self.device)
-                self.flashinfer_decode_wrapper = \
-                    BatchDecodeWithPagedKVCacheWrapper(
-                    self.flashinfer_decode_workspace_buffer, "NHD")
-                self.flashinfer_prefill_workspace_buffer = torch.empty(
-                    FLASHINFER_WORKSPACE_BUFFER_SIZE,
-                    dtype=torch.uint8,
-                    device=self.device)
-                self.flashinfer_prefill_wrapper = \
-                    BatchPrefillWithPagedKVCacheWrapper(
-                    self.flashinfer_prefill_workspace_buffer, "NHD")
-
-            model_input.attn_metadata.prefill_wrapper = \
-                self.flashinfer_prefill_wrapper
-            if model_input.attn_metadata.use_cuda_graph:
-                batch_size = model_input.input_tokens.shape[0]
-                model_input.attn_metadata.decode_wrapper = self.graph_runners[
-                    model_input.
-                    virtual_engine][batch_size].flashinfer_decode_wrapper
-            else:
-                model_input.attn_metadata.decode_wrapper = \
-                    self.flashinfer_decode_wrapper
-            model_input.attn_metadata.begin_forward()
+        self.attn_state.begin_forward(model_input)
 
         # Currently cuda graph is only supported by the decode phase.
         assert model_input.attn_metadata is not None
@@ -1598,22 +1461,17 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
 
 class CUDAGraphRunner:
 
-    def __init__(self, model: nn.Module, backend_name: str):
+    def __init__(self, model: nn.Module, backend_name: str,
+                 attn_state: AttentionState):
         self.model = model
         self.backend_name = backend_name
+        self.attn_state = attn_state
 
         self.input_buffers: Dict[str, torch.Tensor] = {}
         self.output_buffers: Dict[str, torch.Tensor] = {}
 
         self._graph: Optional[torch.cuda.CUDAGraph] = None
 
-        self.flashinfer_decode_workspace_buffer: Optional[torch.Tensor] = None
-        self.flashinfer_indptr_buffer: Optional[torch.Tensor] = None
-        self.flashinfer_indices_buffer: Optional[torch.Tensor] = None
-        self.flashinfer_last_page_len_buffer: Optional[torch.Tensor] = None
-        self.flashinfer_decode_wrapper: Optional[
-            CUDAGraphBatchDecodeWithPagedKVCacheWrapper] = None
-
     @property
     def graph(self):
         assert self._graph is not None
@@ -1678,25 +1536,13 @@ class CUDAGraphRunner:
         torch.cuda.synchronize()
 
         # Save the input and output buffers.
-        if self.backend_name == "flashinfer":
-            self.input_buffers = {
-                "input_ids": input_ids,
-                "positions": positions,
-                "kv_caches": kv_caches,
-                "slot_mapping": attn_metadata.slot_mapping,
-                **kwargs,
-            }
-        else:
-            self.input_buffers = {
-                "input_ids": input_ids,
-                "positions": positions,
-                "kv_caches": kv_caches,
-                "slot_mapping": attn_metadata.slot_mapping,
-                "seq_lens_tensor":
-                attn_metadata.decode_metadata.seq_lens_tensor,
-                "block_tables": attn_metadata.decode_metadata.block_tables,
-                **kwargs,
-            }
+        self.input_buffers = {
+            "input_ids": input_ids,
+            "positions": positions,
+            "kv_caches": kv_caches,
+            **self.attn_state.get_graph_input_buffers(attn_metadata),
+            **kwargs,
+        }
         if intermediate_inputs is not None:
             self.input_buffers.update(intermediate_inputs.tensors)
         if get_pp_group().is_last_rank:
@@ -1725,12 +1571,8 @@ class CUDAGraphRunner:
         if self.backend_name != "No attention":
             self.input_buffers["slot_mapping"].copy_(
                 attn_metadata.slot_mapping, non_blocking=True)
-        if self.backend_name != "flashinfer":
-            self.input_buffers["seq_lens_tensor"].copy_(
-                attn_metadata.decode_metadata.seq_lens_tensor,
-                non_blocking=True)
-            self.input_buffers["block_tables"].copy_(
-                attn_metadata.decode_metadata.block_tables, non_blocking=True)
+        self.attn_state.prepare_graph_input_buffers(self.input_buffers,
+                                                    attn_metadata)
         if "seqlen_agnostic_capture_inputs" in self.input_buffers:
             self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
                                                       **kwargs)