Forráskód Böngészése

feat: mamba model support (#674)

* feat: mamba model support

* ruff
AlpinDale 6 hónapja
szülő
commit
bf88c8567e

+ 304 - 0
aphrodite/attention/backends/placeholder_attn.py

@@ -0,0 +1,304 @@
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, List, Optional, Tuple, Type
+
+import torch
+
+from aphrodite.attention.backends.abstract import (AttentionBackend,
+                                                   AttentionImpl,
+                                                   AttentionMetadata,
+                                                   AttentionMetadataBuilder)
+
+if TYPE_CHECKING:
+    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+
+# Placeholder attention backend for models like Mamba and embedding models that
+# lack attention.
+
+
+class PlaceholderAttentionBackend(AttentionBackend):
+    """Placeholder backend for when no attention is needed."""
+
+    @staticmethod
+    def get_name() -> str:
+        return "No attention"
+
+    @staticmethod
+    def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
+        return PlaceholderAttentionImpl
+
+    @staticmethod
+    def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
+        return PlaceholderAttentionMetadataBuilder
+
+    @staticmethod
+    def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
+        return PlaceholderAttentionMetadata
+
+    @staticmethod
+    def get_kv_cache_shape(
+        num_blocks: int,
+        block_size: int,
+        num_kv_heads: int,
+        head_size: int,
+    ) -> Tuple[int, ...]:
+        return (1, 1, 1, 1, 1)
+
+    @staticmethod
+    def swap_blocks(
+        src_kv_cache: torch.Tensor,
+        dst_kv_cache: torch.Tensor,
+        src_to_dst: torch.Tensor,
+    ) -> None:
+        return
+
+    @staticmethod
+    def copy_blocks(
+        kv_caches: List[torch.Tensor],
+        src_to_dists: torch.Tensor,
+    ) -> None:
+        return
+
+
+@dataclass
+class PlaceholderAttentionMetadata(AttentionMetadata):
+    """Attention metadata for prefill and decode batched together."""
+    # (batch_size,). The sequence length per sequence. Sequence length means
+    # the computed tokens + new tokens None if it is a decoding.
+    seq_lens: Optional[List[int]]
+    # seq_lens stored as a tensor.
+    seq_lens_tensor: Optional[torch.Tensor]
+
+    # Maximum query length in the batch. None for decoding.
+    max_query_len: Optional[int]
+    # Maximum sequence length among prefill batch. 0 if there are decoding
+    # requests only.
+    max_prefill_seq_len: int
+    # Maximum sequence length among decode batch. 0 if there are prefill
+    # requests only.
+    max_decode_seq_len: int
+    # (batch_size + 1,). The cumulative subquery lengths of the sequences in
+    # the batch, used to index into subquery. E.g., if the subquery length
+    # is [4, 6], it is [0, 4, 10].
+    query_start_loc: Optional[torch.Tensor]
+    # (batch_size + 1,). The cumulative sequence lengths of the sequences in
+    # the batch, used to index into sequence. E.g., if the sequence length is
+    # [4, 6], it is [0, 4, 10].
+    seq_start_loc: Optional[torch.Tensor]
+    # (batch_size,) A tensor of context lengths (tokens that are computed
+    # so far).
+    context_lens_tensor: Optional[torch.Tensor]
+
+    # (batch_size, max_blocks_per_seq).
+    # Block addresses per sequence. (Seq id -> list of physical block)
+    # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
+    # in the kv cache. Each block can contain up to block_size tokens.
+    # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
+    # captured.
+    block_tables: Optional[torch.Tensor]
+
+    # Whether or not if cuda graph is enabled.
+    # Cuda-graph is currently enabled for decoding only.
+    # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
+    use_cuda_graph: bool
+
+    _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None
+    _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None
+
+    @property
+    def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
+        if self.num_prefills == 0:
+            return None
+
+        if self._cached_prefill_metadata is not None:
+            return self._cached_prefill_metadata
+
+        assert self.seq_lens is not None
+        assert self.seq_lens_tensor is not None
+        assert self.query_start_loc is not None
+        assert self.context_lens_tensor is not None
+        assert self.seq_start_loc is not None
+
+        # Placeholders
+        slot_mapping = torch.empty(0)
+        block_tables = torch.empty(0)
+
+        self._cached_prefill_metadata = PlaceholderAttentionMetadata(
+            num_prefills=self.num_prefills,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=0,
+            slot_mapping=slot_mapping,
+            seq_lens=self.seq_lens[:self.num_prefills],
+            seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
+            max_query_len=self.max_query_len,
+            max_prefill_seq_len=self.max_prefill_seq_len,
+            max_decode_seq_len=0,
+            query_start_loc=self.query_start_loc[:self.num_prefills + 1],
+            seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
+            context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
+            block_tables=block_tables,
+            use_cuda_graph=False,
+        )
+        return self._cached_prefill_metadata
+
+    @property
+    def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
+        if self.num_decode_tokens == 0:
+            return None
+
+        if self._cached_decode_metadata is not None:
+            return self._cached_decode_metadata
+        assert self.seq_lens_tensor is not None
+
+        # Placeholders
+        slot_mapping = torch.empty(0)
+        block_tables = torch.empty(0)
+
+        self._cached_decode_metadata = PlaceholderAttentionMetadata(
+            num_prefills=0,
+            num_prefill_tokens=0,
+            num_decode_tokens=self.num_decode_tokens,
+            slot_mapping=slot_mapping,
+            seq_lens=None,
+            seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
+            max_query_len=None,
+            max_prefill_seq_len=0,
+            max_decode_seq_len=self.max_decode_seq_len,
+            query_start_loc=None,
+            seq_start_loc=None,
+            context_lens_tensor=None,
+            block_tables=block_tables,
+            use_cuda_graph=self.use_cuda_graph,
+        )
+        return self._cached_decode_metadata
+
+
+class PlaceholderAttentionMetadataBuilder(
+        AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
+
+    def __init__(self, input_builder: "ModelInputForGPUBuilder"):
+        self.prefill_seq_lens: List[int] = []
+        self.context_lens: List[int] = []
+        self.curr_seq_lens: List[int] = []
+        self.num_prefills = 0
+        self.num_prefill_tokens = 0
+        self.num_decode_tokens = 0
+
+        self.input_builder = input_builder
+        self.runner = input_builder.runner
+
+    def _add_seq_group(
+            self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
+            chunked_prefill_enabled: bool):
+        """Add a sequence group to the metadata. Specifically update/append
+        1. context length.
+        """
+        is_prompt = inter_data.is_prompt
+
+        for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
+             curr_sliding_window_block) in zip(
+                 inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
+                 inter_data.orig_seq_lens, inter_data.seq_lens,
+                 inter_data.query_lens, inter_data.context_lens,
+                 inter_data.curr_sliding_window_blocks):
+            self.context_lens.append(context_len)
+
+            if is_prompt:
+                self.num_prefills += 1
+                self.num_prefill_tokens += token_len
+                self.prefill_seq_lens.append(seq_len)
+            else:
+                assert query_len == 1, (
+                    "seq_len: {}, context_len: {}, query_len: {}".format(
+                        seq_len, context_len, query_len))
+                self.num_decode_tokens += query_len
+                self.curr_seq_lens.append(curr_seq_len)
+
+    def build(self, seq_lens: List[int], query_lens: List[int],
+              cuda_graph_pad_size: int, batch_size: int):
+        """Build attention metadata with on-device tensors.
+        Args:
+            seq_lens: The maybe padded sequence lengths of the input sequences.
+            query_lens: The query lengths of the input sequences.
+            cuda_graph_pad_size: The padding size for cuda graph.
+                                 -1 if cuda graph is not used.
+            batch_size: The maybe padded batch size.
+        """
+        for inter_data in self.input_builder.inter_data_list:
+            self._add_seq_group(inter_data,
+                                self.input_builder.chunked_prefill_enabled)
+
+        device = self.runner.device
+        use_captured_graph = cuda_graph_pad_size != -1
+
+        logits_soft_cap = getattr(self.runner.model_config.hf_config,
+                                  "attn_logit_softcapping", None)
+        if logits_soft_cap is not None:
+            raise ValueError(
+                "Please use Flashinfer backend for models with logits_soft_cap"
+                " (i.e., Gemma-2). Otherwise, the output might be wrong."
+                " Set Flashinfer backend by "
+                "export VLLM_ATTENTION_BACKEND=FLASHINFER.")
+
+        max_query_len = max(query_lens)
+        max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
+        max_decode_seq_len = max(self.curr_seq_lens, default=0)
+        num_decode_tokens = self.num_decode_tokens
+
+        if use_captured_graph:
+            num_decode_tokens = batch_size
+
+        assert max_query_len > 0, ("query_lens: {}".format(query_lens))
+
+        context_lens_tensor = torch.tensor(self.context_lens,
+                                           dtype=torch.int,
+                                           device=device)
+        seq_lens_tensor = torch.tensor(seq_lens,
+                                       dtype=torch.int,
+                                       device=device)
+        query_lens_tensor = torch.tensor(query_lens,
+                                         dtype=torch.long,
+                                         device=device)
+        query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
+                                      dtype=torch.int32,
+                                      device=device)
+        seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
+                                    dtype=torch.int32,
+                                    device=device)
+        torch.cumsum(seq_lens_tensor,
+                     dim=0,
+                     dtype=seq_start_loc.dtype,
+                     out=seq_start_loc[1:])
+        torch.cumsum(query_lens_tensor,
+                     dim=0,
+                     dtype=query_start_loc.dtype,
+                     out=query_start_loc[1:])
+
+        # Placeholders
+        slot_mapping = torch.empty(0)
+        block_tables = torch.empty(0)
+
+        return PlaceholderAttentionMetadata(
+            num_prefills=self.num_prefills,
+            slot_mapping=slot_mapping,
+            num_prefill_tokens=self.num_prefill_tokens,
+            num_decode_tokens=num_decode_tokens,
+            seq_lens=seq_lens,
+            seq_lens_tensor=seq_lens_tensor,
+            max_query_len=max_query_len,
+            max_prefill_seq_len=max_prefill_seq_len,
+            max_decode_seq_len=max_decode_seq_len,
+            query_start_loc=query_start_loc,
+            seq_start_loc=seq_start_loc,
+            context_lens_tensor=context_lens_tensor,
+            block_tables=block_tables,
+            use_cuda_graph=use_captured_graph,
+        )
+
+
+class PlaceholderAttentionImpl(AttentionImpl):
+
+    def __init__(self, *args, **kwargs) -> None:
+        return
+
+    def forward(self, *args, **kwargs) -> torch.Tensor:
+        raise NotImplementedError

+ 5 - 3
aphrodite/attention/layer.py

@@ -41,10 +41,12 @@ class Attention(nn.Module):
             kv_cache_dtype = cache_config.cache_dtype
             block_size = cache_config.block_size
             sliding_window = cache_config.sliding_window
+            is_attention_free = cache_config.is_attention_free
         else:
             kv_cache_dtype = "auto"
             block_size = 16
             sliding_window = None
+            is_attention_free = False
         if num_kv_heads is None:
             num_kv_heads = num_heads
 
@@ -75,9 +77,9 @@ class Attention(nn.Module):
         # During model initialization, the default dtype is set as the model
         # weight and activation dtype.
         dtype = torch.get_default_dtype()
-        attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
-                                        sliding_window, dtype, kv_cache_dtype,
-                                        block_size, blocksparse_params
+        attn_backend = get_attn_backend(head_size, sliding_window, dtype,
+                                        kv_cache_dtype, block_size,
+                                        is_attention_free, blocksparse_params
                                         is not None)
         impl_cls = attn_backend.get_impl_cls()
         self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,

+ 14 - 7
aphrodite/attention/selector.py

@@ -24,6 +24,7 @@ class _Backend(enum.Enum):
     FLASHINFER = enum.auto()
     PALLAS = enum.auto()
     IPEX = enum.auto()
+    NO_ATTENTION = enum.auto()
 
 
 def backend_name_to_enum(backend_name: str) -> _Backend:
@@ -88,13 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
 
 @lru_cache(maxsize=None)
 def get_attn_backend(
-    num_heads: int,
     head_size: int,
-    num_kv_heads: int,
     sliding_window: Optional[int],
     dtype: torch.dtype,
     kv_cache_dtype: Optional[str],
     block_size: int,
+    is_attention_free: bool,
     is_blocksparse: bool = False,
 ) -> Type[AttentionBackend]:
     """Selects which attention backend to use and lazily imports it."""
@@ -105,9 +105,8 @@ def get_attn_backend(
             BlocksparseFlashAttentionBackend)
         return BlocksparseFlashAttentionBackend
 
-    backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
-                                sliding_window, dtype, kv_cache_dtype,
-                                block_size)
+    backend = which_attn_to_use(head_size, sliding_window, dtype,
+                                kv_cache_dtype, block_size, is_attention_free)
     if backend == _Backend.FLASH_ATTN:
         from aphrodite.attention.backends.flash_attn import (  # noqa: F401
             FlashAttentionBackend)
@@ -147,23 +146,31 @@ def get_attn_backend(
         logger.info("Using Pallas backend.")
         from aphrodite.attention.backends.pallas import PallasAttentionBackend
         return PallasAttentionBackend
+    elif backend == _Backend.NO_ATTENTION:
+        from aphrodite.attention.backends.placeholder_attn import (
+            PlaceholderAttentionBackend)
+        return PlaceholderAttentionBackend
     else:
         raise ValueError("Invalid attention backend.")
 
 
 def which_attn_to_use(
-    num_heads: int,
     head_size: int,
-    num_kv_heads: int,
     sliding_window: Optional[int],
     dtype: torch.dtype,
     kv_cache_dtype: Optional[str],
     block_size: int,
+    is_attention_free: bool,
 ) -> _Backend:
     """Returns which flash attention backend to use."""
     # Default case.
     selected_backend = _Backend.FLASH_ATTN
 
+    # If there are no attention layers (e.g. we are running Mamba),
+    # use the placeholder NO_ATTENTION
+    if is_attention_free:
+        return _Backend.NO_ATTENTION
+
     # Check whether a particular choice of backend was
     # previously forced.
     #

+ 27 - 0
aphrodite/common/config.py

@@ -363,6 +363,19 @@ class ModelConfig:
             raise ValueError(
                 "BitsAndBytes with enforce_eager=False is not supported yet.")
 
+    def is_attention_free(self) -> bool:
+        """Returns True if the model has no attention, i.e. the model has no
+        state that grows with the size of the context.
+        """
+
+        # Return true if the model is mamba.
+        # This check should be augmented with more models in the future,
+        # and made more robust if possible.
+        if hasattr(self.hf_text_config,
+                   "model_type") and self.hf_text_config.model_type == 'mamba':
+            return True
+        return False
+
     def get_hf_config_sliding_window(self) -> Optional[int]:
         """Get the sliding window size, or None if disabled.
         """
@@ -397,6 +410,8 @@ class ModelConfig:
             # FlashAttention supports only head_size 32, 64, 128, 256,
             # we need to pad head_size 192 to 256
             return 256
+        if self.is_attention_free():
+            return 0
         if hasattr(self.hf_text_config, "head_dim"):
             return self.hf_text_config.head_dim
         # FIXME: This may not be true for all models.
@@ -427,6 +442,9 @@ class ModelConfig:
         if self.hf_config.model_type == "dbrx":
             return getattr(self.hf_config.attn_config, "kv_n_heads",
                            self.hf_config.num_attention_heads)
+        
+        if self.is_attention_free():
+            return 0
 
         attributes = [
             # For Falcon:
@@ -490,6 +508,9 @@ class ModelConfig:
     def get_layers_block_type(self,
                               parallel_config: "ParallelConfig") -> List[str]:
         num_layers = self.get_num_layers(parallel_config)
+        if self.is_attention_free():
+            assert (self.hf_config.model_type == "mamba")
+            return ["mamba"] * num_layers
         # Transformers supports layers_block_type @property
         return getattr(self.hf_config, "layers_block_type",
                        ["attention"] * num_layers)
@@ -538,6 +559,7 @@ class CacheConfig:
         gpu_memory_utilization: float,
         swap_space: int,
         cache_dtype: str,
+        is_attention_free: bool,
         num_gpu_blocks_override: Optional[int] = None,
         sliding_window: Optional[int] = None,
         enable_prefix_caching: bool = False,
@@ -548,6 +570,7 @@ class CacheConfig:
         self.swap_space_bytes = swap_space * _GB
         self.num_gpu_blocks_override = num_gpu_blocks_override
         self.cache_dtype = cache_dtype
+        self.is_attention_free = is_attention_free
         self.sliding_window = sliding_window
         self.enable_prefix_caching = enable_prefix_caching
         self.cpu_offload_gb = cpu_offload_gb
@@ -871,6 +894,8 @@ class SchedulerConfig:
             iteration.
         max_model_len: Maximum length of a sequence (including prompt
             and generated text).
+        is_attention_free: True if the running model does not have state that
+            grows as the context size increases.
         use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not.
         num_lookahead_slots: The number of slots to allocate per sequence per
             step, beyond the known token ids. This is used in speculative
@@ -893,6 +918,7 @@ class SchedulerConfig:
                  max_num_batched_tokens: Optional[int],
                  max_num_seqs: int,
                  max_model_len: int,
+                 is_attention_free: bool,
                  use_v2_block_manager: bool = False,
                  num_lookahead_slots: int = 0,
                  delay_factor: float = 0.0,
@@ -920,6 +946,7 @@ class SchedulerConfig:
 
         self.max_num_seqs = max_num_seqs
         self.max_model_len = max_model_len
+        self.is_attention_free = is_attention_free
         self.use_v2_block_manager = use_v2_block_manager
         self.num_lookahead_slots = num_lookahead_slots
         self.delay_factor = delay_factor

+ 8 - 2
aphrodite/engine/aphrodite_engine.py

@@ -1264,7 +1264,10 @@ class AphroditeEngine:
             num_free_gpu = sum(
                 scheduler.block_manager.get_num_free_gpu_blocks()
                 for scheduler in self.scheduler)
-            gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
+            if not self.model_config.is_attention_free():
+                gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu)
+            else:
+                gpu_cache_usage_sys = 0.0
 
         num_total_cpu = self.cache_config.num_cpu_blocks
         cpu_cache_usage_sys = 0.
@@ -1272,7 +1275,10 @@ class AphroditeEngine:
             num_free_cpu = sum(
                 scheduler.block_manager.get_num_free_cpu_blocks()
                 for scheduler in self.scheduler)
-            cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
+            if not self.model_config.is_attention_free():
+                cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
+            else:
+                cpu_cache_usage_sys = 0.0
 
         # Iteration stats
         num_prompt_tokens_iter = 0

+ 2 - 0
aphrodite/engine/args_tools.py

@@ -853,6 +853,7 @@ class EngineArgs:
             gpu_memory_utilization=self.gpu_memory_utilization,
             swap_space=self.swap_space,
             cache_dtype=self.kv_cache_dtype,
+            is_attention_free=model_config.is_attention_free(),
             num_gpu_blocks_override=self.num_gpu_blocks_override,
             sliding_window=model_config.get_sliding_window(),
             enable_prefix_caching=self.enable_prefix_caching,
@@ -939,6 +940,7 @@ class EngineArgs:
             max_num_batched_tokens=self.max_num_batched_tokens,
             max_num_seqs=self.max_num_seqs,
             max_model_len=model_config.max_model_len,
+            is_attention_free=model_config.is_attention_free(),
             use_v2_block_manager=self.use_v2_block_manager,
             num_lookahead_slots=(self.num_lookahead_slots
                                  if speculative_config is None else

+ 1 - 0
aphrodite/modeling/models/__init__.py

@@ -70,6 +70,7 @@ _GENERATION_MODELS = {
     "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
     "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
     "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
+    "MambaForCausalLM": ("mamba", "MambaForCausalLM"),
     "MedusaModel": ("medusa", "Medusa"),
     "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
     "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),

+ 1 - 1
aphrodite/modeling/models/interfaces.py

@@ -149,7 +149,7 @@ class HasInnerState(Protocol):
     """
         A flag that indicates this model has inner state.
         Models that has inner state usually need access to the scheduler_config
-        for max_num_seqs ,etc... (Currently only used by Jamba)
+        for max_num_seqs ,etc... (Currently only used by Jamba and Mamba)
     """
 
     def __init__(self,

+ 25 - 205
aphrodite/modeling/models/jamba.py

@@ -1,7 +1,7 @@
 # coding=utf-8
 """Inference-only Jamba model."""
 from dataclasses import dataclass
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
 
 import torch
 from torch import nn
@@ -34,6 +34,7 @@ from aphrodite.modeling.layers.vocab_parallel_embedding import (
     DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.interfaces import HasInnerState
+from aphrodite.modeling.models.mamba_cache import MambaCacheManager
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
@@ -612,10 +613,7 @@ class JambaForCausalLM(nn.Module, HasInnerState):
             if not lora_config else lora_config.lora_vocab_padding_size,
         )
         # Used to track and store by the Mamba cache between steps.
-        self.mamba_cache: Tuple[torch.Tensor, torch.Tensor] = tuple()
-        # Maps between the request id and a dict that maps between the seq_id
-        # and its index inside the self.mamba_cache
-        self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
+        self.mamba_cache: Optional[MambaCacheManager] = None
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                 config.vocab_size)
         self.sampler = Sampler()
@@ -627,8 +625,18 @@ class JambaForCausalLM(nn.Module, HasInnerState):
                 attn_metadata: AttentionMetadata,
                 intermediate_tensors: Optional[IntermediateTensors] = None,
                 **kwargs):
-        if not self.mamba_cache:
-            self._prepare_mamba_cache()
+        if self.mamba_cache is None:
+            max_batch_size = (_get_graph_batch_size(
+                self.scheduler_config.max_num_seqs) if self.scheduler_config
+                              else max(_BATCH_SIZES_TO_CAPTURE) + 2)
+
+            layers_type = self.config.layers_block_type
+            num_mamba_layers = sum(
+                [layer_type == "mamba" for layer_type in layers_type])
+
+            self.mamba_cache = MambaCacheManager(
+                self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
+                *self._get_mamba_cache_shape())
 
         if "seqlen_agnostic_capture_inputs" not in kwargs:
             # We get here only on Prefill/Eager mode runs
@@ -638,199 +646,31 @@ class JambaForCausalLM(nn.Module, HasInnerState):
 
             request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
             finished_requests_ids = kwargs["finished_requests_ids"]
-            self._release_mamba_cache(finished_requests_ids)
+            self.mamba_cache.release_finished_requests(finished_requests_ids)
             batch_size = input_ids.shape[0]
             if attn_metadata.prefill_metadata:
                 batch_size = len(request_ids_to_seq_ids)
-            mamba_cache = self._prepare_current_run_mamba_cache(
+            mamba_cache_tensors = self.mamba_cache.prepare_current_run_state(
                 request_ids_to_seq_ids, batch_size, finished_requests_ids)
         else:
             # CUDA graph capturing runs
-            mamba_cache = kwargs["seqlen_agnostic_capture_inputs"]
+            mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
 
         hidden_states = self.model(input_ids, positions, kv_caches,
-                                   attn_metadata, mamba_cache[0],
-                                   mamba_cache[1])
+                                   attn_metadata, mamba_cache_tensors[0],
+                                   mamba_cache_tensors[1])
         return hidden_states
 
-    def _swap_mamba_cache(self, from_index: int, to_index: int):
-        assert len(self.mamba_cache) > 0
-        for cache_t in self.mamba_cache:
-            cache_t[:, [to_index,from_index]] = \
-             cache_t[:, [from_index,to_index]]
-
-    def _copy_mamba_cache(self, from_index: int, to_index: int):
-        assert len(self.mamba_cache) > 0
-        for cache_t in self.mamba_cache:
-            cache_t[:, to_index].copy_(cache_t[:, from_index],
-                                       non_blocking=True)
-
-    def _move_out_if_already_occupied(self, index: int,
-                                      all_occupied_indices: List[int]):
-        if index in all_occupied_indices:
-            first_free_index = self._first_free_index_in_mamba_cache()
-            # In case occupied, move the occupied to a new empty block
-            self._move_cache_index_and_mappings(from_index=index,
-                                                to_index=first_free_index)
-
-    def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
-                                                       seq_id: int,
-                                                       destination_index: int):
-        """
-        Assign (req_id,seq_id) pair to a `destination_index` index, if
-        already occupied, move the occupying index to a free index.
-        """
-        all_occupied_indices = self._get_all_occupied_indices()
-        if cur_rid not in self.mamba_cache_indices_mapping:
-            self._move_out_if_already_occupied(
-                index=destination_index,
-                all_occupied_indices=all_occupied_indices)
-            self.mamba_cache_indices_mapping[cur_rid] = {
-                seq_id: destination_index
-            }
-        elif seq_id not in (seq_ids2indices :=
-                            self.mamba_cache_indices_mapping[cur_rid]):
-            # parallel sampling , where n > 1, assume prefill have
-            # already happened now we only need to copy the already
-            # existing cache into the siblings seq_ids caches
-            self._move_out_if_already_occupied(
-                index=destination_index,
-                all_occupied_indices=all_occupied_indices)
-            index_exists = list(seq_ids2indices.values())[0]
-            # case of decoding n>1, copy prefill cache to decoding indices
-            self._copy_mamba_cache(from_index=index_exists,
-                                   to_index=destination_index)
-            self.mamba_cache_indices_mapping[cur_rid][
-                seq_id] = destination_index
-        else:
-            # already exists
-            cache_index_already_exists = self.mamba_cache_indices_mapping[
-                cur_rid][seq_id]
-            if cache_index_already_exists != destination_index:
-                # In case the seq id already exists but not in
-                # the right destination, swap it with what's occupying it
-                self._swap_pair_indices_and_mappings(
-                    from_index=cache_index_already_exists,
-                    to_index=destination_index)
-
-    def _prepare_current_run_mamba_cache(
-            self, request_ids_to_seq_ids: Dict[str, list[int]],
-            batch_size: int, finished_requests_ids: List[str]):
-        running_indices = []
-        request_ids_to_seq_ids_flatten = [
-            (req_id, seq_id)
-            for req_id, seq_ids in request_ids_to_seq_ids.items()
-            for seq_id in seq_ids
-        ]
-        for dest_index, (request_id,
-                         seq_id) in enumerate(request_ids_to_seq_ids_flatten):
-            if request_id in finished_requests_ids:
-                # Do not allocate cache index for requests that run
-                # and finish right after
-                continue
-            self._assign_seq_id_to_mamba_cache_in_specific_dest(
-                request_id, seq_id, dest_index)
-            running_indices.append(dest_index)
-
-        self._clean_up_first_bs_blocks(batch_size, running_indices)
-        conv_state = self.mamba_cache[0][:, :batch_size]
-        temporal_state = self.mamba_cache[1][:, :batch_size]
-
-        return (conv_state, temporal_state)
-
-    def _get_all_occupied_indices(self):
-        return [
-            cache_idx
-            for seq_ids2indices in self.mamba_cache_indices_mapping.values()
-            for cache_idx in seq_ids2indices.values()
-        ]
-
-    def _clean_up_first_bs_blocks(self, batch_size: int,
-                                  indices_for_current_run: List[int]):
-        # move out all of the occupied but currently not running blocks
-        # outside of the first n blocks
-        destination_indices = set([range(batch_size)])
-        max_possible_batch_size = self.mamba_cache[0].shape[1]
-        for destination_index in destination_indices:
-            if destination_index in self._get_all_occupied_indices() and  \
-               destination_index not in indices_for_current_run:
-                # move not running indices outside of the batch
-                all_other_indices = list(
-                    range(batch_size, max_possible_batch_size))
-                first_avail_index = self._first_free_index_in_mamba_cache(
-                    all_other_indices)
-                self._swap_indices(from_index=destination_index,
-                                   to_index=first_avail_index)
-
-    def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
-        self._copy_mamba_cache(from_index=from_index, to_index=to_index)
-        self._update_mapping_index(from_index=from_index, to_index=to_index)
-
-    def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
-        self._swap_mamba_cache(from_index=from_index, to_index=to_index)
-        self._swap_mapping_index(from_index=from_index, to_index=to_index)
-
-    def _swap_mapping_index(self, from_index: int, to_index: int):
-        for seq_ids2index in self.mamba_cache_indices_mapping.values():
-            for seq_id, index in seq_ids2index.items():
-                if from_index == index:
-                    seq_ids2index.update({seq_id: to_index})
-                elif to_index == index:
-                    seq_ids2index.update({seq_id: from_index})
-
-    def _update_mapping_index(self, from_index: int, to_index: int):
-        for seq_ids2index in self.mamba_cache_indices_mapping.values():
-            for seq_id, index in seq_ids2index.items():
-                if from_index == index:
-                    seq_ids2index.update({seq_id: to_index})
-                    return
 
     def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
-        """
-        Copy the relevant Mamba cache into the CUDA graph input buffer 
-        that was provided during the capture runs 
-        (JambaForCausalLM.mamba_gc_cache_buffer). 
-        """
-        assert all(
-            key in kwargs
-            for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
-        finished_requests_ids = kwargs["finished_requests_ids"]
-        self._release_mamba_cache(finished_requests_ids)
-        request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
-        cg_batch_size = input_buffers['input_ids'].shape[0]
-        self._prepare_current_run_mamba_cache(request_ids_to_seq_ids,
-                                              cg_batch_size,
-                                              finished_requests_ids)
+        return self.mamba_cache.copy_inputs_before_cuda_graphs(
+            input_buffers, **kwargs)
 
     def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
-        """
-        Provide the CUDA graph capture runs with a buffer in adjusted size.
-        The buffer is used to maintain the Mamba Cache during the CUDA graph 
-        replay runs.
-        """
-        return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
-
-    def _release_mamba_cache(self, finished_seq_groups_req_ids: List[str]):
-        for req_id in finished_seq_groups_req_ids:
-            if req_id in self.mamba_cache_indices_mapping:
-                self.mamba_cache_indices_mapping.pop(req_id)
-
-    def _first_free_index_in_mamba_cache(
-            self, indices_range: Optional[List[int]] = None) -> int:
-        assert self.mamba_cache is not None
-        if indices_range is None:
-            max_possible_batch_size = self.mamba_cache[0].shape[1]
-            indices_range = list(range(max_possible_batch_size))
-        all_occupied_indices = self._get_all_occupied_indices()
-        for i in indices_range:
-            if i not in all_occupied_indices:
-                return i
-        raise Exception("Couldn't find a free spot in the mamba cache! This"
-                        "should never happen")
+        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
 
     def _get_mamba_cache_shape(
-            self
-    ) -> Tuple[Optional[Tuple[int, int]], Optional[Tuple[int, int]]]:
+            self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
         world_size = get_tensor_model_parallel_world_size()
         hidden_size = self.config.hidden_size
         conv_state_shape = (
@@ -838,31 +678,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
             self.config.mamba_d_conv,
         )
         temporal_state_shape = (
-            self.config.mamba_expand * self.config.hidden_size // world_size,
+            self.config.mamba_expand * hidden_size // world_size,
             self.config.mamba_d_state,
         )
         return conv_state_shape, temporal_state_shape
 
-    def _prepare_mamba_cache(self):
-        dtype = self.lm_head.weight.dtype
-        layers_type = self.config.layers_block_type
-        mamba_layers = sum(
-            [layer_type == "mamba" for layer_type in layers_type])
-        max_batch_size = (_get_graph_batch_size(
-            self.scheduler_config.max_num_seqs) if self.scheduler_config else
-                          max(_BATCH_SIZES_TO_CAPTURE) + 2)
-        conv_state_shape, temporal_state_shape = self._get_mamba_cache_shape()
-        assert conv_state_shape is not None and temporal_state_shape is not None
-
-        self.mamba_cache = (torch.empty(size=(mamba_layers, max_batch_size) +
-                                        conv_state_shape,
-                                        dtype=dtype,
-                                        device="cuda"),
-                            torch.empty(size=(mamba_layers, max_batch_size) +
-                                        temporal_state_shape,
-                                        dtype=dtype,
-                                        device="cuda"))
-
     def compute_logits(self, hidden_states: torch.Tensor,
                        sampling_metadata: SamplingMetadata) -> torch.Tensor:
         logits = self.logits_processor(self.lm_head, hidden_states,

+ 547 - 0
aphrodite/modeling/models/mamba.py

@@ -0,0 +1,547 @@
+# coding=utf-8
+"""PyTorch MAMBA model."""
+from dataclasses import dataclass
+from typing import Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+from torch.nn.parameter import Parameter
+from transformers import MambaConfig
+
+from aphrodite.attention.backends.abstract import AttentionMetadata
+from aphrodite.common.config import CacheConfig, LoRAConfig, SchedulerConfig
+from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
+from aphrodite.distributed import (get_tensor_model_parallel_rank,
+                                   get_tensor_model_parallel_world_size)
+from aphrodite.modeling.layers.activation import SiluAndMul
+from aphrodite.modeling.layers.layernorm import RMSNorm
+from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
+                                              MergedColumnParallelLinear,
+                                              RowParallelLinear)
+from aphrodite.modeling.layers.logits_processor import LogitsProcessor
+from aphrodite.modeling.layers.mamba.ops.causal_conv1d import (
+    causal_conv1d_fn, causal_conv1d_update)
+from aphrodite.modeling.layers.mamba.ops.mamba_ssm import (
+    selective_scan_fn, selective_state_update)
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.layers.vocab_parallel_embedding import (
+    VocabParallelEmbedding)
+from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
+from aphrodite.modeling.models.interfaces import HasInnerState
+from aphrodite.modeling.models.mamba_cache import MambaCacheManager
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.modeling.utils import set_weight_attrs
+from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.task_handler.model_runner import (_BATCH_SIZES_TO_CAPTURE,
+                                                 _get_graph_batch_size)
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+@dataclass
+class MambaCacheParams:
+    is_prompt: bool = False
+    conv_state: torch.Tensor = torch.Tensor()
+    ssm_state: torch.Tensor = torch.Tensor()
+
+
+# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
+class MambaMixer(nn.Module):
+    """
+    Compute ∆, A, B, C, and D the state space parameters and compute
+    the `contextualized_states`. A, D are input independent
+    (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
+    for why A isn't selective) ∆, B, C are input-dependent
+    (this is a key difference between Mamba and the linear time
+    invariant S4, and is why Mamba is called
+    **selective** state spaces)
+    """
+
+    def __init__(self, config: MambaConfig, layer_idx):
+        super().__init__()
+        self.config = config
+        self.layer_idx = layer_idx
+        self.hidden_size = config.hidden_size
+        self.ssm_state_size = config.state_size
+        self.conv_kernel_size = config.conv_kernel
+        self.intermediate_size = config.intermediate_size
+        self.time_step_rank = int(config.time_step_rank)
+        self.use_conv_bias = config.use_conv_bias
+
+        # TODO: ??
+        #self.use_bias = config.mamba_proj_bias
+        self.use_bias = False
+
+        self.conv1d = ColumnParallelLinear(
+            input_size=self.conv_kernel_size,
+            output_size=self.intermediate_size,
+            bias=self.use_conv_bias,
+        )
+        # unsqueeze to fit conv1d weights shape into the linear weights shape.
+        # Can't do this in `weight_loader` since it already exists in
+        # `ColumnParallelLinear` and `set_weight_attrs`
+        # doesn't allow to override it
+        self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
+
+        self.in_proj = MergedColumnParallelLinear(self.hidden_size,
+                                                  [self.intermediate_size] * 2,
+                                                  bias=self.use_bias)
+        # selective projection used to make dt, B and C input dependent
+        self.x_proj = RowParallelLinear(
+            self.intermediate_size,
+            self.time_step_rank + self.ssm_state_size * 2,
+            bias=False,
+        )
+        # time step projection (discretization) -
+        # In the forward we need to apply dt_proj without the bias,
+        # as the bias is added in the selective scan kernel.
+        self.dt_proj = ColumnParallelLinear(self.time_step_rank,
+                                            self.intermediate_size,
+                                            bias=True,
+                                            skip_bias_add=True)
+
+        def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
+            tp_rank = get_tensor_model_parallel_rank()
+            tp_size = get_tensor_model_parallel_world_size()
+            param.data.copy_(
+                loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
+                                         dim=0)[tp_rank])
+
+        def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
+            weight_loader(param, -torch.exp(loaded_weight.float()))
+
+        tp_size = get_tensor_model_parallel_world_size()
+        self.A = nn.Parameter(
+            torch.empty(
+                self.intermediate_size // tp_size,
+                self.ssm_state_size,
+                dtype=torch.float32,
+            ))
+        self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))
+
+        set_weight_attrs(self.D, {"weight_loader": weight_loader})
+        set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
+
+        self.out_proj = RowParallelLinear(
+            self.intermediate_size,
+            self.hidden_size,
+            bias=self.use_bias,
+            input_is_parallel=True,
+        )
+        self.activation = config.hidden_act
+
+    def mamba_forward(self,
+                      hidden_states: torch.Tensor,
+                      cache_params: MambaCacheParams = None):
+
+        # 1. Gated MLP's linear projection
+        projected_states = self.in_proj(hidden_states)[0].transpose(1, 2)
+        hidden_states, gate = projected_states.chunk(2, dim=1)
+
+        # 2. Convolution sequence transformation
+        conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
+                                               self.conv1d.weight.size(2))
+        if cache_params is not None and not cache_params.is_prompt:
+            hidden_states = causal_conv1d_update(
+                hidden_states.squeeze(-1),
+                cache_params.conv_state,
+                conv_weights,
+                self.conv1d.bias,
+                self.activation,
+            )
+            hidden_states = hidden_states.unsqueeze(-1)
+        else:
+            if cache_params is not None:
+                conv_states = nn.functional.pad(
+                    hidden_states,
+                    (self.conv_kernel_size - hidden_states.shape[-1], 0))
+                cache_params.conv_state.copy_(conv_states)
+
+            hidden_states, _ = causal_conv1d_fn(
+                hidden_states,
+                conv_weights,
+                self.conv1d.bias,
+                activation=self.activation,
+            )
+
+        # 3. State Space Model sequence transformation
+        # 3.a. input varying initialization of time_step, B and C
+        ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))[0]
+
+        time_step, B, C = torch.split(
+            ssm_parameters,
+            [self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
+            dim=-1,
+        )
+
+        # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't.
+
+        discrete_time_step = self.dt_proj(time_step)[0].transpose(1, 2)
+        # 3.c perform the recurrence y ← SSM(A, B, C)(x)
+        time_proj_bias = (self.dt_proj.bias.float() if hasattr(
+            self.dt_proj, "bias") else None)
+        if cache_params is not None and not cache_params.is_prompt:
+            scan_outputs = selective_state_update(
+                cache_params.ssm_state,
+                hidden_states[..., 0],
+                discrete_time_step[..., 0],
+                self.A,
+                B[:, 0],
+                C[:, 0],
+                self.D,
+                gate[..., 0],
+                time_proj_bias,
+                dt_softplus=True,
+            ).unsqueeze(-1)
+        else:
+            scan_outputs, ssm_state = selective_scan_fn(
+                hidden_states,
+                discrete_time_step,
+                self.A,
+                B.transpose(1, 2),
+                C.transpose(1, 2),
+                self.D.float(),
+                gate,
+                time_proj_bias,
+                delta_softplus=True,
+                return_last_state=True,
+            )
+            if ssm_state is not None and cache_params is not None:
+                cache_params.ssm_state.copy_(ssm_state)
+
+        # 4. Final linear projection
+        contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))[0]
+        return contextualized_states
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        conv_state: torch.Tensor,
+        ssm_state: torch.Tensor,
+    ):
+        if attn_metadata.prefill_metadata is not None:
+            offset = 0
+            for i, prompt_len in enumerate(
+                    attn_metadata.prefill_metadata.seq_lens):
+                cache = MambaCacheParams(True,
+                                         conv_state=conv_state[i].unsqueeze(0),
+                                         ssm_state=ssm_state[i].unsqueeze(0))
+                hidden_states[offset:offset + prompt_len].copy_(
+                    self.mamba_forward(hidden_states[offset:offset +
+                                                     prompt_len].unsqueeze(0),
+                                       cache_params=cache)[0])
+                offset += prompt_len
+        else:
+            cache = MambaCacheParams(False,
+                                     conv_state=conv_state,
+                                     ssm_state=ssm_state)
+            hidden_states = self.mamba_forward(hidden_states.unsqueeze(1),
+                                               cache_params=cache)
+            hidden_states = hidden_states.squeeze(1)
+
+        return hidden_states
+
+
+class MambaMLP(nn.Module):
+
+    def __init__(
+        self,
+        config: MambaConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+    ) -> None:
+        super().__init__()
+        hidden_size = config.hidden_size
+        intermediate_size = config.intermediate_size
+        hidden_act = config.hidden_act
+        self.gate_up_proj = MergedColumnParallelLinear(
+            hidden_size, [intermediate_size] * 2,
+            bias=False,
+            quant_config=quant_config)
+        self.down_proj = RowParallelLinear(intermediate_size,
+                                           hidden_size,
+                                           bias=False,
+                                           quant_config=quant_config)
+        if hidden_act != "silu":
+            raise ValueError(f"Unsupported activation: {hidden_act}. "
+                             "Only silu is supported for now.")
+        self.act_fn = SiluAndMul()
+
+    def forward(self, x):
+        gate_up, _ = self.gate_up_proj(x)
+        x = self.act_fn(gate_up)
+        x, _ = self.down_proj(x)
+        return x
+
+
+class MambaDecoderLayer(nn.Module):
+
+    def __init__(self,
+                 config: MambaConfig,
+                 layer_idx: int,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None) -> None:
+        super().__init__()
+        self.layer_idx = layer_idx
+        self.config = config
+        self.mixer = MambaMixer(config, layer_idx)
+
+        self.feed_forward = MambaMLP(config, quant_config=quant_config)
+        self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
+        self.pre_ff_layernorm = RMSNorm(config.hidden_size,
+                                        eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        residual: Optional[torch.Tensor],
+        conv_state: torch.Tensor,
+        ssm_state: torch.Tensor,
+        **kwargs,
+    ):
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.norm(hidden_states)
+        else:
+            hidden_states, residual = self.norm(hidden_states, residual)
+
+        hidden_states = self.mixer(hidden_states, attn_metadata, conv_state,
+                                   ssm_state)
+        # Fully Connected
+        hidden_states, residual = self.pre_ff_layernorm(
+            hidden_states, residual)
+        hidden_states = self.feed_forward(hidden_states)
+        return hidden_states, residual
+
+
+class MambaModel(nn.Module):
+
+    def __init__(
+        self,
+        config: MambaConfig,
+        quant_config: Optional[QuantizationConfig] = None,
+        cache_config: Optional[CacheConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.padding_idx = config.pad_token_id
+        lora_vocab = ((lora_config.lora_extra_vocab_size *
+                       (lora_config.max_loras or 1)) if lora_config else 0)
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.org_vocab_size = config.vocab_size
+
+        self.embeddings = VocabParallelEmbedding(
+            self.vocab_size,
+            config.hidden_size,
+            org_num_embeddings=config.vocab_size,
+        )
+
+        decoder_layers = []
+        for i in range(config.num_hidden_layers):
+            decoder_layers.append(
+                MambaDecoderLayer(config,
+                                  layer_idx=i,
+                                  cache_config=cache_config,
+                                  quant_config=quant_config))
+        self.layers = nn.ModuleList(decoder_layers)
+        self.norm_f = RMSNorm(config.hidden_size,
+                              eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        conv_state: torch.Tensor,
+        ssm_state: torch.Tensor,
+    ) -> torch.Tensor:
+        hidden_states = self.embeddings(input_ids)
+        residual = None
+
+        for i in range(len(self.layers)):
+            layer = self.layers[i]
+            current_ssm_state = ssm_state[i]
+            current_conv_state = conv_state[i]
+
+            hidden_states, residual = layer(
+                positions=positions,
+                hidden_states=hidden_states,
+                attn_metadata=attn_metadata,
+                residual=residual,
+                conv_state=current_conv_state,
+                ssm_state=current_ssm_state,
+            )
+        hidden_states, _ = self.norm_f(hidden_states, residual)
+
+        return hidden_states
+
+
+class MambaForCausalLM(nn.Module, HasInnerState):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+    }
+
+    # LoRA specific attributes
+    supported_lora_modules = [
+        "qkv_proj",
+        "o_proj",
+        "embed_tokens",
+        "lm_head",
+    ]
+    embedding_modules = {
+        "embeddings": "input_embeddings",
+        "lm_head": "output_embeddings",
+    }
+    embedding_padding_modules = ["lm_head"]
+
+    def __init__(
+        self,
+        config: MambaConfig,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        lora_config: Optional[LoRAConfig] = None,
+        scheduler_config: Optional[SchedulerConfig] = None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.scheduler_config = scheduler_config
+        self.backbone = MambaModel(config,
+                                   cache_config=cache_config,
+                                   quant_config=quant_config,
+                                   lora_config=lora_config)
+        self.unpadded_vocab_size = config.vocab_size
+        if lora_config:
+            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+
+        self.lm_head = self.backbone.embeddings
+
+        # Used to track and store by the Mamba cache between steps.
+        self.mamba_cache: Optional[MambaCacheManager] = None
+
+        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                config.vocab_size)
+        self.sampler = Sampler()
+
+    def forward(self,
+                input_ids: torch.Tensor,
+                positions: torch.Tensor,
+                kv_caches: List[KVCache],
+                attn_metadata: AttentionMetadata,
+                intermediate_tensors: Optional[IntermediateTensors] = None,
+                **kwargs):
+        if self.mamba_cache is None:
+            max_batch_size = (_get_graph_batch_size(
+                self.scheduler_config.max_num_seqs) if self.scheduler_config
+                              else max(_BATCH_SIZES_TO_CAPTURE) + 2)
+            self.mamba_cache = MambaCacheManager(
+                self.lm_head.weight.dtype, self.config.num_hidden_layers,
+                max_batch_size, *self._get_mamba_cache_shape())
+
+        if "seqlen_agnostic_capture_inputs" not in kwargs:
+            # We get here only on Prefill/Eager mode runs
+            assert all(
+                key in kwargs
+                for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
+
+            request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
+            finished_requests_ids = kwargs["finished_requests_ids"]
+            self.mamba_cache.release_finished_requests(finished_requests_ids)
+
+            batch_size = input_ids.shape[0]
+            if attn_metadata.prefill_metadata:
+                batch_size = len(request_ids_to_seq_ids)
+            mamba_cache_tensors = self.mamba_cache.prepare_current_run_state(
+                request_ids_to_seq_ids, batch_size, finished_requests_ids)
+
+        else:
+            # CUDA graph capturing runs
+            mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"]
+
+        hidden_states = self.backbone(input_ids, positions, kv_caches,
+                                      attn_metadata, mamba_cache_tensors[0],
+                                      mamba_cache_tensors[1])
+
+        return hidden_states
+
+    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
+        return self.mamba_cache.copy_inputs_before_cuda_graphs(
+            input_buffers, **kwargs)
+
+    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
+        return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
+
+    def _get_mamba_cache_shape(
+            self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
+        world_size = get_tensor_model_parallel_world_size()
+        conv_state_shape = (
+            self.config.intermediate_size // world_size,
+            self.config.conv_kernel,
+        )
+        temporal_state_shape = (
+            self.config.intermediate_size // world_size,
+            self.config.state_size,
+        )
+        return conv_state_shape, temporal_state_shape
+
+    def compute_logits(self, hidden_states: torch.Tensor,
+                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: Optional[torch.Tensor],
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            ("qkv_proj", "q_proj", "q"),
+            ("qkv_proj", "k_proj", "k"),
+            ("qkv_proj", "v_proj", "v"),
+            ("gate_up_proj", "gate_proj", 0),
+            ("gate_up_proj", "up_proj", 1),
+        ]
+
+        params_dict = dict(self.named_parameters())
+        for name, loaded_weight in weights:
+            if "rotary_emb.inv_freq" in name:
+                continue
+
+            if "A_log" in name:
+                name = name.replace("A_log", "A")
+
+            if ".self_attn." in name:
+                name = name.replace(".self_attn", "")
+
+            for param_name, weight_name, shard_id in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 215 - 0
aphrodite/modeling/models/mamba_cache.py

@@ -0,0 +1,215 @@
+from typing import Dict, List, Optional
+
+import torch
+
+
+class MambaCacheManager:
+
+    def __init__(self, dtype, num_mamba_layers, max_batch_size,
+                 conv_state_shape, temporal_state_shape):
+
+        conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
+                                 conv_state_shape,
+                                 dtype=dtype,
+                                 device="cuda")
+        temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
+                                     temporal_state_shape,
+                                     dtype=dtype,
+                                     device="cuda")
+
+        self.mamba_cache = (conv_state, temporal_state)
+
+        # Maps between the request id and a dict that maps between the seq_id
+        # and its index inside the self.mamba_cache
+        self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {}
+
+    def _swap_mamba_cache(self, from_index: int, to_index: int):
+        assert len(self.mamba_cache) > 0
+        for cache_t in self.mamba_cache:
+            cache_t[:, [to_index,from_index]] = \
+             cache_t[:, [from_index,to_index]]
+
+    def _copy_mamba_cache(self, from_index: int, to_index: int):
+        assert len(self.mamba_cache) > 0
+        for cache_t in self.mamba_cache:
+            cache_t[:, to_index].copy_(cache_t[:, from_index],
+                                       non_blocking=True)
+
+    def _move_out_if_already_occupied(self, index: int,
+                                      all_occupied_indices: List[int]):
+        if index in all_occupied_indices:
+            first_free_index = self._first_free_index_in_mamba_cache()
+            # In case occupied, move the occupied to a new empty block
+            self._move_cache_index_and_mappings(from_index=index,
+                                                to_index=first_free_index)
+
+    def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str,
+                                                       seq_id: int,
+                                                       destination_index: int):
+        """
+        Assign (req_id,seq_id) pair to a `destination_index` index, if
+        already occupied, move the occupying index to a free index.
+        """
+        all_occupied_indices = self._get_all_occupied_indices()
+        if cur_rid not in self.mamba_cache_indices_mapping:
+            self._move_out_if_already_occupied(
+                index=destination_index,
+                all_occupied_indices=all_occupied_indices)
+            self.mamba_cache_indices_mapping[cur_rid] = {
+                seq_id: destination_index
+            }
+        elif seq_id not in (seq_ids2indices :=
+                            self.mamba_cache_indices_mapping[cur_rid]):
+            # parallel sampling , where n > 1, assume prefill have
+            # already happened now we only need to copy the already
+            # existing cache into the siblings seq_ids caches
+            self._move_out_if_already_occupied(
+                index=destination_index,
+                all_occupied_indices=all_occupied_indices)
+            index_exists = list(seq_ids2indices.values())[0]
+            # case of decoding n>1, copy prefill cache to decoding indices
+            self._copy_mamba_cache(from_index=index_exists,
+                                   to_index=destination_index)
+            self.mamba_cache_indices_mapping[cur_rid][
+                seq_id] = destination_index
+        else:
+            # already exists
+            cache_index_already_exists = self.mamba_cache_indices_mapping[
+                cur_rid][seq_id]
+            if cache_index_already_exists != destination_index:
+                # In case the seq id already exists but not in
+                # the right destination, swap it with what's occupying it
+                self._swap_pair_indices_and_mappings(
+                    from_index=cache_index_already_exists,
+                    to_index=destination_index)
+
+    def prepare_current_run_state(self,
+                                  request_ids_to_seq_ids: Dict[str, list[int]],
+                                  batch_size: int,
+                                  finished_requests_ids: List[str]):
+        running_indices = []
+        request_ids_to_seq_ids_flatten = [
+            (req_id, seq_id)
+            for req_id, seq_ids in request_ids_to_seq_ids.items()
+            for seq_id in seq_ids
+        ]
+        for dest_index, (request_id,
+                         seq_id) in enumerate(request_ids_to_seq_ids_flatten):
+            if request_id in finished_requests_ids:
+                # Do not allocate cache index for requests that run
+                # and finish right after
+                continue
+            self._assign_seq_id_to_mamba_cache_in_specific_dest(
+                request_id, seq_id, dest_index)
+            running_indices.append(dest_index)
+
+        self._clean_up_first_bs_blocks(batch_size, running_indices)
+        conv_state = self.mamba_cache[0][:, :batch_size]
+        temporal_state = self.mamba_cache[1][:, :batch_size]
+
+        return (conv_state, temporal_state)
+
+    def _get_all_occupied_indices(self):
+        return [
+            cache_idx
+            for seq_ids2indices in self.mamba_cache_indices_mapping.values()
+            for cache_idx in seq_ids2indices.values()
+        ]
+
+    def _clean_up_first_bs_blocks(self, batch_size: int,
+                                  indices_for_current_run: List[int]):
+        # move out all of the occupied but currently not running blocks
+        # outside of the first n blocks
+        destination_indices = set([range(batch_size)])
+        max_possible_batch_size = self.mamba_cache[0].shape[1]
+        for destination_index in destination_indices:
+            if destination_index in self._get_all_occupied_indices() and  \
+               destination_index not in indices_for_current_run:
+                # move not running indices outside of the batch
+                all_other_indices = list(
+                    range(batch_size, max_possible_batch_size))
+                first_avail_index = self._first_free_index_in_mamba_cache(
+                    all_other_indices)
+                self._swap_indices(from_index=destination_index,
+                                   to_index=first_avail_index)
+
+    def _move_cache_index_and_mappings(self, from_index: int, to_index: int):
+        self._copy_mamba_cache(from_index=from_index, to_index=to_index)
+        self._update_mapping_index(from_index=from_index, to_index=to_index)
+
+    def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int):
+        self._swap_mamba_cache(from_index=from_index, to_index=to_index)
+        self._swap_mapping_index(from_index=from_index, to_index=to_index)
+
+    def _swap_mapping_index(self, from_index: int, to_index: int):
+        for seq_ids2index in self.mamba_cache_indices_mapping.values():
+            for seq_id, index in seq_ids2index.items():
+                if from_index == index:
+                    seq_ids2index.update({seq_id: to_index})
+                elif to_index == index:
+                    seq_ids2index.update({seq_id: from_index})
+
+    def _update_mapping_index(self, from_index: int, to_index: int):
+        for seq_ids2index in self.mamba_cache_indices_mapping.values():
+            for seq_id, index in seq_ids2index.items():
+                if from_index == index:
+                    seq_ids2index.update({seq_id: to_index})
+                    return
+
+    def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
+        """
+        Copy the relevant Mamba cache into the CUDA graph input buffer 
+        that was provided during the capture runs 
+        (JambaForCausalLM.mamba_gc_cache_buffer). 
+        """
+        assert all(
+            key in kwargs
+            for key in ["request_ids_to_seq_ids", "finished_requests_ids"])
+        finished_requests_ids = kwargs["finished_requests_ids"]
+        self.release_finished_requests(finished_requests_ids)
+        request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"]
+        cg_batch_size = input_buffers['input_ids'].shape[0]
+        self.prepare_current_run_state(request_ids_to_seq_ids, cg_batch_size,
+                                       finished_requests_ids)
+
+    def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
+        """
+        Provide the CUDA graph capture runs with a buffer in adjusted size.
+        The buffer is used to maintain the Mamba Cache during the CUDA graph 
+        replay runs.
+        """
+        return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache)
+
+    def release_finished_requests(self,
+                                  finished_seq_groups_req_ids: List[str]):
+        for req_id in finished_seq_groups_req_ids:
+            if req_id in self.mamba_cache_indices_mapping:
+                self.mamba_cache_indices_mapping.pop(req_id)
+
+    def _first_free_index_in_mamba_cache(
+            self, indices_range: Optional[List[int]] = None) -> int:
+        assert self.mamba_cache is not None
+        if indices_range is None:
+            max_possible_batch_size = self.mamba_cache[0].shape[1]
+            indices_range = list(range(max_possible_batch_size))
+        all_occupied_indices = self._get_all_occupied_indices()
+        for i in indices_range:
+            if i not in all_occupied_indices:
+                return i
+        raise Exception("Couldn't find a free spot in the mamba cache! This"
+                        "should never happen")
+
+    def initialize_tensors(self, dtype, num_mamba_layers, max_batch_size,
+                           conv_state_shape, temporal_state_shape):
+
+        conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
+                                 conv_state_shape,
+                                 dtype=dtype,
+                                 device="cuda")
+        temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
+                                     temporal_state_shape,
+                                     dtype=dtype,
+                                     device="cuda")
+
+        self.mamba_cache = (conv_state, temporal_state)
+        self.initialized = True

+ 4 - 4
aphrodite/processing/interfaces.py

@@ -37,10 +37,10 @@ class BlockSpaceManager(ABC):
                 BlockSpaceManagerV2)
             return BlockSpaceManagerV2
 
-        if version == "embedding":
-            from aphrodite.processing.embedding_model_block_manager import (
-                EmbeddingModelBlockSpaceManager)
-            return EmbeddingModelBlockSpaceManager
+        if version == "placeholder":
+            from aphrodite.processing.placeholder_block_space_manager import (
+                PlaceholderBlockSpaceManager)
+            return PlaceholderBlockSpaceManager
 
         raise ValueError(f"Unknown version {version=}")
 

+ 5 - 4
aphrodite/processing/embedding_model_block_manager.py → aphrodite/processing/placeholder_block_space_manager.py

@@ -4,9 +4,10 @@ from aphrodite.common.sequence import Sequence, SequenceGroup
 from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager
 
 
-class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
-    """An embedding version of BlockSpaceManager for use in environments
-    with embedding models where block management is not required.
+class PlaceholderBlockSpaceManager(BlockSpaceManager):
+    """A version of BlockSpaceManager for use in environments
+    where block management is not required. 
+    For example: embedding models or attention-free models like Mamba.
     This class provides the same interface as BlockSpaceManager, but its
     methods perform no actions or return simple values like True in specific
     actions. It's designed to be used in scenarios where the overhead of
@@ -36,7 +37,7 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
         seq: Sequence,
         num_lookahead_slots: int,
     ) -> List[Tuple[int, int]]:
-        return None  # type: ignore
+        return []
 
     def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
         pass

+ 3 - 2
aphrodite/processing/scheduler.py

@@ -308,8 +308,9 @@ class Scheduler:
         version = "v1"
         if self.scheduler_config.use_v2_block_manager:
             version = "v2"
-        if self.scheduler_config.embedding_mode:
-            version = "embedding"
+        if (self.scheduler_config.embedding_mode
+                or self.scheduler_config.is_attention_free):
+            version = "placeholder"
 
         BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class(
             version)

+ 6 - 9
aphrodite/task_handler/cache_engine.py

@@ -52,15 +52,12 @@ class CacheEngine:
             self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
 
         # Get attention backend.
-        self.attn_backend = get_attn_backend(
-            model_config.get_num_attention_heads(parallel_config, tp_rank),
-            self.head_size,
-            self.num_kv_heads,
-            model_config.get_sliding_window(),
-            model_config.dtype,
-            cache_config.cache_dtype,
-            self.block_size,
-        )
+        self.attn_backend = get_attn_backend(self.head_size,
+                                             model_config.get_sliding_window(),
+                                             model_config.dtype,
+                                             cache_config.cache_dtype,
+                                             self.block_size,
+                                             model_config.is_attention_free())
 
         # Initialize the cache.
         self.gpu_cache = self._allocate_kv_cache(

+ 1 - 2
aphrodite/task_handler/cpu_model_runner.py

@@ -103,13 +103,12 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
         self.sliding_window = model_config.get_sliding_window()
         self.block_size = cache_config.block_size
         self.attn_backend = get_attn_backend(
-            self.model_config.get_num_attention_heads(self.parallel_config),
             self.model_config.get_head_size(),
-            self.model_config.get_num_kv_heads(self.parallel_config),
             self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.kv_cache_dtype,
             self.block_size,
+            self.model_config.is_attention_free(),
         )
 
         # Multi-modal data support

+ 1 - 2
aphrodite/task_handler/cpu_worker.py

@@ -57,13 +57,12 @@ class CPUCacheEngine:
 
         # Get attention backend.
         self.attn_backend = get_attn_backend(
-            self.model_config.get_num_attention_heads(self.parallel_config),
             self.model_config.get_head_size(),
-            self.model_config.get_num_kv_heads(self.parallel_config),
             self.model_config.get_sliding_window(),
             self.model_config.dtype,
             cache_config.cache_dtype,
             self.block_size,
+            self.model_config.is_attention_free(),
         )
 
         # Initialize the cache.

+ 5 - 8
aphrodite/task_handler/model_runner.py

@@ -840,18 +840,14 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
         self.graph_block_tables = np.zeros(
             (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
             dtype=np.int32)
-        num_attn_heads = self.model_config.get_num_attention_heads(
-            self.parallel_config, self.tp_rank)
         self.attn_backend = get_attn_backend(
-            num_attn_heads,
             self.model_config.get_head_size(),
-            self.model_config.get_num_kv_heads(self.parallel_config,
-                                               self.tp_rank),
             self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.kv_cache_dtype,
             self.block_size,
-        ) if num_attn_heads else None
+            self.model_config.is_attention_free(),
+        )
 
         # Multi-modal data support
         self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \
@@ -1723,8 +1719,9 @@ class CUDAGraphRunner:
         # Copy the input tensors to the input buffers.
         self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
         self.input_buffers["positions"].copy_(positions, non_blocking=True)
-        self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
-                                                 non_blocking=True)
+        if self.backend_name != "No attention":
+            self.input_buffers["slot_mapping"].copy_(
+                attn_metadata.slot_mapping, non_blocking=True)
         if self.backend_name != "flashinfer":
             self.input_buffers["seq_lens_tensor"].copy_(
                 attn_metadata.decode_metadata.seq_lens_tensor,

+ 0 - 2
aphrodite/task_handler/openvino_model_runner.py

@@ -68,9 +68,7 @@ class OpenVINOModelRunner:
         self.block_size = cache_config.block_size
 
         self.attn_backend = get_attn_backend(
-            self.model_config.get_num_attention_heads(self.parallel_config),
             self.model_config.get_head_size(),
-            self.model_config.get_num_kv_heads(self.parallel_config),
             self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.kv_cache_dtype,

+ 1 - 2
aphrodite/task_handler/openvino_worker.py

@@ -57,13 +57,12 @@ class OpenVINOCacheEngine:
 
         # Get attention backend.
         self.attn_backend = get_attn_backend(
-            self.model_config.get_num_attention_heads(self.parallel_config),
             self.head_size,
-            self.model_config.get_num_kv_heads(self.parallel_config),
             self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.cache_config.cache_dtype,
             self.block_size,
+            self.model_config.is_attention_free(),
         )
 
         # Initialize the cache.

+ 1 - 2
aphrodite/task_handler/tpu_model_runner.py

@@ -107,13 +107,12 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
             (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
             dtype=np.int32)
         self.attn_backend = get_attn_backend(
-            self.model_config.get_num_attention_heads(self.parallel_config),
             self.model_config.get_head_size(),
-            self.model_config.get_num_kv_heads(self.parallel_config),
             self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.cache_config.cache_dtype,
             self.block_size,
+            self.model_config.is_attention_free(),
             False,
         )
 

+ 18 - 9
aphrodite/task_handler/worker.py

@@ -224,11 +224,15 @@ class Worker(LocalOrDistributedWorkerBase):
             "not properly cleaned up before initializing Aphrodite.")
 
         cache_block_size = self.get_cache_block_size_bytes()
-        num_gpu_blocks = int(
-            (total_gpu_memory * self.cache_config.gpu_memory_utilization -
-             peak_memory) // cache_block_size)
-        num_cpu_blocks = int(self.cache_config.swap_space_bytes //
-                             cache_block_size)
+        if cache_block_size == 0:
+            num_gpu_blocks = 0
+            num_cpu_blocks = 0
+        else:
+            num_gpu_blocks = int(
+                (total_gpu_memory * self.cache_config.gpu_memory_utilization -
+                 peak_memory) // cache_block_size)
+            num_cpu_blocks = int(self.cache_config.swap_space_bytes //
+                                 cache_block_size)
         num_gpu_blocks = max(num_gpu_blocks, 0)
         num_cpu_blocks = max(num_cpu_blocks, 0)
         if self.model_runner.lora_manager:
@@ -245,6 +249,7 @@ class Worker(LocalOrDistributedWorkerBase):
         """
         raise_if_cache_size_invalid(num_gpu_blocks,
                                     self.cache_config.block_size,
+                                    self.cache_config.is_attention_free,
                                     self.model_config.max_model_len)
 
         self.cache_config.num_gpu_blocks = num_gpu_blocks
@@ -393,18 +398,22 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
                 "`dtype` flag in CLI, for example: --dtype=half.")
 
 
-def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
+def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
                                 max_model_len) -> None:
-    rank = get_tensor_model_parallel_rank()
-    if num_gpu_blocks <= 0:
+    if is_attention_free and num_gpu_blocks != 0:
+        raise ValueError("No memory should be allocated for the cache blocks "
+                         f"for an attention-free model, but {num_gpu_blocks}"
+                         "blocks are allocated.")
+    if not is_attention_free and num_gpu_blocks <= 0:
         raise ValueError("No available memory for the cache blocks. "
                          "Try increasing `gpu_memory_utilization` when "
                          "initializing the engine.")
     max_seq_len = block_size * num_gpu_blocks
+    rank = get_tensor_model_parallel_rank()
     if rank == 0:
         logger.info(f"Maximum sequence length allowed in the cache: "
                     f"{max_seq_len}")
-    if max_model_len > max_seq_len:
+    if not is_attention_free and max_model_len > max_seq_len:
         original_max_model_len = max_model_len
         max_model_len = max_seq_len
         # raise ValueError(

+ 1 - 2
aphrodite/task_handler/xpu_model_runner.py

@@ -113,13 +113,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
             if self.model_config is not None else 0)
 
         self.attn_backend = get_attn_backend(
-            self.model_config.get_num_attention_heads(self.parallel_config),
             self.model_config.get_head_size(),
-            self.model_config.get_num_kv_heads(self.parallel_config),
             self.model_config.get_sliding_window(),
             self.model_config.dtype,
             self.kv_cache_dtype,
             self.block_size,
+            model_config.is_attention_free(),
         )
 
         # Multi-modal data support