Prechádzať zdrojové kódy

chore: attention refactor and upstream sync apr01 (#365)

* add speculative workers

* sync core processor logic with upstream

* add compiled dag api, sync logic

* do not log prompt token IDs

* add cache config to metrics

* add compiled dag in ray utils

* do not import cache ops if neuron backend is detected

* hope i didn't break anything

* modify worker

* add neuron worker

* add file lock

* remove lora loading from loader.py

* add neuronx model loader

* pick the correct loader based on device

* add separate attention backends

* migrate all models to the new attention backend

* neuron patch for llama

* modify rejection sampling

* add optimized moe configs

* add script for benchmarking local GPU for optimal moe config

* fix sampler output return

* logits as hidden state for lora layer

* health check in api server

* correct torch version in pyproject

* move config to where it belongs

* missing imports

* temporarily remove int8 kv cache

* logger fixes, set max log len to 0

* gracefully exit with keyboard interrupt

* explicitly abort a request when a cancellederror is raised

* remove unfinished configs

* update rocm's requirements

* formatting

* yapf

* parse version correctly for ray

* yapf again

* set enforce eager to true by default

* do not enable memory pinning when neuron
AlpinDale 11 mesiacov pred
rodič
commit
f8dfac6372
100 zmenil súbory, kde vykonal 10036 pridanie a 2571 odobranie
  1. 0 1
      .gitignore
  2. 98 33
      aphrodite/common/config.py
  3. 10 8
      aphrodite/common/logger.py
  4. 47 30
      aphrodite/common/outputs.py
  5. 13 0
      aphrodite/common/sampling_params.py
  6. 71 21
      aphrodite/common/sequence.py
  7. 67 27
      aphrodite/common/utils.py
  8. 1 1
      aphrodite/endpoints/llm.py
  9. 104 95
      aphrodite/endpoints/openai/api_server.py
  10. 275 97
      aphrodite/engine/aphrodite_engine.py
  11. 373 240
      aphrodite/engine/args_tools.py
  12. 45 28
      aphrodite/engine/async_aphrodite.py
  13. 79 21
      aphrodite/engine/metrics.py
  14. 21 5
      aphrodite/engine/ray_tools.py
  15. 0 0
      aphrodite/executor/__init__.py
  16. 76 0
      aphrodite/executor/executor_base.py
  17. 153 0
      aphrodite/executor/gpu_executor.py
  18. 78 0
      aphrodite/executor/neuron_executor.py
  19. 452 0
      aphrodite/executor/ray_gpu_executor.py
  20. 13 0
      aphrodite/executor/utils.py
  21. 4 0
      aphrodite/lora/layers.py
  22. 8 3
      aphrodite/modeling/hf_downloader.py
  23. 0 354
      aphrodite/modeling/layers/attention.py
  24. 93 0
      aphrodite/modeling/layers/attention/__init__.py
  25. 0 0
      aphrodite/modeling/layers/attention/backends/__init__.py
  26. 121 0
      aphrodite/modeling/layers/attention/backends/flash_attn.py
  27. 255 0
      aphrodite/modeling/layers/attention/backends/xformers.py
  28. 0 0
      aphrodite/modeling/layers/attention/ops/__init__.py
  29. 138 0
      aphrodite/modeling/layers/attention/ops/paged_attn.py
  30. 28 12
      aphrodite/modeling/layers/attention/ops/prefix_prefill.py
  31. 8 0
      aphrodite/modeling/layers/fused_moe/__init__.py
  32. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json
  33. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json
  34. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json
  35. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json
  36. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json
  37. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json
  38. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json
  39. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json
  40. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json
  41. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
  42. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json
  43. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json
  44. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
  45. 9 0
      aphrodite/modeling/layers/fused_moe/configs/README
  46. 120 58
      aphrodite/modeling/layers/fused_moe/fused_moe.py
  47. 57 18
      aphrodite/modeling/layers/rejection.py
  48. 1 2
      aphrodite/modeling/layers/sampler.py
  49. 62 24
      aphrodite/modeling/loader.py
  50. 5 4
      aphrodite/modeling/metadata.py
  51. 3 4
      aphrodite/modeling/models/baichuan.py
  52. 5 5
      aphrodite/modeling/models/bloom.py
  53. 2 2
      aphrodite/modeling/models/chatglm.py
  54. 1 1
      aphrodite/modeling/models/cohere.py
  55. 3 3
      aphrodite/modeling/models/deepseek.py
  56. 14 14
      aphrodite/modeling/models/falcon.py
  57. 2 2
      aphrodite/modeling/models/gemma.py
  58. 2 4
      aphrodite/modeling/models/gpt2.py
  59. 5 5
      aphrodite/modeling/models/gpt_bigcode.py
  60. 2 2
      aphrodite/modeling/models/gpt_j.py
  61. 2 2
      aphrodite/modeling/models/gpt_neox.py
  62. 2 2
      aphrodite/modeling/models/internlm2.py
  63. 16 9
      aphrodite/modeling/models/llama.py
  64. 3 3
      aphrodite/modeling/models/mixtral.py
  65. 2 2
      aphrodite/modeling/models/mixtral_quant.py
  66. 6 6
      aphrodite/modeling/models/mpt.py
  67. 79 0
      aphrodite/modeling/models/neuron/llama.py
  68. 4 4
      aphrodite/modeling/models/olmo.py
  69. 4 4
      aphrodite/modeling/models/opt.py
  70. 2 2
      aphrodite/modeling/models/phi.py
  71. 2 2
      aphrodite/modeling/models/qwen.py
  72. 2 2
      aphrodite/modeling/models/qwen2.py
  73. 2 2
      aphrodite/modeling/models/stablelm.py
  74. 70 0
      aphrodite/modeling/neuron_loader.py
  75. 2 2
      aphrodite/modeling/sampling_metadata.py
  76. 17 0
      aphrodite/modeling/utils.py
  77. 39 48
      aphrodite/processing/block_manager.py
  78. 2 3
      aphrodite/processing/evictor.py
  79. 1 4
      aphrodite/processing/scheduler.py
  80. 398 0
      aphrodite/spec_decode/batch_expansion.py
  81. 77 0
      aphrodite/spec_decode/interfaces.py
  82. 175 0
      aphrodite/spec_decode/metrics.py
  83. 392 0
      aphrodite/spec_decode/multi_step_worker.py
  84. 394 0
      aphrodite/spec_decode/spec_decode_worker.py
  85. 101 0
      aphrodite/spec_decode/util.py
  86. 9 2
      aphrodite/task_handler/cache_engine.py
  87. 74 56
      aphrodite/task_handler/model_runner.py
  88. 204 0
      aphrodite/task_handler/neuron_worker.py
  89. 30 14
      aphrodite/task_handler/worker.py
  90. 1 2
      kernels/attention/attention_dtypes.h
  91. 936 1014
      kernels/attention/attention_kernels.cu
  92. 255 262
      kernels/attention/dtype_float32.cuh
  93. 1 0
      kernels/backup/README
  94. 8 0
      kernels/backup/attention_dtypes.h
  95. 1032 0
      kernels/backup/attention_kernels.cu
  96. 39 0
      kernels/backup/cache.h
  97. 512 0
      kernels/backup/cache_kernels.cu
  98. 39 0
      kernels/backup/dispatch_utils.h
  99. 280 0
      kernels/backup/dtype_float32.cuh
  100. 0 0
      kernels/backup/dtype_int8.cuh

+ 0 - 1
.gitignore

@@ -6,7 +6,6 @@ repos
 *.so
 .conda
 build
-*.json
 dist*
 .VSCodeCounter
 conda/

+ 98 - 33
aphrodite/common/config.py

@@ -8,7 +8,7 @@ import torch
 from transformers import PretrainedConfig
 
 from aphrodite.transformers_utils.config import get_config
-from aphrodite.common.utils import (get_cpu_memory, is_hip,
+from aphrodite.common.utils import (get_cpu_memory, is_hip, is_neuron,
                                     get_nvcc_cuda_version)
 
 _GB = 1 << 30
@@ -43,6 +43,9 @@ class ModelConfig:
         revision: The specific model version to use. It can be a branch name,
             a tag name, or a commit id. If unspecified, will use the default
             version.
+        code_revision: The specific revision to use for the model code on
+            Hugging Face Hub. It can be a branch name, a tag name, or a 
+            commit id. If unspecified, will use the default version.
         tokenizer_revision: The specific tokenizer version to use. It can be a
             branch name, a tag name, or a commit id. If unspecified, will use
             the default version.
@@ -71,16 +74,18 @@ class ModelConfig:
         trust_remote_code: bool,
         download_dir: Optional[str],
         load_format: str,
-        dtype: str,
+        # dtype: str,
+        dtype: Union[str, torch.dtype],
         seed: int,
         revision: Optional[str] = None,
+        code_revision: Optional[str] = None,
         tokenizer_revision: Optional[str] = None,
         max_model_len: Optional[int] = None,
         quantization: Optional[str] = None,
         load_in_4bit: bool = False,
         load_in_8bit: bool = False,
         load_in_smooth: bool = False,
-        enforce_eager: bool = False,
+        enforce_eager: bool = True,
         max_context_len_to_capture: Optional[int] = None,
         max_log_probs: int = 10,
     ) -> None:
@@ -92,6 +97,7 @@ class ModelConfig:
         self.load_format = load_format
         self.seed = seed
         self.revision = revision
+        self.code_revision = code_revision
         self.tokenizer_revision = tokenizer_revision
         self.quantization = quantization
         self.load_in_4bit = load_in_4bit
@@ -106,14 +112,18 @@ class ModelConfig:
             # download model from ModelScope hub,
             # lazy import so that modelscope is not required for normal use.
             from modelscope.hub.snapshot_download import snapshot_download  # pylint: disable=C
-            model_path = snapshot_download(model_id=model,
-                                           cache_dir=download_dir,
-                                           revision=revision)
+            if not os.path.exists(model):
+                model_path = snapshot_download(model_id=model,
+                                               cache_dir=download_dir,
+                                               revision=revision)
+            else:
+                model_path = model
             self.model = model_path
             self.download_dir = model_path
             self.tokenizer = model_path
 
-        self.hf_config = get_config(self.model, trust_remote_code, revision)
+        self.hf_config = get_config(self.model, trust_remote_code, revision,
+                                    code_revision)
         self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
         self.max_model_len = _get_and_verify_max_len(self.hf_config,
                                                      max_model_len)
@@ -177,6 +187,7 @@ class ModelConfig:
         # Parse quantization method from the HF model config, if available.
         hf_quant_config = getattr(self.hf_config, "quantization_config", None)
         if hf_quant_config is not None:
+
             hf_quant_method = str(hf_quant_config["quant_method"]).lower()
             # If the GPTQ model is serialized in marlin format, use marlin.
             if (hf_quant_method == "gptq"
@@ -375,7 +386,7 @@ class CacheConfig:
         gpu_memory_utilization: float,
         swap_space: int,
         cache_dtype: str,
-        cache_quant_params_path: Optional[str] = None,
+        # cache_quant_params_path: Optional[str] = None,
         sliding_window: Optional[int] = None,
         context_shift: bool = False,
     ) -> None:
@@ -384,7 +395,7 @@ class CacheConfig:
         self.swap_space_bytes = swap_space * _GB
         self.cache_dtype = cache_dtype
         self.sliding_window = sliding_window
-        self.cache_quant_params_path = cache_quant_params_path
+        # self.cache_quant_params_path = cache_quant_params_path
         self.context_shift = context_shift
         self._verify_args()
         self._verify_cache_dtype()
@@ -393,6 +404,11 @@ class CacheConfig:
         self.num_gpu_blocks = None
         self.num_cpu_blocks = None
 
+    def metrics_info(self):
+        # convert cache_config to dict(key: str, value: str) for prometheus
+        # metrics info
+        return {key: str(value) for key, value in self.__dict__.items()}
+
     def _verify_args(self) -> None:
         if self.gpu_memory_utilization > 1.0:
             raise ValueError(
@@ -400,25 +416,24 @@ class CacheConfig:
                 f"{self.gpu_memory_utilization}.")
 
     def _verify_cache_dtype(self) -> None:
-        if self.cache_dtype in ["auto", "int8"]:
+        if self.cache_dtype == "auto":
+            # if self.cache_dtype in ["auto", "int8"]:
             pass
         elif self.cache_dtype == "fp8_e5m2":
-            nvcc_cuda_version = get_nvcc_cuda_version()
-            if nvcc_cuda_version < Version("11.8"):
-                raise ValueError(
-                    "FP8 is not supported when cuda version is lower than "
-                    "11.8. If you think you have the correct cuda version, "
-                    "please make sure you've properly exported CUDA_HOME.")
-            device_name = torch.cuda.get_device_name()
-            if "AMD" in device_name:
+            if is_hip():
                 raise NotImplementedError(
                     "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
+            nvcc_cuda_version = get_nvcc_cuda_version()
+            if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
+                raise ValueError(
+                    "FP8 is not supported when cuda version is lower than 11.8."
+                )
             logger.info(
                 "Using fp8_e5m2 data type to store kv cache. It reduces "
                 "the GPU memory footprint and boosts the performance. "
                 "But it may cause slight accuracy drop. "
                 "Currently we only support fp8 without scaling factors and "
-                "make e5m2 as a default format.")
+                "use e5m2 as a default format.")
         else:
             raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
 
@@ -450,8 +465,13 @@ class ParallelConfig:
         worker_use_ray: Whether to use Ray for model workers. Will be set to
             True if either pipeline_parallel_size or tensor_parallel_size is
             greater than 1.
+        max_parallel_loading_workers: Maximum number of multiple batches
+            when load model sequentially. To avoid RAM OOM when using tensor
+            parallel and large models.
         disable_custom_all_reduce: Disable the custom all-reduce kernel and
             fall back to NCCL.
+        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
+            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
     """
 
     def __init__(
@@ -461,15 +481,26 @@ class ParallelConfig:
         worker_use_ray: bool,
         max_parallel_loading_workers: Optional[int] = None,
         disable_custom_all_reduce: bool = False,
+        ray_workers_use_nsight: bool = False,
     ) -> None:
         self.pipeline_parallel_size = pipeline_parallel_size
-        self.tensor_parallel_size = tensor_parallel_size
+        if is_neuron():
+            # For Neuron device support, here we assign TP=1 to avoid sharding
+            # within Aphrodite directly.
+            # Transformer-neuronx would take neuron_tp_degree attribute, and
+            # distribute the workload to multiple NeuronCores.
+            self.tensor_parallel_size = 1
+            self.neuron_tp_degree = tensor_parallel_size
+        else:
+            self.tensor_parallel_size = tensor_parallel_size
         self.worker_use_ray = worker_use_ray
         self.max_parallel_loading_workers = max_parallel_loading_workers
         self.disable_custom_all_reduce = disable_custom_all_reduce
+        self.ray_workers_use_nsight = ray_workers_use_nsight
 
-        self.world_size = pipeline_parallel_size * tensor_parallel_size
-        if self.world_size > 1:
+        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
+        # Ray worker is not supported for Neuron backend.
+        if self.world_size > 1 and not is_neuron():
             self.worker_use_ray = True
         self._verify_args()
 
@@ -477,16 +508,29 @@ class ParallelConfig:
         if self.pipeline_parallel_size > 1:
             raise NotImplementedError(
                 "Pipeline parallelism is not supported yet.")
-        if is_hip():
+        if not self.disable_custom_all_reduce and self.world_size > 1:
+            if is_hip():
+                self.disable_custom_all_reduce = True
+                logger.info(
+                    "Disabled the custom all-reduce kernel because it is not "
+                    "supported on AMD GPUs.")
+            elif self.pipeline_parallel_size > 1:
+                self.disable_custom_all_reduce = True
+                logger.info(
+                    "Disabled the custom all-reduce kernel because it is not "
+                    "supported with pipeline parallelism.")
+        if self.ray_workers_use_nsight and not self.worker_use_ray:
+            raise ValueError("Unable to use nsight profiling unless workers "
+                             "run with Ray.")
+
+        # FIXME: Fix the stability issues and re-enable the custom
+        # all-reduce kernel.
+        if not self.disable_custom_all_reduce and self.world_size > 1:
             self.disable_custom_all_reduce = True
             logger.info(
-                "Disabled the custom all-reduce kernel because it is not "
-                "supported on AMD GPUs.")
-        elif self.pipeline_parallel_size > 1:
-            self.disable_custom_all_reduce = True
-            logger.info(
-                "Disabled the custom all-reduce kernel because it is not "
-                "supported with pipeline parallelism.")
+                "Custom all-reduce kernels are temporarily disabled due to "
+                "stability issues. We will re-enable them once the issues are "
+                "resolved.")
 
 
 class SchedulerConfig:
@@ -538,8 +582,29 @@ class SchedulerConfig:
 
 class DeviceConfig:
 
-    def __init__(self, device: str = "cuda") -> None:
-        self.device = torch.device(device)
+    def __init__(self, device: str = "auto") -> None:
+        if device == "auto":
+            # Automated device type detection
+            if torch.cuda.is_available():
+                self.device_type = "cuda"
+            elif is_neuron():
+                self.device_type = "neuron"
+            else:
+                raise RuntimeError("No supported device detected.")
+        else:
+            # Device type is assigned explicitly
+            self.device_type = device
+
+        # Some device types require processing inputs on CPU
+        if self.device_type in ["neuron"]:
+            self.device = torch.device("cpu")
+        else:
+            # Set device with device type
+            self.device = torch.device(self.device_type)
+
+    @property
+    def is_neuron(self):
+        return self.device_type == "neuron"
 
 
 @dataclass
@@ -571,7 +636,7 @@ class LoRAConfig:
         elif self.max_cpu_loras < self.max_loras:
             raise ValueError(
                 f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
-                f"max_num_seqs ({self.max_loras})")
+                f"max_loras ({self.max_loras})")
 
     def verify_with_model_config(self, model_config: ModelConfig):
         if self.lora_dtype in (None, "auto"):

+ 10 - 8
aphrodite/common/logger.py

@@ -1,9 +1,10 @@
 """
-Internal logging utility. Adapted from
-https://github.com/theroyallab/tabbyAPI/blob/4cc0b59bdc94e6342b6d1d7acadbadc63c740ed9/common/logger.py
+Internal logging utility.
 """
 
 import logging
+import os
+
 from loguru import logger
 from rich.console import Console
 from rich.markup import escape
@@ -17,6 +18,7 @@ from rich.progress import (
 )
 
 RICH_CONSOLE = Console()
+LOG_LEVEL = os.getenv("APHRODITE_LOG_LEVEL", "INFO").upper()
 
 
 def unwrap(wrapped, default=None):
@@ -60,9 +62,9 @@ def _log_formatter(record: dict):
     message = unwrap(record.get("message"), "")
 
     # Replace once loguru allows for turning off str.format
-    message = message.replace("{{", "{{").replace("}}", "}}")
-    # Manually escape < and > characters
-    message = message.replace("<", "\\<").replace(">", "\\>")
+    message = message.replace("{", "{{").replace("}", "}}").replace("<", "\<")
+
+    # Escape markup tags from Rich
     message = escape(message)
     lines = message.splitlines()
 
@@ -86,7 +88,7 @@ class UvicornLoggingHandler(logging.Handler):
 
 
 # Uvicorn config for logging. Passed into run when creating all loggers in
-#server
+# server
 UVICORN_LOG_CONFIG = {
     "version": 1,
     "disable_existing_loggers": False,
@@ -99,7 +101,7 @@ UVICORN_LOG_CONFIG = {
     "root": {
         "handlers": ["uvicorn"],
         "propagate": False,
-        "level": "INFO"
+        "level": LOG_LEVEL
     },
 }
 
@@ -111,7 +113,7 @@ def setup_logger():
 
     logger.add(
         RICH_CONSOLE.print,
-        level="INFO",
+        level=LOG_LEVEL,
         format=_log_formatter,
         colorize=True,
     )

+ 47 - 30
aphrodite/common/outputs.py

@@ -1,7 +1,13 @@
 from typing import List, Optional
+import time
 
-from aphrodite.common.sequence import (PromptLogprobs, SampleLogprobs,
-                                       SequenceGroup, SequenceStatus)
+from aphrodite.common.sequence import (
+    PromptLogprobs,
+    SampleLogprobs,
+    SequenceGroup,
+    SequenceStatus,
+    RequestMetrics,
+)
 from aphrodite.lora.request import LoRARequest
 
 
@@ -60,6 +66,7 @@ class RequestOutput:
         prompt_logprobs: The log probabilities to return per prompt token.
         outputs: The output sequences of the request.
         finished: Whether the whole request is finished.
+        metrics: Metrics associated with the request.
         lora_request: The LoRA request that was used to generate the output.
     """
 
@@ -71,6 +78,7 @@ class RequestOutput:
         prompt_logprobs: Optional[PromptLogprobs],
         outputs: List[CompletionOutput],
         finished: bool,
+        metrics: Optional[RequestMetrics] = None,
         lora_request: Optional[LoRARequest] = None,
     ) -> None:
         self.request_id = request_id
@@ -79,6 +87,7 @@ class RequestOutput:
         self.prompt_logprobs = prompt_logprobs
         self.outputs = outputs
         self.finished = finished
+        self.metrics = metrics
         self.lora_request = lora_request
 
     @classmethod
@@ -86,43 +95,50 @@ class RequestOutput:
         # Get the top-n sequences.
         n = seq_group.sampling_params.n
         seqs = seq_group.get_seqs()
-        if seq_group.sampling_params.use_beam_search:
-            sorting_key = lambda seq: seq.get_beam_search_score(
-                seq_group.sampling_params.length_penalty)
+        if n == 1:
+            top_n_seqs = seqs
         else:
-            # ruff: noqa: E731
-            sorting_key = lambda seq: seq.get_cumulative_logprob()
-        sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
-        top_n_seqs = sorted_seqs[:n]
+            if seq_group.sampling_params.use_beam_search:
+                sorting_key = lambda seq: seq.get_beam_search_score(
+                    seq_group.sampling_params.length_penalty)
+            else:
+                sorting_key = lambda seq: seq.get_cumulative_logprob()
+            sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
+            top_n_seqs = sorted_seqs[:n]
 
         # Create the outputs.
-        outputs: List[CompletionOutput] = []
-        for seq in top_n_seqs:
-            logprobs = seq.output_logprobs
-            if seq_group.sampling_params.logprobs is None:
-                # NOTE: We need to take care of this case because the sequence
-                # always has the logprobs of the sampled tokens even if the
-                # logprobs are not requested.
-                logprobs = None
-            finshed_reason = SequenceStatus.get_finished_reason(seq.status)
-            output = CompletionOutput(seqs.index(seq), seq.output_text,
-                                      seq.get_output_token_ids(),
-                                      seq.get_cumulative_logprob(), logprobs,
-                                      finshed_reason)
-            outputs.append(output)
+        # NOTE: We need omit logprobs here explicitly because the sequence
+        # always has the logprobs of the sampled tokens even if the
+        # logprobs are not requested.
+        include_logprobs = seq_group.sampling_params.logprobs
+        outputs = [
+            CompletionOutput(
+                seqs.index(seq),
+                seq.output_text,
+                seq.get_output_token_ids(),
+                seq.get_cumulative_logprob(),
+                seq.output_logprobs if include_logprobs else None,
+                SequenceStatus.get_finished_reason(seq.status),
+            ) for seq in top_n_seqs
+        ]
 
         # Every sequence in the sequence group should have the same prompt.
         prompt = seq_group.prompt
         prompt_token_ids = seq_group.prompt_token_ids
         prompt_logprobs = seq_group.prompt_logprobs
         finished = seq_group.is_finished()
-        return cls(seq_group.request_id,
-                   prompt,
-                   prompt_token_ids,
-                   prompt_logprobs,
-                   outputs,
-                   finished,
-                   lora_request=seq_group.lora_request)
+        finished_time = time.time() if finished else None
+        seq_group.set_finished_time(finished_time)
+        return cls(
+            seq_group.request_id,
+            prompt,
+            prompt_token_ids,
+            prompt_logprobs,
+            outputs,
+            finished,
+            seq_group.metrics,
+            lora_request=seq_group.lora_request,
+        )
 
     def __repr__(self) -> str:
         return (f"RequestOutput(request_id={self.request_id}, "
@@ -131,4 +147,5 @@ class RequestOutput:
                 f"prompt_logprobs={self.prompt_logprobs}, "
                 f"outputs={self.outputs}, "
                 f"finished={self.finished}, "
+                f"metrics={self.metrics}, "
                 f"lora_request={self.lora_request})")

+ 13 - 0
aphrodite/common/sampling_params.py

@@ -1,4 +1,5 @@
 """Sampling parameters for text generation."""
+import copy
 from enum import IntEnum
 from functools import cached_property
 from typing import Callable, List, Optional, Union
@@ -375,6 +376,18 @@ class SamplingParams:
             return SamplingType.RANDOM_SEED
         return SamplingType.RANDOM
 
+    def clone(self) -> "SamplingParams":
+        """Deep copy excluding LogitsProcessor objects.
+        LogitsProcessor objects are excluded because they may contain an
+        arbitrary, nontrivial amount of data.
+        """
+
+        logit_processor_refs = None if self.logits_processors is None else {
+            id(lp): lp
+            for lp in self.logits_processors
+        }
+        return copy.deepcopy(self, memo=logit_processor_refs)
+
     def __repr__(self) -> str:
         repr_str = "SamplingParams("
         for param, default_value in self.default_values.items():

+ 71 - 21
aphrodite/common/sequence.py

@@ -2,16 +2,21 @@
 import copy
 import enum
 from dataclasses import dataclass
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Optional, Union, TYPE_CHECKING
 
 from aphrodite.common.block import LogicalTokenBlock
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.lora.request import LoRARequest
 
+if TYPE_CHECKING:
+    import torch
+    from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
+
 
 @dataclass
 class Logprob:
     """Infos for supporting OpenAI compatible logprobs."""
+
     logprob: float
     decoded_token: Optional[str] = None
 
@@ -22,6 +27,7 @@ SampleLogprobs = List[Dict[int, Logprob]]
 
 class SequenceStatus(enum.Enum):
     """Status of a sequence."""
+
     WAITING = enum.auto()
     RUNNING = enum.auto()
     SWAPPED = enum.auto()
@@ -68,6 +74,7 @@ class RequestMetrics:
         time_in_queue: The time the request spent in the queue.
         finished_time: The time when the request was finished.
     """
+
     arrival_time: float
     last_token_time: float
     first_scheduled_time: Optional[float]
@@ -81,6 +88,8 @@ class SequenceData:
 
     Args:
         prompt_token_ids: The token IDs of the prompt.
+        output_token_ids: The token IDs of the output. Set to an empty list if
+            None.
 
     Attributes:
         prompt_token_ids: The token IDs of the prompt.
@@ -91,9 +100,13 @@ class SequenceData:
     def __init__(
         self,
         prompt_token_ids: List[int],
+        output_token_ids: Optional[List[int]] = None,
     ) -> None:
+        if output_token_ids is None:
+            output_token_ids = []
+
         self.prompt_token_ids = prompt_token_ids
-        self.output_token_ids: List[int] = []
+        self.output_token_ids = output_token_ids
         self.cumulative_logprob = 0.0
 
     def append_token_id(self, token_id: int, logprob: float) -> None:
@@ -117,6 +130,12 @@ class SequenceData:
             return self.prompt_token_ids[-1]
         return self.output_token_ids[-1]
 
+    def get_prompt_token_ids(self) -> int:
+        return self.prompt_token_ids
+
+    def get_output_token_ids(self) -> int:
+        return self.output_token_ids
+
     def __repr__(self) -> str:
         return (f"SequenceData("
                 f"prompt_token_ids={self.prompt_token_ids}, "
@@ -142,11 +161,13 @@ class Sequence:
         prompt: str,
         prompt_token_ids: List[int],
         block_size: int,
+        eos_token_id: Optional[int] = None,
         lora_request: Optional[LoRARequest] = None,
     ) -> None:
         self.seq_id = seq_id
         self.prompt = prompt
         self.block_size = block_size
+        self.eos_token_id = eos_token_id
         self.lora_request = lora_request
 
         self.data = SequenceData(prompt_token_ids)
@@ -164,7 +185,6 @@ class Sequence:
         # Input + output tokens
         self.tokens: Optional[List[str]] = None
         self.persistent_data = {}
-        self.persistent_data = {}
 
     @property
     def lora_int_id(self) -> int:
@@ -235,10 +255,12 @@ class Sequence:
     def get_cumulative_logprob(self) -> float:
         return self.data.cumulative_logprob
 
-    def get_beam_search_score(self,
-                              length_penalty: float = 0.0,
-                              seq_len: Optional[int] = None,
-                              eos_token_id: Optional[int] = None) -> float:
+    def get_beam_search_score(
+        self,
+        length_penalty: float = 1.0,
+        seq_len: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+    ) -> float:
         """Calculate the beam search score with length penalty.
 
         Adapted from
@@ -298,11 +320,13 @@ class SequenceGroup:
         self.request_id = request_id
         self.seqs_dict = {seq.seq_id: seq for seq in seqs}
         self.sampling_params = sampling_params
-        self.metrics = RequestMetrics(arrival_time=arrival_time,
-                                      last_token_time=arrival_time,
-                                      first_scheduled_time=None,
-                                      first_token_time=None,
-                                      time_in_queue=None)
+        self.metrics = RequestMetrics(
+            arrival_time=arrival_time,
+            last_token_time=arrival_time,
+            first_scheduled_time=None,
+            first_token_time=None,
+            time_in_queue=None,
+        )
         self.lora_request = lora_request
         self.prompt_logprobs: Optional[PromptLogprobs] = None
         self.state = SequenceGroupState()
@@ -366,12 +390,9 @@ class SequenceGroup:
         self,
         status: Optional[SequenceStatus] = None,
     ) -> List[Sequence]:
-        if status is None:
-            return list(self.seqs_dict.values())
-        else:
-            return [
-                seq for seq in self.seqs_dict.values() if seq.status == status
-            ]
+        return (list(self.seqs_dict.values()) if status is None else [
+            seq for seq in self.seqs_dict.values() if seq.status == status
+        ])
 
     def get_unfinished_seqs(self) -> List[Sequence]:
         return [
@@ -517,6 +538,35 @@ class SequenceGroupOutput:
                 and self.prompt_logprobs == other.prompt_logprobs)
 
 
-# For each sequence group, we generate a list of SequenceOutput object,
-# each of which contains one possible candidate for the next token.
-SamplerOutput = List[SequenceGroupOutput]
+@dataclass
+class SamplerOutput:
+    """For each sequence group, we generate a list of SequenceOutput object,
+    each of which contains one possible candidate for the next token.
+
+    This datastructure implements methods so it can be used like a list, but
+    also has optional fields for device tensors.
+    """
+
+    outputs: List[SequenceGroupOutput]
+
+    # On-device tensor containing probabilities of each token.
+    sampled_token_probs: Optional["torch.Tensor"] = None
+
+    # On-device tensor containing the sampled token ids.
+    sampled_token_ids: Optional["torch.Tensor"] = None
+
+    # Spec decode metrics populated by workers.
+    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
+
+    def __getitem__(self, idx: int):
+        return self.outputs[idx]
+
+    def __setitem__(self, idx: int, value):
+        self.outputs[idx] = value
+
+    def __len__(self):
+        return len(self.outputs)
+
+    def __eq__(self, other: object):
+        return (isinstance(other, self.__class__)
+                and self.outputs == other.outputs)

+ 67 - 27
aphrodite/common/utils.py

@@ -5,16 +5,21 @@ import subprocess
 import uuid
 import gc
 from platform import uname
-from loguru import logger
+from typing import List, Tuple, Union
+from packaging.version import parse, Version
 
 import psutil
 import torch
 import asyncio
 from functools import partial
-from typing import (Any, Awaitable, Callable, Hashable, Optional, TypeVar,
-                    List, Tuple, Union)
+from typing import (
+    Awaitable,
+    Callable,
+    TypeVar,
+)
 from collections import OrderedDict
-from packaging.version import parse, Version
+from typing import Any, Hashable, Optional
+from loguru import logger
 
 T = TypeVar("T")
 
@@ -23,7 +28,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
     "bfloat16": torch.bfloat16,
     "float": torch.float,
     "fp8_e5m2": torch.uint8,
-    "int8": torch.int8,
+    # "int8": torch.int8,
 }
 
 
@@ -113,12 +118,24 @@ def is_hip() -> bool:
     return torch.version.hip is not None
 
 
+def is_neuron() -> bool:
+    try:
+        import transformers_neuronx
+    except ImportError:
+        transformers_neuronx = None
+    return transformers_neuronx is not None
+
+
 def get_max_shared_memory_bytes(gpu: int = 0) -> int:
     """Returns the maximum shared memory per thread block in bytes."""
+    # NOTE: This import statement should be executed lazily since
+    # the Neuron-X backend does not have the `cuda_utils` module.
     from aphrodite._C import cuda_utils
-    # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
+
     max_shared_mem = (
         cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
+    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
+    # will fail
     assert max_shared_mem > 0, "max_shared_mem can not be zero"
     return int(max_shared_mem)
 
@@ -139,6 +156,7 @@ def in_wsl() -> bool:
 
 def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
     """Take a blocking function, and run it on in an executor thread.
+
     This function prevents the blocking function from blocking the
     asyncio event loop.
     The code in this function needs to be thread safe.
@@ -153,15 +171,33 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
 
 
 def get_ip() -> str:
+    # try ipv4
     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-    s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
-    return s.getsockname()[0]
+    try:
+        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
+        return s.getsockname()[0]
+    except OSError:
+        # try ipv6
+        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+        s.connect(("dns.google", 80))
+        return s.getsockname()[0]
+
+
+def get_distributed_init_method(ip: str, port: int) -> str:
+    return f"tcp://{ip}:{port}"
 
 
 def get_open_port() -> int:
-    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
-        s.bind(("", 0))
-        return s.getsockname()[1]
+    # try ipv4
+    try:
+        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+            s.bind(("", 0))
+            return s.getsockname()[1]
+    except OSError:
+        # try ipv6
+        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
+            s.bind(("", 0))
+            return s.getsockname()[1]
 
 
 def set_cuda_visible_devices(device_ids: List[int]) -> None:
@@ -170,18 +206,22 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None:
 
 def get_nvcc_cuda_version() -> Optional[Version]:
     cuda_home = os.environ.get('CUDA_HOME')
-    nvcc_path = os.path.join(cuda_home, 'bin', 'nvcc') if cuda_home else 'nvcc'
-
-    try:
-        nvcc_output = subprocess.check_output([nvcc_path, "-V"],
-                                              universal_newlines=True)
-        output = nvcc_output.split()
-        release_idx = output.index("release") + 1
-        nvcc_cuda_version = parse(output[release_idx].split(",")[0])
-        return nvcc_cuda_version
-    except (FileNotFoundError, subprocess.CalledProcessError):
-        logger.warning("nvcc not found. Skipping CUDA version check!")
-        return None
+    if not cuda_home:
+        cuda_home = '/usr/local/cuda'
+        if os.path.isfile(cuda_home + '/bin/nvcc'):
+            logger.info(
+                f'CUDA_HOME is not found in the environment. Using {cuda_home} '
+                'as CUDA_HOME.')
+        else:
+            logger.warning(
+                f'Not found nvcc in {cuda_home}. Skip cuda version check!')
+            return None
+    nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
+                                          universal_newlines=True)
+    output = nvcc_output.split()
+    release_idx = output.index("release") + 1
+    nvcc_cuda_version = parse(output[release_idx].split(",")[0])
+    return nvcc_cuda_version
 
 
 def _generate_random_fp8_e5m2(
@@ -248,8 +288,8 @@ def create_kv_caches_with_random(
                                 device=device)
         if cache_dtype == 'fp8_e5m2':
             _generate_random_fp8_e5m2(key_cache, -scale, scale)
-        elif cache_dtype == 'int8':
-            torch.randint(-128, 127, key_cache.size(), out=key_cache)
+        # elif cache_dtype == 'int8':
+        #     torch.randint(-128, 127, key_cache.size(), out=key_cache)
         elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
             key_cache.uniform_(-scale, scale)
         else:
@@ -265,8 +305,8 @@ def create_kv_caches_with_random(
                                   device=device)
         if cache_dtype == 'fp8_e5m2':
             _generate_random_fp8_e5m2(value_cache, -scale, scale)
-        elif cache_dtype == 'int8':
-            torch.randint(-128, 127, value_cache.size(), out=value_cache)
+        # elif cache_dtype == 'int8':
+        #     torch.randint(-128, 127, value_cache.size(), out=value_cache)
         elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
             value_cache.uniform_(-scale, scale)
         else:

+ 1 - 1
aphrodite/endpoints/llm.py

@@ -78,7 +78,7 @@ class LLM:
         seed: int = 0,
         gpu_memory_utilization: float = 0.9,
         swap_space: int = 4,
-        enforce_eager: bool = False,
+        enforce_eager: bool = True,
         max_context_len_to_capture: int = 8192,
         disable_custom_all_reduce: bool = False,
         **kwargs,

+ 104 - 95
aphrodite/endpoints/openai/api_server.py

@@ -191,6 +191,7 @@ async def validation_exception_handler(_, exc):
 @app.get("/health")
 async def health() -> Response:
     """Health check."""
+    await openai_serving_chat.engine.check_health()
     return Response(status_code=200)
 
 
@@ -526,104 +527,112 @@ async def get_kobold_lite_ui():
 # ============ KoboldAI API ============ #
 
 if __name__ == "__main__":
-    args = parse_args()
-
-    if args.launch_kobold_api:
-        logger.warning("Launching Kobold API server in addition to OpenAI. "
-                       "Keep in mind that the Kobold API routes are NOT "
-                       "protected via the API key.")
-        app.include_router(kai_api, prefix="/api/v1")
-        app.include_router(kai_api,
-                           prefix="/api/latest",
-                           include_in_schema=False)
-        app.include_router(extra_api, prefix="/api/extra")
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_origins=args.allowed_origins,
-        allow_credentials=args.allow_credentials,
-        allow_methods=args.allowed_methods,
-        allow_headers=args.allowed_headers,
-    )
+    try:
+        args = parse_args()
+
+        if args.launch_kobold_api:
+            logger.warning(
+                "Launching Kobold API server in addition to OpenAI. "
+                "Keep in mind that the Kobold API routes are NOT "
+                "protected via the API key.")
+            app.include_router(kai_api, prefix="/api/v1")
+            app.include_router(kai_api,
+                               prefix="/api/latest",
+                               include_in_schema=False)
+            app.include_router(extra_api, prefix="/api/extra")
+
+        app.add_middleware(
+            CORSMiddleware,
+            allow_origins=args.allowed_origins,
+            allow_credentials=args.allow_credentials,
+            allow_methods=args.allowed_methods,
+            allow_headers=args.allowed_headers,
+        )
+
+        if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
+            admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
+
+            if admin_key is None:
+                logger.warning("Admin key not provided. Admin operations will "
+                               "be disabled.")
+
+            @app.middleware("http")
+            async def authentication(request: Request, call_next):
+                excluded_paths = ["/api"]
+                if any(
+                        request.url.path.startswith(path)
+                        for path in excluded_paths):
+                    return await call_next(request)
+                if not request.url.path.startswith("/v1"):
+                    return await call_next(request)
 
-    if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
-        admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
+                auth_header = request.headers.get("Authorization")
+                api_key_header = request.headers.get("x-api-key")
 
-        if admin_key is None:
-            logger.warning("Admin key not provided. Admin operations will "
-                           "be disabled.")
+                if request.url.path.startswith("/v1/lora"):
+                    if admin_key is not None and api_key_header == admin_key:
+                        return await call_next(request)
+                    return JSONResponse(content={"error": "Unauthorized"},
+                                        status_code=401)
 
-        @app.middleware("http")
-        async def authentication(request: Request, call_next):
-            excluded_paths = ["/api"]
-            if any(
-                    request.url.path.startswith(path)
-                    for path in excluded_paths):
-                return await call_next(request)
-            if not request.url.path.startswith("/v1"):
+                if auth_header != "Bearer " + token and api_key_header != token:
+                    return JSONResponse(content={"error": "Unauthorized"},
+                                        status_code=401)
                 return await call_next(request)
 
-            auth_header = request.headers.get("Authorization")
-            api_key_header = request.headers.get("x-api-key")
-
-            if request.url.path.startswith("/v1/lora"):
-                if admin_key is not None and api_key_header == admin_key:
-                    return await call_next(request)
-                return JSONResponse(content={"error": "Unauthorized"},
-                                    status_code=401)
-
-            if auth_header != "Bearer " + token and api_key_header != token:
-                return JSONResponse(content={"error": "Unauthorized"},
-                                    status_code=401)
-            return await call_next(request)
-
-    for middleware in args.middleware:
-        module_path, object_name = middleware.rsplit(".", 1)
-        imported = getattr(importlib.import_module(module_path), object_name)
-        if inspect.isclass(imported):
-            app.add_middleware(imported)
-        elif inspect.iscoroutinefunction(imported):
-            app.middleware("http")(imported)
+        for middleware in args.middleware:
+            module_path, object_name = middleware.rsplit(".", 1)
+            imported = getattr(importlib.import_module(module_path),
+                               object_name)
+            if inspect.isclass(imported):
+                app.add_middleware(imported)
+            elif inspect.iscoroutinefunction(imported):
+                app.middleware("http")(imported)
+            else:
+                raise ValueError(f"Invalid middleware {middleware}. Must be a "
+                                 "function or a class.")
+
+        logger.debug(f"args: {args}")
+
+        if args.served_model_name is not None:
+            served_model = args.served_model_name
         else:
-            raise ValueError(f"Invalid middleware {middleware}. Must be a "
-                             "function or a class.")
-
-    logger.debug(f"args: {args}")
-
-    if args.served_model_name is not None:
-        served_model = args.served_model_name
-    else:
-        served_model = args.model
-
-    engine_args = AsyncEngineArgs.from_cli_args(args)
-    engine = AsyncAphrodite.from_engine_args(engine_args)
-    tokenizer = get_tokenizer(
-        engine_args.tokenizer,
-        tokenizer_mode=engine_args.tokenizer_mode,
-        trust_remote_code=engine_args.trust_remote_code,
-    )
-
-    chat_template = args.chat_template
-    if chat_template is None and tokenizer.chat_template is not None:
-        chat_template = tokenizer.chat_template
-
-    openai_serving_chat = OpenAIServingChat(engine, served_model,
-                                            args.response_role,
-                                            args.lora_modules,
-                                            args.chat_template)
-    openai_serving_completion = OpenAIServingCompletion(
-        engine, served_model, args.lora_modules)
-    engine_model_config = asyncio.run(engine.get_model_config())
-
-    if args.launch_kobold_api:
-        _set_badwords(tokenizer, engine_model_config.hf_config)
-
-    app.root_path = args.root_path
-    uvicorn.run(app,
-                host=args.host,
-                port=args.port,
-                log_level="info",
-                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
-                ssl_keyfile=args.ssl_keyfile,
-                ssl_certfile=args.ssl_certfile,
-                log_config=UVICORN_LOG_CONFIG)
+            served_model = args.model
+
+        engine_args = AsyncEngineArgs.from_cli_args(args)
+        engine = AsyncAphrodite.from_engine_args(engine_args)
+        tokenizer = get_tokenizer(
+            engine_args.tokenizer,
+            tokenizer_mode=engine_args.tokenizer_mode,
+            trust_remote_code=engine_args.trust_remote_code,
+        )
+
+        chat_template = args.chat_template
+        if chat_template is None and tokenizer.chat_template is not None:
+            chat_template = tokenizer.chat_template
+
+        openai_serving_chat = OpenAIServingChat(engine, served_model,
+                                                args.response_role,
+                                                args.lora_modules,
+                                                args.chat_template)
+        openai_serving_completion = OpenAIServingCompletion(
+            engine, served_model, args.lora_modules)
+        engine_model_config = asyncio.run(engine.get_model_config())
+
+        if args.launch_kobold_api:
+            _set_badwords(tokenizer, engine_model_config.hf_config)
+
+        app.root_path = args.root_path
+        uvicorn.run(app,
+                    host=args.host,
+                    port=args.port,
+                    log_level="info",
+                    timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
+                    ssl_keyfile=args.ssl_keyfile,
+                    ssl_certfile=args.ssl_certfile,
+                    log_config=UVICORN_LOG_CONFIG)
+    except KeyboardInterrupt:
+        logger.info("API server stopped by user. Exiting gracefully.")
+    except asyncio.exceptions.CancelledError:
+        logger.info("API server stopped due to a cancelled request. "
+                    "Exiting gracefully.")

+ 275 - 97
aphrodite/engine/aphrodite_engine.py

@@ -2,29 +2,61 @@ import copy
 from collections import defaultdict
 import os
 import time
-from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
-                    Union)
+import pickle
+import importlib
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 from loguru import logger
 
 import aphrodite
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                                     SchedulerConfig, LoRAConfig, DeviceConfig)
+from aphrodite.common.config import (
+    CacheConfig,
+    DeviceConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+)
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics import StatLogger, Stats
-from aphrodite.engine.ray_tools import (RayWorkerAphrodite, initialize_cluster,
-                                        ray)
-from aphrodite.common.logger import setup_logger
+from aphrodite.engine.ray_tools import (
+    RayWorkerAphrodite,
+    initialize_cluster,
+    ray,
+)
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (SamplerOutput, Sequence, SequenceGroup,
-                                       SequenceGroupOutput, SequenceOutput,
-                                       SequenceStatus, Logprob)
-from aphrodite.transformers_utils.tokenizer import (detokenize_incrementally,
-                                                    TokenizerGroup)
-from aphrodite.common.utils import (Counter, set_cuda_visible_devices, get_ip,
-                                    get_open_port)
+from aphrodite.common.sequence import (
+    Logprob,
+    SamplerOutput,
+    Sequence,
+    SequenceGroup,
+    SequenceGroupOutput,
+    SequenceOutput,
+    SequenceStatus,
+)
+from aphrodite.transformers_utils.tokenizer import (
+    detokenize_incrementally,
+    TokenizerGroup,
+)
+from aphrodite.common.utils import (
+    Counter,
+    set_cuda_visible_devices,
+    get_ip,
+    get_open_port,
+    get_distributed_init_method,
+)
+from aphrodite.common.logger import setup_logger
 
 if ray:
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -34,6 +66,17 @@ if TYPE_CHECKING:
 
 _LOCAL_LOGGING_INTERVAL_SEC = 5
 
+# A map between the device type (in device config) to its worker module.
+DEVICE_TO_WORKER_MODULE_MAP = {
+    "cuda": "aphrodite.task_handler.worker",
+    "neuron": "aphrodite.task_handler.neuron_worker",
+}
+
+# If the env var is set, it uses the Ray's compiled DAG API
+# which optimizes the control plane overhead.
+# Run APHRODITE with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
+USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
+
 
 class AphroditeEngine:
     """An LLM engine that receives requests and generates texts.
@@ -88,7 +131,7 @@ class AphroditeEngine:
             f"Context Length = {model_config.max_model_len}\n"
             f"Enforce Eager Mode = {model_config.enforce_eager}\n"
             f"KV Cache Data Type = {cache_config.cache_dtype}\n"
-            f"KV Cache Params Path = {cache_config.cache_quant_params_path}\n"
+            # f"KV Cache Params Path = {cache_config.cache_quant_params_path}\n"
             f"Device = {device_config.device}")
         # TODO: Print more configs in debug mode.
 
@@ -110,7 +153,20 @@ class AphroditeEngine:
             ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
             if ray_usage != "1":
                 os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
-            self._init_workers_ray(placement_group)
+            # Pass additional arguments to initialize the worker
+            additional_ray_args = {}
+            if self.parallel_config.ray_workers_use_nsight:
+                logger.info("Configuring Ray workers to use nsight.")
+                additional_ray_args = {
+                    "runtime_env": {
+                        "nsight": {
+                            "t": "cuda,cudnn,cublas",
+                            "o": "'worker_process_%p'",
+                            "cuda-graph-trace": "node",
+                        }
+                    }
+                }
+            self._init_workers_ray(placement_group, **additional_ray_args)
         else:
             self._init_workers()
 
@@ -124,22 +180,40 @@ class AphroditeEngine:
         if self.log_stats:
             self.stat_logger = StatLogger(
                 local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
-                labels=dict(model_name=model_config.model))
+                labels=dict(model_name=model_config.model),
+            )
+            self.stat_logger.info("cache_config", self.cache_config)
+
+        self.forward_dag = None
+        if USE_RAY_COMPILED_DAG:
+            self.forward_dag = self._compiled_ray_dag()
+
+    def __reduce__(self):
+        # This is to ensure that the AphroditeEngine is not referenced in
+        # the closure used to initialize Ray worker actors
+        raise RuntimeError("AphroditeEngine should not be pickled!")
 
     def get_tokenizer_for_seq(self, sequence: Sequence):
         return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
 
+    def _dispatch_worker(self):
+        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
+            self.device_config.device_type]
+        imported_worker = importlib.import_module(worker_module)
+        Worker = imported_worker.Worker
+        return Worker
+
     def _init_workers(self):
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        # pylint: disable=import-outside-toplevel
-        from aphrodite.task_handler.worker import Worker
+        Worker = self._dispatch_worker()
 
-        assert self.parallel_config.world_size == 1, (
-            "Ray is required if parallel_config.world_size > 1.")
+        assert (self.parallel_config.world_size == 1
+                ), "Ray is required if parallel_config.world_size > 1."
 
         self.workers: List[Worker] = []
-        distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
+        distributed_init_method = get_distributed_init_method(
+            get_ip(), get_open_port())
         self.driver_worker = Worker(
             self.model_config,
             self.parallel_config,
@@ -150,7 +224,7 @@ class AphroditeEngine:
             distributed_init_method=distributed_init_method,
             lora_config=self.lora_config,
             kv_cache_dtype=self.cache_config.cache_dtype,
-            kv_quant_params_path=(self.cache_config.cache_quant_params_path),
+            # kv_quant_params_path=(self.cache_config.cache_quant_params_path),
             is_driver_worker=True,
         )
         self._run_workers("init_model")
@@ -163,7 +237,8 @@ class AphroditeEngine:
             max_input_length=None,
             tokenizer_mode=self.model_config.tokenizer_mode,
             trust_remote_code=self.model_config.trust_remote_code,
-            revision=self.model_config.tokenizer_revision)
+            revision=self.model_config.tokenizer_revision,
+        )
         init_kwargs.update(tokenizer_init_kwargs)
         self.tokenizer: TokenizerGroup = TokenizerGroup(
             self.model_config.tokenizer, **init_kwargs)
@@ -230,18 +305,21 @@ class AphroditeEngine:
         for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
             worker.set_cuda_visible_devices.remote(node_gpus[node_id])
 
-        distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
+        distributed_init_method = get_distributed_init_method(
+            driver_ip, get_open_port())
 
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        # pylint: disable=import-outside-toplevel
-        from aphrodite.task_handler.worker import Worker
+        Worker = self._dispatch_worker()
 
         # Initialize torch distributed process group for the workers.
         model_config = copy.deepcopy(self.model_config)
         parallel_config = copy.deepcopy(self.parallel_config)
         scheduler_config = copy.deepcopy(self.scheduler_config)
         device_config = copy.deepcopy(self.device_config)
+        lora_config = copy.deepcopy(self.lora_config)
+        kv_cache_dtype = self.cache_config.cache_dtype
+        # kv_quant_params_path = self.cache_config.cache_quant_params_path
 
         for rank, (worker, (node_id,
                             _)) in enumerate(zip(self.workers,
@@ -257,29 +335,33 @@ class AphroditeEngine:
                     local_rank,
                     rank,
                     distributed_init_method,
-                    lora_config=self.lora_config,
-                    kv_cache_dtype=self.cache_config.cache_dtype,
-                    kv_quant_params_path=
-                    (self.cache_config.cache_quant_params_path),
+                    lora_config=lora_config,
+                    kv_cache_dtype=kv_cache_dtype,
+                    # kv_quant_params_path=kv_quant_params_path,
                 ))
 
         driver_rank = 0
         driver_local_rank = node_workers[driver_node_id].index(driver_rank)
         self.driver_worker = Worker(
-            model_config,
-            parallel_config,
-            scheduler_config,
-            device_config,
+            self.model_config,
+            self.parallel_config,
+            self.scheduler_config,
+            self.device_config,
             driver_local_rank,
             driver_rank,
             distributed_init_method,
             lora_config=self.lora_config,
-            kv_cache_dtype=self.cache_config.cache_dtype,
-            kv_quant_params_path=(self.cache_config.cache_quant_params_path),
+            kv_cache_dtype=kv_cache_dtype,
+            # kv_quant_params_path=kv_quant_params_path,
             is_driver_worker=True,
         )
 
-        self._run_workers("init_model", cupy_port=get_open_port())
+        # don't use cupy for eager mode
+        self._run_workers(
+            "init_model",
+            cupy_port=get_open_port()
+            if not model_config.enforce_eager else None,
+        )
         self._run_workers(
             "load_model",
             max_concurrent_workers=self.parallel_config.
@@ -302,7 +384,6 @@ class AphroditeEngine:
         Then, it calculate the maximum possible number of GPU and CPU blocks
         that can be allocated with the remaining free memory.
         More details can be found in the
-        # pylint: disable=line-too-long
         :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
         from class :class:`~aphrodite.task_handler.Worker`.
 
@@ -372,9 +453,11 @@ class AphroditeEngine:
         # Initialize the cluster.
         placement_group = initialize_cluster(parallel_config)
         # Create the LLM engine.
-        engine = cls(*engine_configs,
-                     placement_group,
-                     log_stats=not engine_args.disable_log_stats)
+        engine = cls(
+            *engine_configs,
+            placement_group,
+            log_stats=not engine_args.disable_log_stats,
+        )
         return engine
 
     def encode_request(
@@ -449,20 +532,34 @@ class AphroditeEngine:
                     sampling_params.prompt_logprobs
                     and sampling_params.prompt_logprobs > max_log_probs):
             raise ValueError(f"Cannot request more than "
-                             f"{max_log_probs} logprobs.")
+                             f"{max_log_probs} logprobs. "
+                             "Please increase the max_log_probs.")
         if arrival_time is None:
             arrival_time = time.monotonic()
         prompt_token_ids = self.encode_request(
             request_id=request_id,
             prompt=prompt,
             prompt_token_ids=prompt_token_ids,
-            lora_request=lora_request)
+            lora_request=lora_request,
+        )
 
         # Create the sequences.
         block_size = self.cache_config.block_size
         seq_id = next(self.seq_counter)
-        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
-                       lora_request)
+        eos_token_id = self.tokenizer.get_lora_tokenizer(
+            lora_request).eos_token_id
+        seq = Sequence(
+            seq_id,
+            prompt,
+            prompt_token_ids,
+            block_size,
+            eos_token_id,
+            lora_request,
+        )
+
+        # Defensive copy of SamplingParams, which are used by the sampler,
+        # this doesn't deep-copy LogitsProcessor objects
+        sampling_params = sampling_params.clone()
 
         # Create the sequence group.
         seq_group = SequenceGroup(request_id, [seq], sampling_params,
@@ -514,15 +611,15 @@ class AphroditeEngine:
         if early_stopping is True:
             return True
 
-        current_worst_score = (current_worst_seq.get_beam_search_score(
+        current_worst_score = current_worst_seq.get_beam_search_score(
             length_penalty=length_penalty,
-            eos_token_id=self.get_tokenizer_for_seq(
-                current_worst_seq).eos_token_id))
+            eos_token_id=current_worst_seq.eos_token_id,
+        )
         if early_stopping is False:
-            highest_attainable_score = (best_running_seq.get_beam_search_score(
+            highest_attainable_score = best_running_seq.get_beam_search_score(
                 length_penalty=length_penalty,
-                eos_token_id=self.get_tokenizer_for_seq(
-                    best_running_seq).eos_token_id))
+                eos_token_id=best_running_seq.eos_token_id,
+            )
         else:
             assert early_stopping == "never"
             if length_penalty > 0.0:
@@ -532,13 +629,14 @@ class AphroditeEngine:
                 max_possible_length = max(
                     best_running_seq.get_prompt_len() +
                     sampling_params.max_tokens,
-                    self.scheduler_config.max_model_len)
+                    self.scheduler_config.max_model_len,
+                )
                 highest_attainable_score = (
                     best_running_seq.get_beam_search_score(
                         length_penalty=length_penalty,
-                        eos_token_id=self.get_tokenizer_for_seq(
-                            best_running_seq).eos_token_id,
-                        seq_len=max_possible_length))
+                        eos_token_id=best_running_seq.eos_token_id,
+                        seq_len=max_possible_length,
+                    ))
             else:
                 # Otherwise, beam search will prefer shorter sequences. The
                 # highest attainable score calculation is based on the current
@@ -546,8 +644,8 @@ class AphroditeEngine:
                 highest_attainable_score = (
                     best_running_seq.get_beam_search_score(
                         length_penalty=length_penalty,
-                        eos_token_id=self.get_tokenizer_for_seq(
-                            best_running_seq).eos_token_id))
+                        eos_token_id=best_running_seq.eos_token_id,
+                    ))
         return current_worst_score >= highest_attainable_score
 
     def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
@@ -555,6 +653,16 @@ class AphroditeEngine:
         # Process prompt logprobs
         prompt_logprobs = outputs.prompt_logprobs
         if prompt_logprobs is not None:
+            # We can pick any sequence for the prompt.
+            seq = next(iter(seq_group.seqs_dict.values()))
+            all_token_ids = seq.get_token_ids()
+            for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
+                self._decode_logprobs(
+                    seq,
+                    seq_group.sampling_params,
+                    prompt_logprobs_for_token,
+                    all_token_ids[:i],
+                )
             seq_group.prompt_logprobs = prompt_logprobs
 
         # Process samples
@@ -638,10 +746,11 @@ class AphroditeEngine:
                              if seq.is_finished()]
         all_finished_seqs = existing_finished_seqs + new_finished_seqs
         # Sort the finished sequences by their scores.
-        all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
-            length_penalty=length_penalty,
-            eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
-                               reverse=True)
+        all_finished_seqs.sort(
+            key=lambda x: x[0].get_beam_search_score(
+                length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
+            reverse=True,
+        )
         for seq, parent, is_new in all_finished_seqs[:beam_width]:
             if is_new:
                 # A newly generated child sequence finishes and has a high
@@ -666,10 +775,11 @@ class AphroditeEngine:
         running_child_seqs = [(seq, parent) for seq, parent in child_seqs
                               if not seq.is_finished()]
         # Sort the running sequences by their scores.
-        running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
-            length_penalty=length_penalty,
-            eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
-                                reverse=True)
+        running_child_seqs.sort(
+            key=lambda x: x[0].get_beam_search_score(
+                length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
+            reverse=True,
+        )
 
         # Check if we can stop the beam search.
         if len(running_child_seqs) == 0:
@@ -684,7 +794,10 @@ class AphroditeEngine:
             current_worst_seq = all_finished_seqs[beam_width - 1][0]
             stop_beam_search = self._check_beam_search_early_stopping(
                 seq_group.sampling_params.early_stopping,
-                seq_group.sampling_params, best_running_seq, current_worst_seq)
+                seq_group.sampling_params,
+                best_running_seq,
+                current_worst_seq,
+            )
 
         if stop_beam_search:
             # Stop the beam search and remove all the running sequences from
@@ -726,13 +839,16 @@ class AphroditeEngine:
     def _process_model_outputs(
             self, output: SamplerOutput,
             scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
+        now = time.time()
         # Update the scheduled sequence groups with the model outputs.
         scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
+
         # If prefix caching is enabled, mark all blocks in the sequence groups
         # as completed so that future requests don't attempt to recompute them
         if self.cache_config.context_shift:
             for seq_group in scheduled_seq_groups:
                 self.scheduler.mark_blocks_as_computed(seq_group)
+
         for seq_group, outputs in zip(scheduled_seq_groups, output):
             self._process_sequence_group_outputs(seq_group, outputs)
 
@@ -742,6 +858,7 @@ class AphroditeEngine:
         # Create the outputs.
         request_outputs: List[RequestOutput] = []
         for seq_group in scheduled_seq_groups:
+            seq_group.maybe_set_first_token_time(now)
             request_output = RequestOutput.from_seq_group(seq_group)
             request_outputs.append(request_output)
         for seq_group in scheduler_outputs.ignored_seq_groups:
@@ -751,6 +868,7 @@ class AphroditeEngine:
         # Log stats.
         if self.log_stats:
             self.stat_logger.log(self._get_stats(scheduler_outputs))
+
         return request_outputs
 
     def step(self) -> List[RequestOutput]:
@@ -815,7 +933,9 @@ class AphroditeEngine:
                     "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
                     "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
                     "blocks_to_copy": scheduler_outputs.blocks_to_copy,
-                })
+                },
+                use_ray_compiled_dag=USE_RAY_COMPILED_DAG,
+            )
 
             # Only the driver worker returns the sampling results.
             output = all_outputs[0]
@@ -840,10 +960,10 @@ class AphroditeEngine:
         gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
 
         num_total_cpu = self.cache_config.num_cpu_blocks
-        cpu_cache_usage = 0.
+        cpu_cache_usage = 0.0
         if num_total_cpu > 0:
-            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
-            )
+            num_free_cpu = (
+                self.scheduler.block_manager.get_num_free_cpu_blocks())
             cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
 
         # Scheduler State
@@ -898,16 +1018,24 @@ class AphroditeEngine:
             time_e2e_requests=time_e2e_requests,
         )
 
-    def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
-                         logprobs: Dict[int, Logprob],
-                         all_input_ids: List[int]) -> None:
+    def _decode_logprobs(
+        self,
+        seq: Sequence,
+        prms: SamplingParams,
+        logprobs: Dict[int, Logprob],
+        all_input_ids: List[int],
+    ) -> None:
         if not logprobs:
             return
         for token_id, sample_logprob in logprobs.items():
-            if (sample_logprob.decoded_token is None and token_id != -1):
+            if sample_logprob.decoded_token is None and token_id != -1:
                 all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
-                # pylint: disable=unused-variable
-                _, new_text, prefix_offset, read_offset = detokenize_incrementally(
+                (
+                    _,
+                    new_text,
+                    prefix_offset,
+                    read_offset,
+                ) = detokenize_incrementally(
                     self.get_tokenizer_for_seq(seq),
                     all_input_ids=all_input_ids_with_logprob,
                     prev_tokens=seq.tokens,
@@ -924,16 +1052,21 @@ class AphroditeEngine:
         all_input_ids = seq.get_token_ids()
         self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
                               all_input_ids)
-        (new_tokens, new_output_text, prefix_offset,
-         read_offset) = detokenize_incrementally(
-             self.get_tokenizer_for_seq(seq),
-             all_input_ids=all_input_ids,
-             prev_tokens=seq.tokens,
-             prefix_offset=seq.prefix_offset,
-             read_offset=seq.read_offset,
-             skip_special_tokens=prms.skip_special_tokens,
-             spaces_between_special_tokens=prms.spaces_between_special_tokens,
-         )
+
+        (
+            new_tokens,
+            new_output_text,
+            prefix_offset,
+            read_offset,
+        ) = detokenize_incrementally(
+            self.get_tokenizer_for_seq(seq),
+            all_input_ids=all_input_ids,
+            prev_tokens=seq.tokens,
+            prefix_offset=seq.prefix_offset,
+            read_offset=seq.read_offset,
+            skip_special_tokens=prms.skip_special_tokens,
+            spaces_between_special_tokens=prms.spaces_between_special_tokens,
+        )
         if seq.tokens is None:
             seq.tokens = new_tokens
         else:
@@ -968,15 +1101,18 @@ class AphroditeEngine:
             return
 
         # Check if the sequence has generated the EOS token.
-        if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
-                == self.get_tokenizer_for_seq(seq).eos_token_id):
+        if (not sampling_params.ignore_eos
+            ) and seq.get_last_token_id() == seq.eos_token_id:
             seq.status = SequenceStatus.FINISHED_STOPPED
             return
 
     def _finalize_sequence(self, seq: Sequence,
                            sampling_params: SamplingParams,
                            stop_string: str) -> None:
-        if not sampling_params.include_stop_str_in_output and stop_string:
+        if sampling_params.include_stop_str_in_output:
+            return
+
+        if stop_string and seq.output_text.endswith(stop_string):
             # Truncate the output text so that the stop string is
             # not included in the output.
             seq.output_text = seq.output_text[:-len(stop_string)]
@@ -1005,6 +1141,7 @@ class AphroditeEngine:
         driver_args: Optional[List[Any]] = None,
         driver_kwargs: Optional[Dict[str, Any]] = None,
         max_concurrent_workers: Optional[int] = None,
+        use_ray_compiled_dag: bool = False,
         **kwargs,
     ) -> Any:
         """Runs the given method on all workers."""
@@ -1013,11 +1150,17 @@ class AphroditeEngine:
             raise NotImplementedError(
                 "max_concurrent_workers is not supported yet.")
 
-        # Start the ray workers first.
-        ray_worker_outputs = [
-            worker.execute_method.remote(method, *args, **kwargs)
-            for worker in self.workers
-        ]
+        if use_ray_compiled_dag:
+            # Right now, compiled DAG can only accept a single
+            # input.
+            # TODO: Fix it.
+            output_channels = self.forward_dag.execute(1)
+        else:
+            # Start the ray workers first.
+            ray_worker_outputs = [
+                worker.execute_method.remote(method, *args, **kwargs)
+                for worker in self.workers
+            ]
 
         if driver_args is None:
             driver_args = args
@@ -1030,10 +1173,45 @@ class AphroditeEngine:
 
         # Get the results of the ray workers.
         if self.workers:
-            ray_worker_outputs = ray.get(ray_worker_outputs)
+            if use_ray_compiled_dag:
+                try:
+                    ray_worker_outputs = [
+                        pickle.loads(chan.begin_read())
+                        for chan in output_channels
+                    ]
+                finally:
+                    # Has to call end_read in order to reuse the DAG.
+                    for chan in output_channels:
+                        chan.end_read()
+            else:
+                ray_worker_outputs = ray.get(ray_worker_outputs)
 
         return [driver_worker_output] + ray_worker_outputs
 
+    def _compiled_ray_dag(self):
+        from packaging import version
+        import pkg_resources
+
+        required_version = "2.9"
+        current_version = pkg_resources.get_distribution("ray").version
+
+        if version.parse(current_version) < version.parse(required_version):
+            raise ValueError(f"Ray version {required_version} or greater is "
+                             f"required, but found {current_version}")
+
+        from ray.dag import MultiOutputNode, InputNode
+
+        assert self.parallel_config.worker_use_ray
+
+        # Right now, compiled DAG requires at least 1 arg. We send
+        # a dummy value for now. It will be fixed soon.
+        with InputNode() as input_data:
+            forward_dag = MultiOutputNode([
+                worker.execute_model_compiled_dag_remote.bind(input_data)
+                for worker in self.workers
+            ])
+        return forward_dag.experimental_compile()
+
     def check_health(self) -> None:
         """Raises an error if engine is unhealthy."""
         self._check_if_any_actor_is_dead()
@@ -1052,7 +1230,7 @@ class AphroditeEngine:
                 dead_actors.append(actor)
         if dead_actors:
             raise RuntimeError("At least one Worker is dead. "
-                               f"Dead workers: {dead_actors}")
+                               f"Dead Workers: {dead_actors}. ")
 
 
 setup_logger()

+ 373 - 240
aphrodite/engine/args_tools.py

@@ -3,22 +3,29 @@ import dataclasses
 from dataclasses import dataclass
 from typing import Optional, Tuple
 
-from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                                     SchedulerConfig, LoRAConfig, DeviceConfig)
+from aphrodite.common.config import (
+    CacheConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+    DeviceConfig,
+)
 
 
 @dataclass
 class EngineArgs:
-    """Arguments for the Aphrodite engine."""
+    """Arguments for Aphrodite engine."""
+
     model: str
     tokenizer: Optional[str] = None
-    tokenizer_mode: str = 'auto'
+    tokenizer_mode: str = "auto"
     trust_remote_code: bool = False
     download_dir: Optional[str] = None
-    load_format: str = 'auto'
-    dtype: str = 'auto'
-    kv_cache_dtype: str = 'auto'
-    kv_quant_params_path: str = None
+    load_format: str = "auto"
+    dtype: str = "auto"
+    kv_cache_dtype: str = "auto"
+    # kv_quant_params_path: str = None
     seed: int = 0
     max_model_len: Optional[int] = None
     worker_use_ray: bool = False
@@ -32,24 +39,26 @@ class EngineArgs:
     max_num_batched_tokens: Optional[int] = None
     max_num_seqs: int = 256
     max_paddings: int = 256
-    max_log_probs: int = 10
+    max_log_probs: int = 10  # OpenAI default is 5, setting to 10 because ST
     disable_log_stats: bool = False
     revision: Optional[str] = None
+    code_revision: Optional[str] = None
     tokenizer_revision: Optional[str] = None
     quantization: Optional[str] = None
     load_in_4bit: bool = False
     load_in_8bit: bool = False
     load_in_smooth: bool = False
-    enforce_eager: bool = False
+    enforce_eager: bool = True
     max_context_len_to_capture: int = 8192
     disable_custom_all_reduce: bool = False
     enable_lora: bool = False
     max_loras: int = 1
     max_lora_rank: int = 16
     lora_extra_vocab_size: int = 256
-    lora_dtype = 'auto'
+    lora_dtype = "auto"
     max_cpu_loras: Optional[int] = None
-    device: str = 'cuda'
+    device: str = "auto"
+    ray_workers_use_nsight: bool = False
 
     def __post_init__(self):
         if self.tokenizer is None:
@@ -65,245 +74,333 @@ class EngineArgs:
 
         # Model arguments
         parser.add_argument(
-            '--model',
+            "--model",
             type=str,
-            default='EleutherAI/pythia-70m-deduped',
-            help='name or path of the huggingface model to use')
+            default="EleutherAI/pythia-70m-deduped",
+            help="name or path of the huggingface model to use",
+        )
         parser.add_argument(
-            '--tokenizer',
+            "--tokenizer",
             type=str,
             default=EngineArgs.tokenizer,
-            help='name or path of the huggingface tokenizer to use')
+            help="name or path of the huggingface tokenizer to use",
+        )
         parser.add_argument(
-            '--revision',
+            "--revision",
             type=str,
             default=None,
-            help='the specific model version to use. It can be a branch '
-            'name, a tag name, or a commit id. If unspecified, will use '
-            'the default version.')
+            help="the specific model version to use. It can be a branch "
+            "name, a tag name, or a commit id. If unspecified, will use "
+            "the default version.",
+        )
         parser.add_argument(
-            '--tokenizer-revision',
+            "--code-revision",
             type=str,
             default=None,
-            help='the specific tokenizer version to use. It can be a branch '
-            'name, a tag name, or a commit id. If unspecified, will use '
-            'the default version.')
-        parser.add_argument('--tokenizer-mode',
-                            type=str,
-                            default=EngineArgs.tokenizer_mode,
-                            choices=['auto', 'slow'],
-                            help='tokenizer mode. "auto" will use the fast '
-                            'tokenizer if available, and "slow" will '
-                            'always use the slow tokenizer.')
-        parser.add_argument('--trust-remote-code',
-                            action='store_true',
-                            help='trust remote code from huggingface')
-        parser.add_argument('--download-dir',
-                            type=str,
-                            default=EngineArgs.download_dir,
-                            help='directory to download and load the weights, '
-                            'default to the default cache dir of '
-                            'huggingface')
-        parser.add_argument(
-            '--load-format',
+            help="the specific revision to use for the model code on "
+            "Hugging Face Hub. It can be a branch name, a tag name, or a "
+            "commit id. If unspecified, will use the default version.",
+        )
+        parser.add_argument(
+            "--tokenizer-revision",
+            type=str,
+            default=None,
+            help="the specific tokenizer version to use. It can be a branch "
+            "name, a tag name, or a commit id. If unspecified, will use "
+            "the default version.",
+        )
+        parser.add_argument(
+            "--tokenizer-mode",
+            type=str,
+            default=EngineArgs.tokenizer_mode,
+            choices=["auto", "slow"],
+            help='tokenizer mode. "auto" will use the fast '
+            'tokenizer if available, and "slow" will '
+            "always use the slow tokenizer.",
+        )
+        parser.add_argument(
+            "--trust-remote-code",
+            action="store_true",
+            help="trust remote code from huggingface",
+        )
+        parser.add_argument(
+            "--download-dir",
+            type=str,
+            default=EngineArgs.download_dir,
+            help="directory to download and load the weights, "
+            "default to the default cache dir of "
+            "huggingface",
+        )
+        parser.add_argument(
+            "--load-format",
             type=str,
             default=EngineArgs.load_format,
-            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
-            help='The format of the model weights to load. '
+            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
+            help="The format of the model weights to load. "
             '"auto" will try to load the weights in the safetensors format '
-            'and fall back to the pytorch bin format if safetensors format '
-            'is not available. '
+            "and fall back to the pytorch bin format if safetensors format "
+            "is not available. "
             '"pt" will load the weights in the pytorch bin format. '
             '"safetensors" will load the weights in the safetensors format. '
             '"npcache" will load the weights in pytorch format and store '
-            'a numpy cache to speed up the loading. '
+            "a numpy cache to speed up the loading. "
             '"dummy" will initialize the weights with random values, '
-            'which is mainly for profiling.')
+            "which is mainly for profiling.",
+        )
         parser.add_argument(
-            '--dtype',
+            "--dtype",
             type=str,
             default=EngineArgs.dtype,
             choices=[
-                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
+                "auto", "half", "float16", "bfloat16", "float", "float32"
             ],
-            help='data type for model weights and activations. '
+            help="data type for model weights and activations. "
             'The "auto" option will use FP16 precision '
-            'for FP32 and FP16 models, and BF16 precision '
-            'for BF16 models.')
+            "for FP32 and FP16 models, and BF16 precision "
+            "for BF16 models.",
+        )
         parser.add_argument(
-            '--kv-cache-dtype',
+            "--kv-cache-dtype",
             type=str,
-            choices=['auto', 'fp8_e5m2', 'int8'],
+            # choices=["auto", "fp8_e5m2", "int8"],
+            choices=['auto', 'fp8_e5m2'],
             default=EngineArgs.kv_cache_dtype,
             help='Data type for kv cache storage. If "auto", will use model '
-            'data type. Note FP8 is not supported when cuda version is '
-            'lower than 11.8.')
+            "data type. Note FP8 is not supported when cuda version is "
+            "lower than 11.8.",
+        )
+        # parser.add_argument(
+        #     "--kv-quant-params-path",
+        #     type=str,
+        #     default=EngineArgs.kv_quant_params_path,
+        #     help="Path to scales and zero points of KV cache "
+        #     "quantization. Only applicable when kv-cache-dtype "
+        #     "is int8.",
+        # )
         parser.add_argument(
-            '--kv-quant-params-path',
-            type=str,
-            default=EngineArgs.kv_quant_params_path,
-            help='Path to scales and zero points of KV cache '
-            'quantization. Only applicable when kv-cache-dtype '
-            'is int8.')
-        parser.add_argument('--max-model-len',
-                            type=int,
-                            default=EngineArgs.max_model_len,
-                            help='model context length. If unspecified, '
-                            'will be automatically derived from the model.')
+            "--max-model-len",
+            type=int,
+            default=EngineArgs.max_model_len,
+            help="model context length. If unspecified, "
+            "will be automatically derived from the model.",
+        )
         # Parallel arguments
-        parser.add_argument('--worker-use-ray',
-                            action='store_true',
-                            help='use Ray for distributed serving, will be '
-                            'automatically set when using more than 1 GPU')
-        parser.add_argument('--pipeline-parallel-size',
-                            '-pp',
-                            type=int,
-                            default=EngineArgs.pipeline_parallel_size,
-                            help='number of pipeline stages')
-        parser.add_argument('--tensor-parallel-size',
-                            '-tp',
-                            type=int,
-                            default=EngineArgs.tensor_parallel_size,
-                            help='number of tensor parallel replicas')
         parser.add_argument(
-            '--max-parallel-loading-workers',
+            "--worker-use-ray",
+            action="store_true",
+            help="use Ray for distributed serving, will be "
+            "automatically set when using more than 1 GPU",
+        )
+        parser.add_argument(
+            "--pipeline-parallel-size",
+            "-pp",
+            type=int,
+            default=EngineArgs.pipeline_parallel_size,
+            help="number of pipeline stages",
+        )
+        parser.add_argument(
+            "--tensor-parallel-size",
+            "-tp",
+            type=int,
+            default=EngineArgs.tensor_parallel_size,
+            help="number of tensor parallel replicas",
+        )
+        parser.add_argument(
+            "--max-parallel-loading-workers",
             type=int,
             default=EngineArgs.max_parallel_loading_workers,
-            help='load model sequentially in multiple batches, '
-            'to avoid RAM OOM when using tensor '
-            'parallel and large models')
+            help="load model sequentially in multiple batches, "
+            "to avoid RAM OOM when using tensor "
+            "parallel and large models",
+        )
+        parser.add_argument(
+            "--ray-workers-use-nsight",
+            action="store_true",
+            help="If specified, use nsight to profile ray workers",
+        )
         # KV cache arguments
-        parser.add_argument('--block-size',
-                            type=int,
-                            default=EngineArgs.block_size,
-                            choices=[8, 16, 32],
-                            help='token block size')
-        parser.add_argument('--context-shift',
-                            action='store_true',
-                            help='Enable context shifting.')
-        parser.add_argument('--seed',
+        parser.add_argument(
+            "--block-size",
+            type=int,
+            default=EngineArgs.block_size,
+            choices=[8, 16, 32, 128],
+            help="token block size",
+        )
+        parser.add_argument(
+            "--context-shift",
+            action="store_true",
+            help="Enable context shifting.",
+        )
+        parser.add_argument("--seed",
                             type=int,
                             default=EngineArgs.seed,
-                            help='random seed')
-        parser.add_argument('--swap-space',
-                            type=int,
-                            default=EngineArgs.swap_space,
-                            help='CPU swap space size (GiB) per GPU')
+                            help="random seed")
         parser.add_argument(
-            '--gpu-memory-utilization',
-            '-gmu',
+            "--swap-space",
+            type=int,
+            default=EngineArgs.swap_space,
+            help="CPU swap space size (GiB) per GPU",
+        )
+        parser.add_argument(
+            "--gpu-memory-utilization",
+            "-gmu",
             type=float,
             default=EngineArgs.gpu_memory_utilization,
-            help='the fraction of GPU memory to be used for '
-            'the model executor, which can range from 0 to 1.'
-            'If unspecified, will use the default value of 0.9.')
-        parser.add_argument('--max-num-batched-tokens',
-                            type=int,
-                            default=EngineArgs.max_num_batched_tokens,
-                            help='maximum number of batched tokens per '
-                            'iteration')
-        parser.add_argument('--max-num-seqs',
-                            type=int,
-                            default=EngineArgs.max_num_seqs,
-                            help='maximum number of sequences per iteration')
-        parser.add_argument('--max-paddings',
-                            type=int,
-                            default=EngineArgs.max_paddings,
-                            help='maximum number of paddings in a batch')
-        parser.add_argument('--max-log-probs',
-                            type=int,
-                            default=EngineArgs.max_log_probs,
-                            help='maximum number of log probabilities to '
-                            'return.')
-        parser.add_argument('--disable-log-stats',
-                            action='store_true',
-                            help='disable logging statistics')
+            help="the fraction of GPU memory to be used for "
+            "the model executor, which can range from 0 to 1."
+            "If unspecified, will use the default value of 0.9.",
+        )
+        parser.add_argument(
+            "--max-num-batched-tokens",
+            type=int,
+            default=EngineArgs.max_num_batched_tokens,
+            help="maximum number of batched tokens per "
+            "iteration",
+        )
+        parser.add_argument(
+            "--max-num-seqs",
+            type=int,
+            default=EngineArgs.max_num_seqs,
+            help="maximum number of sequences per iteration",
+        )
+        parser.add_argument(
+            "--max-paddings",
+            type=int,
+            default=EngineArgs.max_paddings,
+            help="maximum number of paddings in a batch",
+        )
+        parser.add_argument(
+            "--max-log-probs",
+            type=int,
+            default=EngineArgs.max_log_probs,
+            help="maximum number of log probabilities to "
+            "return.",
+        )
+        parser.add_argument(
+            "--disable-log-stats",
+            action="store_true",
+            help="disable logging statistics",
+        )
         # Quantization settings.
-        parser.add_argument('--quantization',
-                            '-q',
-                            type=str,
-                            choices=[
-                                'aqlm', 'awq', 'bnb', 'exl2', 'gguf', 'gptq',
-                                'quip', 'squeezellm', 'marlin', None
-                            ],
-                            default=EngineArgs.quantization,
-                            help='Method used to quantize the weights. If '
-                            'None, we first check the `quantization_config` '
-                            'attribute in the model config file. If that is '
-                            'None, we assume the model weights are not '
-                            'quantized and use `dtype` to determine the data '
-                            'type of the weights.')
-        parser.add_argument('--load-in-4bit',
-                            action='store_true',
-                            help='Load the FP16 model in 4-bit format. Also '
-                            'works with AWQ models. Throughput at 2.5x of '
-                            'FP16.')
-        parser.add_argument('--load-in-8bit',
-                            action='store_true',
-                            help='Load the FP16 model in 8-bit format. '
-                            'Throughput at 0.3x of FP16.')
-        parser.add_argument('--load-in-smooth',
-                            action='store_true',
-                            help='Load the FP16 model in smoothquant '
-                            '8bit format. Throughput at 0.7x of FP16. ')
-        parser.add_argument('--enforce-eager',
-                            action='store_true',
-                            help='Always use eager-mode PyTorch. If False, '
-                            'will use eager mode and CUDA graph in hybrid '
-                            'for maximal performance and flexibility.')
-        parser.add_argument('--max-context-len-to-capture',
-                            type=int,
-                            default=EngineArgs.max_context_len_to_capture,
-                            help='maximum context length covered by CUDA '
-                            'graphs. When a sequence has context length '
-                            'larger than this, we fall back to eager mode.')
-        parser.add_argument('--disable-custom-all-reduce',
-                            action='store_true',
-                            default=EngineArgs.disable_custom_all_reduce,
-                            help='See ParallelConfig')
+        parser.add_argument(
+            "--quantization",
+            "-q",
+            type=str,
+            choices=[
+                "aqlm",
+                "awq",
+                "bnb",
+                "exl2",
+                "gguf",
+                "gptq",
+                "quip",
+                "squeezellm",
+                "marlin",
+                None,
+            ],
+            default=EngineArgs.quantization,
+            help="Method used to quantize the weights. If "
+            "None, we first check the `quantization_config` "
+            "attribute in the model config file. If that is "
+            "None, we assume the model weights are not "
+            "quantized and use `dtype` to determine the data "
+            "type of the weights.",
+        )
+        parser.add_argument(
+            "--load-in-4bit",
+            action="store_true",
+            help="Load the FP16 model in 4-bit format. Also "
+            "works with AWQ models. Throughput at 2.5x of "
+            "FP16.",
+        )
+        parser.add_argument(
+            "--load-in-8bit",
+            action="store_true",
+            help="Load the FP16 model in 8-bit format. "
+            "Throughput at 0.3x of FP16.",
+        )
+        parser.add_argument(
+            "--load-in-smooth",
+            action="store_true",
+            help="Load the FP16 model in smoothquant "
+            "8bit format. Throughput at 0.7x of FP16. ",
+        )
+        parser.add_argument(
+            "--enforce-eager",
+            type=lambda x: (str(x).lower() == 'true'),
+            default=EngineArgs.enforce_eager,
+            help="Always use eager-mode PyTorch. If False, "
+            "will use eager mode and CUDA graph in hybrid "
+            "for maximal performance and flexibility.",
+        )
+        parser.add_argument(
+            "--max-context-len-to-capture",
+            type=int,
+            default=EngineArgs.max_context_len_to_capture,
+            help="maximum context length covered by CUDA "
+            "graphs. When a sequence has context length "
+            "larger than this, we fall back to eager mode.",
+        )
+        parser.add_argument(
+            "--disable-custom-all-reduce",
+            action="store_true",
+            default=EngineArgs.disable_custom_all_reduce,
+            help="See ParallelConfig",
+        )
         # LoRA related configs
-        parser.add_argument('--enable-lora',
-                            action='store_true',
-                            help='If True, enable handling of LoRA adapters.')
-        parser.add_argument('--max-loras',
-                            type=int,
-                            default=EngineArgs.max_loras,
-                            help='Max number of LoRAs in a single batch.')
-        parser.add_argument('--max-lora-rank',
-                            type=int,
-                            default=EngineArgs.max_lora_rank,
-                            help='Max LoRA rank.')
         parser.add_argument(
-            '--lora-extra-vocab-size',
+            "--enable-lora",
+            action="store_true",
+            help="If True, enable handling of LoRA adapters.",
+        )
+        parser.add_argument(
+            "--max-loras",
+            type=int,
+            default=EngineArgs.max_loras,
+            help="Max number of LoRAs in a single batch.",
+        )
+        parser.add_argument(
+            "--max-lora-rank",
+            type=int,
+            default=EngineArgs.max_lora_rank,
+            help="Max LoRA rank.",
+        )
+        parser.add_argument(
+            "--lora-extra-vocab-size",
             type=int,
             default=EngineArgs.lora_extra_vocab_size,
-            help=('Maximum size of extra vocabulary that can be '
-                  'present in a LoRA adapter (added to the base '
-                  'model vocabulary).'))
+            help=("Maximum size of extra vocabulary that can be "
+                  "present in a LoRA adapter (added to the base "
+                  "model vocabulary)."),
+        )
         parser.add_argument(
-            '--lora-dtype',
+            "--lora-dtype",
             type=str,
             default=EngineArgs.lora_dtype,
-            choices=['auto', 'float16', 'bfloat16', 'float32'],
-            help=('Data type for LoRA. If auto, will default to '
-                  'base model dtype.'))
+            choices=["auto", "float16", "bfloat16", "float32"],
+            help=("Data type for LoRA. If auto, will default to "
+                  "base model dtype."),
+        )
         parser.add_argument(
-            '--max-cpu-loras',
+            "--max-cpu-loras",
             type=int,
             default=EngineArgs.max_cpu_loras,
-            help=('Maximum number of LoRAs to store in CPU memory. '
-                  'Must be >= than max_num_seqs. '
-                  'Defaults to max_num_seqs.'))
-        parser.add_argument('--device',
-                            type=str,
-                            default=EngineArgs.device,
-                            choices=['cuda'],
-                            help=('Device to use for model execution. '
-                                  'Currently, only "cuda" is supported.'))
+            help=("Maximum number of LoRAs to store in CPU memory. "
+                  "Must be >= than max_num_seqs. "
+                  "Defaults to max_num_seqs."),
+        )
+        parser.add_argument(
+            "--device",
+            type=str,
+            default=EngineArgs.device,
+            choices=["cuda"],
+            help=("Device to use for model execution. "
+                  'Currently, only "cuda" is supported.'),
+        )
         return parser
 
     @classmethod
-    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
+    def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
         # Get the list of attributes of this dataclass.
         attrs = [attr.name for attr in dataclasses.fields(cls)]
         # Set the attributes from the parsed arguments.
@@ -313,63 +410,99 @@ class EngineArgs:
     def create_engine_configs(
         self,
     ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
-               DeviceConfig, Optional[LoRAConfig]]:
+               DeviceConfig, Optional[LoRAConfig], ]:
         device_config = DeviceConfig(self.device)
         model_config = ModelConfig(
-            self.model, self.tokenizer, self.tokenizer_mode,
-            self.trust_remote_code, self.download_dir, self.load_format,
-            self.dtype, self.seed, self.revision, self.tokenizer_revision,
-            self.max_model_len, self.quantization, self.load_in_4bit,
-            self.load_in_8bit, self.load_in_smooth, self.enforce_eager,
-            self.max_context_len_to_capture, self.max_log_probs)
-        cache_config = CacheConfig(self.block_size,
-                                   self.gpu_memory_utilization,
-                                   self.swap_space, self.kv_cache_dtype,
-                                   self.kv_quant_params_path,
-                                   model_config.get_sliding_window(),
-                                   self.context_shift)
-        parallel_config = ParallelConfig(self.pipeline_parallel_size,
-                                         self.tensor_parallel_size,
-                                         self.worker_use_ray,
-                                         self.max_parallel_loading_workers,
-                                         self.disable_custom_all_reduce)
-        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
-                                           self.max_num_seqs,
-                                           model_config.max_model_len,
-                                           self.max_paddings)
-        lora_config = LoRAConfig(
+            self.model,
+            self.tokenizer,
+            self.tokenizer_mode,
+            self.trust_remote_code,
+            self.download_dir,
+            self.load_format,
+            self.dtype,
+            self.seed,
+            self.revision,
+            self.code_revision,
+            self.tokenizer_revision,
+            self.max_model_len,
+            self.quantization,
+            self.load_in_4bit,
+            self.load_in_8bit,
+            self.load_in_smooth,
+            self.enforce_eager,
+            self.max_context_len_to_capture,
+            self.max_log_probs,
+        )
+        cache_config = CacheConfig(
+            self.block_size,
+            self.gpu_memory_utilization,
+            self.swap_space,
+            self.kv_cache_dtype,
+            # self.kv_quant_params_path,
+            model_config.get_sliding_window(),
+            self.context_shift,
+        )
+        parallel_config = ParallelConfig(
+            self.pipeline_parallel_size,
+            self.tensor_parallel_size,
+            self.worker_use_ray,
+            self.max_parallel_loading_workers,
+            self.disable_custom_all_reduce,
+            self.ray_workers_use_nsight,
+        )
+        scheduler_config = SchedulerConfig(
+            self.max_num_batched_tokens,
+            self.max_num_seqs,
+            model_config.max_model_len,
+            self.max_paddings,
+        )
+        lora_config = (LoRAConfig(
             max_lora_rank=self.max_lora_rank,
             max_loras=self.max_loras,
             lora_extra_vocab_size=self.lora_extra_vocab_size,
             lora_dtype=self.lora_dtype,
-            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
-            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
-        return (model_config, cache_config, parallel_config, scheduler_config,
-                device_config, lora_config)
+            max_cpu_loras=self.max_cpu_loras
+            if self.max_cpu_loras and self.max_cpu_loras > 0 else None,
+        ) if self.enable_lora else None)
+        return (
+            model_config,
+            cache_config,
+            parallel_config,
+            scheduler_config,
+            device_config,
+            lora_config,
+        )
 
 
 @dataclass
 class AsyncEngineArgs(EngineArgs):
     """Arguments for asynchronous Aphrodite engine."""
+
     engine_use_ray: bool = False
     disable_log_requests: bool = False
-    max_log_len: Optional[int] = None
+    max_log_len: int = 0
 
     @staticmethod
     def add_cli_args(
             parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
         parser = EngineArgs.add_cli_args(parser)
-        parser.add_argument('--engine-use-ray',
-                            action='store_true',
-                            help='use Ray to start the LLM engine in a '
-                            'separate process as the server process.')
-        parser.add_argument('--disable-log-requests',
-                            action='store_true',
-                            help='disable logging requests')
-        parser.add_argument('--max-log-len',
-                            type=int,
-                            default=None,
-                            help='max number of prompt characters or prompt '
-                            'ID numbers being printed in log. '
-                            'Default: unlimited.')
+        parser.add_argument(
+            "--engine-use-ray",
+            action="store_true",
+            help="use Ray to start the LLM engine in a "
+            "separate process as the server process.",
+        )
+        parser.add_argument(
+            "--disable-log-requests",
+            action="store_true",
+            help="disable logging requests",
+        )
+        parser.add_argument(
+            "--max-log-len",
+            type=int,
+            default=0,
+            help="max number of prompt characters or prompt "
+            "ID numbers being printed in log. "
+            "Default: unlimited.",
+        )
         return parser

+ 45 - 28
aphrodite/engine/async_aphrodite.py

@@ -15,7 +15,7 @@ from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 
 ENGINE_ITERATION_TIMEOUT_S = int(
-    os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", 60))
+    os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", "120"))
 
 
 class AsyncEngineDeadError(RuntimeError):
@@ -26,13 +26,18 @@ def _raise_exception_on_finish(
         task: asyncio.Task, error_callback: Callable[[Exception],
                                                      None]) -> None:
     msg = ("Task finished unexpectedly. This should never happen! "
-           "Please open an issue on Github.")
+           "Please open an issue on Github. Include your full error "
+           "log after killing the process with Ctrl+C.")
 
     exception = None
     try:
         task.result()
         # NOTE: This will be thrown if task exits normally (which it should not)
         raise AsyncEngineDeadError(msg)
+    except asyncio.exceptions.CancelledError:
+        pass
+    except KeyboardInterrupt:
+        raise
     except Exception as e:
         exception = e
         logger.error("Engine background task failed", exc_info=e)
@@ -318,6 +323,8 @@ class AsyncAphrodite:
             async frontend will be executed in a separate process as the
             model workers.
         log_requests: Whether to log the requests.
+        max_log_len: Maximum number of prompt characters or prompt ID numbers
+            being printed in log.
         start_engine_loop: If True, the background task to run the engine
             will be automatically started in the generate call.
         *args: Arguments for AphroditeEngine.
@@ -331,7 +338,7 @@ class AsyncAphrodite:
                  engine_use_ray: bool,
                  *args,
                  log_requests: bool = True,
-                 max_log_len: Optional[int] = None,
+                 max_log_len: int = 0,
                  start_engine_loop: bool = True,
                  **kwargs) -> None:
         self.worker_use_ray = worker_use_ray
@@ -456,23 +463,27 @@ class AsyncAphrodite:
 
     async def run_engine_loop(self):
         has_requests_in_progress = False
-        while True:
-            if not has_requests_in_progress:
-                logger.debug("Waiting for new requests...")
-                await self._request_tracker.wait_for_new_requests()
-                logger.debug("Got new requests!")
-
-            # Abort if iteration takes too long due to unrecoverable errors
-            # (eg. NCCL timeouts).
-            try:
-                has_requests_in_progress = await asyncio.wait_for(
-                    self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
-            except asyncio.TimeoutError as exc:
-                logger.error(
-                    "Engine iteration timed out. This should never happen!")
-                self.set_errored(exc)
-                raise
-            await asyncio.sleep(0)
+        try:
+            while True:
+                if not has_requests_in_progress:
+                    logger.debug("Waiting for new requests...")
+                    await self._request_tracker.wait_for_new_requests()
+                    logger.debug("Got new requests!")
+
+                # Abort if iteration takes too long due to unrecoverable errors
+                # (eg. NCCL timeouts).
+                try:
+                    has_requests_in_progress = await asyncio.wait_for(
+                        self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
+                except asyncio.TimeoutError as exc:
+                    logger.error(
+                        "Engine iteration timed out. This should never happen!"
+                    )
+                    self.set_errored(exc)
+                    raise
+                await asyncio.sleep(0)
+        except KeyboardInterrupt:
+            logger.info("Engine loop interrupted. Exiting gracefully.")
 
     async def add_request(
         self,
@@ -494,8 +505,7 @@ class AsyncAphrodite:
                                                               max_log_len]
             logger.info(f"Received request {request_id}: "
                         f"prompt: {shortened_prompt!r}, "
-                        f"sampling params: {sampling_params}, "
-                        f"prompt token ids: {shortened_token_ids}, "
+                        f"sampling_params: {sampling_params}, "
                         f"lora_request: {lora_request}.")
 
         if not self.is_running:
@@ -510,6 +520,7 @@ class AsyncAphrodite:
 
         if arrival_time is None:
             arrival_time = time.time()
+
         if self.engine_use_ray:
             prompt_token_ids = await self.engine.encode_request_async.remote(
                 request_id=request_id,
@@ -609,15 +620,21 @@ class AsyncAphrodite:
         arrival_time = time.monotonic()
 
         try:
-            stream = await self.add_request(request_id,
-                                            prompt,
-                                            sampling_params,
-                                            prompt_token_ids=prompt_token_ids,
-                                            arrival_time=arrival_time,
-                                            lora_request=lora_request)
+            stream = await self.add_request(
+                request_id,
+                prompt,
+                sampling_params,
+                prompt_token_ids=prompt_token_ids,
+                arrival_time=arrival_time,
+                lora_request=lora_request,
+            )
 
             async for request_output in stream:
                 yield request_output
+        except asyncio.exceptions.CancelledError:
+            logger.info(f"Request {request_id} cancelled.")
+            self._abort(request_id)
+            raise
         except (Exception, asyncio.CancelledError) as e:
             # If there is an exception or coroutine is cancelled, abort the
             # request.

+ 79 - 21
aphrodite/engine/metrics.py

@@ -1,11 +1,18 @@
+from loguru import logger
+from prometheus_client import (
+    Counter,
+    Gauge,
+    Histogram,
+    Info,
+    REGISTRY,
+    disable_created_metrics,
+)
+
 import time
 import numpy as np
 from typing import Dict, List
 from dataclasses import dataclass
 
-from prometheus_client import Counter, Gauge, Histogram, disable_created_metrics
-from loguru import logger
-
 disable_created_metrics()
 
 # The begin-* and end* here are used by the documentation generator
@@ -16,58 +23,104 @@ disable_created_metrics()
 class Metrics:
 
     def __init__(self, labelnames: List[str]):
+        # Unregister any existing Aphrodite collectors
+        for collector in list(REGISTRY._collector_to_names):
+            if hasattr(collector, "_name") and "aphrodite" in collector._name:
+                REGISTRY.unregister(collector)
+
+        # Config Information
+        self.info_cache_config = Info(
+            name="aphrodite:cache_config",
+            documentation="information of cache_config",
+        )
+
         # System stats
         self.gauge_scheduler_running = Gauge(
             name="aphrodite:num_requests_running",
             documentation="Number of requests currently running on GPU.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.gauge_scheduler_swapped = Gauge(
             name="aphrodite:num_requests_swapped",
             documentation="Number of requests swapped to CPU.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.gauge_scheduler_waiting = Gauge(
             name="aphrodite:num_requests_waiting",
             documentation="Number of requests waiting to be processed.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.gauge_gpu_cache_usage = Gauge(
             name="aphrodite:gpu_cache_usage_perc",
             documentation="GPU KV-cache usage. 1 means 100 percent usage.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.gauge_cpu_cache_usage = Gauge(
             name="aphrodite:cpu_cache_usage_perc",
             documentation="CPU KV-cache usage. 1 means 100 percent usage.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
 
         # Raw stats from last model iteration
         self.counter_prompt_tokens = Counter(
             name="aphrodite:prompt_tokens_total",
             documentation="Number of prefill tokens processed.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.counter_generation_tokens = Counter(
             name="aphrodite:generation_tokens_total",
             documentation="Number of generation tokens processed.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.histogram_time_to_first_token = Histogram(
             name="aphrodite:time_to_first_token_seconds",
             documentation="Histogram of time to first token in seconds.",
             labelnames=labelnames,
             buckets=[
-                0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
-                0.75, 1.0, 2.5, 5.0, 7.5, 10.0
-            ])
+                0.001,
+                0.005,
+                0.01,
+                0.02,
+                0.04,
+                0.06,
+                0.08,
+                0.1,
+                0.25,
+                0.5,
+                0.75,
+                1.0,
+                2.5,
+                5.0,
+                7.5,
+                10.0,
+            ],
+        )
         self.histogram_time_per_output_token = Histogram(
             name="aphrodite:time_per_output_token_seconds",
             documentation="Histogram of time per output token in seconds.",
             labelnames=labelnames,
             buckets=[
-                0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
-                1.0, 2.5
-            ])
+                0.01,
+                0.025,
+                0.05,
+                0.075,
+                0.1,
+                0.15,
+                0.2,
+                0.3,
+                0.4,
+                0.5,
+                0.75,
+                1.0,
+                2.5,
+            ],
+        )
         self.histogram_e2e_request_latency = Histogram(
             name="aphrodite:e2e_request_latency_seconds",
             documentation="Histogram of end to end request latency in seconds.",
             labelnames=labelnames,
-            buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
+            buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0],
+        )
 
         # Legacy metrics
         self.gauge_avg_prompt_throughput = Gauge(
@@ -88,6 +141,7 @@ class Metrics:
 @dataclass
 class Stats:
     """Created by AphroditeEngine for use by StatLogger."""
+
     now: float
 
     # System stats.
@@ -121,6 +175,10 @@ class StatLogger:
         self.labels = labels
         self.metrics = Metrics(labelnames=list(labels.keys()))
 
+    def info(self, type: str, obj: object) -> None:
+        if type == "cache_config":
+            self.metrics.info_cache_config.info(obj.metrics_info())
+
     def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
         return float(np.sum(tracked_stats) / (now - self.last_local_log))
 
@@ -174,8 +232,8 @@ class StatLogger:
 
     def log(self, stats: Stats) -> None:
         """Called by AphroditeEngine.
-           Logs to prometheus and tracked stats every iteration.
-           Logs to Stdout every self.local_interval seconds."""
+        Logs to prometheus and tracked stats every iteration.
+        Logs to Stdout every self.local_interval seconds."""
 
         # Log to prometheus.
         self._log_prometheus(stats)
@@ -186,7 +244,6 @@ class StatLogger:
 
         # Log locally every local_interval seconds.
         if self._local_interval_elapsed(stats.now):
-
             # Compute summary metrics for tracked stats (and log them to
             # prometheus if applicable).
             prompt_throughput = self._get_throughput(self.num_prompt_tokens,
@@ -195,7 +252,8 @@ class StatLogger:
                 self.num_generation_tokens, now=stats.now)
             self._log_prometheus_interval(
                 prompt_throughput=prompt_throughput,
-                generation_throughput=generation_throughput)
+                generation_throughput=generation_throughput,
+            )
 
             # Log to stdout.
             logger.info(

+ 21 - 5
aphrodite/engine/ray_tools.py

@@ -1,3 +1,5 @@
+import pickle
+
 from typing import Optional, List, Tuple, TYPE_CHECKING
 from loguru import logger
 
@@ -13,10 +15,14 @@ try:
 
         def __init__(self, init_cached_hf_modules=False) -> None:
             if init_cached_hf_modules:
-                # pylint: disable=import-outside-toplevel
                 from transformers.dynamic_module_utils import init_hf_modules
                 init_hf_modules()
             self.worker = None
+            # Since the compiled DAG runs a main execution
+            # in a different thread that calls cuda.set_device.
+            # The flag indicates is set_device is called on
+            # that thread.
+            self.compiled_dag_cuda_device_set = False
 
         def init_worker(self, worker_init_fn):
             self.worker = worker_init_fn()
@@ -39,6 +45,17 @@ try:
         def set_cuda_visible_devices(self, device_ids) -> None:
             set_cuda_visible_devices(device_ids)
 
+        def execute_model_compiled_dag_remote(self, ignored):
+            """Used only when compiled DAG is enabled."""
+            import torch
+            if not self.compiled_dag_cuda_device_set:
+                torch.cuda.set_device(self.worker.device)
+                self.compiled_dag_cuda_device_set = True
+
+            output = self.worker.execute_model()
+            output = pickle.dumps(output)
+            return output
+
 except ImportError as e:
     logger.warning(f"Failed to import Ray with {e!r}. "
                    "For distributed inference, please install Ray with "
@@ -64,10 +81,9 @@ def initialize_cluster(
             the default Ray cluster address.
 
     Returns:
-        A tuple of (`distributed_init_method`, `placement_group`). The
-        `distributed_init_method` is the address for initializing the
-        distributed backend. `placement_group` includes the specification
-        of the resources for each distributed worker.
+        An optional `PlacementGroup`. It includes the specification
+        of the resources for each distributed worker. None if Ray is
+        not used.
     """
     if parallel_config.worker_use_ray or engine_use_ray:
         if ray is None:

+ 0 - 0
aphrodite/modeling/layers/triton_kernel/__init__.py → aphrodite/executor/__init__.py


+ 76 - 0
aphrodite/executor/executor_base.py

@@ -0,0 +1,76 @@
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional
+
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig)
+from aphrodite.lora.request import LoRARequest
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+
+
+class ExecutorBase(ABC):
+    """Base class for all executors.
+
+    An executor is responsible for executing the model on a specific device
+    type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
+    that can execute the model on multiple devices.
+    """
+
+    @abstractmethod
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        lora_config: Optional[LoRAConfig],
+    ) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def execute_model(self,
+                      seq_group_metadata_list: List[SequenceGroupMetadata],
+                      blocks_to_swap_in: Dict[int, int],
+                      blocks_to_swap_out: Dict[int, int],
+                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+        """Executes one model step on the given sequences."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def add_lora(self, lora_request: LoRARequest) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def remove_lora(self, lora_id: int) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def list_loras(self) -> List[int]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def check_health(self) -> None:
+        """Checks if the executor is healthy. If not, it should raise an
+        exception."""
+        raise NotImplementedError
+
+
+class ExecutorAsyncBase(ExecutorBase):
+
+    @abstractmethod
+    async def execute_model_async(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
+        """Executes one model step on the given sequences."""
+        raise NotImplementedError
+
+    @abstractmethod
+    async def check_health_async(self) -> None:
+        """Checks if the executor is healthy. If not, it should raise an
+        exception."""
+        raise NotImplementedError

+ 153 - 0
aphrodite/executor/gpu_executor.py

@@ -0,0 +1,153 @@
+from typing import Dict, List, Optional
+
+from loguru import logger
+
+from aphrodite.lora.request import LoRARequest
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig)
+from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
+from aphrodite.executor.utils import check_block_size_valid
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.utils import (get_ip, get_open_port,
+                                    get_distributed_init_method, make_async)
+
+
+class GPUExecutor(ExecutorBase):
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        lora_config: Optional[LoRAConfig],
+    ) -> None:
+        self.model_config = model_config
+        self.cache_config = cache_config
+        self.lora_config = lora_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+
+        # Instantiate the worker and load the model to GPU.
+        self._init_worker()
+
+        # Profile the memory usage and initialize the cache.
+        self._init_cache()
+
+    def _init_worker(self):
+        # Lazy import the Worker to avoid importing torch.cuda/xformers
+        # before CUDA_VISIBLE_DEVICES is set in the Worker
+        from aphrodite.task_handler.worker import Worker
+
+        assert self.parallel_config.world_size == 1, (
+            "GPUExecutor only supports single GPU.")
+
+        distributed_init_method = get_distributed_init_method(
+            get_ip(), get_open_port())
+        self.driver_worker = Worker(
+            self.model_config,
+            self.parallel_config,
+            self.scheduler_config,
+            self.device_config,
+            local_rank=0,
+            rank=0,
+            distributed_init_method=distributed_init_method,
+            lora_config=self.lora_config,
+            kv_cache_dtype=self.cache_config.cache_dtype,
+            is_driver_worker=True,
+        )
+        self.driver_worker.init_device()
+        self.driver_worker.load_model()
+
+    def _init_cache(self) -> None:
+        """Profiles the memory usage and initializes the KV cache.
+
+        The engine first profiles the existing memory usage.
+        Then, it allocates the remaining memory for KV blocks.
+
+        .. tip::
+            You may limit the usage of GPU memory
+            by adjusting the `gpu_memory_utilization` parameter.
+        """
+        # Get the maximum number of blocks that can be allocated on GPU and CPU.
+        num_gpu_blocks, num_cpu_blocks = (
+            self.driver_worker.profile_num_available_blocks(
+                block_size=self.cache_config.block_size,
+                gpu_memory_utilization=self.cache_config.
+                gpu_memory_utilization,
+                cpu_swap_space=self.cache_config.swap_space_bytes,
+                cache_dtype=self.cache_config.cache_dtype,
+            ))
+
+        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
+                    f"# CPU blocks: {num_cpu_blocks}")
+
+        logger.info(
+            f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"  # noqa: E501
+        )
+
+        check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
+                               self.model_config.max_model_len)
+
+        self.cache_config.num_gpu_blocks = num_gpu_blocks
+        self.cache_config.num_cpu_blocks = num_cpu_blocks
+
+        # Initialize the cache.
+        self.driver_worker.init_cache_engine(cache_config=self.cache_config)
+        # Warm up the model. This includes capturing the model into CUDA graph
+        # if enforce_eager is False.
+        self.driver_worker.warm_up_model()
+
+    def execute_model(self,
+                      seq_group_metadata_list: List[SequenceGroupMetadata],
+                      blocks_to_swap_in: Dict[int, int],
+                      blocks_to_swap_out: Dict[int, int],
+                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+        output = self.driver_worker.execute_model(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+        )
+        return output
+
+    def add_lora(self, lora_request: LoRARequest) -> bool:
+        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
+        return self.driver_worker.add_lora(lora_request)
+
+    def remove_lora(self, lora_id: int) -> bool:
+        assert lora_id > 0, "lora_id must be greater than 0."
+        return self.driver_worker.remove_lora(lora_id)
+
+    def list_loras(self) -> List[int]:
+        return self.driver_worker.list_loras()
+
+    def check_health(self) -> None:
+        # GPUExecutor will always be healthy as long as
+        # it's running.
+        return
+
+
+class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
+
+    async def execute_model_async(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
+        output = await make_async(self.driver_worker.execute_model)(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy)
+        return output
+
+    async def check_health_async(self) -> None:
+        # GPUExecutor will always be healthy as long as
+        # it's running.
+        return

+ 78 - 0
aphrodite/executor/neuron_executor.py

@@ -0,0 +1,78 @@
+from typing import Dict, List, Optional
+
+from aphrodite.lora.request import LoRARequest
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig)
+from aphrodite.executor.executor_base import ExecutorBase
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+
+
+class NeuronExecutor(ExecutorBase):
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        lora_config: Optional[LoRAConfig],
+    ) -> None:
+        self.model_config = model_config
+        self.cache_config = cache_config
+        assert lora_config is None, "LoRA is not supported for Neuron backend."
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+
+        # Set the number of GPU blocks to be the same as the maximum number of
+        # sequences that can be processed in a single batch. This is equivalent
+        # to schedule without PagedAttention.
+        self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs
+        self.cache_config.num_cpu_blocks = 0
+
+        # Instantiate the worker and load the model to the device.
+        self._init_worker()
+
+    def _init_worker(self):
+        from aphrodite.task_handler.neuron_worker import NeuronWorker
+
+        self.driver_worker = NeuronWorker(
+            self.model_config,
+            self.parallel_config,
+            self.scheduler_config,
+            self.device_config,
+        )
+        self.driver_worker.init_device()
+        self.driver_worker.load_model()
+
+    def execute_model(self,
+                      seq_group_metadata_list: List[SequenceGroupMetadata],
+                      blocks_to_swap_in: Dict[int, int],
+                      blocks_to_swap_out: Dict[int, int],
+                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+        assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
+                and blocks_to_copy == {}), (
+                    "Cache operations are not supported for Neuron backend.")
+
+        output = self.driver_worker.execute_model(
+            seq_group_metadata_list=seq_group_metadata_list)
+        return output
+
+    def add_lora(self, lora_request: LoRARequest) -> bool:
+        raise NotImplementedError(
+            "LoRA is not implemented for neuron backend.")
+
+    def remove_lora(self, lora_id: int) -> bool:
+        raise NotImplementedError(
+            "LoRA is not implemented for neuron backend.")
+
+    def list_loras(self) -> List[int]:
+        raise NotImplementedError(
+            "LoRA is not implemented for neuron backend.")
+
+    def check_health(self) -> None:
+        # NeuronExecutor will always be healthy as long as
+        # it's running.
+        return

+ 452 - 0
aphrodite/executor/ray_gpu_executor.py

@@ -0,0 +1,452 @@
+import asyncio
+import copy
+from collections import defaultdict
+import os
+import pickle
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from loguru import logger
+
+from aphrodite.common.config import (
+    CacheConfig,
+    DeviceConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+)
+from aphrodite.engine.ray_tools import RayWorkerAphrodite, ray
+from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
+from aphrodite.executor.utils import check_block_size_valid
+from aphrodite.lora.request import LoRARequest
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.utils import (
+    set_cuda_visible_devices,
+    get_ip,
+    get_open_port,
+    get_distributed_init_method,
+    make_async,
+)
+
+if ray is not None:
+    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
+
+if TYPE_CHECKING:
+    from ray.util.placement_group import PlacementGroup
+
+# If the env var is set, it uses the Ray's compiled DAG API
+# which optimizes the control plane overhead.
+# Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
+USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
+
+
+class RayGPUExecutor(ExecutorBase):
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        lora_config: Optional[LoRAConfig],
+    ) -> None:
+        self.model_config = model_config
+        self.cache_config = cache_config
+        self.lora_config = lora_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+
+        assert self.parallel_config.worker_use_ray
+        placement_group = self.parallel_config.placement_group
+
+        # Disable Ray usage stats collection.
+        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
+        if ray_usage != "1":
+            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
+
+        # Create the parallel GPU workers.
+        self._init_workers_ray(placement_group)
+
+        # Profile the memory usage and initialize the cache.
+        self._init_cache()
+
+        self.forward_dag = None
+        if USE_RAY_COMPILED_DAG:
+            self.forward_dag = self._compiled_ray_dag()
+
+    def _init_workers_ray(self, placement_group: "PlacementGroup",
+                          **ray_remote_kwargs):
+        if self.parallel_config.tensor_parallel_size == 1:
+            # For single GPU case, we use a ray worker with constrained memory.
+            num_gpus = self.cache_config.gpu_memory_utilization
+        else:
+            # Otherwise, the ray workers are allocated with a full GPU.
+            num_gpus = 1
+
+        # The driver dummy worker does not actually use any resources.
+        # It holds the resource for the driver worker.
+        self.driver_dummy_worker: RayWorkerAphrodite = None
+        # The remaining workers are the actual ray actors.
+        self.workers: List[RayWorkerAphrodite] = []
+
+        # Create the workers.
+        driver_ip = get_ip()
+        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
+            if not bundle.get("GPU", 0):
+                continue
+            scheduling_strategy = PlacementGroupSchedulingStrategy(
+                placement_group=placement_group,
+                placement_group_capture_child_tasks=True,
+                placement_group_bundle_index=bundle_id,
+            )
+            worker = ray.remote(
+                num_cpus=0,
+                num_gpus=num_gpus,
+                scheduling_strategy=scheduling_strategy,
+                **ray_remote_kwargs,
+            )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
+
+            worker_ip = ray.get(worker.get_node_ip.remote())
+            if worker_ip == driver_ip and self.driver_dummy_worker is None:
+                # If the worker is on the same node as the driver, we use it
+                # as the resource holder for the driver process.
+                self.driver_dummy_worker = worker
+            else:
+                # Else, added to the list of workers.
+                self.workers.append(worker)
+
+        if self.driver_dummy_worker is None:
+            raise ValueError(
+                "Ray does not allocate any GPUs on the driver node. Consider "
+                "adjusting the Ray placement group or running the driver on a "
+                "GPU node.")
+
+        # Get the set of GPU IDs used on each node.
+        driver_node_id, driver_gpu_ids = ray.get(
+            self.driver_dummy_worker.get_node_and_gpu_ids.remote())
+        worker_node_and_gpu_ids = ray.get(
+            [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
+
+        node_workers = defaultdict(list)
+        node_gpus = defaultdict(list)
+
+        node_workers[driver_node_id].append(0)
+        node_gpus[driver_node_id].extend(driver_gpu_ids)
+        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
+                                               start=1):
+            node_workers[node_id].append(i)
+            node_gpus[node_id].extend(gpu_ids)
+        for node_id, gpu_ids in node_gpus.items():
+            node_gpus[node_id] = sorted(gpu_ids)
+
+        # Set CUDA_VISIBLE_DEVICES for the driver and workers.
+        set_cuda_visible_devices(node_gpus[driver_node_id])
+        for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
+            worker.set_cuda_visible_devices.remote(node_gpus[node_id])
+
+        distributed_init_method = get_distributed_init_method(
+            driver_ip, get_open_port())
+
+        # Lazy import the Worker to avoid importing torch.cuda/xformers
+        # before CUDA_VISIBLE_DEVICES is set in the Worker
+        from aphrodite.task_handler.worker import Worker
+
+        model_config = copy.deepcopy(self.model_config)
+        parallel_config = copy.deepcopy(self.parallel_config)
+        scheduler_config = copy.deepcopy(self.scheduler_config)
+        device_config = copy.deepcopy(self.device_config)
+        lora_config = copy.deepcopy(self.lora_config)
+        kv_cache_dtype = self.cache_config.cache_dtype
+
+        # Initialize the actual workers with the Worker class.
+        for rank, (worker, (node_id, _)) in enumerate(
+                zip(self.workers, worker_node_and_gpu_ids),
+                start=1,
+        ):
+            local_rank = node_workers[node_id].index(rank)
+            worker.init_worker.remote(
+                lambda rank=rank, local_rank=local_rank: Worker(
+                    model_config,
+                    parallel_config,
+                    scheduler_config,
+                    device_config,
+                    local_rank,
+                    rank,
+                    distributed_init_method,
+                    lora_config=lora_config,
+                    kv_cache_dtype=kv_cache_dtype,
+                ))
+
+        # Initialize the driver worker with the Worker class.
+        driver_rank = 0
+        driver_local_rank = node_workers[driver_node_id].index(driver_rank)
+        self.driver_worker = Worker(
+            self.model_config,
+            self.parallel_config,
+            self.scheduler_config,
+            self.device_config,
+            driver_local_rank,
+            driver_rank,
+            distributed_init_method,
+            lora_config=self.lora_config,
+            kv_cache_dtype=kv_cache_dtype,
+            is_driver_worker=True,
+        )
+
+        # FIXME(woosuk): We are not properly initializing cupy NCCL when
+        # we have multiple nodes.
+        self._run_workers(
+            "init_device",
+            cupy_port=get_open_port()
+            if not model_config.enforce_eager else None,
+        )
+        self._run_workers(
+            "load_model",
+            max_concurrent_workers=self.parallel_config.
+            max_parallel_loading_workers,
+        )
+
+    def _init_cache(self) -> None:
+        """Profiles the memory usage and initializes the KV cache.
+
+        The engine will first conduct a profiling of the existing memory usage.
+        Then, it calculate the maximum possible number of GPU and CPU blocks
+        that can be allocated with the remaining free memory.
+        More details can be found in the
+        :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
+        from class :class:`~aphrodite.task_handler.Worker`.
+
+        Afterwards, as there may be multiple workers,
+        we take the minimum number of blocks across all workers
+        to ensure this can be applied to all of them.
+
+        Finally, the engine will initialize the KV cache
+        with the calculated number of blocks.
+
+        .. tip::
+            You may limit the usage of GPU memory
+            by adjusting the `gpu_memory_utilization` parameter.
+        """  # noqa: E501
+        # Get the maximum number of blocks that can be allocated on GPU and CPU.
+        num_blocks = self._run_workers(
+            "profile_num_available_blocks",
+            block_size=self.cache_config.block_size,
+            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
+            cpu_swap_space=self.cache_config.swap_space_bytes,
+            cache_dtype=self.cache_config.cache_dtype,
+        )
+
+        # Since we use a shared centralized controller, we take the minimum
+        # number of blocks across all workers to make sure all the memory
+        # operators can be applied to all workers.
+        num_gpu_blocks = min(b[0] for b in num_blocks)
+        num_cpu_blocks = min(b[1] for b in num_blocks)
+        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
+                    f"# CPU blocks: {num_cpu_blocks}")
+
+        logger.info(
+            f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"  # noqa: E501
+        )
+
+        check_block_size_valid(
+            num_gpu_blocks,
+            self.cache_config.block_size,
+            self.model_config.max_model_len,
+        )
+
+        self.cache_config.num_gpu_blocks = num_gpu_blocks
+        self.cache_config.num_cpu_blocks = num_cpu_blocks
+
+        # Initialize the cache.
+        self._run_workers("init_cache_engine", cache_config=self.cache_config)
+        # Warm up the model. This includes capturing the model into CUDA graph
+        # if enforce_eager is False.
+        self._run_workers("warm_up_model")
+
+    def execute_model(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
+        all_outputs = self._run_workers(
+            "execute_model",
+            driver_kwargs={
+                "seq_group_metadata_list": seq_group_metadata_list,
+                "blocks_to_swap_in": blocks_to_swap_in,
+                "blocks_to_swap_out": blocks_to_swap_out,
+                "blocks_to_copy": blocks_to_copy,
+            },
+            use_ray_compiled_dag=USE_RAY_COMPILED_DAG,
+        )
+
+        # Only the driver worker returns the sampling results.
+        output = all_outputs[0]
+        return output
+
+    def add_lora(self, lora_request: LoRARequest) -> bool:
+        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
+        return self._run_workers(
+            "add_lora",
+            lora_request=lora_request,
+        )
+
+    def remove_lora(self, lora_id: int) -> bool:
+        assert lora_id > 0, "lora_id must be greater than 0."
+        return self._run_workers(
+            "remove_lora",
+            lora_id=lora_id,
+        )
+
+    def list_loras(self) -> List[int]:
+        return self._run_workers("list_loras")
+
+    def _run_workers(
+        self,
+        method: str,
+        *args,
+        driver_args: Optional[List[Any]] = None,
+        driver_kwargs: Optional[Dict[str, Any]] = None,
+        max_concurrent_workers: Optional[int] = None,
+        use_ray_compiled_dag: bool = False,
+        **kwargs,
+    ) -> Any:
+        """Runs the given method on all workers."""
+
+        if max_concurrent_workers:
+            raise NotImplementedError(
+                "max_concurrent_workers is not supported yet.")
+
+        if use_ray_compiled_dag:
+            # Right now, compiled DAG can only accept a single
+            # input. TODO(sang): Fix it.
+            output_channels = self.forward_dag.execute(1)
+        else:
+            # Start the ray workers first.
+            ray_worker_outputs = [
+                worker.execute_method.remote(method, *args, **kwargs)
+                for worker in self.workers
+            ]
+
+        if driver_args is None:
+            driver_args = args
+        if driver_kwargs is None:
+            driver_kwargs = kwargs
+
+        # Start the driver worker after all the ray workers.
+        driver_worker_output = getattr(self.driver_worker,
+                                       method)(*driver_args, **driver_kwargs)
+
+        # Get the results of the ray workers.
+        if self.workers:
+            if use_ray_compiled_dag:
+                try:
+                    ray_worker_outputs = [
+                        pickle.loads(chan.begin_read())
+                        for chan in output_channels
+                    ]
+                finally:
+                    # Has to call end_read in order to reuse the DAG.
+                    for chan in output_channels:
+                        chan.end_read()
+            else:
+                ray_worker_outputs = ray.get(ray_worker_outputs)
+
+        return [driver_worker_output] + ray_worker_outputs
+
+    def _compiled_ray_dag(self):
+        import pkg_resources
+
+        required_version = "2.9"
+        current_version = pkg_resources.get_distribution("ray").version
+        if current_version < required_version:
+            raise ValueError(f"Ray version {required_version} or greater is "
+                             f"required, but found {current_version}")
+
+        from ray.dag import MultiOutputNode, InputNode
+
+        assert self.parallel_config.worker_use_ray
+
+        # Right now, compiled DAG requires at least 1 arg. We send
+        # a dummy value for now. It will be fixed soon.
+        with InputNode() as input_data:
+            forward_dag = MultiOutputNode([
+                worker.execute_model_compiled_dag_remote.bind(input_data)
+                for worker in self.workers
+            ])
+        return forward_dag.experimental_compile()
+
+    def check_health(self) -> None:
+        """Raises an error if engine is unhealthy."""
+        self._check_if_any_actor_is_dead()
+
+    def _check_if_any_actor_is_dead(self):
+        if not self.workers:
+            return
+
+        dead_actors = []
+        for actor in self.workers:
+            actor_state = ray.state.actors(actor._ray_actor_id.hex())  # pylint: disable=protected-access
+            if actor_state["State"] == "DEAD":
+                dead_actors.append(actor)
+        if dead_actors:
+            raise RuntimeError("At least one Worker is dead. "
+                               f"Dead Workers: {dead_actors}. ")
+
+
+class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
+
+    async def _run_workers_async(
+        self,
+        method: str,
+        *args,
+        driver_args: Optional[List[Any]] = None,
+        driver_kwargs: Optional[Dict[str, Any]] = None,
+        **kwargs,
+    ) -> Any:
+        """Runs the given method on all workers."""
+        coros = []
+
+        if driver_args is None:
+            driver_args = args
+        if driver_kwargs is None:
+            driver_kwargs = kwargs
+
+        # Run the driver worker asynchronously.
+        driver_executor = make_async(getattr(self.driver_worker, method))
+        coros.append(driver_executor(*driver_args, **driver_kwargs))
+
+        # Run the ray workers asynchronously.
+        for worker in self.workers:
+            coros.append(worker.execute_method.remote(method, *args, **kwargs))
+
+        all_outputs = await asyncio.gather(*coros)
+        return all_outputs
+
+    async def execute_model_async(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
+        all_outputs = await self._run_workers_async(
+            "execute_model",
+            driver_kwargs={
+                "seq_group_metadata_list": seq_group_metadata_list,
+                "blocks_to_swap_in": blocks_to_swap_in,
+                "blocks_to_swap_out": blocks_to_swap_out,
+                "blocks_to_copy": blocks_to_copy,
+            },
+        )
+
+        # Only the driver worker returns the sampling results.
+        output = all_outputs[0]
+        return output
+
+    async def check_health_async(self) -> None:
+        """Raises an error if engine is unhealthy."""
+        self._check_if_any_actor_is_dead()

+ 13 - 0
aphrodite/executor/utils.py

@@ -0,0 +1,13 @@
+def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None:
+    if num_gpu_blocks <= 0:
+        raise ValueError("No available memory for the cache blocks. "
+                         "Try increasing `gpu_memory_utilization` when "
+                         "initializing the engine.")
+    max_seq_len = block_size * num_gpu_blocks
+    if max_model_len > max_seq_len:
+        raise ValueError(
+            f"The model's max seq len ({max_model_len}) "
+            "is larger than the maximum number of tokens that can be "
+            f"stored in KV cache ({max_seq_len}). Try increasing "
+            "`gpu_memory_utilization` or decreasing `max_model_len` when "
+            "initializing the engine.")

+ 4 - 0
aphrodite/lora/layers.py

@@ -824,6 +824,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
         self.dtype = dtype
         self.device = device
 
+    @property
+    def logits_as_hidden_states(self):
+        return self.base_layer.logits_as_hidden_states
+
     @property
     def vocab_size(self):
         return self.base_layer.vocab_size

+ 8 - 3
aphrodite/modeling/hf_downloader.py

@@ -21,15 +21,19 @@ from aphrodite.common.gguf import GGUFReader
 from aphrodite.modeling.layers.quantization import (get_quantization_config,
                                                     QuantizationConfig)
 
+_xdg_cache_home = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
+_aphrodite_filelocks_path = os.path.join(_xdg_cache_home, 'aphrodite/locks/')
 
-class Disabledtqdm(tqdm):  # pylint: disable=inconsistent-mro
+
+class Disabledtqdm(tqdm):
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs, disable=True)
 
 
 def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
-    lock_dir = cache_dir if cache_dir is not None else "/tmp"
+    lock_dir = cache_dir if cache_dir is not None else _aphrodite_filelocks_path
+    os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
     lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
     lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
     return lock
@@ -164,7 +168,7 @@ def prepare_hf_model_weights(
                 allow_patterns = [pattern]
                 break
 
-        logger.info(f"Downloading model weights {allow_patterns}")
+        logger.info(f"Using model weights format {allow_patterns}")
         # Use file lock to prevent multiple processes from
         # downloading the same model weights at the same time.
         with get_lock(model_name_or_path, cache_dir):
@@ -192,6 +196,7 @@ def prepare_hf_model_weights(
             "scheduler.pt",
             "scaler.pt",
             "trainer_state.json",
+            "hidden_states.safetensors",  # exllamav2
         ]
         hf_weights_files = [
             f for f in hf_weights_files

+ 0 - 354
aphrodite/modeling/layers/attention.py

@@ -1,354 +0,0 @@
-"""Multi-head attention."""
-from typing import List, Optional
-
-import importlib
-import torch
-import torch.nn as nn
-from xformers import ops as xops
-from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
-                                         LowerTriangularMaskWithTensorBias)
-
-from aphrodite._C import ops
-from aphrodite._C import cache_ops
-from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.triton_kernel.prefix_prefill import (
-    context_attention_fwd)
-from aphrodite.common.utils import is_hip
-
-_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
-# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
-_PARTITION_SIZE = 512
-
-
-class PagedAttention(nn.Module):
-    """MHA/MQA/GQA layer with PagedAttention.
-
-    This class takes query, key, and value tensors as input. The input tensors
-    can either contain prompt tokens or generation tokens.
-    The class does the following:
-
-    1. Reshape and store the input key and value tensors in the KV cache.
-    2. Perform (multi-head/multi-query/grouped-query) attention using either
-        xformers or the PagedAttention custom op.
-    3. Return the output tensor.
-    """
-
-    def __init__(
-        self,
-        num_heads: int,
-        head_size: int,
-        scale: float,
-        num_kv_heads: Optional[int] = None,
-        alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
-    ) -> None:
-        super().__init__()
-        self.num_heads = num_heads
-        self.head_size = head_size
-        self.scale = float(scale)
-        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = sliding_window
-        if alibi_slopes is not None:
-            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
-        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
-
-        assert self.num_heads % self.num_kv_heads == 0
-        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
-
-        if self.head_size not in _SUPPORTED_HEAD_SIZES:
-            raise ValueError(f"head_size ({self.head_size}) is not supported. "
-                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
-
-        self.use_ref_attention = self.check_use_ref_attention()
-
-    def check_use_ref_attention(self) -> bool:
-        if not is_hip():
-            return False
-        # For ROCm, check whether flash attention is installed or not.
-        # if not, use_ref_attention needs to be True
-        return importlib.util.find_spec("flash_attn") is None
-
-    def ref_masked_attention(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-    ) -> torch.Tensor:
-        query = query.view(-1, self.num_heads, self.head_size)
-        key = key.view(-1, self.num_kv_heads, self.head_size)
-        value = value.view(-1, self.num_kv_heads, self.head_size)
-
-        seq_len, _, _ = query.shape
-        attn_mask = torch.triu(torch.ones(seq_len,
-                                          seq_len,
-                                          dtype=query.dtype,
-                                          device=query.device),
-                               diagonal=1)
-        attn_mask = attn_mask * torch.finfo(query.dtype).min
-
-        attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
-                                                 key).float()
-        attn_weights = attn_weights + attn_mask.float()
-        attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
-        out = torch.einsum("hqk,khd->qhd", attn_weights, value)
-        return out
-
-    def forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        key_cache: Optional[torch.Tensor],
-        value_cache: Optional[torch.Tensor],
-        input_metadata: InputMetadata,
-        kv_quant_param: List[float] = None,
-    ) -> torch.Tensor:
-        """PagedAttention forward pass.
-
-        Args:
-            query: shape = [batch_size, seq_len, num_heads * head_size]
-            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
-            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
-            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
-                block_size, x]
-            value_cache: shape = [num_blocks, num_kv_heads, head_size,
-                block_size]
-            input_metadata: metadata for the inputs.
-        Returns:
-            shape = [batch_size, seq_len, num_heads * head_size]
-        """
-        batch_size, seq_len, hidden_size = query.shape
-        # Reshape the query, key, and value tensors.
-        query = query.view(-1, self.num_heads, self.head_size)
-        key = key.view(-1, self.num_kv_heads, self.head_size)
-        value = value.view(-1, self.num_kv_heads, self.head_size)
-        # FIXME: Remove this when all models support int8 kv cache
-        kv_quant_param = [1.0, 0.0, 1.0, 0.0
-                          ] if kv_quant_param is None else kv_quant_param
-
-        # Reshape the keys and values and store them in the cache.
-        # If key_cache and value_cache are not provided, the new key and value
-        # vectors will not be cached. This happens during the initial memory
-        # profiling run.
-        if key_cache is not None and value_cache is not None:
-            cache_ops.reshape_and_cache(
-                key,
-                value,
-                key_cache,
-                value_cache,
-                input_metadata.slot_mapping.flatten(),
-                input_metadata.kv_cache_dtype,
-                *kv_quant_param,
-            )
-
-        if input_metadata.is_prompt:
-            # Prompt run.
-            if self.num_kv_heads != self.num_heads:
-                # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
-                # project the key and value tensors to the desired number of
-                # heads.
-                # TODO: Use MQA/GQA kernels for higher performance.
-                query = query.view(query.shape[0], self.num_kv_heads,
-                                   self.num_queries_per_kv, query.shape[-1])
-                key = key[:, :,
-                          None, :].expand(key.shape[0], self.num_kv_heads,
-                                          self.num_queries_per_kv,
-                                          key.shape[-1])
-                value = value[:, :, None, :].expand(value.shape[0],
-                                                    self.num_kv_heads,
-                                                    self.num_queries_per_kv,
-                                                    value.shape[-1])
-            # normal attention
-            if (key_cache is None or value_cache is None
-                    or input_metadata.block_tables.numel() == 0):
-                # Set attention bias if not provided. This typically happens at
-                # the very attention layer of every iteration.
-                # FIXME: This is a hack.
-                if input_metadata.attn_bias is None:
-                    if self.alibi_slopes is None:
-                        attn_bias = BlockDiagonalCausalMask.from_seqlens(
-                            [seq_len] * batch_size)
-                        if self.sliding_window is not None:
-                            attn_bias = attn_bias.make_local_attention(
-                                self.sliding_window)
-                        input_metadata.attn_bias = attn_bias
-                    else:
-                        input_metadata.attn_bias = _make_alibi_bias(
-                            self.alibi_slopes, self.num_kv_heads, batch_size,
-                            seq_len, query.dtype)
-
-                if self.use_ref_attention:
-                    output = self.ref_masked_attention(
-                        query,
-                        key,
-                        value,
-                    )
-                    return output.reshape(batch_size, seq_len, hidden_size)
-
-                # TODO: Too many view operations. Let's try to reduce
-                # them in the future for code readability.
-                if self.alibi_slopes is None:
-                    query = query.unsqueeze(0)
-                    key = key.unsqueeze(0)
-                    value = value.unsqueeze(0)
-                else:
-                    query = query.unflatten(0, (batch_size, seq_len))
-                    key = key.unflatten(0, (batch_size, seq_len))
-                    value = value.unflatten(0, (batch_size, seq_len))
-
-                out = xops.memory_efficient_attention_forward(
-                    query,
-                    key,
-                    value,
-                    attn_bias=input_metadata.attn_bias,
-                    p=0.0,
-                    scale=self.scale,
-                    op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
-                    (is_hip()) else None,
-                )
-                output = out.view_as(query)
-            else:
-                # prefix-enabled attention
-                output = torch.empty_like(query)
-                context_attention_fwd(
-                    query,
-                    key,
-                    value,
-                    output,
-                    key_cache,
-                    value_cache,
-                    input_metadata.block_tables,  # [BS, max_block_per_request]
-                    input_metadata.start_loc,
-                    input_metadata.prompt_lens,
-                    input_metadata.context_lens,
-                    input_metadata.max_seq_len,
-                    getattr(self, "alibi_slopes", None),
-                )
-
-        else:
-            # Decoding run.
-            output = _paged_attention(
-                query,
-                key_cache,
-                value_cache,
-                input_metadata,
-                self.num_kv_heads,
-                self.scale,
-                self.alibi_slopes,
-                kv_quant_param,
-            )
-
-        # Reshape the output tensor.
-        return output.view(batch_size, seq_len, hidden_size)
-
-
-def _make_alibi_bias(
-    alibi_slopes: torch.Tensor,
-    num_kv_heads: int,
-    batch_size: int,
-    seq_len: int,
-    dtype: torch.dtype,
-) -> LowerTriangularMaskWithTensorBias:
-    bias = torch.arange(seq_len, dtype=dtype)
-    # NOTE: HF uses
-    #     `bias = bias[None, :].repeat(prompt_len, 1)`
-    # here. We find that both biases give the same results, but
-    # the bias below more accurately follows the original ALiBi
-    # paper.
-    bias = bias[None, :] - bias[:, None]
-
-    # When using custom attention bias, xformers requires the bias to
-    # be sliced from a tensor whose length is a multiple of 8.
-    padded_len = (seq_len + 7) // 8 * 8
-    num_heads = alibi_slopes.shape[0]
-    bias = torch.empty(
-        batch_size,
-        num_heads,
-        seq_len,
-        padded_len,
-        device=alibi_slopes.device,
-        dtype=dtype,
-    )[:, :, :, :seq_len].copy_(bias)
-    bias.mul_(alibi_slopes[:, None, None])
-    if num_heads != num_kv_heads:
-        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
-    attn_bias = LowerTriangularMaskWithTensorBias(bias)
-    return attn_bias
-
-
-def _paged_attention(
-    query: torch.Tensor,
-    key_cache: torch.Tensor,
-    value_cache: torch.Tensor,
-    input_metadata: InputMetadata,
-    num_kv_heads: int,
-    scale: float,
-    alibi_slopes: Optional[torch.Tensor],
-    kv_quant_param: List[float],
-) -> torch.Tensor:
-    output = torch.empty_like(query)
-
-    block_size = value_cache.shape[3]
-    num_seqs, num_heads, head_size = query.shape
-    max_num_partitions = (
-        (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
-        _PARTITION_SIZE)
-    # NOTE: We use a simple heuristic to decide whether to use
-    # PagedAttention V1 or V2. If the number of partitions is 1, we use
-    # V1 to avoid the overhead of reduction. Also, if the number of
-    # sequences or heads is large, we use V1 since there is enough work
-    # to parallelize.
-    # TODO: Tune this heuristic.
-    # For context len > 8192, use V2 kernel to avoid shared memory shortage.
-    use_v1 = input_metadata.max_context_len <= 8192 and (
-        max_num_partitions == 1 or num_seqs * num_heads > 512)
-    if use_v1:
-        # Run PagedAttention V1.
-        ops.paged_attention_v1(
-            output,
-            query,
-            key_cache,
-            value_cache,
-            num_kv_heads,
-            scale,
-            input_metadata.block_tables,
-            input_metadata.context_lens,
-            block_size,
-            input_metadata.max_context_len,
-            alibi_slopes,
-            input_metadata.kv_cache_dtype,
-            *kv_quant_param,
-        )
-    else:
-        # Run PagedAttention V2.
-        assert _PARTITION_SIZE % block_size == 0
-        tmp_output = torch.empty(
-            size=(num_seqs, num_heads, max_num_partitions, head_size),
-            dtype=output.dtype,
-            device=output.device,
-        )
-        exp_sums = torch.empty(
-            size=(num_seqs, num_heads, max_num_partitions),
-            dtype=torch.float32,
-            device=output.device,
-        )
-        max_logits = torch.empty_like(exp_sums)
-        ops.paged_attention_v2(
-            output,
-            exp_sums,
-            max_logits,
-            tmp_output,
-            query,
-            key_cache,
-            value_cache,
-            num_kv_heads,
-            scale,
-            input_metadata.block_tables,
-            input_metadata.context_lens,
-            block_size,
-            input_metadata.max_context_len,
-            alibi_slopes,
-            input_metadata.kv_cache_dtype,
-            *kv_quant_param,
-        )
-    return output

+ 93 - 0
aphrodite/modeling/layers/attention/__init__.py

@@ -0,0 +1,93 @@
+"""Attention layer."""
+from functools import lru_cache
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+from loguru import logger
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.common.utils import is_hip
+
+
+class Attention(nn.Module):
+    """Attention layer.
+
+    This class takes query, key, and value tensors as input. The input tensors
+    can either contain prompt tokens or generation tokens.
+    The class does the following:
+
+    1. Store the input key and value tensors in the KV cache.
+    2. Perform (multi-head/multi-query/grouped-query) attention.
+    3. Return the output tensor.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        super().__init__()
+        if _use_flash_attn():
+            from aphrodite.modeling.layers.attention.backends.flash_attn import FlashAttentionBackend  # noqa: E501
+
+            self.backend = FlashAttentionBackend(
+                num_heads,
+                head_size,
+                scale,
+                num_kv_heads,
+                alibi_slopes,
+                sliding_window,
+            )
+        else:
+            from aphrodite.modeling.layers.attention.backends.xformers import XFormersBackend  # noqa: E501
+
+            self.backend = XFormersBackend(
+                num_heads,
+                head_size,
+                scale,
+                num_kv_heads,
+                alibi_slopes,
+                sliding_window,
+            )
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: Optional[torch.Tensor],
+        value_cache: Optional[torch.Tensor],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        return self.backend.forward(query, key, value, key_cache, value_cache,
+                                    input_metadata)
+
+
+@lru_cache(maxsize=1)
+def _use_flash_attn() -> bool:
+    try:
+        import flash_attn  # noqa: F401
+    except ImportError:
+        logger.info("flash_attn is not found. Using xformers backend.")
+        return False
+
+    if is_hip():
+        # AMD GPUs.
+        return False
+    if torch.cuda.get_device_capability()[0] < 8:
+        logger.info("flash_attn is not supported on Turing or older GPUs. "
+                    "Using xformers backend.")
+        return False
+    if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
+        logger.info(
+            "flash_attn only supports torch.float16 or torch.bfloat16. "
+            "Using xformers backend.")
+        return False
+
+    logger.info("Using Flash Attention backend.")
+    return True

+ 0 - 0
aphrodite/modeling/layers/attention/backends/__init__.py


+ 121 - 0
aphrodite/modeling/layers/attention/backends/flash_attn.py

@@ -0,0 +1,121 @@
+"""Attention layer with Flash and PagedAttention."""
+from typing import List, Optional
+
+from flash_attn import flash_attn_func
+import torch
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.attention.ops.paged_attn import (
+    PagedAttentionImpl)
+
+
+class FlashAttentionBackend:
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+        self.sliding_window = sliding_window
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+        suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+        self.sliding_window = ((self.sliding_window, self.sliding_window) if
+                               self.sliding_window is not None else (-1, -1))
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: Optional[torch.Tensor],
+        value_cache: Optional[torch.Tensor],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        """Forward pass with FlashAttention and PagedAttention.
+
+        Args:
+            query: shape = [batch_size, seq_len, num_heads * head_size]
+            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
+            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
+            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
+                block_size, x]
+            value_cache: shape = [num_blocks, num_kv_heads, head_size,
+                block_size]
+            input_metadata: metadata for the inputs.
+        Returns:
+            shape = [batch_size, seq_len, num_heads * head_size]
+        """
+        batch_size, seq_len, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        # Reshape the keys and values and store them in the cache.
+        # If key_cache and value_cache are not provided, the new key and value
+        # vectors will not be cached. This happens during the initial memory
+        # profiling run.
+        if key_cache is not None and value_cache is not None:
+            PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
+                                                 value_cache, input_metadata)
+
+        if input_metadata.is_prompt:
+            # Prompt run.
+            if (key_cache is None or value_cache is None
+                    or input_metadata.block_tables.numel() == 0):
+                # normal attention
+                query = query.unflatten(0, (batch_size, seq_len))
+                key = key.unflatten(0, (batch_size, seq_len))
+                value = value.unflatten(0, (batch_size, seq_len))
+                output = flash_attn_func(
+                    query,
+                    key,
+                    value,
+                    softmax_scale=self.scale,
+                    causal=True,
+                    window_size=self.sliding_window,
+                    alibi_slopes=self.alibi_slopes,
+                )
+            else:
+                # prefix-enabled attention
+                output = PagedAttentionImpl.forward_prefix(
+                    query,
+                    key,
+                    value,
+                    key_cache,
+                    value_cache,
+                    input_metadata,
+                    self.alibi_slopes,
+                )
+        else:
+            # Decoding run.
+            output = PagedAttentionImpl.forward_decode(
+                query,
+                key_cache,
+                value_cache,
+                input_metadata,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+            )
+
+        # Reshape the output tensor.
+        return output.view(batch_size, seq_len, hidden_size)

+ 255 - 0
aphrodite/modeling/layers/attention/backends/xformers.py

@@ -0,0 +1,255 @@
+"""Attention layer with xFormers and PagedAttention."""
+import importlib
+from typing import List, Optional
+
+import torch
+from xformers import ops as xops
+from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
+                                         LowerTriangularMaskWithTensorBias)
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.attention.ops.paged_attn import (
+    PagedAttentionImpl)
+from aphrodite.common.utils import is_hip
+
+
+class XFormersBackend:
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+        self.sliding_window = sliding_window
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+        suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+        self.use_ref_attention = _check_use_ref_attention()
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: Optional[torch.Tensor],
+        value_cache: Optional[torch.Tensor],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        """Forward pass with xFormers and PagedAttention.
+
+        Args:
+            query: shape = [batch_size, seq_len, num_heads * head_size]
+            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
+            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
+            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
+                block_size, x]
+            value_cache: shape = [num_blocks, num_kv_heads, head_size,
+                block_size]
+            input_metadata: metadata for the inputs.
+        Returns:
+            shape = [batch_size, seq_len, num_heads * head_size]
+        """
+        batch_size, seq_len, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        # Reshape the keys and values and store them in the cache.
+        # If key_cache and value_cache are not provided, the new key and value
+        # vectors will not be cached. This happens during the initial memory
+        # profiling run.
+        if key_cache is not None and value_cache is not None:
+            PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
+                                                 value_cache, input_metadata)
+
+        if input_metadata.is_prompt:
+            # Prompt run.
+            if (key_cache is None or value_cache is None
+                    or input_metadata.block_tables.numel() == 0):
+                # normal attention
+                if self.num_kv_heads != self.num_heads:
+                    # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
+                    # project the key and value tensors to the desired number of
+                    # heads.
+                    # TODO: Use MQA/GQA kernels for higher performance.
+                    query = query.view(query.shape[0], self.num_kv_heads,
+                                       self.num_queries_per_kv,
+                                       query.shape[-1])
+                    key = key[:, :,
+                              None, :].expand(key.shape[0], self.num_kv_heads,
+                                              self.num_queries_per_kv,
+                                              key.shape[-1])
+                    value = value[:, :,
+                                  None, :].expand(value.shape[0],
+                                                  self.num_kv_heads,
+                                                  self.num_queries_per_kv,
+                                                  value.shape[-1])
+
+                # Set attention bias if not provided. This typically happens at
+                # the very attention layer of every iteration.
+                # FIXME: This is a hack.
+                if input_metadata.attn_bias is None:
+                    if self.alibi_slopes is None:
+                        attn_bias = BlockDiagonalCausalMask.from_seqlens(
+                            [seq_len] * batch_size)
+                        if self.sliding_window is not None:
+                            attn_bias = attn_bias.make_local_attention(
+                                self.sliding_window)
+                        input_metadata.attn_bias = attn_bias
+                    else:
+                        input_metadata.attn_bias = _make_alibi_bias(
+                            self.alibi_slopes, self.num_kv_heads, batch_size,
+                            seq_len, query.dtype)
+
+                if self.use_ref_attention:
+                    output = _ref_masked_attention(
+                        query,
+                        key,
+                        value,
+                        self.num_heads,
+                        self.num_kv_heads,
+                        self.head_size,
+                        self.scale,
+                    )
+                    # Using view got RuntimeError: view size is not compatible
+                    # with input tensor's size and stride (at least one
+                    # dimension spans across two contiguous subspaces).
+                    # Use reshape instead.
+                    return output.reshape(batch_size, seq_len, hidden_size)
+
+                # TODO: Too many view operations. Let's try to reduce
+                # them in the future for code readability.
+                if self.alibi_slopes is None:
+                    query = query.unsqueeze(0)
+                    key = key.unsqueeze(0)
+                    value = value.unsqueeze(0)
+                else:
+                    query = query.unflatten(0, (batch_size, seq_len))
+                    key = key.unflatten(0, (batch_size, seq_len))
+                    value = value.unflatten(0, (batch_size, seq_len))
+
+                out = xops.memory_efficient_attention_forward(
+                    query,
+                    key,
+                    value,
+                    attn_bias=input_metadata.attn_bias,
+                    p=0.0,
+                    scale=self.scale,
+                    op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
+                    (is_hip()) else None,
+                )
+                output = out.view_as(query)
+
+            else:
+                # prefix-enabled attention
+                output = PagedAttentionImpl.forward_prefix(
+                    query,
+                    key,
+                    value,
+                    key_cache,
+                    value_cache,
+                    input_metadata,
+                    self.alibi_slopes,
+                )
+        else:
+            # Decoding run.
+            output = PagedAttentionImpl.forward_decode(
+                query,
+                key_cache,
+                value_cache,
+                input_metadata,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+            )
+
+        # Reshape the output tensor.
+        return output.view(batch_size, seq_len, hidden_size)
+
+
+def _make_alibi_bias(
+    alibi_slopes: torch.Tensor,
+    num_kv_heads: int,
+    batch_size: int,
+    seq_len: int,
+    dtype: torch.dtype,
+) -> LowerTriangularMaskWithTensorBias:
+    bias = torch.arange(seq_len, dtype=dtype)
+    # NOTE: HF uses
+    #     `bias = bias[None, :].repeat(prompt_len, 1)`
+    # here. We find that both biases give the same results, but
+    # the bias below more accurately follows the original ALiBi
+    # paper.
+    bias = bias[None, :] - bias[:, None]
+
+    # When using custom attention bias, xformers requires the bias to
+    # be sliced from a tensor whose length is a multiple of 8.
+    padded_len = (seq_len + 7) // 8 * 8
+    num_heads = alibi_slopes.shape[0]
+    bias = torch.empty(
+        batch_size,
+        num_heads,
+        seq_len,
+        padded_len,
+        device=alibi_slopes.device,
+        dtype=dtype,
+    )[:, :, :, :seq_len].copy_(bias)
+    bias.mul_(alibi_slopes[:, None, None])
+    if num_heads != num_kv_heads:
+        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
+    attn_bias = LowerTriangularMaskWithTensorBias(bias)
+    return attn_bias
+
+
+def _check_use_ref_attention() -> bool:
+    if not is_hip():
+        return False
+    # For ROCm, check whether flash attention is installed or not.
+    # if not, use_ref_attention needs to be True
+    return importlib.util.find_spec("flash_attn") is None
+
+
+def _ref_masked_attention(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    num_heads: int,
+    num_kv_heads: int,
+    head_size: int,
+    scale: float,
+) -> torch.Tensor:
+    query = query.view(-1, num_heads, head_size)
+    key = key.view(-1, num_kv_heads, head_size)
+    value = value.view(-1, num_kv_heads, head_size)
+
+    seq_len, _, _ = query.shape
+    attn_mask = torch.triu(torch.ones(seq_len,
+                                      seq_len,
+                                      dtype=query.dtype,
+                                      device=query.device),
+                           diagonal=1)
+    attn_mask = attn_mask * torch.finfo(query.dtype).min
+
+    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
+    attn_weights = attn_weights + attn_mask.float()
+    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
+    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
+    return out

+ 0 - 0
aphrodite/modeling/layers/attention/ops/__init__.py


+ 138 - 0
aphrodite/modeling/layers/attention/ops/paged_attn.py

@@ -0,0 +1,138 @@
+from typing import List, Optional
+
+import torch
+
+from aphrodite._C import cache_ops
+from aphrodite._C import ops
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.attention.ops.prefix_prefill import (
+    context_attention_fwd)
+
+# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
+_PARTITION_SIZE = 512
+
+
+class PagedAttentionImpl:
+
+    @staticmethod
+    def get_supported_head_sizes() -> List[int]:
+        return [64, 80, 96, 112, 128, 256]
+
+    @staticmethod
+    def reshape_and_cache(
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        input_metadata: InputMetadata,
+    ) -> None:
+        cache_ops.reshape_and_cache(
+            key,
+            value,
+            key_cache,
+            value_cache,
+            input_metadata.slot_mapping.flatten(),
+            input_metadata.kv_cache_dtype,
+        )
+
+    @staticmethod
+    def forward_decode(
+        query: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        input_metadata: InputMetadata,
+        num_kv_heads: int,
+        scale: float,
+        alibi_slopes: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        output = torch.empty_like(query)
+
+        block_size = value_cache.shape[3]
+        num_seqs, num_heads, head_size = query.shape
+        max_num_partitions = (
+            (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
+            _PARTITION_SIZE)
+        # NOTE: We use a simple heuristic to decide whether to use
+        # PagedAttention V1 or V2. If the number of partitions is 1, we use
+        # V1 to avoid the overhead of reduction. Also, if the number of
+        # sequences or heads is large, we use V1 since there is enough work
+        # to parallelize.
+        # TODO: Tune this heuristic.
+        # For context len > 8192, use V2 kernel to avoid shared memory shortage.
+        use_v1 = input_metadata.max_context_len <= 8192 and (
+            max_num_partitions == 1 or num_seqs * num_heads > 512)
+        if use_v1:
+            # Run PagedAttention V1.
+            ops.paged_attention_v1(
+                output,
+                query,
+                key_cache,
+                value_cache,
+                num_kv_heads,
+                scale,
+                input_metadata.block_tables,
+                input_metadata.context_lens,
+                block_size,
+                input_metadata.max_context_len,
+                alibi_slopes,
+                input_metadata.kv_cache_dtype,
+            )
+        else:
+            # Run PagedAttention V2.
+            assert _PARTITION_SIZE % block_size == 0
+            tmp_output = torch.empty(
+                size=(num_seqs, num_heads, max_num_partitions, head_size),
+                dtype=output.dtype,
+                device=output.device,
+            )
+            exp_sums = torch.empty(
+                size=(num_seqs, num_heads, max_num_partitions),
+                dtype=torch.float32,
+                device=output.device,
+            )
+            max_logits = torch.empty_like(exp_sums)
+            ops.paged_attention_v2(
+                output,
+                exp_sums,
+                max_logits,
+                tmp_output,
+                query,
+                key_cache,
+                value_cache,
+                num_kv_heads,
+                scale,
+                input_metadata.block_tables,
+                input_metadata.context_lens,
+                block_size,
+                input_metadata.max_context_len,
+                alibi_slopes,
+                input_metadata.kv_cache_dtype,
+            )
+        return output
+
+    @staticmethod
+    def forward_prefix(
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        input_metadata: InputMetadata,
+        alibi_slopes: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        output = torch.empty_like(query)
+        context_attention_fwd(
+            query,
+            key,
+            value,
+            output,
+            key_cache,
+            value_cache,
+            input_metadata.block_tables,  # [BS, max_block_per_request]
+            input_metadata.start_loc,
+            input_metadata.prompt_lens,
+            input_metadata.context_lens,
+            input_metadata.max_seq_len,
+            alibi_slopes,
+        )
+        return output

+ 28 - 12
aphrodite/modeling/layers/triton_kernel/prefix_prefill.py → aphrodite/modeling/layers/attention/ops/prefix_prefill.py

@@ -45,6 +45,7 @@ if triton.__version__ >= "2.1.0":
         stride_v_cache_h,
         stride_v_cache_d,
         stride_v_cache_bl,
+        num_queries_per_kv: int,
         BLOCK_M: tl.constexpr,
         BLOCK_DMODEL: tl.constexpr,
         BLOCK_N: tl.constexpr,
@@ -53,6 +54,8 @@ if triton.__version__ >= "2.1.0":
         cur_head = tl.program_id(1)
         start_m = tl.program_id(2)
 
+        cur_kv_head = cur_head // num_queries_per_kv
+
         cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
         cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
         cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
@@ -85,13 +88,14 @@ if triton.__version__ >= "2.1.0":
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
                          other=0)
             off_k = (bn[None, :] * stride_k_cache_bs +
-                     cur_head * stride_k_cache_h +
+                     cur_kv_head * stride_k_cache_h +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      ((start_n + offs_n[None, :]) % block_size) *
                      stride_k_cache_bl +
                      (offs_d[:, None] % x) * stride_k_cache_x)
             off_v = (
-                bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
+                bn[:, None] * stride_v_cache_bs +
+                cur_kv_head * stride_v_cache_h +
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
             k = tl.load(K_cache + off_k,
@@ -131,9 +135,9 @@ if triton.__version__ >= "2.1.0":
             l_i = l_i_new
             m_i = m_i_new
 
-        off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
+        off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
                  offs_d[:, None] * stride_kd)
-        off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
+        off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
                  offs_d[None, :] * stride_vd)
         k_ptrs = K + off_k
         v_ptrs = V + off_v
@@ -232,6 +236,7 @@ if triton.__version__ >= "2.1.0":
         stride_v_cache_h,
         stride_v_cache_d,
         stride_v_cache_bl,
+        num_queries_per_kv: int,
         BLOCK_M: tl.constexpr,
         BLOCK_DMODEL: tl.constexpr,
         BLOCK_N: tl.constexpr,
@@ -240,6 +245,8 @@ if triton.__version__ >= "2.1.0":
         cur_head = tl.program_id(1)
         start_m = tl.program_id(2)
 
+        cur_kv_head = cur_head // num_queries_per_kv
+
         cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
         cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
         cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
@@ -272,13 +279,14 @@ if triton.__version__ >= "2.1.0":
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
                          other=0)
             off_k = (bn[None, :] * stride_k_cache_bs +
-                     cur_head * stride_k_cache_h +
+                     cur_kv_head * stride_k_cache_h +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      ((start_n + offs_n[None, :]) % block_size) *
                      stride_k_cache_bl +
                      (offs_d[:, None] % x) * stride_k_cache_x)
             off_v = (
-                bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
+                bn[:, None] * stride_v_cache_bs +
+                cur_kv_head * stride_v_cache_h +
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
             k = tl.load(K_cache + off_k,
@@ -317,9 +325,9 @@ if triton.__version__ >= "2.1.0":
             l_i = l_i_new
             m_i = m_i_new
 
-        off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
+        off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
                  offs_d[:, None] * stride_kd)
-        off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
+        off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
                  offs_d[None, :] * stride_vd)
         k_ptrs = K + off_k
         v_ptrs = V + off_v
@@ -420,6 +428,7 @@ if triton.__version__ >= "2.1.0":
         stride_v_cache_h,
         stride_v_cache_d,
         stride_v_cache_bl,
+        num_queries_per_kv: int,
         BLOCK_M: tl.constexpr,
         BLOCK_DMODEL: tl.constexpr,
         BLOCK_N: tl.constexpr,
@@ -429,6 +438,8 @@ if triton.__version__ >= "2.1.0":
         cur_head = tl.program_id(1)
         start_m = tl.program_id(2)
 
+        cur_kv_head = cur_head // num_queries_per_kv
+
         # cur_batch_seq_len: the length of prompts
         # cur_batch_ctx_len: the length of prefix
         # cur_batch_in_all_start_index: the start id of the dim=0
@@ -468,13 +479,14 @@ if triton.__version__ >= "2.1.0":
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
                          other=0)
             off_k = (bn[None, :] * stride_k_cache_bs +
-                     cur_head * stride_k_cache_h +
+                     cur_kv_head * stride_k_cache_h +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      ((start_n + offs_n[None, :]) % block_size) *
                      stride_k_cache_bl +
                      (offs_d[:, None] % x) * stride_k_cache_x)
             off_v = (
-                bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
+                bn[:, None] * stride_v_cache_bs +
+                cur_kv_head * stride_v_cache_h +
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
             k = tl.load(K_cache + off_k,
@@ -522,9 +534,9 @@ if triton.__version__ >= "2.1.0":
             l_i = l_i_new
             m_i = m_i_new
 
-        off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
+        off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
                  offs_d[:, None] * stride_kd)
-        off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
+        off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
                  offs_d[None, :] * stride_vd)
         k_ptrs = K + off_k
         v_ptrs = V + off_v
@@ -618,6 +630,7 @@ if triton.__version__ >= "2.1.0":
                               b_ctx_len,
                               max_input_len,
                               alibi_slopes=None):
+
         cap = torch.cuda.get_device_capability()
         BLOCK = 128 if cap[0] >= 8 else 64
         # shape constraints
@@ -627,6 +640,7 @@ if triton.__version__ >= "2.1.0":
 
         sm_scale = 1.0 / (Lq**0.5)
         batch, head = b_seq_len.shape[0], q.shape[1]
+        num_queries_per_kv = q.shape[1] // k.shape[1]
 
         grid = (batch, head, triton.cdiv(max_input_len, BLOCK))  # batch, head,
 
@@ -673,6 +687,7 @@ if triton.__version__ >= "2.1.0":
                 v_cache.stride(2),
                 v_cache.stride(
                     3),  #[num_blocks, num_kv_heads, head_size, block_size]
+                num_queries_per_kv=num_queries_per_kv,
                 BLOCK_M=BLOCK,
                 BLOCK_DMODEL=Lk,
                 BLOCK_N=BLOCK,
@@ -720,6 +735,7 @@ if triton.__version__ >= "2.1.0":
             v_cache.stride(2),
             v_cache.stride(
                 3),  #[num_blocks, num_kv_heads, head_size, block_size]
+            num_queries_per_kv=num_queries_per_kv,
             BLOCK_M=BLOCK,
             BLOCK_DMODEL=Lk,
             BLOCK_N=BLOCK,

+ 8 - 0
aphrodite/modeling/layers/fused_moe/__init__.py

@@ -0,0 +1,8 @@
+from aphrodite.modeling.layers.fused_moe.fused_moe import (fused_moe,
+                                                           get_config_file_name
+                                                           )
+
+__all__ = [
+    "fused_moe",
+    "get_config_file_name",
+]

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 9 - 0
aphrodite/modeling/layers/fused_moe/configs/README

@@ -0,0 +1,9 @@
+This directory contains tuned configurations for different settings of the fused_moe kernel.
+For different settings of
+- E (number of experts)
+- N (intermediate size)
+- device_name (torch.cuda.get_device_name())
+the JSON file contains a mapping from M (batch size) to the chosen configuration.
+
+Mixtral has intermediate size N = 14336, i.e. for TP2 we have
+N = 7168 and for TP4 we have N = 3584.

+ 120 - 58
aphrodite/modeling/layers/triton_kernel/fused_moe.py → aphrodite/modeling/layers/fused_moe/fused_moe.py

@@ -1,7 +1,13 @@
 """Fused MoE kernel."""
+import functools
+import json
+import os
+from typing import Any, Dict, Optional, Tuple
+
 import torch
 import triton
 import triton.language as tl
+from loguru import logger
 
 from aphrodite._C import ops
 from aphrodite.common.utils import is_hip
@@ -22,9 +28,10 @@ def fused_moe_kernel(
     K,
     EM,
     num_valid_tokens,
-    # The stride variables represent how much to increase the ptr by when moving
-    # by 1 element in a particular dimension. E.g. `stride_am` is how much to
-    # increase `a_ptr` by to get the element one row down (A has M rows).
+    # The stride variables represent how much to increase the ptr by when
+    # moving by 1 element in a particular dimension. E.g. `stride_am` is
+    # how much to increase `a_ptr` by to get the element one row down
+    # (A has M rows).
     stride_am,
     stride_ak,
     stride_be,
@@ -44,22 +51,23 @@ def fused_moe_kernel(
     """
     Implements the fused computation for a Mixture of Experts (MOE) using
     token and expert matrices.
+
     Key Parameters:
-    - A: The input tensor representing tokens with shape (*, K), where '*'
-        can be any shape representing batches and K is the feature dimension
-        of each token.
-    - B: The stacked MOE weight tensor with shape (E, N, K), where E is the
-        number of experts, K is the input feature dimension, and N is the
-        output feature dimension.
+    - A: The input tensor representing tokens with shape (*, K), where '*' can
+        be any shape representing batches and K is the feature dimension of
+        each token.
+    - B: The stacked MOE weight tensor with shape (E, N, K), where E is
+        the number of experts, K is the input feature dimension, and N is
+        the output feature dimension.
     - C: The output cache tensor with shape (M, topk, N), where M is the
         total number of tokens post padding, topk is the number of times
         each token is repeated, and N is the output feature dimension.
     - sorted_token_ids: A tensor containing the sorted indices of tokens,
-        repeated topk times and arranged by the expert index they are assigned
-        to.
-    - expert_ids: A tensor containing the indices of the expert for each block.
-        It determines which expert matrix from B should be used for each block
-        in A.
+        repeated topk times and arranged by the expert index they are
+        assigned to.
+    - expert_ids: A tensor containing the indices of the expert for each
+        block. It determines which expert matrix from B should be used for
+        each block in A.
     This kernel performs the multiplication of a token by its corresponding
     expert matrix as determined by `expert_ids`. The sorting of
     `sorted_token_ids` by expert index and padding ensures divisibility by
@@ -142,39 +150,43 @@ def fused_moe_kernel(
 
 def moe_align_block_size(
         topk_ids: torch.Tensor, block_size: int,
-        num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor):
+        num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     Aligns the token distribution across experts to be compatible with block
     size for matrix multiplication.
+
     Parameters:
-    - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k
-        expert indices for each token.
+    - topk_ids: A tensor of shape [total_tokens, top_k] representing the
+        top-k expert indices for each token.
     - block_size: The block size used in block matrix multiplication.
     - num_experts: The total number of experts.
+
     Returns:
     - sorted_token_ids: A tensor containing the sorted token indices according
         to their allocated expert.
     - expert_ids: A tensor indicating the assigned expert index for each block.
     - num_tokens_post_padded: The total number of tokens after padding,
         ensuring divisibility by block_size.
+
     This function pads the number of tokens that each expert needs to process
-    so that it is divisible by block_size. Padding ensures that during block
-    matrix multiplication, the dimensions align correctly.
+    so that it is divisible by block_size.
+    Padding ensures that during block matrix multiplication, the dimensions
+    align correctly.
+
     Example:
     Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
     block_size = 4, and num_experts = 4:
-    - We initially have 12 tokens (after repeating 'top_k' times) and 4
-        experts, with each expert needing to process 3 tokens.
+    - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
+        with each expert needing to process 3 tokens.
     - As block_size is 4, we pad 1 token for each expert.
     - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
     - Then append padding tokens [12, 12, 12, 12] for each block.
-    - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4,
-                                                          10, 12, 1, 7, 11,
-                                                          12, 2, 5, 8, 12].
-        Tokens 12 are non-existent (padding) and are ignored in the subsequent
-        matrix multiplication.
-    - The padding ensures that the total number of tokens is now divisible by
-        block_size for proper block matrix operations.
+    - After sorting by expert index, we obtain token_ids
+        [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
+        Tokens 12 are non-existent (padding) and are ignored in
+        the subsequent matrix multiplication.
+    - The padding ensures that the total number of tokens is now divisible
+        by block_size for proper block matrix operations.
     """
     sorted_ids = torch.empty(
         (topk_ids.numel() + num_experts * (block_size - 1), ),
@@ -197,12 +209,11 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
                             sorted_token_ids: torch.Tensor,
                             expert_ids: torch.Tensor,
                             num_tokens_post_padded: torch.Tensor,
-                            mul_routed_weight: bool, top_k: int, config: dict):
-
+                            mul_routed_weight: bool, top_k: int,
+                            config: Dict[str, Any]) -> None:
     assert topk_weights.stride(1) == 1
     assert sorted_token_ids.stride(0) == 1
 
-    # ruff: noqa: E731
     grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
         'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
 
@@ -232,6 +243,40 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
     )
 
 
+def get_config_file_name(E: int, N: int) -> str:
+    device_name = torch.cuda.get_device_name().replace(" ", "_")
+    return f"E={E},N={N},device_name={device_name}.json"
+
+
+@functools.lru_cache
+def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
+    """
+    Return optimized configurations for the fused MoE kernel.
+
+    The return value will be a dictionary that maps an irregular grid of
+    batch sizes to configurations of the fused_moe kernel. To evaluate the
+    kernel on a given batch size bs, the closest batch size in the grid should
+    be picked and the associated configuration chosen to invoke the kernel.
+    """
+
+    # First look up if an optimized configuration is available in the configs
+    # directory
+    json_file_name = get_config_file_name(E, N)
+
+    config_file_path = os.path.join(
+        os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
+    if os.path.exists(config_file_path):
+        with open(config_file_path) as f:
+            logger.info(
+                f"Using configuration from {config_file_path} for MoE layer.")
+            # If a configuration has been found, return it
+            return {int(key): val for key, val in json.load(f).items()}
+
+    # If no optimized configuration is available, we will use the default
+    # configuration
+    return None
+
+
 def fused_moe(
     hidden_states: torch.Tensor,
     w1: torch.Tensor,
@@ -240,6 +285,7 @@ def fused_moe(
     topk: int,
     renormalize: bool,
     inplace: bool = False,
+    override_config: Optional[Dict[str, Any]] = None,
 ) -> torch.Tensor:
     """
     This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -249,25 +295,29 @@ def fused_moe(
     - hidden_states (torch.Tensor): The input tensor to the MoE layer.
     - w1 (torch.Tensor): The first set of expert weights.
     - w2 (torch.Tensor): The second set of expert weights.
-    - gating_output (torch.Tensor): The output of the gating operation (before
-        softmax).
+    - gating_output (torch.Tensor): The output of the gating operation
+        (before softmax).
     - topk (int): The number of top-k experts to select.
     - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
-    - inplace (bool): If True, perform the operation in-place. Defaults to
-        False.
+    - inplace (bool): If True, perform the operation in-place.
+        Defaults to False.
+    - override_config (Optional[Dict[str, Any]]): Optional override
+        for the kernel configuration.
 
     Returns:
     - torch.Tensor: The output tensor after applying the MoE layer.
     """
     # Check constraints.
     assert hidden_states.shape[0] == gating_output.shape[0], (
-        'Number of tokens mismatch')
-    assert hidden_states.shape[1] == w1.shape[2], 'Hidden size mismatch'
-    assert gating_output.shape[1] == w1.shape[0], 'Number of experts mismatch'
-    assert hidden_states.is_contiguous(), 'Hidden_states must be contiguous'
-    assert w1.is_contiguous(), 'Expert weights1 must be contiguous'
-    assert w2.is_contiguous(), 'Expert weights2 must be contiguous'
-    assert hidden_states.dtype in [torch.float16, torch.bfloat16]
+        "Number of tokens mismatch")
+    assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
+    assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
+    assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
+    assert w1.is_contiguous(), "Expert weights1 must be contiguous"
+    assert w2.is_contiguous(), "Expert weights2 must be contiguous"
+    assert hidden_states.dtype in [
+        torch.float32, torch.float16, torch.bfloat16
+    ]
     M, _ = hidden_states.shape
     E, N, _ = w1.shape
 
@@ -302,20 +352,32 @@ def fused_moe(
     if renormalize:
         topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
-    config = {
-        'BLOCK_SIZE_M': 64,
-        'BLOCK_SIZE_N': 64,
-        'BLOCK_SIZE_K': 32,
-        'GROUP_SIZE_M': 8
-    }
-
-    if topk_ids.numel() <= w1.shape[0]:
-        config = {
-            'BLOCK_SIZE_M': 16,
-            'BLOCK_SIZE_N': 32,
-            'BLOCK_SIZE_K': 64,
-            'GROUP_SIZE_M': 1
-        }
+    if override_config:
+        config = override_config
+    else:
+        # First try to load optimal config from the file
+        configs = get_moe_configs(E, w2.shape[2])
+
+        if configs:
+            # If an optimal configuration map has been found, look up the
+            # optimal config
+            config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
+        else:
+            # Else use the default config
+            config = {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32,
+                'GROUP_SIZE_M': 8
+            }
+
+            if M <= E:
+                config = {
+                    'BLOCK_SIZE_M': 16,
+                    'BLOCK_SIZE_N': 32,
+                    'BLOCK_SIZE_K': 64,
+                    'GROUP_SIZE_M': 1
+                }
 
     intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
                                       device=hidden_states.device,
@@ -327,8 +389,8 @@ def fused_moe(
                                       device=hidden_states.device,
                                       dtype=hidden_states.dtype)
 
-    sorted_token_ids, expert_ids, num_tokens_post_padded = (
-        moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], E))
+    sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
+        topk_ids, config['BLOCK_SIZE_M'], E)
 
     invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
                             topk_weights, topk_ids, sorted_token_ids,

+ 57 - 18
aphrodite/modeling/layers/rejection.py

@@ -8,23 +8,26 @@ import torch.jit
 
 class RejectionSampler(nn.Module):
     """Apply modified rejection sampling as described in "Accelerating Large
-        Language Model Decoding with Speculative Sampling"
-        https://arxiv.org/pdf/2302.01318.pdf.
+    Language Model Decoding with Speculative Sampling"
+    https://arxiv.org/pdf/2302.01318.pdf.
     """
 
     def __init__(self, strict_mode: bool = False):
         """Create a rejection sampler.
+
         Args:
             strict_mode: Whether or not to perform shape/device/dtype checks
                 during sampling. This catches correctness issues but adds
                 nontrivial latency.
         """
         super().__init__()
-        self.probs_dtype = torch.float32
-        self.token_id_dtype = torch.int64
-        self._num_bonus_tokens = 1
         self._strict_mode = strict_mode
 
+        # NOTE: A "bonus token" is accepted iff all proposal tokens are
+        # accepted. There is always only one possible bonus token. We store this
+        # value in a variable for readability.
+        self._num_bonus_tokens = 1
+
         self.num_accepted_tokens: Optional[torch.Tensor] = None
         self.num_emitted_tokens: Optional[torch.Tensor] = None
         self.num_draft_tokens: int = 0
@@ -39,6 +42,14 @@ class RejectionSampler(nn.Module):
                                                dtype=torch.long,
                                                device=device)
 
+    @property
+    def probs_dtype(self):
+        return torch.float32
+
+    @property
+    def token_id_dtype(self):
+        return torch.int64
+
     def forward(
         self,
         target_probs: torch.Tensor,
@@ -49,24 +60,31 @@ class RejectionSampler(nn.Module):
         """Sample token ids using rejection sampling. This accepts or rejects
         tokens proposed by the draft model using the probability of each token
         according to the draft and target models.
+
         In the worst case where all draft tokens are rejected, it is guaranteed
         one correct token will be emitted.
+
         In the case where all draft tokens are accepted, a bonus token will be
         accepted as its cheap to have the target model score this speculative
         sequence.
+
         Args:
             target_probs: The probability distribution over token ids given
                 context according to the target model.
             shape = [batch_size, num_speculative_tokens, vocab_size]
+
             bonus_token_ids: The "bonus" token ids that are accepted iff all
                 speculative tokens in a sequence are accepted.
             shape = [batch_size, num_bonus_tokens]
+
             draft_probs: The probability distribution over token ids given
                 context according to the draft model.
             shape = [batch_size, num_speculative_tokens, vocab_size]
+
             draft_token_ids: The token ids that were sampled from the draft
                 probabilities.
             shape = [batch_size, num_speculative_tokens]
+
         Returns:
             output_token_ids: The token ids sampled via rejection sampling,
                 or -1 if unable to sample a token because the previous token
@@ -107,6 +125,7 @@ class RejectionSampler(nn.Module):
             draft_token_ids: torch.Tensor,  # [batch_size, k]
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Perform modified rejection sampling on each sequence.
+
         Returns:
             A tuple of two tensors:
             0: A bool tensor of which tokens in each sequence is accepted.
@@ -139,16 +158,20 @@ class RejectionSampler(nn.Module):
         r"""Create bool matrix over the proposed draft tokens. If
         True, then a token can be accepted, else it should be
         rejected.
+
         Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
         :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
         to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
         same conditional probability according to the draft model, the token
         is accepted with probability:
+
         .. math::
             \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
                            {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
+
         This implementation does not apply causality. When using the output,
         if a token is rejected, subsequent tokens should not be used.
+
         Returns a bool tensor of shape [batch_size, k] specifying which tokens
         are accepted.
         """
@@ -171,7 +194,8 @@ class RejectionSampler(nn.Module):
                                   device=target_probs.device)
         capped_ratio = torch.minimum(
             selected_target_probs / selected_draft_probs,
-            torch.full((1, ), 1, device=target_probs.device))
+            torch.full((1, ), 1, device=target_probs.device),
+        )
         accepted = uniform_rand < capped_ratio
 
         return accepted
@@ -183,21 +207,26 @@ class RejectionSampler(nn.Module):
     ) -> torch.Tensor:
         r"""Create a probability distribution for each proposed token which can
         be sampled if the proposed token is rejected.
+
         When this routine is applied sequentially, the true distribution of the
         target model is recovered (within hardware numerics).
+
         The probability distribution used in this rejection case is constructed
         as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
         :math:`x` given context :math:`x_1, \dots, x_n` according to the target
         model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
         according to the draft model:
+
         .. math::
             x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
+
         where :math:`(f(x))_+` is defined as:
+
         .. math::
             (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
-        See https://github.com/vllm-project/vllm/pull/2336 for a visualization
-        of the draft, target, and recovered probability distributions.
+
         Returns a tensor of shape [batch_size, k, vocab_size].
+
         Note: This batches operations on GPU and thus constructs the recovered
         distribution for all tokens, even if they are accepted. This causes
         division-by-zero errors, so we use self._smallest_positive_value to
@@ -208,7 +237,7 @@ class RejectionSampler(nn.Module):
         # shape [batch_size, k, vocab_size]
         difference = target_probs - draft_probs
 
-        # TODO(cade): Can we use logprobs instead of probs, and avoid the
+        # TODO: Can we use logprobs instead of probs, and avoid the
         # division-by-zero errors without introducing distribution drift?
 
         # shape [batch_size, k, vocab_size]
@@ -224,7 +253,9 @@ class RejectionSampler(nn.Module):
         """Return the smallest positive value representable by the probs dtype.
         This value is used when constructing a distribution from which to sample
         recovered tokens in the first rejection case.
+
         See _get_recovered_probs for more details
+
         Note that this isn't actually the smallest positive value representable
         by float32, but the smallest positive normal value.
         See https://en.wikipedia.org/wiki/Subnormal_number for more information.
@@ -241,6 +272,7 @@ class RejectionSampler(nn.Module):
         """Format output. Returns a matrix of token ids. When
         a token is rejected via rejection sampling, all subsequent
         token ids are set to -1 for the sequence.
+
         shape = [batch_size, k + num_bonus_tokens]
         """
         bonus_token_ids = bonus_token_ids.squeeze()
@@ -259,7 +291,8 @@ class RejectionSampler(nn.Module):
         output_with_bonus_tokens = -torch.ones(
             (batch_size, k + self._num_bonus_tokens),
             dtype=self.token_id_dtype,
-            device=accepted.device)
+            device=accepted.device,
+        )
         output = output_with_bonus_tokens[:, :k]
 
         # Fill in the first k columns of the output tensor using masks and data
@@ -290,8 +323,11 @@ class RejectionSampler(nn.Module):
         draft_probs: torch.Tensor,
         draft_token_ids: torch.Tensor,
     ) -> None:
-        (target_batch_size, num_target_probs,
-         target_vocab_size) = target_probs.shape
+        (
+            target_batch_size,
+            num_target_probs,
+            target_vocab_size,
+        ) = target_probs.shape
         bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
         draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
         draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
@@ -327,10 +363,13 @@ class RejectionSampler(nn.Module):
         draft_token_ids: torch.Tensor,
     ) -> None:
         devices = [
-            t.device for t in
-            [target_probs, bonus_token_ids, draft_probs, draft_token_ids]
+            t.device for t in [
+                target_probs,
+                bonus_token_ids,
+                draft_probs,
+                draft_token_ids,
+            ]
         ]
-        # pylint: disable=use-a-generator
         assert all([devices[0] == device for device in devices])
 
     def _raise_if_out_of_bounds_vocab(
@@ -358,8 +397,8 @@ def _multinomial(
     if num_samples > 1:
         # This is equivalent to torch.repeat_interleaved (which also
         # forces a GPU<->CPU sync).
-        probs = probs[:, None, :].expand(probs.shape[0], num_samples,
-                                         probs.shape[1]).contiguous().view(
-                                             -1, probs.shape[1])
+        probs = (probs[:, None, :].expand(probs.shape[0], num_samples,
+                                          probs.shape[1]).contiguous().view(
+                                              -1, probs.shape[1]))
     q = torch.empty_like(probs).exponential_(1.0)
     return probs.div_(q).argmax(dim=1).view(-1, num_samples)

+ 1 - 2
aphrodite/modeling/layers/sampler.py

@@ -802,7 +802,6 @@ def _get_logprobs(
         if (i < sampling_metadata.num_prompts
                 and sampling_params.prompt_logprobs is not None):
             num_logprobs = sampling_params.prompt_logprobs
-            prompt_len = sampling_metadata.prompt_lens[i]
             prompt_tokens = sampling_metadata.seq_data[
                 seq_ids[0]].prompt_token_ids
             group_prompt_logprobs: PromptLogprobs = [None]
@@ -876,7 +875,7 @@ def _build_sampler_output(
                                output_metadata.get(seq_ids[parent_id])))
         sampler_output.append(
             SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
-    return sampler_output
+    return SamplerOutput(outputs=sampler_output)
 
 
 def _miro_store_args(seqids: List[int], mus: List[float],

+ 62 - 24
aphrodite/modeling/loader.py

@@ -2,18 +2,24 @@
 import contextlib
 import gc
 from contextlib import nullcontext
-from typing import Optional, Type
+from typing import Type
 from loguru import logger
 
 import torch
 import torch.nn as nn
 
-from aphrodite.common.config import DeviceConfig, ModelConfig, LoRAConfig
+from aphrodite.common.config import DeviceConfig, ModelConfig
 from aphrodite.modeling.models import ModelRegistry
-from aphrodite.modeling.hf_downloader import (get_quant_config,
-                                              initialize_dummy_weights)
+from aphrodite.modeling.hf_downloader import (
+    get_quant_config,
+    initialize_dummy_weights,
+)
 from aphrodite.modeling.layers.quantization.bitsandbytes import (
-    BNBLinearMethod, replace_quant_params)
+    BNBLinearMethod,
+    replace_quant_params,
+)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size, )
 
 
 @contextlib.contextmanager
@@ -32,6 +38,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
     if (model_config.quantization is not None
             and "MixtralForCausalLM" in architectures):
         architectures = ["QuantMixtralForCausalLM"]
+
     for arch in architectures:
         model_cls = ModelRegistry.load_model_cls(arch)
         if model_cls is not None:
@@ -41,9 +48,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
         f"Supported architectures: {ModelRegistry.get_supported_archs()}")
 
 
-def get_model(model_config: ModelConfig,
-              device_config: DeviceConfig,
-              lora_config: Optional[LoRAConfig] = None) -> nn.Module:
+def get_model(model_config: ModelConfig, device_config: DeviceConfig,
+              **kwargs) -> nn.Module:
+    lora_config = kwargs.get("lora_config", None)
     model_class = _get_model_architecture(model_config)
 
     # Get the (maybe quantized) linear method.
@@ -68,9 +75,9 @@ def get_model(model_config: ModelConfig,
     with _set_default_torch_dtype(model_config.dtype):
         # Create a model instance.
         # The weights will be initialized as empty tensors.
-        with torch.device(device_config.device) if not \
-            (isinstance(linear_method, BNBLinearMethod) and
-             linear_method.quant_config.from_float) else nullcontext():
+        with torch.device(device_config.device) if not (
+                isinstance(linear_method, BNBLinearMethod)
+                and linear_method.quant_config.from_float) else nullcontext():
             if hasattr(model_class, "supported_lora_modules"):
                 model = model_class(model_config.hf_config, linear_method,
                                     lora_config)
@@ -88,23 +95,54 @@ def get_model(model_config: ModelConfig,
             initialize_dummy_weights(model)
         else:
             # Load the weights from the cached or downloaded files.
-            model.load_weights(model_config.model, model_config.download_dir,
-                               model_config.load_format, model_config.revision)
+            model.load_weights(
+                model_config.model,
+                model_config.download_dir,
+                model_config.load_format,
+                model_config.revision,
+            )
         if isinstance(linear_method, BNBLinearMethod):
-            replace_quant_params(model,
-                                 quant_config=linear_method.quant_config,
-                                 modules_to_not_convert="lm_head")
+            replace_quant_params(
+                model,
+                quant_config=linear_method.quant_config,
+                modules_to_not_convert="lm_head",
+            )
             torch.cuda.synchronize()
             if linear_method.quant_config.from_float:
                 model = model.cuda()
             gc.collect()
             torch.cuda.empty_cache()
-            logger.info("Memory allocated for converted model: {} GiB".format(
-                round(
-                    torch.cuda.memory_allocated(torch.cuda.current_device()) /
-                    (1024 * 1024 * 1024), 2)))
-            logger.info("Memory reserved for converted model: {} GiB".format(
-                round(
-                    torch.cuda.memory_reserved(torch.cuda.current_device()) /
-                    (1024 * 1024 * 1024), 2)))
+            tp = get_tensor_model_parallel_world_size()
+            logger.info(
+                "Memory allocated for converted model: {} GiB x {} = {} "
+                "GiB".format(
+                    round(
+                        torch.cuda.memory_allocated(
+                            torch.cuda.current_device()) /
+                        (1024 * 1024 * 1024),
+                        2,
+                    ),
+                    tp,
+                    round(
+                        torch.cuda.memory_allocated(
+                            torch.cuda.current_device()) * tp /
+                        (1024 * 1024 * 1024),
+                        2,
+                    ),
+                ))
+            logger.info(
+                "Memory reserved for converted model: {} GiB x {} = {} "
+                "GiB".format(
+                    round(
+                        torch.cuda.memory_reserved(torch.cuda.current_device())
+                        / (1024 * 1024 * 1024),
+                        2,
+                    ),
+                    tp,
+                    round(
+                        torch.cuda.memory_reserved(torch.cuda.current_device())
+                        * tp / (1024 * 1024 * 1024),
+                        2,
+                    ),
+                ))
     return model.eval()

+ 5 - 4
aphrodite/modeling/metadata.py

@@ -1,4 +1,4 @@
-from typing import Optional, List
+from typing import Optional
 
 import torch
 
@@ -28,7 +28,7 @@ class InputMetadata:
         block_tables: Optional[torch.Tensor],
         use_cuda_graph: bool,
         kv_cache_dtype: str,
-        kv_quant_params: List[List[float]],
+        # kv_quant_params: List[List[float]],
     ) -> None:
         self.is_prompt = is_prompt
         self.prompt_lens = prompt_lens
@@ -40,7 +40,7 @@ class InputMetadata:
         self.block_tables = block_tables
         self.use_cuda_graph = use_cuda_graph
         self.kv_cache_dtype = kv_cache_dtype
-        self.kv_quant_params = kv_quant_params
+        # self.kv_quant_params = kv_quant_params
 
         # Set during the execution of the first attention op.
         # FIXME: This is a hack.
@@ -55,4 +55,5 @@ class InputMetadata:
                 f"block_tables={self.block_tables}, "
                 f"use_cuda_graph={self.use_cuda_graph}, "
                 f"kv_cache_dtype={self.kv_cache_dtype}, "
-                f"kv_quant_params={self.kv_quant_params})")
+                # f"kv_quant_params={self.kv_quant_params})"
+                )

+ 3 - 4
aphrodite/modeling/models/baichuan.py

@@ -27,7 +27,7 @@ from torch import nn
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
@@ -187,7 +187,7 @@ class BaiChuanAttention(nn.Module):
             alibi_slopes = alibi_slopes[head_start:head_end].tolist()
 
             scaling = self.head_dim**-0.5
-            self.attn = PagedAttention(
+            self.attn = Attention(
                 self.num_heads,
                 self.head_dim,
                 scaling,
@@ -205,8 +205,7 @@ class BaiChuanAttention(nn.Module):
                 is_neox_style=is_neox_style,
             )
             self.scaling = self.head_dim**-0.5
-            self.attn = PagedAttention(self.num_heads, self.head_dim,
-                                       self.scaling)
+            self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
 
     def forward(
         self,

+ 5 - 5
aphrodite/modeling/models/bloom.py

@@ -26,7 +26,7 @@ from transformers import BloomConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
@@ -108,10 +108,10 @@ class BloomAttention(nn.Module):
         alibi_slopes = alibi_slopes[head_start:head_end].tolist()
 
         scaling = self.head_dim**-0.5
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scaling,
-                                   alibi_slopes=alibi_slopes)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scaling,
+                              alibi_slopes=alibi_slopes)
 
     def forward(
         self,

+ 2 - 2
aphrodite/modeling/models/chatglm.py

@@ -10,7 +10,7 @@ from torch.nn import LayerNorm
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
@@ -87,7 +87,7 @@ class GLMAttention(nn.Module):
             base=10000 * rope_ratio,
             is_neox_style=False,
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.head_dim,
             self.scaling,

+ 1 - 1
aphrodite/modeling/models/cohere.py

@@ -30,7 +30,7 @@ from transformers import CohereConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention as Attention
+from aphrodite.modeling.layers.attention import Attention as Attention
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     MergedColumnParallelLinear,

+ 3 - 3
aphrodite/modeling/models/deepseek.py

@@ -31,8 +31,8 @@ from transformers import PretrainedConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
-from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
+from aphrodite.modeling.layers.attention import Attention
+from aphrodite.modeling.layers.fused_moe.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
@@ -249,7 +249,7 @@ class DeepseekAttention(nn.Module):
             base=rope_theta,
             rope_scaling=rope_scaling,
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.head_dim,
             self.scaling,

+ 14 - 14
aphrodite/modeling/models/falcon.py

@@ -29,7 +29,7 @@ from transformers import FalconConfig as HF_FalconConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
@@ -151,10 +151,10 @@ class FalconAttention(nn.Module):
                 max_position=max_position_embeddings,
                 base=rope_theta,
             )
-            self.attn = PagedAttention(self.num_heads,
-                                       self.head_dim,
-                                       self.inv_norm_factor,
-                                       num_kv_heads=self.num_kv_heads)
+            self.attn = Attention(self.num_heads,
+                                  self.head_dim,
+                                  self.inv_norm_factor,
+                                  num_kv_heads=self.num_kv_heads)
         elif self.use_alibi:
             tp_rank = get_tensor_model_parallel_rank()
             head_start = tp_rank * self.num_heads
@@ -162,16 +162,16 @@ class FalconAttention(nn.Module):
             alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
                             self.inv_norm_factor)
             alibi_slopes = alibi_slopes[head_start:head_end].tolist()
-            self.attn = PagedAttention(self.num_heads,
-                                       self.head_dim,
-                                       self.inv_norm_factor,
-                                       num_kv_heads=self.num_kv_heads,
-                                       alibi_slopes=alibi_slopes)
+            self.attn = Attention(self.num_heads,
+                                  self.head_dim,
+                                  self.inv_norm_factor,
+                                  num_kv_heads=self.num_kv_heads,
+                                  alibi_slopes=alibi_slopes)
         else:
-            self.attn = PagedAttention(self.num_heads,
-                                       self.head_dim,
-                                       scale=self.inv_norm_factor,
-                                       num_kv_heads=self.num_kv_heads)
+            self.attn = Attention(self.num_heads,
+                                  self.head_dim,
+                                  scale=self.inv_norm_factor,
+                                  num_kv_heads=self.num_kv_heads)
 
     def forward(
         self,

+ 2 - 2
aphrodite/modeling/models/gemma.py

@@ -24,7 +24,7 @@ from transformers import GemmaConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import GeluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
@@ -181,7 +181,7 @@ class GemmaAttention(nn.Module):
             base=self.rope_theta,
             is_neox_style=True,
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.head_dim,
             self.scaling,

+ 2 - 4
aphrodite/modeling/models/gpt2.py

@@ -26,7 +26,7 @@ from transformers import GPT2Config
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
@@ -74,9 +74,7 @@ class GPT2Attention(nn.Module):
             bias=True,
             linear_method=linear_method,
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scale=self.scale)
+        self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
 
     def forward(
         self,

+ 5 - 5
aphrodite/modeling/models/gpt_bigcode.py

@@ -27,7 +27,7 @@ from transformers import GPTBigCodeConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
@@ -86,10 +86,10 @@ class GPTBigCodeAttention(nn.Module):
             bias=True,
             linear_method=linear_method,
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scale=self.scale,
-                                   num_kv_heads=self.num_kv_heads)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scale=self.scale,
+                              num_kv_heads=self.num_kv_heads)
 
     def forward(
         self,

+ 2 - 2
aphrodite/modeling/models/gpt_j.py

@@ -26,7 +26,7 @@ from transformers import GPTJConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
     LinearMethodBase,
@@ -117,7 +117,7 @@ class GPTJAttention(nn.Module):
             base=rope_theta,
             is_neox_style=False,
         )
-        self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
+        self.attn = Attention(self.num_heads, self.head_size, scaling)
 
     def forward(
         self,

+ 2 - 2
aphrodite/modeling/models/gpt_neox.py

@@ -26,7 +26,7 @@ from transformers import GPTNeoXConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
     LinearMethodBase,
@@ -99,7 +99,7 @@ class GPTNeoXAttention(nn.Module):
             base=rope_theta,
             is_neox_style=is_neox_style,
         )
-        self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
+        self.attn = Attention(self.num_heads, self.head_size, scaling)
 
     def forward(
         self,

+ 2 - 2
aphrodite/modeling/models/internlm2.py

@@ -7,7 +7,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
@@ -147,7 +147,7 @@ class InternLM2Attention(nn.Module):
             base=rope_theta,
             rope_scaling=rope_scaling,
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.head_dim,
             self.scaling,

+ 16 - 9
aphrodite/modeling/models/llama.py

@@ -30,7 +30,7 @@ from transformers import LlamaConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
@@ -199,7 +199,7 @@ class LlamaAttention(nn.Module):
             rope_scaling=rope_scaling,
             is_neox_style=is_neox_style,
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.head_dim,
             self.scaling,
@@ -213,7 +213,7 @@ class LlamaAttention(nn.Module):
         hidden_states: torch.Tensor,
         kv_cache: KVCache,
         input_metadata: InputMetadata,
-        kv_quant_param: List[float],
+        # kv_quant_param: List[float],
     ) -> torch.Tensor:
         if self.merge_weight:
             qkv, _ = self.qkv_proj(hidden_states)
@@ -225,8 +225,15 @@ class LlamaAttention(nn.Module):
             v, _ = self.v_proj(hidden_states)
         q, k = self.rotary_emb(positions, q, k)
         k_cache, v_cache = kv_cache
-        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
-                                kv_quant_param)
+        attn_output = self.attn(
+            q,
+            k,
+            v,
+            k_cache,
+            v_cache,
+            input_metadata,
+            # kv_quant_param
+        )
         output, _ = self.o_proj(attn_output)
         return output
 
@@ -279,7 +286,7 @@ class LlamaDecoderLayer(nn.Module):
         kv_cache: KVCache,
         input_metadata: InputMetadata,
         residual: Optional[torch.Tensor],
-        kv_quant_param: List[float],
+        # kv_quant_param: List[float],
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
         if residual is None:
@@ -293,7 +300,7 @@ class LlamaDecoderLayer(nn.Module):
             hidden_states=hidden_states,
             kv_cache=kv_cache,
             input_metadata=input_metadata,
-            kv_quant_param=kv_quant_param,
+            # kv_quant_param=kv_quant_param,
         )
 
         # Fully Connected
@@ -347,8 +354,8 @@ class LlamaModel(nn.Module):
                 kv_caches[i],
                 input_metadata,
                 residual,
-                input_metadata.kv_quant_params[i]
-                if input_metadata.kv_quant_params is not None else None,
+                # input_metadata.kv_quant_params[i]
+                # if input_metadata.kv_quant_params is not None else None,
             )
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states

+ 3 - 3
aphrodite/modeling/models/mixtral.py

@@ -31,8 +31,8 @@ from transformers import MixtralConfig
 
 from aphrodite.common.config import LoRAConfig
 from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.attention import PagedAttention
-from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
+from aphrodite.modeling.layers.attention import Attention
+from aphrodite.modeling.layers.fused_moe.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
@@ -256,7 +256,7 @@ class MixtralAttention(nn.Module):
             base=int(self.rope_theta),
             is_neox_style=is_neox_style,
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.head_dim,
             self.scaling,

+ 2 - 2
aphrodite/modeling/models/mixtral_quant.py

@@ -34,7 +34,7 @@ from torch import nn
 from transformers import MixtralConfig
 
 from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
@@ -259,7 +259,7 @@ class MixtralAttention(nn.Module):
             base=int(self.rope_theta),
             is_neox_style=True,
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.head_dim,
             self.scaling,

+ 6 - 6
aphrodite/modeling/models/mpt.py

@@ -8,7 +8,7 @@ import torch.nn as nn
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
@@ -105,11 +105,11 @@ class MPTAttention(nn.Module):
 
         self.head_dim = self.d_model // self.total_num_heads
         scaling = self.head_dim**-0.5
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scaling,
-                                   alibi_slopes=alibi_slopes,
-                                   num_kv_heads=self.num_kv_heads)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scaling,
+                              alibi_slopes=alibi_slopes,
+                              num_kv_heads=self.num_kv_heads)
 
     def forward(
         self,

+ 79 - 0
aphrodite/modeling/models/neuron/llama.py

@@ -0,0 +1,79 @@
+"""Inference-only LLaMA model compatible with HuggingFace weights."""
+import os
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import LlamaConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class LlamaForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: LlamaConfig,
+        linear_method=None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = None
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        with torch.inference_mode():
+            block_size = self.model.context_buckets[-1]
+            if input_metadata.is_prompt:
+                seq_ids = input_metadata.slot_mapping[:, 0] // block_size
+            else:
+                seq_ids = input_metadata.block_tables
+            logits = self.model(input_ids,
+                                cache_ids=positions,
+                                start_ids=seq_ids.flatten())
+        return logits
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.model.chkpt_model.lm_head,
+                                   hidden_states, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None,
+                     **kwargs):
+        from transformers_neuronx.llama.model import LlamaForSampling
+
+        split_model_dir = f"{model_name_or_path}-split"
+        if os.path.isdir(os.path.join(model_name_or_path,
+                                      "pytorch_model.bin")):
+            split_model_dir = model_name_or_path
+        elif not os.path.exists(f"{model_name_or_path}-split"):
+            from transformers.models.llama import LlamaForCausalLM
+            from transformers_neuronx.module import save_pretrained_split
+
+            hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path,
+                                                        low_cpu_mem_usage=True)
+            save_pretrained_split(hf_model, f"{model_name_or_path}-split")
+
+        self.model = LlamaForSampling.from_pretrained(split_model_dir,
+                                                      **kwargs)
+        self.model.to_neuron()

+ 4 - 4
aphrodite/modeling/models/olmo.py

@@ -45,7 +45,7 @@ import torch.nn.functional as F
 from torch import nn
 
 from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
     LinearMethodBase,
@@ -131,9 +131,9 @@ class OlmoAttention(nn.Module):
                 base=rope_theta,
             )
         self.scaling = self.head_dim**-0.5
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scale=self.scaling)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scale=self.scaling)
 
         # Attention output projection.
         self.attn_out = RowParallelLinear(

+ 4 - 4
aphrodite/modeling/models/opt.py

@@ -27,7 +27,7 @@ from transformers import OPTConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
     LinearMethodBase,
@@ -114,9 +114,9 @@ class OPTAttention(nn.Module):
             bias=bias,
             linear_method=linear_method,
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scale=self.scaling)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scale=self.scaling)
 
     def forward(
         self,

+ 2 - 2
aphrodite/modeling/models/phi.py

@@ -45,7 +45,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
     LinearMethodBase,
@@ -145,7 +145,7 @@ class PhiAttention(nn.Module):
             base=rope_theta,
             is_neox_style=is_neox_style,
         )
-        self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
+        self.attn = Attention(self.num_heads, self.head_size, scaling)
 
     def forward(
         self,

+ 2 - 2
aphrodite/modeling/models/qwen.py

@@ -12,7 +12,7 @@ from torch import nn
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
@@ -142,7 +142,7 @@ class QWenAttention(nn.Module):
             rope_scaling=rope_scaling,
             is_neox_style=is_neox_style,
         )
-        self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
+        self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
 
     def forward(
         self,

+ 2 - 2
aphrodite/modeling/models/qwen2.py

@@ -32,7 +32,7 @@ from transformers import Qwen2Config
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
@@ -193,7 +193,7 @@ class Qwen2Attention(nn.Module):
             max_position=max_position,
             base=self.rope_theta,
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.head_dim,
             self.scaling,

+ 2 - 2
aphrodite/modeling/models/stablelm.py

@@ -28,7 +28,7 @@ from transformers import PretrainedConfig
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     MergedColumnParallelLinear,
@@ -188,7 +188,7 @@ class StablelmAttention(nn.Module):
             max_position=self.config.max_position_embeddings,
             base=self.config.rope_theta,
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.head_dim,
             self.scaling,

+ 70 - 0
aphrodite/modeling/neuron_loader.py

@@ -0,0 +1,70 @@
+"""Utilities for selecting and loading models."""
+from typing import Type
+
+import torch
+import torch.nn as nn
+from transformers import PretrainedConfig
+
+from aphrodite.common.config import ModelConfig, DeviceConfig
+from aphrodite.modeling.models import ModelRegistry
+
+TORCH_DTYPE_TO_NEURON_AMP = {
+    "auto": "f32",
+    "half": "f16",
+    "float16": "f16",
+    "bfloat16": "bf16",
+    "float": "f32",
+    "float32": "f32",
+    torch.float16: "f16",
+    torch.bfloat16: "bf16",
+    torch.float32: "f32",
+}
+
+
+def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
+    architectures = getattr(config, "architectures", [])
+    for arch in architectures:
+        model_cls = ModelRegistry.load_model_cls(arch)
+        if model_cls is not None:
+            return model_cls
+    raise ValueError(
+        f"Model architectures {architectures} are not supported for now. "
+        f"Supported architectures: {ModelRegistry.get_supported_archs()}")
+
+
+def get_model(model_config: ModelConfig, device_config: DeviceConfig,
+              **kwargs) -> nn.Module:
+    from transformers_neuronx.config import (
+        NeuronConfig,
+        ContinuousBatchingConfig,
+    )
+
+    parallel_config = kwargs.get("parallel_config")
+    scheduler_config = kwargs.get("scheduler_config")
+
+    model_class = _get_model_architecture(model_config.hf_config)
+    linear_method = None
+
+    # Create a model instance.
+    model = model_class(model_config.hf_config, linear_method)
+
+    continuous_batching_config = ContinuousBatchingConfig(
+        batch_size_for_shared_caches=scheduler_config.max_num_seqs)
+    neuron_config = NeuronConfig(
+        continuous_batching=continuous_batching_config)
+
+    # Load the weights from the cached or downloaded files.
+    model.load_weights(
+        model_config.model,
+        model_config.download_dir,
+        model_config.load_format,
+        model_config.revision,
+        tp_degree=parallel_config.neuron_tp_degree,
+        amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
+        neuron_config=neuron_config,
+        context_length_estimate=[scheduler_config.max_model_len],
+        n_positions=[scheduler_config.max_model_len],
+        batch_size=scheduler_config.max_num_seqs,
+    )
+
+    return model.eval()

+ 2 - 2
aphrodite/modeling/sampling_metadata.py

@@ -5,7 +5,7 @@ import torch
 
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import SequenceData
-from aphrodite.common.utils import in_wsl
+from aphrodite.common.utils import in_wsl, is_neuron
 
 _SAMPLING_EPS = 1e-5
 
@@ -292,7 +292,7 @@ class SamplingTensors:
                    dtype: torch.dtype) -> "SamplingTensors":
         # Note that the performance will be very bad without
         # pinned memory.
-        pin_memory = not in_wsl()
+        pin_memory = not in_wsl() and not is_neuron()
         prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
         prompt_padded_tokens = [
             tokens + [vocab_size] * (prompt_max_len - len(tokens))

+ 17 - 0
aphrodite/modeling/utils.py

@@ -1,10 +1,18 @@
 """Utils for model executor."""
 import random
+import importlib
 from typing import Any, Dict, Optional
 
 import numpy as np
 import torch
 
+from aphrodite.common.config import DeviceConfig, ModelConfig
+
+DEVICE_TO_MODEL_LOADER_MAP = {
+    "cuda": "loader",
+    "neuron": "neuron_loader",
+}
+
 
 def set_random_seed(seed: int) -> None:
     random.seed(seed)
@@ -33,3 +41,12 @@ def set_weight_attrs(
         assert not hasattr(
             weight, key), (f"Overwriting existing tensor attribute: {key}")
         setattr(weight, key, value)
+
+
+def get_model(model_config: ModelConfig, device_config: DeviceConfig,
+              **kwargs) -> torch.nn.Module:
+    model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
+    imported_model_loader = importlib.import_module(
+        f"aphrodite.modeling.{model_loader_module}")
+    get_model_fn = imported_model_loader.get_model
+    return get_model_fn(model_config, device_config, **kwargs)

+ 39 - 48
aphrodite/processing/block_manager.py

@@ -1,7 +1,6 @@
 """A block manager that manages token blocks."""
-
 import enum
-from itertools import count
+from itertools import count, takewhile
 from os.path import commonprefix
 from typing import Dict, List, Optional, Set, Tuple
 
@@ -19,14 +18,12 @@ class BlockAllocator:
     the reference count becomes zero, the block is added back to the free list.
     """
 
-    def __init__(
-        self,
-        device: Device,
-        block_size: int,
-        num_blocks: int,
-        eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
-        enable_caching: bool = False,
-    ) -> None:
+    def __init__(self,
+                 device: Device,
+                 block_size: int,
+                 num_blocks: int,
+                 eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
+                 enable_caching: bool = False) -> None:
         self.device = device
         self.block_size = block_size
         self.num_blocks = num_blocks
@@ -49,13 +46,11 @@ class BlockAllocator:
             block.block_hash = block_hash
             block.num_hashed_tokens = num_hashed_tokens
             return block
-        block = PhysicalTokenBlock(
-            device=self.device,
-            block_number=self.current_num_blocks,
-            block_size=self.block_size,
-            block_hash=block_hash,
-            num_hashed_tokens=num_hashed_tokens,
-        )
+        block = PhysicalTokenBlock(device=self.device,
+                                   block_number=self.current_num_blocks,
+                                   block_size=self.block_size,
+                                   block_hash=block_hash,
+                                   num_hashed_tokens=num_hashed_tokens)
         self.current_num_blocks += 1
         return block
 
@@ -126,7 +121,6 @@ class AllocStatus(enum.Enum):
     3. Never: seq_group can never be allocated.
       The seq_group is too large to allocated in GPU.
     """
-
     OK = enum.auto()
     LATER = enum.auto()
     NEVER = enum.auto()
@@ -150,10 +144,8 @@ class BlockSpaceManager:
 
         self.block_sliding_window = None
         if sliding_window is not None:
-            assert sliding_window % block_size == 0, (
-                sliding_window,
-                block_size,
-            )
+            assert sliding_window % block_size == 0, (sliding_window,
+                                                      block_size)
             self.block_sliding_window = sliding_window // block_size
 
         self.watermark = watermark
@@ -162,23 +154,19 @@ class BlockSpaceManager:
         self.enable_caching = enable_caching
 
         self.watermark_blocks = int(watermark * num_gpu_blocks)
-        self.gpu_allocator = BlockAllocator(
-            Device.GPU,
-            block_size,
-            num_gpu_blocks,
-            enable_caching=enable_caching,
-        )
-        self.cpu_allocator = BlockAllocator(
-            Device.CPU,
-            block_size,
-            num_cpu_blocks,
-            enable_caching=enable_caching,
-        )
+        self.gpu_allocator = BlockAllocator(Device.GPU,
+                                            block_size,
+                                            num_gpu_blocks,
+                                            enable_caching=enable_caching)
+        self.cpu_allocator = BlockAllocator(Device.CPU,
+                                            block_size,
+                                            num_cpu_blocks,
+                                            enable_caching=enable_caching)
         # Mapping: seq_id -> BlockTable.
         self.block_tables: Dict[int, BlockTable] = {}
 
     def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
-        # FIXME(woosuk): Here we assume that all sequences in the group share
+        # FIXME: Here we assume that all sequences in the group share
         # the same prompt. This may not be true for preempted sequences.
         seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
         num_required_blocks = len(seq.logical_token_blocks)
@@ -213,8 +201,7 @@ class BlockSpaceManager:
             else:
                 block = self.gpu_allocator.allocate(
                     seq.hash_of_block(logical_idx),
-                    seq.num_hashed_tokens_of_block(logical_idx),
-                )
+                    seq.num_hashed_tokens_of_block(logical_idx))
             block_table.append(block)
 
         # Assign the block table for each sequence.
@@ -444,23 +431,29 @@ class BlockSpaceManager:
         for block in block_table:
             block.last_accessed = access_time
 
-    def compute_last_full_block_in_seq(self, seq: Sequence):
+    def compute_full_blocks_in_seq(self, seq: Sequence):
         if seq.seq_id not in self.block_tables:
             return
         max_full_block = seq.get_len() // self.block_size - 1
         block_table = self.block_tables[seq.seq_id]
         if max_full_block == -1:
             return
-        block_table[max_full_block].computed = True
+        for i in reversed(range(max_full_block)):
+            if block_table[i].computed:
+                break
+            block_table[i].computed = True
 
-    def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
+    def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
         if seq.seq_id not in self.block_tables:
             return []
         block_table = self.block_tables[seq.seq_id]
-        for block_idx in reversed(range(len(block_table))):
-            if block_table[block_idx].computed:
-                return [b.block_number for b in block_table[:block_idx + 1]]
-        return []
+        # NOTE We exclude the last block to avoid the case where the entire
+        # prompt is cached. This would cause erroneous behavior in model
+        # runner.
+        return [
+            b.block_number
+            for b in takewhile(lambda b: b.computed, block_table[:-1])
+        ]
 
     def get_common_computed_block_ids(self,
                                       seq_group: SequenceGroup) -> List[int]:
@@ -469,14 +462,12 @@ class BlockSpaceManager:
             return []
 
         ids_list = [
-            self.get_all_block_ids_till_computed(seq)
+            self.get_all_computed_blocks(seq)
             for seq in iter(seq_group.seqs_dict.values())
         ]
         return commonprefix([ids for ids in ids_list if ids != []])
 
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
-        # NOTE: We only mark the last full block because with prefix caching,
-        # all blocks until the marked one are guaranteed to be computed.
         if self.enable_caching:
             for seq in seq_group.seqs_dict.values():
-                self.compute_last_full_block_in_seq(seq)
+                self.compute_full_blocks_in_seq(seq)

+ 2 - 3
aphrodite/processing/evictor.py

@@ -7,8 +7,9 @@ from aphrodite.common.block import PhysicalTokenBlock
 
 class EvictionPolicy(enum.Enum):
     """Enum for eviction policy used by make_evictor to instantiate the correct
-       Evictor subclass.
+    Evictor subclass.
     """
+
     LRU = enum.auto()
     FIFO = enum.auto()
 
@@ -115,7 +116,6 @@ class LRUEvictor(Evictor):
         return block
 
     @property
-    # pylint: disable=invalid-overridden-method
     def num_blocks(self) -> int:
         return len(self.free_table)
 
@@ -149,7 +149,6 @@ class RandomEvictor(Evictor):
         return block
 
     @property
-    # pylint: disable=invalid-overridden-method
     def num_blocks(self) -> int:
         return len(self.free_table)
 

+ 1 - 4
aphrodite/processing/scheduler.py

@@ -65,10 +65,7 @@ class SchedulerOutputs:
     def _sort_by_lora_ids(self) -> bool:
         self.scheduled_seq_groups = sorted(
             self.scheduled_seq_groups,
-            key=lambda g: (
-                g.lora_request.lora_int_id if g.lora_request else 0,
-                g.request_id,
-            ),
+            key=lambda g: (g.lora_int_id, g.request_id),
         )
 
     @property

+ 398 - 0
aphrodite/spec_decode/batch_expansion.py

@@ -0,0 +1,398 @@
+from typing import Iterator, List, Tuple, Optional, Dict
+from itertools import chain, count
+
+import torch
+
+from aphrodite.common.sequence import (
+    SamplerOutput,
+    SequenceGroupMetadata,
+    SequenceData,
+)
+from aphrodite.task_handler.worker import Worker
+from aphrodite.spec_decode.util import (
+    nvtx_range,
+    sampler_output_to_torch,
+    get_all_seq_ids,
+    split_batch_by_proposal_len,
+)
+from aphrodite.spec_decode.interfaces import (
+    SpeculativeScorer,
+    SpeculativeProposals,
+    SpeculativeScores,
+)
+
+SeqId = int
+TargetSeqId = int
+TokenId = int
+
+
+class BatchExpansionTop1Scorer(SpeculativeScorer):
+    """Implements a speculative scorer that uses batch expansion to get
+    probabilities of speculative tokens according to the scoring model.
+
+    Batch expansion converts a list of sequences and multiple query positions
+    to a new batch of sequences, each with a single query position. This allows
+    for MQA-like scoring in speculative decoding without requiring an MQA
+    kernel.
+
+    It is strictly less efficient than MQA scoring.
+
+    It only supports scoring the top1 proposal tokens of the proposer, instead
+    of topk/tree.
+    """
+
+    def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
+        self._scorer_worker = scorer_worker
+        self._device = device
+        self._vocab_size = vocab_size
+
+    @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
+    def score_proposals(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Optional[Dict[int, int]],
+        blocks_to_swap_out: Optional[Dict[int, int]],
+        blocks_to_copy: Optional[Dict[int, List[int]]],
+        k: int,
+        proposals: SpeculativeProposals,
+    ) -> SpeculativeScores:
+        """Score the proposed tokens via the scorer model.
+
+        This converts each input sequence to a set of k+1 target sequences. The
+        target sequences have the unique continuations to be scored and a
+        unique sequence ID that is different from all input sequence ids.
+
+        If a speculative sequence length would exceed the max model length, then
+        no speculation is produced for that sequence.
+
+        Args:
+            seq_group_metadata_list: The input sequence group metadata.
+            blocks_to_swap_in: This is passed to the worker during scoring.
+            blocks_to_swap_out: This is passed to the worker during scoring.
+            blocks_to_copy: This is passed to the worker during scoring.
+            k: The fixed proposal length.
+            proposals: The speculative proposals to score.
+        Returns:
+            SpeculativeScores: The scores of each speculative token, along with
+                which sequences were ignored during scoring.
+        """
+
+        # TODO: perform this on GPU to remove blocking call.
+        proposal_lens_list = proposals.proposal_lens.tolist()
+        proposal_token_ids_list = proposals.proposal_token_ids.tolist()
+
+        (
+            spec_indices,
+            non_spec_indices,
+            target_seq_group_metadata_list,
+            num_scoring_tokens,
+        ) = self._expand_batch(
+            seq_group_metadata_list=seq_group_metadata_list,
+            proposal_token_ids_list=proposal_token_ids_list,
+            proposal_lens_list=proposal_lens_list,
+        )
+
+        target_sampler_output = self._scorer_worker.execute_model(
+            seq_group_metadata_list=target_seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+            return_python_output=False,
+        )
+
+        all_tokens, all_probs = self._contract_batch(
+            original_bs=len(seq_group_metadata_list),
+            target_sampler_output=target_sampler_output,
+            proposals=proposals,
+            num_scoring_tokens=num_scoring_tokens,
+            non_spec_indices=non_spec_indices,
+            spec_indices=spec_indices,
+            k=k,
+        )
+
+        return SpeculativeScores(
+            probs=all_probs,
+            token_ids=all_tokens,
+        )
+
+    def _expand_batch(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        proposal_token_ids_list: List[TokenId],
+        proposal_lens_list: List[int],
+    ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
+        """Given the input sequences and potentially multiple corresponding
+        proposal tokens, create a new batch where each sequence has a single
+        query token.
+        """
+
+        # Aphrodite currently only supports proposal lens equal to zero or the
+        # batch proposal len. This adds some complexity (splitting the batch
+        # into spec and non spec sequences) and should be removed in the
+        # future. It can be done by supporting per-sequence proposal lens.
+        spec_seqs, spec_indices = split_batch_by_proposal_len(
+            seq_group_metadata_list,
+            proposal_lens_list,
+            select_proposal_len_zero=False,
+        )
+        non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
+            seq_group_metadata_list,
+            proposal_lens_list,
+            select_proposal_len_zero=True,
+        )
+
+        target_seq_group_metadata_list = self._create_scoring_model_input(
+            spec_seqs, proposal_token_ids_list)
+        num_scoring_tokens = len(target_seq_group_metadata_list)
+        target_seq_group_metadata_list.extend(non_spec_seqs)
+
+        return (
+            spec_indices,
+            non_spec_indices,
+            target_seq_group_metadata_list,
+            num_scoring_tokens,
+        )
+
+    def _contract_batch(
+        self,
+        original_bs: int,
+        target_sampler_output: List[SamplerOutput],
+        proposals: SpeculativeProposals,
+        num_scoring_tokens: int,
+        non_spec_indices: List[int],
+        spec_indices: List[int],
+        k: int,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Contract the expanded batch back into its original size.
+        This maps the scores of speculative tokens back to their original
+        sequences.
+        """
+        (
+            target_token_ids,
+            target_probs,
+            non_spec_target_token_ids,
+            non_spec_target_probs,
+        ) = self._split_scoring_output(target_sampler_output,
+                                       num_scoring_tokens)
+
+        # Map distinct sequences used to score each token
+        # of shape [batch_size * k + 1] back to [batch_size, k + 1].
+        batch_size, k = proposals.proposal_token_ids.shape
+
+        target_token_ids = target_token_ids.squeeze().reshape(
+            batch_size, k + 1)
+        target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
+                                                      self._vocab_size)
+
+        all_tokens = torch.full(
+            size=(original_bs, k + 1),
+            fill_value=-1,
+            device=self._device,
+            dtype=torch.long,
+        )
+        all_probs = torch.zeros(
+            original_bs,
+            k + 1,
+            self._vocab_size,
+            device=self._device,
+            dtype=torch.float32,
+        )
+
+        if non_spec_indices:
+            all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
+            all_probs[non_spec_indices, :1, :] = non_spec_target_probs
+
+        if spec_indices:
+            all_tokens[spec_indices] = target_token_ids
+            all_probs[spec_indices] = target_probs
+
+        return all_tokens, all_probs
+
+    def _create_scoring_model_input(
+            self,
+            seq_group_metadata_list: List[SequenceGroupMetadata],
+            proposal_token_ids: List[List[TokenId]],  # shape: [batch_size, k]
+    ) -> List[SequenceGroupMetadata]:
+        """Given the original input sequences and proposed tokens from the draft
+        model, create a list of target sequences that can be used for scoring.
+        """
+
+        if not seq_group_metadata_list:
+            return []
+
+        target_seq_ids_iter = self._create_target_seq_id_iterator(
+            get_all_seq_ids(seq_group_metadata_list))
+
+        target_seq_group_metadata = list(
+            chain.from_iterable(
+                self._create_target_seq_group_metadata(
+                    seq_group_metadata,
+                    proposal_token_ids,
+                    i,
+                    target_seq_ids_iter,
+                ) for i, seq_group_metadata in enumerate(
+                    seq_group_metadata_list)))
+
+        return target_seq_group_metadata
+
+    def _create_target_seq_group_metadata(
+        self,
+        input_seq_group_metadata: SequenceGroupMetadata,
+        proposal_token_ids: List[TokenId],  # shape: [batch_size, k]
+        batch_index: int,
+        target_seq_ids_iter: Iterator[TargetSeqId],
+    ) -> List[SequenceGroupMetadata]:
+        """Given an input sequence group metadata and a list of draft tokens,
+        create a list of target SequenceGroupMetadata, one for each
+        token id that needs to be scored.
+
+        Naive speculative decoding requires K target model scores, one for each
+        draft model token. However one can add a bonus token such that if each
+        token is accepted, then a final token may be sampled from the model.
+        This function creates K+1 target SequenceGroupMetadata to take
+        advantage of the bonus token.
+        """
+        assert not input_seq_group_metadata.is_prompt, (
+            "Speculating on "
+            "prompts not yet supported")
+        assert len(input_seq_group_metadata.seq_data) == 1, (
+            "Beam search "
+            "not supported in speculative decoding")
+        input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
+
+        token_ids_to_score = self._get_token_ids_to_score(
+            proposal_token_ids[batch_index])
+
+        target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
+        for token_ids in token_ids_to_score:
+            target_seq_group_metadata_list.append(
+                self._create_single_target_seq_group_metadata(
+                    input_seq_group_metadata,
+                    input_seq_id,
+                    next(target_seq_ids_iter),
+                    token_ids,
+                ))
+
+        return target_seq_group_metadata_list
+
+    def _create_single_target_seq_group_metadata(
+        self,
+        seq_group_metadata: SequenceGroupMetadata,
+        seq_id: SeqId,
+        target_seq_id: TargetSeqId,
+        token_ids: List[TokenId],
+    ) -> SequenceGroupMetadata:
+        """Create a single target SequenceGroupMetadata.
+
+        Args:
+            seq_group_metadata: The metadata for the input sequence.
+            seq_id: The input sequence ID.
+            target_seq_id: The corresponding target sequence ID.
+            token_ids: The list of token ids that are to be appended to the
+                input sequence.
+        """
+        seq_data = seq_group_metadata.seq_data[seq_id]
+        prompt_token_ids = seq_data.get_prompt_token_ids()
+        new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
+
+        return SequenceGroupMetadata(
+            request_id=seq_group_metadata.request_id,
+            is_prompt=seq_group_metadata.is_prompt,
+            seq_data={
+                target_seq_id:
+                SequenceData(
+                    prompt_token_ids=prompt_token_ids,
+                    output_token_ids=new_output_token_ids,
+                ),
+            },
+            sampling_params=seq_group_metadata.sampling_params,
+            block_tables={
+                target_seq_id: seq_group_metadata.block_tables[seq_id],
+            },
+            lora_request=None,
+            persistent_data={},
+        )
+
+    def _split_scoring_output(
+        self, sampler_output: SamplerOutput, num_scoring_tokens: int
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Split the target model output into speculative and non-speculative
+        output.
+        """
+
+        # Aphrodite currently only supports proposal lens equal to zero or the
+        # batch proposal len. This adds some complexity (splitting the batch
+        # into spec and non spec sequences) and should be removed in the
+        # future. It can be done by supporting per-sequence proposal lens.
+        # First samples are from speculative scoring, latter samples are non-
+        # speculative samples.
+        split_sizes = [
+            num_scoring_tokens,
+            sampler_output.sampled_token_ids.numel() - num_scoring_tokens,
+        ]
+        (spec_probs, non_spec_probs
+         ) = sampler_output.sampled_token_probs.split(split_sizes)
+        (
+            spec_sampled_tokens,
+            non_spec_sampled_tokens,
+        ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
+
+        # Convert scores to tensors.
+        sampler_output.sampled_token_probs = spec_probs
+        sampler_output.sampled_token_ids = spec_sampled_tokens
+        target_token_ids, target_probs = sampler_output_to_torch(
+            [sampler_output])
+
+        # Convert non-speculative output tokens to tensors.
+        sampler_output.sampled_token_probs = non_spec_probs
+        sampler_output.sampled_token_ids = non_spec_sampled_tokens
+        (
+            non_spec_target_token_ids,
+            non_spec_target_probs,
+        ) = sampler_output_to_torch([sampler_output])
+
+        return (
+            target_token_ids,
+            target_probs,
+            non_spec_target_token_ids,
+            non_spec_target_probs,
+        )
+
+    def _create_target_seq_id_iterator(
+            self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
+        """Create an iterator for creating target sequence ids.
+        Target sequence ids are distinct from sequence ids because we create a
+        distinct target sequence id for each proposal token to be scored.
+
+        This implementation increments a counter starting at 1 + max of all
+        provided input sequence ids.
+        """
+        return count(start=max(seq_ids) + 1)
+
+    def _get_token_ids_to_score(
+            self,
+            full_spec_token_ids: List[TokenId],  # shape: [k]
+    ) -> List[List[TokenId]]:
+        """Given an int tensor of proposal token ids, return a list of
+        token ids that should be scored.
+
+        Returns k+1 output lists. The additional one is used for generating the
+        bonus token.
+
+        Example:
+            Input: [0, 1, 2, 3] (k=4)
+            Output: (k+1 lists)
+                []
+                [0]
+                [0, 1]
+                [0, 1, 2]
+                [0, 1, 2, 3]
+        """
+        empty_token_ids = []
+
+        token_ids_to_score = [empty_token_ids]
+        token_ids_to_score.extend([
+            full_spec_token_ids[:i + 1]
+            for i in range(len(full_spec_token_ids))
+        ])
+        return token_ids_to_score

+ 77 - 0
aphrodite/spec_decode/interfaces.py

@@ -0,0 +1,77 @@
+from typing import List, Tuple, Optional, Dict
+from dataclasses import dataclass
+from abc import ABC, abstractmethod
+
+import torch
+
+from aphrodite.common.sequence import SequenceGroupMetadata
+
+
+@dataclass
+class SpeculativeProposals:
+    """Datastructure used to represent proposal tokens from some proposer. It
+    also tracks how many speculative tokens each sequence has.
+    """
+
+    # Speculative proposal tokens.
+    proposal_token_ids: torch.Tensor
+
+    # Probabilities of the proposal tokens according to the proposer.
+    proposal_probs: torch.Tensor
+
+    # The valid length of each proposal; can be zero.
+    proposal_lens: torch.Tensor
+
+    def __repr__(self):
+        return (f"SpeculativeProposals("
+                f"proposal_token_ids={self.proposal_token_ids.shape}, "
+                f"proposal_probs={self.proposal_probs.shape}, "
+                f"proposal_lens={self.proposal_lens.shape})")
+
+
+@dataclass
+class SpeculativeScores:
+    """Datastructure used to represent the scores of speculative tokens
+    according to the scoring model.
+    """
+
+    # Probabilities of the speculative tokens according to the scoring model.
+    probs: torch.Tensor
+
+    # Token ids sampled from the scoring model. Used for speculative bonus
+    # tokens and also non-speculative normal decoding.
+    token_ids: torch.Tensor
+
+    def __repr__(self):
+        return (f"SpeculativeScores("
+                f"probs={self.probs.shape}, "
+                f"token_ids={self.token_ids.shape})")
+
+
+class SpeculativeProposer(ABC):
+
+    @abstractmethod
+    def get_proposals(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+        max_proposal_len: int,
+    ) -> SpeculativeProposals:
+        raise NotImplementedError
+
+
+class SpeculativeScorer(ABC):
+
+    @abstractmethod
+    def score_proposals(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Optional[Dict[int, int]],
+        blocks_to_swap_out: Optional[Dict[int, int]],
+        blocks_to_copy: Optional[Dict[int, List[int]]],
+        k: int,
+        proposals: SpeculativeProposals,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        raise NotImplementedError

+ 175 - 0
aphrodite/spec_decode/metrics.py

@@ -0,0 +1,175 @@
+import torch
+from dataclasses import dataclass
+from typing import Optional
+import time
+from typing import Callable
+
+from aphrodite.modeling.layers.rejection import RejectionSampler
+from aphrodite.common.utils import in_wsl
+
+
+@dataclass
+class SpecDecodeWorkerMetrics:
+    """Dataclass holding metrics emitted from the spec decode worker."""
+
+    # The empirical acceptance rate of the proposal method on a per-token basis.
+    # This is useful for evaluating how well the proposal method aligns with the
+    # scoring method.
+    draft_acceptance_rate: float
+
+    # The empirical efficiency, measured as the number of tokens emitted by the
+    # system divided by the number of tokens that could be emitted by the system
+    # if the proposal method were perfect.
+    system_efficiency: float
+
+    # The number of speculative tokens produced by the proposal method.
+    draft_tokens: int
+
+    # The number of tokens emitted by the entire system.
+    emitted_tokens: int
+
+    # The number of tokens accepted by the scoring model and verification
+    # routine, e.g. Llama2-70B and lossless rejection sampling.
+    #
+    # NOTE: Any token accepted by the verification routine is considered
+    # accepted (regardless of if the speculative prefix is also accepted). The
+    # user will usually see less accepted tokens. This metric is helpful when
+    # evaluating alignment of the proposal method with the scoring model.
+    accepted_tokens: int
+
+    # The number of speculative tokens per sequence.
+    num_spec_tokens: int
+
+
+Timer = Callable[[], float]
+
+
+class AsyncMetricsCollector:
+    """Class which copies rejection sampler metrics from the device to CPU on a
+    non-default Torch stream.
+    """
+
+    def __init__(
+        self,
+        rejection_sampler: RejectionSampler,
+        timer: Optional[Timer] = None,
+        collect_interval_s: float = 5.0,
+    ):
+        self._rejection_sampler = rejection_sampler
+        self._timer = time.time if timer is None else timer
+
+        self._rank: Optional[int] = None
+
+        # We don't have a device set yet.
+        self._copy_stream: Optional[torch.cuda.Stream] = None
+
+        self._in_flight_copy: Optional[torch.cuda.Event] = None
+
+        pin_memory = not in_wsl()
+        self._aggregate_num_accepted_tokens = torch.tensor(
+            0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
+        self._aggregate_num_emitted_tokens = torch.tensor(
+            0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
+        self._aggregate_num_draft_tokens = 0
+
+        self._rejsample_metrics_collect_interval_s = collect_interval_s
+        self._last_metrics_collect_time = self._timer()
+
+    def init_gpu_tensors(self, rank: int) -> None:
+        self._rank = rank
+        self._copy_stream = torch.cuda.Stream()
+
+    def maybe_collect_rejsample_metrics(
+            self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
+        # If a copy was initiated in the previous call, collect and return.
+        if self._in_flight_copy is not None:
+            ready_event = self._in_flight_copy
+            self._in_flight_copy = None
+            return self._collect_rejsample_metrics(k, ready_event)
+
+        # Otherwise, check if we should start a new copy.
+        if self._should_collect_rejsample_metrics(self._timer()):
+            assert self._in_flight_copy is None
+            self._in_flight_copy = self._copy_rejsample_metrics_async()
+
+        return None
+
+    def _should_collect_rejsample_metrics(self, now: float) -> bool:
+        """Return whether or not this iteration should print rejection sampling
+        metrics.
+        """
+        if self._rank != 0:
+            return False
+
+        if (now - self._last_metrics_collect_time <
+                self._rejsample_metrics_collect_interval_s):
+            return False
+        return True
+
+    def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
+        """Copy rejection sampling metrics (number of accepted tokens, etc) to
+        CPU asynchronously.
+
+        Returns a CUDA event recording when the copy is complete.
+        """
+        self._copy_stream.wait_stream(torch.cuda.current_stream())
+
+        with torch.cuda.stream(self._copy_stream):
+            self._aggregate_num_accepted_tokens.copy_(
+                self._rejection_sampler.num_accepted_tokens, non_blocking=True)
+            self._aggregate_num_emitted_tokens.copy_(
+                self._rejection_sampler.num_emitted_tokens, non_blocking=True)
+            # Number of draft tokens is calculated on CPU, so no copy is
+            # required.
+            self._aggregate_num_draft_tokens = (
+                self._rejection_sampler.num_draft_tokens)
+
+        aggregate_metrics_ready = torch.cuda.Event()
+        aggregate_metrics_ready.record(self._copy_stream)
+
+        return aggregate_metrics_ready
+
+    def _collect_rejsample_metrics(
+            self, k: int,
+            ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
+        """Create metrics object from statistics copied asynchronously.
+
+        Args:
+            k: int. The number of speculative tokens; used to determine system
+                efficiency.
+            ready_event: torch.cuda.Event. The CUDA event recording when the
+                async GPU->CPU copy is complete.
+        """
+
+        ready_event.synchronize()
+        accepted_tokens = self._aggregate_num_accepted_tokens.item()
+        emitted_tokens = self._aggregate_num_emitted_tokens.item()
+        draft_tokens = self._aggregate_num_draft_tokens
+
+        num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
+
+        if draft_tokens > 0:
+            draft_acceptance_rate = accepted_tokens / draft_tokens
+        else:
+            draft_acceptance_rate = float("nan")
+
+        if num_possible_tokens > 0:
+            system_efficiency = emitted_tokens / num_possible_tokens
+        else:
+            system_efficiency = float("nan")
+
+        return SpecDecodeWorkerMetrics(
+            num_spec_tokens=k,
+            draft_acceptance_rate=draft_acceptance_rate,
+            system_efficiency=system_efficiency,
+            accepted_tokens=accepted_tokens,
+            draft_tokens=draft_tokens,
+            emitted_tokens=emitted_tokens,
+        )
+
+    @staticmethod
+    def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
+        # Divide by k since batch size can be variable.
+        total_num_spec_seqs = draft_tokens / k
+        num_accepted_per_seq_if_all_accepted = k + 1
+        return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)

+ 392 - 0
aphrodite/spec_decode/multi_step_worker.py

@@ -0,0 +1,392 @@
+from typing import List, Dict, Optional, Tuple
+import copy
+
+import torch
+
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.task_handler.worker import Worker
+from aphrodite.spec_decode.interfaces import (
+    SpeculativeProposals,
+    SpeculativeProposer,
+)
+from aphrodite.spec_decode.util import sampler_output_to_torch
+
+
+class MultiStepWorker(Worker):
+    """The MultiStepWorker is equivalent to a Worker except that it allows
+    multiple forward passes in a single call, assuming the scheduler has
+    allocated enough space to store the additional KV. This reduces overhead
+    by invoking the scheduler less.
+
+    The MultiStepWorker does not support cache swap operations, or beam search.
+    Cache swap operations do not require large modifications. On the other hand,
+    beam search requires memory allocations during sequence forks and thus
+    requires more thought for MultiStepWorker support.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self._proposer: Optional[DraftModelTop1Proposer] = None
+
+    def init_model(self):
+        super().init_model()
+
+        self._proposer = DraftModelTop1Proposer(
+            self,
+            self.device,
+            self.max_model_len,
+            self.vocab_size,
+        )
+
+    @torch.inference_mode()
+    def execute_model_multi_step(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+        num_steps: int,
+    ) -> List[SamplerOutput]:
+        """Run the model forward pass num_steps times. Returns the list of
+        sampler output, one per model forward pass.
+        """
+        self._raise_if_unsupported(
+            seq_group_metadata_list,
+            blocks_to_swap_in,
+            blocks_to_swap_out,
+            blocks_to_copy,
+        )
+
+        # Shallow copy input data so modifications (such as appending tokens)
+        # do not cause side-effects.
+        copied_seq_group_metadata_list = self._shallow_copy_inputs(
+            seq_group_metadata_list)
+
+        # Assert enough KV space for num_steps tokens per sequence.
+        self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
+
+        # Run model num_steps times.
+        model_outputs = []
+        for _ in range(num_steps):
+            model_output = super().execute_model(
+                seq_group_metadata_list=copied_seq_group_metadata_list,
+                blocks_to_swap_in=blocks_to_swap_in,
+                blocks_to_swap_out=blocks_to_swap_out,
+                blocks_to_copy=blocks_to_copy,
+            )
+
+            self._append_new_tokens(model_output,
+                                    copied_seq_group_metadata_list)
+            model_outputs.append(model_output)
+
+        return model_outputs
+
+    def get_spec_proposals(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+        max_proposal_len: int,
+    ) -> SpeculativeProposals:
+        """Produce speculations given an input batch of sequences. The number of
+        speculative tokens per sequence is determined by max_proposal_len.
+        """
+
+        return self._proposer.get_proposals(
+            seq_group_metadata_list,
+            blocks_to_swap_in,
+            blocks_to_swap_out,
+            blocks_to_copy,
+            max_proposal_len,
+        )
+
+    def _append_new_tokens(
+        self,
+        model_output: SamplerOutput,
+        seq_group_metadata_list: SequenceGroupMetadata,
+    ) -> None:
+        """Given model output from a single run, append the tokens to the
+        sequences. This is normally done outside of the worker, but it is
+        required if the worker is to perform multiple forward passes.
+        """
+        for seq_group_metadata, sequence_group_outputs in zip(
+                seq_group_metadata_list, model_output):
+            seq_group_metadata.is_prompt = False
+
+            for seq_output in sequence_group_outputs.samples:
+                # NOTE: Beam search is not supported, so we can assume that
+                # parent_seq_id == seq_id.
+                seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
+
+                token_id = seq_output.output_token
+                token_logprob = seq_output.logprobs[token_id]
+
+                seq.append_token_id(token_id, token_logprob.logprob)
+
+    def _shallow_copy_inputs(
+        self, seq_group_metadata_list: List[SequenceGroupMetadata]
+    ) -> List[SequenceGroupMetadata]:
+        """Copy input data structures to remove side-effects when input data
+        structures are shared with other modules.
+
+        Helpful when the Aphrodite scheduler runs in the same process as the
+        worker. The alternative is deep-copying (or other form of deep copy);
+        this has performance downsides.
+        """
+
+        # Shallow-copy the list of SequenceGroupMetadata. This allows us to
+        # append tokens and change is_prompt without external side-effects.
+        new_seq_group_metadata_list = []
+
+        for old_seq_group_metadata in seq_group_metadata_list:
+            # We must shallow-copy seq_group_metadata as is_prompt could change.
+            seq_group_metadata = copy.copy(old_seq_group_metadata)
+            new_seq_group_metadata_list.append(seq_group_metadata)
+
+            # We must shallow-copy seq_data as we will append token ids
+            new_seq_data = {}
+            for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
+                new_seq_data[seq_id] = copy.copy(old_seq_data)
+                new_seq_data[
+                    seq_id].output_token_ids = old_seq_data.output_token_ids[:]
+
+            seq_group_metadata.seq_data = new_seq_data
+
+        return new_seq_group_metadata_list
+
+    def _assert_enough_kv_space(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        num_steps: int,
+    ) -> None:
+        """Assert there are enough physical blocks per sequence to store the
+        current KV plus additional KV from num_steps tokens.
+        """
+        assert self.model_runner.block_size is not None
+        for seq_group_metadata in seq_group_metadata_list:
+            # Only one seq_id is guaranteed because there is no beam search.
+            seq_id = list(seq_group_metadata.seq_data.keys())[0]
+            seq = seq_group_metadata.seq_data[seq_id]
+
+            # After num_steps, the seq len will be the current seq len
+            # plus one token per step.
+            final_seq_len = seq.get_len() + num_steps
+
+            # We will have final_seq_len - 1 KV because Aphrodite saves KV for a
+            # token in the iteration after the token was generated.
+            required_num_kv_slots = final_seq_len - 1
+
+            # The allocated number of kv slots is the number of allocated blocks
+            # times the number of slots of block.
+            number_physical_blocks = len(
+                seq_group_metadata.block_tables[seq_id])
+            allocated_kv_slots = (number_physical_blocks *
+                                  self.model_runner.block_size)
+
+            if required_num_kv_slots > allocated_kv_slots:
+                request_id = seq_group_metadata.request_id
+                raise ValueError(
+                    "The worker attempted to run "
+                    f"{num_steps} times but found insufficient KV space for "
+                    f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
+                    f"{required_num_kv_slots=}).")
+
+    def _raise_if_unsupported(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> None:
+        """MultiStepWorker does not yet implement support for cache swap
+        operations or beam search.
+        """
+        if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
+            raise NotImplementedError(
+                "MultiStepWorker does not support cache operations")
+
+        if any(
+                len(seq_group_metadata.seq_data.keys()) != 1
+                for seq_group_metadata in seq_group_metadata_list):
+            raise NotImplementedError(
+                "MultiStepWorker does not support beam search.")
+
+
+class DraftModelTop1Proposer(SpeculativeProposer):
+    """Helper class which separates out sequences which would exceed the max
+    model length when speculated upon.
+
+    This allows combinations of models such as JackFram/llama-68m draft with
+    meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
+    2048 while Llama2-13b has max_position_embeddings of 4096.
+
+    We treat the sequences which exceed the proposal draft model length as
+    "non-spec sequences". Essentially they skip the draft model and go through
+    normal decoding in the target model.
+
+    Currently, only proposal_lens of 0 and k are supported, where k is a global
+    batch proposal length. In the future Aphrodite should support per-sequence
+    proposal lengths.
+    """
+
+    def __init__(
+        self,
+        draft_worker: MultiStepWorker,
+        device: str,
+        max_model_len: int,
+        vocab_size: int,
+    ):
+        self._draft_worker = draft_worker
+        self._device = device
+        self._max_model_len = max_model_len
+        self._vocab_size = vocab_size
+
+    def get_proposals(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+        max_proposal_len: int,
+    ) -> SpeculativeProposals:
+        """Get speculative proposals given the input batch.
+
+        Sequences which would exceed the max model length are skipped during
+        speculation.
+        """
+
+        # Split speculative- and non-speculative- sequences.
+        (
+            proposal_lens,
+            nonzero_proposal_len_seqs,
+            nonzero_proposal_len_indices,
+        ) = self._split_by_max_model_len(seq_group_metadata_list,
+                                         max_proposal_len)
+
+        if nonzero_proposal_len_seqs:
+            # Speculate tokens using the draft worker for the speculative
+            # sequences.
+            maybe_sampler_output = self._draft_worker.execute_model_multi_step(
+                seq_group_metadata_list=nonzero_proposal_len_seqs,
+                blocks_to_swap_in=blocks_to_swap_in,
+                blocks_to_swap_out=blocks_to_swap_out,
+                blocks_to_copy=blocks_to_copy,
+                num_steps=max_proposal_len,
+            )
+        else:
+            # If no sequences can be speculated, set sampler output to None.
+            maybe_sampler_output = None
+
+        # Combine speculative- and non-speculative sequences into the same
+        # representation.
+        proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
+            batch_size=len(seq_group_metadata_list),
+            max_proposal_len=max_proposal_len,
+            maybe_sampler_output=maybe_sampler_output,
+            proposal_lens=proposal_lens,
+            nonzero_proposal_len_indices=nonzero_proposal_len_indices,
+        )
+
+        proposals = SpeculativeProposals(
+            proposal_token_ids=proposal_tokens,
+            proposal_probs=proposal_probs,
+            proposal_lens=proposal_lens,
+        )
+
+        return proposals
+
+    def _split_by_max_model_len(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        max_proposal_len: int,
+    ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
+        """Determine which sequences would exceed the max model length."""
+
+        proposal_lens: List[int] = []
+        nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
+        nonzero_proposal_len_indices: List[int] = []
+        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
+            seq_data = next(iter(seq_group_metadata.seq_data.values()))
+            seq_len = seq_data.get_len()
+
+            # Currently only proposal lens of 0 or the global batch proposal len
+            # are supported.
+            if seq_len + max_proposal_len < self._max_model_len:
+                proposal_lens.append(max_proposal_len)
+                nonzero_proposal_len_seqs.append(seq_group_metadata)
+                nonzero_proposal_len_indices.append(i)
+            else:
+                proposal_lens.append(0)
+
+        return (
+            proposal_lens,
+            nonzero_proposal_len_seqs,
+            nonzero_proposal_len_indices,
+        )
+
+    def _merge_outputs(
+        self,
+        batch_size: int,
+        max_proposal_len: int,
+        maybe_sampler_output: Optional[SamplerOutput],
+        proposal_lens: List[int],
+        nonzero_proposal_len_indices: List[int],
+    ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
+        """After speculations are produced, merge the speculation results with
+        the skipped sequences.
+        """
+        if maybe_sampler_output is None:
+            # If no speculative tokens, the sampler output will be None.
+            # In this case we return empty tensors.
+            proposal_tokens = torch.zeros(0,
+                                          max_proposal_len,
+                                          dtype=torch.long,
+                                          device=self._device)
+            proposal_probs = torch.zeros(
+                0,
+                max_proposal_len,
+                self._vocab_size,
+                dtype=torch.float32,
+                device=self._device,
+            )
+            proposal_lens = torch.zeros(len(proposal_lens),
+                                        dtype=torch.long,
+                                        device=self._device)
+            return proposal_tokens, proposal_probs, proposal_lens
+
+        sampler_output = maybe_sampler_output
+
+        proposal_tokens, proposal_probs = sampler_output_to_torch(
+            sampler_output)
+
+        # Now, reformat the output GPU tensors such that each sequence has
+        # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
+
+        entire_proposal_tokens = torch.full(
+            size=(batch_size, *proposal_tokens.shape[1:]),
+            fill_value=-1,
+            dtype=torch.long,
+            device=self._device,
+        )
+        entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
+        entire_proposal_probs = torch.zeros(
+            batch_size,
+            *proposal_probs.shape[1:],
+            dtype=torch.float32,
+            device=self._device,
+        )
+        entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
+
+        proposal_tokens, proposal_probs = (
+            entire_proposal_tokens,
+            entire_proposal_probs,
+        )
+
+        proposal_lens = torch.zeros(batch_size,
+                                    dtype=torch.long,
+                                    device=self._device)
+        proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
+
+        return proposal_tokens, proposal_probs, proposal_lens

+ 394 - 0
aphrodite/spec_decode/spec_decode_worker.py

@@ -0,0 +1,394 @@
+from typing import List, Tuple, Optional, Dict
+from functools import cached_property
+
+import torch
+
+from aphrodite.spec_decode.metrics import AsyncMetricsCollector
+from aphrodite.common.sequence import (
+    SamplerOutput,
+    SequenceGroupMetadata,
+    SequenceGroupOutput,
+    SequenceOutput,
+)
+from aphrodite.task_handler.worker import Worker
+from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
+from aphrodite.modeling.layers.rejection import RejectionSampler
+from aphrodite.common.config import CacheConfig
+from aphrodite.spec_decode.util import (
+    nvtx_range,
+    get_all_seq_ids,
+    split_batch_by_proposal_len,
+)
+from aphrodite.spec_decode.interfaces import (
+    SpeculativeProposals,
+    SpeculativeScores,
+)
+from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
+from aphrodite.spec_decode.interfaces import SpeculativeScorer
+
+
+class SpecDecodeWorker:
+    """Worker which implements speculative decoding.
+
+    Speculative decoding reduces decoding per-token latency by using a proposal
+    method, such as a small draft model, to speculate ahead of a larger LLM. The
+    probabilities of the speculative tokens are then determined by the larger
+    LLM, after which some verification routine determines which (if any) of the
+    speculative tokens are accepted by the larger LLM.
+
+    The current implementation has the following limitations:
+    * Only draft-model proposal is implemented (contributions for more forms are
+        welcome!).
+    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
+        future work.
+    * Only lossless rejection sampling is supported. Contributions adding lossy
+        verification routines are welcome (e.g. Medusa's typical acceptance).
+    * All sequences in a batch must have the same proposal length, or zero. This
+        can be improved by having per-sequence speculation in the future.
+    * The scoring forward pass is done without an MQA kernel, which is
+        suboptimal especially as the batch size, proposal length, and sequence
+        lengths grow. Contributions to add a MQA scoring are welcome once
+        correctness tests pass.
+    """
+
+    def __init__(
+        self,
+        proposer_worker: MultiStepWorker,
+        scorer_worker: Worker,
+        rejection_sampler: RejectionSampler,
+        metrics_collector: Optional[AsyncMetricsCollector] = None,
+    ):
+        """
+        Create a SpecDecodeWorker.
+
+        Args:
+            proposer_worker: A worker that can produce speculative tokens for
+                sequences.
+            scorer_worker: A worker that produces probabilities of speculative
+                tokens according to some base model. Typically a vanilla
+                Aphrodite Worker.
+            rejection_sampler: A Torch module used to perform modified rejection
+                sampling for speculative decoding.
+            metrics_collector: Helper class for collecting metrics; can be set
+                for testing purposes.
+        """
+        self.proposer_worker = proposer_worker
+        self.scorer_worker = scorer_worker
+        self.rejection_sampler = rejection_sampler
+
+        self._metrics = (AsyncMetricsCollector(rejection_sampler)
+                         if metrics_collector is None else metrics_collector)
+
+        self.probs_dtype = self.rejection_sampler.probs_dtype
+        self.token_id_dtype = self.rejection_sampler.token_id_dtype
+
+        self.scorer: SpeculativeScorer = None
+
+    def init_model(self) -> None:
+        """Initialize both scorer and proposer models."""
+        # The scorer worker model is initialized first in case the proposer
+        # model has a smaller TP degree than the target worker.
+        self.scorer_worker.init_model()
+        self.proposer_worker.init_model()
+
+        self._metrics.init_gpu_tensors(self.rank)
+        self.rejection_sampler.init_gpu_tensors(self.rank)
+        self.scorer = BatchExpansionTop1Scorer(
+            scorer_worker=self.scorer_worker,
+            device=self.device,
+            vocab_size=self._vocab_size,
+        )
+
+    def profile_num_available_blocks(
+        self,
+        block_size: int,
+        gpu_memory_utilization: float,
+        cpu_swap_space: int,
+        cache_dtype: str,
+    ) -> Tuple[int, int]:
+        """Determine the number of cache blocks to use.
+
+        This is done by profiling the scorer model (which is typically the
+        larger of the two). Then the total memory which would be used by the
+        scorer cache is divided evenly between the proposer and scorer model KV,
+        such that the number of blocks is equal in both KV caches.
+        """
+        (
+            num_gpu_blocks,
+            num_cpu_blocks,
+        ) = self.scorer_worker.profile_num_available_blocks(
+            block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype)
+
+        scorer_cache_block_size_bytes = (
+            self.scorer_worker.get_cache_block_size_bytes(
+                block_size, cache_dtype))
+        proposer_cache_block_size_bytes = (
+            self.proposer_worker.get_cache_block_size_bytes(
+                block_size, cache_dtype))
+
+        new_num_gpu_blocks = split_num_cache_blocks_evenly(
+            scorer_cache_block_size_bytes,
+            proposer_cache_block_size_bytes,
+            num_gpu_blocks,
+        )
+        return new_num_gpu_blocks, num_cpu_blocks
+
+    def init_cache_engine(self, cache_config: CacheConfig):
+        """Initialize the cache engine of the scorer and proposer workers."""
+        self.scorer_worker.init_cache_engine(cache_config)
+        self.proposer_worker.init_cache_engine(cache_config)
+
+    @torch.inference_mode()
+    def execute_model(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Optional[Dict[int, int]],
+        blocks_to_swap_out: Optional[Dict[int, int]],
+        blocks_to_copy: Optional[Dict[int, List[int]]],
+        num_spec_tokens: int,
+    ) -> List[SamplerOutput]:
+        """Perform speculative decoding on the input batch."""
+
+        assert seq_group_metadata_list is not None, (
+            "speculative decoding "
+            "requires non-None seq_group_metadata_list")
+
+        # If no spec tokens, call the proposer and scorer workers normally.
+        # Used for prefill.
+        if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
+            return self._run_no_spec(
+                seq_group_metadata_list=seq_group_metadata_list,
+                blocks_to_swap_in=blocks_to_swap_in,
+                blocks_to_swap_out=blocks_to_swap_out,
+                blocks_to_copy=blocks_to_copy,
+            )
+
+        return self._run_speculative_decoding_step(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+            k=num_spec_tokens,
+        )
+
+    @nvtx_range("spec_decode_worker._run_no_spec")
+    def _run_no_spec(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Optional[Dict[int, int]],
+        blocks_to_swap_out: Optional[Dict[int, int]],
+        blocks_to_copy: Optional[Dict[int, List[int]]],
+    ) -> List[SamplerOutput]:
+        """Run a prefill step, without any speculation. The input is sent to the
+        proposer and scorer model so that the KV cache is consistent between the
+        two.
+        """
+
+        self.proposer_worker.execute_model(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+            return_python_output=False,
+        )
+
+        sampler_output = self.scorer_worker.execute_model(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+        )
+
+        # Clear device tensors from sampler output. This reduces communication
+        # overhead when the engine runs in a different process than the workers.
+        sampler_output.probs = None
+        sampler_output.sampled_tokens = None
+        return [sampler_output]
+
+    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
+    def _run_speculative_decoding_step(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Optional[Dict[int, int]],
+        blocks_to_swap_out: Optional[Dict[int, int]],
+        blocks_to_copy: Optional[Dict[int, List[int]]],
+        k: int,
+    ) -> List[SamplerOutput]:
+        """Execute a single step of speculative decoding.
+
+        This invokes the proposer worker to get k speculative tokens for each
+        sequence, then scores each speculative token using the scoring worker.
+
+        Returns a list of SamplerOutput, each containing a single token per
+        sequence.
+        """
+
+        # Generate proposals using draft worker.
+        proposals = self.proposer_worker.get_spec_proposals(
+            seq_group_metadata_list,
+            blocks_to_swap_in,
+            blocks_to_swap_out,
+            blocks_to_copy,
+            k,
+        )
+
+        proposal_scores = self.scorer.score_proposals(
+            seq_group_metadata_list,
+            blocks_to_swap_in,
+            blocks_to_swap_out,
+            blocks_to_copy,
+            k,
+            proposals,
+        )
+
+        accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
+                                                 proposal_scores, proposals, k)
+
+        return self._create_output_sampler_list(seq_group_metadata_list,
+                                                accepted_token_ids, k)
+
+    @nvtx_range("spec_decode_worker._verify_tokens")
+    def _verify_tokens(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        proposal_scores: SpeculativeScores,
+        proposals: SpeculativeProposals,
+        max_proposal_len: int,
+    ) -> torch.Tensor:
+        """Determine which speculative tokens are accepted using the
+        probabilities of each token according to the proposer and scorer models.
+        """
+        proposal_lens_list = proposals.proposal_lens.tolist()
+
+        # Aphrodite currently only supports proposal lens equal to zero or the
+        # batch proposal len. This adds some complexity (splitting the batch
+        # into spec and non spec sequences) and should be removed in the
+        # future. It can be done by supporting per-sequence proposal lens.
+        _, spec_indices = split_batch_by_proposal_len(
+            seq_group_metadata_list,
+            proposal_lens_list,
+            select_proposal_len_zero=False,
+        )
+        _, non_spec_indices = split_batch_by_proposal_len(
+            seq_group_metadata_list,
+            proposal_lens_list,
+            select_proposal_len_zero=True,
+        )
+        original_indices = spec_indices + non_spec_indices
+
+        proposal_probs = proposal_scores.probs[spec_indices, :-1]
+        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
+        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
+
+        accepted_token_ids = self.rejection_sampler(
+            proposal_probs,
+            bonus_token_ids,
+            proposals.proposal_probs,
+            proposals.proposal_token_ids,
+        )
+
+        # Append output tokens from non-speculative sequences to
+        # the accepted token ids tensor.
+        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
+                                                       1).clone()
+        non_spec_token_ids[:, 1:] = -1
+        accepted_token_ids = torch.cat(
+            [accepted_token_ids, non_spec_token_ids])
+
+        # Rearrange so that results are in the order of the original seq group
+        # metadata.
+        accepted_token_ids[original_indices] = accepted_token_ids.clone()
+
+        return accepted_token_ids
+
+    def _create_output_sampler_list(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
+        k: int,
+    ) -> List[SamplerOutput]:
+        """Given the accepted token ids, create a list of SamplerOutput.
+
+        The output is padded with -1 tokens such that each sequence has
+        the same number of outputs.
+        """
+        seq_ids = get_all_seq_ids(seq_group_metadata_list)
+
+        # shape: [k+1, batch_size]
+        accepted_token_ids_by_step = accepted_token_ids.transpose(0,
+                                                                  1).tolist()
+        sampler_output_list = []
+        for token_ids_by_step in accepted_token_ids_by_step:
+            if all(token_id == -1 for token_id in token_ids_by_step):
+                break
+
+            step_output_token_ids = []
+            for token_id, seq_id in zip(token_ids_by_step, seq_ids):
+                step_output_token_ids.append(
+                    SequenceGroupOutput(
+                        samples=[
+                            SequenceOutput(
+                                parent_seq_id=seq_id,
+                                output_token=token_id,
+                                # TODO Add verifier logprobs.
+                                logprobs={token_id: 0.0},
+                                persistent_data={},
+                            )
+                        ],
+                        prompt_logprobs=None,
+                    ))
+            sampler_output_list.append(
+                SamplerOutput(outputs=step_output_token_ids))
+
+        maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics(
+            k)
+        if maybe_rejsample_metrics is not None:
+            sampler_output_list[
+                0].spec_decode_worker_metrics = maybe_rejsample_metrics
+
+        return sampler_output_list
+
+    @cached_property
+    def _vocab_size(self) -> int:
+        """Get the vocab size of the model and make sure it's consistent between
+        draft and target workers.
+        """
+        vocab_sizes = [
+            worker.vocab_size
+            for worker in [self.proposer_worker, self.scorer_worker]
+        ]
+        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
+        return vocab_sizes[0]
+
+    @property
+    def rank(self):
+        return self.scorer_worker.rank
+
+    @property
+    def device(self):
+        return self.scorer_worker.device
+
+
+def split_num_cache_blocks_evenly(
+    scorer_cache_block_size_bytes: int,
+    proposer_cache_block_size_bytes: int,
+    total_num_gpu_blocks: int,
+) -> int:
+    """Given total_num_gpu_blocks, the number of GPU blocks that could be
+    allocate to the target model, this function calculates how many blocks
+    should be given to the draft and target model.
+
+    Note that usually the block size, in bytes, of each model is different,
+    as it's a function of number of KV/layer, number of heads, and hidden
+    dimension size.
+
+    Since the target and draft models allocate the same number of blocks, we
+    simply calculate the number of blocks where if allocated by both models,
+    the total memory usage from KV cache is no larger than the number of
+    blocks allocatable by the target model alone.
+    """
+    new_num_gpu_blocks = int(
+        total_num_gpu_blocks * scorer_cache_block_size_bytes /
+        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
+
+    return new_num_gpu_blocks

+ 101 - 0
aphrodite/spec_decode/util.py

@@ -0,0 +1,101 @@
+import torch
+from typing import List, Tuple
+from contextlib import contextmanager
+from itertools import chain
+
+from aphrodite.common.sequence import SequenceGroupMetadata, SamplerOutput
+
+SeqId = int
+
+
+def get_all_seq_ids(
+        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
+    """Given a list of SequenceGroupMetadata, create a list of all
+    sequence ids.
+    """
+    return list(
+        chain.from_iterable([
+            seq_group_metadata.seq_data.keys()
+            for seq_group_metadata in seq_group_metadata_list
+        ]))
+
+
+def split_batch_by_proposal_len(
+    seq_group_metadata_list: List[SequenceGroupMetadata],
+    proposal_lens: List[int],
+    select_proposal_len_zero: bool,
+) -> Tuple[List[SequenceGroupMetadata], List[int]]:
+    """Utility function that splits a batch based on whether the proposal len is
+    zero or not. We should remove this once Aphrodite supports per-sequence
+    proposal lens in a batch.
+    """
+
+    if select_proposal_len_zero:
+        predicate = lambda proposal_len: proposal_len == 0
+    else:
+        predicate = lambda proposal_len: proposal_len != 0
+
+    indices = [
+        i for i, (_, proposal_len
+                  ) in enumerate(zip(seq_group_metadata_list, proposal_lens))
+        if predicate(proposal_len)
+    ]
+    seq_groups = [
+        seq_group for seq_group, proposal_len in zip(
+            seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
+    ]
+
+    return seq_groups, indices
+
+
+def sampler_output_to_torch(
+    sampler_output_list: List[SamplerOutput],
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Utility function which converts a list of SamplerOutput to tensors.
+
+    Returns:
+        sampled_token_ids: torch.Tensor
+            shape: [batch_size, len(sampler_output_list)]
+
+        sampled_token_probs: torch.Tensor
+            shape: [batch_size, len(sampler_output_list), vocab_size]
+    """
+
+    # shape: [batch_size, num_sampler_output, vocab_size]
+    sampled_token_probs = torch.stack(
+        [
+            sampler_output.sampled_token_probs
+            for sampler_output in sampler_output_list
+        ],
+        dim=0,
+    ).transpose(0, 1)
+
+    # shape: [batch_size, num_sampler_output]
+    sampled_token_ids = torch.stack(
+        [
+            sampler_output.sampled_token_ids.flatten()
+            for sampler_output in sampler_output_list
+        ],
+        dim=0,
+    ).transpose(0, 1)
+
+    return sampled_token_ids, sampled_token_probs
+
+
+@contextmanager
+def nvtx_range(msg, *args, **kwargs):
+    """
+    Context manager / decorator that pushes an NVTX range at the beginning
+    of its scope, and pops it at the end. If extra arguments are given,
+    they are passed as arguments to msg.format().
+
+    If running with cuda graphs, you must enable nsys cuda graph profiling.
+
+    Arguments:
+        msg (string): message to associate with the range
+    """
+    torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
+    try:
+        yield
+    finally:
+        torch.cuda.nvtx.range_pop()

+ 9 - 2
aphrodite/task_handler/cache_engine.py

@@ -4,9 +4,8 @@ from typing import Dict, List, Tuple
 import torch
 from loguru import logger
 
-from aphrodite._C import cache_ops
 from aphrodite.common.config import CacheConfig, ModelConfig, ParallelConfig
-from aphrodite.common.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
+from aphrodite.common.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 
@@ -37,6 +36,10 @@ class CacheEngine:
         self.num_gpu_blocks = cache_config.num_gpu_blocks
         self.num_cpu_blocks = cache_config.num_cpu_blocks
 
+        # Skip initializing CUDA stream and buffer for Neuron backend.
+        if is_neuron():
+            return
+
         if cache_config.cache_dtype == "auto":
             self.dtype = model_config.dtype
         else:
@@ -119,6 +122,8 @@ class CacheEngine:
         dst: List[KVCache],
         src_to_dst: Dict[int, int],
     ) -> None:
+        from aphrodite._C import cache_ops
+
         with torch.cuda.stream(self.cache_stream):
             for i in range(self.num_layers):
                 src_key_cache, src_value_cache = src[i]
@@ -138,6 +143,8 @@ class CacheEngine:
         self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
 
     def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
+        from aphrodite._C import cache_ops
+
         key_caches = [key_cache for key_cache, _ in self.gpu_cache]
         value_caches = [value_cache for _, value_cache in self.gpu_cache]
         # NOTE: This operation implicitly synchronizes the CPU and GPU.

+ 74 - 56
aphrodite/task_handler/model_runner.py

@@ -53,7 +53,7 @@ class ModelRunner:
         device_config: DeviceConfig,
         lora_config: Optional[LoRAConfig],
         kv_cache_dtype: Optional[str] = "auto",
-        kv_quant_params_path: Optional[str] = None,
+        # kv_quant_params_path: Optional[str] = None,
         is_driver_worker: bool = False,
     ):
         self.model_config = model_config
@@ -69,6 +69,7 @@ class ModelRunner:
         self.device_config = (device_config
                               if device_config is not None else DeviceConfig())
         self.device = self.device_config.device
+
         self.model = None
         self.block_size = None  # Set after initial profiling.
         self.lora_manager = None
@@ -89,37 +90,52 @@ class ModelRunner:
         # cache in_wsl result
         self.in_wsl = in_wsl()
         self.kv_cache_dtype = kv_cache_dtype
-        self.kv_quant_params = (self.load_kv_quant_params(
-            model_config, kv_quant_params_path)
-                                if self.kv_cache_dtype == "int8" else None)
-
-    def load_kv_quant_params(self, model_config: ModelConfig,
-                             kv_quant_params_path: str) -> List[List[float]]:
-        if model_config is None:
-            return None
-        # Remove it when all models support kv cache int8.
-        architectures = model_config.hf_config.architectures
-        for arch in architectures:
-            if arch not in ["LlamaForCausalLM", "LLaMAForCausalLM"]:
-                raise ValueError(
-                    "KV CACHE INT8 is not supported for model architectures "
-                    f"{arch} for now. "
-                    "Supported architectures: LlamaForCausalLM and "
-                    "LLaMAForCausalLM.")
-        num_layers = model_config.hf_config.num_hidden_layers
-        kv_quant_params = []
-        for i in range(num_layers):
-            if kv_quant_params_path is not None:
-                path = (kv_quant_params_path +
-                        f"/layers.{i}.past_kv_scale.0.weight")
-                kv_quant_param = list(np.fromfile(path, dtype=np.float32))
-            kv_quant_params.append(kv_quant_param)
-        return kv_quant_params
+        # self.kv_quant_params = (
+        #     self.load_kv_quant_params(model_config, kv_quant_params_path)
+        #     if self.kv_cache_dtype == "int8"
+        #     else None
+        # )
+
+        # Set enforce_eager to True for Neuron backend, to avoid capturing graph
+        if self.device_config.is_neuron:
+            self.model_config.enforce_eager = True
+
+    # def load_kv_quant_params(
+    #     self, model_config: ModelConfig, kv_quant_params_path: str
+    # ) -> List[List[float]]:
+    #     if model_config is None:
+    #         return None
+    #     # Remove it when all models support kv cache int8.
+    #     architectures = model_config.hf_config.architectures
+    #     for arch in architectures:
+    #         if arch not in ["LlamaForCausalLM", "LLaMAForCausalLM"]:
+    #             raise ValueError(
+    #                 "KV CACHE INT8 is not supported for model architectures "
+    #                 f"{arch} for now. "
+    #                 "Supported architectures: LlamaForCausalLM and "
+    #                 "LLaMAForCausalLM."
+    #             )
+    #     num_layers = model_config.hf_config.num_hidden_layers
+    #     kv_quant_params = []
+    #     for i in range(num_layers):
+    #         if kv_quant_params_path is not None:
+    #             path = (
+    #                 kv_quant_params_path + f"/layers.{i}.past_kv_scale.0.weight"  # noqa: E501
+    #             )
+    #             kv_quant_param = list(np.fromfile(path, dtype=np.float32))
+    #         kv_quant_params.append(kv_quant_param)
+    #     return kv_quant_params
 
     def load_model(self) -> None:
         with measure_cuda_memory() as m:
-            self.model = get_model(self.model_config, self.device_config,
-                                   self.lora_config)
+            self.model = get_model(
+                self.model_config,
+                self.device_config,
+                lora_config=self.lora_config,
+                parallel_config=self.parallel_config,
+                scheduler_config=self.scheduler_config,
+            )
+
         self.model_memory_usage = m.consumed_memory
         tp = get_tensor_model_parallel_world_size()
         logger.info(
@@ -127,8 +143,6 @@ class ModelRunner:
             f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
             f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
 
-        vocab_size = self.model.config.vocab_size
-
         if self.lora_config:
             assert (hasattr(self.model, "supported_lora_modules")
                     and self.model.supported_lora_modules
@@ -142,7 +156,7 @@ class ModelRunner:
                 self.scheduler_config.max_num_seqs,
                 self.scheduler_config.max_num_batched_tokens +
                 self.scheduler_config.max_paddings,
-                vocab_size,
+                self.vocab_size,
                 self.lora_config,
                 self.device,
                 self.model.embedding_modules,
@@ -250,6 +264,7 @@ class ModelRunner:
                 slot_mapping[-1].append(slot)
 
         max_prompt_len = max(subquery_lens)
+        assert max_prompt_len > 0
         input_tokens = _make_tensor_with_pad(
             input_tokens,
             max_prompt_len,
@@ -309,7 +324,7 @@ class ModelRunner:
             block_tables=block_tables,
             use_cuda_graph=False,
             kv_cache_dtype=self.kv_cache_dtype,
-            kv_quant_params=self.kv_quant_params,
+            # kv_quant_params=self.kv_quant_params,
         )
         return (
             input_tokens,
@@ -449,7 +464,7 @@ class ModelRunner:
             block_tables=block_tables,
             use_cuda_graph=use_captured_graph,
             kv_cache_dtype=self.kv_cache_dtype,
-            kv_quant_params=self.kv_quant_params,
+            # kv_quant_params=self.kv_quant_params,
         )
         return (
             input_tokens,
@@ -472,6 +487,7 @@ class ModelRunner:
         selected_token_start_idx = 0
         categorized_sample_indices = {t: [] for t in SamplingType}
         categorized_sample_indices_start_idx = 0
+        pin_memory = not self.in_wsl and not self.device_config.is_neuron
 
         max_subquery_len = max(subquery_lens) if subquery_lens else 1
         for i, seq_group_metadata in enumerate(seq_group_metadata_list):
@@ -501,8 +517,8 @@ class ModelRunner:
                 selected_token_indices.append(selected_token_start_idx +
                                               subquery_len - 1)
                 selected_token_start_idx += max_subquery_len
-                if (sampling_params.sampling_type == SamplingType.RANDOM_SEED):
-                    assert sampling_params.seed is not None
+
+                if sampling_params.seed is not None:
                     seq_group_metadata.state.generator = torch.Generator(
                         device="cuda").manual_seed(sampling_params.seed)
             else:
@@ -522,21 +538,21 @@ class ModelRunner:
                         ))
                 categorized_sample_indices_start_idx += num_seqs
 
-            if (seq_group_metadata.state.generator is not None):
+            if sampling_params.seed is not None:
                 generators.append(seq_group_metadata.state.generator)
 
         selected_token_indices = _async_h2d(
             selected_token_indices,
             dtype=torch.long,
             target_device=self.device,
-            pin_memory=not self.in_wsl,
+            pin_memory=pin_memory,
         )
         categorized_sample_indices = {
             t: _async_h2d(
                 seq_ids,
                 dtype=torch.int,
                 target_device=self.device,
-                pin_memory=not self.in_wsl,
+                pin_memory=pin_memory,
             )
             for t, seq_ids in categorized_sample_indices.items()
         }
@@ -621,9 +637,9 @@ class ModelRunner:
                 "block_tables": input_metadata.block_tables,
                 "use_cuda_graph": input_metadata.use_cuda_graph,
                 "kv_cache_dtype": input_metadata.kv_cache_dtype,
-                "kv_quant_params": input_metadata.kv_quant_params,
+                # "kv_quant_params": input_metadata.kv_quant_params,
                 "selected_token_indices":
-                sampling_metadata.selected_token_indices,  # noqa
+                sampling_metadata.selected_token_indices,
                 "lora_requests": lora_requests,
                 "lora_mapping": lora_mapping,
             }
@@ -645,7 +661,7 @@ class ModelRunner:
                 block_tables=metadata_dict["block_tables"],
                 use_cuda_graph=metadata_dict["use_cuda_graph"],
                 kv_cache_dtype=metadata_dict["kv_cache_dtype"],
-                kv_quant_params=metadata_dict["kv_quant_params"],
+                # kv_quant_params=metadata_dict["kv_quant_params"],
             )
             sampling_metadata = SamplingMetadata(
                 seq_groups=None,
@@ -707,8 +723,7 @@ class ModelRunner:
     @torch.inference_mode()
     def profile_run(self) -> None:
         # Enable top-k sampling to reflect the accurate memory usage.
-        vocab_size = self.model_config.get_vocab_size()
-        sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
+        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
         max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
         max_num_seqs = self.scheduler_config.max_num_seqs
 
@@ -789,8 +804,9 @@ class ModelRunner:
     @torch.inference_mode()
     def capture_model(self, kv_caches: List[KVCache]) -> None:
         # NOTE: This is a hack to ensure that the NCCL backend is never
-        # deleted before the CUDA graph
+        # deleted before the CUDA graphs.
         self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
+
         assert not self.model_config.enforce_eager
         logger.info("Capturing the model for CUDA graphs. This may lead to "
                     "unexpected consequences if the model is not static. To "
@@ -818,8 +834,6 @@ class ModelRunner:
             bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
         ]
 
-        # NOTE: Capturing the largest batch size first may help reduce the
-        # memory usage of CUDA graph.
         # NOTE: There are 3 backends for all-reduce: custom all-reduce
         # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
         # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
@@ -847,7 +861,7 @@ class ModelRunner:
                     block_tables=block_tables[:batch_size],
                     use_cuda_graph=True,
                     kv_cache_dtype=self.kv_cache_dtype,
-                    kv_quant_params=self.kv_quant_params,
+                    # kv_quant_params=self.kv_quant_params,
                 )
 
                 if self.lora_config:
@@ -882,6 +896,10 @@ class ModelRunner:
         self.graph_runners.clear()
         self.cupy_nccl_backend = None
 
+    @property
+    def vocab_size(self) -> int:
+        return self.model_config.get_vocab_size()
+
 
 class CUDAGraphRunner:
 
@@ -916,14 +934,14 @@ class CUDAGraphRunner:
         # NOTE: Python 3.8 does not support multi-line with statements.
         # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
         self.graph = torch.cuda.CUDAGraph()
-        with torch.cuda.graph(self.graph,
-                              pool=memory_pool), _maybe_cupy_nccl():
-            hidden_states = self.model(
-                input_ids,
-                positions,
-                kv_caches,
-                input_metadata,
-            )
+        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
+            with _maybe_cupy_nccl():
+                hidden_states = self.model(
+                    input_ids,
+                    positions,
+                    kv_caches,
+                    input_metadata,
+                )
         torch.cuda.synchronize()
 
         # Save the input and output buffers.

+ 204 - 0
aphrodite/task_handler/neuron_worker.py

@@ -0,0 +1,204 @@
+"""A Neuron worker class."""
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.distributed
+
+from aphrodite.common.config import (
+    CacheConfig,
+    DeviceConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+)
+from aphrodite.modeling import set_random_seed
+from aphrodite.modeling.megatron.communication_op import broadcast_tensor_dict
+from aphrodite.modeling.megatron.parallel_state import (
+    ensure_model_parallel_initialized, )
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.task_handler.cache_engine import CacheEngine
+from aphrodite.task_handler.model_runner import ModelRunner
+
+
+class Worker:
+    """A worker class that executes the model on a group of neuron cores."""
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        local_rank: int,
+        rank: int,
+        distributed_init_method: str,
+        lora_config: Optional[LoRAConfig] = None,
+        kv_cache_dtype: Optional[str] = "auto",
+        # kv_quant_params_path: Optional[str] = None,
+        is_driver_worker: bool = False,
+    ) -> None:
+        self.model_config = model_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+        self.local_rank = local_rank
+        self.rank = rank
+        self.distributed_init_method = distributed_init_method
+        self.lora_config = lora_config
+        self.is_driver_worker = is_driver_worker
+        if self.is_driver_worker:
+            assert self.rank == 0, "The driver worker must have rank 0."
+
+        self.model_runner = ModelRunner(
+            model_config,
+            parallel_config,
+            scheduler_config,
+            device_config,
+            lora_config=self.lora_config,
+            is_driver_worker=is_driver_worker,
+        )
+        # Uninitialized cache engine. Will be initialized by
+        # self.init_cache_engine().
+        self.cache_config = None
+        self.cache_engine = None
+        self.cache_events = None
+        self.gpu_cache = None
+
+    def init_model(self) -> None:
+        # Initialize the distributed environment.
+        _init_distributed_environment(
+            self.parallel_config,
+            self.rank,
+            self.distributed_init_method,
+            distributed_backend="gloo",
+        )
+
+        # Initialize the model.
+        set_random_seed(self.model_config.seed)
+
+    def load_model(self):
+        self.model_runner.load_model()
+
+    @torch.inference_mode()
+    def profile_num_available_blocks(
+        self,
+        block_size: int = 128,
+        gpu_memory_utilization: float = 0.9,
+        cpu_swap_space: int = 0,
+        cache_dtype: str = "float16",
+    ) -> Tuple[int, int]:
+        """Simply returns max_num_seqs as num_gpu_blocks, 0 as num_cpu_blocks.
+        """
+        num_gpu_blocks = self.scheduler_config.max_num_seqs
+        num_cpu_blocks = 0
+        return num_gpu_blocks, num_cpu_blocks
+
+    def init_cache_engine(self, cache_config: CacheConfig) -> None:
+        self.cache_config = cache_config
+        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
+                                        self.parallel_config)
+        self.model_runner.set_block_size(self.cache_engine.block_size)
+
+    def warm_up_model(self) -> None:
+        # Warm up is maintained in transformers-neuronx
+        pass
+
+    def cache_swap(
+        self,
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> None:
+        # Issue cache operations.
+        issued_cache_op = False
+        if blocks_to_swap_in:
+            self.cache_engine.swap_in(blocks_to_swap_in)
+            issued_cache_op = True
+        if blocks_to_swap_out:
+            self.cache_engine.swap_out(blocks_to_swap_out)
+            issued_cache_op = True
+        if blocks_to_copy:
+            self.cache_engine.copy(blocks_to_copy)
+            issued_cache_op = True
+
+        cache_events = self.cache_events if issued_cache_op else None
+
+        # Wait for cache operations to finish.
+        if cache_events is not None:
+            raise NotImplementedError(
+                "cache operations are not implemented for neuron backend.")
+
+    @torch.inference_mode()
+    def execute_model(
+        self,
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
+        blocks_to_swap_in: Optional[Dict[int, int]] = None,
+        blocks_to_swap_out: Optional[Dict[int, int]] = None,
+        blocks_to_copy: Optional[Dict[int, List[int]]] = None,
+    ) -> Optional[SamplerOutput]:
+        if self.is_driver_worker:
+            assert seq_group_metadata_list is not None
+            num_seq_groups = len(seq_group_metadata_list)
+            assert blocks_to_swap_in is not None
+            assert blocks_to_swap_out is not None
+            assert blocks_to_copy is not None
+            data = {
+                "num_seq_groups": num_seq_groups,
+                "blocks_to_swap_in": blocks_to_swap_in,
+                "blocks_to_swap_out": blocks_to_swap_out,
+                "blocks_to_copy": blocks_to_copy,
+            }
+            broadcast_tensor_dict(data, src=0)
+        else:
+            data = broadcast_tensor_dict(src=0)
+            num_seq_groups = data["num_seq_groups"]
+            blocks_to_swap_in = data["blocks_to_swap_in"]
+            blocks_to_swap_out = data["blocks_to_swap_out"]
+            blocks_to_copy = data["blocks_to_copy"]
+
+        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
+
+        # If there is no input, we don't need to execute the model.
+        if num_seq_groups == 0:
+            return {}
+
+        output = self.model_runner.execute_model(seq_group_metadata_list,
+                                                 self.gpu_cache)
+        return output
+
+
+def _init_distributed_environment(
+    parallel_config: ParallelConfig,
+    rank: int,
+    distributed_init_method: Optional[str] = None,
+    distributed_backend: Optional[str] = None,
+) -> None:
+    """Initialize the distributed environment."""
+    if torch.distributed.is_initialized():
+        torch_world_size = torch.distributed.get_world_size()
+        if torch_world_size != parallel_config.world_size:
+            raise RuntimeError(
+                "torch.distributed is already initialized but the torch world "
+                "size does not match parallel_config.world_size "
+                f"({torch_world_size} vs. {parallel_config.world_size}).")
+    elif not distributed_init_method:
+        raise ValueError(
+            "distributed_init_method must be set if torch.distributed "
+            "is not already initialized")
+    else:
+        distributed_backend = (distributed_backend
+                               if distributed_backend else "nccl")
+        torch.distributed.init_process_group(
+            backend=distributed_backend,
+            world_size=parallel_config.world_size,
+            rank=rank,
+            init_method=distributed_init_method,
+        )
+
+    # A small all_reduce for warmup.
+    torch.distributed.all_reduce(torch.zeros(1))
+    ensure_model_parallel_initialized(
+        parallel_config.tensor_parallel_size,
+        parallel_config.pipeline_parallel_size,
+    )

+ 30 - 14
aphrodite/task_handler/worker.py

@@ -6,9 +6,9 @@ from typing import Dict, List, Tuple, Set, Optional
 import torch
 import torch.distributed
 
-from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                                     SchedulerConfig, LoRAConfig, DeviceConfig)
-from aphrodite.common.utils import in_wsl
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig)
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling.megatron import cupy_utils
 from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
@@ -20,7 +20,7 @@ from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.task_handler.cache_engine import CacheEngine
 from aphrodite.task_handler.model_runner import ModelRunner
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.utils import is_hip
+from aphrodite.common.utils import in_wsl
 
 
 class Worker:
@@ -42,7 +42,7 @@ class Worker:
         distributed_init_method: str,
         lora_config: Optional[LoRAConfig] = None,
         kv_cache_dtype: Optional[str] = "auto",
-        kv_quant_params_path: Optional[str] = None,
+        # kv_quant_params_path: Optional[str] = None,
         is_driver_worker: bool = False,
     ) -> None:
         self.model_config = model_config
@@ -64,7 +64,7 @@ class Worker:
             device_config,
             lora_config=self.lora_config,
             kv_cache_dtype=kv_cache_dtype,
-            kv_quant_params_path=kv_quant_params_path,
+            # kv_quant_params_path=kv_quant_params_path,
             is_driver_worker=is_driver_worker)
         # Uninitialized cache engine. Will be initialized by
         # self.init_cache_engine().
@@ -99,12 +99,9 @@ class Worker:
         else:
             raise RuntimeError(
                 f"Not support device type: {self.device_config.device}")
-
         # Initialize the distributed environment.
         init_distributed_environment(self.parallel_config, self.rank,
                                      cupy_port, self.distributed_init_method)
-        if not self.parallel_config.disable_custom_all_reduce:
-            init_custom_ar()
         # Initialize the model.
         set_random_seed(self.model_config.seed)
 
@@ -143,8 +140,8 @@ class Worker:
         # GPU did not change their memory usage during the profiling.
         peak_memory = self.init_gpu_memory - free_gpu_memory
 
-        cache_block_size = CacheEngine.get_cache_block_size(
-            block_size, cache_dtype, self.model_config, self.parallel_config)
+        cache_block_size = self.get_cache_block_size_bytes(
+            block_size, cache_dtype)
         num_gpu_blocks = int(
             (total_gpu_memory * gpu_memory_utilization - peak_memory) //
             cache_block_size)
@@ -195,7 +192,7 @@ class Worker:
         # Wait for cache operations to finish.
         # TODO: Profile swapping overhead and optimize if needed.
         if cache_events is not None:
-            for event in cache_events:  # pylint: disable=not-an-iterable
+            for event in cache_events:
                 event.wait()
 
     @torch.inference_mode()
@@ -245,6 +242,22 @@ class Worker:
     def list_loras(self) -> Set[int]:
         return self.model_runner.list_loras()
 
+    @property
+    def max_model_len(self) -> int:
+        return self.model_config.max_model_len
+
+    @property
+    def vocab_size(self) -> int:
+        return self.model_runner.vocab_size
+
+    def get_cache_block_size_bytes(self, block_size: int,
+                                   cache_dtype: str) -> int:
+        """Get the size of the KV cache block size in bytes.
+        """
+        return CacheEngine.get_cache_block_size(block_size, cache_dtype,
+                                                self.model_config,
+                                                self.parallel_config)
+
 
 def init_distributed_environment(
     parallel_config: ParallelConfig,
@@ -279,8 +292,7 @@ def init_distributed_environment(
                 "cupy.distributed is already initialized but the cupy world "
                 "size does not match parallel_config.world_size "
                 f"({cupy_world_size} vs. {parallel_config.world_size}).")
-    elif (parallel_config.world_size > 1 and cupy_port is not None
-          and not is_hip()):
+    elif (parallel_config.world_size > 1 and cupy_port is not None):
         # NOTE: We don't initialize CuPy process group when world size
         # is 1.
         # TODO: Support multi-node connection.
@@ -298,6 +310,10 @@ def init_distributed_environment(
     ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                       parallel_config.pipeline_parallel_size)
 
+    # Initialize a custom fast all-reduce implementation.
+    if not parallel_config.disable_custom_all_reduce:
+        init_custom_ar()
+
 
 def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
     # Check if the GPU supports the dtype.

+ 1 - 2
kernels/attention/attention_dtypes.h

@@ -4,5 +4,4 @@
 #include "dtype_float16.cuh"
 #include "dtype_float32.cuh"
 #include "dtype_bfloat16.cuh"
-#include "dtype_fp8_e5m2.cuh"
-#include "dtype_int8.cuh"
+#include "dtype_fp8_e5m2.cuh"

+ 936 - 1014
kernels/attention/attention_kernels.cu

@@ -16,1017 +16,939 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#ifdef USE_ROCM
-#include <hip/hip_runtime.h>
-#endif
-
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <c10/cuda/CUDAGuard.h>
-
-#include "attention_dtypes.h"
-#include "attention_utils.cuh"
-#include "../quantization/int8_kvcache/quant_utils.cuh"
-#ifdef ENABLE_FP8_E5M2
-#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
-#endif
-
-#include <algorithm>
-
-#ifndef USE_ROCM
-#define WARP_SIZE 32
-#else
-#define WARP_SIZE warpSize
-#endif
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
-
-enum kv_cache_dtype {
-  AUTO,
-#ifdef ENABLE_FP8_E5M2
-  FP8_E5M2,
-#endif
-  INT8};
-
-namespace aphrodite {
-
-// Utility function for attention softmax.
-template<int NUM_WARPS>
-inline __device__ float block_sum(float* red_smem, float sum) {
-  // Decompose the thread index into warp / lane.
-  int warp = threadIdx.x / WARP_SIZE;
-  int lane = threadIdx.x % WARP_SIZE;
-
-  // Compute the sum per warp.
-#pragma unroll
-  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
-    sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
-  }
-
-  // Warp leaders store the data to shared memory.
-  if (lane == 0) {
-    red_smem[warp] = sum;
-  }
-
-  // Make sure the data is in shared memory.
-  __syncthreads();
-
-  // The warps compute the final sums.
-  if (lane < NUM_WARPS) {
-    sum = red_smem[lane];
-  }
-
-  // Parallel reduction inside the warp.
-#pragma unroll
-  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
-    sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
-  }
-
-  // Broadcast to other threads.
-  return APHRODITE_SHFL_SYNC(sum, 0);
-}
-
-// TODO: Merge the last two dimensions of the grid.
-// Grid: (num_heads, num_seqs, max_num_partitions).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int PARTITION_SIZE = 0> // Zero means no partitioning.
-__device__ void paged_attention_kernel(
-  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
-  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
-  const int seq_idx = blockIdx.y;
-  const int partition_idx = blockIdx.z;
-  const int max_num_partitions = gridDim.z;
-  constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
-  const int context_len = context_lens[seq_idx];
-  if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
-    // No work to do. Terminate the thread block.
-    return;
-  }
-
-  const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
-  const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
-
-  // [start_block_idx, end_block_idx) is the range of blocks to process.
-  const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
-  const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
-  const int num_blocks = end_block_idx - start_block_idx;
-
-  // [start_token_idx, end_token_idx) is the range of tokens to process.
-  const int start_token_idx = start_block_idx * BLOCK_SIZE;
-  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
-  const int num_tokens = end_token_idx - start_token_idx;
-
-  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
-  constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
-  assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
-  constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
-  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  const int thread_idx = threadIdx.x;
-  const int warp_idx = thread_idx / WARP_SIZE;
-  const int lane = thread_idx % WARP_SIZE;
-
-  const int head_idx = blockIdx.x;
-  const int num_heads = gridDim.x;
-  const int num_queries_per_kv = num_heads / num_kv_heads;
-  const int kv_head_idx = head_idx / num_queries_per_kv;
-  const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
-
-  // A vector type to store a part of a key or a query.
-  // The vector size is configured in such a way that the threads in a thread group
-  // fetch or compute 16 bytes at a time.
-  // For example, if the size of a thread group is 4 and the data type is half,
-  // then the vector size is 16 / (4 * sizeof(half)) == 2.
-  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
-  using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
-  using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
-  using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
-
-  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
-  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
-
-  const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
-  const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
-
-  // Load the query to registers.
-  // Each thread in a thread group has a different part of the query.
-  // For example, if the the thread group size is 4, then the first thread in the group
-  // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
-  // th vectors of the query, and so on.
-  // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
-  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
-  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
-#pragma unroll
-  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
-    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
-    q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
-  }
-  __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
-
-  // Memory planning.
-  extern __shared__ char shared_mem[];
-  // NOTE: We use FP32 for the softmax logits for better accuracy.
-  float* logits = reinterpret_cast<float*>(shared_mem);
-  // Workspace for reduction.
-  __shared__ float red_smem[2 * NUM_WARPS];
-
-  // x == THREAD_GROUP_SIZE * VEC_SIZE
-  // Each thread group fetches x elements from the key at a time.
-  constexpr int x = 16 / sizeof(cache_t);
-  float qk_max = -FLT_MAX;
-
-  // Iterate over the key blocks.
-  // Each warp fetches a block of keys for each iteration.
-  // Each thread group in a warp fetches a key from the block, and computes
-  // dot product with the query.
-  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
-  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
-    // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by large numbers
-    // (e.g., kv_block_stride).
-    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
-
-    // Load a key to registers.
-    // Each thread in a thread group has a different part of the key.
-    // For example, if the the thread group size is 4, then the first thread in the group
-    // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
-    // vectors of the key, and so on.
-    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
-      const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
-      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
-      K_vec k_vecs[NUM_VECS_PER_THREAD];
-
-#pragma unroll
-      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
-        const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
-                                       + kv_head_idx * kv_head_stride
-                                       + physical_block_offset * x;
-        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
-        const int offset1 = (vec_idx * VEC_SIZE) / x;
-        const int offset2 = (vec_idx * VEC_SIZE) % x;
-        if constexpr (KV_CACHE_DTYPE == INT8) {
-          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
-          using Dequant_vec = typename FloatVec<Quant_vec>::Type;
-          Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
-          k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
-#ifdef ENABLE_FP8_E5M2
-        } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
-          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
-          // Vector conversion from Quant_vec to K_vec.
-          k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
-#endif
-        } else {
-          k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
-        }
-      }
-
-      // Compute dot product.
-      // This includes a reduction across the threads in the same thread group.
-      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
-      // Add the ALiBi bias if slopes are given.
-      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
-
-      if (thread_group_offset == 0) {
-        // Store the partial reductions to shared memory.
-        // NOTE: It is required to zero out the masked logits.
-        const bool mask = token_idx >= context_len;
-        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
-        // Update the max value.
-        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
-      }
-    }
-  }
-
-  // Perform reduction across the threads in the same warp to get the
-  // max qk value for each "warp" (not across the thread block yet).
-  // The 0-th thread of each thread group already has its max qk value.
-#pragma unroll
-  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
-    qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
-  }
-  if (lane == 0) {
-    red_smem[warp_idx] = qk_max;
-  }
-  __syncthreads();
-
-  // TODO: Refactor this part.
-  // Get the max qk value for the sequence.
-  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
-#pragma unroll
-  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
-    qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
-  }
-  // Broadcast the max qk value to all threads.
-  qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
-
-  // Get the sum of the exp values.
-  float exp_sum = 0.f;
-  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
-    float val = __expf(logits[i] - qk_max);
-    logits[i] = val;
-    exp_sum += val;
-  }
-  exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
-
-  // Compute softmax.
-  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
-  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
-    logits[i] *= inv_sum;
-  }
-  __syncthreads();
-
-  // If partitioning is enabled, store the max logit and exp_sum.
-  if (USE_PARTITIONING && thread_idx == 0) {
-    float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
-                                       + head_idx * max_num_partitions
-                                       + partition_idx;
-    *max_logits_ptr = qk_max;
-    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
-                                   + head_idx * max_num_partitions
-                                   + partition_idx;
-    *exp_sums_ptr = exp_sum;
-  }
-
-  // Each thread will fetch 16 bytes from the value cache at a time.
-  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
-  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
-  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
-  using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
-  using Float_L_vec = typename FloatVec<L_vec>::Type;
-
-  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
-  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
-  constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
-
-  // NOTE: We use FP32 for the accumulator for better accuracy.
-  float accs[NUM_ROWS_PER_THREAD];
-#pragma unroll
-  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-    accs[i] = 0.f;
-  }
-
-  scalar_t zero_value;
-  zero(zero_value);
-  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
-    // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by large numbers
-    // (e.g., kv_block_stride).
-    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
-    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
-    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
-    L_vec logits_vec;
-    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
-
-    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
-                                   + kv_head_idx * kv_head_stride;
-#pragma unroll
-    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
-      if (row_idx < HEAD_SIZE) {
-        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
-        V_vec v_vec;
-        if constexpr (KV_CACHE_DTYPE == INT8) {
-          // dequant and conversion
-          V_quant_vec v_vec_quant = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
-          using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
-          V_dequant_vec v_vec_dequant = int8::dequant(v_vec_quant, v_scale, v_zp);
-          v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
-#ifdef ENABLE_FP8_E5M2
-        } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
-          V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
-          // Vector conversion from V_quant_vec to V_vec.
-          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
-#endif
-        } else {
-          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
-        }
-        if (block_idx == num_context_blocks - 1) {
-          // NOTE: When v_vec contains the tokens that are out of the context,
-          // we should explicitly zero out the values since they may contain NaNs.
-          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
-#pragma unroll
-          for (int j = 0; j < V_VEC_SIZE; j++) {
-            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
-          }
-        }
-        accs[i] += dot(logits_vec, v_vec);
-      }
-    }
-  }
-
-  // Perform reduction within each warp.
-#pragma unroll
-  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-    float acc = accs[i];
-#pragma unroll
-    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
-      acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
-    }
-    accs[i] = acc;
-  }
-
-  // NOTE: A barrier is required because the shared memory space for logits
-  // is reused for the output.
-  __syncthreads();
-
-  // Perform reduction across warps.
-  float* out_smem = reinterpret_cast<float*>(shared_mem);
-#pragma unroll
-  for (int i = NUM_WARPS; i > 1; i /= 2) {
-    int mid = i / 2;
-    // Upper warps write to shared memory.
-    if (warp_idx >= mid && warp_idx < i) {
-      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
-#pragma unroll
-      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
-        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
-          dst[row_idx] = accs[i];
-        }
-      }
-    }
-    __syncthreads();
-
-    // Lower warps update the output.
-    if (warp_idx < mid) {
-      const float* src = &out_smem[warp_idx * HEAD_SIZE];
-#pragma unroll
-      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
-        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
-          accs[i] += src[row_idx];
-        }
-      }
-    }
-    __syncthreads();
-  }
-
-  // Write the final output.
-  if (warp_idx == 0) {
-    scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                            + head_idx * max_num_partitions * HEAD_SIZE
-                            + partition_idx * HEAD_SIZE;
-#pragma unroll
-    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
-      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
-        from_float(*(out_ptr + row_idx), accs[i]);
-      }
-    }
-  }
-}
-
-// Grid: (num_heads, num_seqs, 1).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE>
-__global__ void paged_attention_v1_kernel(
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>(
-    /* exp_sums */ nullptr, /* max_logits */ nullptr,
-    out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
-    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
-}
-
-// Grid: (num_heads, num_seqs, max_num_partitions).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int PARTITION_SIZE>
-__global__ void paged_attention_v2_kernel(
-  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
-  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
-  scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE, PARTITION_SIZE>(
-    exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
-    block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
-    q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
-}
-
-// Grid: (num_heads, num_seqs).
-template<
-  typename scalar_t,
-  int HEAD_SIZE,
-  int NUM_THREADS,
-  int PARTITION_SIZE>
-__global__ void paged_attention_v2_reduce_kernel(
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
-  const float* __restrict__ exp_sums,     // [num_seqs, num_heads, max_num_partitions]
-  const float* __restrict__ max_logits,   // [num_seqs, num_heads, max_num_partitions]
-  const scalar_t* __restrict__ tmp_out,   // [num_seqs, num_heads, max_num_partitions, head_size]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_partitions) {
-  const int num_heads = gridDim.x;
-  const int head_idx = blockIdx.x;
-  const int seq_idx = blockIdx.y;
-  const int context_len = context_lens[seq_idx];
-  const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
-  if (num_partitions == 1) {
-    // No need to reduce. Only copy tmp_out to out.
-    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
-    const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                                          + head_idx * max_num_partitions * HEAD_SIZE;
-    for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
-      out_ptr[i] = tmp_out_ptr[i];
-    }
-    // Terminate the thread block.
-    return;
-  }
-
-  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  const int warp_idx = threadIdx.x / WARP_SIZE;
-  const int lane = threadIdx.x % WARP_SIZE;
-
-  // Size: 2 * num_partitions.
-  extern __shared__ char shared_mem[];
-  // Workspace for reduction.
-  __shared__ float red_smem[2 * NUM_WARPS];
-
-  // Load max logits to shared memory.
-  float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
-  const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
-                                           + head_idx * max_num_partitions;
-  float max_logit = -FLT_MAX;
-  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
-    const float l = max_logits_ptr[i];
-    shared_max_logits[i] = l;
-    max_logit = fmaxf(max_logit, l);
-  }
-  __syncthreads();
-
-  // Get the global max logit.
-  // Reduce within the warp.
-#pragma unroll
-  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
-    max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
-  }
-  if (lane == 0) {
-    red_smem[warp_idx] = max_logit;
-  }
-  __syncthreads();
-  // Reduce across warps.
-  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
-#pragma unroll
-  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
-    max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
-  }
-  // Broadcast the max value to all threads.
-  max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
-
-  // Load rescaled exp sums to shared memory.
-  float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
-  const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
-                                       + head_idx * max_num_partitions;
-  float global_exp_sum = 0.0f;
-  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
-    float l = shared_max_logits[i];
-    float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
-    global_exp_sum += rescaled_exp_sum;
-    shared_exp_sums[i] = rescaled_exp_sum;
-  }
-  __syncthreads();
-  global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
-  const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
-
-  // Aggregate tmp_out to out.
-  const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                                        + head_idx * max_num_partitions * HEAD_SIZE;
-  scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
-#pragma unroll
-  for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
-    float acc = 0.0f;
-    for (int j = 0; j < num_partitions; ++j) {
-      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
-    }
-    from_float(out_ptr[i], acc);
-  }
-}
-
-} // namespace aphrodite
-
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                        \
-  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                        \
-    ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,    \
-      KV_CACHE_DTYPE>), shared_mem_size);                                                           \
-  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,              \
-  KV_CACHE_DTYPE><<<grid, block, shared_mem_size, stream>>>(                                        \
-    out_ptr,                                                                                        \
-    query_ptr,                                                                                      \
-    key_cache_ptr,                                                                                  \
-    value_cache_ptr,                                                                                \
-    num_kv_heads,                                                                                   \
-    scale,                                                                                          \
-    block_tables_ptr,                                                                               \
-    context_lens_ptr,                                                                               \
-    max_num_blocks_per_seq,                                                                         \
-    alibi_slopes_ptr,                                                                               \
-    q_stride,                                                                                       \
-    kv_block_stride,                                                                                \
-    kv_head_stride,                                                                                 \
-    k_scale,                                                                                        \
-    k_zp,                                                                                           \
-    v_scale,                                                                                        \
-    v_zp);
-
-// TODO: Tune NUM_THREADS.
-template<
-  typename T,
-  typename CACHE_T,
-  int BLOCK_SIZE,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int NUM_THREADS = 128>
-void paged_attention_v1_launcher(
-  torch::Tensor& out,
-  torch::Tensor& query,
-  torch::Tensor& key_cache,
-  torch::Tensor& value_cache,
-  int num_kv_heads,
-  float scale,
-  torch::Tensor& block_tables,
-  torch::Tensor& context_lens,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  int num_seqs = query.size(0);
-  int num_heads = query.size(1);
-  int head_size = query.size(2);
-  int max_num_blocks_per_seq = block_tables.size(1);
-  int q_stride = query.stride(0);
-  int kv_block_stride = key_cache.stride(0);
-  int kv_head_stride = key_cache.stride(1);
-
-  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
-  assert(head_size % thread_group_size == 0);
-
-  // NOTE: alibi_slopes is optional.
-  const float* alibi_slopes_ptr = alibi_slopes ?
-    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
-    : nullptr;
-
-  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
-  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
-  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
-  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
-  int* block_tables_ptr = block_tables.data_ptr<int>();
-  int* context_lens_ptr = context_lens.data_ptr<int>();
-
-  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
-  int logits_size = padded_max_context_len * sizeof(float);
-  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
-  // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
-  // Keep that in sync with the logic here!
-  int shared_mem_size = std::max(logits_size, outputs_size);
-
-  dim3 grid(num_heads, num_seqs, 1);
-  dim3 block(NUM_THREADS);
-  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
-  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  switch (head_size) {
-    // NOTE: To reduce the compilation time, we only compile for the
-    // head sizes that we use in the model. However, we can easily extend this
-    // to support any head size which is a multiple of 16.
-    case 64:
-      LAUNCH_PAGED_ATTENTION_V1(64);
-      break;
-    case 80:
-      LAUNCH_PAGED_ATTENTION_V1(80);
-      break;
-    case 96:
-      LAUNCH_PAGED_ATTENTION_V1(96);
-      break;
-    case 112:
-      LAUNCH_PAGED_ATTENTION_V1(112);
-      break;
-    case 128:
-      LAUNCH_PAGED_ATTENTION_V1(128);
-      break;
-    case 256:
-      LAUNCH_PAGED_ATTENTION_V1(256);
-      break;
-    default:
-      TORCH_CHECK(false, "Unsupported head size: ", head_size);
-      break;
-  }
-}
-
-#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)             \
-  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(       \
-    out,                                                                     \
-    query,                                                                   \
-    key_cache,                                                               \
-    value_cache,                                                             \
-    num_kv_heads,                                                            \
-    scale,                                                                   \
-    block_tables,                                                            \
-    context_lens,                                                            \
-    max_context_len,                                                         \
-    alibi_slopes,                                                            \
-    k_scale,                                                                 \
-    k_zp,                                                                    \
-    v_scale,                                                                 \
-    v_zp);
-
-// NOTE: To reduce the compilation time, we omitted block sizes
-// 1, 2, 4, 64, 128, 256.
-#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)       \
-  switch (block_size) {                                               \
-    case 8:                                                           \
-      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                \
-      break;                                                          \
-    case 16:                                                          \
-      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);               \
-      break;                                                          \
-    case 32:                                                          \
-      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);               \
-      break;                                                          \
-    default:                                                          \
-      TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
-      break;                                                          \
-  }
-
-void paged_attention_v1(
-  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
-  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
-  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
-  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
-  int num_kv_heads,               // [num_heads]
-  float scale,
-  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
-  torch::Tensor& context_lens,    // [num_seqs]
-  int block_size,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
-  if (kv_cache_dtype == "auto") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-#ifdef ENABLE_FP8_E5M2
-  } else if (kv_cache_dtype == "fp8_e5m2") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-#endif
-  } else if (kv_cache_dtype == "int8") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-  } else {
-    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
-  }
-}
-
-#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
-  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,        \
-  KV_CACHE_DTYPE, PARTITION_SIZE>                                                             \
-  <<<grid, block, shared_mem_size, stream>>>(                                                 \
-    exp_sums_ptr,                                                                             \
-    max_logits_ptr,                                                                           \
-    tmp_out_ptr,                                                                              \
-    query_ptr,                                                                                \
-    key_cache_ptr,                                                                            \
-    value_cache_ptr,                                                                          \
-    num_kv_heads,                                                                             \
-    scale,                                                                                    \
-    block_tables_ptr,                                                                         \
-    context_lens_ptr,                                                                         \
-    max_num_blocks_per_seq,                                                                   \
-    alibi_slopes_ptr,                                                                         \
-    q_stride,                                                                                 \
-    kv_block_stride,                                                                          \
-    kv_head_stride,                                                                           \
-    k_scale,                                                                                  \
-    k_zp,                                                                                     \
-    v_scale,                                                                                  \
-    v_zp);                                                                                    \
-  aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
-  <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
-    out_ptr,                                                                                  \
-    exp_sums_ptr,                                                                             \
-    max_logits_ptr,                                                                           \
-    tmp_out_ptr,                                                                              \
-    context_lens_ptr,                                                                         \
-    max_num_partitions);
-
-template<
-  typename T,
-  typename CACHE_T,
-  int BLOCK_SIZE,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int NUM_THREADS = 128,
-  int PARTITION_SIZE = 512>
-void paged_attention_v2_launcher(
-  torch::Tensor& out,
-  torch::Tensor& exp_sums,
-  torch::Tensor& max_logits,
-  torch::Tensor& tmp_out,
-  torch::Tensor& query,
-  torch::Tensor& key_cache,
-  torch::Tensor& value_cache,
-  int num_kv_heads,
-  float scale,
-  torch::Tensor& block_tables,
-  torch::Tensor& context_lens,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  int num_seqs = query.size(0);
-  int num_heads = query.size(1);
-  int head_size = query.size(2);
-  int max_num_blocks_per_seq = block_tables.size(1);
-  int q_stride = query.stride(0);
-  int kv_block_stride = key_cache.stride(0);
-  int kv_head_stride = key_cache.stride(1);
-
-  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
-  assert(head_size % thread_group_size == 0);
-
-  // NOTE: alibi_slopes is optional.
-  const float* alibi_slopes_ptr = alibi_slopes ?
-    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
-    : nullptr;
-
-  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
-  float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
-  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
-  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
-  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
-  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
-  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
-  int* block_tables_ptr = block_tables.data_ptr<int>();
-  int* context_lens_ptr = context_lens.data_ptr<int>();
-
-  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
-  int logits_size = PARTITION_SIZE * sizeof(float);
-  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
-
-  // For paged attention v2 kernel.
-  dim3 grid(num_heads, num_seqs, max_num_partitions);
-  int shared_mem_size = std::max(logits_size, outputs_size);
-  // For paged attention v2 reduce kernel.
-  dim3 reduce_grid(num_heads, num_seqs);
-  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
-
-  dim3 block(NUM_THREADS);
-  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
-  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  switch (head_size) {
-    // NOTE: To reduce the compilation time, we only compile for the
-    // head sizes that we use in the model. However, we can easily extend this
-    // to support any head size which is a multiple of 16.
-    case 64:
-      LAUNCH_PAGED_ATTENTION_V2(64);
-      break;
-    case 80:
-      LAUNCH_PAGED_ATTENTION_V2(80);
-      break;
-    case 96:
-      LAUNCH_PAGED_ATTENTION_V2(96);
-      break;
-    case 112:
-      LAUNCH_PAGED_ATTENTION_V2(112);
-      break;
-    case 128:
-      LAUNCH_PAGED_ATTENTION_V2(128);
-      break;
-    case 256:
-      LAUNCH_PAGED_ATTENTION_V2(256);
-      break;
-    default:
-      TORCH_CHECK(false, "Unsupported head size: ", head_size);
-      break;
-  }
-}
-
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)                 \
-  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(           \
-    out,                                                                         \
-    exp_sums,                                                                    \
-    max_logits,                                                                  \
-    tmp_out,                                                                     \
-    query,                                                                       \
-    key_cache,                                                                   \
-    value_cache,                                                                 \
-    num_kv_heads,                                                                \
-    scale,                                                                       \
-    block_tables,                                                                \
-    context_lens,                                                                \
-    max_context_len,                                                             \
-    alibi_slopes,                                                                \
-    k_scale,                                                                     \
-    k_zp,                                                                        \
-    v_scale,                                                                     \
-    v_zp);
-
-// NOTE: To reduce the compilation time, we omitted block sizes
-// 1, 2, 4, 64, 128, 256.
-#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)             \
-  switch (block_size) {                                                     \
-    case 8:                                                                 \
-      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                      \
-      break;                                                                \
-    case 16:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);                     \
-      break;                                                                \
-    case 32:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);                     \
-      break;                                                                \
-    default:                                                                \
-      TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
-      break;                                                                \
-  }
-
-void paged_attention_v2(
-  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
-  torch::Tensor& exp_sums,        // [num_seqs, num_heads, max_num_partitions]
-  torch::Tensor& max_logits,      // [num_seqs, num_heads, max_num_partitions]
-  torch::Tensor& tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
-  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
-  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
-  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
-  int num_kv_heads,               // [num_heads]
-  float scale,
-  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
-  torch::Tensor& context_lens,    // [num_seqs]
-  int block_size,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
-  if (kv_cache_dtype == "auto") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-#ifdef ENABLE_FP8_E5M2
-  } else if (kv_cache_dtype == "fp8_e5m2") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-#endif
-  } else if (kv_cache_dtype == "int8") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-  } else {
-    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
-  }
-}
-
-#undef WARP_SIZE
-#undef MAX
-#undef MIN
-#undef DIVIDE_ROUND_UP
+ #ifdef USE_ROCM
+ #include <hip/hip_runtime.h>
+ #endif
+ 
+ #include <torch/extension.h>
+ #include <ATen/cuda/CUDAContext.h>
+ #include <c10/cuda/CUDAGuard.h>
+ 
+ #include "attention_dtypes.h"
+ #include "attention_utils.cuh"
+ #ifdef ENABLE_FP8_E5M2
+ #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+ #endif
+ 
+ #include <algorithm>
+ 
+ #ifndef USE_ROCM
+ #define WARP_SIZE 32
+ #else
+ #define WARP_SIZE warpSize
+ #endif
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
+ #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
+ 
+ namespace aphrodite {
+ 
+ // Utility function for attention softmax.
+ template<int NUM_WARPS>
+ inline __device__ float block_sum(float* red_smem, float sum) {
+   // Decompose the thread index into warp / lane.
+   int warp = threadIdx.x / WARP_SIZE;
+   int lane = threadIdx.x % WARP_SIZE;
+ 
+   // Compute the sum per warp.
+ #pragma unroll
+   for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+     sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
+   }
+ 
+   // Warp leaders store the data to shared memory.
+   if (lane == 0) {
+     red_smem[warp] = sum;
+   }
+ 
+   // Make sure the data is in shared memory.
+   __syncthreads();
+ 
+   // The warps compute the final sums.
+   if (lane < NUM_WARPS) {
+     sum = red_smem[lane];
+   }
+ 
+   // Parallel reduction inside the warp.
+ #pragma unroll
+   for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+     sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
+   }
+ 
+   // Broadcast to other threads.
+   return APHRODITE_SHFL_SYNC(sum, 0);
+ }
+ 
+ // TODO: Merge the last two dimensions of the grid.
+ // Grid: (num_heads, num_seqs, max_num_partitions).
+ template<
+   typename scalar_t,
+   typename cache_t,
+   int HEAD_SIZE,
+   int BLOCK_SIZE,
+   int NUM_THREADS,
+   bool IS_FP8_E5M2_KV_CACHE,
+   int PARTITION_SIZE = 0> // Zero means no partitioning.
+ __device__ void paged_attention_kernel(
+   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
+   float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
+   scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
+   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+   const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+   const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+   const int num_kv_heads,                 // [num_heads]
+   const float scale,
+   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+   const int* __restrict__ context_lens,   // [num_seqs]
+   const int max_num_blocks_per_seq,
+   const float* __restrict__ alibi_slopes, // [num_heads]
+   const int q_stride,
+   const int kv_block_stride,
+   const int kv_head_stride) {
+   const int seq_idx = blockIdx.y;
+   const int partition_idx = blockIdx.z;
+   const int max_num_partitions = gridDim.z;
+   constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
+   const int context_len = context_lens[seq_idx];
+   if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
+     // No work to do. Terminate the thread block.
+     return;
+   }
+ 
+   const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
+   const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
+ 
+   // [start_block_idx, end_block_idx) is the range of blocks to process.
+   const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
+   const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
+   const int num_blocks = end_block_idx - start_block_idx;
+ 
+   // [start_token_idx, end_token_idx) is the range of tokens to process.
+   const int start_token_idx = start_block_idx * BLOCK_SIZE;
+   const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
+   const int num_tokens = end_token_idx - start_token_idx;
+ 
+   constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+   constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
+   assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
+   constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
+   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+   const int thread_idx = threadIdx.x;
+   const int warp_idx = thread_idx / WARP_SIZE;
+   const int lane = thread_idx % WARP_SIZE;
+ 
+   const int head_idx = blockIdx.x;
+   const int num_heads = gridDim.x;
+   const int num_queries_per_kv = num_heads / num_kv_heads;
+   const int kv_head_idx = head_idx / num_queries_per_kv;
+   const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
+ 
+   // A vector type to store a part of a key or a query.
+   // The vector size is configured in such a way that the threads in a thread group
+   // fetch or compute 16 bytes at a time.
+   // For example, if the size of a thread group is 4 and the data type is half,
+   // then the vector size is 16 / (4 * sizeof(half)) == 2.
+   constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
+   using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+   using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+ #ifdef ENABLE_FP8_E5M2
+   using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
+ #endif
+ 
+   constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
+   constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
+ 
+   const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
+   const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
+ 
+   // Load the query to registers.
+   // Each thread in a thread group has a different part of the query.
+   // For example, if the the thread group size is 4, then the first thread in the group
+   // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
+   // th vectors of the query, and so on.
+   // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
+   const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+   __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
+ #pragma unroll
+   for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
+     const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
+     q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
+   }
+   __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
+ 
+   // Memory planning.
+   extern __shared__ char shared_mem[];
+   // NOTE: We use FP32 for the softmax logits for better accuracy.
+   float* logits = reinterpret_cast<float*>(shared_mem);
+   // Workspace for reduction.
+   __shared__ float red_smem[2 * NUM_WARPS];
+ 
+   // x == THREAD_GROUP_SIZE * VEC_SIZE
+   // Each thread group fetches x elements from the key at a time.
+   constexpr int x = 16 / sizeof(cache_t);
+   float qk_max = -FLT_MAX;
+ 
+   // Iterate over the key blocks.
+   // Each warp fetches a block of keys for each iteration.
+   // Each thread group in a warp fetches a key from the block, and computes
+   // dot product with the query.
+   const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+   for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+     // NOTE: The block number is stored in int32. However, we cast it to int64
+     // because int32 can lead to overflow when this variable is multiplied by large numbers
+     // (e.g., kv_block_stride).
+     const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+ 
+     // Load a key to registers.
+     // Each thread in a thread group has a different part of the key.
+     // For example, if the the thread group size is 4, then the first thread in the group
+     // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
+     // vectors of the key, and so on.
+     for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+       const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
+       const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+       K_vec k_vecs[NUM_VECS_PER_THREAD];
+ 
+ #pragma unroll
+       for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
+         const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+                                        + kv_head_idx * kv_head_stride
+                                        + physical_block_offset * x;
+         const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
+         const int offset1 = (vec_idx * VEC_SIZE) / x;
+         const int offset2 = (vec_idx * VEC_SIZE) % x;
+         if constexpr (IS_FP8_E5M2_KV_CACHE) {
+ #ifdef ENABLE_FP8_E5M2
+           Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+           // Vector conversion from Quant_vec to K_vec.
+           k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
+ #else
+           assert(false);
+ #endif
+         } else {
+           k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+         }
+       }
+ 
+       // Compute dot product.
+       // This includes a reduction across the threads in the same thread group.
+       float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
+       // Add the ALiBi bias if slopes are given.
+       qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
+ 
+       if (thread_group_offset == 0) {
+         // Store the partial reductions to shared memory.
+         // NOTE: It is required to zero out the masked logits.
+         const bool mask = token_idx >= context_len;
+         logits[token_idx - start_token_idx] = mask ? 0.f : qk;
+         // Update the max value.
+         qk_max = mask ? qk_max : fmaxf(qk_max, qk);
+       }
+     }
+   }
+ 
+   // Perform reduction across the threads in the same warp to get the
+   // max qk value for each "warp" (not across the thread block yet).
+   // The 0-th thread of each thread group already has its max qk value.
+ #pragma unroll
+   for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
+     qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
+   }
+   if (lane == 0) {
+     red_smem[warp_idx] = qk_max;
+   }
+   __syncthreads();
+ 
+   // TODO: Refactor this part.
+   // Get the max qk value for the sequence.
+   qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+ #pragma unroll
+   for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+     qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
+   }
+   // Broadcast the max qk value to all threads.
+   qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
+ 
+   // Get the sum of the exp values.
+   float exp_sum = 0.f;
+   for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+     float val = __expf(logits[i] - qk_max);
+     logits[i] = val;
+     exp_sum += val;
+   }
+   exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
+ 
+   // Compute softmax.
+   const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
+   for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+     logits[i] *= inv_sum;
+   }
+   __syncthreads();
+ 
+   // If partitioning is enabled, store the max logit and exp_sum.
+   if (USE_PARTITIONING && thread_idx == 0) {
+     float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+                                        + head_idx * max_num_partitions
+                                        + partition_idx;
+     *max_logits_ptr = qk_max;
+     float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+                                    + head_idx * max_num_partitions
+                                    + partition_idx;
+     *exp_sums_ptr = exp_sum;
+   }
+ 
+   // Each thread will fetch 16 bytes from the value cache at a time.
+   constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
+   using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+   using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+ #ifdef ENABLE_FP8_E5M2
+   using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
+ #endif
+   using Float_L_vec = typename FloatVec<L_vec>::Type;
+ 
+   constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
+   constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
+   constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
+ 
+   // NOTE: We use FP32 for the accumulator for better accuracy.
+   float accs[NUM_ROWS_PER_THREAD];
+ #pragma unroll
+   for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+     accs[i] = 0.f;
+   }
+ 
+   scalar_t zero_value;
+   zero(zero_value);
+   for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+     // NOTE: The block number is stored in int32. However, we cast it to int64
+     // because int32 can lead to overflow when this variable is multiplied by large numbers
+     // (e.g., kv_block_stride).
+     const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+     const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
+     const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+     L_vec logits_vec;
+     from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
+ 
+     const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+                                    + kv_head_idx * kv_head_stride;
+ #pragma unroll
+     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+       if (row_idx < HEAD_SIZE) {
+         const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
+         V_vec v_vec;
+         if constexpr (IS_FP8_E5M2_KV_CACHE) {
+ #ifdef ENABLE_FP8_E5M2
+           V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+           // Vector conversion from V_quant_vec to V_vec.
+           v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
+ #else
+           assert(false);
+ #endif
+         } else {
+           v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
+         }
+         if (block_idx == num_context_blocks - 1) {
+           // NOTE: When v_vec contains the tokens that are out of the context,
+           // we should explicitly zero out the values since they may contain NaNs.
+           scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
+ #pragma unroll
+           for (int j = 0; j < V_VEC_SIZE; j++) {
+             v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
+           }
+         }
+         accs[i] += dot(logits_vec, v_vec);
+       }
+     }
+   }
+ 
+   // Perform reduction within each warp.
+ #pragma unroll
+   for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+     float acc = accs[i];
+ #pragma unroll
+     for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
+       acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
+     }
+     accs[i] = acc;
+   }
+ 
+   // NOTE: A barrier is required because the shared memory space for logits
+   // is reused for the output.
+   __syncthreads();
+ 
+   // Perform reduction across warps.
+   float* out_smem = reinterpret_cast<float*>(shared_mem);
+ #pragma unroll
+   for (int i = NUM_WARPS; i > 1; i /= 2) {
+     int mid = i / 2;
+     // Upper warps write to shared memory.
+     if (warp_idx >= mid && warp_idx < i) {
+       float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
+ #pragma unroll
+       for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+         const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+         if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+           dst[row_idx] = accs[i];
+         }
+       }
+     }
+     __syncthreads();
+ 
+     // Lower warps update the output.
+     if (warp_idx < mid) {
+       const float* src = &out_smem[warp_idx * HEAD_SIZE];
+ #pragma unroll
+       for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+         const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+         if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+           accs[i] += src[row_idx];
+         }
+       }
+     }
+     __syncthreads();
+   }
+ 
+   // Write the final output.
+   if (warp_idx == 0) {
+     scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                             + head_idx * max_num_partitions * HEAD_SIZE
+                             + partition_idx * HEAD_SIZE;
+ #pragma unroll
+     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+       if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+         from_float(*(out_ptr + row_idx), accs[i]);
+       }
+     }
+   }
+ }
+ 
+ // Grid: (num_heads, num_seqs, 1).
+ template<
+   typename scalar_t,
+   typename cache_t,
+   int HEAD_SIZE,
+   int BLOCK_SIZE,
+   int NUM_THREADS,
+   bool IS_FP8_E5M2_KV_CACHE>
+ __global__ void paged_attention_v1_kernel(
+   scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
+   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+   const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+   const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+   const int num_kv_heads,                 // [num_heads]
+   const float scale,
+   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+   const int* __restrict__ context_lens,   // [num_seqs]
+   const int max_num_blocks_per_seq,
+   const float* __restrict__ alibi_slopes, // [num_heads]
+   const int q_stride,
+   const int kv_block_stride,
+   const int kv_head_stride) {
+   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
+     /* exp_sums */ nullptr, /* max_logits */ nullptr,
+     out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
+     max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
+ }
+ 
+ // Grid: (num_heads, num_seqs, max_num_partitions).
+ template<
+   typename scalar_t,
+   typename cache_t,
+   int HEAD_SIZE,
+   int BLOCK_SIZE,
+   int NUM_THREADS,
+   bool IS_FP8_E5M2_KV_CACHE,
+   int PARTITION_SIZE>
+ __global__ void paged_attention_v2_kernel(
+   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
+   float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
+   scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
+   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+   const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+   const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+   const int num_kv_heads,                 // [num_heads]
+   const float scale,
+   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+   const int* __restrict__ context_lens,   // [num_seqs]
+   const int max_num_blocks_per_seq,
+   const float* __restrict__ alibi_slopes, // [num_heads]
+   const int q_stride,
+   const int kv_block_stride,
+   const int kv_head_stride) {
+   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
+     exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
+     block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
+     q_stride, kv_block_stride, kv_head_stride);
+ }
+ 
+ // Grid: (num_heads, num_seqs).
+ template<
+   typename scalar_t,
+   int HEAD_SIZE,
+   int NUM_THREADS,
+   int PARTITION_SIZE>
+ __global__ void paged_attention_v2_reduce_kernel(
+   scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
+   const float* __restrict__ exp_sums,     // [num_seqs, num_heads, max_num_partitions]
+   const float* __restrict__ max_logits,   // [num_seqs, num_heads, max_num_partitions]
+   const scalar_t* __restrict__ tmp_out,   // [num_seqs, num_heads, max_num_partitions, head_size]
+   const int* __restrict__ context_lens,   // [num_seqs]
+   const int max_num_partitions) {
+   const int num_heads = gridDim.x;
+   const int head_idx = blockIdx.x;
+   const int seq_idx = blockIdx.y;
+   const int context_len = context_lens[seq_idx];
+   const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+   if (num_partitions == 1) {
+     // No need to reduce. Only copy tmp_out to out.
+     scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+     const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                                           + head_idx * max_num_partitions * HEAD_SIZE;
+     for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
+       out_ptr[i] = tmp_out_ptr[i];
+     }
+     // Terminate the thread block.
+     return;
+   }
+ 
+   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+   const int warp_idx = threadIdx.x / WARP_SIZE;
+   const int lane = threadIdx.x % WARP_SIZE;
+ 
+   // Size: 2 * num_partitions.
+   extern __shared__ char shared_mem[];
+   // Workspace for reduction.
+   __shared__ float red_smem[2 * NUM_WARPS];
+ 
+   // Load max logits to shared memory.
+   float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
+   const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+                                            + head_idx * max_num_partitions;
+   float max_logit = -FLT_MAX;
+   for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+     const float l = max_logits_ptr[i];
+     shared_max_logits[i] = l;
+     max_logit = fmaxf(max_logit, l);
+   }
+   __syncthreads();
+ 
+   // Get the global max logit.
+   // Reduce within the warp.
+ #pragma unroll
+   for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+     max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
+   }
+   if (lane == 0) {
+     red_smem[warp_idx] = max_logit;
+   }
+   __syncthreads();
+   // Reduce across warps.
+   max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+ #pragma unroll
+   for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+     max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
+   }
+   // Broadcast the max value to all threads.
+   max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
+ 
+   // Load rescaled exp sums to shared memory.
+   float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
+   const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+                                        + head_idx * max_num_partitions;
+   float global_exp_sum = 0.0f;
+   for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+     float l = shared_max_logits[i];
+     float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
+     global_exp_sum += rescaled_exp_sum;
+     shared_exp_sums[i] = rescaled_exp_sum;
+   }
+   __syncthreads();
+   global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
+   const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
+ 
+   // Aggregate tmp_out to out.
+   const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                                         + head_idx * max_num_partitions * HEAD_SIZE;
+   scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+ #pragma unroll
+   for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
+     float acc = 0.0f;
+     for (int j = 0; j < num_partitions; ++j) {
+       acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
+     }
+     from_float(out_ptr[i], acc);
+   }
+ }
+ 
+ } // namespace aphrodite
+ 
+ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                  \
+   APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                       \
+     ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,   \
+       IS_FP8_E5M2_KV_CACHE>), shared_mem_size);                                               \
+   aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
+   IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>(                            \
+     out_ptr,                                                                                  \
+     query_ptr,                                                                                \
+     key_cache_ptr,                                                                            \
+     value_cache_ptr,                                                                          \
+     num_kv_heads,                                                                             \
+     scale,                                                                                    \
+     block_tables_ptr,                                                                         \
+     context_lens_ptr,                                                                         \
+     max_num_blocks_per_seq,                                                                   \
+     alibi_slopes_ptr,                                                                         \
+     q_stride,                                                                                 \
+     kv_block_stride,                                                                          \
+     kv_head_stride);
+ 
+ // TODO: Tune NUM_THREADS.
+ template<
+   typename T,
+   typename CACHE_T,
+   int BLOCK_SIZE,
+   bool IS_FP8_E5M2_KV_CACHE,
+   int NUM_THREADS = 128>
+ void paged_attention_v1_launcher(
+   torch::Tensor& out,
+   torch::Tensor& query,
+   torch::Tensor& key_cache,
+   torch::Tensor& value_cache,
+   int num_kv_heads,
+   float scale,
+   torch::Tensor& block_tables,
+   torch::Tensor& context_lens,
+   int max_context_len,
+   const c10::optional<torch::Tensor>& alibi_slopes) {
+   int num_seqs = query.size(0);
+   int num_heads = query.size(1);
+   int head_size = query.size(2);
+   int max_num_blocks_per_seq = block_tables.size(1);
+   int q_stride = query.stride(0);
+   int kv_block_stride = key_cache.stride(0);
+   int kv_head_stride = key_cache.stride(1);
+ 
+   int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+   assert(head_size % thread_group_size == 0);
+ 
+   // NOTE: alibi_slopes is optional.
+   const float* alibi_slopes_ptr = alibi_slopes ?
+     reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+     : nullptr;
+ 
+   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
+   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
+   CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
+   CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
+   int* block_tables_ptr = block_tables.data_ptr<int>();
+   int* context_lens_ptr = context_lens.data_ptr<int>();
+ 
+   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+   int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
+   int logits_size = padded_max_context_len * sizeof(float);
+   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+   // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
+   // Keep that in sync with the logic here!
+   int shared_mem_size = std::max(logits_size, outputs_size);
+ 
+   dim3 grid(num_heads, num_seqs, 1);
+   dim3 block(NUM_THREADS);
+   const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+   switch (head_size) {
+     // NOTE: To reduce the compilation time, we only compile for the
+     // head sizes that we use in the model. However, we can easily extend this
+     // to support any head size which is a multiple of 16.
+     case 64:
+       LAUNCH_PAGED_ATTENTION_V1(64);
+       break;
+     case 80:
+       LAUNCH_PAGED_ATTENTION_V1(80);
+       break;
+     case 96:
+       LAUNCH_PAGED_ATTENTION_V1(96);
+       break;
+     case 112:
+       LAUNCH_PAGED_ATTENTION_V1(112);
+       break;
+     case 128:
+       LAUNCH_PAGED_ATTENTION_V1(128);
+       break;
+     case 256:
+       LAUNCH_PAGED_ATTENTION_V1(256);
+       break;
+     default:
+       TORCH_CHECK(false, "Unsupported head size: ", head_size);
+       break;
+   }
+ }
+ 
+ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)       \
+   paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
+     out,                                                                     \
+     query,                                                                   \
+     key_cache,                                                               \
+     value_cache,                                                             \
+     num_kv_heads,                                                            \
+     scale,                                                                   \
+     block_tables,                                                            \
+     context_lens,                                                            \
+     max_context_len,                                                         \
+     alibi_slopes);
+ 
+ // NOTE: To reduce the compilation time, we omitted block sizes
+ // 1, 2, 4, 64, 128, 256.
+ #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
+   switch (block_size) {                                               \
+     case 8:                                                           \
+       CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);          \
+       break;                                                          \
+     case 16:                                                          \
+       CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);         \
+       break;                                                          \
+     case 32:                                                          \
+       CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);         \
+       break;                                                          \
+     default:                                                          \
+       TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
+       break;                                                          \
+   }
+ 
+ void paged_attention_v1(
+   torch::Tensor& out,             // [num_seqs, num_heads, head_size]
+   torch::Tensor& query,           // [num_seqs, num_heads, head_size]
+   torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
+   torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
+   int num_kv_heads,               // [num_heads]
+   float scale,
+   torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
+   torch::Tensor& context_lens,    // [num_seqs]
+   int block_size,
+   int max_context_len,
+   const c10::optional<torch::Tensor>& alibi_slopes,
+   const std::string& kv_cache_dtype) {
+   if (kv_cache_dtype == "auto") {
+     if (query.dtype() == at::ScalarType::Float) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
+     } else if (query.dtype() == at::ScalarType::Half) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
+     } else if (query.dtype() == at::ScalarType::BFloat16) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
+     } else {
+       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+     }
+   } else if (kv_cache_dtype == "fp8_e5m2") {
+     if (query.dtype() == at::ScalarType::Float) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
+     } else if (query.dtype() == at::ScalarType::Half) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
+     } else if (query.dtype() == at::ScalarType::BFloat16) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
+     } else {
+       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+     }
+   } else {
+     TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+   }
+ }
+ 
+ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
+   aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
+   IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>                                                       \
+   <<<grid, block, shared_mem_size, stream>>>(                                                 \
+     exp_sums_ptr,                                                                             \
+     max_logits_ptr,                                                                           \
+     tmp_out_ptr,                                                                              \
+     query_ptr,                                                                                \
+     key_cache_ptr,                                                                            \
+     value_cache_ptr,                                                                          \
+     num_kv_heads,                                                                             \
+     scale,                                                                                    \
+     block_tables_ptr,                                                                         \
+     context_lens_ptr,                                                                         \
+     max_num_blocks_per_seq,                                                                   \
+     alibi_slopes_ptr,                                                                         \
+     q_stride,                                                                                 \
+     kv_block_stride,                                                                          \
+     kv_head_stride);                                                                          \
+   aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
+   <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
+     out_ptr,                                                                                  \
+     exp_sums_ptr,                                                                             \
+     max_logits_ptr,                                                                           \
+     tmp_out_ptr,                                                                              \
+     context_lens_ptr,                                                                         \
+     max_num_partitions);
+ 
+ template<
+   typename T,
+   typename CACHE_T,
+   int BLOCK_SIZE,
+   bool IS_FP8_E5M2_KV_CACHE,
+   int NUM_THREADS = 128,
+   int PARTITION_SIZE = 512>
+ void paged_attention_v2_launcher(
+   torch::Tensor& out,
+   torch::Tensor& exp_sums,
+   torch::Tensor& max_logits,
+   torch::Tensor& tmp_out,
+   torch::Tensor& query,
+   torch::Tensor& key_cache,
+   torch::Tensor& value_cache,
+   int num_kv_heads,
+   float scale,
+   torch::Tensor& block_tables,
+   torch::Tensor& context_lens,
+   int max_context_len,
+   const c10::optional<torch::Tensor>& alibi_slopes) {
+   int num_seqs = query.size(0);
+   int num_heads = query.size(1);
+   int head_size = query.size(2);
+   int max_num_blocks_per_seq = block_tables.size(1);
+   int q_stride = query.stride(0);
+   int kv_block_stride = key_cache.stride(0);
+   int kv_head_stride = key_cache.stride(1);
+ 
+   int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+   assert(head_size % thread_group_size == 0);
+ 
+   // NOTE: alibi_slopes is optional.
+   const float* alibi_slopes_ptr = alibi_slopes ?
+     reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+     : nullptr;
+ 
+   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
+   float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
+   float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
+   T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
+   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
+   CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
+   CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
+   int* block_tables_ptr = block_tables.data_ptr<int>();
+   int* context_lens_ptr = context_lens.data_ptr<int>();
+ 
+   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+   int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
+   int logits_size = PARTITION_SIZE * sizeof(float);
+   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+ 
+   // For paged attention v2 kernel.
+   dim3 grid(num_heads, num_seqs, max_num_partitions);
+   int shared_mem_size = std::max(logits_size, outputs_size);
+   // For paged attention v2 reduce kernel.
+   dim3 reduce_grid(num_heads, num_seqs);
+   int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
+ 
+   dim3 block(NUM_THREADS);
+   const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+   switch (head_size) {
+     // NOTE: To reduce the compilation time, we only compile for the
+     // head sizes that we use in the model. However, we can easily extend this
+     // to support any head size which is a multiple of 16.
+     case 64:
+       LAUNCH_PAGED_ATTENTION_V2(64);
+       break;
+     case 80:
+       LAUNCH_PAGED_ATTENTION_V2(80);
+       break;
+     case 96:
+       LAUNCH_PAGED_ATTENTION_V2(96);
+       break;
+     case 112:
+       LAUNCH_PAGED_ATTENTION_V2(112);
+       break;
+     case 128:
+       LAUNCH_PAGED_ATTENTION_V2(128);
+       break;
+     case 256:
+       LAUNCH_PAGED_ATTENTION_V2(256);
+       break;
+     default:
+       TORCH_CHECK(false, "Unsupported head size: ", head_size);
+       break;
+   }
+ }
+ 
+ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)           \
+   paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>(     \
+     out,                                                                         \
+     exp_sums,                                                                    \
+     max_logits,                                                                  \
+     tmp_out,                                                                     \
+     query,                                                                       \
+     key_cache,                                                                   \
+     value_cache,                                                                 \
+     num_kv_heads,                                                                \
+     scale,                                                                       \
+     block_tables,                                                                \
+     context_lens,                                                                \
+     max_context_len,                                                             \
+     alibi_slopes);
+ 
+ // NOTE: To reduce the compilation time, we omitted block sizes
+ // 1, 2, 4, 64, 128, 256.
+ #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE)       \
+   switch (block_size) {                                                     \
+     case 8:                                                                 \
+       CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);                \
+       break;                                                                \
+     case 16:                                                                \
+       CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);               \
+       break;                                                                \
+     case 32:                                                                \
+       CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);               \
+       break;                                                                \
+     default:                                                                \
+       TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
+       break;                                                                \
+   }
+ 
+ void paged_attention_v2(
+   torch::Tensor& out,             // [num_seqs, num_heads, head_size]
+   torch::Tensor& exp_sums,        // [num_seqs, num_heads, max_num_partitions]
+   torch::Tensor& max_logits,      // [num_seqs, num_heads, max_num_partitions]
+   torch::Tensor& tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
+   torch::Tensor& query,           // [num_seqs, num_heads, head_size]
+   torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
+   torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
+   int num_kv_heads,               // [num_heads]
+   float scale,
+   torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
+   torch::Tensor& context_lens,    // [num_seqs]
+   int block_size,
+   int max_context_len,
+   const c10::optional<torch::Tensor>& alibi_slopes,
+   const std::string& kv_cache_dtype) {
+   if (kv_cache_dtype == "auto") {
+     if (query.dtype() == at::ScalarType::Float) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
+     } else if (query.dtype() == at::ScalarType::Half) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
+     } else if (query.dtype() == at::ScalarType::BFloat16) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
+     } else {
+       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+     }
+   } else if (kv_cache_dtype == "fp8_e5m2") {
+     if (query.dtype() == at::ScalarType::Float) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
+     } else if (query.dtype() == at::ScalarType::Half) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
+     } else if (query.dtype() == at::ScalarType::BFloat16) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
+     } else {
+       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+     }
+   } else {
+     TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+   }
+ }
+ 
+ #undef WARP_SIZE
+ #undef MAX
+ #undef MIN
+ #undef DIVIDE_ROUND_UP
+ 

+ 255 - 262
kernels/attention/dtype_float32.cuh

@@ -16,265 +16,258 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-#pragma once
-
-#include "attention_generic.cuh"
-
-#include <stdint.h>
-
-namespace aphrodite {
-
-// Define custom FP32 vector data types.
-struct Float4_ {
-  float2 x;
-  float2 y;
-};
-
-struct Float8_ {
-  float2 x;
-  float2 y;
-  float2 z;
-  float2 w;
-};
-
-// FP32 vector types for Q, K, V.
-template<>
-struct Vec<float, 1> {
-  using Type = float;
-};
-template<>
-struct Vec<float, 2> {
-  using Type = float2;
-};
-template<>
-struct Vec<float, 4> {
-  using Type = float4;
-};
-
-// FP32 accumulator vector types corresponding to Vec.
-template<>
-struct FloatVec<float> {
-  using Type = float;
-};
-template<>
-struct FloatVec<float2> {
-  using Type = float2;
-};
-template<>
-struct FloatVec<float4> {
-  using Type = float4;
-};
-
-// Vector addition.
-inline __device__ float add(float a, float b) {
-  return a + b;
-}
-
-inline __device__ float2 add(float2 a, float2 b) {
-  float2 c;
-  c.x = add(a.x, b.x);
-  c.y = add(a.y, b.y);
-  return c;
-}
-
-inline __device__ float4 add(float4 a, float4 b) {
-  float4 c;
-  c.x = add(a.x, b.x);
-  c.y = add(a.y, b.y);
-  c.z = add(a.z, b.z);
-  c.w = add(a.w, b.w);
-  return c;
-}
-
-inline __device__ Float4_ add(Float4_ a, Float4_ b) {
-  Float4_ c;
-  c.x = add(a.x, b.x);
-  c.y = add(a.y, b.y);
-  return c;
-}
-
-// Vector multiplication.
-template<>
-inline __device__ float mul<float, float>(float a, float b) {
-  return a * b;
-}
-
-template<>
-inline __device__ float2 mul(float2 a, float2 b) {
-  float2 c;
-  c.x = a.x * b.x;
-  c.y = a.y * b.y;
-  return c;
-}
-
-template<>
-inline __device__ float2 mul(float a, float2 b) {
-  float2 c;
-  c.x = a * b.x;
-  c.y = a * b.y;
-  return c;
-}
-
-template<>
-inline __device__ float4 mul(float4 a, float4 b) {
-  float4 c;
-  c.x = a.x * b.x;
-  c.y = a.y * b.y;
-  c.z = a.z * b.z;
-  c.w = a.w * b.w;
-  return c;
-}
-
-template<>
-inline __device__ float4 mul(float a, float4 b) {
-  float4 c;
-  c.x = a * b.x;
-  c.y = a * b.y;
-  c.z = a * b.z;
-  c.w = a * b.w;
-  return c;
-}
-
-// Vector fused multiply-add.
-inline __device__ float fma(float a, float b, float c) {
-  return a * b + c;
-}
-
-inline __device__ float2 fma(float2 a, float2 b, float2 c) {
-  float2 d;
-  d.x = fma(a.x, b.x, c.x);
-  d.y = fma(a.y, b.y, c.y);
-  return d;
-}
-
-inline __device__ float2 fma(float a, float2 b, float2 c) {
-  float2 d;
-  d.x = fma(a, b.x, c.x);
-  d.y = fma(a, b.y, c.y);
-  return d;
-}
-
-inline __device__ float4 fma(float4 a, float4 b, float4 c) {
-  float4 d;
-  d.x = fma(a.x, b.x, c.x);
-  d.y = fma(a.y, b.y, c.y);
-  d.z = fma(a.z, b.z, c.z);
-  d.w = fma(a.w, b.w, c.w);
-  return d;
-}
-
-inline __device__ float4 fma(float a, float4 b, float4 c) {
-  float4 d;
-  d.x = fma(a, b.x, c.x);
-  d.y = fma(a, b.y, c.y);
-  d.z = fma(a, b.z, c.z);
-  d.w = fma(a, b.w, c.w);
-  return d;
-}
-
-inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
-  Float4_ d;
-  d.x = fma(a, b.x, c.x);
-  d.y = fma(a, b.y, c.y);
-  return d;
-}
-
-inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
-  Float8_ d;
-  d.x = fma(a, b.x, c.x);
-  d.y = fma(a, b.y, c.y);
-  d.z = fma(a, b.z, c.z);
-  d.w = fma(a, b.w, c.w);
-  return d;
-}
-
-// Vector sum.
-template<>
-inline __device__ float sum(float v) {
-  return v;
-}
-
-template<>
-inline __device__ float sum(float2 v) {
-  return v.x + v.y;
-}
-
-template<>
-inline __device__ float sum(float4 v) {
-  return v.x + v.y + v.z + v.w;
-}
-
-template<>
-inline __device__ float sum(Float4_ v) {
-  return v.x.x + v.x.y + v.y.x + v.y.y;
-}
-
-template<>
-inline __device__ float sum(Float8_ v) {
-  return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
-}
-
-// Vector dot product.
-inline __device__ float dot(float a, float b) {
-  return a * b;
-}
-
-inline __device__ float dot(float2 a, float2 b) {
-  float2 c = mul<float2, float2, float2>(a, b);
-  return c.x + c.y;
-}
-
-inline __device__ float dot(Float4_ a, Float4_ b) {
-  float2 acc = mul<float2, float2, float2>(a.x, b.x);
-  acc = fma(a.y, b.y, acc);
-  return acc.x + acc.y;
-}
-
-inline __device__ float dot(Float8_ a, Float8_ b) {
-  float2 acc = mul<float2, float2, float2>(a.x, b.x);
-  acc = fma(a.y, b.y, acc);
-  acc = fma(a.z, b.z, acc);
-  acc = fma(a.w, b.w, acc);
-  return acc.x + acc.y;
-}
-
-// From float to float.
-inline __device__ void from_float(float& dst, float src) {
-  dst = src;
-}
-
-inline __device__ void from_float(float2& dst, float2 src) {
-  dst = src;
-}
-
-inline __device__ void from_float(float4& dst, float4 src) {
-  dst = src;
-}
-
-// From float to float.
-inline __device__ float to_float(float u) {
-  return u;
-}
-
-inline __device__ float2 to_float(float2 u) {
-  return u;
-}
-
-inline __device__ float4 to_float(float4 u) {
-  return u;
-}
-
-inline __device__ Float4_ to_float(Float4_ u) {
-  return u;
-}
-
-inline __device__ Float8_ to_float(Float8_ u) {
-  return u;
-}
-
-// Zero-out a variable.
-inline __device__ void zero(float& dst) {
-  dst = 0.f;
-}
-
-} // namespace aphrodite
+ #pragma once
+
+ #include "attention_generic.cuh"
+ 
+ #include <stdint.h>
+ 
+ namespace aphrodite {
+ 
+ // Define custom FP32 vector data types.
+ struct Float4_ {
+   float2 x;
+   float2 y;
+ };
+ 
+ struct Float8_ {
+   float2 x;
+   float2 y;
+   float2 z;
+   float2 w;
+ };
+ 
+ // FP32 vector types for Q, K, V.
+ template<>
+ struct Vec<float, 1> {
+   using Type = float;
+ };
+ template<>
+ struct Vec<float, 2> {
+   using Type = float2;
+ };
+ template<>
+ struct Vec<float, 4> {
+   using Type = float4;
+ };
+ 
+ // FP32 accumulator vector types corresponding to Vec.
+ template<>
+ struct FloatVec<float> {
+   using Type = float;
+ };
+ template<>
+ struct FloatVec<float2> {
+   using Type = float2;
+ };
+ template<>
+ struct FloatVec<float4> {
+   using Type = float4;
+ };
+ 
+ // Vector addition.
+ inline __device__ float add(float a, float b) {
+   return a + b;
+ }
+ 
+ inline __device__ float2 add(float2 a, float2 b) {
+   float2 c;
+   c.x = add(a.x, b.x);
+   c.y = add(a.y, b.y);
+   return c;
+ }
+ 
+ inline __device__ float4 add(float4 a, float4 b) {
+   float4 c;
+   c.x = add(a.x, b.x);
+   c.y = add(a.y, b.y);
+   c.z = add(a.z, b.z);
+   c.w = add(a.w, b.w);
+   return c;
+ }
+ 
+ // Vector multiplication.
+ template<>
+ inline __device__ float mul<float, float>(float a, float b) {
+   return a * b;
+ }
+ 
+ template<>
+ inline __device__ float2 mul(float2 a, float2 b) {
+   float2 c;
+   c.x = a.x * b.x;
+   c.y = a.y * b.y;
+   return c;
+ }
+ 
+ template<>
+ inline __device__ float2 mul(float a, float2 b) {
+   float2 c;
+   c.x = a * b.x;
+   c.y = a * b.y;
+   return c;
+ }
+ 
+ template<>
+ inline __device__ float4 mul(float4 a, float4 b) {
+   float4 c;
+   c.x = a.x * b.x;
+   c.y = a.y * b.y;
+   c.z = a.z * b.z;
+   c.w = a.w * b.w;
+   return c;
+ }
+ 
+ template<>
+ inline __device__ float4 mul(float a, float4 b) {
+   float4 c;
+   c.x = a * b.x;
+   c.y = a * b.y;
+   c.z = a * b.z;
+   c.w = a * b.w;
+   return c;
+ }
+ 
+ // Vector fused multiply-add.
+ inline __device__ float fma(float a, float b, float c) {
+   return a * b + c;
+ }
+ 
+ inline __device__ float2 fma(float2 a, float2 b, float2 c) {
+   float2 d;
+   d.x = fma(a.x, b.x, c.x);
+   d.y = fma(a.y, b.y, c.y);
+   return d;
+ }
+ 
+ inline __device__ float2 fma(float a, float2 b, float2 c) {
+   float2 d;
+   d.x = fma(a, b.x, c.x);
+   d.y = fma(a, b.y, c.y);
+   return d;
+ }
+ 
+ inline __device__ float4 fma(float4 a, float4 b, float4 c) {
+   float4 d;
+   d.x = fma(a.x, b.x, c.x);
+   d.y = fma(a.y, b.y, c.y);
+   d.z = fma(a.z, b.z, c.z);
+   d.w = fma(a.w, b.w, c.w);
+   return d;
+ }
+ 
+ inline __device__ float4 fma(float a, float4 b, float4 c) {
+   float4 d;
+   d.x = fma(a, b.x, c.x);
+   d.y = fma(a, b.y, c.y);
+   d.z = fma(a, b.z, c.z);
+   d.w = fma(a, b.w, c.w);
+   return d;
+ }
+ 
+ inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
+   Float4_ d;
+   d.x = fma(a, b.x, c.x);
+   d.y = fma(a, b.y, c.y);
+   return d;
+ }
+ 
+ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
+   Float8_ d;
+   d.x = fma(a, b.x, c.x);
+   d.y = fma(a, b.y, c.y);
+   d.z = fma(a, b.z, c.z);
+   d.w = fma(a, b.w, c.w);
+   return d;
+ }
+ 
+ // Vector sum.
+ template<>
+ inline __device__ float sum(float v) {
+   return v;
+ }
+ 
+ template<>
+ inline __device__ float sum(float2 v) {
+   return v.x + v.y;
+ }
+ 
+ template<>
+ inline __device__ float sum(float4 v) {
+   return v.x + v.y + v.z + v.w;
+ }
+ 
+ template<>
+ inline __device__ float sum(Float4_ v) {
+   return v.x.x + v.x.y + v.y.x + v.y.y;
+ }
+ 
+ template<>
+ inline __device__ float sum(Float8_ v) {
+   return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
+ }
+ 
+ // Vector dot product.
+ inline __device__ float dot(float a, float b) {
+   return a * b;
+ }
+ 
+ inline __device__ float dot(float2 a, float2 b) {
+   float2 c = mul<float2, float2, float2>(a, b);
+   return c.x + c.y;
+ }
+ 
+ inline __device__ float dot(Float4_ a, Float4_ b) {
+   float2 acc = mul<float2, float2, float2>(a.x, b.x);
+   acc = fma(a.y, b.y, acc);
+   return acc.x + acc.y;
+ }
+ 
+ inline __device__ float dot(Float8_ a, Float8_ b) {
+   float2 acc = mul<float2, float2, float2>(a.x, b.x);
+   acc = fma(a.y, b.y, acc);
+   acc = fma(a.z, b.z, acc);
+   acc = fma(a.w, b.w, acc);
+   return acc.x + acc.y;
+ }
+ 
+ // From float to float.
+ inline __device__ void from_float(float& dst, float src) {
+   dst = src;
+ }
+ 
+ inline __device__ void from_float(float2& dst, float2 src) {
+   dst = src;
+ }
+ 
+ inline __device__ void from_float(float4& dst, float4 src) {
+   dst = src;
+ }
+ 
+ // From float to float.
+ inline __device__ float to_float(float u) {
+   return u;
+ }
+ 
+ inline __device__ float2 to_float(float2 u) {
+   return u;
+ }
+ 
+ inline __device__ float4 to_float(float4 u) {
+   return u;
+ }
+ 
+ inline __device__ Float4_ to_float(Float4_ u) {
+   return u;
+ }
+ 
+ inline __device__ Float8_ to_float(Float8_ u) {
+   return u;
+ }
+ 
+ // Zero-out a variable.
+ inline __device__ void zero(float& dst) {
+   dst = 0.f;
+ }
+ 
+ } // namespace aphrodite

+ 1 - 0
kernels/backup/README

@@ -0,0 +1 @@
+Backup of attention and cache kernels from INT8 KV Cache. Will be restored soon.

+ 8 - 0
kernels/backup/attention_dtypes.h

@@ -0,0 +1,8 @@
+#pragma once
+
+#include "attention_generic.cuh"
+#include "dtype_float16.cuh"
+#include "dtype_float32.cuh"
+#include "dtype_bfloat16.cuh"
+#include "dtype_fp8_e5m2.cuh"
+#include "dtype_int8.cuh"

+ 1032 - 0
kernels/backup/attention_kernels.cu

@@ -0,0 +1,1032 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The PygmalionAI team.
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifdef USE_ROCM
+#include <hip/hip_runtime.h>
+#endif
+
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+
+#include "attention_dtypes.h"
+#include "attention_utils.cuh"
+#include "../quantization/int8_kvcache/quant_utils.cuh"
+#ifdef ENABLE_FP8_E5M2
+#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+#endif
+
+#include <algorithm>
+
+#ifndef USE_ROCM
+#define WARP_SIZE 32
+#else
+#define WARP_SIZE warpSize
+#endif
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
+
+enum kv_cache_dtype {
+  AUTO,
+#ifdef ENABLE_FP8_E5M2
+  FP8_E5M2,
+#endif
+  INT8};
+
+namespace aphrodite {
+
+// Utility function for attention softmax.
+template<int NUM_WARPS>
+inline __device__ float block_sum(float* red_smem, float sum) {
+  // Decompose the thread index into warp / lane.
+  int warp = threadIdx.x / WARP_SIZE;
+  int lane = threadIdx.x % WARP_SIZE;
+
+  // Compute the sum per warp.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+    sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
+  }
+
+  // Warp leaders store the data to shared memory.
+  if (lane == 0) {
+    red_smem[warp] = sum;
+  }
+
+  // Make sure the data is in shared memory.
+  __syncthreads();
+
+  // The warps compute the final sums.
+  if (lane < NUM_WARPS) {
+    sum = red_smem[lane];
+  }
+
+  // Parallel reduction inside the warp.
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
+  }
+
+  // Broadcast to other threads.
+  return APHRODITE_SHFL_SYNC(sum, 0);
+}
+
+// TODO: Merge the last two dimensions of the grid.
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template<
+  typename scalar_t,
+  typename cache_t,
+  int HEAD_SIZE,
+  int BLOCK_SIZE,
+  int NUM_THREADS,
+  kv_cache_dtype KV_CACHE_DTYPE,
+  int PARTITION_SIZE = 0> // Zero means no partitioning.
+__device__ void paged_attention_kernel(
+  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
+  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
+  scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
+  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+  const int num_kv_heads,                 // [num_heads]
+  const float scale,
+  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+  const int* __restrict__ context_lens,   // [num_seqs]
+  const int max_num_blocks_per_seq,
+  const float* __restrict__ alibi_slopes, // [num_heads]
+  const int q_stride,
+  const int kv_block_stride,
+  const int kv_head_stride,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f) {
+  const int seq_idx = blockIdx.y;
+  const int partition_idx = blockIdx.z;
+  const int max_num_partitions = gridDim.z;
+  constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
+  const int context_len = context_lens[seq_idx];
+  if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
+    // No work to do. Terminate the thread block.
+    return;
+  }
+
+  const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
+  const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
+
+  // [start_block_idx, end_block_idx) is the range of blocks to process.
+  const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
+  const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
+  const int num_blocks = end_block_idx - start_block_idx;
+
+  // [start_token_idx, end_token_idx) is the range of tokens to process.
+  const int start_token_idx = start_block_idx * BLOCK_SIZE;
+  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
+  const int num_tokens = end_token_idx - start_token_idx;
+
+  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
+  assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
+  constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  const int thread_idx = threadIdx.x;
+  const int warp_idx = thread_idx / WARP_SIZE;
+  const int lane = thread_idx % WARP_SIZE;
+
+  const int head_idx = blockIdx.x;
+  const int num_heads = gridDim.x;
+  const int num_queries_per_kv = num_heads / num_kv_heads;
+  const int kv_head_idx = head_idx / num_queries_per_kv;
+  const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
+
+  // A vector type to store a part of a key or a query.
+  // The vector size is configured in such a way that the threads in a thread group
+  // fetch or compute 16 bytes at a time.
+  // For example, if the size of a thread group is 4 and the data type is half,
+  // then the vector size is 16 / (4 * sizeof(half)) == 2.
+  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
+  using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+  using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+  using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
+
+  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
+  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
+
+  const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
+  const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
+
+  // Load the query to registers.
+  // Each thread in a thread group has a different part of the query.
+  // For example, if the the thread group size is 4, then the first thread in the group
+  // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
+  // th vectors of the query, and so on.
+  // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
+  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
+#pragma unroll
+  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
+    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
+    q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
+  }
+  __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
+
+  // Memory planning.
+  extern __shared__ char shared_mem[];
+  // NOTE: We use FP32 for the softmax logits for better accuracy.
+  float* logits = reinterpret_cast<float*>(shared_mem);
+  // Workspace for reduction.
+  __shared__ float red_smem[2 * NUM_WARPS];
+
+  // x == THREAD_GROUP_SIZE * VEC_SIZE
+  // Each thread group fetches x elements from the key at a time.
+  constexpr int x = 16 / sizeof(cache_t);
+  float qk_max = -FLT_MAX;
+
+  // Iterate over the key blocks.
+  // Each warp fetches a block of keys for each iteration.
+  // Each thread group in a warp fetches a key from the block, and computes
+  // dot product with the query.
+  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+    // NOTE: The block number is stored in int32. However, we cast it to int64
+    // because int32 can lead to overflow when this variable is multiplied by large numbers
+    // (e.g., kv_block_stride).
+    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+
+    // Load a key to registers.
+    // Each thread in a thread group has a different part of the key.
+    // For example, if the the thread group size is 4, then the first thread in the group
+    // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
+    // vectors of the key, and so on.
+    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+      const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
+      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+      K_vec k_vecs[NUM_VECS_PER_THREAD];
+
+#pragma unroll
+      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
+        const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+                                       + kv_head_idx * kv_head_stride
+                                       + physical_block_offset * x;
+        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
+        const int offset1 = (vec_idx * VEC_SIZE) / x;
+        const int offset2 = (vec_idx * VEC_SIZE) % x;
+        if constexpr (KV_CACHE_DTYPE == INT8) {
+          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          using Dequant_vec = typename FloatVec<Quant_vec>::Type;
+          Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
+          k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
+#ifdef ENABLE_FP8_E5M2
+        } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
+          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          // Vector conversion from Quant_vec to K_vec.
+          k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
+#endif
+        } else {
+          k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+        }
+      }
+
+      // Compute dot product.
+      // This includes a reduction across the threads in the same thread group.
+      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
+      // Add the ALiBi bias if slopes are given.
+      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
+
+      if (thread_group_offset == 0) {
+        // Store the partial reductions to shared memory.
+        // NOTE: It is required to zero out the masked logits.
+        const bool mask = token_idx >= context_len;
+        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
+        // Update the max value.
+        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
+      }
+    }
+  }
+
+  // Perform reduction across the threads in the same warp to get the
+  // max qk value for each "warp" (not across the thread block yet).
+  // The 0-th thread of each thread group already has its max qk value.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
+    qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
+  }
+  if (lane == 0) {
+    red_smem[warp_idx] = qk_max;
+  }
+  __syncthreads();
+
+  // TODO: Refactor this part.
+  // Get the max qk value for the sequence.
+  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
+  }
+  // Broadcast the max qk value to all threads.
+  qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
+
+  // Get the sum of the exp values.
+  float exp_sum = 0.f;
+  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+    float val = __expf(logits[i] - qk_max);
+    logits[i] = val;
+    exp_sum += val;
+  }
+  exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
+
+  // Compute softmax.
+  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
+  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+    logits[i] *= inv_sum;
+  }
+  __syncthreads();
+
+  // If partitioning is enabled, store the max logit and exp_sum.
+  if (USE_PARTITIONING && thread_idx == 0) {
+    float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+                                       + head_idx * max_num_partitions
+                                       + partition_idx;
+    *max_logits_ptr = qk_max;
+    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+                                   + head_idx * max_num_partitions
+                                   + partition_idx;
+    *exp_sums_ptr = exp_sum;
+  }
+
+  // Each thread will fetch 16 bytes from the value cache at a time.
+  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
+  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+  using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
+  using Float_L_vec = typename FloatVec<L_vec>::Type;
+
+  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
+  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
+  constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
+
+  // NOTE: We use FP32 for the accumulator for better accuracy.
+  float accs[NUM_ROWS_PER_THREAD];
+#pragma unroll
+  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+    accs[i] = 0.f;
+  }
+
+  scalar_t zero_value;
+  zero(zero_value);
+  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+    // NOTE: The block number is stored in int32. However, we cast it to int64
+    // because int32 can lead to overflow when this variable is multiplied by large numbers
+    // (e.g., kv_block_stride).
+    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
+    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+    L_vec logits_vec;
+    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
+
+    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+                                   + kv_head_idx * kv_head_stride;
+#pragma unroll
+    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+      if (row_idx < HEAD_SIZE) {
+        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
+        V_vec v_vec;
+        if constexpr (KV_CACHE_DTYPE == INT8) {
+          // dequant and conversion
+          V_quant_vec v_vec_quant = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+          using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
+          V_dequant_vec v_vec_dequant = int8::dequant(v_vec_quant, v_scale, v_zp);
+          v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
+#ifdef ENABLE_FP8_E5M2
+        } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
+          V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+          // Vector conversion from V_quant_vec to V_vec.
+          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
+#endif
+        } else {
+          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
+        }
+        if (block_idx == num_context_blocks - 1) {
+          // NOTE: When v_vec contains the tokens that are out of the context,
+          // we should explicitly zero out the values since they may contain NaNs.
+          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
+#pragma unroll
+          for (int j = 0; j < V_VEC_SIZE; j++) {
+            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
+          }
+        }
+        accs[i] += dot(logits_vec, v_vec);
+      }
+    }
+  }
+
+  // Perform reduction within each warp.
+#pragma unroll
+  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+    float acc = accs[i];
+#pragma unroll
+    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
+      acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
+    }
+    accs[i] = acc;
+  }
+
+  // NOTE: A barrier is required because the shared memory space for logits
+  // is reused for the output.
+  __syncthreads();
+
+  // Perform reduction across warps.
+  float* out_smem = reinterpret_cast<float*>(shared_mem);
+#pragma unroll
+  for (int i = NUM_WARPS; i > 1; i /= 2) {
+    int mid = i / 2;
+    // Upper warps write to shared memory.
+    if (warp_idx >= mid && warp_idx < i) {
+      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
+#pragma unroll
+      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+          dst[row_idx] = accs[i];
+        }
+      }
+    }
+    __syncthreads();
+
+    // Lower warps update the output.
+    if (warp_idx < mid) {
+      const float* src = &out_smem[warp_idx * HEAD_SIZE];
+#pragma unroll
+      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+          accs[i] += src[row_idx];
+        }
+      }
+    }
+    __syncthreads();
+  }
+
+  // Write the final output.
+  if (warp_idx == 0) {
+    scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                            + head_idx * max_num_partitions * HEAD_SIZE
+                            + partition_idx * HEAD_SIZE;
+#pragma unroll
+    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+        from_float(*(out_ptr + row_idx), accs[i]);
+      }
+    }
+  }
+}
+
+// Grid: (num_heads, num_seqs, 1).
+template<
+  typename scalar_t,
+  typename cache_t,
+  int HEAD_SIZE,
+  int BLOCK_SIZE,
+  int NUM_THREADS,
+  kv_cache_dtype KV_CACHE_DTYPE>
+__global__ void paged_attention_v1_kernel(
+  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
+  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+  const int num_kv_heads,                 // [num_heads]
+  const float scale,
+  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+  const int* __restrict__ context_lens,   // [num_seqs]
+  const int max_num_blocks_per_seq,
+  const float* __restrict__ alibi_slopes, // [num_heads]
+  const int q_stride,
+  const int kv_block_stride,
+  const int kv_head_stride,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>(
+    /* exp_sums */ nullptr, /* max_logits */ nullptr,
+    out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
+    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
+}
+
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template<
+  typename scalar_t,
+  typename cache_t,
+  int HEAD_SIZE,
+  int BLOCK_SIZE,
+  int NUM_THREADS,
+  kv_cache_dtype KV_CACHE_DTYPE,
+  int PARTITION_SIZE>
+__global__ void paged_attention_v2_kernel(
+  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
+  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
+  scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
+  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+  const int num_kv_heads,                 // [num_heads]
+  const float scale,
+  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+  const int* __restrict__ context_lens,   // [num_seqs]
+  const int max_num_blocks_per_seq,
+  const float* __restrict__ alibi_slopes, // [num_heads]
+  const int q_stride,
+  const int kv_block_stride,
+  const int kv_head_stride,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE, PARTITION_SIZE>(
+    exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
+    block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
+    q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
+}
+
+// Grid: (num_heads, num_seqs).
+template<
+  typename scalar_t,
+  int HEAD_SIZE,
+  int NUM_THREADS,
+  int PARTITION_SIZE>
+__global__ void paged_attention_v2_reduce_kernel(
+  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
+  const float* __restrict__ exp_sums,     // [num_seqs, num_heads, max_num_partitions]
+  const float* __restrict__ max_logits,   // [num_seqs, num_heads, max_num_partitions]
+  const scalar_t* __restrict__ tmp_out,   // [num_seqs, num_heads, max_num_partitions, head_size]
+  const int* __restrict__ context_lens,   // [num_seqs]
+  const int max_num_partitions) {
+  const int num_heads = gridDim.x;
+  const int head_idx = blockIdx.x;
+  const int seq_idx = blockIdx.y;
+  const int context_len = context_lens[seq_idx];
+  const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+  if (num_partitions == 1) {
+    // No need to reduce. Only copy tmp_out to out.
+    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+    const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                                          + head_idx * max_num_partitions * HEAD_SIZE;
+    for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
+      out_ptr[i] = tmp_out_ptr[i];
+    }
+    // Terminate the thread block.
+    return;
+  }
+
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  const int warp_idx = threadIdx.x / WARP_SIZE;
+  const int lane = threadIdx.x % WARP_SIZE;
+
+  // Size: 2 * num_partitions.
+  extern __shared__ char shared_mem[];
+  // Workspace for reduction.
+  __shared__ float red_smem[2 * NUM_WARPS];
+
+  // Load max logits to shared memory.
+  float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
+  const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+                                           + head_idx * max_num_partitions;
+  float max_logit = -FLT_MAX;
+  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+    const float l = max_logits_ptr[i];
+    shared_max_logits[i] = l;
+    max_logit = fmaxf(max_logit, l);
+  }
+  __syncthreads();
+
+  // Get the global max logit.
+  // Reduce within the warp.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+    max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
+  }
+  if (lane == 0) {
+    red_smem[warp_idx] = max_logit;
+  }
+  __syncthreads();
+  // Reduce across warps.
+  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
+  }
+  // Broadcast the max value to all threads.
+  max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
+
+  // Load rescaled exp sums to shared memory.
+  float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
+  const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+                                       + head_idx * max_num_partitions;
+  float global_exp_sum = 0.0f;
+  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+    float l = shared_max_logits[i];
+    float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
+    global_exp_sum += rescaled_exp_sum;
+    shared_exp_sums[i] = rescaled_exp_sum;
+  }
+  __syncthreads();
+  global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
+  const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
+
+  // Aggregate tmp_out to out.
+  const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                                        + head_idx * max_num_partitions * HEAD_SIZE;
+  scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+#pragma unroll
+  for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
+    float acc = 0.0f;
+    for (int j = 0; j < num_partitions; ++j) {
+      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
+    }
+    from_float(out_ptr[i], acc);
+  }
+}
+
+} // namespace aphrodite
+
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                        \
+  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                        \
+    ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,    \
+      KV_CACHE_DTYPE>), shared_mem_size);                                                           \
+  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,              \
+  KV_CACHE_DTYPE><<<grid, block, shared_mem_size, stream>>>(                                        \
+    out_ptr,                                                                                        \
+    query_ptr,                                                                                      \
+    key_cache_ptr,                                                                                  \
+    value_cache_ptr,                                                                                \
+    num_kv_heads,                                                                                   \
+    scale,                                                                                          \
+    block_tables_ptr,                                                                               \
+    context_lens_ptr,                                                                               \
+    max_num_blocks_per_seq,                                                                         \
+    alibi_slopes_ptr,                                                                               \
+    q_stride,                                                                                       \
+    kv_block_stride,                                                                                \
+    kv_head_stride,                                                                                 \
+    k_scale,                                                                                        \
+    k_zp,                                                                                           \
+    v_scale,                                                                                        \
+    v_zp);
+
+// TODO: Tune NUM_THREADS.
+template<
+  typename T,
+  typename CACHE_T,
+  int BLOCK_SIZE,
+  kv_cache_dtype KV_CACHE_DTYPE,
+  int NUM_THREADS = 128>
+void paged_attention_v1_launcher(
+  torch::Tensor& out,
+  torch::Tensor& query,
+  torch::Tensor& key_cache,
+  torch::Tensor& value_cache,
+  int num_kv_heads,
+  float scale,
+  torch::Tensor& block_tables,
+  torch::Tensor& context_lens,
+  int max_context_len,
+  const c10::optional<torch::Tensor>& alibi_slopes,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  int num_seqs = query.size(0);
+  int num_heads = query.size(1);
+  int head_size = query.size(2);
+  int max_num_blocks_per_seq = block_tables.size(1);
+  int q_stride = query.stride(0);
+  int kv_block_stride = key_cache.stride(0);
+  int kv_head_stride = key_cache.stride(1);
+
+  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  assert(head_size % thread_group_size == 0);
+
+  // NOTE: alibi_slopes is optional.
+  const float* alibi_slopes_ptr = alibi_slopes ?
+    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+    : nullptr;
+
+  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
+  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
+  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
+  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
+  int* block_tables_ptr = block_tables.data_ptr<int>();
+  int* context_lens_ptr = context_lens.data_ptr<int>();
+
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
+  int logits_size = padded_max_context_len * sizeof(float);
+  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+  // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
+  // Keep that in sync with the logic here!
+  int shared_mem_size = std::max(logits_size, outputs_size);
+
+  dim3 grid(num_heads, num_seqs, 1);
+  dim3 block(NUM_THREADS);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  switch (head_size) {
+    // NOTE: To reduce the compilation time, we only compile for the
+    // head sizes that we use in the model. However, we can easily extend this
+    // to support any head size which is a multiple of 16.
+    case 64:
+      LAUNCH_PAGED_ATTENTION_V1(64);
+      break;
+    case 80:
+      LAUNCH_PAGED_ATTENTION_V1(80);
+      break;
+    case 96:
+      LAUNCH_PAGED_ATTENTION_V1(96);
+      break;
+    case 112:
+      LAUNCH_PAGED_ATTENTION_V1(112);
+      break;
+    case 128:
+      LAUNCH_PAGED_ATTENTION_V1(128);
+      break;
+    case 256:
+      LAUNCH_PAGED_ATTENTION_V1(256);
+      break;
+    default:
+      TORCH_CHECK(false, "Unsupported head size: ", head_size);
+      break;
+  }
+}
+
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)             \
+  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(       \
+    out,                                                                     \
+    query,                                                                   \
+    key_cache,                                                               \
+    value_cache,                                                             \
+    num_kv_heads,                                                            \
+    scale,                                                                   \
+    block_tables,                                                            \
+    context_lens,                                                            \
+    max_context_len,                                                         \
+    alibi_slopes,                                                            \
+    k_scale,                                                                 \
+    k_zp,                                                                    \
+    v_scale,                                                                 \
+    v_zp);
+
+// NOTE: To reduce the compilation time, we omitted block sizes
+// 1, 2, 4, 64, 128, 256.
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)       \
+  switch (block_size) {                                               \
+    case 8:                                                           \
+      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                \
+      break;                                                          \
+    case 16:                                                          \
+      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);               \
+      break;                                                          \
+    case 32:                                                          \
+      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);               \
+      break;                                                          \
+    default:                                                          \
+      TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
+      break;                                                          \
+  }
+
+void paged_attention_v1(
+  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
+  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
+  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
+  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
+  int num_kv_heads,               // [num_heads]
+  float scale,
+  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
+  torch::Tensor& context_lens,    // [num_seqs]
+  int block_size,
+  int max_context_len,
+  const c10::optional<torch::Tensor>& alibi_slopes,
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f) {
+  if (kv_cache_dtype == "auto") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+#ifdef ENABLE_FP8_E5M2
+  } else if (kv_cache_dtype == "fp8_e5m2") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+#endif
+  } else if (kv_cache_dtype == "int8") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+  } else {
+    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+  }
+}
+
+#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
+  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,        \
+  KV_CACHE_DTYPE, PARTITION_SIZE>                                                             \
+  <<<grid, block, shared_mem_size, stream>>>(                                                 \
+    exp_sums_ptr,                                                                             \
+    max_logits_ptr,                                                                           \
+    tmp_out_ptr,                                                                              \
+    query_ptr,                                                                                \
+    key_cache_ptr,                                                                            \
+    value_cache_ptr,                                                                          \
+    num_kv_heads,                                                                             \
+    scale,                                                                                    \
+    block_tables_ptr,                                                                         \
+    context_lens_ptr,                                                                         \
+    max_num_blocks_per_seq,                                                                   \
+    alibi_slopes_ptr,                                                                         \
+    q_stride,                                                                                 \
+    kv_block_stride,                                                                          \
+    kv_head_stride,                                                                           \
+    k_scale,                                                                                  \
+    k_zp,                                                                                     \
+    v_scale,                                                                                  \
+    v_zp);                                                                                    \
+  aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
+  <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
+    out_ptr,                                                                                  \
+    exp_sums_ptr,                                                                             \
+    max_logits_ptr,                                                                           \
+    tmp_out_ptr,                                                                              \
+    context_lens_ptr,                                                                         \
+    max_num_partitions);
+
+template<
+  typename T,
+  typename CACHE_T,
+  int BLOCK_SIZE,
+  kv_cache_dtype KV_CACHE_DTYPE,
+  int NUM_THREADS = 128,
+  int PARTITION_SIZE = 512>
+void paged_attention_v2_launcher(
+  torch::Tensor& out,
+  torch::Tensor& exp_sums,
+  torch::Tensor& max_logits,
+  torch::Tensor& tmp_out,
+  torch::Tensor& query,
+  torch::Tensor& key_cache,
+  torch::Tensor& value_cache,
+  int num_kv_heads,
+  float scale,
+  torch::Tensor& block_tables,
+  torch::Tensor& context_lens,
+  int max_context_len,
+  const c10::optional<torch::Tensor>& alibi_slopes,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  int num_seqs = query.size(0);
+  int num_heads = query.size(1);
+  int head_size = query.size(2);
+  int max_num_blocks_per_seq = block_tables.size(1);
+  int q_stride = query.stride(0);
+  int kv_block_stride = key_cache.stride(0);
+  int kv_head_stride = key_cache.stride(1);
+
+  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  assert(head_size % thread_group_size == 0);
+
+  // NOTE: alibi_slopes is optional.
+  const float* alibi_slopes_ptr = alibi_slopes ?
+    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+    : nullptr;
+
+  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
+  float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
+  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
+  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
+  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
+  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
+  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
+  int* block_tables_ptr = block_tables.data_ptr<int>();
+  int* context_lens_ptr = context_lens.data_ptr<int>();
+
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
+  int logits_size = PARTITION_SIZE * sizeof(float);
+  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+
+  // For paged attention v2 kernel.
+  dim3 grid(num_heads, num_seqs, max_num_partitions);
+  int shared_mem_size = std::max(logits_size, outputs_size);
+  // For paged attention v2 reduce kernel.
+  dim3 reduce_grid(num_heads, num_seqs);
+  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
+
+  dim3 block(NUM_THREADS);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  switch (head_size) {
+    // NOTE: To reduce the compilation time, we only compile for the
+    // head sizes that we use in the model. However, we can easily extend this
+    // to support any head size which is a multiple of 16.
+    case 64:
+      LAUNCH_PAGED_ATTENTION_V2(64);
+      break;
+    case 80:
+      LAUNCH_PAGED_ATTENTION_V2(80);
+      break;
+    case 96:
+      LAUNCH_PAGED_ATTENTION_V2(96);
+      break;
+    case 112:
+      LAUNCH_PAGED_ATTENTION_V2(112);
+      break;
+    case 128:
+      LAUNCH_PAGED_ATTENTION_V2(128);
+      break;
+    case 256:
+      LAUNCH_PAGED_ATTENTION_V2(256);
+      break;
+    default:
+      TORCH_CHECK(false, "Unsupported head size: ", head_size);
+      break;
+  }
+}
+
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)                 \
+  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(           \
+    out,                                                                         \
+    exp_sums,                                                                    \
+    max_logits,                                                                  \
+    tmp_out,                                                                     \
+    query,                                                                       \
+    key_cache,                                                                   \
+    value_cache,                                                                 \
+    num_kv_heads,                                                                \
+    scale,                                                                       \
+    block_tables,                                                                \
+    context_lens,                                                                \
+    max_context_len,                                                             \
+    alibi_slopes,                                                                \
+    k_scale,                                                                     \
+    k_zp,                                                                        \
+    v_scale,                                                                     \
+    v_zp);
+
+// NOTE: To reduce the compilation time, we omitted block sizes
+// 1, 2, 4, 64, 128, 256.
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)             \
+  switch (block_size) {                                                     \
+    case 8:                                                                 \
+      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                      \
+      break;                                                                \
+    case 16:                                                                \
+      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);                     \
+      break;                                                                \
+    case 32:                                                                \
+      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);                     \
+      break;                                                                \
+    default:                                                                \
+      TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
+      break;                                                                \
+  }
+
+void paged_attention_v2(
+  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
+  torch::Tensor& exp_sums,        // [num_seqs, num_heads, max_num_partitions]
+  torch::Tensor& max_logits,      // [num_seqs, num_heads, max_num_partitions]
+  torch::Tensor& tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
+  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
+  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
+  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
+  int num_kv_heads,               // [num_heads]
+  float scale,
+  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
+  torch::Tensor& context_lens,    // [num_seqs]
+  int block_size,
+  int max_context_len,
+  const c10::optional<torch::Tensor>& alibi_slopes,
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f) {
+  if (kv_cache_dtype == "auto") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+#ifdef ENABLE_FP8_E5M2
+  } else if (kv_cache_dtype == "fp8_e5m2") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+#endif
+  } else if (kv_cache_dtype == "int8") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+  } else {
+    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+  }
+}
+
+#undef WARP_SIZE
+#undef MAX
+#undef MIN
+#undef DIVIDE_ROUND_UP

+ 39 - 0
kernels/backup/cache.h

@@ -0,0 +1,39 @@
+#pragma once
+
+#include <torch/extension.h>
+
+#include <map>
+#include <vector>
+
+void swap_blocks(
+  torch::Tensor& src,
+  torch::Tensor& dst,
+  const std::map<int64_t, int64_t>& block_mapping);
+
+void copy_blocks(
+  std::vector<torch::Tensor>& key_caches,
+  std::vector<torch::Tensor>& value_caches,
+  const std::map<int64_t, std::vector<int64_t>>& block_mapping);
+
+void reshape_and_cache(
+  torch::Tensor& key,   
+  torch::Tensor& value, 
+  torch::Tensor& key_cache, 
+  torch::Tensor& value_cache, 
+  torch::Tensor& slot_mapping, 
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f);
+
+void gather_cached_kv(
+  torch::Tensor& key,
+  torch::Tensor& value,
+  torch::Tensor& key_cache,
+  torch::Tensor& value_cache,
+  torch::Tensor& slot_mapping);
+
+void convert_fp8_e5m2(
+  torch::Tensor& src_cache,
+  torch::Tensor& dst_cache);

+ 512 - 0
kernels/backup/cache_kernels.cu

@@ -0,0 +1,512 @@
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+
+#include "cuda_compat.h"
+#include "dispatch_utils.h"
+#include "quantization/int8_kvcache/quant_utils.cuh"
+#ifdef ENABLE_FP8_E5M2
+#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+#endif
+
+#include <algorithm>
+#include <cassert>
+#include <map>
+#include <vector>
+
+enum kv_cache_dtype {
+  AUTO,
+#ifdef ENABLE_FP8_E5M2
+  FP8_E5M2,
+#endif
+  INT8};
+
+#ifdef USE_ROCM
+  #include <hip/hip_bf16.h>
+  typedef __hip_bfloat16 __nv_bfloat16;
+#endif
+
+void swap_blocks(
+  torch::Tensor& src,
+  torch::Tensor& dst,
+  const std::map<int64_t, int64_t>& block_mapping) {
+  torch::Device src_device = src.device();
+  torch::Device dst_device = dst.device();
+  cudaMemcpyKind memcpy_type;
+  if (src_device.is_cuda() && dst_device.is_cuda()) {
+    TORCH_CHECK(
+      src_device.index() == dst_device.index(),
+      "src and dst must be on the same GPU");
+    memcpy_type = cudaMemcpyDeviceToDevice;
+  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
+    memcpy_type = cudaMemcpyDeviceToHost;
+  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
+    memcpy_type = cudaMemcpyHostToDevice;
+  } else {
+    TORCH_CHECK(false, "Invalid device combination");
+  }
+
+  char *src_ptr = static_cast<char*>(src.data_ptr());
+  char *dst_ptr = static_cast<char*>(dst.data_ptr());
+
+  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
+  const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  // NOTE: This can be slow if the number of blocks is large.
+  for (const auto& pair : block_mapping) {
+    int64_t src_block_number = pair.first;
+    int64_t dst_block_number = pair.second;
+    int64_t src_offset = src_block_number * block_size_in_bytes;
+    int64_t dst_offset = dst_block_number * block_size_in_bytes;
+    cudaMemcpyAsync(
+      dst_ptr + dst_offset,
+      src_ptr + src_offset,
+      block_size_in_bytes,
+      memcpy_type,
+      stream);
+  }
+}
+
+namespace aphrodite {
+
+// Grid: (num_layers, num_pairs)
+template<typename scalar_t>
+__global__ void copy_blocks_kernel(
+  int64_t* key_cache_ptrs,
+  int64_t* value_cache_ptrs,
+  const int64_t* __restrict__ block_mapping,
+  const int numel_per_block) {
+  const int layer_idx = blockIdx.x;
+  const int pair_idx = blockIdx.y;
+
+  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
+  scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
+  int64_t src_block_number = block_mapping[2 * pair_idx];
+  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
+
+  const int64_t src_block_offset = src_block_number * numel_per_block;
+  const int64_t dst_block_offset = dst_block_number * numel_per_block;
+  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+    int64_t src_offset = src_block_offset + i;
+    int64_t dst_offset = dst_block_offset + i;
+    key_cache[dst_offset] = key_cache[src_offset];
+  }
+  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+    int64_t src_offset = src_block_offset + i;
+    int64_t dst_offset = dst_block_offset + i;
+    value_cache[dst_offset] = value_cache[src_offset];
+  }
+}
+
+} // namespace aphrodite
+
+void copy_blocks(
+  std::vector<torch::Tensor>& key_caches,
+  std::vector<torch::Tensor>& value_caches,
+  const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
+  int num_layers = key_caches.size();
+  TORCH_CHECK(num_layers == value_caches.size());
+  if (num_layers == 0) {
+    return;
+  }
+  torch::Device cache_device = key_caches[0].device();
+  TORCH_CHECK(cache_device.is_cuda());
+
+  // Create data structures for the kernel.
+  // Create an array of pointers to the key and value caches.
+  int64_t key_cache_ptrs[num_layers];
+  int64_t value_cache_ptrs[num_layers];
+  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
+    key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
+    value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
+  }
+  // Create block mapping array.
+  std::vector<int64_t> block_mapping_vec;
+  for (const auto& pair : block_mapping) {
+    int64_t src_block_number = pair.first;
+    for (int64_t dst_block_number : pair.second) {
+      block_mapping_vec.push_back(src_block_number);
+      block_mapping_vec.push_back(dst_block_number);
+    }
+  }
+  int64_t* block_mapping_array = block_mapping_vec.data();
+  int num_pairs = block_mapping_vec.size() / 2;
+
+  // Move the data structures to the GPU.
+  // NOTE: This synchronizes the CPU and GPU.
+  torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
+    key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
+  torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
+    value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
+  torch::Tensor block_mapping_tensor = torch::from_blob(
+    block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
+
+  // Launch the kernel.
+  const int numel_per_block = key_caches[0][0].numel();
+  dim3 grid(num_layers, num_pairs);
+  dim3 block(std::min(1024, numel_per_block));
+  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
+    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
+      aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
+        key_cache_ptrs_tensor.data_ptr<int64_t>(),
+        value_cache_ptrs_tensor.data_ptr<int64_t>(),
+        block_mapping_tensor.data_ptr<int64_t>(),
+        numel_per_block);
+    }));
+}
+
+namespace aphrodite {
+
+template<typename scalar_t, typename cache_t, kv_cache_dtype KV_CACHE_DTYPE>
+__global__ void reshape_and_cache_kernel(
+  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
+  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
+  cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x]
+  cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size]
+  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
+  const int key_stride,
+  const int value_stride,
+  const int num_heads,
+  const int head_size,
+  const int block_size,
+  const int x,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  const int64_t token_idx = blockIdx.x;
+  const int64_t slot_idx = slot_mapping[token_idx];
+  if (slot_idx < 0) {
+    // Padding token that should be ignored.
+    return;
+  }
+
+  const int64_t block_idx = slot_idx / block_size;
+  const int64_t block_offset = slot_idx % block_size;
+
+  const int n = num_heads * head_size;
+  for (int i = threadIdx.x; i < n; i += blockDim.x) {
+    const int64_t src_key_idx = token_idx * key_stride + i;
+    const int64_t src_value_idx = token_idx * value_stride + i;
+
+    const int head_idx = i / head_size;
+    const int head_offset = i % head_size;
+    const int x_idx = head_offset / x;
+    const int x_offset = head_offset % x;
+
+    const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+                                + head_idx * (head_size / x) * block_size * x
+                                + x_idx * block_size * x
+                                + block_offset * x
+                                + x_offset;
+    const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
+                                  + head_idx * head_size * block_size
+                                  + head_offset * block_size
+                                  + block_offset;
+    scalar_t tgt_key = key[src_key_idx];
+    scalar_t tgt_value = value[src_value_idx];
+    if constexpr (KV_CACHE_DTYPE == INT8) {
+      key_cache[tgt_key_idx] = int8::quant(tgt_key, k_scale, k_zp);
+      value_cache[tgt_value_idx] = int8::quant(tgt_value, v_scale, v_zp);
+#ifdef ENABLE_FP8_E5M2
+    } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
+      key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
+      value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
+#endif
+    } else {
+      key_cache[tgt_key_idx] = tgt_key;
+      value_cache[tgt_value_idx] = tgt_value;
+    }
+  }
+}
+
+} // namespace aphrodite
+
+#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_CACHE_DTYPE)                                      \
+  aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, KV_CACHE_DTYPE><<<grid, block, 0, stream>>>(  \
+    reinterpret_cast<KV_T*>(key.data_ptr()),                                                       \
+    reinterpret_cast<KV_T*>(value.data_ptr()),                                                     \
+    reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                              \
+    reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                            \
+    slot_mapping.data_ptr<int64_t>(),                                                              \
+    key_stride,                                                                                    \
+    value_stride,                                                                                  \
+    num_heads,                                                                                     \
+    head_size,                                                                                     \
+    block_size,                                                                                    \
+    x,                                                                                             \
+    k_scale,                                                                                       \
+    k_zp,                                                                                          \
+    v_scale,                                                                                       \
+    v_zp);
+
+void reshape_and_cache(
+  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
+  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
+  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
+  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
+  torch::Tensor& slot_mapping,  // [num_tokens]
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f)
+{
+  int num_tokens = key.size(0);
+  int num_heads = key.size(1);
+  int head_size = key.size(2);
+  int block_size = key_cache.size(3);
+  int x = key_cache.size(4);
+
+  int key_stride = key.stride(0);
+  int value_stride = value.stride(0);
+
+  dim3 grid(num_tokens);
+  dim3 block(std::min(num_heads * head_size, 512));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  if (kv_cache_dtype == "auto") {
+    if (key.dtype() == at::ScalarType::Float) {
+      CALL_RESHAPE_AND_CACHE(float, float, AUTO);
+    } else if (key.dtype() == at::ScalarType::Half) {
+      CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, AUTO);
+    } else if (key.dtype() == at::ScalarType::BFloat16) {
+      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, AUTO);
+    }
+#ifdef ENABLE_FP8_E5M2
+  } else if (kv_cache_dtype == "fp8_e5m2") {
+    if (key.dtype() == at::ScalarType::Float) {
+      CALL_RESHAPE_AND_CACHE(float, uint8_t, FP8_E5M2);
+    } else if (key.dtype() == at::ScalarType::Half) {
+      CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, FP8_E5M2);
+    } else if (key.dtype() == at::ScalarType::BFloat16) {
+      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, FP8_E5M2);
+    }
+#endif
+  } else if (kv_cache_dtype == "int8") {
+    if (key.dtype() == at::ScalarType::Float) {
+      CALL_RESHAPE_AND_CACHE(float, int8_t, INT8);
+    } else if (key.dtype() == at::ScalarType::Half) {
+      CALL_RESHAPE_AND_CACHE(uint16_t, int8_t, INT8);
+    } else if (key.dtype() == at::ScalarType::BFloat16) {
+      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, int8_t, INT8);
+    }
+  } else {
+    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+  }
+}
+
+namespace aphrodite {
+
+// Grid: (num_blocks, block_size).
+template<typename scalar_t>
+__global__ void gather_cached_kv_kernel(
+  scalar_t* __restrict__ key,             // [num_tokens, [stride], num_heads, head_size]
+  scalar_t* __restrict__ value,           // [num_tokens, [stride], num_heads, head_size]
+  const scalar_t* __restrict__ key_cache,   // [num_blocks, num_heads, head_size/x, block_size, x]
+  const scalar_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size, block_size]
+  const int* __restrict__ slot_mapping,   // [num_tokens]
+  const int key_stride,
+  const int value_stride,
+  const int num_heads,
+  const int head_size,
+  const int block_size,
+  const int x) {
+    const int token_idx = blockIdx.x;
+    const int slot_idx = slot_mapping[token_idx];
+    const int block_idx = slot_idx / block_size;
+    const int block_offset = slot_idx % block_size;
+
+    const int num_tokens = num_heads * head_size;
+    for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
+      const int tgt_key_idx = token_idx * key_stride + i;
+      const int tgt_value_idx = token_idx * value_stride + i;
+
+      const int head_idx = i / head_size;
+      const int head_offset = i % head_size;
+      const int x_idx = head_offset / x;  // the offset of the [head_size/x] dimension
+      const int x_offset = head_offset % x;
+
+      const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+                              + head_idx * (head_size / x) * block_size * x
+                              + x_idx * block_size * x
+                              + block_offset * x
+                              + x_offset;
+      const int src_value_idx = block_idx * num_heads * head_size * block_size
+                                + head_idx * head_size * block_size
+                                + head_offset * block_size
+                                + block_offset;
+
+      key[tgt_key_idx] = APHRODITE_LDG(&key_cache[src_key_idx]);
+      value[tgt_value_idx] = APHRODITE_LDG(&value_cache[src_value_idx]);
+    }
+}
+
+template <typename scalar_t>
+__global__ void gather_cached_kv_kernel_optimized(
+    scalar_t *__restrict__ key,             // [num_tokens, [stride], num_heads, head_size]
+    scalar_t *__restrict__ value,           // [num_tokens, [stride], num_heads, head_size]
+    const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+    const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
+    const int *__restrict__ slot_mapping,   // [num_tokens]
+    const int key_stride,
+    const int value_stride,
+    const int num_heads,
+    const int head_size,
+    const int block_size,
+    const int x)
+{
+    const int token_idx = blockIdx.x;
+    const int slot_idx = slot_mapping[token_idx];
+    const int block_idx = slot_idx / block_size;
+    const int block_offset = slot_idx % block_size;
+
+    const int dim = num_heads * head_size;
+    assert(dim % 4 == 0);  // this is true for known use cases
+    const int unroll_factor = 4;
+    const int unrolled_dim = dim / unroll_factor;
+
+    for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
+    {
+        int tgt_key_indices[unroll_factor];
+        int tgt_value_indices[unroll_factor];
+        int src_key_indices[unroll_factor];
+        int src_value_indices[unroll_factor];
+        scalar_t keys_to_store[unroll_factor];
+        scalar_t values_to_store[unroll_factor];
+
+        #pragma unroll
+        for (int j = 0; j < unroll_factor; ++j)
+        {
+            int index = i + j * unrolled_dim;
+
+            const int tgt_key_idx = token_idx * key_stride + index;
+            const int tgt_value_idx = token_idx * value_stride + index;
+
+            const int head_idx = index / head_size;
+            const int head_offset = index % head_size;
+            const int x_idx = head_offset / x;
+            const int x_offset = head_offset % x;
+
+            const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+                                    + head_idx * (head_size / x) * block_size * x
+                                    + x_idx * block_size * x
+                                    + block_offset * x
+                                    + x_offset;
+            const int src_value_idx = block_idx * num_heads * head_size * block_size
+                                      + head_idx * head_size * block_size
+                                      + head_offset * block_size
+                                      + block_offset;
+
+            tgt_key_indices[j] = tgt_key_idx;
+            tgt_value_indices[j] = tgt_value_idx;
+            src_key_indices[j] = src_key_idx;
+            src_value_indices[j] = src_value_idx;
+
+            keys_to_store[j] = APHRODITE_LDG(&key_cache[src_key_idx]);
+            values_to_store[j] = APHRODITE_LDG(&value_cache[src_value_idx]);
+        }
+
+        #pragma unroll
+        for (int j = 0; j < unroll_factor; ++j)
+        {
+            key[tgt_key_indices[j]] = keys_to_store[j];
+            value[tgt_value_indices[j]] = values_to_store[j];
+        }
+    }
+}
+
+} // namespace aphrodite
+
+void gather_cached_kv(
+  torch::Tensor& key,           // [out] [num_tokens, num_heads, head_size]
+  torch::Tensor& value,         // [out] [num_tokens, num_heads, head_size]
+  torch::Tensor& key_cache,     // [in]  [num_blocks, num_heads, head_size/x, block_size, x]
+  torch::Tensor& value_cache,   // [in]  [num_blocks, num_heads, head_size, block_size]
+  torch::Tensor& slot_mapping)  // [in]  [num_tokens]
+{
+  int num_tokens = key.size(0);
+  int num_heads = key.size(1);
+  int head_size = key.size(2);
+  int block_size = key_cache.size(3);
+  int x = key_cache.size(4);
+
+  int key_stride = key.stride(0);
+  int value_stride = value.stride(0);
+
+  dim3 grid(num_tokens);
+  dim3 block(std::min(num_heads * head_size, 512));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
+    key.scalar_type(),
+    "gather_cached_kv_kernel_optimized",
+    [&] {
+      aphrodite::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
+        key.data_ptr<scalar_t>(),
+        value.data_ptr<scalar_t>(),
+        key_cache.data_ptr<scalar_t>(),
+        value_cache.data_ptr<scalar_t>(),
+        slot_mapping.data_ptr<int>(),
+        key_stride,
+        value_stride,
+        num_heads,
+        head_size,
+        block_size,
+        x);
+    });
+}
+
+namespace aphrodite {
+
+template<typename Tout, typename Tin>
+__global__ void convert_fp8_e5m2_kernel(
+  const Tin* __restrict__ src_cache,
+  Tout* __restrict__ dst_cache,
+  const int64_t block_stride) {
+  const int64_t block_idx = blockIdx.x;
+  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
+    int64_t idx = block_idx * block_stride + i;
+#ifdef ENABLE_FP8_E5M2
+    dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
+#else
+    assert(false);
+#endif
+  }
+}
+
+} // namespace aphrodite
+
+#define CALL_CONVERT_FP8_E5M2(Tout, Tin)                                 \
+  aphrodite::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>(  \
+    reinterpret_cast<Tin*>(src_cache.data_ptr()),                        \
+    reinterpret_cast<Tout*>(dst_cache.data_ptr()),                       \
+    block_stride);
+
+void convert_fp8_e5m2(
+  torch::Tensor& src_cache,
+  torch::Tensor& dst_cache)
+{
+  int64_t num_blocks = src_cache.size(0);
+  int64_t block_stride = src_cache.stride(0);
+
+  dim3 grid(num_blocks);
+  dim3 block(std::min(block_stride, int64_t(512)));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  if (src_cache.dtype() == at::ScalarType::Float) {
+    CALL_CONVERT_FP8_E5M2(uint8_t, float);
+  } else if (src_cache.dtype() == at::ScalarType::Half) {
+    CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
+  } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
+    CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
+  } else if (dst_cache.dtype() == at::ScalarType::Float) {
+    CALL_CONVERT_FP8_E5M2(float, uint8_t);
+  } else if (dst_cache.dtype() == at::ScalarType::Half) {
+    CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
+  } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
+    CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
+  }
+}

+ 39 - 0
kernels/backup/dispatch_utils.h

@@ -0,0 +1,39 @@
+/*
+ * Adapted from
+ * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
+ */
+#pragma once
+
+#include <torch/extension.h>
+
+#define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...)              \
+  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
+  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
+
+#define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \
+  AT_DISPATCH_SWITCH(                                             \
+    TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
+
+
+#define APHRODITE_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...)     \
+  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
+  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)   \
+  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
+
+#define APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...)           \
+  AT_DISPATCH_SWITCH(                                                    \
+    TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
+
+#define APHRODITE_DISPATCH_CASE_INTEGRAL_TYPES(...)             \
+  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)      \
+  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)      \
+  AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)     \
+  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
+
+#define APHRODITE_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...)             \
+  AT_DISPATCH_SWITCH(                                             \
+    TYPE, NAME, APHRODITE_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

+ 280 - 0
kernels/backup/dtype_float32.cuh

@@ -0,0 +1,280 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "attention_generic.cuh"
+
+#include <stdint.h>
+
+namespace aphrodite {
+
+// Define custom FP32 vector data types.
+struct Float4_ {
+  float2 x;
+  float2 y;
+};
+
+struct Float8_ {
+  float2 x;
+  float2 y;
+  float2 z;
+  float2 w;
+};
+
+// FP32 vector types for Q, K, V.
+template<>
+struct Vec<float, 1> {
+  using Type = float;
+};
+template<>
+struct Vec<float, 2> {
+  using Type = float2;
+};
+template<>
+struct Vec<float, 4> {
+  using Type = float4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template<>
+struct FloatVec<float> {
+  using Type = float;
+};
+template<>
+struct FloatVec<float2> {
+  using Type = float2;
+};
+template<>
+struct FloatVec<float4> {
+  using Type = float4;
+};
+
+// Vector addition.
+inline __device__ float add(float a, float b) {
+  return a + b;
+}
+
+inline __device__ float2 add(float2 a, float2 b) {
+  float2 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  return c;
+}
+
+inline __device__ float4 add(float4 a, float4 b) {
+  float4 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  c.z = add(a.z, b.z);
+  c.w = add(a.w, b.w);
+  return c;
+}
+
+inline __device__ Float4_ add(Float4_ a, Float4_ b) {
+  Float4_ c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  return c;
+}
+
+// Vector multiplication.
+template<>
+inline __device__ float mul<float, float>(float a, float b) {
+  return a * b;
+}
+
+template<>
+inline __device__ float2 mul(float2 a, float2 b) {
+  float2 c;
+  c.x = a.x * b.x;
+  c.y = a.y * b.y;
+  return c;
+}
+
+template<>
+inline __device__ float2 mul(float a, float2 b) {
+  float2 c;
+  c.x = a * b.x;
+  c.y = a * b.y;
+  return c;
+}
+
+template<>
+inline __device__ float4 mul(float4 a, float4 b) {
+  float4 c;
+  c.x = a.x * b.x;
+  c.y = a.y * b.y;
+  c.z = a.z * b.z;
+  c.w = a.w * b.w;
+  return c;
+}
+
+template<>
+inline __device__ float4 mul(float a, float4 b) {
+  float4 c;
+  c.x = a * b.x;
+  c.y = a * b.y;
+  c.z = a * b.z;
+  c.w = a * b.w;
+  return c;
+}
+
+// Vector fused multiply-add.
+inline __device__ float fma(float a, float b, float c) {
+  return a * b + c;
+}
+
+inline __device__ float2 fma(float2 a, float2 b, float2 c) {
+  float2 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  return d;
+}
+
+inline __device__ float2 fma(float a, float2 b, float2 c) {
+  float2 d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  return d;
+}
+
+inline __device__ float4 fma(float4 a, float4 b, float4 c) {
+  float4 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  d.z = fma(a.z, b.z, c.z);
+  d.w = fma(a.w, b.w, c.w);
+  return d;
+}
+
+inline __device__ float4 fma(float a, float4 b, float4 c) {
+  float4 d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  d.z = fma(a, b.z, c.z);
+  d.w = fma(a, b.w, c.w);
+  return d;
+}
+
+inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
+  Float4_ d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  return d;
+}
+
+inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
+  Float8_ d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  d.z = fma(a, b.z, c.z);
+  d.w = fma(a, b.w, c.w);
+  return d;
+}
+
+// Vector sum.
+template<>
+inline __device__ float sum(float v) {
+  return v;
+}
+
+template<>
+inline __device__ float sum(float2 v) {
+  return v.x + v.y;
+}
+
+template<>
+inline __device__ float sum(float4 v) {
+  return v.x + v.y + v.z + v.w;
+}
+
+template<>
+inline __device__ float sum(Float4_ v) {
+  return v.x.x + v.x.y + v.y.x + v.y.y;
+}
+
+template<>
+inline __device__ float sum(Float8_ v) {
+  return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
+}
+
+// Vector dot product.
+inline __device__ float dot(float a, float b) {
+  return a * b;
+}
+
+inline __device__ float dot(float2 a, float2 b) {
+  float2 c = mul<float2, float2, float2>(a, b);
+  return c.x + c.y;
+}
+
+inline __device__ float dot(Float4_ a, Float4_ b) {
+  float2 acc = mul<float2, float2, float2>(a.x, b.x);
+  acc = fma(a.y, b.y, acc);
+  return acc.x + acc.y;
+}
+
+inline __device__ float dot(Float8_ a, Float8_ b) {
+  float2 acc = mul<float2, float2, float2>(a.x, b.x);
+  acc = fma(a.y, b.y, acc);
+  acc = fma(a.z, b.z, acc);
+  acc = fma(a.w, b.w, acc);
+  return acc.x + acc.y;
+}
+
+// From float to float.
+inline __device__ void from_float(float& dst, float src) {
+  dst = src;
+}
+
+inline __device__ void from_float(float2& dst, float2 src) {
+  dst = src;
+}
+
+inline __device__ void from_float(float4& dst, float4 src) {
+  dst = src;
+}
+
+// From float to float.
+inline __device__ float to_float(float u) {
+  return u;
+}
+
+inline __device__ float2 to_float(float2 u) {
+  return u;
+}
+
+inline __device__ float4 to_float(float4 u) {
+  return u;
+}
+
+inline __device__ Float4_ to_float(Float4_ u) {
+  return u;
+}
+
+inline __device__ Float8_ to_float(Float8_ u) {
+  return u;
+}
+
+// Zero-out a variable.
+inline __device__ void zero(float& dst) {
+  dst = 0.f;
+}
+
+} // namespace aphrodite

+ 0 - 0
kernels/attention/dtype_int8.cuh → kernels/backup/dtype_int8.cuh


Niektoré súbory nie sú zobrazené, pretože je v týchto rozdielových dátach zmenené mnoho súborov