Parcourir la source

refactor: attention selector (#552)

* refactor the attention part

* refactor loader

* worker

* use cache_config in all models

* fixes

* .upper() the selector env variable
AlpinDale il y a 7 mois
Parent
commit
50b7c13db0
51 fichiers modifiés avec 595 ajouts et 267 suppressions
  1. 2 3
      aphrodite/attention/backends/abstract.py
  2. 7 6
      aphrodite/attention/backends/flash_attn.py
  3. 23 10
      aphrodite/attention/backends/flashinfer.py
  4. 9 7
      aphrodite/attention/backends/rocm_flash_attn.py
  5. 17 10
      aphrodite/attention/backends/torch_sdpa.py
  6. 6 6
      aphrodite/attention/backends/xformers.py
  7. 17 2
      aphrodite/attention/layer.py
  8. 23 5
      aphrodite/attention/selector.py
  9. 6 4
      aphrodite/engine/aphrodite_engine.py
  10. 1 1
      aphrodite/engine/async_aphrodite.py
  11. 10 8
      aphrodite/modeling/model_loader/__init__.py
  12. 44 51
      aphrodite/modeling/model_loader/loader.py
  13. 1 1
      aphrodite/modeling/models/__init__.py
  14. 13 3
      aphrodite/modeling/models/arctic.py
  15. 22 7
      aphrodite/modeling/models/baichuan.py
  16. 11 4
      aphrodite/modeling/models/bloom.py
  17. 14 6
      aphrodite/modeling/models/chatglm.py
  18. 11 3
      aphrodite/modeling/models/commandr.py
  19. 13 4
      aphrodite/modeling/models/dbrx.py
  20. 3 1
      aphrodite/modeling/models/decilm.py
  21. 13 3
      aphrodite/modeling/models/deepseek.py
  22. 11 4
      aphrodite/modeling/models/falcon.py
  23. 10 4
      aphrodite/modeling/models/gemma.py
  24. 12 4
      aphrodite/modeling/models/gpt2.py
  25. 10 4
      aphrodite/modeling/models/gpt_bigcode.py
  26. 15 5
      aphrodite/modeling/models/gpt_j.py
  27. 12 4
      aphrodite/modeling/models/gpt_neox.py
  28. 10 3
      aphrodite/modeling/models/internlm2.py
  29. 9 3
      aphrodite/modeling/models/jais.py
  30. 22 14
      aphrodite/modeling/models/llama.py
  31. 2 0
      aphrodite/modeling/models/llama_embedding.py
  32. 4 2
      aphrodite/modeling/models/llava.py
  33. 10 3
      aphrodite/modeling/models/minicpm.py
  34. 12 4
      aphrodite/modeling/models/mixtral.py
  35. 21 10
      aphrodite/modeling/models/mixtral_quant.py
  36. 13 5
      aphrodite/modeling/models/mpt.py
  37. 10 4
      aphrodite/modeling/models/olmo.py
  38. 12 4
      aphrodite/modeling/models/opt.py
  39. 10 3
      aphrodite/modeling/models/orion.py
  40. 12 4
      aphrodite/modeling/models/phi.py
  41. 12 3
      aphrodite/modeling/models/qwen.py
  42. 10 4
      aphrodite/modeling/models/qwen2.py
  43. 13 3
      aphrodite/modeling/models/qwen2_moe.py
  44. 10 4
      aphrodite/modeling/models/stablelm.py
  45. 15 3
      aphrodite/modeling/models/starcoder2.py
  46. 10 4
      aphrodite/modeling/models/xverse.py
  47. 11 7
      aphrodite/task_handler/cache_engine.py
  48. 11 4
      aphrodite/task_handler/cpu_model_runner.py
  49. 9 1
      aphrodite/task_handler/cpu_worker.py
  50. 0 1
      aphrodite/task_handler/embedding_model_runner.py
  51. 11 4
      aphrodite/task_handler/model_runner.py

+ 2 - 3
aphrodite/attention/backends/abstract.py

@@ -94,8 +94,6 @@ class AttentionMetadata(Generic[T]):
     # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
     # in block 0, and 1st slot in block 1, respectively.
     slot_mapping: torch.Tensor
-    # The kv cache's data type.
-    kv_cache_dtype: str
 
     def __post_init__(self):
         if self.num_prefill_tokens > 0:
@@ -116,6 +114,7 @@ class AttentionImpl(ABC):
         num_kv_heads: Optional[int] = None,
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
+        kv_cache_dtype: str = "auto",
     ) -> None:
         raise NotImplementedError
 
@@ -127,6 +126,6 @@ class AttentionImpl(ABC):
         value: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata,
-        kv_scale: float,
+        kv_scale: float = 1.0,
     ) -> torch.Tensor:
         raise NotImplementedError

+ 7 - 6
aphrodite/attention/backends/flash_attn.py

@@ -141,16 +141,18 @@ class FlashAttentionImpl(AttentionImpl):
         num_kv_heads: Optional[int] = None,
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
+        kv_cache_dtype: str = "auto",
     ) -> None:
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
         self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = ((sliding_window, sliding_window)
-                               if sliding_window is not None else (-1, -1))
         if alibi_slopes is not None:
             alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
         self.alibi_slopes = alibi_slopes
+        self.sliding_window = ((sliding_window, sliding_window)
+                               if sliding_window is not None else (-1, -1))
+        self.kv_cache_dtype = kv_cache_dtype
 
         assert self.num_heads % self.num_kv_heads == 0
         self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -168,7 +170,7 @@ class FlashAttentionImpl(AttentionImpl):
         value: torch.Tensor,
         kv_cache: torch.Tensor,
         attn_metadata: AttentionMetadata[FlashAttentionMetadata],
-        kv_scale: float,
+        kv_scale: float = 1.0,
     ) -> torch.Tensor:
         """Forward pass with FlashAttention and PagedAttention.
 
@@ -197,8 +199,7 @@ class FlashAttentionImpl(AttentionImpl):
             PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                 value_cache,
                                                 attn_metadata.slot_mapping,
-                                                attn_metadata.kv_cache_dtype,
-                                                kv_scale)
+                                                self.kv_cache_dtype, kv_scale)
 
         num_prefill_tokens = attn_metadata.num_prefill_tokens
         num_decode_tokens = attn_metadata.num_decode_tokens
@@ -265,7 +266,7 @@ class FlashAttentionImpl(AttentionImpl):
                 decode_meta.block_tables,
                 decode_meta.seq_lens_tensor,
                 decode_meta.max_seq_len,
-                attn_metadata.kv_cache_dtype,
+                self.kv_cache_dtype,
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,

+ 23 - 10
aphrodite/attention/backends/flashinfer.py

@@ -150,20 +150,33 @@ class FlashInferImpl(AttentionImpl):
         num_kv_heads: Optional[int] = None,
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
+        kv_cache_dtype: str = "auto",
     ) -> None:
-        if sliding_window is not None:
-            raise ValueError("Sliding window is not supported in FlashInfer.")
-        self.sliding_window = (-1, -1)
-        self.alibi_slopes = alibi_slopes
-        self.scale = scale
         self.num_heads = num_heads
         self.head_size = head_size
+        self.scale = float(scale)
         self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+        if sliding_window is not None:
+            raise ValueError("Sliding window is not supported in FlashInfer.")
+        self.sliding_window = (-1, -1)
+        self.kv_cache_dtype = kv_cache_dtype
 
-    def forward(self, query: torch.Tensor, key: torch.Tensor,
-                value: torch.Tensor, kv_cache: Optional[torch.Tensor],
-                attn_metadata: AttentionMetadata[FlashInferMetadata],
-                kv_scale: float):
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        kv_cache: Optional[torch.Tensor],
+        attn_metadata: AttentionMetadata[FlashInferMetadata],
+        kv_scale: float = 1.0,
+    ) -> torch.Tensor:
+        assert kv_scale == 1.0
         num_tokens, hidden_size = query.shape
         query = query.view(-1, self.num_heads, self.head_size)
         key = key.view(-1, self.num_kv_heads, self.head_size)
@@ -184,7 +197,7 @@ class FlashInferImpl(AttentionImpl):
                 kv_cache[:, 0],
                 kv_cache[:, 1],
                 attn_metadata.slot_mapping.flatten(),
-                attn_metadata.kv_cache_dtype,
+                self.kv_cache_dtype,
             )
 
         if prefill_meta := attn_metadata.prefill_metadata:

+ 9 - 7
aphrodite/attention/backends/rocm_flash_attn.py

@@ -137,25 +137,27 @@ class ROCmFlashAttentionImpl(AttentionImpl):
         num_kv_heads: Optional[int] = None,
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
+        kv_cache_dtype: str = "auto",
     ) -> None:
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
         self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = ((sliding_window, sliding_window)
-                               if sliding_window is not None else (-1, -1))
         if alibi_slopes is not None:
             alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
         self.alibi_slopes = alibi_slopes
+        self.sliding_window = ((sliding_window, sliding_window)
+                               if sliding_window is not None else (-1, -1))
+        self.kv_cache_dtype = kv_cache_dtype
 
         assert self.num_heads % self.num_kv_heads == 0
         self.num_queries_per_kv = self.num_heads // self.num_kv_heads
 
-        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
-        if head_size not in suppored_head_sizes:
+        supported_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in supported_head_sizes:
             raise ValueError(
                 f"Head size {head_size} is not supported by PagedAttention. "
-                f"Supported head sizes are: {suppored_head_sizes}.")
+                f"Supported head sizes are: {supported_head_sizes}.")
 
         self.use_naive_attn = False
         # NOTE: Allow for switching between Triton and CK. Defaulting to triton.
@@ -230,7 +232,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                 key_cache,
                 value_cache,
                 attn_metadata.slot_mapping,
-                attn_metadata.kv_cache_dtype,
+                self.kv_cache_dtype,
                 kv_scale,
             )
 
@@ -324,7 +326,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
                 decode_meta.block_tables,
                 decode_meta.seq_lens_tensor,
                 decode_meta.max_seq_len,
-                attn_metadata.kv_cache_dtype,
+                self.kv_cache_dtype,
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,

+ 17 - 10
aphrodite/attention/backends/torch_sdpa.py

@@ -84,26 +84,32 @@ class TorchSDPABackendImpl(AttentionImpl):
         num_kv_heads: Optional[int] = None,
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
+        kv_cache_dtype: str = "auto",
     ) -> None:
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
         self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = sliding_window
         if alibi_slopes is not None:
-            assert len(alibi_slopes) == num_heads
             alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
         self.alibi_slopes = alibi_slopes
-        self.need_mask = (self.alibi_slopes is not None
-                          or self.sliding_window is not None)
+        self.sliding_window = sliding_window
+        self.kv_cache_dtype = kv_cache_dtype
 
         assert self.num_heads % self.num_kv_heads == 0
         self.num_queries_per_kv = self.num_heads // self.num_kv_heads
-        suppored_head_sizes = PagedAttention.get_supported_head_sizes()
-        if head_size not in suppored_head_sizes:
+        self.need_mask = (self.alibi_slopes is not None
+                          or self.sliding_window is not None)
+
+        supported_head_sizes = PagedAttention.get_supported_head_sizes()
+        if head_size not in supported_head_sizes:
             raise ValueError(
                 f"Head size {head_size} is not supported by PagedAttention. "
-                f"Supported head sizes are: {suppored_head_sizes}.")
+                f"Supported head sizes are: {supported_head_sizes}.")
+        if kv_cache_dtype != "auto":
+            raise NotImplementedError(
+                "Torch SDPA backend does not support FP8 KV cache. "
+                "Please use xFormers backend instead.")
 
     def forward(
         self,
@@ -112,7 +118,7 @@ class TorchSDPABackendImpl(AttentionImpl):
         value: torch.Tensor,
         kv_cache: Optional[torch.Tensor],
         attn_metadata: TorchSDPAMetadata,  # type: ignore
-        kv_scale: float,
+        kv_scale: float = 1.0,
     ) -> torch.Tensor:
         """Forward pass with torch SDPA and PagedAttention.
 
@@ -125,6 +131,7 @@ class TorchSDPABackendImpl(AttentionImpl):
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
+        assert kv_scale == 1.0
         num_tokens, hidden_size = query.shape
         # Reshape the query, key, and value tensors.
         query = query.view(-1, self.num_heads, self.head_size)
@@ -137,7 +144,7 @@ class TorchSDPABackendImpl(AttentionImpl):
             PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                 value_cache,
                                                 attn_metadata.slot_mapping,
-                                                attn_metadata.kv_cache_dtype,
+                                                self..kv_cache_dtype,
                                                 kv_scale)
 
         if attn_metadata.is_prompt:
@@ -196,7 +203,7 @@ class TorchSDPABackendImpl(AttentionImpl):
                 attn_metadata.block_tables,
                 attn_metadata.seq_lens_tensor,
                 attn_metadata.max_seq_len,
-                attn_metadata.kv_cache_dtype,
+                self.kv_cache_dtype,
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,

+ 6 - 6
aphrodite/attention/backends/xformers.py

@@ -147,15 +147,17 @@ class XFormersImpl(AttentionImpl):
         num_kv_heads: Optional[int] = None,
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
+        kv_cache_dtype: str = "auto",
     ) -> None:
         self.num_heads = num_heads
         self.head_size = head_size
         self.scale = float(scale)
         self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = sliding_window
         if alibi_slopes is not None:
             alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
         self.alibi_slopes = alibi_slopes
+        self.sliding_window = sliding_window
+        self.kv_cache_dtype = kv_cache_dtype
 
         assert self.num_heads % self.num_kv_heads == 0
         self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@@ -173,7 +175,7 @@ class XFormersImpl(AttentionImpl):
         value: torch.Tensor,
         kv_cache: Optional[torch.Tensor],
         attn_metadata: AttentionMetadata[XFormersMetadata],
-        kv_scale: float,
+        kv_scale: float = 1.0,
     ) -> torch.Tensor:
         """Forward pass with xFormers and PagedAttention.
 
@@ -186,7 +188,6 @@ class XFormersImpl(AttentionImpl):
         Returns:
             shape = [num_tokens, num_heads * head_size]
         """
-        num_tokens, hidden_size = query.shape
         query = query.view(-1, self.num_heads, self.head_size)
         key = key.view(-1, self.num_kv_heads, self.head_size)
         value = value.view(-1, self.num_kv_heads, self.head_size)
@@ -201,8 +202,7 @@ class XFormersImpl(AttentionImpl):
             PagedAttention.write_to_paged_cache(key, value, key_cache,
                                                 value_cache,
                                                 attn_metadata.slot_mapping,
-                                                attn_metadata.kv_cache_dtype,
-                                                kv_scale)
+                                                self.kv_cache_dtype, kv_scale)
 
         num_prefill_tokens = attn_metadata.num_prefill_tokens
         num_decode_tokens = attn_metadata.num_decode_tokens
@@ -260,7 +260,7 @@ class XFormersImpl(AttentionImpl):
                 decode_meta.block_tables,
                 decode_meta.seq_lens_tensor,
                 decode_meta.max_seq_len,
-                attn_metadata.kv_cache_dtype,
+                self.kv_cache_dtype,
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,

+ 17 - 2
aphrodite/attention/layer.py

@@ -7,6 +7,7 @@ import torch.nn as nn
 from aphrodite.attention.backends.abstract import (AttentionMetadata,
                                                    AttentionMetadataPerStage)
 from aphrodite.attention.selector import get_attn_backend
+from aphrodite.common.config import CacheConfig
 
 
 class Attention(nn.Module):
@@ -27,10 +28,24 @@ class Attention(nn.Module):
         num_kv_heads: Optional[int] = None,
         alibi_slopes: Optional[List[float]] = None,
         sliding_window: Optional[int] = None,
+        cache_config: Optional[CacheConfig] = None,
     ) -> None:
         super().__init__()
-        self.backend = get_attn_backend(torch.get_default_dtype())
-        impl_cls = self.backend.get_impl_cls()
+        if cache_config is not None:
+            kv_cache_dtype = cache_config.cache_dtype
+            block_size = cache_config.block_size
+        else:
+            kv_cache_dtype = "auto"
+            block_size = 16
+        if num_kv_heads is None:
+            num_kv_heads = num_heads
+        # During model initialization, the default dtype is set as the model
+        # weight and activation dtype.
+        dtype = torch.get_default_dtype()
+        attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
+                                        sliding_window, dtype, kv_cache_dtype,
+                                        block_size)
+        impl_cls = attn_backend.get_impl_cls()
         self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
                              alibi_slopes, sliding_window)
 

+ 23 - 5
aphrodite/attention/selector.py

@@ -1,7 +1,7 @@
 import enum
 import os
 from functools import lru_cache
-from typing import Type
+from typing import Optional, Type
 
 import torch
 from loguru import logger
@@ -20,8 +20,18 @@ class _Backend(enum.Enum):
 
 
 @lru_cache(maxsize=None)
-def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
-    backend = _which_attn_to_use(dtype)
+def get_attn_backend(
+    num_heads: int,
+    head_size: int,
+    num_kv_heads: int,
+    sliding_window: Optional[int],
+    dtype: torch.dtype,
+    kv_cache_dtype: Optional[str],
+    block_size: int,
+) -> Type[AttentionBackend]:
+    backend = _which_attn_to_use(num_heads, head_size, num_kv_heads,
+                                 sliding_window, dtype, kv_cache_dtype,
+                                 block_size)
     if backend == _Backend.FLASH_ATTN:
         logger.info("Using FlashAttention backend.")
         from aphrodite.attention.backends.flash_attn import \
@@ -50,7 +60,15 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
         raise ValueError("Invalid attention backend.")
 
 
-def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
+def _which_attn_to_use(
+    num_heads: int,
+    head_size: int,
+    num_kv_heads: int,
+    sliding_window: Optional[int],
+    dtype: torch.dtype,
+    kv_cache_dtype: Optional[str],
+    block_size: int,
+) -> _Backend:
     """Returns which flash attention backend to use."""
     if is_cpu():
         return _Backend.TORCH_SDPA
@@ -85,7 +103,7 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
 
     backend_by_env_var = os.getenv(APHRODITE_ATTENTION_BACKEND)
     if backend_by_env_var is not None:
-        return _Backend[backend_by_env_var]
+        return _Backend[backend_by_env_var.upper()]
 
     # Default case.
     return _Backend.FLASH_ATTN

+ 6 - 4
aphrodite/engine/aphrodite_engine.py

@@ -326,10 +326,12 @@ class AphroditeEngine:
         Details:
             - Set arrival_time to the current time if it is None.
             - Set prompt_token_ids to the encoded prompt if it is None.
-            - Create `best_of` number of :class:`~aphrodite.Sequence` objects.
-            - Create a :class:`~aphrodite.SequenceGroup` object
-              from the list of :class:`~aphrodite.Sequence`.
-            - Add the :class:`~aphrodite.SequenceGroup` object to the scheduler.
+            - Create `best_of` number of :class:`~aphrodite.common.sequence`
+                objects.
+            - Create a :class:`~aphrodite.common.sequenceGroup` object
+              from the list of :class:`~aphrodite.common.sequence`.
+            - Add the :class:`~aphrodite.common.sequenceGroup` object to the
+                scheduler.
 
         Example:
             >>> # initialize engine

+ 1 - 1
aphrodite/engine/async_aphrodite.py

@@ -687,7 +687,7 @@ class AsyncAphrodite:
         Details:
             - If the engine is not running, start the background loop,
               which iteratively invokes
-              :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
+              :meth:`~aphrodite.engine.async_aphrodite.AsyncAphrodite.engine_step`
               to process the waiting requests.
             - Add the request to the engine's `RequestTracker`.
               On the next background loop, this request will be sent to

+ 10 - 8
aphrodite/modeling/model_loader/__init__.py

@@ -2,8 +2,8 @@ from typing import Optional
 
 from torch import nn
 
-from aphrodite.common.config import (DeviceConfig, LoadConfig, LoRAConfig,
-                                     ModelConfig, ParallelConfig,
+from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
+                                     LoRAConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig, VisionLanguageConfig)
 from aphrodite.modeling.model_loader.loader import (BaseModelLoader,
                                                     get_model_loader)
@@ -11,18 +11,20 @@ from aphrodite.modeling.model_loader.utils import (get_architecture_class_name,
                                                    get_model_architecture)
 
 
-def get_model(
-        *, model_config: ModelConfig, load_config: LoadConfig,
-        device_config: DeviceConfig, parallel_config: ParallelConfig,
-        scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
-        vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
+def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
+              device_config: DeviceConfig, parallel_config: ParallelConfig,
+              scheduler_config: SchedulerConfig,
+              lora_config: Optional[LoRAConfig],
+              vision_language_config: Optional[VisionLanguageConfig],
+              cache_config: CacheConfig) -> nn.Module:
     loader = get_model_loader(load_config)
     return loader.load_model(model_config=model_config,
                              device_config=device_config,
                              lora_config=lora_config,
                              vision_language_config=vision_language_config,
                              parallel_config=parallel_config,
-                             scheduler_config=scheduler_config)
+                             scheduler_config=scheduler_config,
+                             cache_config=cache_config)
 
 
 __all__ = [

+ 44 - 51
aphrodite/modeling/model_loader/loader.py

@@ -1,19 +1,17 @@
 # ruff: noqa: SIM117
 import copy
-import gc
 import glob
 import os
 from abc import ABC, abstractmethod
-from contextlib import nullcontext
-from typing import (Any, Dict, Generator, List, Optional, Tuple, Type)
+from typing import Any, Dict, Generator, List, Optional, Tuple, Type
 
 import huggingface_hub
 import torch
 from torch import nn
 
-from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, DeviceConfig,
-                                     LoadConfig, LoadFormat, LoRAConfig,
-                                     ModelConfig, ParallelConfig,
+from aphrodite.common.config import (APHRODITE_USE_MODELSCOPE, CacheConfig,
+                                     DeviceConfig, LoadConfig, LoadFormat,
+                                     LoRAConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig, VisionLanguageConfig)
 from aphrodite.modeling.model_loader.tensorizer import (
     TensorizerConfig, is_aphrodite_serialized_tensorizer, load_with_tensorizer,
@@ -26,8 +24,6 @@ from aphrodite.modeling.model_loader.weight_utils import (
     pt_weights_iterator, safetensors_weights_iterator)
 from aphrodite.modeling.models.llava import LlavaForConditionalGeneration
 from aphrodite.quantization.base_config import QuantizationConfig
-from aphrodite.quantization.bitsandbytes import (BNBLinearMethod,
-                                                 replace_quant_params)
 
 _VISION_MODEL_CLASSES = [
     LlavaForConditionalGeneration,
@@ -78,15 +74,16 @@ def _get_model_initialization_kwargs(
     return extra_kwargs
 
 
-def _initialize_model(
-        model_config: ModelConfig, load_config: LoadConfig,
-        lora_config: Optional[LoRAConfig],
-        vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
+def _initialize_model(model_config: ModelConfig, load_config: LoadConfig,
+                      lora_config: Optional[LoRAConfig],
+                      vision_language_config: Optional[VisionLanguageConfig],
+                      cache_config: CacheConfig) -> nn.Module:
     """Initialize a model with the given configurations."""
     model_class = get_model_architecture(model_config)[0]
     quant_config = _get_quantization_config(model_config, load_config)
 
     return model_class(config=model_config.hf_config,
+                       cache_config=cache_config,
                        quant_config=quant_config,
                        **_get_model_initialization_kwargs(
                            model_class, lora_config, vision_language_config))
@@ -104,7 +101,8 @@ class BaseModelLoader(ABC):
                    lora_config: Optional[LoRAConfig],
                    vision_language_config: Optional[VisionLanguageConfig],
                    parallel_config: ParallelConfig,
-                   scheduler_config: SchedulerConfig) -> nn.Module:
+                   scheduler_config: SchedulerConfig,
+                   cache_config: CacheConfig) -> nn.Module:
         """Load a model with the given configurations."""
         ...
 
@@ -218,18 +216,13 @@ class DefaultModelLoader(BaseModelLoader):
                    lora_config: Optional[LoRAConfig],
                    vision_language_config: Optional[VisionLanguageConfig],
                    parallel_config: ParallelConfig,
-                   scheduler_config: SchedulerConfig) -> nn.Module:
+                   scheduler_config: SchedulerConfig,
+                   cache_config: CacheConfig) -> nn.Module:
         with set_default_torch_dtype(model_config.dtype):
-            linear_method = _get_quantization_config(model_config,
-                                                     self.load_config)
-
-            context = torch.device(device_config.device) if not (
-                isinstance(linear_method, BNBLinearMethod)
-                and linear_method.quant_config.from_float) else nullcontext()
-
-            with context:
+            with torch.device(device_config.device):
                 model = _initialize_model(model_config, self.load_config,
-                                          lora_config, vision_language_config)
+                                          lora_config, vision_language_config,
+                                          cache_config)
             model.load_weights(
                 self._get_weights_iterator(model_config.model,
                                            model_config.revision,
@@ -246,18 +239,6 @@ class DefaultModelLoader(BaseModelLoader):
                 if hasattr(module, "process_weights_after_loading"):
                     module.process_weights_after_loading()
 
-            if isinstance(linear_method, BNBLinearMethod):
-                replace_quant_params(
-                    model,
-                    quant_config=linear_method.quant_config,
-                    modules_to_not_convert="lm_head",
-                )
-                torch.cuda.synchronize()
-                if linear_method.quant_config.from_float:
-                    model = model.cuda()
-                gc.collect()
-                torch.cuda.empty_cache()
-
         return model.eval()
 
 
@@ -275,12 +256,14 @@ class DummyModelLoader(BaseModelLoader):
                    lora_config: Optional[LoRAConfig],
                    vision_language_config: Optional[VisionLanguageConfig],
                    parallel_config: ParallelConfig,
-                   scheduler_config: SchedulerConfig) -> nn.Module:
+                   scheduler_config: SchedulerConfig,
+                   cache_config: CacheConfig) -> nn.Module:
         with set_default_torch_dtype(model_config.dtype):
             with torch.device(device_config.device):
                 model = _initialize_model(model_config, self.load_config,
-                                          lora_config, vision_language_config)
-            # NOTE(woosuk): For accurate performance evaluation, we assign
+                                          lora_config, vision_language_config,
+                                          cache_config)
+            # NOTE: For accurate performance evaluation, we assign
             # random values to the weights.
             initialize_dummy_weights(model)
         return model.eval()
@@ -308,9 +291,12 @@ class TensorizerLoader(BaseModelLoader):
         return tensorizer_weights_iterator(tensorizer_args)
 
     def _load_model_unserialized(
-            self, model_config: ModelConfig, device_config: DeviceConfig,
-            lora_config: Optional[LoRAConfig],
-            vision_language_config: Optional[VisionLanguageConfig]
+        self,
+        model_config: ModelConfig,
+        device_config: DeviceConfig,
+        lora_config: Optional[LoRAConfig],
+        vision_language_config: Optional[VisionLanguageConfig],
+        cache_config: CacheConfig,
     ) -> nn.Module:
         """Load an unserialized model with tensorizer.
 
@@ -321,20 +307,23 @@ class TensorizerLoader(BaseModelLoader):
         with set_default_torch_dtype(model_config.dtype):
             with torch.device(device_config.device):
                 model = _initialize_model(model_config, self.load_config,
-                                          lora_config, vision_language_config)
+                                          lora_config, vision_language_config,
+                                          cache_config)
 
             model.load_weights(self._get_weights_iterator())
         return model.eval()
 
     def _load_model_serialized(
-            self, model_config: ModelConfig, device_config: DeviceConfig,
-            lora_config: Optional[LoRAConfig],
-            vision_language_config: Optional[VisionLanguageConfig]
+        self,
+        model_config: ModelConfig,
+        device_config: DeviceConfig,
+        lora_config: Optional[LoRAConfig],
+        vision_language_config: Optional[VisionLanguageConfig],
+        cache_config: CacheConfig,
     ) -> nn.Module:
         """Load a serialized model with tensorizer.
-
-        See the examples/tensorize_aphrodite_model.py example "
-        script for serializing Aphrodite models."""
+        See the examples/tensorize_vllm_model.py example "
+        script for serializing vLLM models."""
         with set_default_torch_dtype(model_config.dtype):
             with torch.device(device_config.device):
                 model_class = get_model_architecture(model_config)[0]
@@ -343,6 +332,7 @@ class TensorizerLoader(BaseModelLoader):
                 extra_kwargs = _get_model_initialization_kwargs(
                     model_class, lora_config, vision_language_config)
                 extra_kwargs["quant_config"] = quant_config
+                extra_kwargs["cache_config"] = cache_config
 
                 tensorizer_config = copy.copy(self.tensorizer_config)
                 tensorizer_config.model_class = model_class
@@ -357,16 +347,19 @@ class TensorizerLoader(BaseModelLoader):
                    lora_config: Optional[LoRAConfig],
                    vision_language_config: Optional[VisionLanguageConfig],
                    parallel_config: ParallelConfig,
-                   scheduler_config: SchedulerConfig) -> nn.Module:
+                   scheduler_config: SchedulerConfig,
+                   cache_config: CacheConfig) -> nn.Module:
         self._verify_config(model_config, parallel_config)
 
         if is_aphrodite_serialized_tensorizer(self.tensorizer_config):
             return self._load_model_serialized(model_config, device_config,
                                                lora_config,
-                                               vision_language_config)
+                                               vision_language_config,
+                                               cache_config)
         return self._load_model_unserialized(model_config, device_config,
                                              lora_config,
-                                             vision_language_config)
+                                             vision_language_config,
+                                             cache_config)
 
 
 def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:

+ 1 - 1
aphrodite/modeling/models/__init__.py

@@ -2,8 +2,8 @@ import importlib
 from typing import Dict, List, Optional, Type
 
 import torch.nn as nn
-from loguru import logger
 
+from loguru import logger
 from aphrodite.common.utils import is_hip
 
 # Architecture -> (module, class).

+ 13 - 3
aphrodite/modeling/models/arctic.py

@@ -6,6 +6,7 @@ from loguru import logger
 from torch import nn
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
@@ -212,6 +213,7 @@ class ArcticAttention(nn.Module):
         self,
         config: ArcticConfig,
         layer_idx: Optional[int] = None,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -262,7 +264,8 @@ class ArcticAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               self.scaling,
-                              num_kv_heads=self.num_kv_heads)
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -285,6 +288,7 @@ class ArcticDecoderLayer(nn.Module):
         self,
         config: ArcticConfig,
         layer_idx: int,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -294,6 +298,7 @@ class ArcticDecoderLayer(nn.Module):
         self.use_residual = config.use_residual and is_moe_layer
         self.self_attn = ArcticAttention(config,
                                          layer_idx,
+                                         cache_config,
                                          quant_config=quant_config)
         self.block_sparse_moe = ArcticMoE(
             config,
@@ -353,6 +358,7 @@ class ArcticModel(nn.Module):
     def __init__(
         self,
         config: ArcticConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -363,7 +369,10 @@ class ArcticModel(nn.Module):
             config.hidden_size,
             org_num_embeddings=self.vocab_size)
         self.layers = nn.ModuleList([
-            ArcticDecoderLayer(config, layer_idx, quant_config=quant_config)
+            ArcticDecoderLayer(config,
+                               layer_idx,
+                               cache_config,
+                               quant_config=quant_config)
             for layer_idx in range(config.num_hidden_layers)
         ])
         self._attn_implementation = config._attn_implementation
@@ -389,11 +398,12 @@ class ArcticForCausalLM(nn.Module):
 
     def __init__(self,
                  config: ArcticConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None,
                  **kwargs) -> None:
         super().__init__()
         self.config = config
-        self.model = ArcticModel(config, quant_config)
+        self.model = ArcticModel(config, cache_config, quant_config)
         self.vocab_size = config.vocab_size
         self.lm_head = ParallelLMHead(
             self.vocab_size,

+ 22 - 7
aphrodite/modeling/models/baichuan.py

@@ -26,7 +26,7 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
-from aphrodite.common.config import LoRAConfig
+from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
@@ -110,6 +110,7 @@ class BaiChuanAttention(nn.Module):
         position_embedding: str,
         rope_theta: float = 10000,
         max_position_embeddings: int = 8192,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -161,7 +162,10 @@ class BaiChuanAttention(nn.Module):
                 base=self.rope_theta,
             )
             self.scaling = self.head_dim**-0.5
-            self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
+            self.attn = Attention(self.num_heads,
+                                  self.head_dim,
+                                  self.scaling,
+                                  cache_config=cache_config)
 
     def forward(
         self,
@@ -184,6 +188,7 @@ class BaiChuanDecoderLayer(nn.Module):
     def __init__(self,
                  config: PretrainedConfig,
                  position_embedding: str,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.hidden_size = config.hidden_size
@@ -196,6 +201,7 @@ class BaiChuanDecoderLayer(nn.Module):
             position_embedding=position_embedding,
             rope_theta=rope_theta,
             max_position_embeddings=max_position_embeddings,
+            cache_config=cache_config,
             quant_config=quant_config,
         )
         self.mlp = BaiChuanMLP(
@@ -243,6 +249,7 @@ class BaiChuanModel(nn.Module):
     def __init__(self,
                  config: PretrainedConfig,
                  position_embedding: str,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
@@ -254,7 +261,8 @@ class BaiChuanModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            BaiChuanDecoderLayer(config, position_embedding, quant_config)
+            BaiChuanDecoderLayer(config, position_embedding, cache_config,
+                                 quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -303,13 +311,15 @@ class BaiChuanBaseForCausalLM(nn.Module):
         self,
         config,
         position_embedding: str,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = BaiChuanModel(config, position_embedding, quant_config)
+        self.model = BaiChuanModel(config, position_embedding, cache_config,
+                                   quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
@@ -388,13 +398,16 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ):
         if config.hidden_size == 4096:  # baichuan2 7b
-            super().__init__(config, "ROPE", quant_config, lora_config)
+            super().__init__(config, "ROPE", cache_config, quant_config,
+                             lora_config)
         else:  # baichuan 13b, baichuan2 13b
-            super().__init__(config, "ALIBI", quant_config, lora_config)
+            super().__init__(config, "ALIBI", cache_config, quant_config,
+                             lora_config)
 
 
 class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
@@ -403,7 +416,9 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ):
-        super().__init__(config, "ROPE", quant_config, lora_config)
+        super().__init__(config, "ROPE", cache_config, quant_config,
+                         lora_config)

+ 11 - 4
aphrodite/modeling/models/bloom.py

@@ -24,6 +24,7 @@ from torch import nn
 from transformers import BloomConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
@@ -70,6 +71,7 @@ class BloomAttention(nn.Module):
     def __init__(
         self,
         config: BloomConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -107,7 +109,8 @@ class BloomAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               scaling,
-                              alibi_slopes=alibi_slopes)
+                              alibi_slopes=alibi_slopes,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -157,6 +160,7 @@ class BloomBlock(nn.Module):
     def __init__(
         self,
         config: BloomConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -164,7 +168,8 @@ class BloomBlock(nn.Module):
 
         self.input_layernorm = nn.LayerNorm(hidden_size,
                                             eps=config.layer_norm_epsilon)
-        self.self_attention = BloomAttention(config, quant_config)
+        self.self_attention = BloomAttention(config, cache_config,
+                                             quant_config)
         self.post_attention_layernorm = nn.LayerNorm(
             hidden_size, eps=config.layer_norm_epsilon)
         self.mlp = BloomMLP(config, quant_config)
@@ -213,6 +218,7 @@ class BloomModel(nn.Module):
     def __init__(
         self,
         config: BloomConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -228,7 +234,7 @@ class BloomModel(nn.Module):
 
         # Transformer blocks
         self.h = nn.ModuleList([
-            BloomBlock(config, quant_config)
+            BloomBlock(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
 
@@ -261,12 +267,13 @@ class BloomForCausalLM(nn.Module):
     def __init__(
         self,
         config: BloomConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.transformer = BloomModel(config, quant_config)
+        self.transformer = BloomModel(config, cache_config, quant_config)
         self.lm_head_weight = self.transformer.word_embeddings.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 14 - 6
aphrodite/modeling/models/chatglm.py

@@ -9,7 +9,7 @@ from torch import nn
 from torch.nn import LayerNorm
 
 from aphrodite.attention import Attention, AttentionMetadata
-from aphrodite.common.config import LoRAConfig
+from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -33,6 +33,7 @@ class GLMAttention(nn.Module):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -89,6 +90,7 @@ class GLMAttention(nn.Module):
             self.head_dim,
             self.scaling,
             num_kv_heads=self.num_kv_heads,
+            cache_config=cache_config,
         )
 
     def forward(
@@ -166,6 +168,7 @@ class GLMBlock(nn.Module):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -180,7 +183,7 @@ class GLMBlock(nn.Module):
                                                eps=config.layernorm_epsilon)
 
         # Self attention.
-        self.self_attention = GLMAttention(config, quant_config)
+        self.self_attention = GLMAttention(config, cache_config, quant_config)
         self.hidden_dropout = config.hidden_dropout
 
         # Layernorm on the attention output
@@ -236,6 +239,7 @@ class GLMTransformer(nn.Module):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -245,8 +249,10 @@ class GLMTransformer(nn.Module):
         self.num_layers = config.num_layers
 
         # Transformer layers.
-        self.layers = nn.ModuleList(
-            [GLMBlock(config, quant_config) for i in range(self.num_layers)])
+        self.layers = nn.ModuleList([
+            GLMBlock(config, cache_config, quant_config)
+            for i in range(self.num_layers)
+        ])
 
         if self.post_layer_norm:
             layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
@@ -281,6 +287,7 @@ class ChatGLMModel(nn.Module):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -291,7 +298,7 @@ class ChatGLMModel(nn.Module):
         self.num_layers = config.num_layers
         self.multi_query_group_num = config.multi_query_group_num
         self.kv_channels = config.kv_channels
-        self.encoder = GLMTransformer(config, quant_config)
+        self.encoder = GLMTransformer(config, cache_config, quant_config)
 
         self.output_layer = ParallelLMHead(config.padded_vocab_size,
                                            config.hidden_size)
@@ -333,13 +340,14 @@ class ChatGLMForCausalLM(nn.Module):
     def __init__(
         self,
         config: ChatGLMConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ):
         super().__init__()
         self.config: ChatGLMConfig = config
         self.quant_config = quant_config
-        self.transformer = ChatGLMModel(config, quant_config)
+        self.transformer = ChatGLMModel(config, cache_config, quant_config)
         self.lm_head_weight = self.transformer.output_layer.weight
         self.logits_processor = LogitsProcessor(config.padded_vocab_size)
         self.sampler = Sampler()

+ 11 - 3
aphrodite/modeling/models/commandr.py

@@ -29,6 +29,7 @@ from torch.nn.parameter import Parameter
 from transformers import CohereConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
@@ -123,6 +124,7 @@ class CohereAttention(nn.Module):
     def __init__(
         self,
         config: CohereConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -179,6 +181,7 @@ class CohereAttention(nn.Module):
             self.head_dim,
             self.scaling,
             num_kv_heads=self.num_kv_heads,
+            cache_config=cache_config,
         )
         if self.use_qk_norm:
             self.q_norm = LayerNorm(param_shape=(self.num_heads,
@@ -218,11 +221,14 @@ class CohereDecoderLayer(nn.Module):
 
     def __init__(self,
                  config: CohereConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.hidden_size = config.hidden_size
 
-        self.self_attn = CohereAttention(config, quant_config=quant_config)
+        self.self_attn = CohereAttention(config,
+                                         cache_config,
+                                         quant_config=quant_config)
 
         self.mlp = CohereMLP(config, quant_config=quant_config)
         self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
@@ -257,6 +263,7 @@ class CohereModel(nn.Module):
     def __init__(
         self,
         config: CohereConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -265,7 +272,7 @@ class CohereModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size)
         self.layers = nn.ModuleList([
-            CohereDecoderLayer(config, quant_config=quant_config)
+            CohereDecoderLayer(config, cache_config, quant_config=quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = LayerNorm(param_shape=(config.hidden_size),
@@ -298,6 +305,7 @@ class CohereForCausalLM(nn.Module):
     def __init__(
         self,
         config: CohereConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -305,7 +313,7 @@ class CohereForCausalLM(nn.Module):
         self.quant_config = quant_config
         self.logits_processor = LogitsProcessor(config.vocab_size,
                                                 scale=config.logit_scale)
-        self.model = CohereModel(config, quant_config)
+        self.model = CohereModel(config, cache_config, quant_config)
         self.sampler = Sampler()
 
     @torch.no_grad()

+ 13 - 4
aphrodite/modeling/models/dbrx.py

@@ -5,6 +5,7 @@ import torch
 import torch.nn as nn
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
@@ -165,6 +166,7 @@ class DbrxAttention(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -220,6 +222,7 @@ class DbrxAttention(nn.Module):
             self.head_dim,
             self.scaling,
             num_kv_heads=self.num_kv_heads,
+            cache_config=cache_config,
         )
 
     def forward(
@@ -278,10 +281,12 @@ class DbrxBlock(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
-        self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
+        self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config,
+                                                     quant_config)
         self.ffn = DbrxExperts(config, quant_config)
 
     def forward(
@@ -307,6 +312,7 @@ class DbrxModel(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -314,8 +320,10 @@ class DbrxModel(nn.Module):
             config.vocab_size,
             config.d_model,
         )
-        self.blocks = nn.ModuleList(
-            [DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
+        self.blocks = nn.ModuleList([
+            DbrxBlock(config, cache_config, quant_config)
+            for _ in range(config.n_layers)
+        ])
         self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
         for module in self.modules():
             if hasattr(module, "bias") and isinstance(module.bias,
@@ -348,13 +356,14 @@ class DbrxForCausalLM(nn.Module):
     def __init__(
         self,
         config: DbrxConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
         self.unpadded_vocab_size = config.vocab_size
-        self.transformer = DbrxModel(config, quant_config)
+        self.transformer = DbrxModel(config, cache_config, quant_config)
         self.lm_head = ParallelLMHead(
             config.vocab_size,
             config.d_model,

+ 3 - 1
aphrodite/modeling/models/decilm.py

@@ -28,7 +28,7 @@ from typing import Iterable, Optional, Tuple
 import torch
 from transformers import PretrainedConfig
 
-from aphrodite.common.config import LoRAConfig
+from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.modeling.model_loader.weight_utils import default_weight_loader
 from aphrodite.modeling.models.llama import LlamaForCausalLM
 from aphrodite.quantization.base_config import QuantizationConfig
@@ -55,12 +55,14 @@ class DeciLMForCausalLM(LlamaForCausalLM):
     def __init__(
         self,
         config: Optional[PretrainedConfig] = None,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
         delattr(config, "num_key_value_heads_per_layer")
         super().__init__(config=config,
+                         cache_config=cache_config,
                          quant_config=quant_config,
                          lora_config=lora_config)
 

+ 13 - 3
aphrodite/modeling/models/deepseek.py

@@ -28,6 +28,7 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
@@ -177,6 +178,7 @@ class DeepseekAttention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -228,7 +230,8 @@ class DeepseekAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               self.scaling,
-                              num_kv_heads=self.num_kv_heads)
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -251,6 +254,7 @@ class DeepseekDecoderLayer(nn.Module):
         self,
         config: PretrainedConfig,
         layer_idx: int,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -266,6 +270,7 @@ class DeepseekDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
+            cache_config=cache_config,
             quant_config=quant_config,
         )
         if (config.n_routed_experts is not None
@@ -320,6 +325,7 @@ class DeepseekModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -331,7 +337,10 @@ class DeepseekModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
+            DeepseekDecoderLayer(config,
+                                 layer_idx,
+                                 cache_config,
+                                 quant_config=quant_config)
             for layer_idx in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -359,12 +368,13 @@ class DeepseekForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = DeepseekModel(config, quant_config)
+        self.model = DeepseekModel(config, cache_config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 11 - 4
aphrodite/modeling/models/falcon.py

@@ -27,6 +27,7 @@ from torch.nn import LayerNorm
 from transformers import FalconConfig as HF_FalconConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
@@ -76,6 +77,7 @@ class FalconAttention(nn.Module):
     def __init__(
         self,
         config: FalconConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -167,7 +169,8 @@ class FalconAttention(nn.Module):
             self.attn = Attention(self.num_heads,
                                   self.head_dim,
                                   scale=self.inv_norm_factor,
-                                  num_kv_heads=self.num_kv_heads)
+                                  num_kv_heads=self.num_kv_heads,
+                                  cache_config=cache_config)
 
     def forward(
         self,
@@ -228,12 +231,14 @@ class FalconDecoderLayer(nn.Module):
     def __init__(
         self,
         config: FalconConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.hidden_size
         self.num_heads = config.num_attention_heads
-        self.self_attention = FalconAttention(config, quant_config)
+        self.self_attention = FalconAttention(config, cache_config,
+                                              quant_config)
         self.mlp = FalconMLP(config, quant_config)
         self.config = config
 
@@ -310,6 +315,7 @@ class FalconModel(nn.Module):
     def __init__(
         self,
         config: FalconConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -326,7 +332,7 @@ class FalconModel(nn.Module):
 
         # Transformer blocks
         self.h = nn.ModuleList([
-            FalconDecoderLayer(config, quant_config)
+            FalconDecoderLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
 
@@ -358,12 +364,13 @@ class FalconForCausalLM(nn.Module):
     def __init__(
         self,
         config: FalconConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.transformer = FalconModel(config, quant_config)
+        self.transformer = FalconModel(config, cache_config, quant_config)
         self.lm_head_weight = self.transformer.word_embeddings.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 10 - 4
aphrodite/modeling/models/gemma.py

@@ -23,7 +23,7 @@ from torch import nn
 from transformers import GemmaConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
-from aphrodite.common.config import LoRAConfig
+from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import GeluAndMul
@@ -104,6 +104,7 @@ class GemmaAttention(nn.Module):
                  head_dim: int,
                  max_position_embeddings: int = 8192,
                  rope_theta: float = 10000,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -152,7 +153,8 @@ class GemmaAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               self.scaling,
-                              num_kv_heads=self.num_kv_heads)
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -174,6 +176,7 @@ class GemmaDecoderLayer(nn.Module):
     def __init__(
         self,
         config: GemmaConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -185,6 +188,7 @@ class GemmaDecoderLayer(nn.Module):
             head_dim=config.head_dim,
             max_position_embeddings=config.max_position_embeddings,
             rope_theta=config.rope_theta,
+            cache_config=cache_config,
             quant_config=quant_config,
         )
         self.mlp = GemmaMLP(
@@ -233,6 +237,7 @@ class GemmaModel(nn.Module):
     def __init__(
         self,
         config: GemmaConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -243,7 +248,7 @@ class GemmaModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            GemmaDecoderLayer(config, quant_config)
+            GemmaDecoderLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -306,6 +311,7 @@ class GemmaForCausalLM(nn.Module):
     def __init__(
         self,
         config: GemmaConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
@@ -313,7 +319,7 @@ class GemmaForCausalLM(nn.Module):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = GemmaModel(config, quant_config)
+        self.model = GemmaModel(config, cache_config, quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
 

+ 12 - 4
aphrodite/modeling/models/gpt2.py

@@ -24,6 +24,7 @@ from torch import nn
 from transformers import GPT2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -44,6 +45,7 @@ class GPT2Attention(nn.Module):
     def __init__(
         self,
         config: GPT2Config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -69,7 +71,10 @@ class GPT2Attention(nn.Module):
             bias=True,
             quant_config=quant_config,
         )
-        self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scale=self.scale,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -121,6 +126,7 @@ class GPT2Block(nn.Module):
     def __init__(
         self,
         config: GPT2Config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -129,7 +135,7 @@ class GPT2Block(nn.Module):
                      hidden_size)
 
         self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
-        self.attn = GPT2Attention(config, quant_config)
+        self.attn = GPT2Attention(config, cache_config, quant_config)
         self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
         self.mlp = GPT2MLP(inner_dim, config, quant_config)
 
@@ -162,6 +168,7 @@ class GPT2Model(nn.Module):
     def __init__(
         self,
         config: GPT2Config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -173,7 +180,7 @@ class GPT2Model(nn.Module):
         self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.h = nn.ModuleList([
-            GPT2Block(config, quant_config)
+            GPT2Block(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -202,12 +209,13 @@ class GPT2LMHeadModel(nn.Module):
     def __init__(
         self,
         config: GPT2Config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.transformer = GPT2Model(config, quant_config)
+        self.transformer = GPT2Model(config, cache_config, quant_config)
         self.lm_head_weight = self.transformer.wte.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 10 - 4
aphrodite/modeling/models/gpt_bigcode.py

@@ -25,6 +25,7 @@ from torch import nn
 from transformers import GPTBigCodeConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -45,6 +46,7 @@ class GPTBigCodeAttention(nn.Module):
     def __init__(
         self,
         config: GPTBigCodeConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -84,7 +86,8 @@ class GPTBigCodeAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               scale=self.scale,
-                              num_kv_heads=self.num_kv_heads)
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -142,6 +145,7 @@ class GPTBigCodeBlock(nn.Module):
     def __init__(
         self,
         config: GPTBigCodeConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -150,7 +154,7 @@ class GPTBigCodeBlock(nn.Module):
                      hidden_size)
 
         self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
-        self.attn = GPTBigCodeAttention(config, quant_config)
+        self.attn = GPTBigCodeAttention(config, cache_config, quant_config)
         self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
         self.mlp = GPTBigMLP(inner_dim, config, quant_config)
 
@@ -183,6 +187,7 @@ class GPTBigCodeModel(nn.Module):
     def __init__(
         self,
         config: GPTBigCodeConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -194,7 +199,7 @@ class GPTBigCodeModel(nn.Module):
         self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.h = nn.ModuleList([
-            GPTBigCodeBlock(config, quant_config)
+            GPTBigCodeBlock(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -223,12 +228,13 @@ class GPTBigCodeForCausalLM(nn.Module):
     def __init__(
         self,
         config: GPTBigCodeConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.transformer = GPTBigCodeModel(config, quant_config)
+        self.transformer = GPTBigCodeModel(config, cache_config, quant_config)
         self.lm_head_weight = self.transformer.wte.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 15 - 5
aphrodite/modeling/models/gpt_j.py

@@ -23,6 +23,7 @@ from torch import nn
 from transformers import GPTJConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -44,6 +45,7 @@ class GPTJAttention(nn.Module):
     def __init__(
         self,
         config: GPTJConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -82,7 +84,10 @@ class GPTJAttention(nn.Module):
             base=rope_theta,
             is_neox_style=False,
         )
-        self.attn = Attention(self.num_heads, self.head_size, scaling)
+        self.attn = Attention(self.num_heads,
+                              self.head_size,
+                              scaling,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -134,13 +139,14 @@ class GPTJBlock(nn.Module):
     def __init__(
         self,
         config: GPTJConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         inner_dim = (4 * config.n_embd
                      if config.n_inner is None else config.n_inner)
         self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
-        self.attn = GPTJAttention(config, quant_config)
+        self.attn = GPTJAttention(config, cache_config, quant_config)
         self.mlp = GPTJMLP(inner_dim, config, quant_config)
 
     def forward(
@@ -168,6 +174,7 @@ class GPTJModel(nn.Module):
     def __init__(
         self,
         config: GPTJConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -177,8 +184,10 @@ class GPTJModel(nn.Module):
             config.vocab_size,
             self.embed_dim,
         )
-        self.h = nn.ModuleList(
-            [GPTJBlock(config, quant_config) for _ in range(config.n_layer)])
+        self.h = nn.ModuleList([
+            GPTJBlock(config, cache_config, quant_config)
+            for _ in range(config.n_layer)
+        ])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
     def forward(
@@ -206,13 +215,14 @@ class GPTJForCausalLM(nn.Module):
     def __init__(
         self,
         config: GPTJConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
         assert not config.tie_word_embeddings
-        self.transformer = GPTJModel(config, quant_config)
+        self.transformer = GPTJModel(config, cache_config, quant_config)
         self.lm_head = ParallelLMHead(
             config.vocab_size,
             config.n_embd,

+ 12 - 4
aphrodite/modeling/models/gpt_neox.py

@@ -23,6 +23,7 @@ from torch import nn
 from transformers import GPTNeoXConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -44,6 +45,7 @@ class GPTNeoXAttention(nn.Module):
     def __init__(
         self,
         config: GPTNeoXConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -83,7 +85,10 @@ class GPTNeoXAttention(nn.Module):
             max_position=max_position_embeddings,
             base=rope_theta,
         )
-        self.attn = Attention(self.num_heads, self.head_size, scaling)
+        self.attn = Attention(self.num_heads,
+                              self.head_size,
+                              scaling,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -133,6 +138,7 @@ class GPTNeoXLayer(nn.Module):
     def __init__(
         self,
         config: GPTNeoXConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -141,7 +147,7 @@ class GPTNeoXLayer(nn.Module):
                                             eps=config.layer_norm_eps)
         self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                      eps=config.layer_norm_eps)
-        self.attention = GPTNeoXAttention(config, quant_config)
+        self.attention = GPTNeoXAttention(config, cache_config, quant_config)
         self.mlp = GPTNeoXMLP(config, quant_config)
 
     def forward(
@@ -181,6 +187,7 @@ class GPTNeoXModel(nn.Module):
     def __init__(
         self,
         config: GPTNeoXConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -191,7 +198,7 @@ class GPTNeoXModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            GPTNeoXLayer(config, quant_config)
+            GPTNeoXLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.final_layer_norm = nn.LayerNorm(config.hidden_size,
@@ -222,12 +229,13 @@ class GPTNeoXForCausalLM(nn.Module):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.gpt_neox = GPTNeoXModel(config, quant_config)
+        self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config)
         self.embed_out = ParallelLMHead(
             config.vocab_size,
             config.hidden_size,

+ 10 - 3
aphrodite/modeling/models/internlm2.py

@@ -6,6 +6,7 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -63,6 +64,7 @@ class InternLM2Attention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -113,7 +115,8 @@ class InternLM2Attention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               self.scaling,
-                              num_kv_heads=self.num_kv_heads)
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -135,6 +138,7 @@ class InternLMDecoderLayer(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -150,6 +154,7 @@ class InternLMDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
+            cache_config=cache_config,
             quant_config=quant_config,
         )
         self.feed_forward = InternLM2MLP(
@@ -195,6 +200,7 @@ class InternLM2Model(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -206,7 +212,7 @@ class InternLM2Model(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            InternLMDecoderLayer(config, quant_config)
+            InternLMDecoderLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -238,12 +244,13 @@ class InternLM2ForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = InternLM2Model(config, quant_config)
+        self.model = InternLM2Model(config, cache_config, quant_config)
         self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 9 - 3
aphrodite/modeling/models/jais.py

@@ -26,6 +26,7 @@ import torch
 from torch import nn
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
@@ -68,6 +69,7 @@ class JAISAttention(nn.Module):
     def __init__(
         self,
         config: JAISConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -107,6 +109,7 @@ class JAISAttention(nn.Module):
             self.head_dim,
             scale=self.scale,
             alibi_slopes=alibi_slopes,
+            cache_config=cache_config,
         )
 
     def forward(
@@ -169,6 +172,7 @@ class JAISBlock(nn.Module):
     def __init__(
         self,
         config: JAISConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -177,7 +181,7 @@ class JAISBlock(nn.Module):
                      hidden_size)
 
         self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
-        self.attn = JAISAttention(config, quant_config)
+        self.attn = JAISAttention(config, cache_config, quant_config)
         self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
         self.mlp = JAISMLP(inner_dim, config, quant_config)
 
@@ -210,6 +214,7 @@ class JAISModel(nn.Module):
     def __init__(
         self,
         config: JAISConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -227,7 +232,7 @@ class JAISModel(nn.Module):
         else:
             self.embeddings_scale = config.mup_embeddings_scale
         self.h = nn.ModuleList([
-            JAISBlock(config, quant_config)
+            JAISBlock(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -261,12 +266,13 @@ class JAISLMHeadModel(nn.Module):
     def __init__(
         self,
         config: JAISConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.transformer = JAISModel(config, quant_config)
+        self.transformer = JAISModel(config, cache_config, quant_config)
         self.lm_head_weight = self.transformer.wte.weight
         if hasattr(config, "width_scale"):
             self.output_logits_scale = config.width_scale

+ 22 - 14
aphrodite/modeling/models/llama.py

@@ -28,9 +28,8 @@ from torch import nn
 from transformers import LlamaConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
-from aphrodite.common.config import LoRAConfig
+from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
-from aphrodite.common.utils import is_hip
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -47,6 +46,7 @@ from aphrodite.modeling.model_loader.weight_utils import (
     default_weight_loader, kv_cache_scales_loader)
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.quantization.base_config import QuantizationConfig
+from aphrodite.common.utils import is_hip
 
 
 class LlamaMLP(nn.Module):
@@ -84,7 +84,6 @@ class LlamaAttention(nn.Module):
 
     def __init__(
         self,
-        config: LlamaConfig,
         hidden_size: int,
         num_heads: int,
         num_kv_heads: int,
@@ -94,6 +93,7 @@ class LlamaAttention(nn.Module):
         quant_config: Optional[QuantizationConfig] = None,
         bias: bool = False,
         sliding_window: Optional[int] = None,
+        cache_config: Optional[CacheConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -111,8 +111,7 @@ class LlamaAttention(nn.Module):
             # the KV heads across multiple tensor parallel GPUs.
             assert tp_size % self.total_num_kv_heads == 0
         self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
-        self.head_dim = getattr(config, "head_dim",
-                                hidden_size // self.total_num_heads)
+        self.head_dim = hidden_size // self.total_num_heads
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
@@ -154,7 +153,8 @@ class LlamaAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              sliding_window=sliding_window)
+                              sliding_window=sliding_window,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -177,6 +177,7 @@ class LlamaDecoderLayer(nn.Module):
     def __init__(
         self,
         config: LlamaConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -195,7 +196,6 @@ class LlamaDecoderLayer(nn.Module):
         attention_bias = getattr(config, "attention_bias", False) or getattr(
             config, "bias", False)
         self.self_attn = LlamaAttention(
-            config=config,
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
             num_kv_heads=getattr(config, "num_key_value_heads",
@@ -206,12 +206,15 @@ class LlamaDecoderLayer(nn.Module):
             quant_config=quant_config,
             bias=attention_bias,
             sliding_window=sliding_window,
+            cache_config=cache_config,
+        )
+        self.mlp = LlamaMLP(
+            hidden_size=self.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+            quant_config=quant_config,
+            bias=getattr(config, "mlp_bias", False),
         )
-        self.mlp = LlamaMLP(hidden_size=self.hidden_size,
-                            intermediate_size=config.intermediate_size,
-                            hidden_act=config.hidden_act,
-                            quant_config=quant_config,
-                            bias=getattr(config, "mlp_bias", False))
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
         self.post_attention_layernorm = RMSNorm(config.hidden_size,
@@ -251,6 +254,7 @@ class LlamaModel(nn.Module):
     def __init__(
         self,
         config: LlamaConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
@@ -267,7 +271,7 @@ class LlamaModel(nn.Module):
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
-            LlamaDecoderLayer(config, quant_config)
+            LlamaDecoderLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -332,12 +336,16 @@ class LlamaForCausalLM(nn.Module):
     def __init__(
         self,
         config: LlamaConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
-        self.model = LlamaModel(config, quant_config, lora_config=lora_config)
+        self.model = LlamaModel(config,
+                                cache_config,
+                                quant_config,
+                                lora_config=lora_config)
         self.unpadded_vocab_size = config.vocab_size
         if lora_config:
             self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

+ 2 - 0
aphrodite/modeling/models/llama_embedding.py

@@ -13,8 +13,10 @@ from aphrodite.modeling.pooling_metadata import PoolingMetadata
 
 class LlamaEmbeddingModel(nn.Module):
     """A model that uses Llama with additional embedding functionalities.
+
    This class encapsulates the LlamaModel and provides an interface for
    embedding operations and customized pooling functions.
+
    Attributes:
        model: An instance of LlamaModel used for forward operations.
        _pooler: An instance of Pooler used for pooling operations.

+ 4 - 2
aphrodite/modeling/models/llava.py

@@ -7,7 +7,7 @@ from torch import nn
 from transformers import CLIPVisionModel, LlavaConfig
 
 from aphrodite.attention import AttentionMetadata
-from aphrodite.common.config import VisionLanguageConfig
+from aphrodite.common.config import CacheConfig, VisionLanguageConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.logits_processor import LogitsProcessor
@@ -61,6 +61,7 @@ class LlavaForConditionalGeneration(nn.Module):
     def __init__(self,
                  config: "LlavaConfig",
                  vision_language_config: VisionLanguageConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional["QuantizationConfig"] = None) -> None:
         super().__init__()
         self.config = config
@@ -84,7 +85,8 @@ class LlavaForConditionalGeneration(nn.Module):
             projector_hidden_act=config.projector_hidden_act)
 
         self.quant_config = quant_config
-        self.language_model = LlamaModel(config.text_config, quant_config)
+        self.language_model = LlamaModel(config.text_config, cache_config,
+                                         quant_config)
         self.unpadded_vocab_size = config.text_config.vocab_size
         self.lm_head = ParallelLMHead(
             self.unpadded_vocab_size,

+ 10 - 3
aphrodite/modeling/models/minicpm.py

@@ -28,7 +28,7 @@ import torch
 from torch import nn
 
 from aphrodite.attention import Attention, AttentionMetadata
-from aphrodite.common.config import LoRAConfig
+from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
@@ -180,6 +180,7 @@ class MiniCPMAttention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -233,7 +234,8 @@ class MiniCPMAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               self.scaling,
-                              num_kv_heads=self.num_kv_heads)
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -258,6 +260,7 @@ class MiniCPMDecoderLayer(nn.Module):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -274,6 +277,7 @@ class MiniCPMDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
+            cache_config=cache_config,
             quant_config=quant_config,
         )
         self.num_experts = getattr(self.config, "num_experts", 0)
@@ -329,6 +333,7 @@ class MiniCPMModel(nn.Module):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
@@ -345,7 +350,7 @@ class MiniCPMModel(nn.Module):
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
-            MiniCPMDecoderLayer(config, quant_config)
+            MiniCPMDecoderLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -412,6 +417,7 @@ class MiniCPMForCausalLM(nn.Module):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
@@ -420,6 +426,7 @@ class MiniCPMForCausalLM(nn.Module):
         self.num_experts = getattr(self.config, "num_experts", 0)
         self.quant_config = quant_config
         self.model = MiniCPMModel(config,
+                                  cache_config,
                                   quant_config,
                                   lora_config=lora_config)
         unpadded_vocab_size = config.vocab_size

+ 12 - 4
aphrodite/modeling/models/mixtral.py

@@ -1,7 +1,6 @@
 # coding=utf-8
 # Adapted from
 # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
-# Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
 #
@@ -29,7 +28,7 @@ from torch import nn
 from transformers import MixtralConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
-from aphrodite.common.config import LoRAConfig
+from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.common.utils import print_warning_once
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
@@ -79,7 +78,7 @@ class MixtralMoE(nn.Module):
         self.intermediate_size = intermediate_size // self.tp_size
         self.quant_config = quant_config
 
-        # FIXME: Make this more general to support different
+        # FIXME(pcmoritz): Make this more general to support different
         # quantization schemes
         self.use_fp8 = isinstance(quant_config, Fp8Config)
 
@@ -251,6 +250,7 @@ class MixtralAttention(nn.Module):
                  num_kv_heads: int,
                  max_position: int = 4096 * 32,
                  rope_theta: float = 10000,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None,
                  sliding_window: Optional[int] = None) -> None:
         super().__init__()
@@ -312,6 +312,7 @@ class MixtralAttention(nn.Module):
             self.scaling,
             num_kv_heads=self.num_kv_heads,
             sliding_window=self.sliding_window,
+            cache_config=cache_config,
         )
 
     def forward(
@@ -334,6 +335,7 @@ class MixtralDecoderLayer(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -347,6 +349,7 @@ class MixtralDecoderLayer(nn.Module):
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
             sliding_window=config.sliding_window,
+            cache_config=cache_config,
             quant_config=quant_config)
         self.block_sparse_moe = MixtralMoE(
             num_experts=config.num_local_experts,
@@ -393,6 +396,7 @@ class MixtralModel(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
@@ -409,7 +413,9 @@ class MixtralModel(nn.Module):
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
-            MixtralDecoderLayer(config, quant_config=quant_config)
+            MixtralDecoderLayer(config,
+                                cache_config,
+                                quant_config=quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -459,12 +465,14 @@ class MixtralForCausalLM(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.model = MixtralModel(config,
+                                  cache_config,
                                   quant_config,
                                   lora_config=lora_config)
         self.unpadded_vocab_size = config.vocab_size

+ 21 - 10
aphrodite/modeling/models/mixtral_quant.py

@@ -30,6 +30,7 @@ from torch import nn
 from transformers import MixtralConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
@@ -156,14 +157,17 @@ class MixtralMoE(nn.Module):
 
 class MixtralAttention(nn.Module):
 
-    def __init__(self,
-                 hidden_size: int,
-                 num_heads: int,
-                 num_kv_heads: int,
-                 max_position: int = 4096 * 32,
-                 rope_theta: float = 10000,
-                 quant_config: Optional[QuantizationConfig] = None,
-                 sliding_window: Optional[int] = None) -> None:
+    def __init__(
+        self,
+        hidden_size: int,
+        num_heads: int,
+        num_kv_heads: int,
+        max_position: int = 4096 * 32,
+        rope_theta: float = 10000,
+        quant_config: Optional[QuantizationConfig] = None,
+        sliding_window: Optional[int] = None,
+        cache_config: Optional[CacheConfig] = None,
+    ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
         tp_size = get_tensor_model_parallel_world_size()
@@ -214,6 +218,7 @@ class MixtralAttention(nn.Module):
             self.scaling,
             num_kv_heads=self.num_kv_heads,
             sliding_window=self.sliding_window,
+            cache_config=cache_config,
         )
 
     def forward(
@@ -236,6 +241,7 @@ class MixtralDecoderLayer(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -249,6 +255,7 @@ class MixtralDecoderLayer(nn.Module):
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
             sliding_window=config.sliding_window,
+            cache_config=cache_config,
             quant_config=quant_config)
         self.block_sparse_moe = MixtralMoE(config=config,
                                            quant_config=quant_config)
@@ -291,6 +298,7 @@ class MixtralModel(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -302,7 +310,9 @@ class MixtralModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            MixtralDecoderLayer(config, quant_config=quant_config)
+            MixtralDecoderLayer(config,
+                                cache_config,
+                                quant_config=quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -331,12 +341,13 @@ class MixtralForCausalLM(nn.Module):
     def __init__(
         self,
         config: MixtralConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = MixtralModel(config, quant_config)
+        self.model = MixtralModel(config, cache_config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 13 - 5
aphrodite/modeling/models/mpt.py

@@ -7,6 +7,7 @@ import torch
 import torch.nn as nn
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size)
@@ -42,6 +43,7 @@ class MPTAttention(nn.Module):
     def __init__(
         self,
         config: MPTConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -106,7 +108,8 @@ class MPTAttention(nn.Module):
                               self.head_dim,
                               scaling,
                               alibi_slopes=alibi_slopes,
-                              num_kv_heads=self.num_kv_heads)
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -165,12 +168,13 @@ class MPTBlock(nn.Module):
     def __init__(
         self,
         config: MPTConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         hidden_size = config.d_model
         self.norm_1 = nn.LayerNorm(hidden_size)
-        self.attn = MPTAttention(config, quant_config)
+        self.attn = MPTAttention(config, cache_config, quant_config)
         self.norm_2 = nn.LayerNorm(hidden_size)
         self.ffn = MPTMLP(config, quant_config)
 
@@ -200,6 +204,7 @@ class MPTModel(nn.Module):
     def __init__(
         self,
         config: MPTConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -210,8 +215,10 @@ class MPTModel(nn.Module):
             config.vocab_size,
             config.d_model,
         )
-        self.blocks = nn.ModuleList(
-            [MPTBlock(config, quant_config) for _ in range(config.n_layers)])
+        self.blocks = nn.ModuleList([
+            MPTBlock(config, cache_config, quant_config)
+            for _ in range(config.n_layers)
+        ])
         self.norm_f = nn.LayerNorm(config.d_model)
         if config.no_bias:
             for module in self.modules():
@@ -245,6 +252,7 @@ class MPTForCausalLM(nn.Module):
     def __init__(
         self,
         config: MPTConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -252,7 +260,7 @@ class MPTForCausalLM(nn.Module):
         assert config.tie_word_embeddings
         self.quant_config = quant_config
 
-        self.transformer = MPTModel(config, quant_config)
+        self.transformer = MPTModel(config, cache_config, quant_config)
         self.lm_head_weight = self.transformer.wte.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 10 - 4
aphrodite/modeling/models/olmo.py

@@ -28,6 +28,7 @@ from torch import nn
 from transformers import OlmoConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -54,6 +55,7 @@ class OlmoAttention(nn.Module):
     def __init__(
         self,
         config: OlmoConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -92,7 +94,8 @@ class OlmoAttention(nn.Module):
         self.scaling = self.head_dim**-0.5
         self.attn = Attention(self.num_heads,
                               self.head_dim,
-                              scale=self.scaling)
+                              scale=self.scaling,
+                              cache_config=cache_config)
 
         # Attention output projection.
         self.o_proj = RowParallelLinear(
@@ -174,10 +177,11 @@ class OlmoDecoderLayer(nn.Module):
 
     def __init__(self,
                  config: OlmoConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         # Attention block.
-        self.self_attn = OlmoAttention(config, quant_config)
+        self.self_attn = OlmoAttention(config, cache_config, quant_config)
 
         # MLP block.
         self.mlp = OlmoMLP(config, quant_config)
@@ -216,6 +220,7 @@ class OlmoModel(nn.Module):
 
     def __init__(self,
                  config: OlmoConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
@@ -223,7 +228,7 @@ class OlmoModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size)
         self.layers = nn.ModuleList([
-            OlmoDecoderLayer(config, quant_config)
+            OlmoDecoderLayer(config, cache_config, quant_config)
             for layer_idx in range(config.num_hidden_layers)
         ])
         self.norm = nn.LayerNorm(config.hidden_size,
@@ -270,10 +275,11 @@ class OlmoForCausalLM(nn.Module):
 
     def __init__(self,
                  config: OlmoConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
-        self.model = OlmoModel(config, quant_config)
+        self.model = OlmoModel(config, cache_config, quant_config)
         if config.tie_word_embeddings:
             self.lm_head_weight = self.model.embed_tokens.weight
         else:

+ 12 - 4
aphrodite/modeling/models/opt.py

@@ -24,6 +24,7 @@ from torch import nn
 from transformers import OPTConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -60,6 +61,7 @@ class OPTAttention(nn.Module):
         embed_dim: int,
         num_heads: int,
         bias: bool = True,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -87,7 +89,8 @@ class OPTAttention(nn.Module):
         )
         self.attn = Attention(self.num_heads,
                               self.head_dim,
-                              scale=self.scaling)
+                              scale=self.scaling,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -107,6 +110,7 @@ class OPTDecoderLayer(nn.Module):
     def __init__(
         self,
         config: OPTConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -116,6 +120,7 @@ class OPTDecoderLayer(nn.Module):
             embed_dim=self.embed_dim,
             num_heads=config.num_attention_heads,
             bias=config.enable_bias,
+            cache_config=cache_config,
             quant_config=quant_config,
         )
         self.do_layer_norm_before = config.do_layer_norm_before
@@ -180,6 +185,7 @@ class OPTDecoder(nn.Module):
     def __init__(
         self,
         config: OPTConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -225,7 +231,7 @@ class OPTDecoder(nn.Module):
             self.final_layer_norm = None
 
         self.layers = nn.ModuleList([
-            OPTDecoderLayer(config, quant_config)
+            OPTDecoderLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
 
@@ -258,10 +264,11 @@ class OPTModel(nn.Module):
     def __init__(
         self,
         config: OPTConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
-        self.decoder = OPTDecoder(config, quant_config)
+        self.decoder = OPTDecoder(config, cache_config, quant_config)
 
     def forward(
         self,
@@ -278,12 +285,13 @@ class OPTForCausalLM(nn.Module):
     def __init__(
         self,
         config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = OPTModel(config, quant_config)
+        self.model = OPTModel(config, cache_config, quant_config)
         self.lm_head_weight = self.model.decoder.embed_tokens.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 10 - 3
aphrodite/modeling/models/orion.py

@@ -11,6 +11,7 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -67,6 +68,7 @@ class OrionAttention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -117,7 +119,8 @@ class OrionAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               self.scaling,
-                              num_kv_heads=self.num_kv_heads)
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -139,6 +142,7 @@ class OrionDecoderLayer(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -154,6 +158,7 @@ class OrionDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
+            cache_config=cache_config,
             quant_config=quant_config,
         )
         self.mlp = OrionMLP(
@@ -201,6 +206,7 @@ class OrionModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -212,7 +218,7 @@ class OrionModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            OrionDecoderLayer(config, quant_config)
+            OrionDecoderLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -244,12 +250,13 @@ class OrionForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = OrionModel(config, quant_config)
+        self.model = OrionModel(config, cache_config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 12 - 4
aphrodite/modeling/models/phi.py

@@ -42,6 +42,7 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -62,6 +63,7 @@ class PhiAttention(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.total_num_heads = config.num_attention_heads
@@ -104,7 +106,10 @@ class PhiAttention(nn.Module):
             max_position=max_position_embeddings,
             base=rope_theta,
         )
-        self.attn = Attention(self.num_heads, self.head_size, scaling)
+        self.attn = Attention(self.num_heads,
+                              self.head_size,
+                              scaling,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -154,11 +159,12 @@ class PhiLayer(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                             eps=config.layer_norm_eps)
-        self.self_attn = PhiAttention(config, quant_config)
+        self.self_attn = PhiAttention(config, cache_config, quant_config)
         self.mlp = PhiMLP(config, quant_config)
 
     def forward(
@@ -185,6 +191,7 @@ class PhiModel(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
@@ -192,7 +199,7 @@ class PhiModel(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size)
         self.layers = nn.ModuleList([
-            PhiLayer(config, quant_config)
+            PhiLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.final_layernorm = nn.LayerNorm(config.hidden_size,
@@ -224,12 +231,13 @@ class PhiForCausalLM(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
 
-        self.model = PhiModel(config, quant_config)
+        self.model = PhiModel(config, cache_config, quant_config)
 
         self.lm_head = ParallelLMHead(config.vocab_size,
                                       config.hidden_size,

+ 12 - 3
aphrodite/modeling/models/qwen.py

@@ -11,6 +11,7 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -67,6 +68,7 @@ class QWenAttention(nn.Module):
         max_position_embeddings: int,
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -100,7 +102,10 @@ class QWenAttention(nn.Module):
             base=rope_theta,
             rope_scaling=rope_scaling,
         )
-        self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              self.scaling,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -122,6 +127,7 @@ class QWenBlock(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -134,6 +140,7 @@ class QWenBlock(nn.Module):
                                   config.max_position_embeddings,
                                   rope_theta=rope_theta,
                                   rope_scaling=rope_scaling,
+                                  cache_config=cache_config,
                                   quant_config=quant_config)
 
         self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -174,6 +181,7 @@ class QWenModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
@@ -185,7 +193,7 @@ class QWenModel(nn.Module):
             config.hidden_size,
         )
         self.h = nn.ModuleList([
-            QWenBlock(config, quant_config)
+            QWenBlock(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -217,12 +225,13 @@ class QWenLMHeadModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.transformer = QWenModel(config, quant_config)
+        self.transformer = QWenModel(config, cache_config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 10 - 4
aphrodite/modeling/models/qwen2.py

@@ -29,7 +29,7 @@ from torch import nn
 from transformers import Qwen2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
-from aphrodite.common.config import LoRAConfig
+from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -86,6 +86,7 @@ class Qwen2Attention(nn.Module):
                  max_position: int = 4096 * 32,
                  rope_theta: float = 10000,
                  use_sliding_window: bool = False,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None,
                  sliding_window: Optional[int] = None) -> None:
         super().__init__()
@@ -136,7 +137,8 @@ class Qwen2Attention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              sliding_window=self.sliding_window)
+                              sliding_window=self.sliding_window,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -159,6 +161,7 @@ class Qwen2DecoderLayer(nn.Module):
         self,
         config: Qwen2Config,
         layer_idx: int,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -174,6 +177,7 @@ class Qwen2DecoderLayer(nn.Module):
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
             use_sliding_window=use_sliding_window,
+            cache_config=cache_config,
             quant_config=quant_config,
             sliding_window=config.sliding_window)
         self.mlp = Qwen2MLP(
@@ -221,6 +225,7 @@ class Qwen2Model(nn.Module):
     def __init__(
         self,
         config: Qwen2Config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -233,7 +238,7 @@ class Qwen2Model(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            Qwen2DecoderLayer(config, layer_idx, quant_config)
+            Qwen2DecoderLayer(config, layer_idx, cache_config, quant_config)
             for layer_idx in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -286,6 +291,7 @@ class Qwen2ForCausalLM(nn.Module):
     def __init__(
         self,
         config: Qwen2Config,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
@@ -293,7 +299,7 @@ class Qwen2ForCausalLM(nn.Module):
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = Qwen2Model(config, quant_config)
+        self.model = Qwen2Model(config, cache_config, quant_config)
 
         if config.tie_word_embeddings:
             self.lm_head_weight = self.model.embed_tokens.weight

+ 13 - 3
aphrodite/modeling/models/qwen2_moe.py

@@ -30,6 +30,7 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import (get_tensor_model_parallel_rank,
                                    get_tensor_model_parallel_world_size,
@@ -186,6 +187,7 @@ class Qwen2MoeAttention(nn.Module):
         rope_theta: float = 10000,
         rope_scaling: Optional[Dict[str, Any]] = None,
         max_position_embeddings: int = 8192,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -237,7 +239,8 @@ class Qwen2MoeAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               self.scaling,
-                              num_kv_heads=self.num_kv_heads)
+                              num_kv_heads=self.num_kv_heads,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -260,6 +263,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
         self,
         config: PretrainedConfig,
         layer_idx: int,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -275,6 +279,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
             rope_theta=rope_theta,
             rope_scaling=rope_scaling,
             max_position_embeddings=max_position_embeddings,
+            cache_config=cache_config,
             quant_config=quant_config,
         )
         if (config.num_experts is not None
@@ -327,6 +332,7 @@ class Qwen2MoeModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -338,7 +344,10 @@ class Qwen2MoeModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
+            Qwen2MoeDecoderLayer(config,
+                                 layer_idx,
+                                 cache_config,
+                                 quant_config=quant_config)
             for layer_idx in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -368,12 +377,13 @@ class Qwen2MoeForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = Qwen2MoeModel(config, quant_config)
+        self.model = Qwen2MoeModel(config, cache_config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 10 - 4
aphrodite/modeling/models/stablelm.py

@@ -26,6 +26,7 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -71,6 +72,7 @@ class StablelmAttention(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None) -> None:
         super().__init__()
         self.config = config
@@ -123,7 +125,8 @@ class StablelmAttention(nn.Module):
         self.attn = Attention(self.num_heads,
                               self.head_dim,
                               self.scaling,
-                              num_kv_heads=self.num_key_value_heads)
+                              num_kv_heads=self.num_key_value_heads,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -145,10 +148,11 @@ class StablelmDecoderLayer(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
-        self.self_attn = StablelmAttention(config)
+        self.self_attn = StablelmAttention(config, cache_config, quant_config)
         self.mlp = StablelmMLP(config, quant_config)
         norm_eps = getattr(config, "norm_eps",
                            getattr(config, "layer_norm_eps", 1e-05))
@@ -187,6 +191,7 @@ class StableLMEpochModel(nn.Module):
 
     def __init__(self,
                  config: PretrainedConfig,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None) -> None:
         super().__init__()
         self.embed_tokens = VocabParallelEmbedding(
@@ -194,7 +199,7 @@ class StableLMEpochModel(nn.Module):
             config.hidden_size,
         )
         self.layers = nn.ModuleList([
-            StablelmDecoderLayer(config, quant_config)
+            StablelmDecoderLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         norm_eps = getattr(config, "norm_eps",
@@ -226,12 +231,13 @@ class StablelmForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = StableLMEpochModel(config, quant_config)
+        self.model = StableLMEpochModel(config, cache_config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 15 - 3
aphrodite/modeling/models/starcoder2.py

@@ -25,6 +25,7 @@ from torch import nn
 from transformers import Starcoder2Config
 
 from aphrodite.attention import Attention, AttentionMetadata
+from aphrodite.common.config import CacheConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import get_act_fn
@@ -45,6 +46,7 @@ class Starcoder2Attention(nn.Module):
 
     def __init__(self,
                  config: Starcoder2Config,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
@@ -100,6 +102,7 @@ class Starcoder2Attention(nn.Module):
             self.scaling,
             num_kv_heads=self.num_kv_heads,
             sliding_window=self.sliding_window,
+            cache_config=cache_config,
         )
 
     def forward(
@@ -149,10 +152,13 @@ class Starcoder2DecoderLayer(nn.Module):
 
     def __init__(self,
                  config: Starcoder2Config,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.hidden_size = config.hidden_size
-        self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
+        self.self_attn = Starcoder2Attention(config,
+                                             cache_config,
+                                             quant_config=quant_config)
         self.mlp = Starcoder2MLP(config, quant_config=quant_config)
         self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                             eps=config.norm_epsilon)
@@ -190,6 +196,7 @@ class Starcoder2Model(nn.Module):
 
     def __init__(self,
                  config: Starcoder2Config,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
@@ -200,7 +207,9 @@ class Starcoder2Model(nn.Module):
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                    config.hidden_size)
         self.layers = nn.ModuleList([
-            Starcoder2DecoderLayer(config, quant_config=quant_config)
+            Starcoder2DecoderLayer(config,
+                                   cache_config,
+                                   quant_config=quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
@@ -225,10 +234,13 @@ class Starcoder2ForCausalLM(nn.Module):
 
     def __init__(self,
                  config: Starcoder2Config,
+                 cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None):
         super().__init__()
         self.config = config
-        self.model = Starcoder2Model(config, quant_config=quant_config)
+        self.model = Starcoder2Model(config,
+                                     cache_config,
+                                     quant_config=quant_config)
         self.vocab_size = config.vocab_size
         self.unpadded_vocab_size = config.vocab_size
         if config.tie_word_embeddings:

+ 10 - 4
aphrodite/modeling/models/xverse.py

@@ -27,7 +27,7 @@ from torch import nn
 from transformers import PretrainedConfig
 
 from aphrodite.attention import Attention, AttentionMetadata
-from aphrodite.common.config import LoRAConfig
+from aphrodite.common.config import CacheConfig, LoRAConfig
 from aphrodite.common.sequence import SamplerOutput
 from aphrodite.distributed import get_tensor_model_parallel_world_size
 from aphrodite.modeling.layers.activation import SiluAndMul
@@ -88,6 +88,7 @@ class XverseAttention(nn.Module):
         quant_config: Optional[QuantizationConfig] = None,
         bias: bool = False,
         sliding_window: Optional[int] = None,
+        cache_config: Optional[CacheConfig] = None,
     ) -> None:
         super().__init__()
         self.hidden_size = hidden_size
@@ -132,7 +133,8 @@ class XverseAttention(nn.Module):
                               self.head_dim,
                               self.scaling,
                               num_kv_heads=self.num_kv_heads,
-                              sliding_window=sliding_window)
+                              sliding_window=sliding_window,
+                              cache_config=cache_config)
 
     def forward(
         self,
@@ -154,6 +156,7 @@ class XverseDecoderLayer(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
     ) -> None:
         super().__init__()
@@ -174,6 +177,7 @@ class XverseDecoderLayer(nn.Module):
             quant_config=quant_config,
             bias=getattr(config, "bias", False),
             sliding_window=sliding_window,
+            cache_config=cache_config,
         )
         self.mlp = XverseMLP(
             hidden_size=self.hidden_size,
@@ -220,6 +224,7 @@ class XverseModel(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
     ) -> None:
@@ -236,7 +241,7 @@ class XverseModel(nn.Module):
             org_num_embeddings=config.vocab_size,
         )
         self.layers = nn.ModuleList([
-            XverseDecoderLayer(config, quant_config)
+            XverseDecoderLayer(config, cache_config, quant_config)
             for _ in range(config.num_hidden_layers)
         ])
         self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -294,13 +299,14 @@ class XverseForCausalLM(nn.Module):
     def __init__(
         self,
         config: PretrainedConfig,
+        cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
         lora_config=None,
     ) -> None:
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = XverseModel(config, quant_config)
+        self.model = XverseModel(config, cache_config, quant_config)
         self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()

+ 11 - 7
aphrodite/task_handler/cache_engine.py

@@ -31,7 +31,7 @@ class CacheEngine:
 
         self.head_size = model_config.get_head_size()
         self.num_layers = model_config.get_num_layers(parallel_config)
-        self.num_heads = model_config.get_num_kv_heads(parallel_config)
+        self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
 
         self.block_size = cache_config.block_size
         self.num_gpu_blocks = cache_config.num_gpu_blocks
@@ -43,11 +43,15 @@ class CacheEngine:
             self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
 
         # Get attention backend.
-        self.attn_backend = get_attn_backend(model_config.dtype)
-
-        # Initialize the cache.
-        # Get attention backend.
-        self.attn_backend = get_attn_backend(model_config.dtype)
+        self.attn_backend = get_attn_backend(
+            model_config.get_num_attention_heads(parallel_config),
+            self.head_size,
+            self.num_kv_heads,
+            model_config.get_sliding_window(),
+            model_config.dtype,
+            cache_config.cache_dtype,
+            self.block_size,
+        )
 
         # Initialize the cache.
         self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
@@ -60,7 +64,7 @@ class CacheEngine:
     ) -> List[torch.Tensor]:
         """Allocates KV cache on the specified device."""
         kv_cache_shape = self.attn_backend.get_kv_cache_shape(
-            num_blocks, self.block_size, self.num_heads, self.head_size)
+            num_blocks, self.block_size, self.num_kv_heads, self.head_size)
         pin_memory = is_pin_memory_available() if device == "cpu" else False
         kv_cache: List[torch.Tensor] = []
         for _ in range(self.num_layers):

+ 11 - 4
aphrodite/task_handler/cpu_model_runner.py

@@ -50,7 +50,15 @@ class CPUModelRunner:
         self.kv_cache_dtype = kv_cache_dtype
         self.sliding_window = model_config.get_sliding_window()
         self.block_size = cache_config.block_size
-        self.attn_backend = get_attn_backend(self.model_config.dtype)
+        self.attn_backend = get_attn_backend(
+            self.model_config.get_num_attention_heads(self.parallel_config),
+            self.model_config.get_head_size(),
+            self.model_config.get_num_kv_heads(self.parallel_config),
+            self.model_config.get_sliding_window(),
+            self.model_config.dtype,
+            self.kv_cache_dtype,
+            self.block_size,
+        )
 
         # Lazy initialization.
         self.model: nn.Module  # Set after init_Model
@@ -63,7 +71,8 @@ class CPUModelRunner:
             vision_language_config=self.vision_language_config,
             lora_config=self.lora_config,
             parallel_config=self.parallel_config,
-            scheduler_config=self.scheduler_config)
+            scheduler_config=self.scheduler_config,
+            cache_config=self.cache_config)
 
     def _prepare_prompt(
         self,
@@ -155,7 +164,6 @@ class CPUModelRunner:
             decode_metadata=None,
             block_tables=torch.tensor([]),
             slot_mapping=slot_mapping,
-            kv_cache_dtype=self.kv_cache_dtype,
         )
         return (input_tokens, input_positions, attn_metadata, seq_lens,
                 multi_modal_input)
@@ -239,7 +247,6 @@ class CPUModelRunner:
             prefill_metadata=None,
             decode_metadata=None,
             block_tables=block_tables,
-            kv_cache_dtype=self.kv_cache_dtype,
         )
         return (
             input_tokens,

+ 9 - 1
aphrodite/task_handler/cpu_worker.py

@@ -50,7 +50,15 @@ class CPUCacheEngine:
             self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
 
         # Get attention backend.
-        self.attn_backend = get_attn_backend(model_config.dtype)
+        self.attn_backend = get_attn_backend(
+            self.model_config.get_num_attention_heads(self.parallel_config),
+            self.model_config.get_head_size(),
+            self.model_config.get_num_kv_heads(self.parallel_config),
+            self.model_config.get_sliding_window(),
+            self.model_config.dtype,
+            cache_config.cache_dtype,
+            self.block_size,
+        )
 
         # Initialize the cache.
         self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)

+ 0 - 1
aphrodite/task_handler/embedding_model_runner.py

@@ -233,7 +233,6 @@ class EmbeddingModelRunner(ModelRunner):
             num_decode_tokens=num_decode_tokens,
             prefill_metadata=prefill_attn_metadata,
             decode_metadata=decode_attn_metadata,
-            kv_cache_dtype=self.kv_cache_dtype,
         )
 
         return (input_tokens, input_positions, attn_metadata, pooling_metadata,

+ 11 - 4
aphrodite/task_handler/model_runner.py

@@ -143,10 +143,18 @@ class ModelRunner:
         self.graph_block_tables = np.zeros(
             (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
             dtype=np.int32)
-        self.attn_backend = get_attn_backend(self.model_config.dtype)
+        self.attn_backend = get_attn_backend(
+            self.model_config.get_num_attention_heads(self.parallel_config),
+            self.model_config.get_head_size(),
+            self.model_config.get_num_kv_heads(self.parallel_config),
+            self.model_config.get_sliding_window(),
+            self.model_config.dtype,
+            self.kv_cache_dtype,
+            self.block_size,
+        )
 
         # Lazy initialization
-        self.model: torch.nn.Module  # Set after load_model
+        self.model: nn.Module  # Set after load_model
 
         # Set if the backend is flashinfer.
         self.flashinfer_workspace_buffer: torch.Tensor
@@ -163,6 +171,7 @@ class ModelRunner:
                 vision_language_config=self.vision_language_config,
                 parallel_config=self.parallel_config,
                 scheduler_config=self.scheduler_config,
+                cache_config=self.cache_config,
             )
 
         self.model_memory_usage = m.consumed_memory
@@ -759,7 +768,6 @@ class ModelRunner:
             num_decode_tokens=num_decode_tokens,
             prefill_metadata=prefill_attn_metadata,
             decode_metadata=decode_attn_metadata,
-            kv_cache_dtype=self.kv_cache_dtype,
         )
 
         return (input_tokens, input_positions, attn_metadata,
@@ -973,7 +981,6 @@ class ModelRunner:
                     slot_mapping=slot_mapping[:batch_size],
                     prefill_metadata=None,
                     decode_metadata=decode_metadata,
-                    kv_cache_dtype=self.kv_cache_dtype,
                 )
 
                 if self.lora_config: