Browse Source

chore: attention refactor and upstream sync apr01 (#365)

* add speculative workers

* sync core processor logic with upstream

* add compiled dag api, sync logic

* do not log prompt token IDs

* add cache config to metrics

* add compiled dag in ray utils

* do not import cache ops if neuron backend is detected

* hope i didn't break anything

* modify worker

* add neuron worker

* add file lock

* remove lora loading from loader.py

* add neuronx model loader

* pick the correct loader based on device

* add separate attention backends

* migrate all models to the new attention backend

* neuron patch for llama

* modify rejection sampling

* add optimized moe configs

* add script for benchmarking local GPU for optimal moe config

* fix sampler output return

* logits as hidden state for lora layer

* health check in api server

* correct torch version in pyproject

* move config to where it belongs

* missing imports

* temporarily remove int8 kv cache

* logger fixes, set max log len to 0

* gracefully exit with keyboard interrupt

* explicitly abort a request when a cancellederror is raised

* remove unfinished configs

* update rocm's requirements

* formatting

* yapf

* parse version correctly for ray

* yapf again

* set enforce eager to true by default

* do not enable memory pinning when neuron
AlpinDale 11 months ago
parent
commit
f8dfac6372
100 changed files with 10036 additions and 2571 deletions
  1. 0 1
      .gitignore
  2. 98 33
      aphrodite/common/config.py
  3. 10 8
      aphrodite/common/logger.py
  4. 47 30
      aphrodite/common/outputs.py
  5. 13 0
      aphrodite/common/sampling_params.py
  6. 71 21
      aphrodite/common/sequence.py
  7. 67 27
      aphrodite/common/utils.py
  8. 1 1
      aphrodite/endpoints/llm.py
  9. 104 95
      aphrodite/endpoints/openai/api_server.py
  10. 275 97
      aphrodite/engine/aphrodite_engine.py
  11. 373 240
      aphrodite/engine/args_tools.py
  12. 45 28
      aphrodite/engine/async_aphrodite.py
  13. 79 21
      aphrodite/engine/metrics.py
  14. 21 5
      aphrodite/engine/ray_tools.py
  15. 0 0
      aphrodite/executor/__init__.py
  16. 76 0
      aphrodite/executor/executor_base.py
  17. 153 0
      aphrodite/executor/gpu_executor.py
  18. 78 0
      aphrodite/executor/neuron_executor.py
  19. 452 0
      aphrodite/executor/ray_gpu_executor.py
  20. 13 0
      aphrodite/executor/utils.py
  21. 4 0
      aphrodite/lora/layers.py
  22. 8 3
      aphrodite/modeling/hf_downloader.py
  23. 0 354
      aphrodite/modeling/layers/attention.py
  24. 93 0
      aphrodite/modeling/layers/attention/__init__.py
  25. 0 0
      aphrodite/modeling/layers/attention/backends/__init__.py
  26. 121 0
      aphrodite/modeling/layers/attention/backends/flash_attn.py
  27. 255 0
      aphrodite/modeling/layers/attention/backends/xformers.py
  28. 0 0
      aphrodite/modeling/layers/attention/ops/__init__.py
  29. 138 0
      aphrodite/modeling/layers/attention/ops/paged_attn.py
  30. 28 12
      aphrodite/modeling/layers/attention/ops/prefix_prefill.py
  31. 8 0
      aphrodite/modeling/layers/fused_moe/__init__.py
  32. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json
  33. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json
  34. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json
  35. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json
  36. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json
  37. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json
  38. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json
  39. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json
  40. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json
  41. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json
  42. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json
  43. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json
  44. 146 0
      aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json
  45. 9 0
      aphrodite/modeling/layers/fused_moe/configs/README
  46. 120 58
      aphrodite/modeling/layers/fused_moe/fused_moe.py
  47. 57 18
      aphrodite/modeling/layers/rejection.py
  48. 1 2
      aphrodite/modeling/layers/sampler.py
  49. 62 24
      aphrodite/modeling/loader.py
  50. 5 4
      aphrodite/modeling/metadata.py
  51. 3 4
      aphrodite/modeling/models/baichuan.py
  52. 5 5
      aphrodite/modeling/models/bloom.py
  53. 2 2
      aphrodite/modeling/models/chatglm.py
  54. 1 1
      aphrodite/modeling/models/cohere.py
  55. 3 3
      aphrodite/modeling/models/deepseek.py
  56. 14 14
      aphrodite/modeling/models/falcon.py
  57. 2 2
      aphrodite/modeling/models/gemma.py
  58. 2 4
      aphrodite/modeling/models/gpt2.py
  59. 5 5
      aphrodite/modeling/models/gpt_bigcode.py
  60. 2 2
      aphrodite/modeling/models/gpt_j.py
  61. 2 2
      aphrodite/modeling/models/gpt_neox.py
  62. 2 2
      aphrodite/modeling/models/internlm2.py
  63. 16 9
      aphrodite/modeling/models/llama.py
  64. 3 3
      aphrodite/modeling/models/mixtral.py
  65. 2 2
      aphrodite/modeling/models/mixtral_quant.py
  66. 6 6
      aphrodite/modeling/models/mpt.py
  67. 79 0
      aphrodite/modeling/models/neuron/llama.py
  68. 4 4
      aphrodite/modeling/models/olmo.py
  69. 4 4
      aphrodite/modeling/models/opt.py
  70. 2 2
      aphrodite/modeling/models/phi.py
  71. 2 2
      aphrodite/modeling/models/qwen.py
  72. 2 2
      aphrodite/modeling/models/qwen2.py
  73. 2 2
      aphrodite/modeling/models/stablelm.py
  74. 70 0
      aphrodite/modeling/neuron_loader.py
  75. 2 2
      aphrodite/modeling/sampling_metadata.py
  76. 17 0
      aphrodite/modeling/utils.py
  77. 39 48
      aphrodite/processing/block_manager.py
  78. 2 3
      aphrodite/processing/evictor.py
  79. 1 4
      aphrodite/processing/scheduler.py
  80. 398 0
      aphrodite/spec_decode/batch_expansion.py
  81. 77 0
      aphrodite/spec_decode/interfaces.py
  82. 175 0
      aphrodite/spec_decode/metrics.py
  83. 392 0
      aphrodite/spec_decode/multi_step_worker.py
  84. 394 0
      aphrodite/spec_decode/spec_decode_worker.py
  85. 101 0
      aphrodite/spec_decode/util.py
  86. 9 2
      aphrodite/task_handler/cache_engine.py
  87. 74 56
      aphrodite/task_handler/model_runner.py
  88. 204 0
      aphrodite/task_handler/neuron_worker.py
  89. 30 14
      aphrodite/task_handler/worker.py
  90. 1 2
      kernels/attention/attention_dtypes.h
  91. 936 1014
      kernels/attention/attention_kernels.cu
  92. 255 262
      kernels/attention/dtype_float32.cuh
  93. 1 0
      kernels/backup/README
  94. 8 0
      kernels/backup/attention_dtypes.h
  95. 1032 0
      kernels/backup/attention_kernels.cu
  96. 39 0
      kernels/backup/cache.h
  97. 512 0
      kernels/backup/cache_kernels.cu
  98. 39 0
      kernels/backup/dispatch_utils.h
  99. 280 0
      kernels/backup/dtype_float32.cuh
  100. 0 0
      kernels/backup/dtype_int8.cuh

+ 0 - 1
.gitignore

@@ -6,7 +6,6 @@ repos
 *.so
 *.so
 .conda
 .conda
 build
 build
-*.json
 dist*
 dist*
 .VSCodeCounter
 .VSCodeCounter
 conda/
 conda/

+ 98 - 33
aphrodite/common/config.py

@@ -8,7 +8,7 @@ import torch
 from transformers import PretrainedConfig
 from transformers import PretrainedConfig
 
 
 from aphrodite.transformers_utils.config import get_config
 from aphrodite.transformers_utils.config import get_config
-from aphrodite.common.utils import (get_cpu_memory, is_hip,
+from aphrodite.common.utils import (get_cpu_memory, is_hip, is_neuron,
                                     get_nvcc_cuda_version)
                                     get_nvcc_cuda_version)
 
 
 _GB = 1 << 30
 _GB = 1 << 30
@@ -43,6 +43,9 @@ class ModelConfig:
         revision: The specific model version to use. It can be a branch name,
         revision: The specific model version to use. It can be a branch name,
             a tag name, or a commit id. If unspecified, will use the default
             a tag name, or a commit id. If unspecified, will use the default
             version.
             version.
+        code_revision: The specific revision to use for the model code on
+            Hugging Face Hub. It can be a branch name, a tag name, or a 
+            commit id. If unspecified, will use the default version.
         tokenizer_revision: The specific tokenizer version to use. It can be a
         tokenizer_revision: The specific tokenizer version to use. It can be a
             branch name, a tag name, or a commit id. If unspecified, will use
             branch name, a tag name, or a commit id. If unspecified, will use
             the default version.
             the default version.
@@ -71,16 +74,18 @@ class ModelConfig:
         trust_remote_code: bool,
         trust_remote_code: bool,
         download_dir: Optional[str],
         download_dir: Optional[str],
         load_format: str,
         load_format: str,
-        dtype: str,
+        # dtype: str,
+        dtype: Union[str, torch.dtype],
         seed: int,
         seed: int,
         revision: Optional[str] = None,
         revision: Optional[str] = None,
+        code_revision: Optional[str] = None,
         tokenizer_revision: Optional[str] = None,
         tokenizer_revision: Optional[str] = None,
         max_model_len: Optional[int] = None,
         max_model_len: Optional[int] = None,
         quantization: Optional[str] = None,
         quantization: Optional[str] = None,
         load_in_4bit: bool = False,
         load_in_4bit: bool = False,
         load_in_8bit: bool = False,
         load_in_8bit: bool = False,
         load_in_smooth: bool = False,
         load_in_smooth: bool = False,
-        enforce_eager: bool = False,
+        enforce_eager: bool = True,
         max_context_len_to_capture: Optional[int] = None,
         max_context_len_to_capture: Optional[int] = None,
         max_log_probs: int = 10,
         max_log_probs: int = 10,
     ) -> None:
     ) -> None:
@@ -92,6 +97,7 @@ class ModelConfig:
         self.load_format = load_format
         self.load_format = load_format
         self.seed = seed
         self.seed = seed
         self.revision = revision
         self.revision = revision
+        self.code_revision = code_revision
         self.tokenizer_revision = tokenizer_revision
         self.tokenizer_revision = tokenizer_revision
         self.quantization = quantization
         self.quantization = quantization
         self.load_in_4bit = load_in_4bit
         self.load_in_4bit = load_in_4bit
@@ -106,14 +112,18 @@ class ModelConfig:
             # download model from ModelScope hub,
             # download model from ModelScope hub,
             # lazy import so that modelscope is not required for normal use.
             # lazy import so that modelscope is not required for normal use.
             from modelscope.hub.snapshot_download import snapshot_download  # pylint: disable=C
             from modelscope.hub.snapshot_download import snapshot_download  # pylint: disable=C
-            model_path = snapshot_download(model_id=model,
-                                           cache_dir=download_dir,
-                                           revision=revision)
+            if not os.path.exists(model):
+                model_path = snapshot_download(model_id=model,
+                                               cache_dir=download_dir,
+                                               revision=revision)
+            else:
+                model_path = model
             self.model = model_path
             self.model = model_path
             self.download_dir = model_path
             self.download_dir = model_path
             self.tokenizer = model_path
             self.tokenizer = model_path
 
 
-        self.hf_config = get_config(self.model, trust_remote_code, revision)
+        self.hf_config = get_config(self.model, trust_remote_code, revision,
+                                    code_revision)
         self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
         self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
         self.max_model_len = _get_and_verify_max_len(self.hf_config,
         self.max_model_len = _get_and_verify_max_len(self.hf_config,
                                                      max_model_len)
                                                      max_model_len)
@@ -177,6 +187,7 @@ class ModelConfig:
         # Parse quantization method from the HF model config, if available.
         # Parse quantization method from the HF model config, if available.
         hf_quant_config = getattr(self.hf_config, "quantization_config", None)
         hf_quant_config = getattr(self.hf_config, "quantization_config", None)
         if hf_quant_config is not None:
         if hf_quant_config is not None:
+
             hf_quant_method = str(hf_quant_config["quant_method"]).lower()
             hf_quant_method = str(hf_quant_config["quant_method"]).lower()
             # If the GPTQ model is serialized in marlin format, use marlin.
             # If the GPTQ model is serialized in marlin format, use marlin.
             if (hf_quant_method == "gptq"
             if (hf_quant_method == "gptq"
@@ -375,7 +386,7 @@ class CacheConfig:
         gpu_memory_utilization: float,
         gpu_memory_utilization: float,
         swap_space: int,
         swap_space: int,
         cache_dtype: str,
         cache_dtype: str,
-        cache_quant_params_path: Optional[str] = None,
+        # cache_quant_params_path: Optional[str] = None,
         sliding_window: Optional[int] = None,
         sliding_window: Optional[int] = None,
         context_shift: bool = False,
         context_shift: bool = False,
     ) -> None:
     ) -> None:
@@ -384,7 +395,7 @@ class CacheConfig:
         self.swap_space_bytes = swap_space * _GB
         self.swap_space_bytes = swap_space * _GB
         self.cache_dtype = cache_dtype
         self.cache_dtype = cache_dtype
         self.sliding_window = sliding_window
         self.sliding_window = sliding_window
-        self.cache_quant_params_path = cache_quant_params_path
+        # self.cache_quant_params_path = cache_quant_params_path
         self.context_shift = context_shift
         self.context_shift = context_shift
         self._verify_args()
         self._verify_args()
         self._verify_cache_dtype()
         self._verify_cache_dtype()
@@ -393,6 +404,11 @@ class CacheConfig:
         self.num_gpu_blocks = None
         self.num_gpu_blocks = None
         self.num_cpu_blocks = None
         self.num_cpu_blocks = None
 
 
+    def metrics_info(self):
+        # convert cache_config to dict(key: str, value: str) for prometheus
+        # metrics info
+        return {key: str(value) for key, value in self.__dict__.items()}
+
     def _verify_args(self) -> None:
     def _verify_args(self) -> None:
         if self.gpu_memory_utilization > 1.0:
         if self.gpu_memory_utilization > 1.0:
             raise ValueError(
             raise ValueError(
@@ -400,25 +416,24 @@ class CacheConfig:
                 f"{self.gpu_memory_utilization}.")
                 f"{self.gpu_memory_utilization}.")
 
 
     def _verify_cache_dtype(self) -> None:
     def _verify_cache_dtype(self) -> None:
-        if self.cache_dtype in ["auto", "int8"]:
+        if self.cache_dtype == "auto":
+            # if self.cache_dtype in ["auto", "int8"]:
             pass
             pass
         elif self.cache_dtype == "fp8_e5m2":
         elif self.cache_dtype == "fp8_e5m2":
-            nvcc_cuda_version = get_nvcc_cuda_version()
-            if nvcc_cuda_version < Version("11.8"):
-                raise ValueError(
-                    "FP8 is not supported when cuda version is lower than "
-                    "11.8. If you think you have the correct cuda version, "
-                    "please make sure you've properly exported CUDA_HOME.")
-            device_name = torch.cuda.get_device_name()
-            if "AMD" in device_name:
+            if is_hip():
                 raise NotImplementedError(
                 raise NotImplementedError(
                     "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
                     "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
+            nvcc_cuda_version = get_nvcc_cuda_version()
+            if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
+                raise ValueError(
+                    "FP8 is not supported when cuda version is lower than 11.8."
+                )
             logger.info(
             logger.info(
                 "Using fp8_e5m2 data type to store kv cache. It reduces "
                 "Using fp8_e5m2 data type to store kv cache. It reduces "
                 "the GPU memory footprint and boosts the performance. "
                 "the GPU memory footprint and boosts the performance. "
                 "But it may cause slight accuracy drop. "
                 "But it may cause slight accuracy drop. "
                 "Currently we only support fp8 without scaling factors and "
                 "Currently we only support fp8 without scaling factors and "
-                "make e5m2 as a default format.")
+                "use e5m2 as a default format.")
         else:
         else:
             raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
             raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
 
 
@@ -450,8 +465,13 @@ class ParallelConfig:
         worker_use_ray: Whether to use Ray for model workers. Will be set to
         worker_use_ray: Whether to use Ray for model workers. Will be set to
             True if either pipeline_parallel_size or tensor_parallel_size is
             True if either pipeline_parallel_size or tensor_parallel_size is
             greater than 1.
             greater than 1.
+        max_parallel_loading_workers: Maximum number of multiple batches
+            when load model sequentially. To avoid RAM OOM when using tensor
+            parallel and large models.
         disable_custom_all_reduce: Disable the custom all-reduce kernel and
         disable_custom_all_reduce: Disable the custom all-reduce kernel and
             fall back to NCCL.
             fall back to NCCL.
+        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
+            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
     """
     """
 
 
     def __init__(
     def __init__(
@@ -461,15 +481,26 @@ class ParallelConfig:
         worker_use_ray: bool,
         worker_use_ray: bool,
         max_parallel_loading_workers: Optional[int] = None,
         max_parallel_loading_workers: Optional[int] = None,
         disable_custom_all_reduce: bool = False,
         disable_custom_all_reduce: bool = False,
+        ray_workers_use_nsight: bool = False,
     ) -> None:
     ) -> None:
         self.pipeline_parallel_size = pipeline_parallel_size
         self.pipeline_parallel_size = pipeline_parallel_size
-        self.tensor_parallel_size = tensor_parallel_size
+        if is_neuron():
+            # For Neuron device support, here we assign TP=1 to avoid sharding
+            # within Aphrodite directly.
+            # Transformer-neuronx would take neuron_tp_degree attribute, and
+            # distribute the workload to multiple NeuronCores.
+            self.tensor_parallel_size = 1
+            self.neuron_tp_degree = tensor_parallel_size
+        else:
+            self.tensor_parallel_size = tensor_parallel_size
         self.worker_use_ray = worker_use_ray
         self.worker_use_ray = worker_use_ray
         self.max_parallel_loading_workers = max_parallel_loading_workers
         self.max_parallel_loading_workers = max_parallel_loading_workers
         self.disable_custom_all_reduce = disable_custom_all_reduce
         self.disable_custom_all_reduce = disable_custom_all_reduce
+        self.ray_workers_use_nsight = ray_workers_use_nsight
 
 
-        self.world_size = pipeline_parallel_size * tensor_parallel_size
-        if self.world_size > 1:
+        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
+        # Ray worker is not supported for Neuron backend.
+        if self.world_size > 1 and not is_neuron():
             self.worker_use_ray = True
             self.worker_use_ray = True
         self._verify_args()
         self._verify_args()
 
 
@@ -477,16 +508,29 @@ class ParallelConfig:
         if self.pipeline_parallel_size > 1:
         if self.pipeline_parallel_size > 1:
             raise NotImplementedError(
             raise NotImplementedError(
                 "Pipeline parallelism is not supported yet.")
                 "Pipeline parallelism is not supported yet.")
-        if is_hip():
+        if not self.disable_custom_all_reduce and self.world_size > 1:
+            if is_hip():
+                self.disable_custom_all_reduce = True
+                logger.info(
+                    "Disabled the custom all-reduce kernel because it is not "
+                    "supported on AMD GPUs.")
+            elif self.pipeline_parallel_size > 1:
+                self.disable_custom_all_reduce = True
+                logger.info(
+                    "Disabled the custom all-reduce kernel because it is not "
+                    "supported with pipeline parallelism.")
+        if self.ray_workers_use_nsight and not self.worker_use_ray:
+            raise ValueError("Unable to use nsight profiling unless workers "
+                             "run with Ray.")
+
+        # FIXME: Fix the stability issues and re-enable the custom
+        # all-reduce kernel.
+        if not self.disable_custom_all_reduce and self.world_size > 1:
             self.disable_custom_all_reduce = True
             self.disable_custom_all_reduce = True
             logger.info(
             logger.info(
-                "Disabled the custom all-reduce kernel because it is not "
-                "supported on AMD GPUs.")
-        elif self.pipeline_parallel_size > 1:
-            self.disable_custom_all_reduce = True
-            logger.info(
-                "Disabled the custom all-reduce kernel because it is not "
-                "supported with pipeline parallelism.")
+                "Custom all-reduce kernels are temporarily disabled due to "
+                "stability issues. We will re-enable them once the issues are "
+                "resolved.")
 
 
 
 
 class SchedulerConfig:
 class SchedulerConfig:
@@ -538,8 +582,29 @@ class SchedulerConfig:
 
 
 class DeviceConfig:
 class DeviceConfig:
 
 
-    def __init__(self, device: str = "cuda") -> None:
-        self.device = torch.device(device)
+    def __init__(self, device: str = "auto") -> None:
+        if device == "auto":
+            # Automated device type detection
+            if torch.cuda.is_available():
+                self.device_type = "cuda"
+            elif is_neuron():
+                self.device_type = "neuron"
+            else:
+                raise RuntimeError("No supported device detected.")
+        else:
+            # Device type is assigned explicitly
+            self.device_type = device
+
+        # Some device types require processing inputs on CPU
+        if self.device_type in ["neuron"]:
+            self.device = torch.device("cpu")
+        else:
+            # Set device with device type
+            self.device = torch.device(self.device_type)
+
+    @property
+    def is_neuron(self):
+        return self.device_type == "neuron"
 
 
 
 
 @dataclass
 @dataclass
@@ -571,7 +636,7 @@ class LoRAConfig:
         elif self.max_cpu_loras < self.max_loras:
         elif self.max_cpu_loras < self.max_loras:
             raise ValueError(
             raise ValueError(
                 f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
                 f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
-                f"max_num_seqs ({self.max_loras})")
+                f"max_loras ({self.max_loras})")
 
 
     def verify_with_model_config(self, model_config: ModelConfig):
     def verify_with_model_config(self, model_config: ModelConfig):
         if self.lora_dtype in (None, "auto"):
         if self.lora_dtype in (None, "auto"):

+ 10 - 8
aphrodite/common/logger.py

@@ -1,9 +1,10 @@
 """
 """
-Internal logging utility. Adapted from
-https://github.com/theroyallab/tabbyAPI/blob/4cc0b59bdc94e6342b6d1d7acadbadc63c740ed9/common/logger.py
+Internal logging utility.
 """
 """
 
 
 import logging
 import logging
+import os
+
 from loguru import logger
 from loguru import logger
 from rich.console import Console
 from rich.console import Console
 from rich.markup import escape
 from rich.markup import escape
@@ -17,6 +18,7 @@ from rich.progress import (
 )
 )
 
 
 RICH_CONSOLE = Console()
 RICH_CONSOLE = Console()
+LOG_LEVEL = os.getenv("APHRODITE_LOG_LEVEL", "INFO").upper()
 
 
 
 
 def unwrap(wrapped, default=None):
 def unwrap(wrapped, default=None):
@@ -60,9 +62,9 @@ def _log_formatter(record: dict):
     message = unwrap(record.get("message"), "")
     message = unwrap(record.get("message"), "")
 
 
     # Replace once loguru allows for turning off str.format
     # Replace once loguru allows for turning off str.format
-    message = message.replace("{{", "{{").replace("}}", "}}")
-    # Manually escape < and > characters
-    message = message.replace("<", "\\<").replace(">", "\\>")
+    message = message.replace("{", "{{").replace("}", "}}").replace("<", "\<")
+
+    # Escape markup tags from Rich
     message = escape(message)
     message = escape(message)
     lines = message.splitlines()
     lines = message.splitlines()
 
 
@@ -86,7 +88,7 @@ class UvicornLoggingHandler(logging.Handler):
 
 
 
 
 # Uvicorn config for logging. Passed into run when creating all loggers in
 # Uvicorn config for logging. Passed into run when creating all loggers in
-#server
+# server
 UVICORN_LOG_CONFIG = {
 UVICORN_LOG_CONFIG = {
     "version": 1,
     "version": 1,
     "disable_existing_loggers": False,
     "disable_existing_loggers": False,
@@ -99,7 +101,7 @@ UVICORN_LOG_CONFIG = {
     "root": {
     "root": {
         "handlers": ["uvicorn"],
         "handlers": ["uvicorn"],
         "propagate": False,
         "propagate": False,
-        "level": "INFO"
+        "level": LOG_LEVEL
     },
     },
 }
 }
 
 
@@ -111,7 +113,7 @@ def setup_logger():
 
 
     logger.add(
     logger.add(
         RICH_CONSOLE.print,
         RICH_CONSOLE.print,
-        level="INFO",
+        level=LOG_LEVEL,
         format=_log_formatter,
         format=_log_formatter,
         colorize=True,
         colorize=True,
     )
     )

+ 47 - 30
aphrodite/common/outputs.py

@@ -1,7 +1,13 @@
 from typing import List, Optional
 from typing import List, Optional
+import time
 
 
-from aphrodite.common.sequence import (PromptLogprobs, SampleLogprobs,
-                                       SequenceGroup, SequenceStatus)
+from aphrodite.common.sequence import (
+    PromptLogprobs,
+    SampleLogprobs,
+    SequenceGroup,
+    SequenceStatus,
+    RequestMetrics,
+)
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 
 
 
 
@@ -60,6 +66,7 @@ class RequestOutput:
         prompt_logprobs: The log probabilities to return per prompt token.
         prompt_logprobs: The log probabilities to return per prompt token.
         outputs: The output sequences of the request.
         outputs: The output sequences of the request.
         finished: Whether the whole request is finished.
         finished: Whether the whole request is finished.
+        metrics: Metrics associated with the request.
         lora_request: The LoRA request that was used to generate the output.
         lora_request: The LoRA request that was used to generate the output.
     """
     """
 
 
@@ -71,6 +78,7 @@ class RequestOutput:
         prompt_logprobs: Optional[PromptLogprobs],
         prompt_logprobs: Optional[PromptLogprobs],
         outputs: List[CompletionOutput],
         outputs: List[CompletionOutput],
         finished: bool,
         finished: bool,
+        metrics: Optional[RequestMetrics] = None,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
     ) -> None:
     ) -> None:
         self.request_id = request_id
         self.request_id = request_id
@@ -79,6 +87,7 @@ class RequestOutput:
         self.prompt_logprobs = prompt_logprobs
         self.prompt_logprobs = prompt_logprobs
         self.outputs = outputs
         self.outputs = outputs
         self.finished = finished
         self.finished = finished
+        self.metrics = metrics
         self.lora_request = lora_request
         self.lora_request = lora_request
 
 
     @classmethod
     @classmethod
@@ -86,43 +95,50 @@ class RequestOutput:
         # Get the top-n sequences.
         # Get the top-n sequences.
         n = seq_group.sampling_params.n
         n = seq_group.sampling_params.n
         seqs = seq_group.get_seqs()
         seqs = seq_group.get_seqs()
-        if seq_group.sampling_params.use_beam_search:
-            sorting_key = lambda seq: seq.get_beam_search_score(
-                seq_group.sampling_params.length_penalty)
+        if n == 1:
+            top_n_seqs = seqs
         else:
         else:
-            # ruff: noqa: E731
-            sorting_key = lambda seq: seq.get_cumulative_logprob()
-        sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
-        top_n_seqs = sorted_seqs[:n]
+            if seq_group.sampling_params.use_beam_search:
+                sorting_key = lambda seq: seq.get_beam_search_score(
+                    seq_group.sampling_params.length_penalty)
+            else:
+                sorting_key = lambda seq: seq.get_cumulative_logprob()
+            sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
+            top_n_seqs = sorted_seqs[:n]
 
 
         # Create the outputs.
         # Create the outputs.
-        outputs: List[CompletionOutput] = []
-        for seq in top_n_seqs:
-            logprobs = seq.output_logprobs
-            if seq_group.sampling_params.logprobs is None:
-                # NOTE: We need to take care of this case because the sequence
-                # always has the logprobs of the sampled tokens even if the
-                # logprobs are not requested.
-                logprobs = None
-            finshed_reason = SequenceStatus.get_finished_reason(seq.status)
-            output = CompletionOutput(seqs.index(seq), seq.output_text,
-                                      seq.get_output_token_ids(),
-                                      seq.get_cumulative_logprob(), logprobs,
-                                      finshed_reason)
-            outputs.append(output)
+        # NOTE: We need omit logprobs here explicitly because the sequence
+        # always has the logprobs of the sampled tokens even if the
+        # logprobs are not requested.
+        include_logprobs = seq_group.sampling_params.logprobs
+        outputs = [
+            CompletionOutput(
+                seqs.index(seq),
+                seq.output_text,
+                seq.get_output_token_ids(),
+                seq.get_cumulative_logprob(),
+                seq.output_logprobs if include_logprobs else None,
+                SequenceStatus.get_finished_reason(seq.status),
+            ) for seq in top_n_seqs
+        ]
 
 
         # Every sequence in the sequence group should have the same prompt.
         # Every sequence in the sequence group should have the same prompt.
         prompt = seq_group.prompt
         prompt = seq_group.prompt
         prompt_token_ids = seq_group.prompt_token_ids
         prompt_token_ids = seq_group.prompt_token_ids
         prompt_logprobs = seq_group.prompt_logprobs
         prompt_logprobs = seq_group.prompt_logprobs
         finished = seq_group.is_finished()
         finished = seq_group.is_finished()
-        return cls(seq_group.request_id,
-                   prompt,
-                   prompt_token_ids,
-                   prompt_logprobs,
-                   outputs,
-                   finished,
-                   lora_request=seq_group.lora_request)
+        finished_time = time.time() if finished else None
+        seq_group.set_finished_time(finished_time)
+        return cls(
+            seq_group.request_id,
+            prompt,
+            prompt_token_ids,
+            prompt_logprobs,
+            outputs,
+            finished,
+            seq_group.metrics,
+            lora_request=seq_group.lora_request,
+        )
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return (f"RequestOutput(request_id={self.request_id}, "
         return (f"RequestOutput(request_id={self.request_id}, "
@@ -131,4 +147,5 @@ class RequestOutput:
                 f"prompt_logprobs={self.prompt_logprobs}, "
                 f"prompt_logprobs={self.prompt_logprobs}, "
                 f"outputs={self.outputs}, "
                 f"outputs={self.outputs}, "
                 f"finished={self.finished}, "
                 f"finished={self.finished}, "
+                f"metrics={self.metrics}, "
                 f"lora_request={self.lora_request})")
                 f"lora_request={self.lora_request})")

+ 13 - 0
aphrodite/common/sampling_params.py

@@ -1,4 +1,5 @@
 """Sampling parameters for text generation."""
 """Sampling parameters for text generation."""
+import copy
 from enum import IntEnum
 from enum import IntEnum
 from functools import cached_property
 from functools import cached_property
 from typing import Callable, List, Optional, Union
 from typing import Callable, List, Optional, Union
@@ -375,6 +376,18 @@ class SamplingParams:
             return SamplingType.RANDOM_SEED
             return SamplingType.RANDOM_SEED
         return SamplingType.RANDOM
         return SamplingType.RANDOM
 
 
+    def clone(self) -> "SamplingParams":
+        """Deep copy excluding LogitsProcessor objects.
+        LogitsProcessor objects are excluded because they may contain an
+        arbitrary, nontrivial amount of data.
+        """
+
+        logit_processor_refs = None if self.logits_processors is None else {
+            id(lp): lp
+            for lp in self.logits_processors
+        }
+        return copy.deepcopy(self, memo=logit_processor_refs)
+
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         repr_str = "SamplingParams("
         repr_str = "SamplingParams("
         for param, default_value in self.default_values.items():
         for param, default_value in self.default_values.items():

+ 71 - 21
aphrodite/common/sequence.py

@@ -2,16 +2,21 @@
 import copy
 import copy
 import enum
 import enum
 from dataclasses import dataclass
 from dataclasses import dataclass
-from typing import Dict, List, Optional, Union
+from typing import Dict, List, Optional, Union, TYPE_CHECKING
 
 
 from aphrodite.common.block import LogicalTokenBlock
 from aphrodite.common.block import LogicalTokenBlock
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 
 
+if TYPE_CHECKING:
+    import torch
+    from aphrodite.spec_decode.metrics import SpecDecodeWorkerMetrics
+
 
 
 @dataclass
 @dataclass
 class Logprob:
 class Logprob:
     """Infos for supporting OpenAI compatible logprobs."""
     """Infos for supporting OpenAI compatible logprobs."""
+
     logprob: float
     logprob: float
     decoded_token: Optional[str] = None
     decoded_token: Optional[str] = None
 
 
@@ -22,6 +27,7 @@ SampleLogprobs = List[Dict[int, Logprob]]
 
 
 class SequenceStatus(enum.Enum):
 class SequenceStatus(enum.Enum):
     """Status of a sequence."""
     """Status of a sequence."""
+
     WAITING = enum.auto()
     WAITING = enum.auto()
     RUNNING = enum.auto()
     RUNNING = enum.auto()
     SWAPPED = enum.auto()
     SWAPPED = enum.auto()
@@ -68,6 +74,7 @@ class RequestMetrics:
         time_in_queue: The time the request spent in the queue.
         time_in_queue: The time the request spent in the queue.
         finished_time: The time when the request was finished.
         finished_time: The time when the request was finished.
     """
     """
+
     arrival_time: float
     arrival_time: float
     last_token_time: float
     last_token_time: float
     first_scheduled_time: Optional[float]
     first_scheduled_time: Optional[float]
@@ -81,6 +88,8 @@ class SequenceData:
 
 
     Args:
     Args:
         prompt_token_ids: The token IDs of the prompt.
         prompt_token_ids: The token IDs of the prompt.
+        output_token_ids: The token IDs of the output. Set to an empty list if
+            None.
 
 
     Attributes:
     Attributes:
         prompt_token_ids: The token IDs of the prompt.
         prompt_token_ids: The token IDs of the prompt.
@@ -91,9 +100,13 @@ class SequenceData:
     def __init__(
     def __init__(
         self,
         self,
         prompt_token_ids: List[int],
         prompt_token_ids: List[int],
+        output_token_ids: Optional[List[int]] = None,
     ) -> None:
     ) -> None:
+        if output_token_ids is None:
+            output_token_ids = []
+
         self.prompt_token_ids = prompt_token_ids
         self.prompt_token_ids = prompt_token_ids
-        self.output_token_ids: List[int] = []
+        self.output_token_ids = output_token_ids
         self.cumulative_logprob = 0.0
         self.cumulative_logprob = 0.0
 
 
     def append_token_id(self, token_id: int, logprob: float) -> None:
     def append_token_id(self, token_id: int, logprob: float) -> None:
@@ -117,6 +130,12 @@ class SequenceData:
             return self.prompt_token_ids[-1]
             return self.prompt_token_ids[-1]
         return self.output_token_ids[-1]
         return self.output_token_ids[-1]
 
 
+    def get_prompt_token_ids(self) -> int:
+        return self.prompt_token_ids
+
+    def get_output_token_ids(self) -> int:
+        return self.output_token_ids
+
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return (f"SequenceData("
         return (f"SequenceData("
                 f"prompt_token_ids={self.prompt_token_ids}, "
                 f"prompt_token_ids={self.prompt_token_ids}, "
@@ -142,11 +161,13 @@ class Sequence:
         prompt: str,
         prompt: str,
         prompt_token_ids: List[int],
         prompt_token_ids: List[int],
         block_size: int,
         block_size: int,
+        eos_token_id: Optional[int] = None,
         lora_request: Optional[LoRARequest] = None,
         lora_request: Optional[LoRARequest] = None,
     ) -> None:
     ) -> None:
         self.seq_id = seq_id
         self.seq_id = seq_id
         self.prompt = prompt
         self.prompt = prompt
         self.block_size = block_size
         self.block_size = block_size
+        self.eos_token_id = eos_token_id
         self.lora_request = lora_request
         self.lora_request = lora_request
 
 
         self.data = SequenceData(prompt_token_ids)
         self.data = SequenceData(prompt_token_ids)
@@ -164,7 +185,6 @@ class Sequence:
         # Input + output tokens
         # Input + output tokens
         self.tokens: Optional[List[str]] = None
         self.tokens: Optional[List[str]] = None
         self.persistent_data = {}
         self.persistent_data = {}
-        self.persistent_data = {}
 
 
     @property
     @property
     def lora_int_id(self) -> int:
     def lora_int_id(self) -> int:
@@ -235,10 +255,12 @@ class Sequence:
     def get_cumulative_logprob(self) -> float:
     def get_cumulative_logprob(self) -> float:
         return self.data.cumulative_logprob
         return self.data.cumulative_logprob
 
 
-    def get_beam_search_score(self,
-                              length_penalty: float = 0.0,
-                              seq_len: Optional[int] = None,
-                              eos_token_id: Optional[int] = None) -> float:
+    def get_beam_search_score(
+        self,
+        length_penalty: float = 1.0,
+        seq_len: Optional[int] = None,
+        eos_token_id: Optional[int] = None,
+    ) -> float:
         """Calculate the beam search score with length penalty.
         """Calculate the beam search score with length penalty.
 
 
         Adapted from
         Adapted from
@@ -298,11 +320,13 @@ class SequenceGroup:
         self.request_id = request_id
         self.request_id = request_id
         self.seqs_dict = {seq.seq_id: seq for seq in seqs}
         self.seqs_dict = {seq.seq_id: seq for seq in seqs}
         self.sampling_params = sampling_params
         self.sampling_params = sampling_params
-        self.metrics = RequestMetrics(arrival_time=arrival_time,
-                                      last_token_time=arrival_time,
-                                      first_scheduled_time=None,
-                                      first_token_time=None,
-                                      time_in_queue=None)
+        self.metrics = RequestMetrics(
+            arrival_time=arrival_time,
+            last_token_time=arrival_time,
+            first_scheduled_time=None,
+            first_token_time=None,
+            time_in_queue=None,
+        )
         self.lora_request = lora_request
         self.lora_request = lora_request
         self.prompt_logprobs: Optional[PromptLogprobs] = None
         self.prompt_logprobs: Optional[PromptLogprobs] = None
         self.state = SequenceGroupState()
         self.state = SequenceGroupState()
@@ -366,12 +390,9 @@ class SequenceGroup:
         self,
         self,
         status: Optional[SequenceStatus] = None,
         status: Optional[SequenceStatus] = None,
     ) -> List[Sequence]:
     ) -> List[Sequence]:
-        if status is None:
-            return list(self.seqs_dict.values())
-        else:
-            return [
-                seq for seq in self.seqs_dict.values() if seq.status == status
-            ]
+        return (list(self.seqs_dict.values()) if status is None else [
+            seq for seq in self.seqs_dict.values() if seq.status == status
+        ])
 
 
     def get_unfinished_seqs(self) -> List[Sequence]:
     def get_unfinished_seqs(self) -> List[Sequence]:
         return [
         return [
@@ -517,6 +538,35 @@ class SequenceGroupOutput:
                 and self.prompt_logprobs == other.prompt_logprobs)
                 and self.prompt_logprobs == other.prompt_logprobs)
 
 
 
 
-# For each sequence group, we generate a list of SequenceOutput object,
-# each of which contains one possible candidate for the next token.
-SamplerOutput = List[SequenceGroupOutput]
+@dataclass
+class SamplerOutput:
+    """For each sequence group, we generate a list of SequenceOutput object,
+    each of which contains one possible candidate for the next token.
+
+    This datastructure implements methods so it can be used like a list, but
+    also has optional fields for device tensors.
+    """
+
+    outputs: List[SequenceGroupOutput]
+
+    # On-device tensor containing probabilities of each token.
+    sampled_token_probs: Optional["torch.Tensor"] = None
+
+    # On-device tensor containing the sampled token ids.
+    sampled_token_ids: Optional["torch.Tensor"] = None
+
+    # Spec decode metrics populated by workers.
+    spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
+
+    def __getitem__(self, idx: int):
+        return self.outputs[idx]
+
+    def __setitem__(self, idx: int, value):
+        self.outputs[idx] = value
+
+    def __len__(self):
+        return len(self.outputs)
+
+    def __eq__(self, other: object):
+        return (isinstance(other, self.__class__)
+                and self.outputs == other.outputs)

+ 67 - 27
aphrodite/common/utils.py

@@ -5,16 +5,21 @@ import subprocess
 import uuid
 import uuid
 import gc
 import gc
 from platform import uname
 from platform import uname
-from loguru import logger
+from typing import List, Tuple, Union
+from packaging.version import parse, Version
 
 
 import psutil
 import psutil
 import torch
 import torch
 import asyncio
 import asyncio
 from functools import partial
 from functools import partial
-from typing import (Any, Awaitable, Callable, Hashable, Optional, TypeVar,
-                    List, Tuple, Union)
+from typing import (
+    Awaitable,
+    Callable,
+    TypeVar,
+)
 from collections import OrderedDict
 from collections import OrderedDict
-from packaging.version import parse, Version
+from typing import Any, Hashable, Optional
+from loguru import logger
 
 
 T = TypeVar("T")
 T = TypeVar("T")
 
 
@@ -23,7 +28,7 @@ STR_DTYPE_TO_TORCH_DTYPE = {
     "bfloat16": torch.bfloat16,
     "bfloat16": torch.bfloat16,
     "float": torch.float,
     "float": torch.float,
     "fp8_e5m2": torch.uint8,
     "fp8_e5m2": torch.uint8,
-    "int8": torch.int8,
+    # "int8": torch.int8,
 }
 }
 
 
 
 
@@ -113,12 +118,24 @@ def is_hip() -> bool:
     return torch.version.hip is not None
     return torch.version.hip is not None
 
 
 
 
+def is_neuron() -> bool:
+    try:
+        import transformers_neuronx
+    except ImportError:
+        transformers_neuronx = None
+    return transformers_neuronx is not None
+
+
 def get_max_shared_memory_bytes(gpu: int = 0) -> int:
 def get_max_shared_memory_bytes(gpu: int = 0) -> int:
     """Returns the maximum shared memory per thread block in bytes."""
     """Returns the maximum shared memory per thread block in bytes."""
+    # NOTE: This import statement should be executed lazily since
+    # the Neuron-X backend does not have the `cuda_utils` module.
     from aphrodite._C import cuda_utils
     from aphrodite._C import cuda_utils
-    # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
+
     max_shared_mem = (
     max_shared_mem = (
         cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
         cuda_utils.get_max_shared_memory_per_block_device_attribute(gpu))
+    # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py
+    # will fail
     assert max_shared_mem > 0, "max_shared_mem can not be zero"
     assert max_shared_mem > 0, "max_shared_mem can not be zero"
     return int(max_shared_mem)
     return int(max_shared_mem)
 
 
@@ -139,6 +156,7 @@ def in_wsl() -> bool:
 
 
 def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
 def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
     """Take a blocking function, and run it on in an executor thread.
     """Take a blocking function, and run it on in an executor thread.
+
     This function prevents the blocking function from blocking the
     This function prevents the blocking function from blocking the
     asyncio event loop.
     asyncio event loop.
     The code in this function needs to be thread safe.
     The code in this function needs to be thread safe.
@@ -153,15 +171,33 @@ def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]:
 
 
 
 
 def get_ip() -> str:
 def get_ip() -> str:
+    # try ipv4
     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
     s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-    s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
-    return s.getsockname()[0]
+    try:
+        s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
+        return s.getsockname()[0]
+    except OSError:
+        # try ipv6
+        s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+        s.connect(("dns.google", 80))
+        return s.getsockname()[0]
+
+
+def get_distributed_init_method(ip: str, port: int) -> str:
+    return f"tcp://{ip}:{port}"
 
 
 
 
 def get_open_port() -> int:
 def get_open_port() -> int:
-    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
-        s.bind(("", 0))
-        return s.getsockname()[1]
+    # try ipv4
+    try:
+        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+            s.bind(("", 0))
+            return s.getsockname()[1]
+    except OSError:
+        # try ipv6
+        with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
+            s.bind(("", 0))
+            return s.getsockname()[1]
 
 
 
 
 def set_cuda_visible_devices(device_ids: List[int]) -> None:
 def set_cuda_visible_devices(device_ids: List[int]) -> None:
@@ -170,18 +206,22 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None:
 
 
 def get_nvcc_cuda_version() -> Optional[Version]:
 def get_nvcc_cuda_version() -> Optional[Version]:
     cuda_home = os.environ.get('CUDA_HOME')
     cuda_home = os.environ.get('CUDA_HOME')
-    nvcc_path = os.path.join(cuda_home, 'bin', 'nvcc') if cuda_home else 'nvcc'
-
-    try:
-        nvcc_output = subprocess.check_output([nvcc_path, "-V"],
-                                              universal_newlines=True)
-        output = nvcc_output.split()
-        release_idx = output.index("release") + 1
-        nvcc_cuda_version = parse(output[release_idx].split(",")[0])
-        return nvcc_cuda_version
-    except (FileNotFoundError, subprocess.CalledProcessError):
-        logger.warning("nvcc not found. Skipping CUDA version check!")
-        return None
+    if not cuda_home:
+        cuda_home = '/usr/local/cuda'
+        if os.path.isfile(cuda_home + '/bin/nvcc'):
+            logger.info(
+                f'CUDA_HOME is not found in the environment. Using {cuda_home} '
+                'as CUDA_HOME.')
+        else:
+            logger.warning(
+                f'Not found nvcc in {cuda_home}. Skip cuda version check!')
+            return None
+    nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"],
+                                          universal_newlines=True)
+    output = nvcc_output.split()
+    release_idx = output.index("release") + 1
+    nvcc_cuda_version = parse(output[release_idx].split(",")[0])
+    return nvcc_cuda_version
 
 
 
 
 def _generate_random_fp8_e5m2(
 def _generate_random_fp8_e5m2(
@@ -248,8 +288,8 @@ def create_kv_caches_with_random(
                                 device=device)
                                 device=device)
         if cache_dtype == 'fp8_e5m2':
         if cache_dtype == 'fp8_e5m2':
             _generate_random_fp8_e5m2(key_cache, -scale, scale)
             _generate_random_fp8_e5m2(key_cache, -scale, scale)
-        elif cache_dtype == 'int8':
-            torch.randint(-128, 127, key_cache.size(), out=key_cache)
+        # elif cache_dtype == 'int8':
+        #     torch.randint(-128, 127, key_cache.size(), out=key_cache)
         elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
         elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
             key_cache.uniform_(-scale, scale)
             key_cache.uniform_(-scale, scale)
         else:
         else:
@@ -265,8 +305,8 @@ def create_kv_caches_with_random(
                                   device=device)
                                   device=device)
         if cache_dtype == 'fp8_e5m2':
         if cache_dtype == 'fp8_e5m2':
             _generate_random_fp8_e5m2(value_cache, -scale, scale)
             _generate_random_fp8_e5m2(value_cache, -scale, scale)
-        elif cache_dtype == 'int8':
-            torch.randint(-128, 127, value_cache.size(), out=value_cache)
+        # elif cache_dtype == 'int8':
+        #     torch.randint(-128, 127, value_cache.size(), out=value_cache)
         elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
         elif torch_dtype in [torch.half, torch.bfloat16, torch.float]:
             value_cache.uniform_(-scale, scale)
             value_cache.uniform_(-scale, scale)
         else:
         else:

+ 1 - 1
aphrodite/endpoints/llm.py

@@ -78,7 +78,7 @@ class LLM:
         seed: int = 0,
         seed: int = 0,
         gpu_memory_utilization: float = 0.9,
         gpu_memory_utilization: float = 0.9,
         swap_space: int = 4,
         swap_space: int = 4,
-        enforce_eager: bool = False,
+        enforce_eager: bool = True,
         max_context_len_to_capture: int = 8192,
         max_context_len_to_capture: int = 8192,
         disable_custom_all_reduce: bool = False,
         disable_custom_all_reduce: bool = False,
         **kwargs,
         **kwargs,

+ 104 - 95
aphrodite/endpoints/openai/api_server.py

@@ -191,6 +191,7 @@ async def validation_exception_handler(_, exc):
 @app.get("/health")
 @app.get("/health")
 async def health() -> Response:
 async def health() -> Response:
     """Health check."""
     """Health check."""
+    await openai_serving_chat.engine.check_health()
     return Response(status_code=200)
     return Response(status_code=200)
 
 
 
 
@@ -526,104 +527,112 @@ async def get_kobold_lite_ui():
 # ============ KoboldAI API ============ #
 # ============ KoboldAI API ============ #
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    args = parse_args()
-
-    if args.launch_kobold_api:
-        logger.warning("Launching Kobold API server in addition to OpenAI. "
-                       "Keep in mind that the Kobold API routes are NOT "
-                       "protected via the API key.")
-        app.include_router(kai_api, prefix="/api/v1")
-        app.include_router(kai_api,
-                           prefix="/api/latest",
-                           include_in_schema=False)
-        app.include_router(extra_api, prefix="/api/extra")
-
-    app.add_middleware(
-        CORSMiddleware,
-        allow_origins=args.allowed_origins,
-        allow_credentials=args.allow_credentials,
-        allow_methods=args.allowed_methods,
-        allow_headers=args.allowed_headers,
-    )
+    try:
+        args = parse_args()
+
+        if args.launch_kobold_api:
+            logger.warning(
+                "Launching Kobold API server in addition to OpenAI. "
+                "Keep in mind that the Kobold API routes are NOT "
+                "protected via the API key.")
+            app.include_router(kai_api, prefix="/api/v1")
+            app.include_router(kai_api,
+                               prefix="/api/latest",
+                               include_in_schema=False)
+            app.include_router(extra_api, prefix="/api/extra")
+
+        app.add_middleware(
+            CORSMiddleware,
+            allow_origins=args.allowed_origins,
+            allow_credentials=args.allow_credentials,
+            allow_methods=args.allowed_methods,
+            allow_headers=args.allowed_headers,
+        )
+
+        if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
+            admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
+
+            if admin_key is None:
+                logger.warning("Admin key not provided. Admin operations will "
+                               "be disabled.")
+
+            @app.middleware("http")
+            async def authentication(request: Request, call_next):
+                excluded_paths = ["/api"]
+                if any(
+                        request.url.path.startswith(path)
+                        for path in excluded_paths):
+                    return await call_next(request)
+                if not request.url.path.startswith("/v1"):
+                    return await call_next(request)
 
 
-    if token := os.environ.get("APHRODITE_API_KEY") or args.api_keys:
-        admin_key = os.environ.get("APHRODITE_ADMIN_KEY") or args.admin_key
+                auth_header = request.headers.get("Authorization")
+                api_key_header = request.headers.get("x-api-key")
 
 
-        if admin_key is None:
-            logger.warning("Admin key not provided. Admin operations will "
-                           "be disabled.")
+                if request.url.path.startswith("/v1/lora"):
+                    if admin_key is not None and api_key_header == admin_key:
+                        return await call_next(request)
+                    return JSONResponse(content={"error": "Unauthorized"},
+                                        status_code=401)
 
 
-        @app.middleware("http")
-        async def authentication(request: Request, call_next):
-            excluded_paths = ["/api"]
-            if any(
-                    request.url.path.startswith(path)
-                    for path in excluded_paths):
-                return await call_next(request)
-            if not request.url.path.startswith("/v1"):
+                if auth_header != "Bearer " + token and api_key_header != token:
+                    return JSONResponse(content={"error": "Unauthorized"},
+                                        status_code=401)
                 return await call_next(request)
                 return await call_next(request)
 
 
-            auth_header = request.headers.get("Authorization")
-            api_key_header = request.headers.get("x-api-key")
-
-            if request.url.path.startswith("/v1/lora"):
-                if admin_key is not None and api_key_header == admin_key:
-                    return await call_next(request)
-                return JSONResponse(content={"error": "Unauthorized"},
-                                    status_code=401)
-
-            if auth_header != "Bearer " + token and api_key_header != token:
-                return JSONResponse(content={"error": "Unauthorized"},
-                                    status_code=401)
-            return await call_next(request)
-
-    for middleware in args.middleware:
-        module_path, object_name = middleware.rsplit(".", 1)
-        imported = getattr(importlib.import_module(module_path), object_name)
-        if inspect.isclass(imported):
-            app.add_middleware(imported)
-        elif inspect.iscoroutinefunction(imported):
-            app.middleware("http")(imported)
+        for middleware in args.middleware:
+            module_path, object_name = middleware.rsplit(".", 1)
+            imported = getattr(importlib.import_module(module_path),
+                               object_name)
+            if inspect.isclass(imported):
+                app.add_middleware(imported)
+            elif inspect.iscoroutinefunction(imported):
+                app.middleware("http")(imported)
+            else:
+                raise ValueError(f"Invalid middleware {middleware}. Must be a "
+                                 "function or a class.")
+
+        logger.debug(f"args: {args}")
+
+        if args.served_model_name is not None:
+            served_model = args.served_model_name
         else:
         else:
-            raise ValueError(f"Invalid middleware {middleware}. Must be a "
-                             "function or a class.")
-
-    logger.debug(f"args: {args}")
-
-    if args.served_model_name is not None:
-        served_model = args.served_model_name
-    else:
-        served_model = args.model
-
-    engine_args = AsyncEngineArgs.from_cli_args(args)
-    engine = AsyncAphrodite.from_engine_args(engine_args)
-    tokenizer = get_tokenizer(
-        engine_args.tokenizer,
-        tokenizer_mode=engine_args.tokenizer_mode,
-        trust_remote_code=engine_args.trust_remote_code,
-    )
-
-    chat_template = args.chat_template
-    if chat_template is None and tokenizer.chat_template is not None:
-        chat_template = tokenizer.chat_template
-
-    openai_serving_chat = OpenAIServingChat(engine, served_model,
-                                            args.response_role,
-                                            args.lora_modules,
-                                            args.chat_template)
-    openai_serving_completion = OpenAIServingCompletion(
-        engine, served_model, args.lora_modules)
-    engine_model_config = asyncio.run(engine.get_model_config())
-
-    if args.launch_kobold_api:
-        _set_badwords(tokenizer, engine_model_config.hf_config)
-
-    app.root_path = args.root_path
-    uvicorn.run(app,
-                host=args.host,
-                port=args.port,
-                log_level="info",
-                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
-                ssl_keyfile=args.ssl_keyfile,
-                ssl_certfile=args.ssl_certfile,
-                log_config=UVICORN_LOG_CONFIG)
+            served_model = args.model
+
+        engine_args = AsyncEngineArgs.from_cli_args(args)
+        engine = AsyncAphrodite.from_engine_args(engine_args)
+        tokenizer = get_tokenizer(
+            engine_args.tokenizer,
+            tokenizer_mode=engine_args.tokenizer_mode,
+            trust_remote_code=engine_args.trust_remote_code,
+        )
+
+        chat_template = args.chat_template
+        if chat_template is None and tokenizer.chat_template is not None:
+            chat_template = tokenizer.chat_template
+
+        openai_serving_chat = OpenAIServingChat(engine, served_model,
+                                                args.response_role,
+                                                args.lora_modules,
+                                                args.chat_template)
+        openai_serving_completion = OpenAIServingCompletion(
+            engine, served_model, args.lora_modules)
+        engine_model_config = asyncio.run(engine.get_model_config())
+
+        if args.launch_kobold_api:
+            _set_badwords(tokenizer, engine_model_config.hf_config)
+
+        app.root_path = args.root_path
+        uvicorn.run(app,
+                    host=args.host,
+                    port=args.port,
+                    log_level="info",
+                    timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
+                    ssl_keyfile=args.ssl_keyfile,
+                    ssl_certfile=args.ssl_certfile,
+                    log_config=UVICORN_LOG_CONFIG)
+    except KeyboardInterrupt:
+        logger.info("API server stopped by user. Exiting gracefully.")
+    except asyncio.exceptions.CancelledError:
+        logger.info("API server stopped due to a cancelled request. "
+                    "Exiting gracefully.")

+ 275 - 97
aphrodite/engine/aphrodite_engine.py

@@ -2,29 +2,61 @@ import copy
 from collections import defaultdict
 from collections import defaultdict
 import os
 import os
 import time
 import time
-from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
-                    Union)
+import pickle
+import importlib
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 from loguru import logger
 from loguru import logger
 
 
 import aphrodite
 import aphrodite
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                                     SchedulerConfig, LoRAConfig, DeviceConfig)
+from aphrodite.common.config import (
+    CacheConfig,
+    DeviceConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+)
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics import StatLogger, Stats
 from aphrodite.engine.metrics import StatLogger, Stats
-from aphrodite.engine.ray_tools import (RayWorkerAphrodite, initialize_cluster,
-                                        ray)
-from aphrodite.common.logger import setup_logger
+from aphrodite.engine.ray_tools import (
+    RayWorkerAphrodite,
+    initialize_cluster,
+    ray,
+)
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
-from aphrodite.common.sequence import (SamplerOutput, Sequence, SequenceGroup,
-                                       SequenceGroupOutput, SequenceOutput,
-                                       SequenceStatus, Logprob)
-from aphrodite.transformers_utils.tokenizer import (detokenize_incrementally,
-                                                    TokenizerGroup)
-from aphrodite.common.utils import (Counter, set_cuda_visible_devices, get_ip,
-                                    get_open_port)
+from aphrodite.common.sequence import (
+    Logprob,
+    SamplerOutput,
+    Sequence,
+    SequenceGroup,
+    SequenceGroupOutput,
+    SequenceOutput,
+    SequenceStatus,
+)
+from aphrodite.transformers_utils.tokenizer import (
+    detokenize_incrementally,
+    TokenizerGroup,
+)
+from aphrodite.common.utils import (
+    Counter,
+    set_cuda_visible_devices,
+    get_ip,
+    get_open_port,
+    get_distributed_init_method,
+)
+from aphrodite.common.logger import setup_logger
 
 
 if ray:
 if ray:
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@@ -34,6 +66,17 @@ if TYPE_CHECKING:
 
 
 _LOCAL_LOGGING_INTERVAL_SEC = 5
 _LOCAL_LOGGING_INTERVAL_SEC = 5
 
 
+# A map between the device type (in device config) to its worker module.
+DEVICE_TO_WORKER_MODULE_MAP = {
+    "cuda": "aphrodite.task_handler.worker",
+    "neuron": "aphrodite.task_handler.neuron_worker",
+}
+
+# If the env var is set, it uses the Ray's compiled DAG API
+# which optimizes the control plane overhead.
+# Run APHRODITE with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
+USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
+
 
 
 class AphroditeEngine:
 class AphroditeEngine:
     """An LLM engine that receives requests and generates texts.
     """An LLM engine that receives requests and generates texts.
@@ -88,7 +131,7 @@ class AphroditeEngine:
             f"Context Length = {model_config.max_model_len}\n"
             f"Context Length = {model_config.max_model_len}\n"
             f"Enforce Eager Mode = {model_config.enforce_eager}\n"
             f"Enforce Eager Mode = {model_config.enforce_eager}\n"
             f"KV Cache Data Type = {cache_config.cache_dtype}\n"
             f"KV Cache Data Type = {cache_config.cache_dtype}\n"
-            f"KV Cache Params Path = {cache_config.cache_quant_params_path}\n"
+            # f"KV Cache Params Path = {cache_config.cache_quant_params_path}\n"
             f"Device = {device_config.device}")
             f"Device = {device_config.device}")
         # TODO: Print more configs in debug mode.
         # TODO: Print more configs in debug mode.
 
 
@@ -110,7 +153,20 @@ class AphroditeEngine:
             ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
             ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
             if ray_usage != "1":
             if ray_usage != "1":
                 os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
                 os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
-            self._init_workers_ray(placement_group)
+            # Pass additional arguments to initialize the worker
+            additional_ray_args = {}
+            if self.parallel_config.ray_workers_use_nsight:
+                logger.info("Configuring Ray workers to use nsight.")
+                additional_ray_args = {
+                    "runtime_env": {
+                        "nsight": {
+                            "t": "cuda,cudnn,cublas",
+                            "o": "'worker_process_%p'",
+                            "cuda-graph-trace": "node",
+                        }
+                    }
+                }
+            self._init_workers_ray(placement_group, **additional_ray_args)
         else:
         else:
             self._init_workers()
             self._init_workers()
 
 
@@ -124,22 +180,40 @@ class AphroditeEngine:
         if self.log_stats:
         if self.log_stats:
             self.stat_logger = StatLogger(
             self.stat_logger = StatLogger(
                 local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
                 local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
-                labels=dict(model_name=model_config.model))
+                labels=dict(model_name=model_config.model),
+            )
+            self.stat_logger.info("cache_config", self.cache_config)
+
+        self.forward_dag = None
+        if USE_RAY_COMPILED_DAG:
+            self.forward_dag = self._compiled_ray_dag()
+
+    def __reduce__(self):
+        # This is to ensure that the AphroditeEngine is not referenced in
+        # the closure used to initialize Ray worker actors
+        raise RuntimeError("AphroditeEngine should not be pickled!")
 
 
     def get_tokenizer_for_seq(self, sequence: Sequence):
     def get_tokenizer_for_seq(self, sequence: Sequence):
         return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
         return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
 
 
+    def _dispatch_worker(self):
+        worker_module = DEVICE_TO_WORKER_MODULE_MAP[
+            self.device_config.device_type]
+        imported_worker = importlib.import_module(worker_module)
+        Worker = imported_worker.Worker
+        return Worker
+
     def _init_workers(self):
     def _init_workers(self):
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        # pylint: disable=import-outside-toplevel
-        from aphrodite.task_handler.worker import Worker
+        Worker = self._dispatch_worker()
 
 
-        assert self.parallel_config.world_size == 1, (
-            "Ray is required if parallel_config.world_size > 1.")
+        assert (self.parallel_config.world_size == 1
+                ), "Ray is required if parallel_config.world_size > 1."
 
 
         self.workers: List[Worker] = []
         self.workers: List[Worker] = []
-        distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
+        distributed_init_method = get_distributed_init_method(
+            get_ip(), get_open_port())
         self.driver_worker = Worker(
         self.driver_worker = Worker(
             self.model_config,
             self.model_config,
             self.parallel_config,
             self.parallel_config,
@@ -150,7 +224,7 @@ class AphroditeEngine:
             distributed_init_method=distributed_init_method,
             distributed_init_method=distributed_init_method,
             lora_config=self.lora_config,
             lora_config=self.lora_config,
             kv_cache_dtype=self.cache_config.cache_dtype,
             kv_cache_dtype=self.cache_config.cache_dtype,
-            kv_quant_params_path=(self.cache_config.cache_quant_params_path),
+            # kv_quant_params_path=(self.cache_config.cache_quant_params_path),
             is_driver_worker=True,
             is_driver_worker=True,
         )
         )
         self._run_workers("init_model")
         self._run_workers("init_model")
@@ -163,7 +237,8 @@ class AphroditeEngine:
             max_input_length=None,
             max_input_length=None,
             tokenizer_mode=self.model_config.tokenizer_mode,
             tokenizer_mode=self.model_config.tokenizer_mode,
             trust_remote_code=self.model_config.trust_remote_code,
             trust_remote_code=self.model_config.trust_remote_code,
-            revision=self.model_config.tokenizer_revision)
+            revision=self.model_config.tokenizer_revision,
+        )
         init_kwargs.update(tokenizer_init_kwargs)
         init_kwargs.update(tokenizer_init_kwargs)
         self.tokenizer: TokenizerGroup = TokenizerGroup(
         self.tokenizer: TokenizerGroup = TokenizerGroup(
             self.model_config.tokenizer, **init_kwargs)
             self.model_config.tokenizer, **init_kwargs)
@@ -230,18 +305,21 @@ class AphroditeEngine:
         for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
         for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
             worker.set_cuda_visible_devices.remote(node_gpus[node_id])
             worker.set_cuda_visible_devices.remote(node_gpus[node_id])
 
 
-        distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
+        distributed_init_method = get_distributed_init_method(
+            driver_ip, get_open_port())
 
 
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        # pylint: disable=import-outside-toplevel
-        from aphrodite.task_handler.worker import Worker
+        Worker = self._dispatch_worker()
 
 
         # Initialize torch distributed process group for the workers.
         # Initialize torch distributed process group for the workers.
         model_config = copy.deepcopy(self.model_config)
         model_config = copy.deepcopy(self.model_config)
         parallel_config = copy.deepcopy(self.parallel_config)
         parallel_config = copy.deepcopy(self.parallel_config)
         scheduler_config = copy.deepcopy(self.scheduler_config)
         scheduler_config = copy.deepcopy(self.scheduler_config)
         device_config = copy.deepcopy(self.device_config)
         device_config = copy.deepcopy(self.device_config)
+        lora_config = copy.deepcopy(self.lora_config)
+        kv_cache_dtype = self.cache_config.cache_dtype
+        # kv_quant_params_path = self.cache_config.cache_quant_params_path
 
 
         for rank, (worker, (node_id,
         for rank, (worker, (node_id,
                             _)) in enumerate(zip(self.workers,
                             _)) in enumerate(zip(self.workers,
@@ -257,29 +335,33 @@ class AphroditeEngine:
                     local_rank,
                     local_rank,
                     rank,
                     rank,
                     distributed_init_method,
                     distributed_init_method,
-                    lora_config=self.lora_config,
-                    kv_cache_dtype=self.cache_config.cache_dtype,
-                    kv_quant_params_path=
-                    (self.cache_config.cache_quant_params_path),
+                    lora_config=lora_config,
+                    kv_cache_dtype=kv_cache_dtype,
+                    # kv_quant_params_path=kv_quant_params_path,
                 ))
                 ))
 
 
         driver_rank = 0
         driver_rank = 0
         driver_local_rank = node_workers[driver_node_id].index(driver_rank)
         driver_local_rank = node_workers[driver_node_id].index(driver_rank)
         self.driver_worker = Worker(
         self.driver_worker = Worker(
-            model_config,
-            parallel_config,
-            scheduler_config,
-            device_config,
+            self.model_config,
+            self.parallel_config,
+            self.scheduler_config,
+            self.device_config,
             driver_local_rank,
             driver_local_rank,
             driver_rank,
             driver_rank,
             distributed_init_method,
             distributed_init_method,
             lora_config=self.lora_config,
             lora_config=self.lora_config,
-            kv_cache_dtype=self.cache_config.cache_dtype,
-            kv_quant_params_path=(self.cache_config.cache_quant_params_path),
+            kv_cache_dtype=kv_cache_dtype,
+            # kv_quant_params_path=kv_quant_params_path,
             is_driver_worker=True,
             is_driver_worker=True,
         )
         )
 
 
-        self._run_workers("init_model", cupy_port=get_open_port())
+        # don't use cupy for eager mode
+        self._run_workers(
+            "init_model",
+            cupy_port=get_open_port()
+            if not model_config.enforce_eager else None,
+        )
         self._run_workers(
         self._run_workers(
             "load_model",
             "load_model",
             max_concurrent_workers=self.parallel_config.
             max_concurrent_workers=self.parallel_config.
@@ -302,7 +384,6 @@ class AphroditeEngine:
         Then, it calculate the maximum possible number of GPU and CPU blocks
         Then, it calculate the maximum possible number of GPU and CPU blocks
         that can be allocated with the remaining free memory.
         that can be allocated with the remaining free memory.
         More details can be found in the
         More details can be found in the
-        # pylint: disable=line-too-long
         :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
         :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
         from class :class:`~aphrodite.task_handler.Worker`.
         from class :class:`~aphrodite.task_handler.Worker`.
 
 
@@ -372,9 +453,11 @@ class AphroditeEngine:
         # Initialize the cluster.
         # Initialize the cluster.
         placement_group = initialize_cluster(parallel_config)
         placement_group = initialize_cluster(parallel_config)
         # Create the LLM engine.
         # Create the LLM engine.
-        engine = cls(*engine_configs,
-                     placement_group,
-                     log_stats=not engine_args.disable_log_stats)
+        engine = cls(
+            *engine_configs,
+            placement_group,
+            log_stats=not engine_args.disable_log_stats,
+        )
         return engine
         return engine
 
 
     def encode_request(
     def encode_request(
@@ -449,20 +532,34 @@ class AphroditeEngine:
                     sampling_params.prompt_logprobs
                     sampling_params.prompt_logprobs
                     and sampling_params.prompt_logprobs > max_log_probs):
                     and sampling_params.prompt_logprobs > max_log_probs):
             raise ValueError(f"Cannot request more than "
             raise ValueError(f"Cannot request more than "
-                             f"{max_log_probs} logprobs.")
+                             f"{max_log_probs} logprobs. "
+                             "Please increase the max_log_probs.")
         if arrival_time is None:
         if arrival_time is None:
             arrival_time = time.monotonic()
             arrival_time = time.monotonic()
         prompt_token_ids = self.encode_request(
         prompt_token_ids = self.encode_request(
             request_id=request_id,
             request_id=request_id,
             prompt=prompt,
             prompt=prompt,
             prompt_token_ids=prompt_token_ids,
             prompt_token_ids=prompt_token_ids,
-            lora_request=lora_request)
+            lora_request=lora_request,
+        )
 
 
         # Create the sequences.
         # Create the sequences.
         block_size = self.cache_config.block_size
         block_size = self.cache_config.block_size
         seq_id = next(self.seq_counter)
         seq_id = next(self.seq_counter)
-        seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
-                       lora_request)
+        eos_token_id = self.tokenizer.get_lora_tokenizer(
+            lora_request).eos_token_id
+        seq = Sequence(
+            seq_id,
+            prompt,
+            prompt_token_ids,
+            block_size,
+            eos_token_id,
+            lora_request,
+        )
+
+        # Defensive copy of SamplingParams, which are used by the sampler,
+        # this doesn't deep-copy LogitsProcessor objects
+        sampling_params = sampling_params.clone()
 
 
         # Create the sequence group.
         # Create the sequence group.
         seq_group = SequenceGroup(request_id, [seq], sampling_params,
         seq_group = SequenceGroup(request_id, [seq], sampling_params,
@@ -514,15 +611,15 @@ class AphroditeEngine:
         if early_stopping is True:
         if early_stopping is True:
             return True
             return True
 
 
-        current_worst_score = (current_worst_seq.get_beam_search_score(
+        current_worst_score = current_worst_seq.get_beam_search_score(
             length_penalty=length_penalty,
             length_penalty=length_penalty,
-            eos_token_id=self.get_tokenizer_for_seq(
-                current_worst_seq).eos_token_id))
+            eos_token_id=current_worst_seq.eos_token_id,
+        )
         if early_stopping is False:
         if early_stopping is False:
-            highest_attainable_score = (best_running_seq.get_beam_search_score(
+            highest_attainable_score = best_running_seq.get_beam_search_score(
                 length_penalty=length_penalty,
                 length_penalty=length_penalty,
-                eos_token_id=self.get_tokenizer_for_seq(
-                    best_running_seq).eos_token_id))
+                eos_token_id=best_running_seq.eos_token_id,
+            )
         else:
         else:
             assert early_stopping == "never"
             assert early_stopping == "never"
             if length_penalty > 0.0:
             if length_penalty > 0.0:
@@ -532,13 +629,14 @@ class AphroditeEngine:
                 max_possible_length = max(
                 max_possible_length = max(
                     best_running_seq.get_prompt_len() +
                     best_running_seq.get_prompt_len() +
                     sampling_params.max_tokens,
                     sampling_params.max_tokens,
-                    self.scheduler_config.max_model_len)
+                    self.scheduler_config.max_model_len,
+                )
                 highest_attainable_score = (
                 highest_attainable_score = (
                     best_running_seq.get_beam_search_score(
                     best_running_seq.get_beam_search_score(
                         length_penalty=length_penalty,
                         length_penalty=length_penalty,
-                        eos_token_id=self.get_tokenizer_for_seq(
-                            best_running_seq).eos_token_id,
-                        seq_len=max_possible_length))
+                        eos_token_id=best_running_seq.eos_token_id,
+                        seq_len=max_possible_length,
+                    ))
             else:
             else:
                 # Otherwise, beam search will prefer shorter sequences. The
                 # Otherwise, beam search will prefer shorter sequences. The
                 # highest attainable score calculation is based on the current
                 # highest attainable score calculation is based on the current
@@ -546,8 +644,8 @@ class AphroditeEngine:
                 highest_attainable_score = (
                 highest_attainable_score = (
                     best_running_seq.get_beam_search_score(
                     best_running_seq.get_beam_search_score(
                         length_penalty=length_penalty,
                         length_penalty=length_penalty,
-                        eos_token_id=self.get_tokenizer_for_seq(
-                            best_running_seq).eos_token_id))
+                        eos_token_id=best_running_seq.eos_token_id,
+                    ))
         return current_worst_score >= highest_attainable_score
         return current_worst_score >= highest_attainable_score
 
 
     def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
     def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
@@ -555,6 +653,16 @@ class AphroditeEngine:
         # Process prompt logprobs
         # Process prompt logprobs
         prompt_logprobs = outputs.prompt_logprobs
         prompt_logprobs = outputs.prompt_logprobs
         if prompt_logprobs is not None:
         if prompt_logprobs is not None:
+            # We can pick any sequence for the prompt.
+            seq = next(iter(seq_group.seqs_dict.values()))
+            all_token_ids = seq.get_token_ids()
+            for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
+                self._decode_logprobs(
+                    seq,
+                    seq_group.sampling_params,
+                    prompt_logprobs_for_token,
+                    all_token_ids[:i],
+                )
             seq_group.prompt_logprobs = prompt_logprobs
             seq_group.prompt_logprobs = prompt_logprobs
 
 
         # Process samples
         # Process samples
@@ -638,10 +746,11 @@ class AphroditeEngine:
                              if seq.is_finished()]
                              if seq.is_finished()]
         all_finished_seqs = existing_finished_seqs + new_finished_seqs
         all_finished_seqs = existing_finished_seqs + new_finished_seqs
         # Sort the finished sequences by their scores.
         # Sort the finished sequences by their scores.
-        all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
-            length_penalty=length_penalty,
-            eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
-                               reverse=True)
+        all_finished_seqs.sort(
+            key=lambda x: x[0].get_beam_search_score(
+                length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
+            reverse=True,
+        )
         for seq, parent, is_new in all_finished_seqs[:beam_width]:
         for seq, parent, is_new in all_finished_seqs[:beam_width]:
             if is_new:
             if is_new:
                 # A newly generated child sequence finishes and has a high
                 # A newly generated child sequence finishes and has a high
@@ -666,10 +775,11 @@ class AphroditeEngine:
         running_child_seqs = [(seq, parent) for seq, parent in child_seqs
         running_child_seqs = [(seq, parent) for seq, parent in child_seqs
                               if not seq.is_finished()]
                               if not seq.is_finished()]
         # Sort the running sequences by their scores.
         # Sort the running sequences by their scores.
-        running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
-            length_penalty=length_penalty,
-            eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id),
-                                reverse=True)
+        running_child_seqs.sort(
+            key=lambda x: x[0].get_beam_search_score(
+                length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
+            reverse=True,
+        )
 
 
         # Check if we can stop the beam search.
         # Check if we can stop the beam search.
         if len(running_child_seqs) == 0:
         if len(running_child_seqs) == 0:
@@ -684,7 +794,10 @@ class AphroditeEngine:
             current_worst_seq = all_finished_seqs[beam_width - 1][0]
             current_worst_seq = all_finished_seqs[beam_width - 1][0]
             stop_beam_search = self._check_beam_search_early_stopping(
             stop_beam_search = self._check_beam_search_early_stopping(
                 seq_group.sampling_params.early_stopping,
                 seq_group.sampling_params.early_stopping,
-                seq_group.sampling_params, best_running_seq, current_worst_seq)
+                seq_group.sampling_params,
+                best_running_seq,
+                current_worst_seq,
+            )
 
 
         if stop_beam_search:
         if stop_beam_search:
             # Stop the beam search and remove all the running sequences from
             # Stop the beam search and remove all the running sequences from
@@ -726,13 +839,16 @@ class AphroditeEngine:
     def _process_model_outputs(
     def _process_model_outputs(
             self, output: SamplerOutput,
             self, output: SamplerOutput,
             scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
             scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]:
+        now = time.time()
         # Update the scheduled sequence groups with the model outputs.
         # Update the scheduled sequence groups with the model outputs.
         scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
         scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
+
         # If prefix caching is enabled, mark all blocks in the sequence groups
         # If prefix caching is enabled, mark all blocks in the sequence groups
         # as completed so that future requests don't attempt to recompute them
         # as completed so that future requests don't attempt to recompute them
         if self.cache_config.context_shift:
         if self.cache_config.context_shift:
             for seq_group in scheduled_seq_groups:
             for seq_group in scheduled_seq_groups:
                 self.scheduler.mark_blocks_as_computed(seq_group)
                 self.scheduler.mark_blocks_as_computed(seq_group)
+
         for seq_group, outputs in zip(scheduled_seq_groups, output):
         for seq_group, outputs in zip(scheduled_seq_groups, output):
             self._process_sequence_group_outputs(seq_group, outputs)
             self._process_sequence_group_outputs(seq_group, outputs)
 
 
@@ -742,6 +858,7 @@ class AphroditeEngine:
         # Create the outputs.
         # Create the outputs.
         request_outputs: List[RequestOutput] = []
         request_outputs: List[RequestOutput] = []
         for seq_group in scheduled_seq_groups:
         for seq_group in scheduled_seq_groups:
+            seq_group.maybe_set_first_token_time(now)
             request_output = RequestOutput.from_seq_group(seq_group)
             request_output = RequestOutput.from_seq_group(seq_group)
             request_outputs.append(request_output)
             request_outputs.append(request_output)
         for seq_group in scheduler_outputs.ignored_seq_groups:
         for seq_group in scheduler_outputs.ignored_seq_groups:
@@ -751,6 +868,7 @@ class AphroditeEngine:
         # Log stats.
         # Log stats.
         if self.log_stats:
         if self.log_stats:
             self.stat_logger.log(self._get_stats(scheduler_outputs))
             self.stat_logger.log(self._get_stats(scheduler_outputs))
+
         return request_outputs
         return request_outputs
 
 
     def step(self) -> List[RequestOutput]:
     def step(self) -> List[RequestOutput]:
@@ -815,7 +933,9 @@ class AphroditeEngine:
                     "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
                     "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
                     "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
                     "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
                     "blocks_to_copy": scheduler_outputs.blocks_to_copy,
                     "blocks_to_copy": scheduler_outputs.blocks_to_copy,
-                })
+                },
+                use_ray_compiled_dag=USE_RAY_COMPILED_DAG,
+            )
 
 
             # Only the driver worker returns the sampling results.
             # Only the driver worker returns the sampling results.
             output = all_outputs[0]
             output = all_outputs[0]
@@ -840,10 +960,10 @@ class AphroditeEngine:
         gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
         gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)
 
 
         num_total_cpu = self.cache_config.num_cpu_blocks
         num_total_cpu = self.cache_config.num_cpu_blocks
-        cpu_cache_usage = 0.
+        cpu_cache_usage = 0.0
         if num_total_cpu > 0:
         if num_total_cpu > 0:
-            num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
-            )
+            num_free_cpu = (
+                self.scheduler.block_manager.get_num_free_cpu_blocks())
             cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
             cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)
 
 
         # Scheduler State
         # Scheduler State
@@ -898,16 +1018,24 @@ class AphroditeEngine:
             time_e2e_requests=time_e2e_requests,
             time_e2e_requests=time_e2e_requests,
         )
         )
 
 
-    def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
-                         logprobs: Dict[int, Logprob],
-                         all_input_ids: List[int]) -> None:
+    def _decode_logprobs(
+        self,
+        seq: Sequence,
+        prms: SamplingParams,
+        logprobs: Dict[int, Logprob],
+        all_input_ids: List[int],
+    ) -> None:
         if not logprobs:
         if not logprobs:
             return
             return
         for token_id, sample_logprob in logprobs.items():
         for token_id, sample_logprob in logprobs.items():
-            if (sample_logprob.decoded_token is None and token_id != -1):
+            if sample_logprob.decoded_token is None and token_id != -1:
                 all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
                 all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
-                # pylint: disable=unused-variable
-                _, new_text, prefix_offset, read_offset = detokenize_incrementally(
+                (
+                    _,
+                    new_text,
+                    prefix_offset,
+                    read_offset,
+                ) = detokenize_incrementally(
                     self.get_tokenizer_for_seq(seq),
                     self.get_tokenizer_for_seq(seq),
                     all_input_ids=all_input_ids_with_logprob,
                     all_input_ids=all_input_ids_with_logprob,
                     prev_tokens=seq.tokens,
                     prev_tokens=seq.tokens,
@@ -924,16 +1052,21 @@ class AphroditeEngine:
         all_input_ids = seq.get_token_ids()
         all_input_ids = seq.get_token_ids()
         self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
         self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
                               all_input_ids)
                               all_input_ids)
-        (new_tokens, new_output_text, prefix_offset,
-         read_offset) = detokenize_incrementally(
-             self.get_tokenizer_for_seq(seq),
-             all_input_ids=all_input_ids,
-             prev_tokens=seq.tokens,
-             prefix_offset=seq.prefix_offset,
-             read_offset=seq.read_offset,
-             skip_special_tokens=prms.skip_special_tokens,
-             spaces_between_special_tokens=prms.spaces_between_special_tokens,
-         )
+
+        (
+            new_tokens,
+            new_output_text,
+            prefix_offset,
+            read_offset,
+        ) = detokenize_incrementally(
+            self.get_tokenizer_for_seq(seq),
+            all_input_ids=all_input_ids,
+            prev_tokens=seq.tokens,
+            prefix_offset=seq.prefix_offset,
+            read_offset=seq.read_offset,
+            skip_special_tokens=prms.skip_special_tokens,
+            spaces_between_special_tokens=prms.spaces_between_special_tokens,
+        )
         if seq.tokens is None:
         if seq.tokens is None:
             seq.tokens = new_tokens
             seq.tokens = new_tokens
         else:
         else:
@@ -968,15 +1101,18 @@ class AphroditeEngine:
             return
             return
 
 
         # Check if the sequence has generated the EOS token.
         # Check if the sequence has generated the EOS token.
-        if ((not sampling_params.ignore_eos) and seq.get_last_token_id()
-                == self.get_tokenizer_for_seq(seq).eos_token_id):
+        if (not sampling_params.ignore_eos
+            ) and seq.get_last_token_id() == seq.eos_token_id:
             seq.status = SequenceStatus.FINISHED_STOPPED
             seq.status = SequenceStatus.FINISHED_STOPPED
             return
             return
 
 
     def _finalize_sequence(self, seq: Sequence,
     def _finalize_sequence(self, seq: Sequence,
                            sampling_params: SamplingParams,
                            sampling_params: SamplingParams,
                            stop_string: str) -> None:
                            stop_string: str) -> None:
-        if not sampling_params.include_stop_str_in_output and stop_string:
+        if sampling_params.include_stop_str_in_output:
+            return
+
+        if stop_string and seq.output_text.endswith(stop_string):
             # Truncate the output text so that the stop string is
             # Truncate the output text so that the stop string is
             # not included in the output.
             # not included in the output.
             seq.output_text = seq.output_text[:-len(stop_string)]
             seq.output_text = seq.output_text[:-len(stop_string)]
@@ -1005,6 +1141,7 @@ class AphroditeEngine:
         driver_args: Optional[List[Any]] = None,
         driver_args: Optional[List[Any]] = None,
         driver_kwargs: Optional[Dict[str, Any]] = None,
         driver_kwargs: Optional[Dict[str, Any]] = None,
         max_concurrent_workers: Optional[int] = None,
         max_concurrent_workers: Optional[int] = None,
+        use_ray_compiled_dag: bool = False,
         **kwargs,
         **kwargs,
     ) -> Any:
     ) -> Any:
         """Runs the given method on all workers."""
         """Runs the given method on all workers."""
@@ -1013,11 +1150,17 @@ class AphroditeEngine:
             raise NotImplementedError(
             raise NotImplementedError(
                 "max_concurrent_workers is not supported yet.")
                 "max_concurrent_workers is not supported yet.")
 
 
-        # Start the ray workers first.
-        ray_worker_outputs = [
-            worker.execute_method.remote(method, *args, **kwargs)
-            for worker in self.workers
-        ]
+        if use_ray_compiled_dag:
+            # Right now, compiled DAG can only accept a single
+            # input.
+            # TODO: Fix it.
+            output_channels = self.forward_dag.execute(1)
+        else:
+            # Start the ray workers first.
+            ray_worker_outputs = [
+                worker.execute_method.remote(method, *args, **kwargs)
+                for worker in self.workers
+            ]
 
 
         if driver_args is None:
         if driver_args is None:
             driver_args = args
             driver_args = args
@@ -1030,10 +1173,45 @@ class AphroditeEngine:
 
 
         # Get the results of the ray workers.
         # Get the results of the ray workers.
         if self.workers:
         if self.workers:
-            ray_worker_outputs = ray.get(ray_worker_outputs)
+            if use_ray_compiled_dag:
+                try:
+                    ray_worker_outputs = [
+                        pickle.loads(chan.begin_read())
+                        for chan in output_channels
+                    ]
+                finally:
+                    # Has to call end_read in order to reuse the DAG.
+                    for chan in output_channels:
+                        chan.end_read()
+            else:
+                ray_worker_outputs = ray.get(ray_worker_outputs)
 
 
         return [driver_worker_output] + ray_worker_outputs
         return [driver_worker_output] + ray_worker_outputs
 
 
+    def _compiled_ray_dag(self):
+        from packaging import version
+        import pkg_resources
+
+        required_version = "2.9"
+        current_version = pkg_resources.get_distribution("ray").version
+
+        if version.parse(current_version) < version.parse(required_version):
+            raise ValueError(f"Ray version {required_version} or greater is "
+                             f"required, but found {current_version}")
+
+        from ray.dag import MultiOutputNode, InputNode
+
+        assert self.parallel_config.worker_use_ray
+
+        # Right now, compiled DAG requires at least 1 arg. We send
+        # a dummy value for now. It will be fixed soon.
+        with InputNode() as input_data:
+            forward_dag = MultiOutputNode([
+                worker.execute_model_compiled_dag_remote.bind(input_data)
+                for worker in self.workers
+            ])
+        return forward_dag.experimental_compile()
+
     def check_health(self) -> None:
     def check_health(self) -> None:
         """Raises an error if engine is unhealthy."""
         """Raises an error if engine is unhealthy."""
         self._check_if_any_actor_is_dead()
         self._check_if_any_actor_is_dead()
@@ -1052,7 +1230,7 @@ class AphroditeEngine:
                 dead_actors.append(actor)
                 dead_actors.append(actor)
         if dead_actors:
         if dead_actors:
             raise RuntimeError("At least one Worker is dead. "
             raise RuntimeError("At least one Worker is dead. "
-                               f"Dead workers: {dead_actors}")
+                               f"Dead Workers: {dead_actors}. ")
 
 
 
 
 setup_logger()
 setup_logger()

+ 373 - 240
aphrodite/engine/args_tools.py

@@ -3,22 +3,29 @@ import dataclasses
 from dataclasses import dataclass
 from dataclasses import dataclass
 from typing import Optional, Tuple
 from typing import Optional, Tuple
 
 
-from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                                     SchedulerConfig, LoRAConfig, DeviceConfig)
+from aphrodite.common.config import (
+    CacheConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+    DeviceConfig,
+)
 
 
 
 
 @dataclass
 @dataclass
 class EngineArgs:
 class EngineArgs:
-    """Arguments for the Aphrodite engine."""
+    """Arguments for Aphrodite engine."""
+
     model: str
     model: str
     tokenizer: Optional[str] = None
     tokenizer: Optional[str] = None
-    tokenizer_mode: str = 'auto'
+    tokenizer_mode: str = "auto"
     trust_remote_code: bool = False
     trust_remote_code: bool = False
     download_dir: Optional[str] = None
     download_dir: Optional[str] = None
-    load_format: str = 'auto'
-    dtype: str = 'auto'
-    kv_cache_dtype: str = 'auto'
-    kv_quant_params_path: str = None
+    load_format: str = "auto"
+    dtype: str = "auto"
+    kv_cache_dtype: str = "auto"
+    # kv_quant_params_path: str = None
     seed: int = 0
     seed: int = 0
     max_model_len: Optional[int] = None
     max_model_len: Optional[int] = None
     worker_use_ray: bool = False
     worker_use_ray: bool = False
@@ -32,24 +39,26 @@ class EngineArgs:
     max_num_batched_tokens: Optional[int] = None
     max_num_batched_tokens: Optional[int] = None
     max_num_seqs: int = 256
     max_num_seqs: int = 256
     max_paddings: int = 256
     max_paddings: int = 256
-    max_log_probs: int = 10
+    max_log_probs: int = 10  # OpenAI default is 5, setting to 10 because ST
     disable_log_stats: bool = False
     disable_log_stats: bool = False
     revision: Optional[str] = None
     revision: Optional[str] = None
+    code_revision: Optional[str] = None
     tokenizer_revision: Optional[str] = None
     tokenizer_revision: Optional[str] = None
     quantization: Optional[str] = None
     quantization: Optional[str] = None
     load_in_4bit: bool = False
     load_in_4bit: bool = False
     load_in_8bit: bool = False
     load_in_8bit: bool = False
     load_in_smooth: bool = False
     load_in_smooth: bool = False
-    enforce_eager: bool = False
+    enforce_eager: bool = True
     max_context_len_to_capture: int = 8192
     max_context_len_to_capture: int = 8192
     disable_custom_all_reduce: bool = False
     disable_custom_all_reduce: bool = False
     enable_lora: bool = False
     enable_lora: bool = False
     max_loras: int = 1
     max_loras: int = 1
     max_lora_rank: int = 16
     max_lora_rank: int = 16
     lora_extra_vocab_size: int = 256
     lora_extra_vocab_size: int = 256
-    lora_dtype = 'auto'
+    lora_dtype = "auto"
     max_cpu_loras: Optional[int] = None
     max_cpu_loras: Optional[int] = None
-    device: str = 'cuda'
+    device: str = "auto"
+    ray_workers_use_nsight: bool = False
 
 
     def __post_init__(self):
     def __post_init__(self):
         if self.tokenizer is None:
         if self.tokenizer is None:
@@ -65,245 +74,333 @@ class EngineArgs:
 
 
         # Model arguments
         # Model arguments
         parser.add_argument(
         parser.add_argument(
-            '--model',
+            "--model",
             type=str,
             type=str,
-            default='EleutherAI/pythia-70m-deduped',
-            help='name or path of the huggingface model to use')
+            default="EleutherAI/pythia-70m-deduped",
+            help="name or path of the huggingface model to use",
+        )
         parser.add_argument(
         parser.add_argument(
-            '--tokenizer',
+            "--tokenizer",
             type=str,
             type=str,
             default=EngineArgs.tokenizer,
             default=EngineArgs.tokenizer,
-            help='name or path of the huggingface tokenizer to use')
+            help="name or path of the huggingface tokenizer to use",
+        )
         parser.add_argument(
         parser.add_argument(
-            '--revision',
+            "--revision",
             type=str,
             type=str,
             default=None,
             default=None,
-            help='the specific model version to use. It can be a branch '
-            'name, a tag name, or a commit id. If unspecified, will use '
-            'the default version.')
+            help="the specific model version to use. It can be a branch "
+            "name, a tag name, or a commit id. If unspecified, will use "
+            "the default version.",
+        )
         parser.add_argument(
         parser.add_argument(
-            '--tokenizer-revision',
+            "--code-revision",
             type=str,
             type=str,
             default=None,
             default=None,
-            help='the specific tokenizer version to use. It can be a branch '
-            'name, a tag name, or a commit id. If unspecified, will use '
-            'the default version.')
-        parser.add_argument('--tokenizer-mode',
-                            type=str,
-                            default=EngineArgs.tokenizer_mode,
-                            choices=['auto', 'slow'],
-                            help='tokenizer mode. "auto" will use the fast '
-                            'tokenizer if available, and "slow" will '
-                            'always use the slow tokenizer.')
-        parser.add_argument('--trust-remote-code',
-                            action='store_true',
-                            help='trust remote code from huggingface')
-        parser.add_argument('--download-dir',
-                            type=str,
-                            default=EngineArgs.download_dir,
-                            help='directory to download and load the weights, '
-                            'default to the default cache dir of '
-                            'huggingface')
-        parser.add_argument(
-            '--load-format',
+            help="the specific revision to use for the model code on "
+            "Hugging Face Hub. It can be a branch name, a tag name, or a "
+            "commit id. If unspecified, will use the default version.",
+        )
+        parser.add_argument(
+            "--tokenizer-revision",
+            type=str,
+            default=None,
+            help="the specific tokenizer version to use. It can be a branch "
+            "name, a tag name, or a commit id. If unspecified, will use "
+            "the default version.",
+        )
+        parser.add_argument(
+            "--tokenizer-mode",
+            type=str,
+            default=EngineArgs.tokenizer_mode,
+            choices=["auto", "slow"],
+            help='tokenizer mode. "auto" will use the fast '
+            'tokenizer if available, and "slow" will '
+            "always use the slow tokenizer.",
+        )
+        parser.add_argument(
+            "--trust-remote-code",
+            action="store_true",
+            help="trust remote code from huggingface",
+        )
+        parser.add_argument(
+            "--download-dir",
+            type=str,
+            default=EngineArgs.download_dir,
+            help="directory to download and load the weights, "
+            "default to the default cache dir of "
+            "huggingface",
+        )
+        parser.add_argument(
+            "--load-format",
             type=str,
             type=str,
             default=EngineArgs.load_format,
             default=EngineArgs.load_format,
-            choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
-            help='The format of the model weights to load. '
+            choices=["auto", "pt", "safetensors", "npcache", "dummy"],
+            help="The format of the model weights to load. "
             '"auto" will try to load the weights in the safetensors format '
             '"auto" will try to load the weights in the safetensors format '
-            'and fall back to the pytorch bin format if safetensors format '
-            'is not available. '
+            "and fall back to the pytorch bin format if safetensors format "
+            "is not available. "
             '"pt" will load the weights in the pytorch bin format. '
             '"pt" will load the weights in the pytorch bin format. '
             '"safetensors" will load the weights in the safetensors format. '
             '"safetensors" will load the weights in the safetensors format. '
             '"npcache" will load the weights in pytorch format and store '
             '"npcache" will load the weights in pytorch format and store '
-            'a numpy cache to speed up the loading. '
+            "a numpy cache to speed up the loading. "
             '"dummy" will initialize the weights with random values, '
             '"dummy" will initialize the weights with random values, '
-            'which is mainly for profiling.')
+            "which is mainly for profiling.",
+        )
         parser.add_argument(
         parser.add_argument(
-            '--dtype',
+            "--dtype",
             type=str,
             type=str,
             default=EngineArgs.dtype,
             default=EngineArgs.dtype,
             choices=[
             choices=[
-                'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
+                "auto", "half", "float16", "bfloat16", "float", "float32"
             ],
             ],
-            help='data type for model weights and activations. '
+            help="data type for model weights and activations. "
             'The "auto" option will use FP16 precision '
             'The "auto" option will use FP16 precision '
-            'for FP32 and FP16 models, and BF16 precision '
-            'for BF16 models.')
+            "for FP32 and FP16 models, and BF16 precision "
+            "for BF16 models.",
+        )
         parser.add_argument(
         parser.add_argument(
-            '--kv-cache-dtype',
+            "--kv-cache-dtype",
             type=str,
             type=str,
-            choices=['auto', 'fp8_e5m2', 'int8'],
+            # choices=["auto", "fp8_e5m2", "int8"],
+            choices=['auto', 'fp8_e5m2'],
             default=EngineArgs.kv_cache_dtype,
             default=EngineArgs.kv_cache_dtype,
             help='Data type for kv cache storage. If "auto", will use model '
             help='Data type for kv cache storage. If "auto", will use model '
-            'data type. Note FP8 is not supported when cuda version is '
-            'lower than 11.8.')
+            "data type. Note FP8 is not supported when cuda version is "
+            "lower than 11.8.",
+        )
+        # parser.add_argument(
+        #     "--kv-quant-params-path",
+        #     type=str,
+        #     default=EngineArgs.kv_quant_params_path,
+        #     help="Path to scales and zero points of KV cache "
+        #     "quantization. Only applicable when kv-cache-dtype "
+        #     "is int8.",
+        # )
         parser.add_argument(
         parser.add_argument(
-            '--kv-quant-params-path',
-            type=str,
-            default=EngineArgs.kv_quant_params_path,
-            help='Path to scales and zero points of KV cache '
-            'quantization. Only applicable when kv-cache-dtype '
-            'is int8.')
-        parser.add_argument('--max-model-len',
-                            type=int,
-                            default=EngineArgs.max_model_len,
-                            help='model context length. If unspecified, '
-                            'will be automatically derived from the model.')
+            "--max-model-len",
+            type=int,
+            default=EngineArgs.max_model_len,
+            help="model context length. If unspecified, "
+            "will be automatically derived from the model.",
+        )
         # Parallel arguments
         # Parallel arguments
-        parser.add_argument('--worker-use-ray',
-                            action='store_true',
-                            help='use Ray for distributed serving, will be '
-                            'automatically set when using more than 1 GPU')
-        parser.add_argument('--pipeline-parallel-size',
-                            '-pp',
-                            type=int,
-                            default=EngineArgs.pipeline_parallel_size,
-                            help='number of pipeline stages')
-        parser.add_argument('--tensor-parallel-size',
-                            '-tp',
-                            type=int,
-                            default=EngineArgs.tensor_parallel_size,
-                            help='number of tensor parallel replicas')
         parser.add_argument(
         parser.add_argument(
-            '--max-parallel-loading-workers',
+            "--worker-use-ray",
+            action="store_true",
+            help="use Ray for distributed serving, will be "
+            "automatically set when using more than 1 GPU",
+        )
+        parser.add_argument(
+            "--pipeline-parallel-size",
+            "-pp",
+            type=int,
+            default=EngineArgs.pipeline_parallel_size,
+            help="number of pipeline stages",
+        )
+        parser.add_argument(
+            "--tensor-parallel-size",
+            "-tp",
+            type=int,
+            default=EngineArgs.tensor_parallel_size,
+            help="number of tensor parallel replicas",
+        )
+        parser.add_argument(
+            "--max-parallel-loading-workers",
             type=int,
             type=int,
             default=EngineArgs.max_parallel_loading_workers,
             default=EngineArgs.max_parallel_loading_workers,
-            help='load model sequentially in multiple batches, '
-            'to avoid RAM OOM when using tensor '
-            'parallel and large models')
+            help="load model sequentially in multiple batches, "
+            "to avoid RAM OOM when using tensor "
+            "parallel and large models",
+        )
+        parser.add_argument(
+            "--ray-workers-use-nsight",
+            action="store_true",
+            help="If specified, use nsight to profile ray workers",
+        )
         # KV cache arguments
         # KV cache arguments
-        parser.add_argument('--block-size',
-                            type=int,
-                            default=EngineArgs.block_size,
-                            choices=[8, 16, 32],
-                            help='token block size')
-        parser.add_argument('--context-shift',
-                            action='store_true',
-                            help='Enable context shifting.')
-        parser.add_argument('--seed',
+        parser.add_argument(
+            "--block-size",
+            type=int,
+            default=EngineArgs.block_size,
+            choices=[8, 16, 32, 128],
+            help="token block size",
+        )
+        parser.add_argument(
+            "--context-shift",
+            action="store_true",
+            help="Enable context shifting.",
+        )
+        parser.add_argument("--seed",
                             type=int,
                             type=int,
                             default=EngineArgs.seed,
                             default=EngineArgs.seed,
-                            help='random seed')
-        parser.add_argument('--swap-space',
-                            type=int,
-                            default=EngineArgs.swap_space,
-                            help='CPU swap space size (GiB) per GPU')
+                            help="random seed")
         parser.add_argument(
         parser.add_argument(
-            '--gpu-memory-utilization',
-            '-gmu',
+            "--swap-space",
+            type=int,
+            default=EngineArgs.swap_space,
+            help="CPU swap space size (GiB) per GPU",
+        )
+        parser.add_argument(
+            "--gpu-memory-utilization",
+            "-gmu",
             type=float,
             type=float,
             default=EngineArgs.gpu_memory_utilization,
             default=EngineArgs.gpu_memory_utilization,
-            help='the fraction of GPU memory to be used for '
-            'the model executor, which can range from 0 to 1.'
-            'If unspecified, will use the default value of 0.9.')
-        parser.add_argument('--max-num-batched-tokens',
-                            type=int,
-                            default=EngineArgs.max_num_batched_tokens,
-                            help='maximum number of batched tokens per '
-                            'iteration')
-        parser.add_argument('--max-num-seqs',
-                            type=int,
-                            default=EngineArgs.max_num_seqs,
-                            help='maximum number of sequences per iteration')
-        parser.add_argument('--max-paddings',
-                            type=int,
-                            default=EngineArgs.max_paddings,
-                            help='maximum number of paddings in a batch')
-        parser.add_argument('--max-log-probs',
-                            type=int,
-                            default=EngineArgs.max_log_probs,
-                            help='maximum number of log probabilities to '
-                            'return.')
-        parser.add_argument('--disable-log-stats',
-                            action='store_true',
-                            help='disable logging statistics')
+            help="the fraction of GPU memory to be used for "
+            "the model executor, which can range from 0 to 1."
+            "If unspecified, will use the default value of 0.9.",
+        )
+        parser.add_argument(
+            "--max-num-batched-tokens",
+            type=int,
+            default=EngineArgs.max_num_batched_tokens,
+            help="maximum number of batched tokens per "
+            "iteration",
+        )
+        parser.add_argument(
+            "--max-num-seqs",
+            type=int,
+            default=EngineArgs.max_num_seqs,
+            help="maximum number of sequences per iteration",
+        )
+        parser.add_argument(
+            "--max-paddings",
+            type=int,
+            default=EngineArgs.max_paddings,
+            help="maximum number of paddings in a batch",
+        )
+        parser.add_argument(
+            "--max-log-probs",
+            type=int,
+            default=EngineArgs.max_log_probs,
+            help="maximum number of log probabilities to "
+            "return.",
+        )
+        parser.add_argument(
+            "--disable-log-stats",
+            action="store_true",
+            help="disable logging statistics",
+        )
         # Quantization settings.
         # Quantization settings.
-        parser.add_argument('--quantization',
-                            '-q',
-                            type=str,
-                            choices=[
-                                'aqlm', 'awq', 'bnb', 'exl2', 'gguf', 'gptq',
-                                'quip', 'squeezellm', 'marlin', None
-                            ],
-                            default=EngineArgs.quantization,
-                            help='Method used to quantize the weights. If '
-                            'None, we first check the `quantization_config` '
-                            'attribute in the model config file. If that is '
-                            'None, we assume the model weights are not '
-                            'quantized and use `dtype` to determine the data '
-                            'type of the weights.')
-        parser.add_argument('--load-in-4bit',
-                            action='store_true',
-                            help='Load the FP16 model in 4-bit format. Also '
-                            'works with AWQ models. Throughput at 2.5x of '
-                            'FP16.')
-        parser.add_argument('--load-in-8bit',
-                            action='store_true',
-                            help='Load the FP16 model in 8-bit format. '
-                            'Throughput at 0.3x of FP16.')
-        parser.add_argument('--load-in-smooth',
-                            action='store_true',
-                            help='Load the FP16 model in smoothquant '
-                            '8bit format. Throughput at 0.7x of FP16. ')
-        parser.add_argument('--enforce-eager',
-                            action='store_true',
-                            help='Always use eager-mode PyTorch. If False, '
-                            'will use eager mode and CUDA graph in hybrid '
-                            'for maximal performance and flexibility.')
-        parser.add_argument('--max-context-len-to-capture',
-                            type=int,
-                            default=EngineArgs.max_context_len_to_capture,
-                            help='maximum context length covered by CUDA '
-                            'graphs. When a sequence has context length '
-                            'larger than this, we fall back to eager mode.')
-        parser.add_argument('--disable-custom-all-reduce',
-                            action='store_true',
-                            default=EngineArgs.disable_custom_all_reduce,
-                            help='See ParallelConfig')
+        parser.add_argument(
+            "--quantization",
+            "-q",
+            type=str,
+            choices=[
+                "aqlm",
+                "awq",
+                "bnb",
+                "exl2",
+                "gguf",
+                "gptq",
+                "quip",
+                "squeezellm",
+                "marlin",
+                None,
+            ],
+            default=EngineArgs.quantization,
+            help="Method used to quantize the weights. If "
+            "None, we first check the `quantization_config` "
+            "attribute in the model config file. If that is "
+            "None, we assume the model weights are not "
+            "quantized and use `dtype` to determine the data "
+            "type of the weights.",
+        )
+        parser.add_argument(
+            "--load-in-4bit",
+            action="store_true",
+            help="Load the FP16 model in 4-bit format. Also "
+            "works with AWQ models. Throughput at 2.5x of "
+            "FP16.",
+        )
+        parser.add_argument(
+            "--load-in-8bit",
+            action="store_true",
+            help="Load the FP16 model in 8-bit format. "
+            "Throughput at 0.3x of FP16.",
+        )
+        parser.add_argument(
+            "--load-in-smooth",
+            action="store_true",
+            help="Load the FP16 model in smoothquant "
+            "8bit format. Throughput at 0.7x of FP16. ",
+        )
+        parser.add_argument(
+            "--enforce-eager",
+            type=lambda x: (str(x).lower() == 'true'),
+            default=EngineArgs.enforce_eager,
+            help="Always use eager-mode PyTorch. If False, "
+            "will use eager mode and CUDA graph in hybrid "
+            "for maximal performance and flexibility.",
+        )
+        parser.add_argument(
+            "--max-context-len-to-capture",
+            type=int,
+            default=EngineArgs.max_context_len_to_capture,
+            help="maximum context length covered by CUDA "
+            "graphs. When a sequence has context length "
+            "larger than this, we fall back to eager mode.",
+        )
+        parser.add_argument(
+            "--disable-custom-all-reduce",
+            action="store_true",
+            default=EngineArgs.disable_custom_all_reduce,
+            help="See ParallelConfig",
+        )
         # LoRA related configs
         # LoRA related configs
-        parser.add_argument('--enable-lora',
-                            action='store_true',
-                            help='If True, enable handling of LoRA adapters.')
-        parser.add_argument('--max-loras',
-                            type=int,
-                            default=EngineArgs.max_loras,
-                            help='Max number of LoRAs in a single batch.')
-        parser.add_argument('--max-lora-rank',
-                            type=int,
-                            default=EngineArgs.max_lora_rank,
-                            help='Max LoRA rank.')
         parser.add_argument(
         parser.add_argument(
-            '--lora-extra-vocab-size',
+            "--enable-lora",
+            action="store_true",
+            help="If True, enable handling of LoRA adapters.",
+        )
+        parser.add_argument(
+            "--max-loras",
+            type=int,
+            default=EngineArgs.max_loras,
+            help="Max number of LoRAs in a single batch.",
+        )
+        parser.add_argument(
+            "--max-lora-rank",
+            type=int,
+            default=EngineArgs.max_lora_rank,
+            help="Max LoRA rank.",
+        )
+        parser.add_argument(
+            "--lora-extra-vocab-size",
             type=int,
             type=int,
             default=EngineArgs.lora_extra_vocab_size,
             default=EngineArgs.lora_extra_vocab_size,
-            help=('Maximum size of extra vocabulary that can be '
-                  'present in a LoRA adapter (added to the base '
-                  'model vocabulary).'))
+            help=("Maximum size of extra vocabulary that can be "
+                  "present in a LoRA adapter (added to the base "
+                  "model vocabulary)."),
+        )
         parser.add_argument(
         parser.add_argument(
-            '--lora-dtype',
+            "--lora-dtype",
             type=str,
             type=str,
             default=EngineArgs.lora_dtype,
             default=EngineArgs.lora_dtype,
-            choices=['auto', 'float16', 'bfloat16', 'float32'],
-            help=('Data type for LoRA. If auto, will default to '
-                  'base model dtype.'))
+            choices=["auto", "float16", "bfloat16", "float32"],
+            help=("Data type for LoRA. If auto, will default to "
+                  "base model dtype."),
+        )
         parser.add_argument(
         parser.add_argument(
-            '--max-cpu-loras',
+            "--max-cpu-loras",
             type=int,
             type=int,
             default=EngineArgs.max_cpu_loras,
             default=EngineArgs.max_cpu_loras,
-            help=('Maximum number of LoRAs to store in CPU memory. '
-                  'Must be >= than max_num_seqs. '
-                  'Defaults to max_num_seqs.'))
-        parser.add_argument('--device',
-                            type=str,
-                            default=EngineArgs.device,
-                            choices=['cuda'],
-                            help=('Device to use for model execution. '
-                                  'Currently, only "cuda" is supported.'))
+            help=("Maximum number of LoRAs to store in CPU memory. "
+                  "Must be >= than max_num_seqs. "
+                  "Defaults to max_num_seqs."),
+        )
+        parser.add_argument(
+            "--device",
+            type=str,
+            default=EngineArgs.device,
+            choices=["cuda"],
+            help=("Device to use for model execution. "
+                  'Currently, only "cuda" is supported.'),
+        )
         return parser
         return parser
 
 
     @classmethod
     @classmethod
-    def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
+    def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
         # Get the list of attributes of this dataclass.
         # Get the list of attributes of this dataclass.
         attrs = [attr.name for attr in dataclasses.fields(cls)]
         attrs = [attr.name for attr in dataclasses.fields(cls)]
         # Set the attributes from the parsed arguments.
         # Set the attributes from the parsed arguments.
@@ -313,63 +410,99 @@ class EngineArgs:
     def create_engine_configs(
     def create_engine_configs(
         self,
         self,
     ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
     ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
-               DeviceConfig, Optional[LoRAConfig]]:
+               DeviceConfig, Optional[LoRAConfig], ]:
         device_config = DeviceConfig(self.device)
         device_config = DeviceConfig(self.device)
         model_config = ModelConfig(
         model_config = ModelConfig(
-            self.model, self.tokenizer, self.tokenizer_mode,
-            self.trust_remote_code, self.download_dir, self.load_format,
-            self.dtype, self.seed, self.revision, self.tokenizer_revision,
-            self.max_model_len, self.quantization, self.load_in_4bit,
-            self.load_in_8bit, self.load_in_smooth, self.enforce_eager,
-            self.max_context_len_to_capture, self.max_log_probs)
-        cache_config = CacheConfig(self.block_size,
-                                   self.gpu_memory_utilization,
-                                   self.swap_space, self.kv_cache_dtype,
-                                   self.kv_quant_params_path,
-                                   model_config.get_sliding_window(),
-                                   self.context_shift)
-        parallel_config = ParallelConfig(self.pipeline_parallel_size,
-                                         self.tensor_parallel_size,
-                                         self.worker_use_ray,
-                                         self.max_parallel_loading_workers,
-                                         self.disable_custom_all_reduce)
-        scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
-                                           self.max_num_seqs,
-                                           model_config.max_model_len,
-                                           self.max_paddings)
-        lora_config = LoRAConfig(
+            self.model,
+            self.tokenizer,
+            self.tokenizer_mode,
+            self.trust_remote_code,
+            self.download_dir,
+            self.load_format,
+            self.dtype,
+            self.seed,
+            self.revision,
+            self.code_revision,
+            self.tokenizer_revision,
+            self.max_model_len,
+            self.quantization,
+            self.load_in_4bit,
+            self.load_in_8bit,
+            self.load_in_smooth,
+            self.enforce_eager,
+            self.max_context_len_to_capture,
+            self.max_log_probs,
+        )
+        cache_config = CacheConfig(
+            self.block_size,
+            self.gpu_memory_utilization,
+            self.swap_space,
+            self.kv_cache_dtype,
+            # self.kv_quant_params_path,
+            model_config.get_sliding_window(),
+            self.context_shift,
+        )
+        parallel_config = ParallelConfig(
+            self.pipeline_parallel_size,
+            self.tensor_parallel_size,
+            self.worker_use_ray,
+            self.max_parallel_loading_workers,
+            self.disable_custom_all_reduce,
+            self.ray_workers_use_nsight,
+        )
+        scheduler_config = SchedulerConfig(
+            self.max_num_batched_tokens,
+            self.max_num_seqs,
+            model_config.max_model_len,
+            self.max_paddings,
+        )
+        lora_config = (LoRAConfig(
             max_lora_rank=self.max_lora_rank,
             max_lora_rank=self.max_lora_rank,
             max_loras=self.max_loras,
             max_loras=self.max_loras,
             lora_extra_vocab_size=self.lora_extra_vocab_size,
             lora_extra_vocab_size=self.lora_extra_vocab_size,
             lora_dtype=self.lora_dtype,
             lora_dtype=self.lora_dtype,
-            max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
-            and self.max_cpu_loras > 0 else None) if self.enable_lora else None
-        return (model_config, cache_config, parallel_config, scheduler_config,
-                device_config, lora_config)
+            max_cpu_loras=self.max_cpu_loras
+            if self.max_cpu_loras and self.max_cpu_loras > 0 else None,
+        ) if self.enable_lora else None)
+        return (
+            model_config,
+            cache_config,
+            parallel_config,
+            scheduler_config,
+            device_config,
+            lora_config,
+        )
 
 
 
 
 @dataclass
 @dataclass
 class AsyncEngineArgs(EngineArgs):
 class AsyncEngineArgs(EngineArgs):
     """Arguments for asynchronous Aphrodite engine."""
     """Arguments for asynchronous Aphrodite engine."""
+
     engine_use_ray: bool = False
     engine_use_ray: bool = False
     disable_log_requests: bool = False
     disable_log_requests: bool = False
-    max_log_len: Optional[int] = None
+    max_log_len: int = 0
 
 
     @staticmethod
     @staticmethod
     def add_cli_args(
     def add_cli_args(
             parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
             parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
         parser = EngineArgs.add_cli_args(parser)
         parser = EngineArgs.add_cli_args(parser)
-        parser.add_argument('--engine-use-ray',
-                            action='store_true',
-                            help='use Ray to start the LLM engine in a '
-                            'separate process as the server process.')
-        parser.add_argument('--disable-log-requests',
-                            action='store_true',
-                            help='disable logging requests')
-        parser.add_argument('--max-log-len',
-                            type=int,
-                            default=None,
-                            help='max number of prompt characters or prompt '
-                            'ID numbers being printed in log. '
-                            'Default: unlimited.')
+        parser.add_argument(
+            "--engine-use-ray",
+            action="store_true",
+            help="use Ray to start the LLM engine in a "
+            "separate process as the server process.",
+        )
+        parser.add_argument(
+            "--disable-log-requests",
+            action="store_true",
+            help="disable logging requests",
+        )
+        parser.add_argument(
+            "--max-log-len",
+            type=int,
+            default=0,
+            help="max number of prompt characters or prompt "
+            "ID numbers being printed in log. "
+            "Default: unlimited.",
+        )
         return parser
         return parser

+ 45 - 28
aphrodite/engine/async_aphrodite.py

@@ -15,7 +15,7 @@ from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
 from aphrodite.common.sampling_params import SamplingParams
 
 
 ENGINE_ITERATION_TIMEOUT_S = int(
 ENGINE_ITERATION_TIMEOUT_S = int(
-    os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", 60))
+    os.environ.get("APHRODITE_ENGINE_ITERATION_TIMEOUT_S", "120"))
 
 
 
 
 class AsyncEngineDeadError(RuntimeError):
 class AsyncEngineDeadError(RuntimeError):
@@ -26,13 +26,18 @@ def _raise_exception_on_finish(
         task: asyncio.Task, error_callback: Callable[[Exception],
         task: asyncio.Task, error_callback: Callable[[Exception],
                                                      None]) -> None:
                                                      None]) -> None:
     msg = ("Task finished unexpectedly. This should never happen! "
     msg = ("Task finished unexpectedly. This should never happen! "
-           "Please open an issue on Github.")
+           "Please open an issue on Github. Include your full error "
+           "log after killing the process with Ctrl+C.")
 
 
     exception = None
     exception = None
     try:
     try:
         task.result()
         task.result()
         # NOTE: This will be thrown if task exits normally (which it should not)
         # NOTE: This will be thrown if task exits normally (which it should not)
         raise AsyncEngineDeadError(msg)
         raise AsyncEngineDeadError(msg)
+    except asyncio.exceptions.CancelledError:
+        pass
+    except KeyboardInterrupt:
+        raise
     except Exception as e:
     except Exception as e:
         exception = e
         exception = e
         logger.error("Engine background task failed", exc_info=e)
         logger.error("Engine background task failed", exc_info=e)
@@ -318,6 +323,8 @@ class AsyncAphrodite:
             async frontend will be executed in a separate process as the
             async frontend will be executed in a separate process as the
             model workers.
             model workers.
         log_requests: Whether to log the requests.
         log_requests: Whether to log the requests.
+        max_log_len: Maximum number of prompt characters or prompt ID numbers
+            being printed in log.
         start_engine_loop: If True, the background task to run the engine
         start_engine_loop: If True, the background task to run the engine
             will be automatically started in the generate call.
             will be automatically started in the generate call.
         *args: Arguments for AphroditeEngine.
         *args: Arguments for AphroditeEngine.
@@ -331,7 +338,7 @@ class AsyncAphrodite:
                  engine_use_ray: bool,
                  engine_use_ray: bool,
                  *args,
                  *args,
                  log_requests: bool = True,
                  log_requests: bool = True,
-                 max_log_len: Optional[int] = None,
+                 max_log_len: int = 0,
                  start_engine_loop: bool = True,
                  start_engine_loop: bool = True,
                  **kwargs) -> None:
                  **kwargs) -> None:
         self.worker_use_ray = worker_use_ray
         self.worker_use_ray = worker_use_ray
@@ -456,23 +463,27 @@ class AsyncAphrodite:
 
 
     async def run_engine_loop(self):
     async def run_engine_loop(self):
         has_requests_in_progress = False
         has_requests_in_progress = False
-        while True:
-            if not has_requests_in_progress:
-                logger.debug("Waiting for new requests...")
-                await self._request_tracker.wait_for_new_requests()
-                logger.debug("Got new requests!")
-
-            # Abort if iteration takes too long due to unrecoverable errors
-            # (eg. NCCL timeouts).
-            try:
-                has_requests_in_progress = await asyncio.wait_for(
-                    self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
-            except asyncio.TimeoutError as exc:
-                logger.error(
-                    "Engine iteration timed out. This should never happen!")
-                self.set_errored(exc)
-                raise
-            await asyncio.sleep(0)
+        try:
+            while True:
+                if not has_requests_in_progress:
+                    logger.debug("Waiting for new requests...")
+                    await self._request_tracker.wait_for_new_requests()
+                    logger.debug("Got new requests!")
+
+                # Abort if iteration takes too long due to unrecoverable errors
+                # (eg. NCCL timeouts).
+                try:
+                    has_requests_in_progress = await asyncio.wait_for(
+                        self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
+                except asyncio.TimeoutError as exc:
+                    logger.error(
+                        "Engine iteration timed out. This should never happen!"
+                    )
+                    self.set_errored(exc)
+                    raise
+                await asyncio.sleep(0)
+        except KeyboardInterrupt:
+            logger.info("Engine loop interrupted. Exiting gracefully.")
 
 
     async def add_request(
     async def add_request(
         self,
         self,
@@ -494,8 +505,7 @@ class AsyncAphrodite:
                                                               max_log_len]
                                                               max_log_len]
             logger.info(f"Received request {request_id}: "
             logger.info(f"Received request {request_id}: "
                         f"prompt: {shortened_prompt!r}, "
                         f"prompt: {shortened_prompt!r}, "
-                        f"sampling params: {sampling_params}, "
-                        f"prompt token ids: {shortened_token_ids}, "
+                        f"sampling_params: {sampling_params}, "
                         f"lora_request: {lora_request}.")
                         f"lora_request: {lora_request}.")
 
 
         if not self.is_running:
         if not self.is_running:
@@ -510,6 +520,7 @@ class AsyncAphrodite:
 
 
         if arrival_time is None:
         if arrival_time is None:
             arrival_time = time.time()
             arrival_time = time.time()
+
         if self.engine_use_ray:
         if self.engine_use_ray:
             prompt_token_ids = await self.engine.encode_request_async.remote(
             prompt_token_ids = await self.engine.encode_request_async.remote(
                 request_id=request_id,
                 request_id=request_id,
@@ -609,15 +620,21 @@ class AsyncAphrodite:
         arrival_time = time.monotonic()
         arrival_time = time.monotonic()
 
 
         try:
         try:
-            stream = await self.add_request(request_id,
-                                            prompt,
-                                            sampling_params,
-                                            prompt_token_ids=prompt_token_ids,
-                                            arrival_time=arrival_time,
-                                            lora_request=lora_request)
+            stream = await self.add_request(
+                request_id,
+                prompt,
+                sampling_params,
+                prompt_token_ids=prompt_token_ids,
+                arrival_time=arrival_time,
+                lora_request=lora_request,
+            )
 
 
             async for request_output in stream:
             async for request_output in stream:
                 yield request_output
                 yield request_output
+        except asyncio.exceptions.CancelledError:
+            logger.info(f"Request {request_id} cancelled.")
+            self._abort(request_id)
+            raise
         except (Exception, asyncio.CancelledError) as e:
         except (Exception, asyncio.CancelledError) as e:
             # If there is an exception or coroutine is cancelled, abort the
             # If there is an exception or coroutine is cancelled, abort the
             # request.
             # request.

+ 79 - 21
aphrodite/engine/metrics.py

@@ -1,11 +1,18 @@
+from loguru import logger
+from prometheus_client import (
+    Counter,
+    Gauge,
+    Histogram,
+    Info,
+    REGISTRY,
+    disable_created_metrics,
+)
+
 import time
 import time
 import numpy as np
 import numpy as np
 from typing import Dict, List
 from typing import Dict, List
 from dataclasses import dataclass
 from dataclasses import dataclass
 
 
-from prometheus_client import Counter, Gauge, Histogram, disable_created_metrics
-from loguru import logger
-
 disable_created_metrics()
 disable_created_metrics()
 
 
 # The begin-* and end* here are used by the documentation generator
 # The begin-* and end* here are used by the documentation generator
@@ -16,58 +23,104 @@ disable_created_metrics()
 class Metrics:
 class Metrics:
 
 
     def __init__(self, labelnames: List[str]):
     def __init__(self, labelnames: List[str]):
+        # Unregister any existing Aphrodite collectors
+        for collector in list(REGISTRY._collector_to_names):
+            if hasattr(collector, "_name") and "aphrodite" in collector._name:
+                REGISTRY.unregister(collector)
+
+        # Config Information
+        self.info_cache_config = Info(
+            name="aphrodite:cache_config",
+            documentation="information of cache_config",
+        )
+
         # System stats
         # System stats
         self.gauge_scheduler_running = Gauge(
         self.gauge_scheduler_running = Gauge(
             name="aphrodite:num_requests_running",
             name="aphrodite:num_requests_running",
             documentation="Number of requests currently running on GPU.",
             documentation="Number of requests currently running on GPU.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.gauge_scheduler_swapped = Gauge(
         self.gauge_scheduler_swapped = Gauge(
             name="aphrodite:num_requests_swapped",
             name="aphrodite:num_requests_swapped",
             documentation="Number of requests swapped to CPU.",
             documentation="Number of requests swapped to CPU.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.gauge_scheduler_waiting = Gauge(
         self.gauge_scheduler_waiting = Gauge(
             name="aphrodite:num_requests_waiting",
             name="aphrodite:num_requests_waiting",
             documentation="Number of requests waiting to be processed.",
             documentation="Number of requests waiting to be processed.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.gauge_gpu_cache_usage = Gauge(
         self.gauge_gpu_cache_usage = Gauge(
             name="aphrodite:gpu_cache_usage_perc",
             name="aphrodite:gpu_cache_usage_perc",
             documentation="GPU KV-cache usage. 1 means 100 percent usage.",
             documentation="GPU KV-cache usage. 1 means 100 percent usage.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.gauge_cpu_cache_usage = Gauge(
         self.gauge_cpu_cache_usage = Gauge(
             name="aphrodite:cpu_cache_usage_perc",
             name="aphrodite:cpu_cache_usage_perc",
             documentation="CPU KV-cache usage. 1 means 100 percent usage.",
             documentation="CPU KV-cache usage. 1 means 100 percent usage.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
 
 
         # Raw stats from last model iteration
         # Raw stats from last model iteration
         self.counter_prompt_tokens = Counter(
         self.counter_prompt_tokens = Counter(
             name="aphrodite:prompt_tokens_total",
             name="aphrodite:prompt_tokens_total",
             documentation="Number of prefill tokens processed.",
             documentation="Number of prefill tokens processed.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.counter_generation_tokens = Counter(
         self.counter_generation_tokens = Counter(
             name="aphrodite:generation_tokens_total",
             name="aphrodite:generation_tokens_total",
             documentation="Number of generation tokens processed.",
             documentation="Number of generation tokens processed.",
-            labelnames=labelnames)
+            labelnames=labelnames,
+        )
         self.histogram_time_to_first_token = Histogram(
         self.histogram_time_to_first_token = Histogram(
             name="aphrodite:time_to_first_token_seconds",
             name="aphrodite:time_to_first_token_seconds",
             documentation="Histogram of time to first token in seconds.",
             documentation="Histogram of time to first token in seconds.",
             labelnames=labelnames,
             labelnames=labelnames,
             buckets=[
             buckets=[
-                0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
-                0.75, 1.0, 2.5, 5.0, 7.5, 10.0
-            ])
+                0.001,
+                0.005,
+                0.01,
+                0.02,
+                0.04,
+                0.06,
+                0.08,
+                0.1,
+                0.25,
+                0.5,
+                0.75,
+                1.0,
+                2.5,
+                5.0,
+                7.5,
+                10.0,
+            ],
+        )
         self.histogram_time_per_output_token = Histogram(
         self.histogram_time_per_output_token = Histogram(
             name="aphrodite:time_per_output_token_seconds",
             name="aphrodite:time_per_output_token_seconds",
             documentation="Histogram of time per output token in seconds.",
             documentation="Histogram of time per output token in seconds.",
             labelnames=labelnames,
             labelnames=labelnames,
             buckets=[
             buckets=[
-                0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
-                1.0, 2.5
-            ])
+                0.01,
+                0.025,
+                0.05,
+                0.075,
+                0.1,
+                0.15,
+                0.2,
+                0.3,
+                0.4,
+                0.5,
+                0.75,
+                1.0,
+                2.5,
+            ],
+        )
         self.histogram_e2e_request_latency = Histogram(
         self.histogram_e2e_request_latency = Histogram(
             name="aphrodite:e2e_request_latency_seconds",
             name="aphrodite:e2e_request_latency_seconds",
             documentation="Histogram of end to end request latency in seconds.",
             documentation="Histogram of end to end request latency in seconds.",
             labelnames=labelnames,
             labelnames=labelnames,
-            buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
+            buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0],
+        )
 
 
         # Legacy metrics
         # Legacy metrics
         self.gauge_avg_prompt_throughput = Gauge(
         self.gauge_avg_prompt_throughput = Gauge(
@@ -88,6 +141,7 @@ class Metrics:
 @dataclass
 @dataclass
 class Stats:
 class Stats:
     """Created by AphroditeEngine for use by StatLogger."""
     """Created by AphroditeEngine for use by StatLogger."""
+
     now: float
     now: float
 
 
     # System stats.
     # System stats.
@@ -121,6 +175,10 @@ class StatLogger:
         self.labels = labels
         self.labels = labels
         self.metrics = Metrics(labelnames=list(labels.keys()))
         self.metrics = Metrics(labelnames=list(labels.keys()))
 
 
+    def info(self, type: str, obj: object) -> None:
+        if type == "cache_config":
+            self.metrics.info_cache_config.info(obj.metrics_info())
+
     def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
     def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
         return float(np.sum(tracked_stats) / (now - self.last_local_log))
         return float(np.sum(tracked_stats) / (now - self.last_local_log))
 
 
@@ -174,8 +232,8 @@ class StatLogger:
 
 
     def log(self, stats: Stats) -> None:
     def log(self, stats: Stats) -> None:
         """Called by AphroditeEngine.
         """Called by AphroditeEngine.
-           Logs to prometheus and tracked stats every iteration.
-           Logs to Stdout every self.local_interval seconds."""
+        Logs to prometheus and tracked stats every iteration.
+        Logs to Stdout every self.local_interval seconds."""
 
 
         # Log to prometheus.
         # Log to prometheus.
         self._log_prometheus(stats)
         self._log_prometheus(stats)
@@ -186,7 +244,6 @@ class StatLogger:
 
 
         # Log locally every local_interval seconds.
         # Log locally every local_interval seconds.
         if self._local_interval_elapsed(stats.now):
         if self._local_interval_elapsed(stats.now):
-
             # Compute summary metrics for tracked stats (and log them to
             # Compute summary metrics for tracked stats (and log them to
             # prometheus if applicable).
             # prometheus if applicable).
             prompt_throughput = self._get_throughput(self.num_prompt_tokens,
             prompt_throughput = self._get_throughput(self.num_prompt_tokens,
@@ -195,7 +252,8 @@ class StatLogger:
                 self.num_generation_tokens, now=stats.now)
                 self.num_generation_tokens, now=stats.now)
             self._log_prometheus_interval(
             self._log_prometheus_interval(
                 prompt_throughput=prompt_throughput,
                 prompt_throughput=prompt_throughput,
-                generation_throughput=generation_throughput)
+                generation_throughput=generation_throughput,
+            )
 
 
             # Log to stdout.
             # Log to stdout.
             logger.info(
             logger.info(

+ 21 - 5
aphrodite/engine/ray_tools.py

@@ -1,3 +1,5 @@
+import pickle
+
 from typing import Optional, List, Tuple, TYPE_CHECKING
 from typing import Optional, List, Tuple, TYPE_CHECKING
 from loguru import logger
 from loguru import logger
 
 
@@ -13,10 +15,14 @@ try:
 
 
         def __init__(self, init_cached_hf_modules=False) -> None:
         def __init__(self, init_cached_hf_modules=False) -> None:
             if init_cached_hf_modules:
             if init_cached_hf_modules:
-                # pylint: disable=import-outside-toplevel
                 from transformers.dynamic_module_utils import init_hf_modules
                 from transformers.dynamic_module_utils import init_hf_modules
                 init_hf_modules()
                 init_hf_modules()
             self.worker = None
             self.worker = None
+            # Since the compiled DAG runs a main execution
+            # in a different thread that calls cuda.set_device.
+            # The flag indicates is set_device is called on
+            # that thread.
+            self.compiled_dag_cuda_device_set = False
 
 
         def init_worker(self, worker_init_fn):
         def init_worker(self, worker_init_fn):
             self.worker = worker_init_fn()
             self.worker = worker_init_fn()
@@ -39,6 +45,17 @@ try:
         def set_cuda_visible_devices(self, device_ids) -> None:
         def set_cuda_visible_devices(self, device_ids) -> None:
             set_cuda_visible_devices(device_ids)
             set_cuda_visible_devices(device_ids)
 
 
+        def execute_model_compiled_dag_remote(self, ignored):
+            """Used only when compiled DAG is enabled."""
+            import torch
+            if not self.compiled_dag_cuda_device_set:
+                torch.cuda.set_device(self.worker.device)
+                self.compiled_dag_cuda_device_set = True
+
+            output = self.worker.execute_model()
+            output = pickle.dumps(output)
+            return output
+
 except ImportError as e:
 except ImportError as e:
     logger.warning(f"Failed to import Ray with {e!r}. "
     logger.warning(f"Failed to import Ray with {e!r}. "
                    "For distributed inference, please install Ray with "
                    "For distributed inference, please install Ray with "
@@ -64,10 +81,9 @@ def initialize_cluster(
             the default Ray cluster address.
             the default Ray cluster address.
 
 
     Returns:
     Returns:
-        A tuple of (`distributed_init_method`, `placement_group`). The
-        `distributed_init_method` is the address for initializing the
-        distributed backend. `placement_group` includes the specification
-        of the resources for each distributed worker.
+        An optional `PlacementGroup`. It includes the specification
+        of the resources for each distributed worker. None if Ray is
+        not used.
     """
     """
     if parallel_config.worker_use_ray or engine_use_ray:
     if parallel_config.worker_use_ray or engine_use_ray:
         if ray is None:
         if ray is None:

+ 0 - 0
aphrodite/modeling/layers/triton_kernel/__init__.py → aphrodite/executor/__init__.py


+ 76 - 0
aphrodite/executor/executor_base.py

@@ -0,0 +1,76 @@
+from abc import ABC, abstractmethod
+from typing import Dict, List, Optional
+
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig)
+from aphrodite.lora.request import LoRARequest
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+
+
+class ExecutorBase(ABC):
+    """Base class for all executors.
+
+    An executor is responsible for executing the model on a specific device
+    type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
+    that can execute the model on multiple devices.
+    """
+
+    @abstractmethod
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        lora_config: Optional[LoRAConfig],
+    ) -> None:
+        raise NotImplementedError
+
+    @abstractmethod
+    def execute_model(self,
+                      seq_group_metadata_list: List[SequenceGroupMetadata],
+                      blocks_to_swap_in: Dict[int, int],
+                      blocks_to_swap_out: Dict[int, int],
+                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+        """Executes one model step on the given sequences."""
+        raise NotImplementedError
+
+    @abstractmethod
+    def add_lora(self, lora_request: LoRARequest) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def remove_lora(self, lora_id: int) -> bool:
+        raise NotImplementedError
+
+    @abstractmethod
+    def list_loras(self) -> List[int]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def check_health(self) -> None:
+        """Checks if the executor is healthy. If not, it should raise an
+        exception."""
+        raise NotImplementedError
+
+
+class ExecutorAsyncBase(ExecutorBase):
+
+    @abstractmethod
+    async def execute_model_async(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
+        """Executes one model step on the given sequences."""
+        raise NotImplementedError
+
+    @abstractmethod
+    async def check_health_async(self) -> None:
+        """Checks if the executor is healthy. If not, it should raise an
+        exception."""
+        raise NotImplementedError

+ 153 - 0
aphrodite/executor/gpu_executor.py

@@ -0,0 +1,153 @@
+from typing import Dict, List, Optional
+
+from loguru import logger
+
+from aphrodite.lora.request import LoRARequest
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig)
+from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
+from aphrodite.executor.utils import check_block_size_valid
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.utils import (get_ip, get_open_port,
+                                    get_distributed_init_method, make_async)
+
+
+class GPUExecutor(ExecutorBase):
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        lora_config: Optional[LoRAConfig],
+    ) -> None:
+        self.model_config = model_config
+        self.cache_config = cache_config
+        self.lora_config = lora_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+
+        # Instantiate the worker and load the model to GPU.
+        self._init_worker()
+
+        # Profile the memory usage and initialize the cache.
+        self._init_cache()
+
+    def _init_worker(self):
+        # Lazy import the Worker to avoid importing torch.cuda/xformers
+        # before CUDA_VISIBLE_DEVICES is set in the Worker
+        from aphrodite.task_handler.worker import Worker
+
+        assert self.parallel_config.world_size == 1, (
+            "GPUExecutor only supports single GPU.")
+
+        distributed_init_method = get_distributed_init_method(
+            get_ip(), get_open_port())
+        self.driver_worker = Worker(
+            self.model_config,
+            self.parallel_config,
+            self.scheduler_config,
+            self.device_config,
+            local_rank=0,
+            rank=0,
+            distributed_init_method=distributed_init_method,
+            lora_config=self.lora_config,
+            kv_cache_dtype=self.cache_config.cache_dtype,
+            is_driver_worker=True,
+        )
+        self.driver_worker.init_device()
+        self.driver_worker.load_model()
+
+    def _init_cache(self) -> None:
+        """Profiles the memory usage and initializes the KV cache.
+
+        The engine first profiles the existing memory usage.
+        Then, it allocates the remaining memory for KV blocks.
+
+        .. tip::
+            You may limit the usage of GPU memory
+            by adjusting the `gpu_memory_utilization` parameter.
+        """
+        # Get the maximum number of blocks that can be allocated on GPU and CPU.
+        num_gpu_blocks, num_cpu_blocks = (
+            self.driver_worker.profile_num_available_blocks(
+                block_size=self.cache_config.block_size,
+                gpu_memory_utilization=self.cache_config.
+                gpu_memory_utilization,
+                cpu_swap_space=self.cache_config.swap_space_bytes,
+                cache_dtype=self.cache_config.cache_dtype,
+            ))
+
+        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
+                    f"# CPU blocks: {num_cpu_blocks}")
+
+        logger.info(
+            f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"  # noqa: E501
+        )
+
+        check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
+                               self.model_config.max_model_len)
+
+        self.cache_config.num_gpu_blocks = num_gpu_blocks
+        self.cache_config.num_cpu_blocks = num_cpu_blocks
+
+        # Initialize the cache.
+        self.driver_worker.init_cache_engine(cache_config=self.cache_config)
+        # Warm up the model. This includes capturing the model into CUDA graph
+        # if enforce_eager is False.
+        self.driver_worker.warm_up_model()
+
+    def execute_model(self,
+                      seq_group_metadata_list: List[SequenceGroupMetadata],
+                      blocks_to_swap_in: Dict[int, int],
+                      blocks_to_swap_out: Dict[int, int],
+                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+        output = self.driver_worker.execute_model(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+        )
+        return output
+
+    def add_lora(self, lora_request: LoRARequest) -> bool:
+        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
+        return self.driver_worker.add_lora(lora_request)
+
+    def remove_lora(self, lora_id: int) -> bool:
+        assert lora_id > 0, "lora_id must be greater than 0."
+        return self.driver_worker.remove_lora(lora_id)
+
+    def list_loras(self) -> List[int]:
+        return self.driver_worker.list_loras()
+
+    def check_health(self) -> None:
+        # GPUExecutor will always be healthy as long as
+        # it's running.
+        return
+
+
+class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
+
+    async def execute_model_async(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
+        output = await make_async(self.driver_worker.execute_model)(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy)
+        return output
+
+    async def check_health_async(self) -> None:
+        # GPUExecutor will always be healthy as long as
+        # it's running.
+        return

+ 78 - 0
aphrodite/executor/neuron_executor.py

@@ -0,0 +1,78 @@
+from typing import Dict, List, Optional
+
+from aphrodite.lora.request import LoRARequest
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig)
+from aphrodite.executor.executor_base import ExecutorBase
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+
+
+class NeuronExecutor(ExecutorBase):
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        lora_config: Optional[LoRAConfig],
+    ) -> None:
+        self.model_config = model_config
+        self.cache_config = cache_config
+        assert lora_config is None, "LoRA is not supported for Neuron backend."
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+
+        # Set the number of GPU blocks to be the same as the maximum number of
+        # sequences that can be processed in a single batch. This is equivalent
+        # to schedule without PagedAttention.
+        self.cache_config.num_gpu_blocks = self.scheduler_config.max_num_seqs
+        self.cache_config.num_cpu_blocks = 0
+
+        # Instantiate the worker and load the model to the device.
+        self._init_worker()
+
+    def _init_worker(self):
+        from aphrodite.task_handler.neuron_worker import NeuronWorker
+
+        self.driver_worker = NeuronWorker(
+            self.model_config,
+            self.parallel_config,
+            self.scheduler_config,
+            self.device_config,
+        )
+        self.driver_worker.init_device()
+        self.driver_worker.load_model()
+
+    def execute_model(self,
+                      seq_group_metadata_list: List[SequenceGroupMetadata],
+                      blocks_to_swap_in: Dict[int, int],
+                      blocks_to_swap_out: Dict[int, int],
+                      blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
+        assert (blocks_to_swap_in == {} and blocks_to_swap_out == {}
+                and blocks_to_copy == {}), (
+                    "Cache operations are not supported for Neuron backend.")
+
+        output = self.driver_worker.execute_model(
+            seq_group_metadata_list=seq_group_metadata_list)
+        return output
+
+    def add_lora(self, lora_request: LoRARequest) -> bool:
+        raise NotImplementedError(
+            "LoRA is not implemented for neuron backend.")
+
+    def remove_lora(self, lora_id: int) -> bool:
+        raise NotImplementedError(
+            "LoRA is not implemented for neuron backend.")
+
+    def list_loras(self) -> List[int]:
+        raise NotImplementedError(
+            "LoRA is not implemented for neuron backend.")
+
+    def check_health(self) -> None:
+        # NeuronExecutor will always be healthy as long as
+        # it's running.
+        return

+ 452 - 0
aphrodite/executor/ray_gpu_executor.py

@@ -0,0 +1,452 @@
+import asyncio
+import copy
+from collections import defaultdict
+import os
+import pickle
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from loguru import logger
+
+from aphrodite.common.config import (
+    CacheConfig,
+    DeviceConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+)
+from aphrodite.engine.ray_tools import RayWorkerAphrodite, ray
+from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
+from aphrodite.executor.utils import check_block_size_valid
+from aphrodite.lora.request import LoRARequest
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.common.utils import (
+    set_cuda_visible_devices,
+    get_ip,
+    get_open_port,
+    get_distributed_init_method,
+    make_async,
+)
+
+if ray is not None:
+    from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
+
+if TYPE_CHECKING:
+    from ray.util.placement_group import PlacementGroup
+
+# If the env var is set, it uses the Ray's compiled DAG API
+# which optimizes the control plane overhead.
+# Run Aphrodite with APHRODITE_USE_RAY_COMPILED_DAG=1 to enable it.
+USE_RAY_COMPILED_DAG = bool(os.getenv("APHRODITE_USE_RAY_COMPILED_DAG", 0))
+
+
+class RayGPUExecutor(ExecutorBase):
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        cache_config: CacheConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        lora_config: Optional[LoRAConfig],
+    ) -> None:
+        self.model_config = model_config
+        self.cache_config = cache_config
+        self.lora_config = lora_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+
+        assert self.parallel_config.worker_use_ray
+        placement_group = self.parallel_config.placement_group
+
+        # Disable Ray usage stats collection.
+        ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
+        if ray_usage != "1":
+            os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
+
+        # Create the parallel GPU workers.
+        self._init_workers_ray(placement_group)
+
+        # Profile the memory usage and initialize the cache.
+        self._init_cache()
+
+        self.forward_dag = None
+        if USE_RAY_COMPILED_DAG:
+            self.forward_dag = self._compiled_ray_dag()
+
+    def _init_workers_ray(self, placement_group: "PlacementGroup",
+                          **ray_remote_kwargs):
+        if self.parallel_config.tensor_parallel_size == 1:
+            # For single GPU case, we use a ray worker with constrained memory.
+            num_gpus = self.cache_config.gpu_memory_utilization
+        else:
+            # Otherwise, the ray workers are allocated with a full GPU.
+            num_gpus = 1
+
+        # The driver dummy worker does not actually use any resources.
+        # It holds the resource for the driver worker.
+        self.driver_dummy_worker: RayWorkerAphrodite = None
+        # The remaining workers are the actual ray actors.
+        self.workers: List[RayWorkerAphrodite] = []
+
+        # Create the workers.
+        driver_ip = get_ip()
+        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
+            if not bundle.get("GPU", 0):
+                continue
+            scheduling_strategy = PlacementGroupSchedulingStrategy(
+                placement_group=placement_group,
+                placement_group_capture_child_tasks=True,
+                placement_group_bundle_index=bundle_id,
+            )
+            worker = ray.remote(
+                num_cpus=0,
+                num_gpus=num_gpus,
+                scheduling_strategy=scheduling_strategy,
+                **ray_remote_kwargs,
+            )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
+
+            worker_ip = ray.get(worker.get_node_ip.remote())
+            if worker_ip == driver_ip and self.driver_dummy_worker is None:
+                # If the worker is on the same node as the driver, we use it
+                # as the resource holder for the driver process.
+                self.driver_dummy_worker = worker
+            else:
+                # Else, added to the list of workers.
+                self.workers.append(worker)
+
+        if self.driver_dummy_worker is None:
+            raise ValueError(
+                "Ray does not allocate any GPUs on the driver node. Consider "
+                "adjusting the Ray placement group or running the driver on a "
+                "GPU node.")
+
+        # Get the set of GPU IDs used on each node.
+        driver_node_id, driver_gpu_ids = ray.get(
+            self.driver_dummy_worker.get_node_and_gpu_ids.remote())
+        worker_node_and_gpu_ids = ray.get(
+            [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
+
+        node_workers = defaultdict(list)
+        node_gpus = defaultdict(list)
+
+        node_workers[driver_node_id].append(0)
+        node_gpus[driver_node_id].extend(driver_gpu_ids)
+        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
+                                               start=1):
+            node_workers[node_id].append(i)
+            node_gpus[node_id].extend(gpu_ids)
+        for node_id, gpu_ids in node_gpus.items():
+            node_gpus[node_id] = sorted(gpu_ids)
+
+        # Set CUDA_VISIBLE_DEVICES for the driver and workers.
+        set_cuda_visible_devices(node_gpus[driver_node_id])
+        for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
+            worker.set_cuda_visible_devices.remote(node_gpus[node_id])
+
+        distributed_init_method = get_distributed_init_method(
+            driver_ip, get_open_port())
+
+        # Lazy import the Worker to avoid importing torch.cuda/xformers
+        # before CUDA_VISIBLE_DEVICES is set in the Worker
+        from aphrodite.task_handler.worker import Worker
+
+        model_config = copy.deepcopy(self.model_config)
+        parallel_config = copy.deepcopy(self.parallel_config)
+        scheduler_config = copy.deepcopy(self.scheduler_config)
+        device_config = copy.deepcopy(self.device_config)
+        lora_config = copy.deepcopy(self.lora_config)
+        kv_cache_dtype = self.cache_config.cache_dtype
+
+        # Initialize the actual workers with the Worker class.
+        for rank, (worker, (node_id, _)) in enumerate(
+                zip(self.workers, worker_node_and_gpu_ids),
+                start=1,
+        ):
+            local_rank = node_workers[node_id].index(rank)
+            worker.init_worker.remote(
+                lambda rank=rank, local_rank=local_rank: Worker(
+                    model_config,
+                    parallel_config,
+                    scheduler_config,
+                    device_config,
+                    local_rank,
+                    rank,
+                    distributed_init_method,
+                    lora_config=lora_config,
+                    kv_cache_dtype=kv_cache_dtype,
+                ))
+
+        # Initialize the driver worker with the Worker class.
+        driver_rank = 0
+        driver_local_rank = node_workers[driver_node_id].index(driver_rank)
+        self.driver_worker = Worker(
+            self.model_config,
+            self.parallel_config,
+            self.scheduler_config,
+            self.device_config,
+            driver_local_rank,
+            driver_rank,
+            distributed_init_method,
+            lora_config=self.lora_config,
+            kv_cache_dtype=kv_cache_dtype,
+            is_driver_worker=True,
+        )
+
+        # FIXME(woosuk): We are not properly initializing cupy NCCL when
+        # we have multiple nodes.
+        self._run_workers(
+            "init_device",
+            cupy_port=get_open_port()
+            if not model_config.enforce_eager else None,
+        )
+        self._run_workers(
+            "load_model",
+            max_concurrent_workers=self.parallel_config.
+            max_parallel_loading_workers,
+        )
+
+    def _init_cache(self) -> None:
+        """Profiles the memory usage and initializes the KV cache.
+
+        The engine will first conduct a profiling of the existing memory usage.
+        Then, it calculate the maximum possible number of GPU and CPU blocks
+        that can be allocated with the remaining free memory.
+        More details can be found in the
+        :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
+        from class :class:`~aphrodite.task_handler.Worker`.
+
+        Afterwards, as there may be multiple workers,
+        we take the minimum number of blocks across all workers
+        to ensure this can be applied to all of them.
+
+        Finally, the engine will initialize the KV cache
+        with the calculated number of blocks.
+
+        .. tip::
+            You may limit the usage of GPU memory
+            by adjusting the `gpu_memory_utilization` parameter.
+        """  # noqa: E501
+        # Get the maximum number of blocks that can be allocated on GPU and CPU.
+        num_blocks = self._run_workers(
+            "profile_num_available_blocks",
+            block_size=self.cache_config.block_size,
+            gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
+            cpu_swap_space=self.cache_config.swap_space_bytes,
+            cache_dtype=self.cache_config.cache_dtype,
+        )
+
+        # Since we use a shared centralized controller, we take the minimum
+        # number of blocks across all workers to make sure all the memory
+        # operators can be applied to all workers.
+        num_gpu_blocks = min(b[0] for b in num_blocks)
+        num_cpu_blocks = min(b[1] for b in num_blocks)
+        logger.info(f"# GPU blocks: {num_gpu_blocks}, "
+                    f"# CPU blocks: {num_cpu_blocks}")
+
+        logger.info(
+            f"Minimum concurrency: {num_gpu_blocks * self.cache_config.block_size / self.scheduler_config.max_model_len:.2f}x"  # noqa: E501
+        )
+
+        check_block_size_valid(
+            num_gpu_blocks,
+            self.cache_config.block_size,
+            self.model_config.max_model_len,
+        )
+
+        self.cache_config.num_gpu_blocks = num_gpu_blocks
+        self.cache_config.num_cpu_blocks = num_cpu_blocks
+
+        # Initialize the cache.
+        self._run_workers("init_cache_engine", cache_config=self.cache_config)
+        # Warm up the model. This includes capturing the model into CUDA graph
+        # if enforce_eager is False.
+        self._run_workers("warm_up_model")
+
+    def execute_model(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
+        all_outputs = self._run_workers(
+            "execute_model",
+            driver_kwargs={
+                "seq_group_metadata_list": seq_group_metadata_list,
+                "blocks_to_swap_in": blocks_to_swap_in,
+                "blocks_to_swap_out": blocks_to_swap_out,
+                "blocks_to_copy": blocks_to_copy,
+            },
+            use_ray_compiled_dag=USE_RAY_COMPILED_DAG,
+        )
+
+        # Only the driver worker returns the sampling results.
+        output = all_outputs[0]
+        return output
+
+    def add_lora(self, lora_request: LoRARequest) -> bool:
+        assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
+        return self._run_workers(
+            "add_lora",
+            lora_request=lora_request,
+        )
+
+    def remove_lora(self, lora_id: int) -> bool:
+        assert lora_id > 0, "lora_id must be greater than 0."
+        return self._run_workers(
+            "remove_lora",
+            lora_id=lora_id,
+        )
+
+    def list_loras(self) -> List[int]:
+        return self._run_workers("list_loras")
+
+    def _run_workers(
+        self,
+        method: str,
+        *args,
+        driver_args: Optional[List[Any]] = None,
+        driver_kwargs: Optional[Dict[str, Any]] = None,
+        max_concurrent_workers: Optional[int] = None,
+        use_ray_compiled_dag: bool = False,
+        **kwargs,
+    ) -> Any:
+        """Runs the given method on all workers."""
+
+        if max_concurrent_workers:
+            raise NotImplementedError(
+                "max_concurrent_workers is not supported yet.")
+
+        if use_ray_compiled_dag:
+            # Right now, compiled DAG can only accept a single
+            # input. TODO(sang): Fix it.
+            output_channels = self.forward_dag.execute(1)
+        else:
+            # Start the ray workers first.
+            ray_worker_outputs = [
+                worker.execute_method.remote(method, *args, **kwargs)
+                for worker in self.workers
+            ]
+
+        if driver_args is None:
+            driver_args = args
+        if driver_kwargs is None:
+            driver_kwargs = kwargs
+
+        # Start the driver worker after all the ray workers.
+        driver_worker_output = getattr(self.driver_worker,
+                                       method)(*driver_args, **driver_kwargs)
+
+        # Get the results of the ray workers.
+        if self.workers:
+            if use_ray_compiled_dag:
+                try:
+                    ray_worker_outputs = [
+                        pickle.loads(chan.begin_read())
+                        for chan in output_channels
+                    ]
+                finally:
+                    # Has to call end_read in order to reuse the DAG.
+                    for chan in output_channels:
+                        chan.end_read()
+            else:
+                ray_worker_outputs = ray.get(ray_worker_outputs)
+
+        return [driver_worker_output] + ray_worker_outputs
+
+    def _compiled_ray_dag(self):
+        import pkg_resources
+
+        required_version = "2.9"
+        current_version = pkg_resources.get_distribution("ray").version
+        if current_version < required_version:
+            raise ValueError(f"Ray version {required_version} or greater is "
+                             f"required, but found {current_version}")
+
+        from ray.dag import MultiOutputNode, InputNode
+
+        assert self.parallel_config.worker_use_ray
+
+        # Right now, compiled DAG requires at least 1 arg. We send
+        # a dummy value for now. It will be fixed soon.
+        with InputNode() as input_data:
+            forward_dag = MultiOutputNode([
+                worker.execute_model_compiled_dag_remote.bind(input_data)
+                for worker in self.workers
+            ])
+        return forward_dag.experimental_compile()
+
+    def check_health(self) -> None:
+        """Raises an error if engine is unhealthy."""
+        self._check_if_any_actor_is_dead()
+
+    def _check_if_any_actor_is_dead(self):
+        if not self.workers:
+            return
+
+        dead_actors = []
+        for actor in self.workers:
+            actor_state = ray.state.actors(actor._ray_actor_id.hex())  # pylint: disable=protected-access
+            if actor_state["State"] == "DEAD":
+                dead_actors.append(actor)
+        if dead_actors:
+            raise RuntimeError("At least one Worker is dead. "
+                               f"Dead Workers: {dead_actors}. ")
+
+
+class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
+
+    async def _run_workers_async(
+        self,
+        method: str,
+        *args,
+        driver_args: Optional[List[Any]] = None,
+        driver_kwargs: Optional[Dict[str, Any]] = None,
+        **kwargs,
+    ) -> Any:
+        """Runs the given method on all workers."""
+        coros = []
+
+        if driver_args is None:
+            driver_args = args
+        if driver_kwargs is None:
+            driver_kwargs = kwargs
+
+        # Run the driver worker asynchronously.
+        driver_executor = make_async(getattr(self.driver_worker, method))
+        coros.append(driver_executor(*driver_args, **driver_kwargs))
+
+        # Run the ray workers asynchronously.
+        for worker in self.workers:
+            coros.append(worker.execute_method.remote(method, *args, **kwargs))
+
+        all_outputs = await asyncio.gather(*coros)
+        return all_outputs
+
+    async def execute_model_async(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> SamplerOutput:
+        all_outputs = await self._run_workers_async(
+            "execute_model",
+            driver_kwargs={
+                "seq_group_metadata_list": seq_group_metadata_list,
+                "blocks_to_swap_in": blocks_to_swap_in,
+                "blocks_to_swap_out": blocks_to_swap_out,
+                "blocks_to_copy": blocks_to_copy,
+            },
+        )
+
+        # Only the driver worker returns the sampling results.
+        output = all_outputs[0]
+        return output
+
+    async def check_health_async(self) -> None:
+        """Raises an error if engine is unhealthy."""
+        self._check_if_any_actor_is_dead()

+ 13 - 0
aphrodite/executor/utils.py

@@ -0,0 +1,13 @@
+def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None:
+    if num_gpu_blocks <= 0:
+        raise ValueError("No available memory for the cache blocks. "
+                         "Try increasing `gpu_memory_utilization` when "
+                         "initializing the engine.")
+    max_seq_len = block_size * num_gpu_blocks
+    if max_model_len > max_seq_len:
+        raise ValueError(
+            f"The model's max seq len ({max_model_len}) "
+            "is larger than the maximum number of tokens that can be "
+            f"stored in KV cache ({max_seq_len}). Try increasing "
+            "`gpu_memory_utilization` or decreasing `max_model_len` when "
+            "initializing the engine.")

+ 4 - 0
aphrodite/lora/layers.py

@@ -824,6 +824,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
         self.dtype = dtype
         self.dtype = dtype
         self.device = device
         self.device = device
 
 
+    @property
+    def logits_as_hidden_states(self):
+        return self.base_layer.logits_as_hidden_states
+
     @property
     @property
     def vocab_size(self):
     def vocab_size(self):
         return self.base_layer.vocab_size
         return self.base_layer.vocab_size

+ 8 - 3
aphrodite/modeling/hf_downloader.py

@@ -21,15 +21,19 @@ from aphrodite.common.gguf import GGUFReader
 from aphrodite.modeling.layers.quantization import (get_quantization_config,
 from aphrodite.modeling.layers.quantization import (get_quantization_config,
                                                     QuantizationConfig)
                                                     QuantizationConfig)
 
 
+_xdg_cache_home = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache'))
+_aphrodite_filelocks_path = os.path.join(_xdg_cache_home, 'aphrodite/locks/')
 
 
-class Disabledtqdm(tqdm):  # pylint: disable=inconsistent-mro
+
+class Disabledtqdm(tqdm):
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs, disable=True)
         super().__init__(*args, **kwargs, disable=True)
 
 
 
 
 def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
 def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
-    lock_dir = cache_dir if cache_dir is not None else "/tmp"
+    lock_dir = cache_dir if cache_dir is not None else _aphrodite_filelocks_path
+    os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
     lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
     lock_file_name = model_name_or_path.replace("/", "-") + ".lock"
     lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
     lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name))
     return lock
     return lock
@@ -164,7 +168,7 @@ def prepare_hf_model_weights(
                 allow_patterns = [pattern]
                 allow_patterns = [pattern]
                 break
                 break
 
 
-        logger.info(f"Downloading model weights {allow_patterns}")
+        logger.info(f"Using model weights format {allow_patterns}")
         # Use file lock to prevent multiple processes from
         # Use file lock to prevent multiple processes from
         # downloading the same model weights at the same time.
         # downloading the same model weights at the same time.
         with get_lock(model_name_or_path, cache_dir):
         with get_lock(model_name_or_path, cache_dir):
@@ -192,6 +196,7 @@ def prepare_hf_model_weights(
             "scheduler.pt",
             "scheduler.pt",
             "scaler.pt",
             "scaler.pt",
             "trainer_state.json",
             "trainer_state.json",
+            "hidden_states.safetensors",  # exllamav2
         ]
         ]
         hf_weights_files = [
         hf_weights_files = [
             f for f in hf_weights_files
             f for f in hf_weights_files

+ 0 - 354
aphrodite/modeling/layers/attention.py

@@ -1,354 +0,0 @@
-"""Multi-head attention."""
-from typing import List, Optional
-
-import importlib
-import torch
-import torch.nn as nn
-from xformers import ops as xops
-from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
-                                         LowerTriangularMaskWithTensorBias)
-
-from aphrodite._C import ops
-from aphrodite._C import cache_ops
-from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.triton_kernel.prefix_prefill import (
-    context_attention_fwd)
-from aphrodite.common.utils import is_hip
-
-_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
-# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
-_PARTITION_SIZE = 512
-
-
-class PagedAttention(nn.Module):
-    """MHA/MQA/GQA layer with PagedAttention.
-
-    This class takes query, key, and value tensors as input. The input tensors
-    can either contain prompt tokens or generation tokens.
-    The class does the following:
-
-    1. Reshape and store the input key and value tensors in the KV cache.
-    2. Perform (multi-head/multi-query/grouped-query) attention using either
-        xformers or the PagedAttention custom op.
-    3. Return the output tensor.
-    """
-
-    def __init__(
-        self,
-        num_heads: int,
-        head_size: int,
-        scale: float,
-        num_kv_heads: Optional[int] = None,
-        alibi_slopes: Optional[List[float]] = None,
-        sliding_window: Optional[int] = None,
-    ) -> None:
-        super().__init__()
-        self.num_heads = num_heads
-        self.head_size = head_size
-        self.scale = float(scale)
-        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
-        self.sliding_window = sliding_window
-        if alibi_slopes is not None:
-            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
-        self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
-
-        assert self.num_heads % self.num_kv_heads == 0
-        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
-
-        if self.head_size not in _SUPPORTED_HEAD_SIZES:
-            raise ValueError(f"head_size ({self.head_size}) is not supported. "
-                             f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
-
-        self.use_ref_attention = self.check_use_ref_attention()
-
-    def check_use_ref_attention(self) -> bool:
-        if not is_hip():
-            return False
-        # For ROCm, check whether flash attention is installed or not.
-        # if not, use_ref_attention needs to be True
-        return importlib.util.find_spec("flash_attn") is None
-
-    def ref_masked_attention(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-    ) -> torch.Tensor:
-        query = query.view(-1, self.num_heads, self.head_size)
-        key = key.view(-1, self.num_kv_heads, self.head_size)
-        value = value.view(-1, self.num_kv_heads, self.head_size)
-
-        seq_len, _, _ = query.shape
-        attn_mask = torch.triu(torch.ones(seq_len,
-                                          seq_len,
-                                          dtype=query.dtype,
-                                          device=query.device),
-                               diagonal=1)
-        attn_mask = attn_mask * torch.finfo(query.dtype).min
-
-        attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
-                                                 key).float()
-        attn_weights = attn_weights + attn_mask.float()
-        attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
-        out = torch.einsum("hqk,khd->qhd", attn_weights, value)
-        return out
-
-    def forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        key_cache: Optional[torch.Tensor],
-        value_cache: Optional[torch.Tensor],
-        input_metadata: InputMetadata,
-        kv_quant_param: List[float] = None,
-    ) -> torch.Tensor:
-        """PagedAttention forward pass.
-
-        Args:
-            query: shape = [batch_size, seq_len, num_heads * head_size]
-            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
-            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
-            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
-                block_size, x]
-            value_cache: shape = [num_blocks, num_kv_heads, head_size,
-                block_size]
-            input_metadata: metadata for the inputs.
-        Returns:
-            shape = [batch_size, seq_len, num_heads * head_size]
-        """
-        batch_size, seq_len, hidden_size = query.shape
-        # Reshape the query, key, and value tensors.
-        query = query.view(-1, self.num_heads, self.head_size)
-        key = key.view(-1, self.num_kv_heads, self.head_size)
-        value = value.view(-1, self.num_kv_heads, self.head_size)
-        # FIXME: Remove this when all models support int8 kv cache
-        kv_quant_param = [1.0, 0.0, 1.0, 0.0
-                          ] if kv_quant_param is None else kv_quant_param
-
-        # Reshape the keys and values and store them in the cache.
-        # If key_cache and value_cache are not provided, the new key and value
-        # vectors will not be cached. This happens during the initial memory
-        # profiling run.
-        if key_cache is not None and value_cache is not None:
-            cache_ops.reshape_and_cache(
-                key,
-                value,
-                key_cache,
-                value_cache,
-                input_metadata.slot_mapping.flatten(),
-                input_metadata.kv_cache_dtype,
-                *kv_quant_param,
-            )
-
-        if input_metadata.is_prompt:
-            # Prompt run.
-            if self.num_kv_heads != self.num_heads:
-                # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
-                # project the key and value tensors to the desired number of
-                # heads.
-                # TODO: Use MQA/GQA kernels for higher performance.
-                query = query.view(query.shape[0], self.num_kv_heads,
-                                   self.num_queries_per_kv, query.shape[-1])
-                key = key[:, :,
-                          None, :].expand(key.shape[0], self.num_kv_heads,
-                                          self.num_queries_per_kv,
-                                          key.shape[-1])
-                value = value[:, :, None, :].expand(value.shape[0],
-                                                    self.num_kv_heads,
-                                                    self.num_queries_per_kv,
-                                                    value.shape[-1])
-            # normal attention
-            if (key_cache is None or value_cache is None
-                    or input_metadata.block_tables.numel() == 0):
-                # Set attention bias if not provided. This typically happens at
-                # the very attention layer of every iteration.
-                # FIXME: This is a hack.
-                if input_metadata.attn_bias is None:
-                    if self.alibi_slopes is None:
-                        attn_bias = BlockDiagonalCausalMask.from_seqlens(
-                            [seq_len] * batch_size)
-                        if self.sliding_window is not None:
-                            attn_bias = attn_bias.make_local_attention(
-                                self.sliding_window)
-                        input_metadata.attn_bias = attn_bias
-                    else:
-                        input_metadata.attn_bias = _make_alibi_bias(
-                            self.alibi_slopes, self.num_kv_heads, batch_size,
-                            seq_len, query.dtype)
-
-                if self.use_ref_attention:
-                    output = self.ref_masked_attention(
-                        query,
-                        key,
-                        value,
-                    )
-                    return output.reshape(batch_size, seq_len, hidden_size)
-
-                # TODO: Too many view operations. Let's try to reduce
-                # them in the future for code readability.
-                if self.alibi_slopes is None:
-                    query = query.unsqueeze(0)
-                    key = key.unsqueeze(0)
-                    value = value.unsqueeze(0)
-                else:
-                    query = query.unflatten(0, (batch_size, seq_len))
-                    key = key.unflatten(0, (batch_size, seq_len))
-                    value = value.unflatten(0, (batch_size, seq_len))
-
-                out = xops.memory_efficient_attention_forward(
-                    query,
-                    key,
-                    value,
-                    attn_bias=input_metadata.attn_bias,
-                    p=0.0,
-                    scale=self.scale,
-                    op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
-                    (is_hip()) else None,
-                )
-                output = out.view_as(query)
-            else:
-                # prefix-enabled attention
-                output = torch.empty_like(query)
-                context_attention_fwd(
-                    query,
-                    key,
-                    value,
-                    output,
-                    key_cache,
-                    value_cache,
-                    input_metadata.block_tables,  # [BS, max_block_per_request]
-                    input_metadata.start_loc,
-                    input_metadata.prompt_lens,
-                    input_metadata.context_lens,
-                    input_metadata.max_seq_len,
-                    getattr(self, "alibi_slopes", None),
-                )
-
-        else:
-            # Decoding run.
-            output = _paged_attention(
-                query,
-                key_cache,
-                value_cache,
-                input_metadata,
-                self.num_kv_heads,
-                self.scale,
-                self.alibi_slopes,
-                kv_quant_param,
-            )
-
-        # Reshape the output tensor.
-        return output.view(batch_size, seq_len, hidden_size)
-
-
-def _make_alibi_bias(
-    alibi_slopes: torch.Tensor,
-    num_kv_heads: int,
-    batch_size: int,
-    seq_len: int,
-    dtype: torch.dtype,
-) -> LowerTriangularMaskWithTensorBias:
-    bias = torch.arange(seq_len, dtype=dtype)
-    # NOTE: HF uses
-    #     `bias = bias[None, :].repeat(prompt_len, 1)`
-    # here. We find that both biases give the same results, but
-    # the bias below more accurately follows the original ALiBi
-    # paper.
-    bias = bias[None, :] - bias[:, None]
-
-    # When using custom attention bias, xformers requires the bias to
-    # be sliced from a tensor whose length is a multiple of 8.
-    padded_len = (seq_len + 7) // 8 * 8
-    num_heads = alibi_slopes.shape[0]
-    bias = torch.empty(
-        batch_size,
-        num_heads,
-        seq_len,
-        padded_len,
-        device=alibi_slopes.device,
-        dtype=dtype,
-    )[:, :, :, :seq_len].copy_(bias)
-    bias.mul_(alibi_slopes[:, None, None])
-    if num_heads != num_kv_heads:
-        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
-    attn_bias = LowerTriangularMaskWithTensorBias(bias)
-    return attn_bias
-
-
-def _paged_attention(
-    query: torch.Tensor,
-    key_cache: torch.Tensor,
-    value_cache: torch.Tensor,
-    input_metadata: InputMetadata,
-    num_kv_heads: int,
-    scale: float,
-    alibi_slopes: Optional[torch.Tensor],
-    kv_quant_param: List[float],
-) -> torch.Tensor:
-    output = torch.empty_like(query)
-
-    block_size = value_cache.shape[3]
-    num_seqs, num_heads, head_size = query.shape
-    max_num_partitions = (
-        (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
-        _PARTITION_SIZE)
-    # NOTE: We use a simple heuristic to decide whether to use
-    # PagedAttention V1 or V2. If the number of partitions is 1, we use
-    # V1 to avoid the overhead of reduction. Also, if the number of
-    # sequences or heads is large, we use V1 since there is enough work
-    # to parallelize.
-    # TODO: Tune this heuristic.
-    # For context len > 8192, use V2 kernel to avoid shared memory shortage.
-    use_v1 = input_metadata.max_context_len <= 8192 and (
-        max_num_partitions == 1 or num_seqs * num_heads > 512)
-    if use_v1:
-        # Run PagedAttention V1.
-        ops.paged_attention_v1(
-            output,
-            query,
-            key_cache,
-            value_cache,
-            num_kv_heads,
-            scale,
-            input_metadata.block_tables,
-            input_metadata.context_lens,
-            block_size,
-            input_metadata.max_context_len,
-            alibi_slopes,
-            input_metadata.kv_cache_dtype,
-            *kv_quant_param,
-        )
-    else:
-        # Run PagedAttention V2.
-        assert _PARTITION_SIZE % block_size == 0
-        tmp_output = torch.empty(
-            size=(num_seqs, num_heads, max_num_partitions, head_size),
-            dtype=output.dtype,
-            device=output.device,
-        )
-        exp_sums = torch.empty(
-            size=(num_seqs, num_heads, max_num_partitions),
-            dtype=torch.float32,
-            device=output.device,
-        )
-        max_logits = torch.empty_like(exp_sums)
-        ops.paged_attention_v2(
-            output,
-            exp_sums,
-            max_logits,
-            tmp_output,
-            query,
-            key_cache,
-            value_cache,
-            num_kv_heads,
-            scale,
-            input_metadata.block_tables,
-            input_metadata.context_lens,
-            block_size,
-            input_metadata.max_context_len,
-            alibi_slopes,
-            input_metadata.kv_cache_dtype,
-            *kv_quant_param,
-        )
-    return output

+ 93 - 0
aphrodite/modeling/layers/attention/__init__.py

@@ -0,0 +1,93 @@
+"""Attention layer."""
+from functools import lru_cache
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+from loguru import logger
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.common.utils import is_hip
+
+
+class Attention(nn.Module):
+    """Attention layer.
+
+    This class takes query, key, and value tensors as input. The input tensors
+    can either contain prompt tokens or generation tokens.
+    The class does the following:
+
+    1. Store the input key and value tensors in the KV cache.
+    2. Perform (multi-head/multi-query/grouped-query) attention.
+    3. Return the output tensor.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        super().__init__()
+        if _use_flash_attn():
+            from aphrodite.modeling.layers.attention.backends.flash_attn import FlashAttentionBackend  # noqa: E501
+
+            self.backend = FlashAttentionBackend(
+                num_heads,
+                head_size,
+                scale,
+                num_kv_heads,
+                alibi_slopes,
+                sliding_window,
+            )
+        else:
+            from aphrodite.modeling.layers.attention.backends.xformers import XFormersBackend  # noqa: E501
+
+            self.backend = XFormersBackend(
+                num_heads,
+                head_size,
+                scale,
+                num_kv_heads,
+                alibi_slopes,
+                sliding_window,
+            )
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: Optional[torch.Tensor],
+        value_cache: Optional[torch.Tensor],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        return self.backend.forward(query, key, value, key_cache, value_cache,
+                                    input_metadata)
+
+
+@lru_cache(maxsize=1)
+def _use_flash_attn() -> bool:
+    try:
+        import flash_attn  # noqa: F401
+    except ImportError:
+        logger.info("flash_attn is not found. Using xformers backend.")
+        return False
+
+    if is_hip():
+        # AMD GPUs.
+        return False
+    if torch.cuda.get_device_capability()[0] < 8:
+        logger.info("flash_attn is not supported on Turing or older GPUs. "
+                    "Using xformers backend.")
+        return False
+    if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
+        logger.info(
+            "flash_attn only supports torch.float16 or torch.bfloat16. "
+            "Using xformers backend.")
+        return False
+
+    logger.info("Using Flash Attention backend.")
+    return True

+ 0 - 0
aphrodite/modeling/layers/attention/backends/__init__.py


+ 121 - 0
aphrodite/modeling/layers/attention/backends/flash_attn.py

@@ -0,0 +1,121 @@
+"""Attention layer with Flash and PagedAttention."""
+from typing import List, Optional
+
+from flash_attn import flash_attn_func
+import torch
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.attention.ops.paged_attn import (
+    PagedAttentionImpl)
+
+
+class FlashAttentionBackend:
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+        self.sliding_window = sliding_window
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+        suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+        self.sliding_window = ((self.sliding_window, self.sliding_window) if
+                               self.sliding_window is not None else (-1, -1))
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: Optional[torch.Tensor],
+        value_cache: Optional[torch.Tensor],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        """Forward pass with FlashAttention and PagedAttention.
+
+        Args:
+            query: shape = [batch_size, seq_len, num_heads * head_size]
+            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
+            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
+            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
+                block_size, x]
+            value_cache: shape = [num_blocks, num_kv_heads, head_size,
+                block_size]
+            input_metadata: metadata for the inputs.
+        Returns:
+            shape = [batch_size, seq_len, num_heads * head_size]
+        """
+        batch_size, seq_len, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        # Reshape the keys and values and store them in the cache.
+        # If key_cache and value_cache are not provided, the new key and value
+        # vectors will not be cached. This happens during the initial memory
+        # profiling run.
+        if key_cache is not None and value_cache is not None:
+            PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
+                                                 value_cache, input_metadata)
+
+        if input_metadata.is_prompt:
+            # Prompt run.
+            if (key_cache is None or value_cache is None
+                    or input_metadata.block_tables.numel() == 0):
+                # normal attention
+                query = query.unflatten(0, (batch_size, seq_len))
+                key = key.unflatten(0, (batch_size, seq_len))
+                value = value.unflatten(0, (batch_size, seq_len))
+                output = flash_attn_func(
+                    query,
+                    key,
+                    value,
+                    softmax_scale=self.scale,
+                    causal=True,
+                    window_size=self.sliding_window,
+                    alibi_slopes=self.alibi_slopes,
+                )
+            else:
+                # prefix-enabled attention
+                output = PagedAttentionImpl.forward_prefix(
+                    query,
+                    key,
+                    value,
+                    key_cache,
+                    value_cache,
+                    input_metadata,
+                    self.alibi_slopes,
+                )
+        else:
+            # Decoding run.
+            output = PagedAttentionImpl.forward_decode(
+                query,
+                key_cache,
+                value_cache,
+                input_metadata,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+            )
+
+        # Reshape the output tensor.
+        return output.view(batch_size, seq_len, hidden_size)

+ 255 - 0
aphrodite/modeling/layers/attention/backends/xformers.py

@@ -0,0 +1,255 @@
+"""Attention layer with xFormers and PagedAttention."""
+import importlib
+from typing import List, Optional
+
+import torch
+from xformers import ops as xops
+from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
+                                         LowerTriangularMaskWithTensorBias)
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.attention.ops.paged_attn import (
+    PagedAttentionImpl)
+from aphrodite.common.utils import is_hip
+
+
+class XFormersBackend:
+
+    def __init__(
+        self,
+        num_heads: int,
+        head_size: int,
+        scale: float,
+        num_kv_heads: Optional[int] = None,
+        alibi_slopes: Optional[List[float]] = None,
+        sliding_window: Optional[int] = None,
+    ) -> None:
+        self.num_heads = num_heads
+        self.head_size = head_size
+        self.scale = float(scale)
+        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
+        self.sliding_window = sliding_window
+        if alibi_slopes is not None:
+            alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
+        self.alibi_slopes = alibi_slopes
+
+        assert self.num_heads % self.num_kv_heads == 0
+        self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+        suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
+        if head_size not in suppored_head_sizes:
+            raise ValueError(
+                f"Head size {head_size} is not supported by PagedAttention. "
+                f"Supported head sizes are: {suppored_head_sizes}.")
+
+        self.use_ref_attention = _check_use_ref_attention()
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: Optional[torch.Tensor],
+        value_cache: Optional[torch.Tensor],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        """Forward pass with xFormers and PagedAttention.
+
+        Args:
+            query: shape = [batch_size, seq_len, num_heads * head_size]
+            key: shape = [batch_size, seq_len, num_kv_heads * head_size]
+            value: shape = [batch_size, seq_len, num_kv_heads * head_size]
+            key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
+                block_size, x]
+            value_cache: shape = [num_blocks, num_kv_heads, head_size,
+                block_size]
+            input_metadata: metadata for the inputs.
+        Returns:
+            shape = [batch_size, seq_len, num_heads * head_size]
+        """
+        batch_size, seq_len, hidden_size = query.shape
+        # Reshape the query, key, and value tensors.
+        query = query.view(-1, self.num_heads, self.head_size)
+        key = key.view(-1, self.num_kv_heads, self.head_size)
+        value = value.view(-1, self.num_kv_heads, self.head_size)
+
+        # Reshape the keys and values and store them in the cache.
+        # If key_cache and value_cache are not provided, the new key and value
+        # vectors will not be cached. This happens during the initial memory
+        # profiling run.
+        if key_cache is not None and value_cache is not None:
+            PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
+                                                 value_cache, input_metadata)
+
+        if input_metadata.is_prompt:
+            # Prompt run.
+            if (key_cache is None or value_cache is None
+                    or input_metadata.block_tables.numel() == 0):
+                # normal attention
+                if self.num_kv_heads != self.num_heads:
+                    # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
+                    # project the key and value tensors to the desired number of
+                    # heads.
+                    # TODO: Use MQA/GQA kernels for higher performance.
+                    query = query.view(query.shape[0], self.num_kv_heads,
+                                       self.num_queries_per_kv,
+                                       query.shape[-1])
+                    key = key[:, :,
+                              None, :].expand(key.shape[0], self.num_kv_heads,
+                                              self.num_queries_per_kv,
+                                              key.shape[-1])
+                    value = value[:, :,
+                                  None, :].expand(value.shape[0],
+                                                  self.num_kv_heads,
+                                                  self.num_queries_per_kv,
+                                                  value.shape[-1])
+
+                # Set attention bias if not provided. This typically happens at
+                # the very attention layer of every iteration.
+                # FIXME: This is a hack.
+                if input_metadata.attn_bias is None:
+                    if self.alibi_slopes is None:
+                        attn_bias = BlockDiagonalCausalMask.from_seqlens(
+                            [seq_len] * batch_size)
+                        if self.sliding_window is not None:
+                            attn_bias = attn_bias.make_local_attention(
+                                self.sliding_window)
+                        input_metadata.attn_bias = attn_bias
+                    else:
+                        input_metadata.attn_bias = _make_alibi_bias(
+                            self.alibi_slopes, self.num_kv_heads, batch_size,
+                            seq_len, query.dtype)
+
+                if self.use_ref_attention:
+                    output = _ref_masked_attention(
+                        query,
+                        key,
+                        value,
+                        self.num_heads,
+                        self.num_kv_heads,
+                        self.head_size,
+                        self.scale,
+                    )
+                    # Using view got RuntimeError: view size is not compatible
+                    # with input tensor's size and stride (at least one
+                    # dimension spans across two contiguous subspaces).
+                    # Use reshape instead.
+                    return output.reshape(batch_size, seq_len, hidden_size)
+
+                # TODO: Too many view operations. Let's try to reduce
+                # them in the future for code readability.
+                if self.alibi_slopes is None:
+                    query = query.unsqueeze(0)
+                    key = key.unsqueeze(0)
+                    value = value.unsqueeze(0)
+                else:
+                    query = query.unflatten(0, (batch_size, seq_len))
+                    key = key.unflatten(0, (batch_size, seq_len))
+                    value = value.unflatten(0, (batch_size, seq_len))
+
+                out = xops.memory_efficient_attention_forward(
+                    query,
+                    key,
+                    value,
+                    attn_bias=input_metadata.attn_bias,
+                    p=0.0,
+                    scale=self.scale,
+                    op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
+                    (is_hip()) else None,
+                )
+                output = out.view_as(query)
+
+            else:
+                # prefix-enabled attention
+                output = PagedAttentionImpl.forward_prefix(
+                    query,
+                    key,
+                    value,
+                    key_cache,
+                    value_cache,
+                    input_metadata,
+                    self.alibi_slopes,
+                )
+        else:
+            # Decoding run.
+            output = PagedAttentionImpl.forward_decode(
+                query,
+                key_cache,
+                value_cache,
+                input_metadata,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+            )
+
+        # Reshape the output tensor.
+        return output.view(batch_size, seq_len, hidden_size)
+
+
+def _make_alibi_bias(
+    alibi_slopes: torch.Tensor,
+    num_kv_heads: int,
+    batch_size: int,
+    seq_len: int,
+    dtype: torch.dtype,
+) -> LowerTriangularMaskWithTensorBias:
+    bias = torch.arange(seq_len, dtype=dtype)
+    # NOTE: HF uses
+    #     `bias = bias[None, :].repeat(prompt_len, 1)`
+    # here. We find that both biases give the same results, but
+    # the bias below more accurately follows the original ALiBi
+    # paper.
+    bias = bias[None, :] - bias[:, None]
+
+    # When using custom attention bias, xformers requires the bias to
+    # be sliced from a tensor whose length is a multiple of 8.
+    padded_len = (seq_len + 7) // 8 * 8
+    num_heads = alibi_slopes.shape[0]
+    bias = torch.empty(
+        batch_size,
+        num_heads,
+        seq_len,
+        padded_len,
+        device=alibi_slopes.device,
+        dtype=dtype,
+    )[:, :, :, :seq_len].copy_(bias)
+    bias.mul_(alibi_slopes[:, None, None])
+    if num_heads != num_kv_heads:
+        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
+    attn_bias = LowerTriangularMaskWithTensorBias(bias)
+    return attn_bias
+
+
+def _check_use_ref_attention() -> bool:
+    if not is_hip():
+        return False
+    # For ROCm, check whether flash attention is installed or not.
+    # if not, use_ref_attention needs to be True
+    return importlib.util.find_spec("flash_attn") is None
+
+
+def _ref_masked_attention(
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    num_heads: int,
+    num_kv_heads: int,
+    head_size: int,
+    scale: float,
+) -> torch.Tensor:
+    query = query.view(-1, num_heads, head_size)
+    key = key.view(-1, num_kv_heads, head_size)
+    value = value.view(-1, num_kv_heads, head_size)
+
+    seq_len, _, _ = query.shape
+    attn_mask = torch.triu(torch.ones(seq_len,
+                                      seq_len,
+                                      dtype=query.dtype,
+                                      device=query.device),
+                           diagonal=1)
+    attn_mask = attn_mask * torch.finfo(query.dtype).min
+
+    attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
+    attn_weights = attn_weights + attn_mask.float()
+    attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
+    out = torch.einsum("hqk,khd->qhd", attn_weights, value)
+    return out

+ 0 - 0
aphrodite/modeling/layers/attention/ops/__init__.py


+ 138 - 0
aphrodite/modeling/layers/attention/ops/paged_attn.py

@@ -0,0 +1,138 @@
+from typing import List, Optional
+
+import torch
+
+from aphrodite._C import cache_ops
+from aphrodite._C import ops
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.attention.ops.prefix_prefill import (
+    context_attention_fwd)
+
+# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
+_PARTITION_SIZE = 512
+
+
+class PagedAttentionImpl:
+
+    @staticmethod
+    def get_supported_head_sizes() -> List[int]:
+        return [64, 80, 96, 112, 128, 256]
+
+    @staticmethod
+    def reshape_and_cache(
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        input_metadata: InputMetadata,
+    ) -> None:
+        cache_ops.reshape_and_cache(
+            key,
+            value,
+            key_cache,
+            value_cache,
+            input_metadata.slot_mapping.flatten(),
+            input_metadata.kv_cache_dtype,
+        )
+
+    @staticmethod
+    def forward_decode(
+        query: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        input_metadata: InputMetadata,
+        num_kv_heads: int,
+        scale: float,
+        alibi_slopes: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        output = torch.empty_like(query)
+
+        block_size = value_cache.shape[3]
+        num_seqs, num_heads, head_size = query.shape
+        max_num_partitions = (
+            (input_metadata.max_context_len + _PARTITION_SIZE - 1) //
+            _PARTITION_SIZE)
+        # NOTE: We use a simple heuristic to decide whether to use
+        # PagedAttention V1 or V2. If the number of partitions is 1, we use
+        # V1 to avoid the overhead of reduction. Also, if the number of
+        # sequences or heads is large, we use V1 since there is enough work
+        # to parallelize.
+        # TODO: Tune this heuristic.
+        # For context len > 8192, use V2 kernel to avoid shared memory shortage.
+        use_v1 = input_metadata.max_context_len <= 8192 and (
+            max_num_partitions == 1 or num_seqs * num_heads > 512)
+        if use_v1:
+            # Run PagedAttention V1.
+            ops.paged_attention_v1(
+                output,
+                query,
+                key_cache,
+                value_cache,
+                num_kv_heads,
+                scale,
+                input_metadata.block_tables,
+                input_metadata.context_lens,
+                block_size,
+                input_metadata.max_context_len,
+                alibi_slopes,
+                input_metadata.kv_cache_dtype,
+            )
+        else:
+            # Run PagedAttention V2.
+            assert _PARTITION_SIZE % block_size == 0
+            tmp_output = torch.empty(
+                size=(num_seqs, num_heads, max_num_partitions, head_size),
+                dtype=output.dtype,
+                device=output.device,
+            )
+            exp_sums = torch.empty(
+                size=(num_seqs, num_heads, max_num_partitions),
+                dtype=torch.float32,
+                device=output.device,
+            )
+            max_logits = torch.empty_like(exp_sums)
+            ops.paged_attention_v2(
+                output,
+                exp_sums,
+                max_logits,
+                tmp_output,
+                query,
+                key_cache,
+                value_cache,
+                num_kv_heads,
+                scale,
+                input_metadata.block_tables,
+                input_metadata.context_lens,
+                block_size,
+                input_metadata.max_context_len,
+                alibi_slopes,
+                input_metadata.kv_cache_dtype,
+            )
+        return output
+
+    @staticmethod
+    def forward_prefix(
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        key_cache: torch.Tensor,
+        value_cache: torch.Tensor,
+        input_metadata: InputMetadata,
+        alibi_slopes: Optional[torch.Tensor],
+    ) -> torch.Tensor:
+        output = torch.empty_like(query)
+        context_attention_fwd(
+            query,
+            key,
+            value,
+            output,
+            key_cache,
+            value_cache,
+            input_metadata.block_tables,  # [BS, max_block_per_request]
+            input_metadata.start_loc,
+            input_metadata.prompt_lens,
+            input_metadata.context_lens,
+            input_metadata.max_seq_len,
+            alibi_slopes,
+        )
+        return output

+ 28 - 12
aphrodite/modeling/layers/triton_kernel/prefix_prefill.py → aphrodite/modeling/layers/attention/ops/prefix_prefill.py

@@ -45,6 +45,7 @@ if triton.__version__ >= "2.1.0":
         stride_v_cache_h,
         stride_v_cache_h,
         stride_v_cache_d,
         stride_v_cache_d,
         stride_v_cache_bl,
         stride_v_cache_bl,
+        num_queries_per_kv: int,
         BLOCK_M: tl.constexpr,
         BLOCK_M: tl.constexpr,
         BLOCK_DMODEL: tl.constexpr,
         BLOCK_DMODEL: tl.constexpr,
         BLOCK_N: tl.constexpr,
         BLOCK_N: tl.constexpr,
@@ -53,6 +54,8 @@ if triton.__version__ >= "2.1.0":
         cur_head = tl.program_id(1)
         cur_head = tl.program_id(1)
         start_m = tl.program_id(2)
         start_m = tl.program_id(2)
 
 
+        cur_kv_head = cur_head // num_queries_per_kv
+
         cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
         cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
         cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
         cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
         cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
         cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
@@ -85,13 +88,14 @@ if triton.__version__ >= "2.1.0":
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
                          other=0)
                          other=0)
             off_k = (bn[None, :] * stride_k_cache_bs +
             off_k = (bn[None, :] * stride_k_cache_bs +
-                     cur_head * stride_k_cache_h +
+                     cur_kv_head * stride_k_cache_h +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      ((start_n + offs_n[None, :]) % block_size) *
                      ((start_n + offs_n[None, :]) % block_size) *
                      stride_k_cache_bl +
                      stride_k_cache_bl +
                      (offs_d[:, None] % x) * stride_k_cache_x)
                      (offs_d[:, None] % x) * stride_k_cache_x)
             off_v = (
             off_v = (
-                bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
+                bn[:, None] * stride_v_cache_bs +
+                cur_kv_head * stride_v_cache_h +
                 offs_d[None, :] * stride_v_cache_d +
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
             k = tl.load(K_cache + off_k,
             k = tl.load(K_cache + off_k,
@@ -131,9 +135,9 @@ if triton.__version__ >= "2.1.0":
             l_i = l_i_new
             l_i = l_i_new
             m_i = m_i_new
             m_i = m_i_new
 
 
-        off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
+        off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
                  offs_d[:, None] * stride_kd)
                  offs_d[:, None] * stride_kd)
-        off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
+        off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
                  offs_d[None, :] * stride_vd)
                  offs_d[None, :] * stride_vd)
         k_ptrs = K + off_k
         k_ptrs = K + off_k
         v_ptrs = V + off_v
         v_ptrs = V + off_v
@@ -232,6 +236,7 @@ if triton.__version__ >= "2.1.0":
         stride_v_cache_h,
         stride_v_cache_h,
         stride_v_cache_d,
         stride_v_cache_d,
         stride_v_cache_bl,
         stride_v_cache_bl,
+        num_queries_per_kv: int,
         BLOCK_M: tl.constexpr,
         BLOCK_M: tl.constexpr,
         BLOCK_DMODEL: tl.constexpr,
         BLOCK_DMODEL: tl.constexpr,
         BLOCK_N: tl.constexpr,
         BLOCK_N: tl.constexpr,
@@ -240,6 +245,8 @@ if triton.__version__ >= "2.1.0":
         cur_head = tl.program_id(1)
         cur_head = tl.program_id(1)
         start_m = tl.program_id(2)
         start_m = tl.program_id(2)
 
 
+        cur_kv_head = cur_head // num_queries_per_kv
+
         cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
         cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
         cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
         cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
         cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
         cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
@@ -272,13 +279,14 @@ if triton.__version__ >= "2.1.0":
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
                          other=0)
                          other=0)
             off_k = (bn[None, :] * stride_k_cache_bs +
             off_k = (bn[None, :] * stride_k_cache_bs +
-                     cur_head * stride_k_cache_h +
+                     cur_kv_head * stride_k_cache_h +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      ((start_n + offs_n[None, :]) % block_size) *
                      ((start_n + offs_n[None, :]) % block_size) *
                      stride_k_cache_bl +
                      stride_k_cache_bl +
                      (offs_d[:, None] % x) * stride_k_cache_x)
                      (offs_d[:, None] % x) * stride_k_cache_x)
             off_v = (
             off_v = (
-                bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
+                bn[:, None] * stride_v_cache_bs +
+                cur_kv_head * stride_v_cache_h +
                 offs_d[None, :] * stride_v_cache_d +
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
             k = tl.load(K_cache + off_k,
             k = tl.load(K_cache + off_k,
@@ -317,9 +325,9 @@ if triton.__version__ >= "2.1.0":
             l_i = l_i_new
             l_i = l_i_new
             m_i = m_i_new
             m_i = m_i_new
 
 
-        off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
+        off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
                  offs_d[:, None] * stride_kd)
                  offs_d[:, None] * stride_kd)
-        off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
+        off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
                  offs_d[None, :] * stride_vd)
                  offs_d[None, :] * stride_vd)
         k_ptrs = K + off_k
         k_ptrs = K + off_k
         v_ptrs = V + off_v
         v_ptrs = V + off_v
@@ -420,6 +428,7 @@ if triton.__version__ >= "2.1.0":
         stride_v_cache_h,
         stride_v_cache_h,
         stride_v_cache_d,
         stride_v_cache_d,
         stride_v_cache_bl,
         stride_v_cache_bl,
+        num_queries_per_kv: int,
         BLOCK_M: tl.constexpr,
         BLOCK_M: tl.constexpr,
         BLOCK_DMODEL: tl.constexpr,
         BLOCK_DMODEL: tl.constexpr,
         BLOCK_N: tl.constexpr,
         BLOCK_N: tl.constexpr,
@@ -429,6 +438,8 @@ if triton.__version__ >= "2.1.0":
         cur_head = tl.program_id(1)
         cur_head = tl.program_id(1)
         start_m = tl.program_id(2)
         start_m = tl.program_id(2)
 
 
+        cur_kv_head = cur_head // num_queries_per_kv
+
         # cur_batch_seq_len: the length of prompts
         # cur_batch_seq_len: the length of prompts
         # cur_batch_ctx_len: the length of prefix
         # cur_batch_ctx_len: the length of prefix
         # cur_batch_in_all_start_index: the start id of the dim=0
         # cur_batch_in_all_start_index: the start id of the dim=0
@@ -468,13 +479,14 @@ if triton.__version__ >= "2.1.0":
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
                          mask=(start_n + offs_n) < cur_batch_ctx_len,
                          other=0)
                          other=0)
             off_k = (bn[None, :] * stride_k_cache_bs +
             off_k = (bn[None, :] * stride_k_cache_bs +
-                     cur_head * stride_k_cache_h +
+                     cur_kv_head * stride_k_cache_h +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      (offs_d[:, None] // x) * stride_k_cache_d +
                      ((start_n + offs_n[None, :]) % block_size) *
                      ((start_n + offs_n[None, :]) % block_size) *
                      stride_k_cache_bl +
                      stride_k_cache_bl +
                      (offs_d[:, None] % x) * stride_k_cache_x)
                      (offs_d[:, None] % x) * stride_k_cache_x)
             off_v = (
             off_v = (
-                bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
+                bn[:, None] * stride_v_cache_bs +
+                cur_kv_head * stride_v_cache_h +
                 offs_d[None, :] * stride_v_cache_d +
                 offs_d[None, :] * stride_v_cache_d +
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
                 (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
             k = tl.load(K_cache + off_k,
             k = tl.load(K_cache + off_k,
@@ -522,9 +534,9 @@ if triton.__version__ >= "2.1.0":
             l_i = l_i_new
             l_i = l_i_new
             m_i = m_i_new
             m_i = m_i_new
 
 
-        off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
+        off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
                  offs_d[:, None] * stride_kd)
                  offs_d[:, None] * stride_kd)
-        off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
+        off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
                  offs_d[None, :] * stride_vd)
                  offs_d[None, :] * stride_vd)
         k_ptrs = K + off_k
         k_ptrs = K + off_k
         v_ptrs = V + off_v
         v_ptrs = V + off_v
@@ -618,6 +630,7 @@ if triton.__version__ >= "2.1.0":
                               b_ctx_len,
                               b_ctx_len,
                               max_input_len,
                               max_input_len,
                               alibi_slopes=None):
                               alibi_slopes=None):
+
         cap = torch.cuda.get_device_capability()
         cap = torch.cuda.get_device_capability()
         BLOCK = 128 if cap[0] >= 8 else 64
         BLOCK = 128 if cap[0] >= 8 else 64
         # shape constraints
         # shape constraints
@@ -627,6 +640,7 @@ if triton.__version__ >= "2.1.0":
 
 
         sm_scale = 1.0 / (Lq**0.5)
         sm_scale = 1.0 / (Lq**0.5)
         batch, head = b_seq_len.shape[0], q.shape[1]
         batch, head = b_seq_len.shape[0], q.shape[1]
+        num_queries_per_kv = q.shape[1] // k.shape[1]
 
 
         grid = (batch, head, triton.cdiv(max_input_len, BLOCK))  # batch, head,
         grid = (batch, head, triton.cdiv(max_input_len, BLOCK))  # batch, head,
 
 
@@ -673,6 +687,7 @@ if triton.__version__ >= "2.1.0":
                 v_cache.stride(2),
                 v_cache.stride(2),
                 v_cache.stride(
                 v_cache.stride(
                     3),  #[num_blocks, num_kv_heads, head_size, block_size]
                     3),  #[num_blocks, num_kv_heads, head_size, block_size]
+                num_queries_per_kv=num_queries_per_kv,
                 BLOCK_M=BLOCK,
                 BLOCK_M=BLOCK,
                 BLOCK_DMODEL=Lk,
                 BLOCK_DMODEL=Lk,
                 BLOCK_N=BLOCK,
                 BLOCK_N=BLOCK,
@@ -720,6 +735,7 @@ if triton.__version__ >= "2.1.0":
             v_cache.stride(2),
             v_cache.stride(2),
             v_cache.stride(
             v_cache.stride(
                 3),  #[num_blocks, num_kv_heads, head_size, block_size]
                 3),  #[num_blocks, num_kv_heads, head_size, block_size]
+            num_queries_per_kv=num_queries_per_kv,
             BLOCK_M=BLOCK,
             BLOCK_M=BLOCK,
             BLOCK_DMODEL=Lk,
             BLOCK_DMODEL=Lk,
             BLOCK_N=BLOCK,
             BLOCK_N=BLOCK,

+ 8 - 0
aphrodite/modeling/layers/fused_moe/__init__.py

@@ -0,0 +1,8 @@
+from aphrodite.modeling.layers.fused_moe.fused_moe import (fused_moe,
+                                                           get_config_file_name
+                                                           )
+
+__all__ = [
+    "fused_moe",
+    "get_config_file_name",
+]

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 146 - 0
aphrodite/modeling/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json

@@ -0,0 +1,146 @@
+{
+    "1": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 64,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "2": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 32,
+        "BLOCK_SIZE_K": 256,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "8": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "16": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "24": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "32": {
+        "BLOCK_SIZE_M": 16,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "48": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "64": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "96": {
+        "BLOCK_SIZE_M": 32,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "128": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 1,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "256": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 128,
+        "BLOCK_SIZE_K": 128,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "512": {
+        "BLOCK_SIZE_M": 64,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 4,
+        "num_stages": 4
+    },
+    "1024": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "1536": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "2048": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 32,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "3072": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 64,
+        "num_warps": 8,
+        "num_stages": 4
+    },
+    "4096": {
+        "BLOCK_SIZE_M": 128,
+        "BLOCK_SIZE_N": 256,
+        "BLOCK_SIZE_K": 64,
+        "GROUP_SIZE_M": 16,
+        "num_warps": 8,
+        "num_stages": 4
+    }
+}

+ 9 - 0
aphrodite/modeling/layers/fused_moe/configs/README

@@ -0,0 +1,9 @@
+This directory contains tuned configurations for different settings of the fused_moe kernel.
+For different settings of
+- E (number of experts)
+- N (intermediate size)
+- device_name (torch.cuda.get_device_name())
+the JSON file contains a mapping from M (batch size) to the chosen configuration.
+
+Mixtral has intermediate size N = 14336, i.e. for TP2 we have
+N = 7168 and for TP4 we have N = 3584.

+ 120 - 58
aphrodite/modeling/layers/triton_kernel/fused_moe.py → aphrodite/modeling/layers/fused_moe/fused_moe.py

@@ -1,7 +1,13 @@
 """Fused MoE kernel."""
 """Fused MoE kernel."""
+import functools
+import json
+import os
+from typing import Any, Dict, Optional, Tuple
+
 import torch
 import torch
 import triton
 import triton
 import triton.language as tl
 import triton.language as tl
+from loguru import logger
 
 
 from aphrodite._C import ops
 from aphrodite._C import ops
 from aphrodite.common.utils import is_hip
 from aphrodite.common.utils import is_hip
@@ -22,9 +28,10 @@ def fused_moe_kernel(
     K,
     K,
     EM,
     EM,
     num_valid_tokens,
     num_valid_tokens,
-    # The stride variables represent how much to increase the ptr by when moving
-    # by 1 element in a particular dimension. E.g. `stride_am` is how much to
-    # increase `a_ptr` by to get the element one row down (A has M rows).
+    # The stride variables represent how much to increase the ptr by when
+    # moving by 1 element in a particular dimension. E.g. `stride_am` is
+    # how much to increase `a_ptr` by to get the element one row down
+    # (A has M rows).
     stride_am,
     stride_am,
     stride_ak,
     stride_ak,
     stride_be,
     stride_be,
@@ -44,22 +51,23 @@ def fused_moe_kernel(
     """
     """
     Implements the fused computation for a Mixture of Experts (MOE) using
     Implements the fused computation for a Mixture of Experts (MOE) using
     token and expert matrices.
     token and expert matrices.
+
     Key Parameters:
     Key Parameters:
-    - A: The input tensor representing tokens with shape (*, K), where '*'
-        can be any shape representing batches and K is the feature dimension
-        of each token.
-    - B: The stacked MOE weight tensor with shape (E, N, K), where E is the
-        number of experts, K is the input feature dimension, and N is the
-        output feature dimension.
+    - A: The input tensor representing tokens with shape (*, K), where '*' can
+        be any shape representing batches and K is the feature dimension of
+        each token.
+    - B: The stacked MOE weight tensor with shape (E, N, K), where E is
+        the number of experts, K is the input feature dimension, and N is
+        the output feature dimension.
     - C: The output cache tensor with shape (M, topk, N), where M is the
     - C: The output cache tensor with shape (M, topk, N), where M is the
         total number of tokens post padding, topk is the number of times
         total number of tokens post padding, topk is the number of times
         each token is repeated, and N is the output feature dimension.
         each token is repeated, and N is the output feature dimension.
     - sorted_token_ids: A tensor containing the sorted indices of tokens,
     - sorted_token_ids: A tensor containing the sorted indices of tokens,
-        repeated topk times and arranged by the expert index they are assigned
-        to.
-    - expert_ids: A tensor containing the indices of the expert for each block.
-        It determines which expert matrix from B should be used for each block
-        in A.
+        repeated topk times and arranged by the expert index they are
+        assigned to.
+    - expert_ids: A tensor containing the indices of the expert for each
+        block. It determines which expert matrix from B should be used for
+        each block in A.
     This kernel performs the multiplication of a token by its corresponding
     This kernel performs the multiplication of a token by its corresponding
     expert matrix as determined by `expert_ids`. The sorting of
     expert matrix as determined by `expert_ids`. The sorting of
     `sorted_token_ids` by expert index and padding ensures divisibility by
     `sorted_token_ids` by expert index and padding ensures divisibility by
@@ -142,39 +150,43 @@ def fused_moe_kernel(
 
 
 def moe_align_block_size(
 def moe_align_block_size(
         topk_ids: torch.Tensor, block_size: int,
         topk_ids: torch.Tensor, block_size: int,
-        num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor):
+        num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
     """
     """
     Aligns the token distribution across experts to be compatible with block
     Aligns the token distribution across experts to be compatible with block
     size for matrix multiplication.
     size for matrix multiplication.
+
     Parameters:
     Parameters:
-    - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k
-        expert indices for each token.
+    - topk_ids: A tensor of shape [total_tokens, top_k] representing the
+        top-k expert indices for each token.
     - block_size: The block size used in block matrix multiplication.
     - block_size: The block size used in block matrix multiplication.
     - num_experts: The total number of experts.
     - num_experts: The total number of experts.
+
     Returns:
     Returns:
     - sorted_token_ids: A tensor containing the sorted token indices according
     - sorted_token_ids: A tensor containing the sorted token indices according
         to their allocated expert.
         to their allocated expert.
     - expert_ids: A tensor indicating the assigned expert index for each block.
     - expert_ids: A tensor indicating the assigned expert index for each block.
     - num_tokens_post_padded: The total number of tokens after padding,
     - num_tokens_post_padded: The total number of tokens after padding,
         ensuring divisibility by block_size.
         ensuring divisibility by block_size.
+
     This function pads the number of tokens that each expert needs to process
     This function pads the number of tokens that each expert needs to process
-    so that it is divisible by block_size. Padding ensures that during block
-    matrix multiplication, the dimensions align correctly.
+    so that it is divisible by block_size.
+    Padding ensures that during block matrix multiplication, the dimensions
+    align correctly.
+
     Example:
     Example:
     Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
     Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
     block_size = 4, and num_experts = 4:
     block_size = 4, and num_experts = 4:
-    - We initially have 12 tokens (after repeating 'top_k' times) and 4
-        experts, with each expert needing to process 3 tokens.
+    - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
+        with each expert needing to process 3 tokens.
     - As block_size is 4, we pad 1 token for each expert.
     - As block_size is 4, we pad 1 token for each expert.
     - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
     - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
     - Then append padding tokens [12, 12, 12, 12] for each block.
     - Then append padding tokens [12, 12, 12, 12] for each block.
-    - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4,
-                                                          10, 12, 1, 7, 11,
-                                                          12, 2, 5, 8, 12].
-        Tokens 12 are non-existent (padding) and are ignored in the subsequent
-        matrix multiplication.
-    - The padding ensures that the total number of tokens is now divisible by
-        block_size for proper block matrix operations.
+    - After sorting by expert index, we obtain token_ids
+        [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
+        Tokens 12 are non-existent (padding) and are ignored in
+        the subsequent matrix multiplication.
+    - The padding ensures that the total number of tokens is now divisible
+        by block_size for proper block matrix operations.
     """
     """
     sorted_ids = torch.empty(
     sorted_ids = torch.empty(
         (topk_ids.numel() + num_experts * (block_size - 1), ),
         (topk_ids.numel() + num_experts * (block_size - 1), ),
@@ -197,12 +209,11 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
                             sorted_token_ids: torch.Tensor,
                             sorted_token_ids: torch.Tensor,
                             expert_ids: torch.Tensor,
                             expert_ids: torch.Tensor,
                             num_tokens_post_padded: torch.Tensor,
                             num_tokens_post_padded: torch.Tensor,
-                            mul_routed_weight: bool, top_k: int, config: dict):
-
+                            mul_routed_weight: bool, top_k: int,
+                            config: Dict[str, Any]) -> None:
     assert topk_weights.stride(1) == 1
     assert topk_weights.stride(1) == 1
     assert sorted_token_ids.stride(0) == 1
     assert sorted_token_ids.stride(0) == 1
 
 
-    # ruff: noqa: E731
     grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
     grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
         'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
         'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
 
 
@@ -232,6 +243,40 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
     )
     )
 
 
 
 
+def get_config_file_name(E: int, N: int) -> str:
+    device_name = torch.cuda.get_device_name().replace(" ", "_")
+    return f"E={E},N={N},device_name={device_name}.json"
+
+
+@functools.lru_cache
+def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
+    """
+    Return optimized configurations for the fused MoE kernel.
+
+    The return value will be a dictionary that maps an irregular grid of
+    batch sizes to configurations of the fused_moe kernel. To evaluate the
+    kernel on a given batch size bs, the closest batch size in the grid should
+    be picked and the associated configuration chosen to invoke the kernel.
+    """
+
+    # First look up if an optimized configuration is available in the configs
+    # directory
+    json_file_name = get_config_file_name(E, N)
+
+    config_file_path = os.path.join(
+        os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
+    if os.path.exists(config_file_path):
+        with open(config_file_path) as f:
+            logger.info(
+                f"Using configuration from {config_file_path} for MoE layer.")
+            # If a configuration has been found, return it
+            return {int(key): val for key, val in json.load(f).items()}
+
+    # If no optimized configuration is available, we will use the default
+    # configuration
+    return None
+
+
 def fused_moe(
 def fused_moe(
     hidden_states: torch.Tensor,
     hidden_states: torch.Tensor,
     w1: torch.Tensor,
     w1: torch.Tensor,
@@ -240,6 +285,7 @@ def fused_moe(
     topk: int,
     topk: int,
     renormalize: bool,
     renormalize: bool,
     inplace: bool = False,
     inplace: bool = False,
+    override_config: Optional[Dict[str, Any]] = None,
 ) -> torch.Tensor:
 ) -> torch.Tensor:
     """
     """
     This function computes a Mixture of Experts (MoE) layer using two sets of
     This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -249,25 +295,29 @@ def fused_moe(
     - hidden_states (torch.Tensor): The input tensor to the MoE layer.
     - hidden_states (torch.Tensor): The input tensor to the MoE layer.
     - w1 (torch.Tensor): The first set of expert weights.
     - w1 (torch.Tensor): The first set of expert weights.
     - w2 (torch.Tensor): The second set of expert weights.
     - w2 (torch.Tensor): The second set of expert weights.
-    - gating_output (torch.Tensor): The output of the gating operation (before
-        softmax).
+    - gating_output (torch.Tensor): The output of the gating operation
+        (before softmax).
     - topk (int): The number of top-k experts to select.
     - topk (int): The number of top-k experts to select.
     - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
     - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
-    - inplace (bool): If True, perform the operation in-place. Defaults to
-        False.
+    - inplace (bool): If True, perform the operation in-place.
+        Defaults to False.
+    - override_config (Optional[Dict[str, Any]]): Optional override
+        for the kernel configuration.
 
 
     Returns:
     Returns:
     - torch.Tensor: The output tensor after applying the MoE layer.
     - torch.Tensor: The output tensor after applying the MoE layer.
     """
     """
     # Check constraints.
     # Check constraints.
     assert hidden_states.shape[0] == gating_output.shape[0], (
     assert hidden_states.shape[0] == gating_output.shape[0], (
-        'Number of tokens mismatch')
-    assert hidden_states.shape[1] == w1.shape[2], 'Hidden size mismatch'
-    assert gating_output.shape[1] == w1.shape[0], 'Number of experts mismatch'
-    assert hidden_states.is_contiguous(), 'Hidden_states must be contiguous'
-    assert w1.is_contiguous(), 'Expert weights1 must be contiguous'
-    assert w2.is_contiguous(), 'Expert weights2 must be contiguous'
-    assert hidden_states.dtype in [torch.float16, torch.bfloat16]
+        "Number of tokens mismatch")
+    assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
+    assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
+    assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
+    assert w1.is_contiguous(), "Expert weights1 must be contiguous"
+    assert w2.is_contiguous(), "Expert weights2 must be contiguous"
+    assert hidden_states.dtype in [
+        torch.float32, torch.float16, torch.bfloat16
+    ]
     M, _ = hidden_states.shape
     M, _ = hidden_states.shape
     E, N, _ = w1.shape
     E, N, _ = w1.shape
 
 
@@ -302,20 +352,32 @@ def fused_moe(
     if renormalize:
     if renormalize:
         topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
         topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
 
 
-    config = {
-        'BLOCK_SIZE_M': 64,
-        'BLOCK_SIZE_N': 64,
-        'BLOCK_SIZE_K': 32,
-        'GROUP_SIZE_M': 8
-    }
-
-    if topk_ids.numel() <= w1.shape[0]:
-        config = {
-            'BLOCK_SIZE_M': 16,
-            'BLOCK_SIZE_N': 32,
-            'BLOCK_SIZE_K': 64,
-            'GROUP_SIZE_M': 1
-        }
+    if override_config:
+        config = override_config
+    else:
+        # First try to load optimal config from the file
+        configs = get_moe_configs(E, w2.shape[2])
+
+        if configs:
+            # If an optimal configuration map has been found, look up the
+            # optimal config
+            config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
+        else:
+            # Else use the default config
+            config = {
+                'BLOCK_SIZE_M': 64,
+                'BLOCK_SIZE_N': 64,
+                'BLOCK_SIZE_K': 32,
+                'GROUP_SIZE_M': 8
+            }
+
+            if M <= E:
+                config = {
+                    'BLOCK_SIZE_M': 16,
+                    'BLOCK_SIZE_N': 32,
+                    'BLOCK_SIZE_K': 64,
+                    'GROUP_SIZE_M': 1
+                }
 
 
     intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
     intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
                                       device=hidden_states.device,
                                       device=hidden_states.device,
@@ -327,8 +389,8 @@ def fused_moe(
                                       device=hidden_states.device,
                                       device=hidden_states.device,
                                       dtype=hidden_states.dtype)
                                       dtype=hidden_states.dtype)
 
 
-    sorted_token_ids, expert_ids, num_tokens_post_padded = (
-        moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], E))
+    sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
+        topk_ids, config['BLOCK_SIZE_M'], E)
 
 
     invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
     invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1,
                             topk_weights, topk_ids, sorted_token_ids,
                             topk_weights, topk_ids, sorted_token_ids,

+ 57 - 18
aphrodite/modeling/layers/rejection.py

@@ -8,23 +8,26 @@ import torch.jit
 
 
 class RejectionSampler(nn.Module):
 class RejectionSampler(nn.Module):
     """Apply modified rejection sampling as described in "Accelerating Large
     """Apply modified rejection sampling as described in "Accelerating Large
-        Language Model Decoding with Speculative Sampling"
-        https://arxiv.org/pdf/2302.01318.pdf.
+    Language Model Decoding with Speculative Sampling"
+    https://arxiv.org/pdf/2302.01318.pdf.
     """
     """
 
 
     def __init__(self, strict_mode: bool = False):
     def __init__(self, strict_mode: bool = False):
         """Create a rejection sampler.
         """Create a rejection sampler.
+
         Args:
         Args:
             strict_mode: Whether or not to perform shape/device/dtype checks
             strict_mode: Whether or not to perform shape/device/dtype checks
                 during sampling. This catches correctness issues but adds
                 during sampling. This catches correctness issues but adds
                 nontrivial latency.
                 nontrivial latency.
         """
         """
         super().__init__()
         super().__init__()
-        self.probs_dtype = torch.float32
-        self.token_id_dtype = torch.int64
-        self._num_bonus_tokens = 1
         self._strict_mode = strict_mode
         self._strict_mode = strict_mode
 
 
+        # NOTE: A "bonus token" is accepted iff all proposal tokens are
+        # accepted. There is always only one possible bonus token. We store this
+        # value in a variable for readability.
+        self._num_bonus_tokens = 1
+
         self.num_accepted_tokens: Optional[torch.Tensor] = None
         self.num_accepted_tokens: Optional[torch.Tensor] = None
         self.num_emitted_tokens: Optional[torch.Tensor] = None
         self.num_emitted_tokens: Optional[torch.Tensor] = None
         self.num_draft_tokens: int = 0
         self.num_draft_tokens: int = 0
@@ -39,6 +42,14 @@ class RejectionSampler(nn.Module):
                                                dtype=torch.long,
                                                dtype=torch.long,
                                                device=device)
                                                device=device)
 
 
+    @property
+    def probs_dtype(self):
+        return torch.float32
+
+    @property
+    def token_id_dtype(self):
+        return torch.int64
+
     def forward(
     def forward(
         self,
         self,
         target_probs: torch.Tensor,
         target_probs: torch.Tensor,
@@ -49,24 +60,31 @@ class RejectionSampler(nn.Module):
         """Sample token ids using rejection sampling. This accepts or rejects
         """Sample token ids using rejection sampling. This accepts or rejects
         tokens proposed by the draft model using the probability of each token
         tokens proposed by the draft model using the probability of each token
         according to the draft and target models.
         according to the draft and target models.
+
         In the worst case where all draft tokens are rejected, it is guaranteed
         In the worst case where all draft tokens are rejected, it is guaranteed
         one correct token will be emitted.
         one correct token will be emitted.
+
         In the case where all draft tokens are accepted, a bonus token will be
         In the case where all draft tokens are accepted, a bonus token will be
         accepted as its cheap to have the target model score this speculative
         accepted as its cheap to have the target model score this speculative
         sequence.
         sequence.
+
         Args:
         Args:
             target_probs: The probability distribution over token ids given
             target_probs: The probability distribution over token ids given
                 context according to the target model.
                 context according to the target model.
             shape = [batch_size, num_speculative_tokens, vocab_size]
             shape = [batch_size, num_speculative_tokens, vocab_size]
+
             bonus_token_ids: The "bonus" token ids that are accepted iff all
             bonus_token_ids: The "bonus" token ids that are accepted iff all
                 speculative tokens in a sequence are accepted.
                 speculative tokens in a sequence are accepted.
             shape = [batch_size, num_bonus_tokens]
             shape = [batch_size, num_bonus_tokens]
+
             draft_probs: The probability distribution over token ids given
             draft_probs: The probability distribution over token ids given
                 context according to the draft model.
                 context according to the draft model.
             shape = [batch_size, num_speculative_tokens, vocab_size]
             shape = [batch_size, num_speculative_tokens, vocab_size]
+
             draft_token_ids: The token ids that were sampled from the draft
             draft_token_ids: The token ids that were sampled from the draft
                 probabilities.
                 probabilities.
             shape = [batch_size, num_speculative_tokens]
             shape = [batch_size, num_speculative_tokens]
+
         Returns:
         Returns:
             output_token_ids: The token ids sampled via rejection sampling,
             output_token_ids: The token ids sampled via rejection sampling,
                 or -1 if unable to sample a token because the previous token
                 or -1 if unable to sample a token because the previous token
@@ -107,6 +125,7 @@ class RejectionSampler(nn.Module):
             draft_token_ids: torch.Tensor,  # [batch_size, k]
             draft_token_ids: torch.Tensor,  # [batch_size, k]
     ) -> Tuple[torch.Tensor, torch.Tensor]:
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Perform modified rejection sampling on each sequence.
         """Perform modified rejection sampling on each sequence.
+
         Returns:
         Returns:
             A tuple of two tensors:
             A tuple of two tensors:
             0: A bool tensor of which tokens in each sequence is accepted.
             0: A bool tensor of which tokens in each sequence is accepted.
@@ -139,16 +158,20 @@ class RejectionSampler(nn.Module):
         r"""Create bool matrix over the proposed draft tokens. If
         r"""Create bool matrix over the proposed draft tokens. If
         True, then a token can be accepted, else it should be
         True, then a token can be accepted, else it should be
         rejected.
         rejected.
+
         Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
         Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
         :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
         :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
         to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
         to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
         same conditional probability according to the draft model, the token
         same conditional probability according to the draft model, the token
         is accepted with probability:
         is accepted with probability:
+
         .. math::
         .. math::
             \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
             \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
                            {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
                            {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
+
         This implementation does not apply causality. When using the output,
         This implementation does not apply causality. When using the output,
         if a token is rejected, subsequent tokens should not be used.
         if a token is rejected, subsequent tokens should not be used.
+
         Returns a bool tensor of shape [batch_size, k] specifying which tokens
         Returns a bool tensor of shape [batch_size, k] specifying which tokens
         are accepted.
         are accepted.
         """
         """
@@ -171,7 +194,8 @@ class RejectionSampler(nn.Module):
                                   device=target_probs.device)
                                   device=target_probs.device)
         capped_ratio = torch.minimum(
         capped_ratio = torch.minimum(
             selected_target_probs / selected_draft_probs,
             selected_target_probs / selected_draft_probs,
-            torch.full((1, ), 1, device=target_probs.device))
+            torch.full((1, ), 1, device=target_probs.device),
+        )
         accepted = uniform_rand < capped_ratio
         accepted = uniform_rand < capped_ratio
 
 
         return accepted
         return accepted
@@ -183,21 +207,26 @@ class RejectionSampler(nn.Module):
     ) -> torch.Tensor:
     ) -> torch.Tensor:
         r"""Create a probability distribution for each proposed token which can
         r"""Create a probability distribution for each proposed token which can
         be sampled if the proposed token is rejected.
         be sampled if the proposed token is rejected.
+
         When this routine is applied sequentially, the true distribution of the
         When this routine is applied sequentially, the true distribution of the
         target model is recovered (within hardware numerics).
         target model is recovered (within hardware numerics).
+
         The probability distribution used in this rejection case is constructed
         The probability distribution used in this rejection case is constructed
         as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
         as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
         :math:`x` given context :math:`x_1, \dots, x_n` according to the target
         :math:`x` given context :math:`x_1, \dots, x_n` according to the target
         model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
         model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
         according to the draft model:
         according to the draft model:
+
         .. math::
         .. math::
             x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
             x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
+
         where :math:`(f(x))_+` is defined as:
         where :math:`(f(x))_+` is defined as:
+
         .. math::
         .. math::
             (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
             (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
-        See https://github.com/vllm-project/vllm/pull/2336 for a visualization
-        of the draft, target, and recovered probability distributions.
+
         Returns a tensor of shape [batch_size, k, vocab_size].
         Returns a tensor of shape [batch_size, k, vocab_size].
+
         Note: This batches operations on GPU and thus constructs the recovered
         Note: This batches operations on GPU and thus constructs the recovered
         distribution for all tokens, even if they are accepted. This causes
         distribution for all tokens, even if they are accepted. This causes
         division-by-zero errors, so we use self._smallest_positive_value to
         division-by-zero errors, so we use self._smallest_positive_value to
@@ -208,7 +237,7 @@ class RejectionSampler(nn.Module):
         # shape [batch_size, k, vocab_size]
         # shape [batch_size, k, vocab_size]
         difference = target_probs - draft_probs
         difference = target_probs - draft_probs
 
 
-        # TODO(cade): Can we use logprobs instead of probs, and avoid the
+        # TODO: Can we use logprobs instead of probs, and avoid the
         # division-by-zero errors without introducing distribution drift?
         # division-by-zero errors without introducing distribution drift?
 
 
         # shape [batch_size, k, vocab_size]
         # shape [batch_size, k, vocab_size]
@@ -224,7 +253,9 @@ class RejectionSampler(nn.Module):
         """Return the smallest positive value representable by the probs dtype.
         """Return the smallest positive value representable by the probs dtype.
         This value is used when constructing a distribution from which to sample
         This value is used when constructing a distribution from which to sample
         recovered tokens in the first rejection case.
         recovered tokens in the first rejection case.
+
         See _get_recovered_probs for more details
         See _get_recovered_probs for more details
+
         Note that this isn't actually the smallest positive value representable
         Note that this isn't actually the smallest positive value representable
         by float32, but the smallest positive normal value.
         by float32, but the smallest positive normal value.
         See https://en.wikipedia.org/wiki/Subnormal_number for more information.
         See https://en.wikipedia.org/wiki/Subnormal_number for more information.
@@ -241,6 +272,7 @@ class RejectionSampler(nn.Module):
         """Format output. Returns a matrix of token ids. When
         """Format output. Returns a matrix of token ids. When
         a token is rejected via rejection sampling, all subsequent
         a token is rejected via rejection sampling, all subsequent
         token ids are set to -1 for the sequence.
         token ids are set to -1 for the sequence.
+
         shape = [batch_size, k + num_bonus_tokens]
         shape = [batch_size, k + num_bonus_tokens]
         """
         """
         bonus_token_ids = bonus_token_ids.squeeze()
         bonus_token_ids = bonus_token_ids.squeeze()
@@ -259,7 +291,8 @@ class RejectionSampler(nn.Module):
         output_with_bonus_tokens = -torch.ones(
         output_with_bonus_tokens = -torch.ones(
             (batch_size, k + self._num_bonus_tokens),
             (batch_size, k + self._num_bonus_tokens),
             dtype=self.token_id_dtype,
             dtype=self.token_id_dtype,
-            device=accepted.device)
+            device=accepted.device,
+        )
         output = output_with_bonus_tokens[:, :k]
         output = output_with_bonus_tokens[:, :k]
 
 
         # Fill in the first k columns of the output tensor using masks and data
         # Fill in the first k columns of the output tensor using masks and data
@@ -290,8 +323,11 @@ class RejectionSampler(nn.Module):
         draft_probs: torch.Tensor,
         draft_probs: torch.Tensor,
         draft_token_ids: torch.Tensor,
         draft_token_ids: torch.Tensor,
     ) -> None:
     ) -> None:
-        (target_batch_size, num_target_probs,
-         target_vocab_size) = target_probs.shape
+        (
+            target_batch_size,
+            num_target_probs,
+            target_vocab_size,
+        ) = target_probs.shape
         bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
         bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
         draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
         draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
         draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
         draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
@@ -327,10 +363,13 @@ class RejectionSampler(nn.Module):
         draft_token_ids: torch.Tensor,
         draft_token_ids: torch.Tensor,
     ) -> None:
     ) -> None:
         devices = [
         devices = [
-            t.device for t in
-            [target_probs, bonus_token_ids, draft_probs, draft_token_ids]
+            t.device for t in [
+                target_probs,
+                bonus_token_ids,
+                draft_probs,
+                draft_token_ids,
+            ]
         ]
         ]
-        # pylint: disable=use-a-generator
         assert all([devices[0] == device for device in devices])
         assert all([devices[0] == device for device in devices])
 
 
     def _raise_if_out_of_bounds_vocab(
     def _raise_if_out_of_bounds_vocab(
@@ -358,8 +397,8 @@ def _multinomial(
     if num_samples > 1:
     if num_samples > 1:
         # This is equivalent to torch.repeat_interleaved (which also
         # This is equivalent to torch.repeat_interleaved (which also
         # forces a GPU<->CPU sync).
         # forces a GPU<->CPU sync).
-        probs = probs[:, None, :].expand(probs.shape[0], num_samples,
-                                         probs.shape[1]).contiguous().view(
-                                             -1, probs.shape[1])
+        probs = (probs[:, None, :].expand(probs.shape[0], num_samples,
+                                          probs.shape[1]).contiguous().view(
+                                              -1, probs.shape[1]))
     q = torch.empty_like(probs).exponential_(1.0)
     q = torch.empty_like(probs).exponential_(1.0)
     return probs.div_(q).argmax(dim=1).view(-1, num_samples)
     return probs.div_(q).argmax(dim=1).view(-1, num_samples)

+ 1 - 2
aphrodite/modeling/layers/sampler.py

@@ -802,7 +802,6 @@ def _get_logprobs(
         if (i < sampling_metadata.num_prompts
         if (i < sampling_metadata.num_prompts
                 and sampling_params.prompt_logprobs is not None):
                 and sampling_params.prompt_logprobs is not None):
             num_logprobs = sampling_params.prompt_logprobs
             num_logprobs = sampling_params.prompt_logprobs
-            prompt_len = sampling_metadata.prompt_lens[i]
             prompt_tokens = sampling_metadata.seq_data[
             prompt_tokens = sampling_metadata.seq_data[
                 seq_ids[0]].prompt_token_ids
                 seq_ids[0]].prompt_token_ids
             group_prompt_logprobs: PromptLogprobs = [None]
             group_prompt_logprobs: PromptLogprobs = [None]
@@ -876,7 +875,7 @@ def _build_sampler_output(
                                output_metadata.get(seq_ids[parent_id])))
                                output_metadata.get(seq_ids[parent_id])))
         sampler_output.append(
         sampler_output.append(
             SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
             SequenceGroupOutput(seq_outputs, group_prompt_logprobs))
-    return sampler_output
+    return SamplerOutput(outputs=sampler_output)
 
 
 
 
 def _miro_store_args(seqids: List[int], mus: List[float],
 def _miro_store_args(seqids: List[int], mus: List[float],

+ 62 - 24
aphrodite/modeling/loader.py

@@ -2,18 +2,24 @@
 import contextlib
 import contextlib
 import gc
 import gc
 from contextlib import nullcontext
 from contextlib import nullcontext
-from typing import Optional, Type
+from typing import Type
 from loguru import logger
 from loguru import logger
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
 
 
-from aphrodite.common.config import DeviceConfig, ModelConfig, LoRAConfig
+from aphrodite.common.config import DeviceConfig, ModelConfig
 from aphrodite.modeling.models import ModelRegistry
 from aphrodite.modeling.models import ModelRegistry
-from aphrodite.modeling.hf_downloader import (get_quant_config,
-                                              initialize_dummy_weights)
+from aphrodite.modeling.hf_downloader import (
+    get_quant_config,
+    initialize_dummy_weights,
+)
 from aphrodite.modeling.layers.quantization.bitsandbytes import (
 from aphrodite.modeling.layers.quantization.bitsandbytes import (
-    BNBLinearMethod, replace_quant_params)
+    BNBLinearMethod,
+    replace_quant_params,
+)
+from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_world_size, )
 
 
 
 
 @contextlib.contextmanager
 @contextlib.contextmanager
@@ -32,6 +38,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
     if (model_config.quantization is not None
     if (model_config.quantization is not None
             and "MixtralForCausalLM" in architectures):
             and "MixtralForCausalLM" in architectures):
         architectures = ["QuantMixtralForCausalLM"]
         architectures = ["QuantMixtralForCausalLM"]
+
     for arch in architectures:
     for arch in architectures:
         model_cls = ModelRegistry.load_model_cls(arch)
         model_cls = ModelRegistry.load_model_cls(arch)
         if model_cls is not None:
         if model_cls is not None:
@@ -41,9 +48,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
         f"Supported architectures: {ModelRegistry.get_supported_archs()}")
         f"Supported architectures: {ModelRegistry.get_supported_archs()}")
 
 
 
 
-def get_model(model_config: ModelConfig,
-              device_config: DeviceConfig,
-              lora_config: Optional[LoRAConfig] = None) -> nn.Module:
+def get_model(model_config: ModelConfig, device_config: DeviceConfig,
+              **kwargs) -> nn.Module:
+    lora_config = kwargs.get("lora_config", None)
     model_class = _get_model_architecture(model_config)
     model_class = _get_model_architecture(model_config)
 
 
     # Get the (maybe quantized) linear method.
     # Get the (maybe quantized) linear method.
@@ -68,9 +75,9 @@ def get_model(model_config: ModelConfig,
     with _set_default_torch_dtype(model_config.dtype):
     with _set_default_torch_dtype(model_config.dtype):
         # Create a model instance.
         # Create a model instance.
         # The weights will be initialized as empty tensors.
         # The weights will be initialized as empty tensors.
-        with torch.device(device_config.device) if not \
-            (isinstance(linear_method, BNBLinearMethod) and
-             linear_method.quant_config.from_float) else nullcontext():
+        with torch.device(device_config.device) if not (
+                isinstance(linear_method, BNBLinearMethod)
+                and linear_method.quant_config.from_float) else nullcontext():
             if hasattr(model_class, "supported_lora_modules"):
             if hasattr(model_class, "supported_lora_modules"):
                 model = model_class(model_config.hf_config, linear_method,
                 model = model_class(model_config.hf_config, linear_method,
                                     lora_config)
                                     lora_config)
@@ -88,23 +95,54 @@ def get_model(model_config: ModelConfig,
             initialize_dummy_weights(model)
             initialize_dummy_weights(model)
         else:
         else:
             # Load the weights from the cached or downloaded files.
             # Load the weights from the cached or downloaded files.
-            model.load_weights(model_config.model, model_config.download_dir,
-                               model_config.load_format, model_config.revision)
+            model.load_weights(
+                model_config.model,
+                model_config.download_dir,
+                model_config.load_format,
+                model_config.revision,
+            )
         if isinstance(linear_method, BNBLinearMethod):
         if isinstance(linear_method, BNBLinearMethod):
-            replace_quant_params(model,
-                                 quant_config=linear_method.quant_config,
-                                 modules_to_not_convert="lm_head")
+            replace_quant_params(
+                model,
+                quant_config=linear_method.quant_config,
+                modules_to_not_convert="lm_head",
+            )
             torch.cuda.synchronize()
             torch.cuda.synchronize()
             if linear_method.quant_config.from_float:
             if linear_method.quant_config.from_float:
                 model = model.cuda()
                 model = model.cuda()
             gc.collect()
             gc.collect()
             torch.cuda.empty_cache()
             torch.cuda.empty_cache()
-            logger.info("Memory allocated for converted model: {} GiB".format(
-                round(
-                    torch.cuda.memory_allocated(torch.cuda.current_device()) /
-                    (1024 * 1024 * 1024), 2)))
-            logger.info("Memory reserved for converted model: {} GiB".format(
-                round(
-                    torch.cuda.memory_reserved(torch.cuda.current_device()) /
-                    (1024 * 1024 * 1024), 2)))
+            tp = get_tensor_model_parallel_world_size()
+            logger.info(
+                "Memory allocated for converted model: {} GiB x {} = {} "
+                "GiB".format(
+                    round(
+                        torch.cuda.memory_allocated(
+                            torch.cuda.current_device()) /
+                        (1024 * 1024 * 1024),
+                        2,
+                    ),
+                    tp,
+                    round(
+                        torch.cuda.memory_allocated(
+                            torch.cuda.current_device()) * tp /
+                        (1024 * 1024 * 1024),
+                        2,
+                    ),
+                ))
+            logger.info(
+                "Memory reserved for converted model: {} GiB x {} = {} "
+                "GiB".format(
+                    round(
+                        torch.cuda.memory_reserved(torch.cuda.current_device())
+                        / (1024 * 1024 * 1024),
+                        2,
+                    ),
+                    tp,
+                    round(
+                        torch.cuda.memory_reserved(torch.cuda.current_device())
+                        * tp / (1024 * 1024 * 1024),
+                        2,
+                    ),
+                ))
     return model.eval()
     return model.eval()

+ 5 - 4
aphrodite/modeling/metadata.py

@@ -1,4 +1,4 @@
-from typing import Optional, List
+from typing import Optional
 
 
 import torch
 import torch
 
 
@@ -28,7 +28,7 @@ class InputMetadata:
         block_tables: Optional[torch.Tensor],
         block_tables: Optional[torch.Tensor],
         use_cuda_graph: bool,
         use_cuda_graph: bool,
         kv_cache_dtype: str,
         kv_cache_dtype: str,
-        kv_quant_params: List[List[float]],
+        # kv_quant_params: List[List[float]],
     ) -> None:
     ) -> None:
         self.is_prompt = is_prompt
         self.is_prompt = is_prompt
         self.prompt_lens = prompt_lens
         self.prompt_lens = prompt_lens
@@ -40,7 +40,7 @@ class InputMetadata:
         self.block_tables = block_tables
         self.block_tables = block_tables
         self.use_cuda_graph = use_cuda_graph
         self.use_cuda_graph = use_cuda_graph
         self.kv_cache_dtype = kv_cache_dtype
         self.kv_cache_dtype = kv_cache_dtype
-        self.kv_quant_params = kv_quant_params
+        # self.kv_quant_params = kv_quant_params
 
 
         # Set during the execution of the first attention op.
         # Set during the execution of the first attention op.
         # FIXME: This is a hack.
         # FIXME: This is a hack.
@@ -55,4 +55,5 @@ class InputMetadata:
                 f"block_tables={self.block_tables}, "
                 f"block_tables={self.block_tables}, "
                 f"use_cuda_graph={self.use_cuda_graph}, "
                 f"use_cuda_graph={self.use_cuda_graph}, "
                 f"kv_cache_dtype={self.kv_cache_dtype}, "
                 f"kv_cache_dtype={self.kv_cache_dtype}, "
-                f"kv_quant_params={self.kv_quant_params})")
+                # f"kv_quant_params={self.kv_quant_params})"
+                )

+ 3 - 4
aphrodite/modeling/models/baichuan.py

@@ -27,7 +27,7 @@ from torch import nn
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
@@ -187,7 +187,7 @@ class BaiChuanAttention(nn.Module):
             alibi_slopes = alibi_slopes[head_start:head_end].tolist()
             alibi_slopes = alibi_slopes[head_start:head_end].tolist()
 
 
             scaling = self.head_dim**-0.5
             scaling = self.head_dim**-0.5
-            self.attn = PagedAttention(
+            self.attn = Attention(
                 self.num_heads,
                 self.num_heads,
                 self.head_dim,
                 self.head_dim,
                 scaling,
                 scaling,
@@ -205,8 +205,7 @@ class BaiChuanAttention(nn.Module):
                 is_neox_style=is_neox_style,
                 is_neox_style=is_neox_style,
             )
             )
             self.scaling = self.head_dim**-0.5
             self.scaling = self.head_dim**-0.5
-            self.attn = PagedAttention(self.num_heads, self.head_dim,
-                                       self.scaling)
+            self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
 
 
     def forward(
     def forward(
         self,
         self,

+ 5 - 5
aphrodite/modeling/models/bloom.py

@@ -26,7 +26,7 @@ from transformers import BloomConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
@@ -108,10 +108,10 @@ class BloomAttention(nn.Module):
         alibi_slopes = alibi_slopes[head_start:head_end].tolist()
         alibi_slopes = alibi_slopes[head_start:head_end].tolist()
 
 
         scaling = self.head_dim**-0.5
         scaling = self.head_dim**-0.5
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scaling,
-                                   alibi_slopes=alibi_slopes)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scaling,
+                              alibi_slopes=alibi_slopes)
 
 
     def forward(
     def forward(
         self,
         self,

+ 2 - 2
aphrodite/modeling/models/chatglm.py

@@ -10,7 +10,7 @@ from torch.nn import LayerNorm
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               MergedColumnParallelLinear,
@@ -87,7 +87,7 @@ class GLMAttention(nn.Module):
             base=10000 * rope_ratio,
             base=10000 * rope_ratio,
             is_neox_style=False,
             is_neox_style=False,
         )
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.num_heads,
             self.head_dim,
             self.head_dim,
             self.scaling,
             self.scaling,

+ 1 - 1
aphrodite/modeling/models/cohere.py

@@ -30,7 +30,7 @@ from transformers import CohereConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention as Attention
+from aphrodite.modeling.layers.attention import Attention as Attention
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
     MergedColumnParallelLinear,
     MergedColumnParallelLinear,

+ 3 - 3
aphrodite/modeling/models/deepseek.py

@@ -31,8 +31,8 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
-from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
+from aphrodite.modeling.layers.attention import Attention
+from aphrodite.modeling.layers.fused_moe.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
@@ -249,7 +249,7 @@ class DeepseekAttention(nn.Module):
             base=rope_theta,
             base=rope_theta,
             rope_scaling=rope_scaling,
             rope_scaling=rope_scaling,
         )
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.num_heads,
             self.head_dim,
             self.head_dim,
             self.scaling,
             self.scaling,

+ 14 - 14
aphrodite/modeling/models/falcon.py

@@ -29,7 +29,7 @@ from transformers import FalconConfig as HF_FalconConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
@@ -151,10 +151,10 @@ class FalconAttention(nn.Module):
                 max_position=max_position_embeddings,
                 max_position=max_position_embeddings,
                 base=rope_theta,
                 base=rope_theta,
             )
             )
-            self.attn = PagedAttention(self.num_heads,
-                                       self.head_dim,
-                                       self.inv_norm_factor,
-                                       num_kv_heads=self.num_kv_heads)
+            self.attn = Attention(self.num_heads,
+                                  self.head_dim,
+                                  self.inv_norm_factor,
+                                  num_kv_heads=self.num_kv_heads)
         elif self.use_alibi:
         elif self.use_alibi:
             tp_rank = get_tensor_model_parallel_rank()
             tp_rank = get_tensor_model_parallel_rank()
             head_start = tp_rank * self.num_heads
             head_start = tp_rank * self.num_heads
@@ -162,16 +162,16 @@ class FalconAttention(nn.Module):
             alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
             alibi_slopes = (_get_alibi_slopes(self.total_num_heads) *
                             self.inv_norm_factor)
                             self.inv_norm_factor)
             alibi_slopes = alibi_slopes[head_start:head_end].tolist()
             alibi_slopes = alibi_slopes[head_start:head_end].tolist()
-            self.attn = PagedAttention(self.num_heads,
-                                       self.head_dim,
-                                       self.inv_norm_factor,
-                                       num_kv_heads=self.num_kv_heads,
-                                       alibi_slopes=alibi_slopes)
+            self.attn = Attention(self.num_heads,
+                                  self.head_dim,
+                                  self.inv_norm_factor,
+                                  num_kv_heads=self.num_kv_heads,
+                                  alibi_slopes=alibi_slopes)
         else:
         else:
-            self.attn = PagedAttention(self.num_heads,
-                                       self.head_dim,
-                                       scale=self.inv_norm_factor,
-                                       num_kv_heads=self.num_kv_heads)
+            self.attn = Attention(self.num_heads,
+                                  self.head_dim,
+                                  scale=self.inv_norm_factor,
+                                  num_kv_heads=self.num_kv_heads)
 
 
     def forward(
     def forward(
         self,
         self,

+ 2 - 2
aphrodite/modeling/models/gemma.py

@@ -24,7 +24,7 @@ from transformers import GemmaConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import GeluAndMul
 from aphrodite.modeling.layers.activation import GeluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
@@ -181,7 +181,7 @@ class GemmaAttention(nn.Module):
             base=self.rope_theta,
             base=self.rope_theta,
             is_neox_style=True,
             is_neox_style=True,
         )
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.num_heads,
             self.head_dim,
             self.head_dim,
             self.scaling,
             self.scaling,

+ 2 - 4
aphrodite/modeling/models/gpt2.py

@@ -26,7 +26,7 @@ from transformers import GPT2Config
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
@@ -74,9 +74,7 @@ class GPT2Attention(nn.Module):
             bias=True,
             bias=True,
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scale=self.scale)
+        self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
 
 
     def forward(
     def forward(
         self,
         self,

+ 5 - 5
aphrodite/modeling/models/gpt_bigcode.py

@@ -27,7 +27,7 @@ from transformers import GPTBigCodeConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
@@ -86,10 +86,10 @@ class GPTBigCodeAttention(nn.Module):
             bias=True,
             bias=True,
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scale=self.scale,
-                                   num_kv_heads=self.num_kv_heads)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scale=self.scale,
+                              num_kv_heads=self.num_kv_heads)
 
 
     def forward(
     def forward(
         self,
         self,

+ 2 - 2
aphrodite/modeling/models/gpt_j.py

@@ -26,7 +26,7 @@ from transformers import GPTJConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
     ColumnParallelLinear,
     LinearMethodBase,
     LinearMethodBase,
@@ -117,7 +117,7 @@ class GPTJAttention(nn.Module):
             base=rope_theta,
             base=rope_theta,
             is_neox_style=False,
             is_neox_style=False,
         )
         )
-        self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
+        self.attn = Attention(self.num_heads, self.head_size, scaling)
 
 
     def forward(
     def forward(
         self,
         self,

+ 2 - 2
aphrodite/modeling/models/gpt_neox.py

@@ -26,7 +26,7 @@ from transformers import GPTNeoXConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
     ColumnParallelLinear,
     LinearMethodBase,
     LinearMethodBase,
@@ -99,7 +99,7 @@ class GPTNeoXAttention(nn.Module):
             base=rope_theta,
             base=rope_theta,
             is_neox_style=is_neox_style,
             is_neox_style=is_neox_style,
         )
         )
-        self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
+        self.attn = Attention(self.num_heads, self.head_size, scaling)
 
 
     def forward(
     def forward(
         self,
         self,

+ 2 - 2
aphrodite/modeling/models/internlm2.py

@@ -7,7 +7,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
@@ -147,7 +147,7 @@ class InternLM2Attention(nn.Module):
             base=rope_theta,
             base=rope_theta,
             rope_scaling=rope_scaling,
             rope_scaling=rope_scaling,
         )
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.num_heads,
             self.head_dim,
             self.head_dim,
             self.scaling,
             self.scaling,

+ 16 - 9
aphrodite/modeling/models/llama.py

@@ -30,7 +30,7 @@ from transformers import LlamaConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
@@ -199,7 +199,7 @@ class LlamaAttention(nn.Module):
             rope_scaling=rope_scaling,
             rope_scaling=rope_scaling,
             is_neox_style=is_neox_style,
             is_neox_style=is_neox_style,
         )
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.num_heads,
             self.head_dim,
             self.head_dim,
             self.scaling,
             self.scaling,
@@ -213,7 +213,7 @@ class LlamaAttention(nn.Module):
         hidden_states: torch.Tensor,
         hidden_states: torch.Tensor,
         kv_cache: KVCache,
         kv_cache: KVCache,
         input_metadata: InputMetadata,
         input_metadata: InputMetadata,
-        kv_quant_param: List[float],
+        # kv_quant_param: List[float],
     ) -> torch.Tensor:
     ) -> torch.Tensor:
         if self.merge_weight:
         if self.merge_weight:
             qkv, _ = self.qkv_proj(hidden_states)
             qkv, _ = self.qkv_proj(hidden_states)
@@ -225,8 +225,15 @@ class LlamaAttention(nn.Module):
             v, _ = self.v_proj(hidden_states)
             v, _ = self.v_proj(hidden_states)
         q, k = self.rotary_emb(positions, q, k)
         q, k = self.rotary_emb(positions, q, k)
         k_cache, v_cache = kv_cache
         k_cache, v_cache = kv_cache
-        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
-                                kv_quant_param)
+        attn_output = self.attn(
+            q,
+            k,
+            v,
+            k_cache,
+            v_cache,
+            input_metadata,
+            # kv_quant_param
+        )
         output, _ = self.o_proj(attn_output)
         output, _ = self.o_proj(attn_output)
         return output
         return output
 
 
@@ -279,7 +286,7 @@ class LlamaDecoderLayer(nn.Module):
         kv_cache: KVCache,
         kv_cache: KVCache,
         input_metadata: InputMetadata,
         input_metadata: InputMetadata,
         residual: Optional[torch.Tensor],
         residual: Optional[torch.Tensor],
-        kv_quant_param: List[float],
+        # kv_quant_param: List[float],
     ) -> Tuple[torch.Tensor, torch.Tensor]:
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
         # Self Attention
         if residual is None:
         if residual is None:
@@ -293,7 +300,7 @@ class LlamaDecoderLayer(nn.Module):
             hidden_states=hidden_states,
             hidden_states=hidden_states,
             kv_cache=kv_cache,
             kv_cache=kv_cache,
             input_metadata=input_metadata,
             input_metadata=input_metadata,
-            kv_quant_param=kv_quant_param,
+            # kv_quant_param=kv_quant_param,
         )
         )
 
 
         # Fully Connected
         # Fully Connected
@@ -347,8 +354,8 @@ class LlamaModel(nn.Module):
                 kv_caches[i],
                 kv_caches[i],
                 input_metadata,
                 input_metadata,
                 residual,
                 residual,
-                input_metadata.kv_quant_params[i]
-                if input_metadata.kv_quant_params is not None else None,
+                # input_metadata.kv_quant_params[i]
+                # if input_metadata.kv_quant_params is not None else None,
             )
             )
         hidden_states, _ = self.norm(hidden_states, residual)
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
         return hidden_states

+ 3 - 3
aphrodite/modeling/models/mixtral.py

@@ -31,8 +31,8 @@ from transformers import MixtralConfig
 
 
 from aphrodite.common.config import LoRAConfig
 from aphrodite.common.config import LoRAConfig
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.attention import PagedAttention
-from aphrodite.modeling.layers.triton_kernel.fused_moe import fused_moe
+from aphrodite.modeling.layers.attention import Attention
+from aphrodite.modeling.layers.fused_moe.fused_moe import fused_moe
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
@@ -256,7 +256,7 @@ class MixtralAttention(nn.Module):
             base=int(self.rope_theta),
             base=int(self.rope_theta),
             is_neox_style=is_neox_style,
             is_neox_style=is_neox_style,
         )
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.num_heads,
             self.head_dim,
             self.head_dim,
             self.scaling,
             self.scaling,

+ 2 - 2
aphrodite/modeling/models/mixtral_quant.py

@@ -34,7 +34,7 @@ from torch import nn
 from transformers import MixtralConfig
 from transformers import MixtralConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
@@ -259,7 +259,7 @@ class MixtralAttention(nn.Module):
             base=int(self.rope_theta),
             base=int(self.rope_theta),
             is_neox_style=True,
             is_neox_style=True,
         )
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.num_heads,
             self.head_dim,
             self.head_dim,
             self.scaling,
             self.scaling,

+ 6 - 6
aphrodite/modeling/models/mpt.py

@@ -8,7 +8,7 @@ import torch.nn as nn
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
 from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               QKVParallelLinear,
@@ -105,11 +105,11 @@ class MPTAttention(nn.Module):
 
 
         self.head_dim = self.d_model // self.total_num_heads
         self.head_dim = self.d_model // self.total_num_heads
         scaling = self.head_dim**-0.5
         scaling = self.head_dim**-0.5
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scaling,
-                                   alibi_slopes=alibi_slopes,
-                                   num_kv_heads=self.num_kv_heads)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scaling,
+                              alibi_slopes=alibi_slopes,
+                              num_kv_heads=self.num_kv_heads)
 
 
     def forward(
     def forward(
         self,
         self,

+ 79 - 0
aphrodite/modeling/models/neuron/llama.py

@@ -0,0 +1,79 @@
+"""Inference-only LLaMA model compatible with HuggingFace weights."""
+import os
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from transformers import LlamaConfig
+
+from aphrodite.modeling.metadata import InputMetadata
+from aphrodite.modeling.layers.sampler import Sampler
+from aphrodite.modeling.sampling_metadata import SamplingMetadata
+from aphrodite.common.sequence import SamplerOutput
+
+KVCache = Tuple[torch.Tensor, torch.Tensor]
+
+
+class LlamaForCausalLM(nn.Module):
+
+    def __init__(
+        self,
+        config: LlamaConfig,
+        linear_method=None,
+    ) -> None:
+        super().__init__()
+        self.config = config
+        self.linear_method = linear_method
+        self.model = None
+        self.sampler = Sampler(config.vocab_size)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[KVCache],
+        input_metadata: InputMetadata,
+    ) -> torch.Tensor:
+        with torch.inference_mode():
+            block_size = self.model.context_buckets[-1]
+            if input_metadata.is_prompt:
+                seq_ids = input_metadata.slot_mapping[:, 0] // block_size
+            else:
+                seq_ids = input_metadata.block_tables
+            logits = self.model(input_ids,
+                                cache_ids=positions,
+                                start_ids=seq_ids.flatten())
+        return logits
+
+    def sample(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(self.model.chkpt_model.lm_head,
+                                   hidden_states, sampling_metadata)
+        return next_tokens
+
+    def load_weights(self,
+                     model_name_or_path: str,
+                     cache_dir: Optional[str] = None,
+                     load_format: str = "auto",
+                     revision: Optional[str] = None,
+                     **kwargs):
+        from transformers_neuronx.llama.model import LlamaForSampling
+
+        split_model_dir = f"{model_name_or_path}-split"
+        if os.path.isdir(os.path.join(model_name_or_path,
+                                      "pytorch_model.bin")):
+            split_model_dir = model_name_or_path
+        elif not os.path.exists(f"{model_name_or_path}-split"):
+            from transformers.models.llama import LlamaForCausalLM
+            from transformers_neuronx.module import save_pretrained_split
+
+            hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path,
+                                                        low_cpu_mem_usage=True)
+            save_pretrained_split(hf_model, f"{model_name_or_path}-split")
+
+        self.model = LlamaForSampling.from_pretrained(split_model_dir,
+                                                      **kwargs)
+        self.model.to_neuron()

+ 4 - 4
aphrodite/modeling/models/olmo.py

@@ -45,7 +45,7 @@ import torch.nn.functional as F
 from torch import nn
 from torch import nn
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
     ColumnParallelLinear,
     LinearMethodBase,
     LinearMethodBase,
@@ -131,9 +131,9 @@ class OlmoAttention(nn.Module):
                 base=rope_theta,
                 base=rope_theta,
             )
             )
         self.scaling = self.head_dim**-0.5
         self.scaling = self.head_dim**-0.5
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scale=self.scaling)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scale=self.scaling)
 
 
         # Attention output projection.
         # Attention output projection.
         self.attn_out = RowParallelLinear(
         self.attn_out = RowParallelLinear(

+ 4 - 4
aphrodite/modeling/models/opt.py

@@ -27,7 +27,7 @@ from transformers import OPTConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
     ColumnParallelLinear,
     LinearMethodBase,
     LinearMethodBase,
@@ -114,9 +114,9 @@ class OPTAttention(nn.Module):
             bias=bias,
             bias=bias,
             linear_method=linear_method,
             linear_method=linear_method,
         )
         )
-        self.attn = PagedAttention(self.num_heads,
-                                   self.head_dim,
-                                   scale=self.scaling)
+        self.attn = Attention(self.num_heads,
+                              self.head_dim,
+                              scale=self.scaling)
 
 
     def forward(
     def forward(
         self,
         self,

+ 2 - 2
aphrodite/modeling/models/phi.py

@@ -45,7 +45,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import get_act_fn
 from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     ColumnParallelLinear,
     ColumnParallelLinear,
     LinearMethodBase,
     LinearMethodBase,
@@ -145,7 +145,7 @@ class PhiAttention(nn.Module):
             base=rope_theta,
             base=rope_theta,
             is_neox_style=is_neox_style,
             is_neox_style=is_neox_style,
         )
         )
-        self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
+        self.attn = Attention(self.num_heads, self.head_size, scaling)
 
 
     def forward(
     def forward(
         self,
         self,

+ 2 - 2
aphrodite/modeling/models/qwen.py

@@ -12,7 +12,7 @@ from torch import nn
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
@@ -142,7 +142,7 @@ class QWenAttention(nn.Module):
             rope_scaling=rope_scaling,
             rope_scaling=rope_scaling,
             is_neox_style=is_neox_style,
             is_neox_style=is_neox_style,
         )
         )
-        self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
+        self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
 
 
     def forward(
     def forward(
         self,
         self,

+ 2 - 2
aphrodite/modeling/models/qwen2.py

@@ -32,7 +32,7 @@ from transformers import Qwen2Config
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.layernorm import RMSNorm
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
@@ -193,7 +193,7 @@ class Qwen2Attention(nn.Module):
             max_position=max_position,
             max_position=max_position,
             base=self.rope_theta,
             base=self.rope_theta,
         )
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.num_heads,
             self.head_dim,
             self.head_dim,
             self.scaling,
             self.scaling,

+ 2 - 2
aphrodite/modeling/models/stablelm.py

@@ -28,7 +28,7 @@ from transformers import PretrainedConfig
 
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.activation import SiluAndMul
 from aphrodite.modeling.layers.activation import SiluAndMul
-from aphrodite.modeling.layers.attention import PagedAttention
+from aphrodite.modeling.layers.attention import Attention
 from aphrodite.modeling.layers.linear import (
 from aphrodite.modeling.layers.linear import (
     LinearMethodBase,
     LinearMethodBase,
     MergedColumnParallelLinear,
     MergedColumnParallelLinear,
@@ -188,7 +188,7 @@ class StablelmAttention(nn.Module):
             max_position=self.config.max_position_embeddings,
             max_position=self.config.max_position_embeddings,
             base=self.config.rope_theta,
             base=self.config.rope_theta,
         )
         )
-        self.attn = PagedAttention(
+        self.attn = Attention(
             self.num_heads,
             self.num_heads,
             self.head_dim,
             self.head_dim,
             self.scaling,
             self.scaling,

+ 70 - 0
aphrodite/modeling/neuron_loader.py

@@ -0,0 +1,70 @@
+"""Utilities for selecting and loading models."""
+from typing import Type
+
+import torch
+import torch.nn as nn
+from transformers import PretrainedConfig
+
+from aphrodite.common.config import ModelConfig, DeviceConfig
+from aphrodite.modeling.models import ModelRegistry
+
+TORCH_DTYPE_TO_NEURON_AMP = {
+    "auto": "f32",
+    "half": "f16",
+    "float16": "f16",
+    "bfloat16": "bf16",
+    "float": "f32",
+    "float32": "f32",
+    torch.float16: "f16",
+    torch.bfloat16: "bf16",
+    torch.float32: "f32",
+}
+
+
+def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
+    architectures = getattr(config, "architectures", [])
+    for arch in architectures:
+        model_cls = ModelRegistry.load_model_cls(arch)
+        if model_cls is not None:
+            return model_cls
+    raise ValueError(
+        f"Model architectures {architectures} are not supported for now. "
+        f"Supported architectures: {ModelRegistry.get_supported_archs()}")
+
+
+def get_model(model_config: ModelConfig, device_config: DeviceConfig,
+              **kwargs) -> nn.Module:
+    from transformers_neuronx.config import (
+        NeuronConfig,
+        ContinuousBatchingConfig,
+    )
+
+    parallel_config = kwargs.get("parallel_config")
+    scheduler_config = kwargs.get("scheduler_config")
+
+    model_class = _get_model_architecture(model_config.hf_config)
+    linear_method = None
+
+    # Create a model instance.
+    model = model_class(model_config.hf_config, linear_method)
+
+    continuous_batching_config = ContinuousBatchingConfig(
+        batch_size_for_shared_caches=scheduler_config.max_num_seqs)
+    neuron_config = NeuronConfig(
+        continuous_batching=continuous_batching_config)
+
+    # Load the weights from the cached or downloaded files.
+    model.load_weights(
+        model_config.model,
+        model_config.download_dir,
+        model_config.load_format,
+        model_config.revision,
+        tp_degree=parallel_config.neuron_tp_degree,
+        amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
+        neuron_config=neuron_config,
+        context_length_estimate=[scheduler_config.max_model_len],
+        n_positions=[scheduler_config.max_model_len],
+        batch_size=scheduler_config.max_num_seqs,
+    )
+
+    return model.eval()

+ 2 - 2
aphrodite/modeling/sampling_metadata.py

@@ -5,7 +5,7 @@ import torch
 
 
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import SequenceData
 from aphrodite.common.sequence import SequenceData
-from aphrodite.common.utils import in_wsl
+from aphrodite.common.utils import in_wsl, is_neuron
 
 
 _SAMPLING_EPS = 1e-5
 _SAMPLING_EPS = 1e-5
 
 
@@ -292,7 +292,7 @@ class SamplingTensors:
                    dtype: torch.dtype) -> "SamplingTensors":
                    dtype: torch.dtype) -> "SamplingTensors":
         # Note that the performance will be very bad without
         # Note that the performance will be very bad without
         # pinned memory.
         # pinned memory.
-        pin_memory = not in_wsl()
+        pin_memory = not in_wsl() and not is_neuron()
         prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
         prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
         prompt_padded_tokens = [
         prompt_padded_tokens = [
             tokens + [vocab_size] * (prompt_max_len - len(tokens))
             tokens + [vocab_size] * (prompt_max_len - len(tokens))

+ 17 - 0
aphrodite/modeling/utils.py

@@ -1,10 +1,18 @@
 """Utils for model executor."""
 """Utils for model executor."""
 import random
 import random
+import importlib
 from typing import Any, Dict, Optional
 from typing import Any, Dict, Optional
 
 
 import numpy as np
 import numpy as np
 import torch
 import torch
 
 
+from aphrodite.common.config import DeviceConfig, ModelConfig
+
+DEVICE_TO_MODEL_LOADER_MAP = {
+    "cuda": "loader",
+    "neuron": "neuron_loader",
+}
+
 
 
 def set_random_seed(seed: int) -> None:
 def set_random_seed(seed: int) -> None:
     random.seed(seed)
     random.seed(seed)
@@ -33,3 +41,12 @@ def set_weight_attrs(
         assert not hasattr(
         assert not hasattr(
             weight, key), (f"Overwriting existing tensor attribute: {key}")
             weight, key), (f"Overwriting existing tensor attribute: {key}")
         setattr(weight, key, value)
         setattr(weight, key, value)
+
+
+def get_model(model_config: ModelConfig, device_config: DeviceConfig,
+              **kwargs) -> torch.nn.Module:
+    model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
+    imported_model_loader = importlib.import_module(
+        f"aphrodite.modeling.{model_loader_module}")
+    get_model_fn = imported_model_loader.get_model
+    return get_model_fn(model_config, device_config, **kwargs)

+ 39 - 48
aphrodite/processing/block_manager.py

@@ -1,7 +1,6 @@
 """A block manager that manages token blocks."""
 """A block manager that manages token blocks."""
-
 import enum
 import enum
-from itertools import count
+from itertools import count, takewhile
 from os.path import commonprefix
 from os.path import commonprefix
 from typing import Dict, List, Optional, Set, Tuple
 from typing import Dict, List, Optional, Set, Tuple
 
 
@@ -19,14 +18,12 @@ class BlockAllocator:
     the reference count becomes zero, the block is added back to the free list.
     the reference count becomes zero, the block is added back to the free list.
     """
     """
 
 
-    def __init__(
-        self,
-        device: Device,
-        block_size: int,
-        num_blocks: int,
-        eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
-        enable_caching: bool = False,
-    ) -> None:
+    def __init__(self,
+                 device: Device,
+                 block_size: int,
+                 num_blocks: int,
+                 eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
+                 enable_caching: bool = False) -> None:
         self.device = device
         self.device = device
         self.block_size = block_size
         self.block_size = block_size
         self.num_blocks = num_blocks
         self.num_blocks = num_blocks
@@ -49,13 +46,11 @@ class BlockAllocator:
             block.block_hash = block_hash
             block.block_hash = block_hash
             block.num_hashed_tokens = num_hashed_tokens
             block.num_hashed_tokens = num_hashed_tokens
             return block
             return block
-        block = PhysicalTokenBlock(
-            device=self.device,
-            block_number=self.current_num_blocks,
-            block_size=self.block_size,
-            block_hash=block_hash,
-            num_hashed_tokens=num_hashed_tokens,
-        )
+        block = PhysicalTokenBlock(device=self.device,
+                                   block_number=self.current_num_blocks,
+                                   block_size=self.block_size,
+                                   block_hash=block_hash,
+                                   num_hashed_tokens=num_hashed_tokens)
         self.current_num_blocks += 1
         self.current_num_blocks += 1
         return block
         return block
 
 
@@ -126,7 +121,6 @@ class AllocStatus(enum.Enum):
     3. Never: seq_group can never be allocated.
     3. Never: seq_group can never be allocated.
       The seq_group is too large to allocated in GPU.
       The seq_group is too large to allocated in GPU.
     """
     """
-
     OK = enum.auto()
     OK = enum.auto()
     LATER = enum.auto()
     LATER = enum.auto()
     NEVER = enum.auto()
     NEVER = enum.auto()
@@ -150,10 +144,8 @@ class BlockSpaceManager:
 
 
         self.block_sliding_window = None
         self.block_sliding_window = None
         if sliding_window is not None:
         if sliding_window is not None:
-            assert sliding_window % block_size == 0, (
-                sliding_window,
-                block_size,
-            )
+            assert sliding_window % block_size == 0, (sliding_window,
+                                                      block_size)
             self.block_sliding_window = sliding_window // block_size
             self.block_sliding_window = sliding_window // block_size
 
 
         self.watermark = watermark
         self.watermark = watermark
@@ -162,23 +154,19 @@ class BlockSpaceManager:
         self.enable_caching = enable_caching
         self.enable_caching = enable_caching
 
 
         self.watermark_blocks = int(watermark * num_gpu_blocks)
         self.watermark_blocks = int(watermark * num_gpu_blocks)
-        self.gpu_allocator = BlockAllocator(
-            Device.GPU,
-            block_size,
-            num_gpu_blocks,
-            enable_caching=enable_caching,
-        )
-        self.cpu_allocator = BlockAllocator(
-            Device.CPU,
-            block_size,
-            num_cpu_blocks,
-            enable_caching=enable_caching,
-        )
+        self.gpu_allocator = BlockAllocator(Device.GPU,
+                                            block_size,
+                                            num_gpu_blocks,
+                                            enable_caching=enable_caching)
+        self.cpu_allocator = BlockAllocator(Device.CPU,
+                                            block_size,
+                                            num_cpu_blocks,
+                                            enable_caching=enable_caching)
         # Mapping: seq_id -> BlockTable.
         # Mapping: seq_id -> BlockTable.
         self.block_tables: Dict[int, BlockTable] = {}
         self.block_tables: Dict[int, BlockTable] = {}
 
 
     def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
     def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
-        # FIXME(woosuk): Here we assume that all sequences in the group share
+        # FIXME: Here we assume that all sequences in the group share
         # the same prompt. This may not be true for preempted sequences.
         # the same prompt. This may not be true for preempted sequences.
         seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
         seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
         num_required_blocks = len(seq.logical_token_blocks)
         num_required_blocks = len(seq.logical_token_blocks)
@@ -213,8 +201,7 @@ class BlockSpaceManager:
             else:
             else:
                 block = self.gpu_allocator.allocate(
                 block = self.gpu_allocator.allocate(
                     seq.hash_of_block(logical_idx),
                     seq.hash_of_block(logical_idx),
-                    seq.num_hashed_tokens_of_block(logical_idx),
-                )
+                    seq.num_hashed_tokens_of_block(logical_idx))
             block_table.append(block)
             block_table.append(block)
 
 
         # Assign the block table for each sequence.
         # Assign the block table for each sequence.
@@ -444,23 +431,29 @@ class BlockSpaceManager:
         for block in block_table:
         for block in block_table:
             block.last_accessed = access_time
             block.last_accessed = access_time
 
 
-    def compute_last_full_block_in_seq(self, seq: Sequence):
+    def compute_full_blocks_in_seq(self, seq: Sequence):
         if seq.seq_id not in self.block_tables:
         if seq.seq_id not in self.block_tables:
             return
             return
         max_full_block = seq.get_len() // self.block_size - 1
         max_full_block = seq.get_len() // self.block_size - 1
         block_table = self.block_tables[seq.seq_id]
         block_table = self.block_tables[seq.seq_id]
         if max_full_block == -1:
         if max_full_block == -1:
             return
             return
-        block_table[max_full_block].computed = True
+        for i in reversed(range(max_full_block)):
+            if block_table[i].computed:
+                break
+            block_table[i].computed = True
 
 
-    def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
+    def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
         if seq.seq_id not in self.block_tables:
         if seq.seq_id not in self.block_tables:
             return []
             return []
         block_table = self.block_tables[seq.seq_id]
         block_table = self.block_tables[seq.seq_id]
-        for block_idx in reversed(range(len(block_table))):
-            if block_table[block_idx].computed:
-                return [b.block_number for b in block_table[:block_idx + 1]]
-        return []
+        # NOTE We exclude the last block to avoid the case where the entire
+        # prompt is cached. This would cause erroneous behavior in model
+        # runner.
+        return [
+            b.block_number
+            for b in takewhile(lambda b: b.computed, block_table[:-1])
+        ]
 
 
     def get_common_computed_block_ids(self,
     def get_common_computed_block_ids(self,
                                       seq_group: SequenceGroup) -> List[int]:
                                       seq_group: SequenceGroup) -> List[int]:
@@ -469,14 +462,12 @@ class BlockSpaceManager:
             return []
             return []
 
 
         ids_list = [
         ids_list = [
-            self.get_all_block_ids_till_computed(seq)
+            self.get_all_computed_blocks(seq)
             for seq in iter(seq_group.seqs_dict.values())
             for seq in iter(seq_group.seqs_dict.values())
         ]
         ]
         return commonprefix([ids for ids in ids_list if ids != []])
         return commonprefix([ids for ids in ids_list if ids != []])
 
 
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
     def mark_blocks_as_computed(self, seq_group: SequenceGroup):
-        # NOTE: We only mark the last full block because with prefix caching,
-        # all blocks until the marked one are guaranteed to be computed.
         if self.enable_caching:
         if self.enable_caching:
             for seq in seq_group.seqs_dict.values():
             for seq in seq_group.seqs_dict.values():
-                self.compute_last_full_block_in_seq(seq)
+                self.compute_full_blocks_in_seq(seq)

+ 2 - 3
aphrodite/processing/evictor.py

@@ -7,8 +7,9 @@ from aphrodite.common.block import PhysicalTokenBlock
 
 
 class EvictionPolicy(enum.Enum):
 class EvictionPolicy(enum.Enum):
     """Enum for eviction policy used by make_evictor to instantiate the correct
     """Enum for eviction policy used by make_evictor to instantiate the correct
-       Evictor subclass.
+    Evictor subclass.
     """
     """
+
     LRU = enum.auto()
     LRU = enum.auto()
     FIFO = enum.auto()
     FIFO = enum.auto()
 
 
@@ -115,7 +116,6 @@ class LRUEvictor(Evictor):
         return block
         return block
 
 
     @property
     @property
-    # pylint: disable=invalid-overridden-method
     def num_blocks(self) -> int:
     def num_blocks(self) -> int:
         return len(self.free_table)
         return len(self.free_table)
 
 
@@ -149,7 +149,6 @@ class RandomEvictor(Evictor):
         return block
         return block
 
 
     @property
     @property
-    # pylint: disable=invalid-overridden-method
     def num_blocks(self) -> int:
     def num_blocks(self) -> int:
         return len(self.free_table)
         return len(self.free_table)
 
 

+ 1 - 4
aphrodite/processing/scheduler.py

@@ -65,10 +65,7 @@ class SchedulerOutputs:
     def _sort_by_lora_ids(self) -> bool:
     def _sort_by_lora_ids(self) -> bool:
         self.scheduled_seq_groups = sorted(
         self.scheduled_seq_groups = sorted(
             self.scheduled_seq_groups,
             self.scheduled_seq_groups,
-            key=lambda g: (
-                g.lora_request.lora_int_id if g.lora_request else 0,
-                g.request_id,
-            ),
+            key=lambda g: (g.lora_int_id, g.request_id),
         )
         )
 
 
     @property
     @property

+ 398 - 0
aphrodite/spec_decode/batch_expansion.py

@@ -0,0 +1,398 @@
+from typing import Iterator, List, Tuple, Optional, Dict
+from itertools import chain, count
+
+import torch
+
+from aphrodite.common.sequence import (
+    SamplerOutput,
+    SequenceGroupMetadata,
+    SequenceData,
+)
+from aphrodite.task_handler.worker import Worker
+from aphrodite.spec_decode.util import (
+    nvtx_range,
+    sampler_output_to_torch,
+    get_all_seq_ids,
+    split_batch_by_proposal_len,
+)
+from aphrodite.spec_decode.interfaces import (
+    SpeculativeScorer,
+    SpeculativeProposals,
+    SpeculativeScores,
+)
+
+SeqId = int
+TargetSeqId = int
+TokenId = int
+
+
+class BatchExpansionTop1Scorer(SpeculativeScorer):
+    """Implements a speculative scorer that uses batch expansion to get
+    probabilities of speculative tokens according to the scoring model.
+
+    Batch expansion converts a list of sequences and multiple query positions
+    to a new batch of sequences, each with a single query position. This allows
+    for MQA-like scoring in speculative decoding without requiring an MQA
+    kernel.
+
+    It is strictly less efficient than MQA scoring.
+
+    It only supports scoring the top1 proposal tokens of the proposer, instead
+    of topk/tree.
+    """
+
+    def __init__(self, scorer_worker: Worker, device: str, vocab_size: int):
+        self._scorer_worker = scorer_worker
+        self._device = device
+        self._vocab_size = vocab_size
+
+    @nvtx_range("BatchExpansionTop1Scorer.score_proposals")
+    def score_proposals(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Optional[Dict[int, int]],
+        blocks_to_swap_out: Optional[Dict[int, int]],
+        blocks_to_copy: Optional[Dict[int, List[int]]],
+        k: int,
+        proposals: SpeculativeProposals,
+    ) -> SpeculativeScores:
+        """Score the proposed tokens via the scorer model.
+
+        This converts each input sequence to a set of k+1 target sequences. The
+        target sequences have the unique continuations to be scored and a
+        unique sequence ID that is different from all input sequence ids.
+
+        If a speculative sequence length would exceed the max model length, then
+        no speculation is produced for that sequence.
+
+        Args:
+            seq_group_metadata_list: The input sequence group metadata.
+            blocks_to_swap_in: This is passed to the worker during scoring.
+            blocks_to_swap_out: This is passed to the worker during scoring.
+            blocks_to_copy: This is passed to the worker during scoring.
+            k: The fixed proposal length.
+            proposals: The speculative proposals to score.
+        Returns:
+            SpeculativeScores: The scores of each speculative token, along with
+                which sequences were ignored during scoring.
+        """
+
+        # TODO: perform this on GPU to remove blocking call.
+        proposal_lens_list = proposals.proposal_lens.tolist()
+        proposal_token_ids_list = proposals.proposal_token_ids.tolist()
+
+        (
+            spec_indices,
+            non_spec_indices,
+            target_seq_group_metadata_list,
+            num_scoring_tokens,
+        ) = self._expand_batch(
+            seq_group_metadata_list=seq_group_metadata_list,
+            proposal_token_ids_list=proposal_token_ids_list,
+            proposal_lens_list=proposal_lens_list,
+        )
+
+        target_sampler_output = self._scorer_worker.execute_model(
+            seq_group_metadata_list=target_seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+            return_python_output=False,
+        )
+
+        all_tokens, all_probs = self._contract_batch(
+            original_bs=len(seq_group_metadata_list),
+            target_sampler_output=target_sampler_output,
+            proposals=proposals,
+            num_scoring_tokens=num_scoring_tokens,
+            non_spec_indices=non_spec_indices,
+            spec_indices=spec_indices,
+            k=k,
+        )
+
+        return SpeculativeScores(
+            probs=all_probs,
+            token_ids=all_tokens,
+        )
+
+    def _expand_batch(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        proposal_token_ids_list: List[TokenId],
+        proposal_lens_list: List[int],
+    ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
+        """Given the input sequences and potentially multiple corresponding
+        proposal tokens, create a new batch where each sequence has a single
+        query token.
+        """
+
+        # Aphrodite currently only supports proposal lens equal to zero or the
+        # batch proposal len. This adds some complexity (splitting the batch
+        # into spec and non spec sequences) and should be removed in the
+        # future. It can be done by supporting per-sequence proposal lens.
+        spec_seqs, spec_indices = split_batch_by_proposal_len(
+            seq_group_metadata_list,
+            proposal_lens_list,
+            select_proposal_len_zero=False,
+        )
+        non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
+            seq_group_metadata_list,
+            proposal_lens_list,
+            select_proposal_len_zero=True,
+        )
+
+        target_seq_group_metadata_list = self._create_scoring_model_input(
+            spec_seqs, proposal_token_ids_list)
+        num_scoring_tokens = len(target_seq_group_metadata_list)
+        target_seq_group_metadata_list.extend(non_spec_seqs)
+
+        return (
+            spec_indices,
+            non_spec_indices,
+            target_seq_group_metadata_list,
+            num_scoring_tokens,
+        )
+
+    def _contract_batch(
+        self,
+        original_bs: int,
+        target_sampler_output: List[SamplerOutput],
+        proposals: SpeculativeProposals,
+        num_scoring_tokens: int,
+        non_spec_indices: List[int],
+        spec_indices: List[int],
+        k: int,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Contract the expanded batch back into its original size.
+        This maps the scores of speculative tokens back to their original
+        sequences.
+        """
+        (
+            target_token_ids,
+            target_probs,
+            non_spec_target_token_ids,
+            non_spec_target_probs,
+        ) = self._split_scoring_output(target_sampler_output,
+                                       num_scoring_tokens)
+
+        # Map distinct sequences used to score each token
+        # of shape [batch_size * k + 1] back to [batch_size, k + 1].
+        batch_size, k = proposals.proposal_token_ids.shape
+
+        target_token_ids = target_token_ids.squeeze().reshape(
+            batch_size, k + 1)
+        target_probs = target_probs.squeeze().reshape(batch_size, k + 1,
+                                                      self._vocab_size)
+
+        all_tokens = torch.full(
+            size=(original_bs, k + 1),
+            fill_value=-1,
+            device=self._device,
+            dtype=torch.long,
+        )
+        all_probs = torch.zeros(
+            original_bs,
+            k + 1,
+            self._vocab_size,
+            device=self._device,
+            dtype=torch.float32,
+        )
+
+        if non_spec_indices:
+            all_tokens[non_spec_indices, 0] = non_spec_target_token_ids
+            all_probs[non_spec_indices, :1, :] = non_spec_target_probs
+
+        if spec_indices:
+            all_tokens[spec_indices] = target_token_ids
+            all_probs[spec_indices] = target_probs
+
+        return all_tokens, all_probs
+
+    def _create_scoring_model_input(
+            self,
+            seq_group_metadata_list: List[SequenceGroupMetadata],
+            proposal_token_ids: List[List[TokenId]],  # shape: [batch_size, k]
+    ) -> List[SequenceGroupMetadata]:
+        """Given the original input sequences and proposed tokens from the draft
+        model, create a list of target sequences that can be used for scoring.
+        """
+
+        if not seq_group_metadata_list:
+            return []
+
+        target_seq_ids_iter = self._create_target_seq_id_iterator(
+            get_all_seq_ids(seq_group_metadata_list))
+
+        target_seq_group_metadata = list(
+            chain.from_iterable(
+                self._create_target_seq_group_metadata(
+                    seq_group_metadata,
+                    proposal_token_ids,
+                    i,
+                    target_seq_ids_iter,
+                ) for i, seq_group_metadata in enumerate(
+                    seq_group_metadata_list)))
+
+        return target_seq_group_metadata
+
+    def _create_target_seq_group_metadata(
+        self,
+        input_seq_group_metadata: SequenceGroupMetadata,
+        proposal_token_ids: List[TokenId],  # shape: [batch_size, k]
+        batch_index: int,
+        target_seq_ids_iter: Iterator[TargetSeqId],
+    ) -> List[SequenceGroupMetadata]:
+        """Given an input sequence group metadata and a list of draft tokens,
+        create a list of target SequenceGroupMetadata, one for each
+        token id that needs to be scored.
+
+        Naive speculative decoding requires K target model scores, one for each
+        draft model token. However one can add a bonus token such that if each
+        token is accepted, then a final token may be sampled from the model.
+        This function creates K+1 target SequenceGroupMetadata to take
+        advantage of the bonus token.
+        """
+        assert not input_seq_group_metadata.is_prompt, (
+            "Speculating on "
+            "prompts not yet supported")
+        assert len(input_seq_group_metadata.seq_data) == 1, (
+            "Beam search "
+            "not supported in speculative decoding")
+        input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys()))
+
+        token_ids_to_score = self._get_token_ids_to_score(
+            proposal_token_ids[batch_index])
+
+        target_seq_group_metadata_list: List[SequenceGroupMetadata] = []
+        for token_ids in token_ids_to_score:
+            target_seq_group_metadata_list.append(
+                self._create_single_target_seq_group_metadata(
+                    input_seq_group_metadata,
+                    input_seq_id,
+                    next(target_seq_ids_iter),
+                    token_ids,
+                ))
+
+        return target_seq_group_metadata_list
+
+    def _create_single_target_seq_group_metadata(
+        self,
+        seq_group_metadata: SequenceGroupMetadata,
+        seq_id: SeqId,
+        target_seq_id: TargetSeqId,
+        token_ids: List[TokenId],
+    ) -> SequenceGroupMetadata:
+        """Create a single target SequenceGroupMetadata.
+
+        Args:
+            seq_group_metadata: The metadata for the input sequence.
+            seq_id: The input sequence ID.
+            target_seq_id: The corresponding target sequence ID.
+            token_ids: The list of token ids that are to be appended to the
+                input sequence.
+        """
+        seq_data = seq_group_metadata.seq_data[seq_id]
+        prompt_token_ids = seq_data.get_prompt_token_ids()
+        new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids]
+
+        return SequenceGroupMetadata(
+            request_id=seq_group_metadata.request_id,
+            is_prompt=seq_group_metadata.is_prompt,
+            seq_data={
+                target_seq_id:
+                SequenceData(
+                    prompt_token_ids=prompt_token_ids,
+                    output_token_ids=new_output_token_ids,
+                ),
+            },
+            sampling_params=seq_group_metadata.sampling_params,
+            block_tables={
+                target_seq_id: seq_group_metadata.block_tables[seq_id],
+            },
+            lora_request=None,
+            persistent_data={},
+        )
+
+    def _split_scoring_output(
+        self, sampler_output: SamplerOutput, num_scoring_tokens: int
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Split the target model output into speculative and non-speculative
+        output.
+        """
+
+        # Aphrodite currently only supports proposal lens equal to zero or the
+        # batch proposal len. This adds some complexity (splitting the batch
+        # into spec and non spec sequences) and should be removed in the
+        # future. It can be done by supporting per-sequence proposal lens.
+        # First samples are from speculative scoring, latter samples are non-
+        # speculative samples.
+        split_sizes = [
+            num_scoring_tokens,
+            sampler_output.sampled_token_ids.numel() - num_scoring_tokens,
+        ]
+        (spec_probs, non_spec_probs
+         ) = sampler_output.sampled_token_probs.split(split_sizes)
+        (
+            spec_sampled_tokens,
+            non_spec_sampled_tokens,
+        ) = sampler_output.sampled_token_ids.flatten().split(split_sizes)
+
+        # Convert scores to tensors.
+        sampler_output.sampled_token_probs = spec_probs
+        sampler_output.sampled_token_ids = spec_sampled_tokens
+        target_token_ids, target_probs = sampler_output_to_torch(
+            [sampler_output])
+
+        # Convert non-speculative output tokens to tensors.
+        sampler_output.sampled_token_probs = non_spec_probs
+        sampler_output.sampled_token_ids = non_spec_sampled_tokens
+        (
+            non_spec_target_token_ids,
+            non_spec_target_probs,
+        ) = sampler_output_to_torch([sampler_output])
+
+        return (
+            target_token_ids,
+            target_probs,
+            non_spec_target_token_ids,
+            non_spec_target_probs,
+        )
+
+    def _create_target_seq_id_iterator(
+            self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
+        """Create an iterator for creating target sequence ids.
+        Target sequence ids are distinct from sequence ids because we create a
+        distinct target sequence id for each proposal token to be scored.
+
+        This implementation increments a counter starting at 1 + max of all
+        provided input sequence ids.
+        """
+        return count(start=max(seq_ids) + 1)
+
+    def _get_token_ids_to_score(
+            self,
+            full_spec_token_ids: List[TokenId],  # shape: [k]
+    ) -> List[List[TokenId]]:
+        """Given an int tensor of proposal token ids, return a list of
+        token ids that should be scored.
+
+        Returns k+1 output lists. The additional one is used for generating the
+        bonus token.
+
+        Example:
+            Input: [0, 1, 2, 3] (k=4)
+            Output: (k+1 lists)
+                []
+                [0]
+                [0, 1]
+                [0, 1, 2]
+                [0, 1, 2, 3]
+        """
+        empty_token_ids = []
+
+        token_ids_to_score = [empty_token_ids]
+        token_ids_to_score.extend([
+            full_spec_token_ids[:i + 1]
+            for i in range(len(full_spec_token_ids))
+        ])
+        return token_ids_to_score

+ 77 - 0
aphrodite/spec_decode/interfaces.py

@@ -0,0 +1,77 @@
+from typing import List, Tuple, Optional, Dict
+from dataclasses import dataclass
+from abc import ABC, abstractmethod
+
+import torch
+
+from aphrodite.common.sequence import SequenceGroupMetadata
+
+
+@dataclass
+class SpeculativeProposals:
+    """Datastructure used to represent proposal tokens from some proposer. It
+    also tracks how many speculative tokens each sequence has.
+    """
+
+    # Speculative proposal tokens.
+    proposal_token_ids: torch.Tensor
+
+    # Probabilities of the proposal tokens according to the proposer.
+    proposal_probs: torch.Tensor
+
+    # The valid length of each proposal; can be zero.
+    proposal_lens: torch.Tensor
+
+    def __repr__(self):
+        return (f"SpeculativeProposals("
+                f"proposal_token_ids={self.proposal_token_ids.shape}, "
+                f"proposal_probs={self.proposal_probs.shape}, "
+                f"proposal_lens={self.proposal_lens.shape})")
+
+
+@dataclass
+class SpeculativeScores:
+    """Datastructure used to represent the scores of speculative tokens
+    according to the scoring model.
+    """
+
+    # Probabilities of the speculative tokens according to the scoring model.
+    probs: torch.Tensor
+
+    # Token ids sampled from the scoring model. Used for speculative bonus
+    # tokens and also non-speculative normal decoding.
+    token_ids: torch.Tensor
+
+    def __repr__(self):
+        return (f"SpeculativeScores("
+                f"probs={self.probs.shape}, "
+                f"token_ids={self.token_ids.shape})")
+
+
+class SpeculativeProposer(ABC):
+
+    @abstractmethod
+    def get_proposals(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+        max_proposal_len: int,
+    ) -> SpeculativeProposals:
+        raise NotImplementedError
+
+
+class SpeculativeScorer(ABC):
+
+    @abstractmethod
+    def score_proposals(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Optional[Dict[int, int]],
+        blocks_to_swap_out: Optional[Dict[int, int]],
+        blocks_to_copy: Optional[Dict[int, List[int]]],
+        k: int,
+        proposals: SpeculativeProposals,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        raise NotImplementedError

+ 175 - 0
aphrodite/spec_decode/metrics.py

@@ -0,0 +1,175 @@
+import torch
+from dataclasses import dataclass
+from typing import Optional
+import time
+from typing import Callable
+
+from aphrodite.modeling.layers.rejection import RejectionSampler
+from aphrodite.common.utils import in_wsl
+
+
+@dataclass
+class SpecDecodeWorkerMetrics:
+    """Dataclass holding metrics emitted from the spec decode worker."""
+
+    # The empirical acceptance rate of the proposal method on a per-token basis.
+    # This is useful for evaluating how well the proposal method aligns with the
+    # scoring method.
+    draft_acceptance_rate: float
+
+    # The empirical efficiency, measured as the number of tokens emitted by the
+    # system divided by the number of tokens that could be emitted by the system
+    # if the proposal method were perfect.
+    system_efficiency: float
+
+    # The number of speculative tokens produced by the proposal method.
+    draft_tokens: int
+
+    # The number of tokens emitted by the entire system.
+    emitted_tokens: int
+
+    # The number of tokens accepted by the scoring model and verification
+    # routine, e.g. Llama2-70B and lossless rejection sampling.
+    #
+    # NOTE: Any token accepted by the verification routine is considered
+    # accepted (regardless of if the speculative prefix is also accepted). The
+    # user will usually see less accepted tokens. This metric is helpful when
+    # evaluating alignment of the proposal method with the scoring model.
+    accepted_tokens: int
+
+    # The number of speculative tokens per sequence.
+    num_spec_tokens: int
+
+
+Timer = Callable[[], float]
+
+
+class AsyncMetricsCollector:
+    """Class which copies rejection sampler metrics from the device to CPU on a
+    non-default Torch stream.
+    """
+
+    def __init__(
+        self,
+        rejection_sampler: RejectionSampler,
+        timer: Optional[Timer] = None,
+        collect_interval_s: float = 5.0,
+    ):
+        self._rejection_sampler = rejection_sampler
+        self._timer = time.time if timer is None else timer
+
+        self._rank: Optional[int] = None
+
+        # We don't have a device set yet.
+        self._copy_stream: Optional[torch.cuda.Stream] = None
+
+        self._in_flight_copy: Optional[torch.cuda.Event] = None
+
+        pin_memory = not in_wsl()
+        self._aggregate_num_accepted_tokens = torch.tensor(
+            0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
+        self._aggregate_num_emitted_tokens = torch.tensor(
+            0, dtype=torch.long, device="cpu", pin_memory=pin_memory)
+        self._aggregate_num_draft_tokens = 0
+
+        self._rejsample_metrics_collect_interval_s = collect_interval_s
+        self._last_metrics_collect_time = self._timer()
+
+    def init_gpu_tensors(self, rank: int) -> None:
+        self._rank = rank
+        self._copy_stream = torch.cuda.Stream()
+
+    def maybe_collect_rejsample_metrics(
+            self, k: int) -> Optional[SpecDecodeWorkerMetrics]:
+        # If a copy was initiated in the previous call, collect and return.
+        if self._in_flight_copy is not None:
+            ready_event = self._in_flight_copy
+            self._in_flight_copy = None
+            return self._collect_rejsample_metrics(k, ready_event)
+
+        # Otherwise, check if we should start a new copy.
+        if self._should_collect_rejsample_metrics(self._timer()):
+            assert self._in_flight_copy is None
+            self._in_flight_copy = self._copy_rejsample_metrics_async()
+
+        return None
+
+    def _should_collect_rejsample_metrics(self, now: float) -> bool:
+        """Return whether or not this iteration should print rejection sampling
+        metrics.
+        """
+        if self._rank != 0:
+            return False
+
+        if (now - self._last_metrics_collect_time <
+                self._rejsample_metrics_collect_interval_s):
+            return False
+        return True
+
+    def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
+        """Copy rejection sampling metrics (number of accepted tokens, etc) to
+        CPU asynchronously.
+
+        Returns a CUDA event recording when the copy is complete.
+        """
+        self._copy_stream.wait_stream(torch.cuda.current_stream())
+
+        with torch.cuda.stream(self._copy_stream):
+            self._aggregate_num_accepted_tokens.copy_(
+                self._rejection_sampler.num_accepted_tokens, non_blocking=True)
+            self._aggregate_num_emitted_tokens.copy_(
+                self._rejection_sampler.num_emitted_tokens, non_blocking=True)
+            # Number of draft tokens is calculated on CPU, so no copy is
+            # required.
+            self._aggregate_num_draft_tokens = (
+                self._rejection_sampler.num_draft_tokens)
+
+        aggregate_metrics_ready = torch.cuda.Event()
+        aggregate_metrics_ready.record(self._copy_stream)
+
+        return aggregate_metrics_ready
+
+    def _collect_rejsample_metrics(
+            self, k: int,
+            ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics:
+        """Create metrics object from statistics copied asynchronously.
+
+        Args:
+            k: int. The number of speculative tokens; used to determine system
+                efficiency.
+            ready_event: torch.cuda.Event. The CUDA event recording when the
+                async GPU->CPU copy is complete.
+        """
+
+        ready_event.synchronize()
+        accepted_tokens = self._aggregate_num_accepted_tokens.item()
+        emitted_tokens = self._aggregate_num_emitted_tokens.item()
+        draft_tokens = self._aggregate_num_draft_tokens
+
+        num_possible_tokens = self.get_max_num_accepted_tokens(draft_tokens, k)
+
+        if draft_tokens > 0:
+            draft_acceptance_rate = accepted_tokens / draft_tokens
+        else:
+            draft_acceptance_rate = float("nan")
+
+        if num_possible_tokens > 0:
+            system_efficiency = emitted_tokens / num_possible_tokens
+        else:
+            system_efficiency = float("nan")
+
+        return SpecDecodeWorkerMetrics(
+            num_spec_tokens=k,
+            draft_acceptance_rate=draft_acceptance_rate,
+            system_efficiency=system_efficiency,
+            accepted_tokens=accepted_tokens,
+            draft_tokens=draft_tokens,
+            emitted_tokens=emitted_tokens,
+        )
+
+    @staticmethod
+    def get_max_num_accepted_tokens(draft_tokens: int, k: int) -> int:
+        # Divide by k since batch size can be variable.
+        total_num_spec_seqs = draft_tokens / k
+        num_accepted_per_seq_if_all_accepted = k + 1
+        return int(total_num_spec_seqs / num_accepted_per_seq_if_all_accepted)

+ 392 - 0
aphrodite/spec_decode/multi_step_worker.py

@@ -0,0 +1,392 @@
+from typing import List, Dict, Optional, Tuple
+import copy
+
+import torch
+
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.task_handler.worker import Worker
+from aphrodite.spec_decode.interfaces import (
+    SpeculativeProposals,
+    SpeculativeProposer,
+)
+from aphrodite.spec_decode.util import sampler_output_to_torch
+
+
+class MultiStepWorker(Worker):
+    """The MultiStepWorker is equivalent to a Worker except that it allows
+    multiple forward passes in a single call, assuming the scheduler has
+    allocated enough space to store the additional KV. This reduces overhead
+    by invoking the scheduler less.
+
+    The MultiStepWorker does not support cache swap operations, or beam search.
+    Cache swap operations do not require large modifications. On the other hand,
+    beam search requires memory allocations during sequence forks and thus
+    requires more thought for MultiStepWorker support.
+    """
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        self._proposer: Optional[DraftModelTop1Proposer] = None
+
+    def init_model(self):
+        super().init_model()
+
+        self._proposer = DraftModelTop1Proposer(
+            self,
+            self.device,
+            self.max_model_len,
+            self.vocab_size,
+        )
+
+    @torch.inference_mode()
+    def execute_model_multi_step(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+        num_steps: int,
+    ) -> List[SamplerOutput]:
+        """Run the model forward pass num_steps times. Returns the list of
+        sampler output, one per model forward pass.
+        """
+        self._raise_if_unsupported(
+            seq_group_metadata_list,
+            blocks_to_swap_in,
+            blocks_to_swap_out,
+            blocks_to_copy,
+        )
+
+        # Shallow copy input data so modifications (such as appending tokens)
+        # do not cause side-effects.
+        copied_seq_group_metadata_list = self._shallow_copy_inputs(
+            seq_group_metadata_list)
+
+        # Assert enough KV space for num_steps tokens per sequence.
+        self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
+
+        # Run model num_steps times.
+        model_outputs = []
+        for _ in range(num_steps):
+            model_output = super().execute_model(
+                seq_group_metadata_list=copied_seq_group_metadata_list,
+                blocks_to_swap_in=blocks_to_swap_in,
+                blocks_to_swap_out=blocks_to_swap_out,
+                blocks_to_copy=blocks_to_copy,
+            )
+
+            self._append_new_tokens(model_output,
+                                    copied_seq_group_metadata_list)
+            model_outputs.append(model_output)
+
+        return model_outputs
+
+    def get_spec_proposals(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+        max_proposal_len: int,
+    ) -> SpeculativeProposals:
+        """Produce speculations given an input batch of sequences. The number of
+        speculative tokens per sequence is determined by max_proposal_len.
+        """
+
+        return self._proposer.get_proposals(
+            seq_group_metadata_list,
+            blocks_to_swap_in,
+            blocks_to_swap_out,
+            blocks_to_copy,
+            max_proposal_len,
+        )
+
+    def _append_new_tokens(
+        self,
+        model_output: SamplerOutput,
+        seq_group_metadata_list: SequenceGroupMetadata,
+    ) -> None:
+        """Given model output from a single run, append the tokens to the
+        sequences. This is normally done outside of the worker, but it is
+        required if the worker is to perform multiple forward passes.
+        """
+        for seq_group_metadata, sequence_group_outputs in zip(
+                seq_group_metadata_list, model_output):
+            seq_group_metadata.is_prompt = False
+
+            for seq_output in sequence_group_outputs.samples:
+                # NOTE: Beam search is not supported, so we can assume that
+                # parent_seq_id == seq_id.
+                seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
+
+                token_id = seq_output.output_token
+                token_logprob = seq_output.logprobs[token_id]
+
+                seq.append_token_id(token_id, token_logprob.logprob)
+
+    def _shallow_copy_inputs(
+        self, seq_group_metadata_list: List[SequenceGroupMetadata]
+    ) -> List[SequenceGroupMetadata]:
+        """Copy input data structures to remove side-effects when input data
+        structures are shared with other modules.
+
+        Helpful when the Aphrodite scheduler runs in the same process as the
+        worker. The alternative is deep-copying (or other form of deep copy);
+        this has performance downsides.
+        """
+
+        # Shallow-copy the list of SequenceGroupMetadata. This allows us to
+        # append tokens and change is_prompt without external side-effects.
+        new_seq_group_metadata_list = []
+
+        for old_seq_group_metadata in seq_group_metadata_list:
+            # We must shallow-copy seq_group_metadata as is_prompt could change.
+            seq_group_metadata = copy.copy(old_seq_group_metadata)
+            new_seq_group_metadata_list.append(seq_group_metadata)
+
+            # We must shallow-copy seq_data as we will append token ids
+            new_seq_data = {}
+            for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
+                new_seq_data[seq_id] = copy.copy(old_seq_data)
+                new_seq_data[
+                    seq_id].output_token_ids = old_seq_data.output_token_ids[:]
+
+            seq_group_metadata.seq_data = new_seq_data
+
+        return new_seq_group_metadata_list
+
+    def _assert_enough_kv_space(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        num_steps: int,
+    ) -> None:
+        """Assert there are enough physical blocks per sequence to store the
+        current KV plus additional KV from num_steps tokens.
+        """
+        assert self.model_runner.block_size is not None
+        for seq_group_metadata in seq_group_metadata_list:
+            # Only one seq_id is guaranteed because there is no beam search.
+            seq_id = list(seq_group_metadata.seq_data.keys())[0]
+            seq = seq_group_metadata.seq_data[seq_id]
+
+            # After num_steps, the seq len will be the current seq len
+            # plus one token per step.
+            final_seq_len = seq.get_len() + num_steps
+
+            # We will have final_seq_len - 1 KV because Aphrodite saves KV for a
+            # token in the iteration after the token was generated.
+            required_num_kv_slots = final_seq_len - 1
+
+            # The allocated number of kv slots is the number of allocated blocks
+            # times the number of slots of block.
+            number_physical_blocks = len(
+                seq_group_metadata.block_tables[seq_id])
+            allocated_kv_slots = (number_physical_blocks *
+                                  self.model_runner.block_size)
+
+            if required_num_kv_slots > allocated_kv_slots:
+                request_id = seq_group_metadata.request_id
+                raise ValueError(
+                    "The worker attempted to run "
+                    f"{num_steps} times but found insufficient KV space for "
+                    f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
+                    f"{required_num_kv_slots=}).")
+
+    def _raise_if_unsupported(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> None:
+        """MultiStepWorker does not yet implement support for cache swap
+        operations or beam search.
+        """
+        if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
+            raise NotImplementedError(
+                "MultiStepWorker does not support cache operations")
+
+        if any(
+                len(seq_group_metadata.seq_data.keys()) != 1
+                for seq_group_metadata in seq_group_metadata_list):
+            raise NotImplementedError(
+                "MultiStepWorker does not support beam search.")
+
+
+class DraftModelTop1Proposer(SpeculativeProposer):
+    """Helper class which separates out sequences which would exceed the max
+    model length when speculated upon.
+
+    This allows combinations of models such as JackFram/llama-68m draft with
+    meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of
+    2048 while Llama2-13b has max_position_embeddings of 4096.
+
+    We treat the sequences which exceed the proposal draft model length as
+    "non-spec sequences". Essentially they skip the draft model and go through
+    normal decoding in the target model.
+
+    Currently, only proposal_lens of 0 and k are supported, where k is a global
+    batch proposal length. In the future Aphrodite should support per-sequence
+    proposal lengths.
+    """
+
+    def __init__(
+        self,
+        draft_worker: MultiStepWorker,
+        device: str,
+        max_model_len: int,
+        vocab_size: int,
+    ):
+        self._draft_worker = draft_worker
+        self._device = device
+        self._max_model_len = max_model_len
+        self._vocab_size = vocab_size
+
+    def get_proposals(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+        max_proposal_len: int,
+    ) -> SpeculativeProposals:
+        """Get speculative proposals given the input batch.
+
+        Sequences which would exceed the max model length are skipped during
+        speculation.
+        """
+
+        # Split speculative- and non-speculative- sequences.
+        (
+            proposal_lens,
+            nonzero_proposal_len_seqs,
+            nonzero_proposal_len_indices,
+        ) = self._split_by_max_model_len(seq_group_metadata_list,
+                                         max_proposal_len)
+
+        if nonzero_proposal_len_seqs:
+            # Speculate tokens using the draft worker for the speculative
+            # sequences.
+            maybe_sampler_output = self._draft_worker.execute_model_multi_step(
+                seq_group_metadata_list=nonzero_proposal_len_seqs,
+                blocks_to_swap_in=blocks_to_swap_in,
+                blocks_to_swap_out=blocks_to_swap_out,
+                blocks_to_copy=blocks_to_copy,
+                num_steps=max_proposal_len,
+            )
+        else:
+            # If no sequences can be speculated, set sampler output to None.
+            maybe_sampler_output = None
+
+        # Combine speculative- and non-speculative sequences into the same
+        # representation.
+        proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs(
+            batch_size=len(seq_group_metadata_list),
+            max_proposal_len=max_proposal_len,
+            maybe_sampler_output=maybe_sampler_output,
+            proposal_lens=proposal_lens,
+            nonzero_proposal_len_indices=nonzero_proposal_len_indices,
+        )
+
+        proposals = SpeculativeProposals(
+            proposal_token_ids=proposal_tokens,
+            proposal_probs=proposal_probs,
+            proposal_lens=proposal_lens,
+        )
+
+        return proposals
+
+    def _split_by_max_model_len(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        max_proposal_len: int,
+    ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]:
+        """Determine which sequences would exceed the max model length."""
+
+        proposal_lens: List[int] = []
+        nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = []
+        nonzero_proposal_len_indices: List[int] = []
+        for i, seq_group_metadata in enumerate(seq_group_metadata_list):
+            seq_data = next(iter(seq_group_metadata.seq_data.values()))
+            seq_len = seq_data.get_len()
+
+            # Currently only proposal lens of 0 or the global batch proposal len
+            # are supported.
+            if seq_len + max_proposal_len < self._max_model_len:
+                proposal_lens.append(max_proposal_len)
+                nonzero_proposal_len_seqs.append(seq_group_metadata)
+                nonzero_proposal_len_indices.append(i)
+            else:
+                proposal_lens.append(0)
+
+        return (
+            proposal_lens,
+            nonzero_proposal_len_seqs,
+            nonzero_proposal_len_indices,
+        )
+
+    def _merge_outputs(
+        self,
+        batch_size: int,
+        max_proposal_len: int,
+        maybe_sampler_output: Optional[SamplerOutput],
+        proposal_lens: List[int],
+        nonzero_proposal_len_indices: List[int],
+    ) -> Tuple[torch.Tensor, torch.tensor, torch.Tensor]:
+        """After speculations are produced, merge the speculation results with
+        the skipped sequences.
+        """
+        if maybe_sampler_output is None:
+            # If no speculative tokens, the sampler output will be None.
+            # In this case we return empty tensors.
+            proposal_tokens = torch.zeros(0,
+                                          max_proposal_len,
+                                          dtype=torch.long,
+                                          device=self._device)
+            proposal_probs = torch.zeros(
+                0,
+                max_proposal_len,
+                self._vocab_size,
+                dtype=torch.float32,
+                device=self._device,
+            )
+            proposal_lens = torch.zeros(len(proposal_lens),
+                                        dtype=torch.long,
+                                        device=self._device)
+            return proposal_tokens, proposal_probs, proposal_lens
+
+        sampler_output = maybe_sampler_output
+
+        proposal_tokens, proposal_probs = sampler_output_to_torch(
+            sampler_output)
+
+        # Now, reformat the output GPU tensors such that each sequence has
+        # a proposal. the proposal can be empty, e.g. [-1, -1, -1]
+
+        entire_proposal_tokens = torch.full(
+            size=(batch_size, *proposal_tokens.shape[1:]),
+            fill_value=-1,
+            dtype=torch.long,
+            device=self._device,
+        )
+        entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens
+        entire_proposal_probs = torch.zeros(
+            batch_size,
+            *proposal_probs.shape[1:],
+            dtype=torch.float32,
+            device=self._device,
+        )
+        entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs
+
+        proposal_tokens, proposal_probs = (
+            entire_proposal_tokens,
+            entire_proposal_probs,
+        )
+
+        proposal_lens = torch.zeros(batch_size,
+                                    dtype=torch.long,
+                                    device=self._device)
+        proposal_lens[nonzero_proposal_len_indices] = max_proposal_len
+
+        return proposal_tokens, proposal_probs, proposal_lens

+ 394 - 0
aphrodite/spec_decode/spec_decode_worker.py

@@ -0,0 +1,394 @@
+from typing import List, Tuple, Optional, Dict
+from functools import cached_property
+
+import torch
+
+from aphrodite.spec_decode.metrics import AsyncMetricsCollector
+from aphrodite.common.sequence import (
+    SamplerOutput,
+    SequenceGroupMetadata,
+    SequenceGroupOutput,
+    SequenceOutput,
+)
+from aphrodite.task_handler.worker import Worker
+from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
+from aphrodite.modeling.layers.rejection import RejectionSampler
+from aphrodite.common.config import CacheConfig
+from aphrodite.spec_decode.util import (
+    nvtx_range,
+    get_all_seq_ids,
+    split_batch_by_proposal_len,
+)
+from aphrodite.spec_decode.interfaces import (
+    SpeculativeProposals,
+    SpeculativeScores,
+)
+from aphrodite.spec_decode.batch_expansion import BatchExpansionTop1Scorer
+from aphrodite.spec_decode.interfaces import SpeculativeScorer
+
+
+class SpecDecodeWorker:
+    """Worker which implements speculative decoding.
+
+    Speculative decoding reduces decoding per-token latency by using a proposal
+    method, such as a small draft model, to speculate ahead of a larger LLM. The
+    probabilities of the speculative tokens are then determined by the larger
+    LLM, after which some verification routine determines which (if any) of the
+    speculative tokens are accepted by the larger LLM.
+
+    The current implementation has the following limitations:
+    * Only draft-model proposal is implemented (contributions for more forms are
+        welcome!).
+    * Only top-1 proposal and scoring are implemented. Tree-attention is left as
+        future work.
+    * Only lossless rejection sampling is supported. Contributions adding lossy
+        verification routines are welcome (e.g. Medusa's typical acceptance).
+    * All sequences in a batch must have the same proposal length, or zero. This
+        can be improved by having per-sequence speculation in the future.
+    * The scoring forward pass is done without an MQA kernel, which is
+        suboptimal especially as the batch size, proposal length, and sequence
+        lengths grow. Contributions to add a MQA scoring are welcome once
+        correctness tests pass.
+    """
+
+    def __init__(
+        self,
+        proposer_worker: MultiStepWorker,
+        scorer_worker: Worker,
+        rejection_sampler: RejectionSampler,
+        metrics_collector: Optional[AsyncMetricsCollector] = None,
+    ):
+        """
+        Create a SpecDecodeWorker.
+
+        Args:
+            proposer_worker: A worker that can produce speculative tokens for
+                sequences.
+            scorer_worker: A worker that produces probabilities of speculative
+                tokens according to some base model. Typically a vanilla
+                Aphrodite Worker.
+            rejection_sampler: A Torch module used to perform modified rejection
+                sampling for speculative decoding.
+            metrics_collector: Helper class for collecting metrics; can be set
+                for testing purposes.
+        """
+        self.proposer_worker = proposer_worker
+        self.scorer_worker = scorer_worker
+        self.rejection_sampler = rejection_sampler
+
+        self._metrics = (AsyncMetricsCollector(rejection_sampler)
+                         if metrics_collector is None else metrics_collector)
+
+        self.probs_dtype = self.rejection_sampler.probs_dtype
+        self.token_id_dtype = self.rejection_sampler.token_id_dtype
+
+        self.scorer: SpeculativeScorer = None
+
+    def init_model(self) -> None:
+        """Initialize both scorer and proposer models."""
+        # The scorer worker model is initialized first in case the proposer
+        # model has a smaller TP degree than the target worker.
+        self.scorer_worker.init_model()
+        self.proposer_worker.init_model()
+
+        self._metrics.init_gpu_tensors(self.rank)
+        self.rejection_sampler.init_gpu_tensors(self.rank)
+        self.scorer = BatchExpansionTop1Scorer(
+            scorer_worker=self.scorer_worker,
+            device=self.device,
+            vocab_size=self._vocab_size,
+        )
+
+    def profile_num_available_blocks(
+        self,
+        block_size: int,
+        gpu_memory_utilization: float,
+        cpu_swap_space: int,
+        cache_dtype: str,
+    ) -> Tuple[int, int]:
+        """Determine the number of cache blocks to use.
+
+        This is done by profiling the scorer model (which is typically the
+        larger of the two). Then the total memory which would be used by the
+        scorer cache is divided evenly between the proposer and scorer model KV,
+        such that the number of blocks is equal in both KV caches.
+        """
+        (
+            num_gpu_blocks,
+            num_cpu_blocks,
+        ) = self.scorer_worker.profile_num_available_blocks(
+            block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype)
+
+        scorer_cache_block_size_bytes = (
+            self.scorer_worker.get_cache_block_size_bytes(
+                block_size, cache_dtype))
+        proposer_cache_block_size_bytes = (
+            self.proposer_worker.get_cache_block_size_bytes(
+                block_size, cache_dtype))
+
+        new_num_gpu_blocks = split_num_cache_blocks_evenly(
+            scorer_cache_block_size_bytes,
+            proposer_cache_block_size_bytes,
+            num_gpu_blocks,
+        )
+        return new_num_gpu_blocks, num_cpu_blocks
+
+    def init_cache_engine(self, cache_config: CacheConfig):
+        """Initialize the cache engine of the scorer and proposer workers."""
+        self.scorer_worker.init_cache_engine(cache_config)
+        self.proposer_worker.init_cache_engine(cache_config)
+
+    @torch.inference_mode()
+    def execute_model(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Optional[Dict[int, int]],
+        blocks_to_swap_out: Optional[Dict[int, int]],
+        blocks_to_copy: Optional[Dict[int, List[int]]],
+        num_spec_tokens: int,
+    ) -> List[SamplerOutput]:
+        """Perform speculative decoding on the input batch."""
+
+        assert seq_group_metadata_list is not None, (
+            "speculative decoding "
+            "requires non-None seq_group_metadata_list")
+
+        # If no spec tokens, call the proposer and scorer workers normally.
+        # Used for prefill.
+        if num_spec_tokens == 0 or len(seq_group_metadata_list) == 0:
+            return self._run_no_spec(
+                seq_group_metadata_list=seq_group_metadata_list,
+                blocks_to_swap_in=blocks_to_swap_in,
+                blocks_to_swap_out=blocks_to_swap_out,
+                blocks_to_copy=blocks_to_copy,
+            )
+
+        return self._run_speculative_decoding_step(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+            k=num_spec_tokens,
+        )
+
+    @nvtx_range("spec_decode_worker._run_no_spec")
+    def _run_no_spec(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Optional[Dict[int, int]],
+        blocks_to_swap_out: Optional[Dict[int, int]],
+        blocks_to_copy: Optional[Dict[int, List[int]]],
+    ) -> List[SamplerOutput]:
+        """Run a prefill step, without any speculation. The input is sent to the
+        proposer and scorer model so that the KV cache is consistent between the
+        two.
+        """
+
+        self.proposer_worker.execute_model(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+            return_python_output=False,
+        )
+
+        sampler_output = self.scorer_worker.execute_model(
+            seq_group_metadata_list=seq_group_metadata_list,
+            blocks_to_swap_in=blocks_to_swap_in,
+            blocks_to_swap_out=blocks_to_swap_out,
+            blocks_to_copy=blocks_to_copy,
+        )
+
+        # Clear device tensors from sampler output. This reduces communication
+        # overhead when the engine runs in a different process than the workers.
+        sampler_output.probs = None
+        sampler_output.sampled_tokens = None
+        return [sampler_output]
+
+    @nvtx_range("spec_decode_worker._run_speculative_decoding_step")
+    def _run_speculative_decoding_step(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        blocks_to_swap_in: Optional[Dict[int, int]],
+        blocks_to_swap_out: Optional[Dict[int, int]],
+        blocks_to_copy: Optional[Dict[int, List[int]]],
+        k: int,
+    ) -> List[SamplerOutput]:
+        """Execute a single step of speculative decoding.
+
+        This invokes the proposer worker to get k speculative tokens for each
+        sequence, then scores each speculative token using the scoring worker.
+
+        Returns a list of SamplerOutput, each containing a single token per
+        sequence.
+        """
+
+        # Generate proposals using draft worker.
+        proposals = self.proposer_worker.get_spec_proposals(
+            seq_group_metadata_list,
+            blocks_to_swap_in,
+            blocks_to_swap_out,
+            blocks_to_copy,
+            k,
+        )
+
+        proposal_scores = self.scorer.score_proposals(
+            seq_group_metadata_list,
+            blocks_to_swap_in,
+            blocks_to_swap_out,
+            blocks_to_copy,
+            k,
+            proposals,
+        )
+
+        accepted_token_ids = self._verify_tokens(seq_group_metadata_list,
+                                                 proposal_scores, proposals, k)
+
+        return self._create_output_sampler_list(seq_group_metadata_list,
+                                                accepted_token_ids, k)
+
+    @nvtx_range("spec_decode_worker._verify_tokens")
+    def _verify_tokens(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        proposal_scores: SpeculativeScores,
+        proposals: SpeculativeProposals,
+        max_proposal_len: int,
+    ) -> torch.Tensor:
+        """Determine which speculative tokens are accepted using the
+        probabilities of each token according to the proposer and scorer models.
+        """
+        proposal_lens_list = proposals.proposal_lens.tolist()
+
+        # Aphrodite currently only supports proposal lens equal to zero or the
+        # batch proposal len. This adds some complexity (splitting the batch
+        # into spec and non spec sequences) and should be removed in the
+        # future. It can be done by supporting per-sequence proposal lens.
+        _, spec_indices = split_batch_by_proposal_len(
+            seq_group_metadata_list,
+            proposal_lens_list,
+            select_proposal_len_zero=False,
+        )
+        _, non_spec_indices = split_batch_by_proposal_len(
+            seq_group_metadata_list,
+            proposal_lens_list,
+            select_proposal_len_zero=True,
+        )
+        original_indices = spec_indices + non_spec_indices
+
+        proposal_probs = proposal_scores.probs[spec_indices, :-1]
+        bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:]
+        non_spec_token_ids = proposal_scores.token_ids[non_spec_indices]
+
+        accepted_token_ids = self.rejection_sampler(
+            proposal_probs,
+            bonus_token_ids,
+            proposals.proposal_probs,
+            proposals.proposal_token_ids,
+        )
+
+        # Append output tokens from non-speculative sequences to
+        # the accepted token ids tensor.
+        non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len +
+                                                       1).clone()
+        non_spec_token_ids[:, 1:] = -1
+        accepted_token_ids = torch.cat(
+            [accepted_token_ids, non_spec_token_ids])
+
+        # Rearrange so that results are in the order of the original seq group
+        # metadata.
+        accepted_token_ids[original_indices] = accepted_token_ids.clone()
+
+        return accepted_token_ids
+
+    def _create_output_sampler_list(
+        self,
+        seq_group_metadata_list: List[SequenceGroupMetadata],
+        accepted_token_ids: torch.Tensor,  # shape: [batch_size, k+1]
+        k: int,
+    ) -> List[SamplerOutput]:
+        """Given the accepted token ids, create a list of SamplerOutput.
+
+        The output is padded with -1 tokens such that each sequence has
+        the same number of outputs.
+        """
+        seq_ids = get_all_seq_ids(seq_group_metadata_list)
+
+        # shape: [k+1, batch_size]
+        accepted_token_ids_by_step = accepted_token_ids.transpose(0,
+                                                                  1).tolist()
+        sampler_output_list = []
+        for token_ids_by_step in accepted_token_ids_by_step:
+            if all(token_id == -1 for token_id in token_ids_by_step):
+                break
+
+            step_output_token_ids = []
+            for token_id, seq_id in zip(token_ids_by_step, seq_ids):
+                step_output_token_ids.append(
+                    SequenceGroupOutput(
+                        samples=[
+                            SequenceOutput(
+                                parent_seq_id=seq_id,
+                                output_token=token_id,
+                                # TODO Add verifier logprobs.
+                                logprobs={token_id: 0.0},
+                                persistent_data={},
+                            )
+                        ],
+                        prompt_logprobs=None,
+                    ))
+            sampler_output_list.append(
+                SamplerOutput(outputs=step_output_token_ids))
+
+        maybe_rejsample_metrics = self._metrics.maybe_collect_rejsample_metrics(
+            k)
+        if maybe_rejsample_metrics is not None:
+            sampler_output_list[
+                0].spec_decode_worker_metrics = maybe_rejsample_metrics
+
+        return sampler_output_list
+
+    @cached_property
+    def _vocab_size(self) -> int:
+        """Get the vocab size of the model and make sure it's consistent between
+        draft and target workers.
+        """
+        vocab_sizes = [
+            worker.vocab_size
+            for worker in [self.proposer_worker, self.scorer_worker]
+        ]
+        assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes)
+        return vocab_sizes[0]
+
+    @property
+    def rank(self):
+        return self.scorer_worker.rank
+
+    @property
+    def device(self):
+        return self.scorer_worker.device
+
+
+def split_num_cache_blocks_evenly(
+    scorer_cache_block_size_bytes: int,
+    proposer_cache_block_size_bytes: int,
+    total_num_gpu_blocks: int,
+) -> int:
+    """Given total_num_gpu_blocks, the number of GPU blocks that could be
+    allocate to the target model, this function calculates how many blocks
+    should be given to the draft and target model.
+
+    Note that usually the block size, in bytes, of each model is different,
+    as it's a function of number of KV/layer, number of heads, and hidden
+    dimension size.
+
+    Since the target and draft models allocate the same number of blocks, we
+    simply calculate the number of blocks where if allocated by both models,
+    the total memory usage from KV cache is no larger than the number of
+    blocks allocatable by the target model alone.
+    """
+    new_num_gpu_blocks = int(
+        total_num_gpu_blocks * scorer_cache_block_size_bytes /
+        (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
+
+    return new_num_gpu_blocks

+ 101 - 0
aphrodite/spec_decode/util.py

@@ -0,0 +1,101 @@
+import torch
+from typing import List, Tuple
+from contextlib import contextmanager
+from itertools import chain
+
+from aphrodite.common.sequence import SequenceGroupMetadata, SamplerOutput
+
+SeqId = int
+
+
+def get_all_seq_ids(
+        seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[SeqId]:
+    """Given a list of SequenceGroupMetadata, create a list of all
+    sequence ids.
+    """
+    return list(
+        chain.from_iterable([
+            seq_group_metadata.seq_data.keys()
+            for seq_group_metadata in seq_group_metadata_list
+        ]))
+
+
+def split_batch_by_proposal_len(
+    seq_group_metadata_list: List[SequenceGroupMetadata],
+    proposal_lens: List[int],
+    select_proposal_len_zero: bool,
+) -> Tuple[List[SequenceGroupMetadata], List[int]]:
+    """Utility function that splits a batch based on whether the proposal len is
+    zero or not. We should remove this once Aphrodite supports per-sequence
+    proposal lens in a batch.
+    """
+
+    if select_proposal_len_zero:
+        predicate = lambda proposal_len: proposal_len == 0
+    else:
+        predicate = lambda proposal_len: proposal_len != 0
+
+    indices = [
+        i for i, (_, proposal_len
+                  ) in enumerate(zip(seq_group_metadata_list, proposal_lens))
+        if predicate(proposal_len)
+    ]
+    seq_groups = [
+        seq_group for seq_group, proposal_len in zip(
+            seq_group_metadata_list, proposal_lens) if predicate(proposal_len)
+    ]
+
+    return seq_groups, indices
+
+
+def sampler_output_to_torch(
+    sampler_output_list: List[SamplerOutput],
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Utility function which converts a list of SamplerOutput to tensors.
+
+    Returns:
+        sampled_token_ids: torch.Tensor
+            shape: [batch_size, len(sampler_output_list)]
+
+        sampled_token_probs: torch.Tensor
+            shape: [batch_size, len(sampler_output_list), vocab_size]
+    """
+
+    # shape: [batch_size, num_sampler_output, vocab_size]
+    sampled_token_probs = torch.stack(
+        [
+            sampler_output.sampled_token_probs
+            for sampler_output in sampler_output_list
+        ],
+        dim=0,
+    ).transpose(0, 1)
+
+    # shape: [batch_size, num_sampler_output]
+    sampled_token_ids = torch.stack(
+        [
+            sampler_output.sampled_token_ids.flatten()
+            for sampler_output in sampler_output_list
+        ],
+        dim=0,
+    ).transpose(0, 1)
+
+    return sampled_token_ids, sampled_token_probs
+
+
+@contextmanager
+def nvtx_range(msg, *args, **kwargs):
+    """
+    Context manager / decorator that pushes an NVTX range at the beginning
+    of its scope, and pops it at the end. If extra arguments are given,
+    they are passed as arguments to msg.format().
+
+    If running with cuda graphs, you must enable nsys cuda graph profiling.
+
+    Arguments:
+        msg (string): message to associate with the range
+    """
+    torch.cuda.nvtx.range_push(msg.format(*args, **kwargs))
+    try:
+        yield
+    finally:
+        torch.cuda.nvtx.range_pop()

+ 9 - 2
aphrodite/task_handler/cache_engine.py

@@ -4,9 +4,8 @@ from typing import Dict, List, Tuple
 import torch
 import torch
 from loguru import logger
 from loguru import logger
 
 
-from aphrodite._C import cache_ops
 from aphrodite.common.config import CacheConfig, ModelConfig, ParallelConfig
 from aphrodite.common.config import CacheConfig, ModelConfig, ParallelConfig
-from aphrodite.common.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
+from aphrodite.common.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE
 
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 
@@ -37,6 +36,10 @@ class CacheEngine:
         self.num_gpu_blocks = cache_config.num_gpu_blocks
         self.num_gpu_blocks = cache_config.num_gpu_blocks
         self.num_cpu_blocks = cache_config.num_cpu_blocks
         self.num_cpu_blocks = cache_config.num_cpu_blocks
 
 
+        # Skip initializing CUDA stream and buffer for Neuron backend.
+        if is_neuron():
+            return
+
         if cache_config.cache_dtype == "auto":
         if cache_config.cache_dtype == "auto":
             self.dtype = model_config.dtype
             self.dtype = model_config.dtype
         else:
         else:
@@ -119,6 +122,8 @@ class CacheEngine:
         dst: List[KVCache],
         dst: List[KVCache],
         src_to_dst: Dict[int, int],
         src_to_dst: Dict[int, int],
     ) -> None:
     ) -> None:
+        from aphrodite._C import cache_ops
+
         with torch.cuda.stream(self.cache_stream):
         with torch.cuda.stream(self.cache_stream):
             for i in range(self.num_layers):
             for i in range(self.num_layers):
                 src_key_cache, src_value_cache = src[i]
                 src_key_cache, src_value_cache = src[i]
@@ -138,6 +143,8 @@ class CacheEngine:
         self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
         self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
 
 
     def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
     def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
+        from aphrodite._C import cache_ops
+
         key_caches = [key_cache for key_cache, _ in self.gpu_cache]
         key_caches = [key_cache for key_cache, _ in self.gpu_cache]
         value_caches = [value_cache for _, value_cache in self.gpu_cache]
         value_caches = [value_cache for _, value_cache in self.gpu_cache]
         # NOTE: This operation implicitly synchronizes the CPU and GPU.
         # NOTE: This operation implicitly synchronizes the CPU and GPU.

+ 74 - 56
aphrodite/task_handler/model_runner.py

@@ -53,7 +53,7 @@ class ModelRunner:
         device_config: DeviceConfig,
         device_config: DeviceConfig,
         lora_config: Optional[LoRAConfig],
         lora_config: Optional[LoRAConfig],
         kv_cache_dtype: Optional[str] = "auto",
         kv_cache_dtype: Optional[str] = "auto",
-        kv_quant_params_path: Optional[str] = None,
+        # kv_quant_params_path: Optional[str] = None,
         is_driver_worker: bool = False,
         is_driver_worker: bool = False,
     ):
     ):
         self.model_config = model_config
         self.model_config = model_config
@@ -69,6 +69,7 @@ class ModelRunner:
         self.device_config = (device_config
         self.device_config = (device_config
                               if device_config is not None else DeviceConfig())
                               if device_config is not None else DeviceConfig())
         self.device = self.device_config.device
         self.device = self.device_config.device
+
         self.model = None
         self.model = None
         self.block_size = None  # Set after initial profiling.
         self.block_size = None  # Set after initial profiling.
         self.lora_manager = None
         self.lora_manager = None
@@ -89,37 +90,52 @@ class ModelRunner:
         # cache in_wsl result
         # cache in_wsl result
         self.in_wsl = in_wsl()
         self.in_wsl = in_wsl()
         self.kv_cache_dtype = kv_cache_dtype
         self.kv_cache_dtype = kv_cache_dtype
-        self.kv_quant_params = (self.load_kv_quant_params(
-            model_config, kv_quant_params_path)
-                                if self.kv_cache_dtype == "int8" else None)
-
-    def load_kv_quant_params(self, model_config: ModelConfig,
-                             kv_quant_params_path: str) -> List[List[float]]:
-        if model_config is None:
-            return None
-        # Remove it when all models support kv cache int8.
-        architectures = model_config.hf_config.architectures
-        for arch in architectures:
-            if arch not in ["LlamaForCausalLM", "LLaMAForCausalLM"]:
-                raise ValueError(
-                    "KV CACHE INT8 is not supported for model architectures "
-                    f"{arch} for now. "
-                    "Supported architectures: LlamaForCausalLM and "
-                    "LLaMAForCausalLM.")
-        num_layers = model_config.hf_config.num_hidden_layers
-        kv_quant_params = []
-        for i in range(num_layers):
-            if kv_quant_params_path is not None:
-                path = (kv_quant_params_path +
-                        f"/layers.{i}.past_kv_scale.0.weight")
-                kv_quant_param = list(np.fromfile(path, dtype=np.float32))
-            kv_quant_params.append(kv_quant_param)
-        return kv_quant_params
+        # self.kv_quant_params = (
+        #     self.load_kv_quant_params(model_config, kv_quant_params_path)
+        #     if self.kv_cache_dtype == "int8"
+        #     else None
+        # )
+
+        # Set enforce_eager to True for Neuron backend, to avoid capturing graph
+        if self.device_config.is_neuron:
+            self.model_config.enforce_eager = True
+
+    # def load_kv_quant_params(
+    #     self, model_config: ModelConfig, kv_quant_params_path: str
+    # ) -> List[List[float]]:
+    #     if model_config is None:
+    #         return None
+    #     # Remove it when all models support kv cache int8.
+    #     architectures = model_config.hf_config.architectures
+    #     for arch in architectures:
+    #         if arch not in ["LlamaForCausalLM", "LLaMAForCausalLM"]:
+    #             raise ValueError(
+    #                 "KV CACHE INT8 is not supported for model architectures "
+    #                 f"{arch} for now. "
+    #                 "Supported architectures: LlamaForCausalLM and "
+    #                 "LLaMAForCausalLM."
+    #             )
+    #     num_layers = model_config.hf_config.num_hidden_layers
+    #     kv_quant_params = []
+    #     for i in range(num_layers):
+    #         if kv_quant_params_path is not None:
+    #             path = (
+    #                 kv_quant_params_path + f"/layers.{i}.past_kv_scale.0.weight"  # noqa: E501
+    #             )
+    #             kv_quant_param = list(np.fromfile(path, dtype=np.float32))
+    #         kv_quant_params.append(kv_quant_param)
+    #     return kv_quant_params
 
 
     def load_model(self) -> None:
     def load_model(self) -> None:
         with measure_cuda_memory() as m:
         with measure_cuda_memory() as m:
-            self.model = get_model(self.model_config, self.device_config,
-                                   self.lora_config)
+            self.model = get_model(
+                self.model_config,
+                self.device_config,
+                lora_config=self.lora_config,
+                parallel_config=self.parallel_config,
+                scheduler_config=self.scheduler_config,
+            )
+
         self.model_memory_usage = m.consumed_memory
         self.model_memory_usage = m.consumed_memory
         tp = get_tensor_model_parallel_world_size()
         tp = get_tensor_model_parallel_world_size()
         logger.info(
         logger.info(
@@ -127,8 +143,6 @@ class ModelRunner:
             f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
             f"{self.model_memory_usage / float(2**30):.2f} GiB x {tp} = "
             f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
             f"{self.model_memory_usage * tp / float(2**30):.2f} GiB")
 
 
-        vocab_size = self.model.config.vocab_size
-
         if self.lora_config:
         if self.lora_config:
             assert (hasattr(self.model, "supported_lora_modules")
             assert (hasattr(self.model, "supported_lora_modules")
                     and self.model.supported_lora_modules
                     and self.model.supported_lora_modules
@@ -142,7 +156,7 @@ class ModelRunner:
                 self.scheduler_config.max_num_seqs,
                 self.scheduler_config.max_num_seqs,
                 self.scheduler_config.max_num_batched_tokens +
                 self.scheduler_config.max_num_batched_tokens +
                 self.scheduler_config.max_paddings,
                 self.scheduler_config.max_paddings,
-                vocab_size,
+                self.vocab_size,
                 self.lora_config,
                 self.lora_config,
                 self.device,
                 self.device,
                 self.model.embedding_modules,
                 self.model.embedding_modules,
@@ -250,6 +264,7 @@ class ModelRunner:
                 slot_mapping[-1].append(slot)
                 slot_mapping[-1].append(slot)
 
 
         max_prompt_len = max(subquery_lens)
         max_prompt_len = max(subquery_lens)
+        assert max_prompt_len > 0
         input_tokens = _make_tensor_with_pad(
         input_tokens = _make_tensor_with_pad(
             input_tokens,
             input_tokens,
             max_prompt_len,
             max_prompt_len,
@@ -309,7 +324,7 @@ class ModelRunner:
             block_tables=block_tables,
             block_tables=block_tables,
             use_cuda_graph=False,
             use_cuda_graph=False,
             kv_cache_dtype=self.kv_cache_dtype,
             kv_cache_dtype=self.kv_cache_dtype,
-            kv_quant_params=self.kv_quant_params,
+            # kv_quant_params=self.kv_quant_params,
         )
         )
         return (
         return (
             input_tokens,
             input_tokens,
@@ -449,7 +464,7 @@ class ModelRunner:
             block_tables=block_tables,
             block_tables=block_tables,
             use_cuda_graph=use_captured_graph,
             use_cuda_graph=use_captured_graph,
             kv_cache_dtype=self.kv_cache_dtype,
             kv_cache_dtype=self.kv_cache_dtype,
-            kv_quant_params=self.kv_quant_params,
+            # kv_quant_params=self.kv_quant_params,
         )
         )
         return (
         return (
             input_tokens,
             input_tokens,
@@ -472,6 +487,7 @@ class ModelRunner:
         selected_token_start_idx = 0
         selected_token_start_idx = 0
         categorized_sample_indices = {t: [] for t in SamplingType}
         categorized_sample_indices = {t: [] for t in SamplingType}
         categorized_sample_indices_start_idx = 0
         categorized_sample_indices_start_idx = 0
+        pin_memory = not self.in_wsl and not self.device_config.is_neuron
 
 
         max_subquery_len = max(subquery_lens) if subquery_lens else 1
         max_subquery_len = max(subquery_lens) if subquery_lens else 1
         for i, seq_group_metadata in enumerate(seq_group_metadata_list):
         for i, seq_group_metadata in enumerate(seq_group_metadata_list):
@@ -501,8 +517,8 @@ class ModelRunner:
                 selected_token_indices.append(selected_token_start_idx +
                 selected_token_indices.append(selected_token_start_idx +
                                               subquery_len - 1)
                                               subquery_len - 1)
                 selected_token_start_idx += max_subquery_len
                 selected_token_start_idx += max_subquery_len
-                if (sampling_params.sampling_type == SamplingType.RANDOM_SEED):
-                    assert sampling_params.seed is not None
+
+                if sampling_params.seed is not None:
                     seq_group_metadata.state.generator = torch.Generator(
                     seq_group_metadata.state.generator = torch.Generator(
                         device="cuda").manual_seed(sampling_params.seed)
                         device="cuda").manual_seed(sampling_params.seed)
             else:
             else:
@@ -522,21 +538,21 @@ class ModelRunner:
                         ))
                         ))
                 categorized_sample_indices_start_idx += num_seqs
                 categorized_sample_indices_start_idx += num_seqs
 
 
-            if (seq_group_metadata.state.generator is not None):
+            if sampling_params.seed is not None:
                 generators.append(seq_group_metadata.state.generator)
                 generators.append(seq_group_metadata.state.generator)
 
 
         selected_token_indices = _async_h2d(
         selected_token_indices = _async_h2d(
             selected_token_indices,
             selected_token_indices,
             dtype=torch.long,
             dtype=torch.long,
             target_device=self.device,
             target_device=self.device,
-            pin_memory=not self.in_wsl,
+            pin_memory=pin_memory,
         )
         )
         categorized_sample_indices = {
         categorized_sample_indices = {
             t: _async_h2d(
             t: _async_h2d(
                 seq_ids,
                 seq_ids,
                 dtype=torch.int,
                 dtype=torch.int,
                 target_device=self.device,
                 target_device=self.device,
-                pin_memory=not self.in_wsl,
+                pin_memory=pin_memory,
             )
             )
             for t, seq_ids in categorized_sample_indices.items()
             for t, seq_ids in categorized_sample_indices.items()
         }
         }
@@ -621,9 +637,9 @@ class ModelRunner:
                 "block_tables": input_metadata.block_tables,
                 "block_tables": input_metadata.block_tables,
                 "use_cuda_graph": input_metadata.use_cuda_graph,
                 "use_cuda_graph": input_metadata.use_cuda_graph,
                 "kv_cache_dtype": input_metadata.kv_cache_dtype,
                 "kv_cache_dtype": input_metadata.kv_cache_dtype,
-                "kv_quant_params": input_metadata.kv_quant_params,
+                # "kv_quant_params": input_metadata.kv_quant_params,
                 "selected_token_indices":
                 "selected_token_indices":
-                sampling_metadata.selected_token_indices,  # noqa
+                sampling_metadata.selected_token_indices,
                 "lora_requests": lora_requests,
                 "lora_requests": lora_requests,
                 "lora_mapping": lora_mapping,
                 "lora_mapping": lora_mapping,
             }
             }
@@ -645,7 +661,7 @@ class ModelRunner:
                 block_tables=metadata_dict["block_tables"],
                 block_tables=metadata_dict["block_tables"],
                 use_cuda_graph=metadata_dict["use_cuda_graph"],
                 use_cuda_graph=metadata_dict["use_cuda_graph"],
                 kv_cache_dtype=metadata_dict["kv_cache_dtype"],
                 kv_cache_dtype=metadata_dict["kv_cache_dtype"],
-                kv_quant_params=metadata_dict["kv_quant_params"],
+                # kv_quant_params=metadata_dict["kv_quant_params"],
             )
             )
             sampling_metadata = SamplingMetadata(
             sampling_metadata = SamplingMetadata(
                 seq_groups=None,
                 seq_groups=None,
@@ -707,8 +723,7 @@ class ModelRunner:
     @torch.inference_mode()
     @torch.inference_mode()
     def profile_run(self) -> None:
     def profile_run(self) -> None:
         # Enable top-k sampling to reflect the accurate memory usage.
         # Enable top-k sampling to reflect the accurate memory usage.
-        vocab_size = self.model_config.get_vocab_size()
-        sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
+        sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
         max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
         max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
         max_num_seqs = self.scheduler_config.max_num_seqs
         max_num_seqs = self.scheduler_config.max_num_seqs
 
 
@@ -789,8 +804,9 @@ class ModelRunner:
     @torch.inference_mode()
     @torch.inference_mode()
     def capture_model(self, kv_caches: List[KVCache]) -> None:
     def capture_model(self, kv_caches: List[KVCache]) -> None:
         # NOTE: This is a hack to ensure that the NCCL backend is never
         # NOTE: This is a hack to ensure that the NCCL backend is never
-        # deleted before the CUDA graph
+        # deleted before the CUDA graphs.
         self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
         self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
+
         assert not self.model_config.enforce_eager
         assert not self.model_config.enforce_eager
         logger.info("Capturing the model for CUDA graphs. This may lead to "
         logger.info("Capturing the model for CUDA graphs. This may lead to "
                     "unexpected consequences if the model is not static. To "
                     "unexpected consequences if the model is not static. To "
@@ -818,8 +834,6 @@ class ModelRunner:
             bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
             bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
         ]
         ]
 
 
-        # NOTE: Capturing the largest batch size first may help reduce the
-        # memory usage of CUDA graph.
         # NOTE: There are 3 backends for all-reduce: custom all-reduce
         # NOTE: There are 3 backends for all-reduce: custom all-reduce
         # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
         # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
         # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
         # either custom all-reduce kernel or CuPy NCCL. When not using CUDA
@@ -847,7 +861,7 @@ class ModelRunner:
                     block_tables=block_tables[:batch_size],
                     block_tables=block_tables[:batch_size],
                     use_cuda_graph=True,
                     use_cuda_graph=True,
                     kv_cache_dtype=self.kv_cache_dtype,
                     kv_cache_dtype=self.kv_cache_dtype,
-                    kv_quant_params=self.kv_quant_params,
+                    # kv_quant_params=self.kv_quant_params,
                 )
                 )
 
 
                 if self.lora_config:
                 if self.lora_config:
@@ -882,6 +896,10 @@ class ModelRunner:
         self.graph_runners.clear()
         self.graph_runners.clear()
         self.cupy_nccl_backend = None
         self.cupy_nccl_backend = None
 
 
+    @property
+    def vocab_size(self) -> int:
+        return self.model_config.get_vocab_size()
+
 
 
 class CUDAGraphRunner:
 class CUDAGraphRunner:
 
 
@@ -916,14 +934,14 @@ class CUDAGraphRunner:
         # NOTE: Python 3.8 does not support multi-line with statements.
         # NOTE: Python 3.8 does not support multi-line with statements.
         # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
         # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
         self.graph = torch.cuda.CUDAGraph()
         self.graph = torch.cuda.CUDAGraph()
-        with torch.cuda.graph(self.graph,
-                              pool=memory_pool), _maybe_cupy_nccl():
-            hidden_states = self.model(
-                input_ids,
-                positions,
-                kv_caches,
-                input_metadata,
-            )
+        with torch.cuda.graph(self.graph, pool=memory_pool):  # noqa: SIM117
+            with _maybe_cupy_nccl():
+                hidden_states = self.model(
+                    input_ids,
+                    positions,
+                    kv_caches,
+                    input_metadata,
+                )
         torch.cuda.synchronize()
         torch.cuda.synchronize()
 
 
         # Save the input and output buffers.
         # Save the input and output buffers.

+ 204 - 0
aphrodite/task_handler/neuron_worker.py

@@ -0,0 +1,204 @@
+"""A Neuron worker class."""
+from typing import Dict, List, Optional, Tuple
+
+import torch
+import torch.distributed
+
+from aphrodite.common.config import (
+    CacheConfig,
+    DeviceConfig,
+    ModelConfig,
+    ParallelConfig,
+    SchedulerConfig,
+    LoRAConfig,
+)
+from aphrodite.modeling import set_random_seed
+from aphrodite.modeling.megatron.communication_op import broadcast_tensor_dict
+from aphrodite.modeling.megatron.parallel_state import (
+    ensure_model_parallel_initialized, )
+from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
+from aphrodite.task_handler.cache_engine import CacheEngine
+from aphrodite.task_handler.model_runner import ModelRunner
+
+
+class Worker:
+    """A worker class that executes the model on a group of neuron cores."""
+
+    def __init__(
+        self,
+        model_config: ModelConfig,
+        parallel_config: ParallelConfig,
+        scheduler_config: SchedulerConfig,
+        device_config: DeviceConfig,
+        local_rank: int,
+        rank: int,
+        distributed_init_method: str,
+        lora_config: Optional[LoRAConfig] = None,
+        kv_cache_dtype: Optional[str] = "auto",
+        # kv_quant_params_path: Optional[str] = None,
+        is_driver_worker: bool = False,
+    ) -> None:
+        self.model_config = model_config
+        self.parallel_config = parallel_config
+        self.scheduler_config = scheduler_config
+        self.device_config = device_config
+        self.local_rank = local_rank
+        self.rank = rank
+        self.distributed_init_method = distributed_init_method
+        self.lora_config = lora_config
+        self.is_driver_worker = is_driver_worker
+        if self.is_driver_worker:
+            assert self.rank == 0, "The driver worker must have rank 0."
+
+        self.model_runner = ModelRunner(
+            model_config,
+            parallel_config,
+            scheduler_config,
+            device_config,
+            lora_config=self.lora_config,
+            is_driver_worker=is_driver_worker,
+        )
+        # Uninitialized cache engine. Will be initialized by
+        # self.init_cache_engine().
+        self.cache_config = None
+        self.cache_engine = None
+        self.cache_events = None
+        self.gpu_cache = None
+
+    def init_model(self) -> None:
+        # Initialize the distributed environment.
+        _init_distributed_environment(
+            self.parallel_config,
+            self.rank,
+            self.distributed_init_method,
+            distributed_backend="gloo",
+        )
+
+        # Initialize the model.
+        set_random_seed(self.model_config.seed)
+
+    def load_model(self):
+        self.model_runner.load_model()
+
+    @torch.inference_mode()
+    def profile_num_available_blocks(
+        self,
+        block_size: int = 128,
+        gpu_memory_utilization: float = 0.9,
+        cpu_swap_space: int = 0,
+        cache_dtype: str = "float16",
+    ) -> Tuple[int, int]:
+        """Simply returns max_num_seqs as num_gpu_blocks, 0 as num_cpu_blocks.
+        """
+        num_gpu_blocks = self.scheduler_config.max_num_seqs
+        num_cpu_blocks = 0
+        return num_gpu_blocks, num_cpu_blocks
+
+    def init_cache_engine(self, cache_config: CacheConfig) -> None:
+        self.cache_config = cache_config
+        self.cache_engine = CacheEngine(self.cache_config, self.model_config,
+                                        self.parallel_config)
+        self.model_runner.set_block_size(self.cache_engine.block_size)
+
+    def warm_up_model(self) -> None:
+        # Warm up is maintained in transformers-neuronx
+        pass
+
+    def cache_swap(
+        self,
+        blocks_to_swap_in: Dict[int, int],
+        blocks_to_swap_out: Dict[int, int],
+        blocks_to_copy: Dict[int, List[int]],
+    ) -> None:
+        # Issue cache operations.
+        issued_cache_op = False
+        if blocks_to_swap_in:
+            self.cache_engine.swap_in(blocks_to_swap_in)
+            issued_cache_op = True
+        if blocks_to_swap_out:
+            self.cache_engine.swap_out(blocks_to_swap_out)
+            issued_cache_op = True
+        if blocks_to_copy:
+            self.cache_engine.copy(blocks_to_copy)
+            issued_cache_op = True
+
+        cache_events = self.cache_events if issued_cache_op else None
+
+        # Wait for cache operations to finish.
+        if cache_events is not None:
+            raise NotImplementedError(
+                "cache operations are not implemented for neuron backend.")
+
+    @torch.inference_mode()
+    def execute_model(
+        self,
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
+        blocks_to_swap_in: Optional[Dict[int, int]] = None,
+        blocks_to_swap_out: Optional[Dict[int, int]] = None,
+        blocks_to_copy: Optional[Dict[int, List[int]]] = None,
+    ) -> Optional[SamplerOutput]:
+        if self.is_driver_worker:
+            assert seq_group_metadata_list is not None
+            num_seq_groups = len(seq_group_metadata_list)
+            assert blocks_to_swap_in is not None
+            assert blocks_to_swap_out is not None
+            assert blocks_to_copy is not None
+            data = {
+                "num_seq_groups": num_seq_groups,
+                "blocks_to_swap_in": blocks_to_swap_in,
+                "blocks_to_swap_out": blocks_to_swap_out,
+                "blocks_to_copy": blocks_to_copy,
+            }
+            broadcast_tensor_dict(data, src=0)
+        else:
+            data = broadcast_tensor_dict(src=0)
+            num_seq_groups = data["num_seq_groups"]
+            blocks_to_swap_in = data["blocks_to_swap_in"]
+            blocks_to_swap_out = data["blocks_to_swap_out"]
+            blocks_to_copy = data["blocks_to_copy"]
+
+        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
+
+        # If there is no input, we don't need to execute the model.
+        if num_seq_groups == 0:
+            return {}
+
+        output = self.model_runner.execute_model(seq_group_metadata_list,
+                                                 self.gpu_cache)
+        return output
+
+
+def _init_distributed_environment(
+    parallel_config: ParallelConfig,
+    rank: int,
+    distributed_init_method: Optional[str] = None,
+    distributed_backend: Optional[str] = None,
+) -> None:
+    """Initialize the distributed environment."""
+    if torch.distributed.is_initialized():
+        torch_world_size = torch.distributed.get_world_size()
+        if torch_world_size != parallel_config.world_size:
+            raise RuntimeError(
+                "torch.distributed is already initialized but the torch world "
+                "size does not match parallel_config.world_size "
+                f"({torch_world_size} vs. {parallel_config.world_size}).")
+    elif not distributed_init_method:
+        raise ValueError(
+            "distributed_init_method must be set if torch.distributed "
+            "is not already initialized")
+    else:
+        distributed_backend = (distributed_backend
+                               if distributed_backend else "nccl")
+        torch.distributed.init_process_group(
+            backend=distributed_backend,
+            world_size=parallel_config.world_size,
+            rank=rank,
+            init_method=distributed_init_method,
+        )
+
+    # A small all_reduce for warmup.
+    torch.distributed.all_reduce(torch.zeros(1))
+    ensure_model_parallel_initialized(
+        parallel_config.tensor_parallel_size,
+        parallel_config.pipeline_parallel_size,
+    )

+ 30 - 14
aphrodite/task_handler/worker.py

@@ -6,9 +6,9 @@ from typing import Dict, List, Tuple, Set, Optional
 import torch
 import torch
 import torch.distributed
 import torch.distributed
 
 
-from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
-                                     SchedulerConfig, LoRAConfig, DeviceConfig)
-from aphrodite.common.utils import in_wsl
+from aphrodite.common.config import (CacheConfig, DeviceConfig, ModelConfig,
+                                     ParallelConfig, SchedulerConfig,
+                                     LoRAConfig)
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling.megatron import cupy_utils
 from aphrodite.modeling.megatron import cupy_utils
 from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
 from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
@@ -20,7 +20,7 @@ from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
 from aphrodite.task_handler.cache_engine import CacheEngine
 from aphrodite.task_handler.cache_engine import CacheEngine
 from aphrodite.task_handler.model_runner import ModelRunner
 from aphrodite.task_handler.model_runner import ModelRunner
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
-from aphrodite.common.utils import is_hip
+from aphrodite.common.utils import in_wsl
 
 
 
 
 class Worker:
 class Worker:
@@ -42,7 +42,7 @@ class Worker:
         distributed_init_method: str,
         distributed_init_method: str,
         lora_config: Optional[LoRAConfig] = None,
         lora_config: Optional[LoRAConfig] = None,
         kv_cache_dtype: Optional[str] = "auto",
         kv_cache_dtype: Optional[str] = "auto",
-        kv_quant_params_path: Optional[str] = None,
+        # kv_quant_params_path: Optional[str] = None,
         is_driver_worker: bool = False,
         is_driver_worker: bool = False,
     ) -> None:
     ) -> None:
         self.model_config = model_config
         self.model_config = model_config
@@ -64,7 +64,7 @@ class Worker:
             device_config,
             device_config,
             lora_config=self.lora_config,
             lora_config=self.lora_config,
             kv_cache_dtype=kv_cache_dtype,
             kv_cache_dtype=kv_cache_dtype,
-            kv_quant_params_path=kv_quant_params_path,
+            # kv_quant_params_path=kv_quant_params_path,
             is_driver_worker=is_driver_worker)
             is_driver_worker=is_driver_worker)
         # Uninitialized cache engine. Will be initialized by
         # Uninitialized cache engine. Will be initialized by
         # self.init_cache_engine().
         # self.init_cache_engine().
@@ -99,12 +99,9 @@ class Worker:
         else:
         else:
             raise RuntimeError(
             raise RuntimeError(
                 f"Not support device type: {self.device_config.device}")
                 f"Not support device type: {self.device_config.device}")
-
         # Initialize the distributed environment.
         # Initialize the distributed environment.
         init_distributed_environment(self.parallel_config, self.rank,
         init_distributed_environment(self.parallel_config, self.rank,
                                      cupy_port, self.distributed_init_method)
                                      cupy_port, self.distributed_init_method)
-        if not self.parallel_config.disable_custom_all_reduce:
-            init_custom_ar()
         # Initialize the model.
         # Initialize the model.
         set_random_seed(self.model_config.seed)
         set_random_seed(self.model_config.seed)
 
 
@@ -143,8 +140,8 @@ class Worker:
         # GPU did not change their memory usage during the profiling.
         # GPU did not change their memory usage during the profiling.
         peak_memory = self.init_gpu_memory - free_gpu_memory
         peak_memory = self.init_gpu_memory - free_gpu_memory
 
 
-        cache_block_size = CacheEngine.get_cache_block_size(
-            block_size, cache_dtype, self.model_config, self.parallel_config)
+        cache_block_size = self.get_cache_block_size_bytes(
+            block_size, cache_dtype)
         num_gpu_blocks = int(
         num_gpu_blocks = int(
             (total_gpu_memory * gpu_memory_utilization - peak_memory) //
             (total_gpu_memory * gpu_memory_utilization - peak_memory) //
             cache_block_size)
             cache_block_size)
@@ -195,7 +192,7 @@ class Worker:
         # Wait for cache operations to finish.
         # Wait for cache operations to finish.
         # TODO: Profile swapping overhead and optimize if needed.
         # TODO: Profile swapping overhead and optimize if needed.
         if cache_events is not None:
         if cache_events is not None:
-            for event in cache_events:  # pylint: disable=not-an-iterable
+            for event in cache_events:
                 event.wait()
                 event.wait()
 
 
     @torch.inference_mode()
     @torch.inference_mode()
@@ -245,6 +242,22 @@ class Worker:
     def list_loras(self) -> Set[int]:
     def list_loras(self) -> Set[int]:
         return self.model_runner.list_loras()
         return self.model_runner.list_loras()
 
 
+    @property
+    def max_model_len(self) -> int:
+        return self.model_config.max_model_len
+
+    @property
+    def vocab_size(self) -> int:
+        return self.model_runner.vocab_size
+
+    def get_cache_block_size_bytes(self, block_size: int,
+                                   cache_dtype: str) -> int:
+        """Get the size of the KV cache block size in bytes.
+        """
+        return CacheEngine.get_cache_block_size(block_size, cache_dtype,
+                                                self.model_config,
+                                                self.parallel_config)
+
 
 
 def init_distributed_environment(
 def init_distributed_environment(
     parallel_config: ParallelConfig,
     parallel_config: ParallelConfig,
@@ -279,8 +292,7 @@ def init_distributed_environment(
                 "cupy.distributed is already initialized but the cupy world "
                 "cupy.distributed is already initialized but the cupy world "
                 "size does not match parallel_config.world_size "
                 "size does not match parallel_config.world_size "
                 f"({cupy_world_size} vs. {parallel_config.world_size}).")
                 f"({cupy_world_size} vs. {parallel_config.world_size}).")
-    elif (parallel_config.world_size > 1 and cupy_port is not None
-          and not is_hip()):
+    elif (parallel_config.world_size > 1 and cupy_port is not None):
         # NOTE: We don't initialize CuPy process group when world size
         # NOTE: We don't initialize CuPy process group when world size
         # is 1.
         # is 1.
         # TODO: Support multi-node connection.
         # TODO: Support multi-node connection.
@@ -298,6 +310,10 @@ def init_distributed_environment(
     ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
     ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
                                       parallel_config.pipeline_parallel_size)
                                       parallel_config.pipeline_parallel_size)
 
 
+    # Initialize a custom fast all-reduce implementation.
+    if not parallel_config.disable_custom_all_reduce:
+        init_custom_ar()
+
 
 
 def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
 def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
     # Check if the GPU supports the dtype.
     # Check if the GPU supports the dtype.

+ 1 - 2
kernels/attention/attention_dtypes.h

@@ -4,5 +4,4 @@
 #include "dtype_float16.cuh"
 #include "dtype_float16.cuh"
 #include "dtype_float32.cuh"
 #include "dtype_float32.cuh"
 #include "dtype_bfloat16.cuh"
 #include "dtype_bfloat16.cuh"
-#include "dtype_fp8_e5m2.cuh"
-#include "dtype_int8.cuh"
+#include "dtype_fp8_e5m2.cuh"

+ 936 - 1014
kernels/attention/attention_kernels.cu

@@ -16,1017 +16,939 @@
  * See the License for the specific language governing permissions and
  * See the License for the specific language governing permissions and
  * limitations under the License.
  * limitations under the License.
  */
  */
-#ifdef USE_ROCM
-#include <hip/hip_runtime.h>
-#endif
-
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-#include <c10/cuda/CUDAGuard.h>
-
-#include "attention_dtypes.h"
-#include "attention_utils.cuh"
-#include "../quantization/int8_kvcache/quant_utils.cuh"
-#ifdef ENABLE_FP8_E5M2
-#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
-#endif
-
-#include <algorithm>
-
-#ifndef USE_ROCM
-#define WARP_SIZE 32
-#else
-#define WARP_SIZE warpSize
-#endif
-#define MAX(a, b) ((a) > (b) ? (a) : (b))
-#define MIN(a, b) ((a) < (b) ? (a) : (b))
-#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
-
-enum kv_cache_dtype {
-  AUTO,
-#ifdef ENABLE_FP8_E5M2
-  FP8_E5M2,
-#endif
-  INT8};
-
-namespace aphrodite {
-
-// Utility function for attention softmax.
-template<int NUM_WARPS>
-inline __device__ float block_sum(float* red_smem, float sum) {
-  // Decompose the thread index into warp / lane.
-  int warp = threadIdx.x / WARP_SIZE;
-  int lane = threadIdx.x % WARP_SIZE;
-
-  // Compute the sum per warp.
-#pragma unroll
-  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
-    sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
-  }
-
-  // Warp leaders store the data to shared memory.
-  if (lane == 0) {
-    red_smem[warp] = sum;
-  }
-
-  // Make sure the data is in shared memory.
-  __syncthreads();
-
-  // The warps compute the final sums.
-  if (lane < NUM_WARPS) {
-    sum = red_smem[lane];
-  }
-
-  // Parallel reduction inside the warp.
-#pragma unroll
-  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
-    sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
-  }
-
-  // Broadcast to other threads.
-  return APHRODITE_SHFL_SYNC(sum, 0);
-}
-
-// TODO: Merge the last two dimensions of the grid.
-// Grid: (num_heads, num_seqs, max_num_partitions).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int PARTITION_SIZE = 0> // Zero means no partitioning.
-__device__ void paged_attention_kernel(
-  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
-  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
-  const int seq_idx = blockIdx.y;
-  const int partition_idx = blockIdx.z;
-  const int max_num_partitions = gridDim.z;
-  constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
-  const int context_len = context_lens[seq_idx];
-  if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
-    // No work to do. Terminate the thread block.
-    return;
-  }
-
-  const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
-  const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
-
-  // [start_block_idx, end_block_idx) is the range of blocks to process.
-  const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
-  const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
-  const int num_blocks = end_block_idx - start_block_idx;
-
-  // [start_token_idx, end_token_idx) is the range of tokens to process.
-  const int start_token_idx = start_block_idx * BLOCK_SIZE;
-  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
-  const int num_tokens = end_token_idx - start_token_idx;
-
-  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
-  constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
-  assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
-  constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
-  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  const int thread_idx = threadIdx.x;
-  const int warp_idx = thread_idx / WARP_SIZE;
-  const int lane = thread_idx % WARP_SIZE;
-
-  const int head_idx = blockIdx.x;
-  const int num_heads = gridDim.x;
-  const int num_queries_per_kv = num_heads / num_kv_heads;
-  const int kv_head_idx = head_idx / num_queries_per_kv;
-  const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
-
-  // A vector type to store a part of a key or a query.
-  // The vector size is configured in such a way that the threads in a thread group
-  // fetch or compute 16 bytes at a time.
-  // For example, if the size of a thread group is 4 and the data type is half,
-  // then the vector size is 16 / (4 * sizeof(half)) == 2.
-  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
-  using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
-  using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
-  using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
-
-  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
-  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
-
-  const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
-  const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
-
-  // Load the query to registers.
-  // Each thread in a thread group has a different part of the query.
-  // For example, if the the thread group size is 4, then the first thread in the group
-  // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
-  // th vectors of the query, and so on.
-  // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
-  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
-  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
-#pragma unroll
-  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
-    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
-    q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
-  }
-  __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
-
-  // Memory planning.
-  extern __shared__ char shared_mem[];
-  // NOTE: We use FP32 for the softmax logits for better accuracy.
-  float* logits = reinterpret_cast<float*>(shared_mem);
-  // Workspace for reduction.
-  __shared__ float red_smem[2 * NUM_WARPS];
-
-  // x == THREAD_GROUP_SIZE * VEC_SIZE
-  // Each thread group fetches x elements from the key at a time.
-  constexpr int x = 16 / sizeof(cache_t);
-  float qk_max = -FLT_MAX;
-
-  // Iterate over the key blocks.
-  // Each warp fetches a block of keys for each iteration.
-  // Each thread group in a warp fetches a key from the block, and computes
-  // dot product with the query.
-  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
-  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
-    // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by large numbers
-    // (e.g., kv_block_stride).
-    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
-
-    // Load a key to registers.
-    // Each thread in a thread group has a different part of the key.
-    // For example, if the the thread group size is 4, then the first thread in the group
-    // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
-    // vectors of the key, and so on.
-    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
-      const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
-      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
-      K_vec k_vecs[NUM_VECS_PER_THREAD];
-
-#pragma unroll
-      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
-        const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
-                                       + kv_head_idx * kv_head_stride
-                                       + physical_block_offset * x;
-        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
-        const int offset1 = (vec_idx * VEC_SIZE) / x;
-        const int offset2 = (vec_idx * VEC_SIZE) % x;
-        if constexpr (KV_CACHE_DTYPE == INT8) {
-          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
-          using Dequant_vec = typename FloatVec<Quant_vec>::Type;
-          Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
-          k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
-#ifdef ENABLE_FP8_E5M2
-        } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
-          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
-          // Vector conversion from Quant_vec to K_vec.
-          k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
-#endif
-        } else {
-          k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
-        }
-      }
-
-      // Compute dot product.
-      // This includes a reduction across the threads in the same thread group.
-      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
-      // Add the ALiBi bias if slopes are given.
-      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
-
-      if (thread_group_offset == 0) {
-        // Store the partial reductions to shared memory.
-        // NOTE: It is required to zero out the masked logits.
-        const bool mask = token_idx >= context_len;
-        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
-        // Update the max value.
-        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
-      }
-    }
-  }
-
-  // Perform reduction across the threads in the same warp to get the
-  // max qk value for each "warp" (not across the thread block yet).
-  // The 0-th thread of each thread group already has its max qk value.
-#pragma unroll
-  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
-    qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
-  }
-  if (lane == 0) {
-    red_smem[warp_idx] = qk_max;
-  }
-  __syncthreads();
-
-  // TODO: Refactor this part.
-  // Get the max qk value for the sequence.
-  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
-#pragma unroll
-  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
-    qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
-  }
-  // Broadcast the max qk value to all threads.
-  qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
-
-  // Get the sum of the exp values.
-  float exp_sum = 0.f;
-  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
-    float val = __expf(logits[i] - qk_max);
-    logits[i] = val;
-    exp_sum += val;
-  }
-  exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
-
-  // Compute softmax.
-  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
-  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
-    logits[i] *= inv_sum;
-  }
-  __syncthreads();
-
-  // If partitioning is enabled, store the max logit and exp_sum.
-  if (USE_PARTITIONING && thread_idx == 0) {
-    float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
-                                       + head_idx * max_num_partitions
-                                       + partition_idx;
-    *max_logits_ptr = qk_max;
-    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
-                                   + head_idx * max_num_partitions
-                                   + partition_idx;
-    *exp_sums_ptr = exp_sum;
-  }
-
-  // Each thread will fetch 16 bytes from the value cache at a time.
-  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
-  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
-  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
-  using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
-  using Float_L_vec = typename FloatVec<L_vec>::Type;
-
-  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
-  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
-  constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
-
-  // NOTE: We use FP32 for the accumulator for better accuracy.
-  float accs[NUM_ROWS_PER_THREAD];
-#pragma unroll
-  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-    accs[i] = 0.f;
-  }
-
-  scalar_t zero_value;
-  zero(zero_value);
-  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
-    // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by large numbers
-    // (e.g., kv_block_stride).
-    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
-    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
-    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
-    L_vec logits_vec;
-    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
-
-    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
-                                   + kv_head_idx * kv_head_stride;
-#pragma unroll
-    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
-      if (row_idx < HEAD_SIZE) {
-        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
-        V_vec v_vec;
-        if constexpr (KV_CACHE_DTYPE == INT8) {
-          // dequant and conversion
-          V_quant_vec v_vec_quant = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
-          using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
-          V_dequant_vec v_vec_dequant = int8::dequant(v_vec_quant, v_scale, v_zp);
-          v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
-#ifdef ENABLE_FP8_E5M2
-        } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
-          V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
-          // Vector conversion from V_quant_vec to V_vec.
-          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
-#endif
-        } else {
-          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
-        }
-        if (block_idx == num_context_blocks - 1) {
-          // NOTE: When v_vec contains the tokens that are out of the context,
-          // we should explicitly zero out the values since they may contain NaNs.
-          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
-#pragma unroll
-          for (int j = 0; j < V_VEC_SIZE; j++) {
-            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
-          }
-        }
-        accs[i] += dot(logits_vec, v_vec);
-      }
-    }
-  }
-
-  // Perform reduction within each warp.
-#pragma unroll
-  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-    float acc = accs[i];
-#pragma unroll
-    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
-      acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
-    }
-    accs[i] = acc;
-  }
-
-  // NOTE: A barrier is required because the shared memory space for logits
-  // is reused for the output.
-  __syncthreads();
-
-  // Perform reduction across warps.
-  float* out_smem = reinterpret_cast<float*>(shared_mem);
-#pragma unroll
-  for (int i = NUM_WARPS; i > 1; i /= 2) {
-    int mid = i / 2;
-    // Upper warps write to shared memory.
-    if (warp_idx >= mid && warp_idx < i) {
-      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
-#pragma unroll
-      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
-        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
-          dst[row_idx] = accs[i];
-        }
-      }
-    }
-    __syncthreads();
-
-    // Lower warps update the output.
-    if (warp_idx < mid) {
-      const float* src = &out_smem[warp_idx * HEAD_SIZE];
-#pragma unroll
-      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
-        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
-          accs[i] += src[row_idx];
-        }
-      }
-    }
-    __syncthreads();
-  }
-
-  // Write the final output.
-  if (warp_idx == 0) {
-    scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                            + head_idx * max_num_partitions * HEAD_SIZE
-                            + partition_idx * HEAD_SIZE;
-#pragma unroll
-    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
-      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
-      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
-        from_float(*(out_ptr + row_idx), accs[i]);
-      }
-    }
-  }
-}
-
-// Grid: (num_heads, num_seqs, 1).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE>
-__global__ void paged_attention_v1_kernel(
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>(
-    /* exp_sums */ nullptr, /* max_logits */ nullptr,
-    out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
-    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
-}
-
-// Grid: (num_heads, num_seqs, max_num_partitions).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int PARTITION_SIZE>
-__global__ void paged_attention_v2_kernel(
-  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
-  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
-  scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE, PARTITION_SIZE>(
-    exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
-    block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
-    q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
-}
-
-// Grid: (num_heads, num_seqs).
-template<
-  typename scalar_t,
-  int HEAD_SIZE,
-  int NUM_THREADS,
-  int PARTITION_SIZE>
-__global__ void paged_attention_v2_reduce_kernel(
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
-  const float* __restrict__ exp_sums,     // [num_seqs, num_heads, max_num_partitions]
-  const float* __restrict__ max_logits,   // [num_seqs, num_heads, max_num_partitions]
-  const scalar_t* __restrict__ tmp_out,   // [num_seqs, num_heads, max_num_partitions, head_size]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_partitions) {
-  const int num_heads = gridDim.x;
-  const int head_idx = blockIdx.x;
-  const int seq_idx = blockIdx.y;
-  const int context_len = context_lens[seq_idx];
-  const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
-  if (num_partitions == 1) {
-    // No need to reduce. Only copy tmp_out to out.
-    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
-    const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                                          + head_idx * max_num_partitions * HEAD_SIZE;
-    for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
-      out_ptr[i] = tmp_out_ptr[i];
-    }
-    // Terminate the thread block.
-    return;
-  }
-
-  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  const int warp_idx = threadIdx.x / WARP_SIZE;
-  const int lane = threadIdx.x % WARP_SIZE;
-
-  // Size: 2 * num_partitions.
-  extern __shared__ char shared_mem[];
-  // Workspace for reduction.
-  __shared__ float red_smem[2 * NUM_WARPS];
-
-  // Load max logits to shared memory.
-  float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
-  const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
-                                           + head_idx * max_num_partitions;
-  float max_logit = -FLT_MAX;
-  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
-    const float l = max_logits_ptr[i];
-    shared_max_logits[i] = l;
-    max_logit = fmaxf(max_logit, l);
-  }
-  __syncthreads();
-
-  // Get the global max logit.
-  // Reduce within the warp.
-#pragma unroll
-  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
-    max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
-  }
-  if (lane == 0) {
-    red_smem[warp_idx] = max_logit;
-  }
-  __syncthreads();
-  // Reduce across warps.
-  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
-#pragma unroll
-  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
-    max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
-  }
-  // Broadcast the max value to all threads.
-  max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
-
-  // Load rescaled exp sums to shared memory.
-  float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
-  const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
-                                       + head_idx * max_num_partitions;
-  float global_exp_sum = 0.0f;
-  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
-    float l = shared_max_logits[i];
-    float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
-    global_exp_sum += rescaled_exp_sum;
-    shared_exp_sums[i] = rescaled_exp_sum;
-  }
-  __syncthreads();
-  global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
-  const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
-
-  // Aggregate tmp_out to out.
-  const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                                        + head_idx * max_num_partitions * HEAD_SIZE;
-  scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
-#pragma unroll
-  for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
-    float acc = 0.0f;
-    for (int j = 0; j < num_partitions; ++j) {
-      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
-    }
-    from_float(out_ptr[i], acc);
-  }
-}
-
-} // namespace aphrodite
-
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                        \
-  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                        \
-    ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,    \
-      KV_CACHE_DTYPE>), shared_mem_size);                                                           \
-  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,              \
-  KV_CACHE_DTYPE><<<grid, block, shared_mem_size, stream>>>(                                        \
-    out_ptr,                                                                                        \
-    query_ptr,                                                                                      \
-    key_cache_ptr,                                                                                  \
-    value_cache_ptr,                                                                                \
-    num_kv_heads,                                                                                   \
-    scale,                                                                                          \
-    block_tables_ptr,                                                                               \
-    context_lens_ptr,                                                                               \
-    max_num_blocks_per_seq,                                                                         \
-    alibi_slopes_ptr,                                                                               \
-    q_stride,                                                                                       \
-    kv_block_stride,                                                                                \
-    kv_head_stride,                                                                                 \
-    k_scale,                                                                                        \
-    k_zp,                                                                                           \
-    v_scale,                                                                                        \
-    v_zp);
-
-// TODO: Tune NUM_THREADS.
-template<
-  typename T,
-  typename CACHE_T,
-  int BLOCK_SIZE,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int NUM_THREADS = 128>
-void paged_attention_v1_launcher(
-  torch::Tensor& out,
-  torch::Tensor& query,
-  torch::Tensor& key_cache,
-  torch::Tensor& value_cache,
-  int num_kv_heads,
-  float scale,
-  torch::Tensor& block_tables,
-  torch::Tensor& context_lens,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  int num_seqs = query.size(0);
-  int num_heads = query.size(1);
-  int head_size = query.size(2);
-  int max_num_blocks_per_seq = block_tables.size(1);
-  int q_stride = query.stride(0);
-  int kv_block_stride = key_cache.stride(0);
-  int kv_head_stride = key_cache.stride(1);
-
-  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
-  assert(head_size % thread_group_size == 0);
-
-  // NOTE: alibi_slopes is optional.
-  const float* alibi_slopes_ptr = alibi_slopes ?
-    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
-    : nullptr;
-
-  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
-  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
-  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
-  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
-  int* block_tables_ptr = block_tables.data_ptr<int>();
-  int* context_lens_ptr = context_lens.data_ptr<int>();
-
-  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
-  int logits_size = padded_max_context_len * sizeof(float);
-  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
-  // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
-  // Keep that in sync with the logic here!
-  int shared_mem_size = std::max(logits_size, outputs_size);
-
-  dim3 grid(num_heads, num_seqs, 1);
-  dim3 block(NUM_THREADS);
-  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
-  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  switch (head_size) {
-    // NOTE: To reduce the compilation time, we only compile for the
-    // head sizes that we use in the model. However, we can easily extend this
-    // to support any head size which is a multiple of 16.
-    case 64:
-      LAUNCH_PAGED_ATTENTION_V1(64);
-      break;
-    case 80:
-      LAUNCH_PAGED_ATTENTION_V1(80);
-      break;
-    case 96:
-      LAUNCH_PAGED_ATTENTION_V1(96);
-      break;
-    case 112:
-      LAUNCH_PAGED_ATTENTION_V1(112);
-      break;
-    case 128:
-      LAUNCH_PAGED_ATTENTION_V1(128);
-      break;
-    case 256:
-      LAUNCH_PAGED_ATTENTION_V1(256);
-      break;
-    default:
-      TORCH_CHECK(false, "Unsupported head size: ", head_size);
-      break;
-  }
-}
-
-#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)             \
-  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(       \
-    out,                                                                     \
-    query,                                                                   \
-    key_cache,                                                               \
-    value_cache,                                                             \
-    num_kv_heads,                                                            \
-    scale,                                                                   \
-    block_tables,                                                            \
-    context_lens,                                                            \
-    max_context_len,                                                         \
-    alibi_slopes,                                                            \
-    k_scale,                                                                 \
-    k_zp,                                                                    \
-    v_scale,                                                                 \
-    v_zp);
-
-// NOTE: To reduce the compilation time, we omitted block sizes
-// 1, 2, 4, 64, 128, 256.
-#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)       \
-  switch (block_size) {                                               \
-    case 8:                                                           \
-      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                \
-      break;                                                          \
-    case 16:                                                          \
-      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);               \
-      break;                                                          \
-    case 32:                                                          \
-      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);               \
-      break;                                                          \
-    default:                                                          \
-      TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
-      break;                                                          \
-  }
-
-void paged_attention_v1(
-  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
-  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
-  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
-  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
-  int num_kv_heads,               // [num_heads]
-  float scale,
-  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
-  torch::Tensor& context_lens,    // [num_seqs]
-  int block_size,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
-  if (kv_cache_dtype == "auto") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-#ifdef ENABLE_FP8_E5M2
-  } else if (kv_cache_dtype == "fp8_e5m2") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-#endif
-  } else if (kv_cache_dtype == "int8") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-  } else {
-    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
-  }
-}
-
-#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
-  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,        \
-  KV_CACHE_DTYPE, PARTITION_SIZE>                                                             \
-  <<<grid, block, shared_mem_size, stream>>>(                                                 \
-    exp_sums_ptr,                                                                             \
-    max_logits_ptr,                                                                           \
-    tmp_out_ptr,                                                                              \
-    query_ptr,                                                                                \
-    key_cache_ptr,                                                                            \
-    value_cache_ptr,                                                                          \
-    num_kv_heads,                                                                             \
-    scale,                                                                                    \
-    block_tables_ptr,                                                                         \
-    context_lens_ptr,                                                                         \
-    max_num_blocks_per_seq,                                                                   \
-    alibi_slopes_ptr,                                                                         \
-    q_stride,                                                                                 \
-    kv_block_stride,                                                                          \
-    kv_head_stride,                                                                           \
-    k_scale,                                                                                  \
-    k_zp,                                                                                     \
-    v_scale,                                                                                  \
-    v_zp);                                                                                    \
-  aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
-  <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
-    out_ptr,                                                                                  \
-    exp_sums_ptr,                                                                             \
-    max_logits_ptr,                                                                           \
-    tmp_out_ptr,                                                                              \
-    context_lens_ptr,                                                                         \
-    max_num_partitions);
-
-template<
-  typename T,
-  typename CACHE_T,
-  int BLOCK_SIZE,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int NUM_THREADS = 128,
-  int PARTITION_SIZE = 512>
-void paged_attention_v2_launcher(
-  torch::Tensor& out,
-  torch::Tensor& exp_sums,
-  torch::Tensor& max_logits,
-  torch::Tensor& tmp_out,
-  torch::Tensor& query,
-  torch::Tensor& key_cache,
-  torch::Tensor& value_cache,
-  int num_kv_heads,
-  float scale,
-  torch::Tensor& block_tables,
-  torch::Tensor& context_lens,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  int num_seqs = query.size(0);
-  int num_heads = query.size(1);
-  int head_size = query.size(2);
-  int max_num_blocks_per_seq = block_tables.size(1);
-  int q_stride = query.stride(0);
-  int kv_block_stride = key_cache.stride(0);
-  int kv_head_stride = key_cache.stride(1);
-
-  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
-  assert(head_size % thread_group_size == 0);
-
-  // NOTE: alibi_slopes is optional.
-  const float* alibi_slopes_ptr = alibi_slopes ?
-    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
-    : nullptr;
-
-  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
-  float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
-  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
-  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
-  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
-  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
-  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
-  int* block_tables_ptr = block_tables.data_ptr<int>();
-  int* context_lens_ptr = context_lens.data_ptr<int>();
-
-  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
-  int logits_size = PARTITION_SIZE * sizeof(float);
-  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
-
-  // For paged attention v2 kernel.
-  dim3 grid(num_heads, num_seqs, max_num_partitions);
-  int shared_mem_size = std::max(logits_size, outputs_size);
-  // For paged attention v2 reduce kernel.
-  dim3 reduce_grid(num_heads, num_seqs);
-  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
-
-  dim3 block(NUM_THREADS);
-  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
-  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  switch (head_size) {
-    // NOTE: To reduce the compilation time, we only compile for the
-    // head sizes that we use in the model. However, we can easily extend this
-    // to support any head size which is a multiple of 16.
-    case 64:
-      LAUNCH_PAGED_ATTENTION_V2(64);
-      break;
-    case 80:
-      LAUNCH_PAGED_ATTENTION_V2(80);
-      break;
-    case 96:
-      LAUNCH_PAGED_ATTENTION_V2(96);
-      break;
-    case 112:
-      LAUNCH_PAGED_ATTENTION_V2(112);
-      break;
-    case 128:
-      LAUNCH_PAGED_ATTENTION_V2(128);
-      break;
-    case 256:
-      LAUNCH_PAGED_ATTENTION_V2(256);
-      break;
-    default:
-      TORCH_CHECK(false, "Unsupported head size: ", head_size);
-      break;
-  }
-}
-
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)                 \
-  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(           \
-    out,                                                                         \
-    exp_sums,                                                                    \
-    max_logits,                                                                  \
-    tmp_out,                                                                     \
-    query,                                                                       \
-    key_cache,                                                                   \
-    value_cache,                                                                 \
-    num_kv_heads,                                                                \
-    scale,                                                                       \
-    block_tables,                                                                \
-    context_lens,                                                                \
-    max_context_len,                                                             \
-    alibi_slopes,                                                                \
-    k_scale,                                                                     \
-    k_zp,                                                                        \
-    v_scale,                                                                     \
-    v_zp);
-
-// NOTE: To reduce the compilation time, we omitted block sizes
-// 1, 2, 4, 64, 128, 256.
-#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)             \
-  switch (block_size) {                                                     \
-    case 8:                                                                 \
-      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                      \
-      break;                                                                \
-    case 16:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);                     \
-      break;                                                                \
-    case 32:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);                     \
-      break;                                                                \
-    default:                                                                \
-      TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
-      break;                                                                \
-  }
-
-void paged_attention_v2(
-  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
-  torch::Tensor& exp_sums,        // [num_seqs, num_heads, max_num_partitions]
-  torch::Tensor& max_logits,      // [num_seqs, num_heads, max_num_partitions]
-  torch::Tensor& tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
-  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
-  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
-  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
-  int num_kv_heads,               // [num_heads]
-  float scale,
-  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
-  torch::Tensor& context_lens,    // [num_seqs]
-  int block_size,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
-  if (kv_cache_dtype == "auto") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-#ifdef ENABLE_FP8_E5M2
-  } else if (kv_cache_dtype == "fp8_e5m2") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-#endif
-  } else if (kv_cache_dtype == "int8") {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
-  } else {
-    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
-  }
-}
-
-#undef WARP_SIZE
-#undef MAX
-#undef MIN
-#undef DIVIDE_ROUND_UP
+ #ifdef USE_ROCM
+ #include <hip/hip_runtime.h>
+ #endif
+ 
+ #include <torch/extension.h>
+ #include <ATen/cuda/CUDAContext.h>
+ #include <c10/cuda/CUDAGuard.h>
+ 
+ #include "attention_dtypes.h"
+ #include "attention_utils.cuh"
+ #ifdef ENABLE_FP8_E5M2
+ #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+ #endif
+ 
+ #include <algorithm>
+ 
+ #ifndef USE_ROCM
+ #define WARP_SIZE 32
+ #else
+ #define WARP_SIZE warpSize
+ #endif
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
+ #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
+ 
+ namespace aphrodite {
+ 
+ // Utility function for attention softmax.
+ template<int NUM_WARPS>
+ inline __device__ float block_sum(float* red_smem, float sum) {
+   // Decompose the thread index into warp / lane.
+   int warp = threadIdx.x / WARP_SIZE;
+   int lane = threadIdx.x % WARP_SIZE;
+ 
+   // Compute the sum per warp.
+ #pragma unroll
+   for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+     sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
+   }
+ 
+   // Warp leaders store the data to shared memory.
+   if (lane == 0) {
+     red_smem[warp] = sum;
+   }
+ 
+   // Make sure the data is in shared memory.
+   __syncthreads();
+ 
+   // The warps compute the final sums.
+   if (lane < NUM_WARPS) {
+     sum = red_smem[lane];
+   }
+ 
+   // Parallel reduction inside the warp.
+ #pragma unroll
+   for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+     sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
+   }
+ 
+   // Broadcast to other threads.
+   return APHRODITE_SHFL_SYNC(sum, 0);
+ }
+ 
+ // TODO: Merge the last two dimensions of the grid.
+ // Grid: (num_heads, num_seqs, max_num_partitions).
+ template<
+   typename scalar_t,
+   typename cache_t,
+   int HEAD_SIZE,
+   int BLOCK_SIZE,
+   int NUM_THREADS,
+   bool IS_FP8_E5M2_KV_CACHE,
+   int PARTITION_SIZE = 0> // Zero means no partitioning.
+ __device__ void paged_attention_kernel(
+   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
+   float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
+   scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
+   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+   const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+   const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+   const int num_kv_heads,                 // [num_heads]
+   const float scale,
+   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+   const int* __restrict__ context_lens,   // [num_seqs]
+   const int max_num_blocks_per_seq,
+   const float* __restrict__ alibi_slopes, // [num_heads]
+   const int q_stride,
+   const int kv_block_stride,
+   const int kv_head_stride) {
+   const int seq_idx = blockIdx.y;
+   const int partition_idx = blockIdx.z;
+   const int max_num_partitions = gridDim.z;
+   constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
+   const int context_len = context_lens[seq_idx];
+   if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
+     // No work to do. Terminate the thread block.
+     return;
+   }
+ 
+   const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
+   const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
+ 
+   // [start_block_idx, end_block_idx) is the range of blocks to process.
+   const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
+   const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
+   const int num_blocks = end_block_idx - start_block_idx;
+ 
+   // [start_token_idx, end_token_idx) is the range of tokens to process.
+   const int start_token_idx = start_block_idx * BLOCK_SIZE;
+   const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
+   const int num_tokens = end_token_idx - start_token_idx;
+ 
+   constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+   constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
+   assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
+   constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
+   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+   const int thread_idx = threadIdx.x;
+   const int warp_idx = thread_idx / WARP_SIZE;
+   const int lane = thread_idx % WARP_SIZE;
+ 
+   const int head_idx = blockIdx.x;
+   const int num_heads = gridDim.x;
+   const int num_queries_per_kv = num_heads / num_kv_heads;
+   const int kv_head_idx = head_idx / num_queries_per_kv;
+   const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
+ 
+   // A vector type to store a part of a key or a query.
+   // The vector size is configured in such a way that the threads in a thread group
+   // fetch or compute 16 bytes at a time.
+   // For example, if the size of a thread group is 4 and the data type is half,
+   // then the vector size is 16 / (4 * sizeof(half)) == 2.
+   constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
+   using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+   using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+ #ifdef ENABLE_FP8_E5M2
+   using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
+ #endif
+ 
+   constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
+   constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
+ 
+   const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
+   const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
+ 
+   // Load the query to registers.
+   // Each thread in a thread group has a different part of the query.
+   // For example, if the the thread group size is 4, then the first thread in the group
+   // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
+   // th vectors of the query, and so on.
+   // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
+   const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+   __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
+ #pragma unroll
+   for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
+     const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
+     q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
+   }
+   __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
+ 
+   // Memory planning.
+   extern __shared__ char shared_mem[];
+   // NOTE: We use FP32 for the softmax logits for better accuracy.
+   float* logits = reinterpret_cast<float*>(shared_mem);
+   // Workspace for reduction.
+   __shared__ float red_smem[2 * NUM_WARPS];
+ 
+   // x == THREAD_GROUP_SIZE * VEC_SIZE
+   // Each thread group fetches x elements from the key at a time.
+   constexpr int x = 16 / sizeof(cache_t);
+   float qk_max = -FLT_MAX;
+ 
+   // Iterate over the key blocks.
+   // Each warp fetches a block of keys for each iteration.
+   // Each thread group in a warp fetches a key from the block, and computes
+   // dot product with the query.
+   const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+   for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+     // NOTE: The block number is stored in int32. However, we cast it to int64
+     // because int32 can lead to overflow when this variable is multiplied by large numbers
+     // (e.g., kv_block_stride).
+     const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+ 
+     // Load a key to registers.
+     // Each thread in a thread group has a different part of the key.
+     // For example, if the the thread group size is 4, then the first thread in the group
+     // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
+     // vectors of the key, and so on.
+     for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+       const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
+       const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+       K_vec k_vecs[NUM_VECS_PER_THREAD];
+ 
+ #pragma unroll
+       for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
+         const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+                                        + kv_head_idx * kv_head_stride
+                                        + physical_block_offset * x;
+         const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
+         const int offset1 = (vec_idx * VEC_SIZE) / x;
+         const int offset2 = (vec_idx * VEC_SIZE) % x;
+         if constexpr (IS_FP8_E5M2_KV_CACHE) {
+ #ifdef ENABLE_FP8_E5M2
+           Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+           // Vector conversion from Quant_vec to K_vec.
+           k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
+ #else
+           assert(false);
+ #endif
+         } else {
+           k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+         }
+       }
+ 
+       // Compute dot product.
+       // This includes a reduction across the threads in the same thread group.
+       float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
+       // Add the ALiBi bias if slopes are given.
+       qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
+ 
+       if (thread_group_offset == 0) {
+         // Store the partial reductions to shared memory.
+         // NOTE: It is required to zero out the masked logits.
+         const bool mask = token_idx >= context_len;
+         logits[token_idx - start_token_idx] = mask ? 0.f : qk;
+         // Update the max value.
+         qk_max = mask ? qk_max : fmaxf(qk_max, qk);
+       }
+     }
+   }
+ 
+   // Perform reduction across the threads in the same warp to get the
+   // max qk value for each "warp" (not across the thread block yet).
+   // The 0-th thread of each thread group already has its max qk value.
+ #pragma unroll
+   for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
+     qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
+   }
+   if (lane == 0) {
+     red_smem[warp_idx] = qk_max;
+   }
+   __syncthreads();
+ 
+   // TODO: Refactor this part.
+   // Get the max qk value for the sequence.
+   qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+ #pragma unroll
+   for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+     qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
+   }
+   // Broadcast the max qk value to all threads.
+   qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
+ 
+   // Get the sum of the exp values.
+   float exp_sum = 0.f;
+   for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+     float val = __expf(logits[i] - qk_max);
+     logits[i] = val;
+     exp_sum += val;
+   }
+   exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
+ 
+   // Compute softmax.
+   const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
+   for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+     logits[i] *= inv_sum;
+   }
+   __syncthreads();
+ 
+   // If partitioning is enabled, store the max logit and exp_sum.
+   if (USE_PARTITIONING && thread_idx == 0) {
+     float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+                                        + head_idx * max_num_partitions
+                                        + partition_idx;
+     *max_logits_ptr = qk_max;
+     float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+                                    + head_idx * max_num_partitions
+                                    + partition_idx;
+     *exp_sums_ptr = exp_sum;
+   }
+ 
+   // Each thread will fetch 16 bytes from the value cache at a time.
+   constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
+   using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+   using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+ #ifdef ENABLE_FP8_E5M2
+   using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
+ #endif
+   using Float_L_vec = typename FloatVec<L_vec>::Type;
+ 
+   constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
+   constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
+   constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
+ 
+   // NOTE: We use FP32 for the accumulator for better accuracy.
+   float accs[NUM_ROWS_PER_THREAD];
+ #pragma unroll
+   for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+     accs[i] = 0.f;
+   }
+ 
+   scalar_t zero_value;
+   zero(zero_value);
+   for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+     // NOTE: The block number is stored in int32. However, we cast it to int64
+     // because int32 can lead to overflow when this variable is multiplied by large numbers
+     // (e.g., kv_block_stride).
+     const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+     const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
+     const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+     L_vec logits_vec;
+     from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
+ 
+     const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+                                    + kv_head_idx * kv_head_stride;
+ #pragma unroll
+     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+       if (row_idx < HEAD_SIZE) {
+         const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
+         V_vec v_vec;
+         if constexpr (IS_FP8_E5M2_KV_CACHE) {
+ #ifdef ENABLE_FP8_E5M2
+           V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+           // Vector conversion from V_quant_vec to V_vec.
+           v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
+ #else
+           assert(false);
+ #endif
+         } else {
+           v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
+         }
+         if (block_idx == num_context_blocks - 1) {
+           // NOTE: When v_vec contains the tokens that are out of the context,
+           // we should explicitly zero out the values since they may contain NaNs.
+           scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
+ #pragma unroll
+           for (int j = 0; j < V_VEC_SIZE; j++) {
+             v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
+           }
+         }
+         accs[i] += dot(logits_vec, v_vec);
+       }
+     }
+   }
+ 
+   // Perform reduction within each warp.
+ #pragma unroll
+   for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+     float acc = accs[i];
+ #pragma unroll
+     for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
+       acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
+     }
+     accs[i] = acc;
+   }
+ 
+   // NOTE: A barrier is required because the shared memory space for logits
+   // is reused for the output.
+   __syncthreads();
+ 
+   // Perform reduction across warps.
+   float* out_smem = reinterpret_cast<float*>(shared_mem);
+ #pragma unroll
+   for (int i = NUM_WARPS; i > 1; i /= 2) {
+     int mid = i / 2;
+     // Upper warps write to shared memory.
+     if (warp_idx >= mid && warp_idx < i) {
+       float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
+ #pragma unroll
+       for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+         const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+         if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+           dst[row_idx] = accs[i];
+         }
+       }
+     }
+     __syncthreads();
+ 
+     // Lower warps update the output.
+     if (warp_idx < mid) {
+       const float* src = &out_smem[warp_idx * HEAD_SIZE];
+ #pragma unroll
+       for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+         const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+         if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+           accs[i] += src[row_idx];
+         }
+       }
+     }
+     __syncthreads();
+   }
+ 
+   // Write the final output.
+   if (warp_idx == 0) {
+     scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                             + head_idx * max_num_partitions * HEAD_SIZE
+                             + partition_idx * HEAD_SIZE;
+ #pragma unroll
+     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+       if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+         from_float(*(out_ptr + row_idx), accs[i]);
+       }
+     }
+   }
+ }
+ 
+ // Grid: (num_heads, num_seqs, 1).
+ template<
+   typename scalar_t,
+   typename cache_t,
+   int HEAD_SIZE,
+   int BLOCK_SIZE,
+   int NUM_THREADS,
+   bool IS_FP8_E5M2_KV_CACHE>
+ __global__ void paged_attention_v1_kernel(
+   scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
+   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+   const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+   const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+   const int num_kv_heads,                 // [num_heads]
+   const float scale,
+   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+   const int* __restrict__ context_lens,   // [num_seqs]
+   const int max_num_blocks_per_seq,
+   const float* __restrict__ alibi_slopes, // [num_heads]
+   const int q_stride,
+   const int kv_block_stride,
+   const int kv_head_stride) {
+   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
+     /* exp_sums */ nullptr, /* max_logits */ nullptr,
+     out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
+     max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
+ }
+ 
+ // Grid: (num_heads, num_seqs, max_num_partitions).
+ template<
+   typename scalar_t,
+   typename cache_t,
+   int HEAD_SIZE,
+   int BLOCK_SIZE,
+   int NUM_THREADS,
+   bool IS_FP8_E5M2_KV_CACHE,
+   int PARTITION_SIZE>
+ __global__ void paged_attention_v2_kernel(
+   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
+   float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
+   scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
+   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+   const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+   const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+   const int num_kv_heads,                 // [num_heads]
+   const float scale,
+   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+   const int* __restrict__ context_lens,   // [num_seqs]
+   const int max_num_blocks_per_seq,
+   const float* __restrict__ alibi_slopes, // [num_heads]
+   const int q_stride,
+   const int kv_block_stride,
+   const int kv_head_stride) {
+   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
+     exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
+     block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
+     q_stride, kv_block_stride, kv_head_stride);
+ }
+ 
+ // Grid: (num_heads, num_seqs).
+ template<
+   typename scalar_t,
+   int HEAD_SIZE,
+   int NUM_THREADS,
+   int PARTITION_SIZE>
+ __global__ void paged_attention_v2_reduce_kernel(
+   scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
+   const float* __restrict__ exp_sums,     // [num_seqs, num_heads, max_num_partitions]
+   const float* __restrict__ max_logits,   // [num_seqs, num_heads, max_num_partitions]
+   const scalar_t* __restrict__ tmp_out,   // [num_seqs, num_heads, max_num_partitions, head_size]
+   const int* __restrict__ context_lens,   // [num_seqs]
+   const int max_num_partitions) {
+   const int num_heads = gridDim.x;
+   const int head_idx = blockIdx.x;
+   const int seq_idx = blockIdx.y;
+   const int context_len = context_lens[seq_idx];
+   const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+   if (num_partitions == 1) {
+     // No need to reduce. Only copy tmp_out to out.
+     scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+     const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                                           + head_idx * max_num_partitions * HEAD_SIZE;
+     for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
+       out_ptr[i] = tmp_out_ptr[i];
+     }
+     // Terminate the thread block.
+     return;
+   }
+ 
+   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+   const int warp_idx = threadIdx.x / WARP_SIZE;
+   const int lane = threadIdx.x % WARP_SIZE;
+ 
+   // Size: 2 * num_partitions.
+   extern __shared__ char shared_mem[];
+   // Workspace for reduction.
+   __shared__ float red_smem[2 * NUM_WARPS];
+ 
+   // Load max logits to shared memory.
+   float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
+   const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+                                            + head_idx * max_num_partitions;
+   float max_logit = -FLT_MAX;
+   for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+     const float l = max_logits_ptr[i];
+     shared_max_logits[i] = l;
+     max_logit = fmaxf(max_logit, l);
+   }
+   __syncthreads();
+ 
+   // Get the global max logit.
+   // Reduce within the warp.
+ #pragma unroll
+   for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+     max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
+   }
+   if (lane == 0) {
+     red_smem[warp_idx] = max_logit;
+   }
+   __syncthreads();
+   // Reduce across warps.
+   max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+ #pragma unroll
+   for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+     max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
+   }
+   // Broadcast the max value to all threads.
+   max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
+ 
+   // Load rescaled exp sums to shared memory.
+   float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
+   const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+                                        + head_idx * max_num_partitions;
+   float global_exp_sum = 0.0f;
+   for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+     float l = shared_max_logits[i];
+     float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
+     global_exp_sum += rescaled_exp_sum;
+     shared_exp_sums[i] = rescaled_exp_sum;
+   }
+   __syncthreads();
+   global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
+   const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
+ 
+   // Aggregate tmp_out to out.
+   const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                                         + head_idx * max_num_partitions * HEAD_SIZE;
+   scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+ #pragma unroll
+   for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
+     float acc = 0.0f;
+     for (int j = 0; j < num_partitions; ++j) {
+       acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
+     }
+     from_float(out_ptr[i], acc);
+   }
+ }
+ 
+ } // namespace aphrodite
+ 
+ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                  \
+   APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                       \
+     ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,   \
+       IS_FP8_E5M2_KV_CACHE>), shared_mem_size);                                               \
+   aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
+   IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>(                            \
+     out_ptr,                                                                                  \
+     query_ptr,                                                                                \
+     key_cache_ptr,                                                                            \
+     value_cache_ptr,                                                                          \
+     num_kv_heads,                                                                             \
+     scale,                                                                                    \
+     block_tables_ptr,                                                                         \
+     context_lens_ptr,                                                                         \
+     max_num_blocks_per_seq,                                                                   \
+     alibi_slopes_ptr,                                                                         \
+     q_stride,                                                                                 \
+     kv_block_stride,                                                                          \
+     kv_head_stride);
+ 
+ // TODO: Tune NUM_THREADS.
+ template<
+   typename T,
+   typename CACHE_T,
+   int BLOCK_SIZE,
+   bool IS_FP8_E5M2_KV_CACHE,
+   int NUM_THREADS = 128>
+ void paged_attention_v1_launcher(
+   torch::Tensor& out,
+   torch::Tensor& query,
+   torch::Tensor& key_cache,
+   torch::Tensor& value_cache,
+   int num_kv_heads,
+   float scale,
+   torch::Tensor& block_tables,
+   torch::Tensor& context_lens,
+   int max_context_len,
+   const c10::optional<torch::Tensor>& alibi_slopes) {
+   int num_seqs = query.size(0);
+   int num_heads = query.size(1);
+   int head_size = query.size(2);
+   int max_num_blocks_per_seq = block_tables.size(1);
+   int q_stride = query.stride(0);
+   int kv_block_stride = key_cache.stride(0);
+   int kv_head_stride = key_cache.stride(1);
+ 
+   int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+   assert(head_size % thread_group_size == 0);
+ 
+   // NOTE: alibi_slopes is optional.
+   const float* alibi_slopes_ptr = alibi_slopes ?
+     reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+     : nullptr;
+ 
+   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
+   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
+   CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
+   CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
+   int* block_tables_ptr = block_tables.data_ptr<int>();
+   int* context_lens_ptr = context_lens.data_ptr<int>();
+ 
+   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+   int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
+   int logits_size = padded_max_context_len * sizeof(float);
+   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+   // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
+   // Keep that in sync with the logic here!
+   int shared_mem_size = std::max(logits_size, outputs_size);
+ 
+   dim3 grid(num_heads, num_seqs, 1);
+   dim3 block(NUM_THREADS);
+   const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+   switch (head_size) {
+     // NOTE: To reduce the compilation time, we only compile for the
+     // head sizes that we use in the model. However, we can easily extend this
+     // to support any head size which is a multiple of 16.
+     case 64:
+       LAUNCH_PAGED_ATTENTION_V1(64);
+       break;
+     case 80:
+       LAUNCH_PAGED_ATTENTION_V1(80);
+       break;
+     case 96:
+       LAUNCH_PAGED_ATTENTION_V1(96);
+       break;
+     case 112:
+       LAUNCH_PAGED_ATTENTION_V1(112);
+       break;
+     case 128:
+       LAUNCH_PAGED_ATTENTION_V1(128);
+       break;
+     case 256:
+       LAUNCH_PAGED_ATTENTION_V1(256);
+       break;
+     default:
+       TORCH_CHECK(false, "Unsupported head size: ", head_size);
+       break;
+   }
+ }
+ 
+ #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)       \
+   paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
+     out,                                                                     \
+     query,                                                                   \
+     key_cache,                                                               \
+     value_cache,                                                             \
+     num_kv_heads,                                                            \
+     scale,                                                                   \
+     block_tables,                                                            \
+     context_lens,                                                            \
+     max_context_len,                                                         \
+     alibi_slopes);
+ 
+ // NOTE: To reduce the compilation time, we omitted block sizes
+ // 1, 2, 4, 64, 128, 256.
+ #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
+   switch (block_size) {                                               \
+     case 8:                                                           \
+       CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);          \
+       break;                                                          \
+     case 16:                                                          \
+       CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);         \
+       break;                                                          \
+     case 32:                                                          \
+       CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);         \
+       break;                                                          \
+     default:                                                          \
+       TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
+       break;                                                          \
+   }
+ 
+ void paged_attention_v1(
+   torch::Tensor& out,             // [num_seqs, num_heads, head_size]
+   torch::Tensor& query,           // [num_seqs, num_heads, head_size]
+   torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
+   torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
+   int num_kv_heads,               // [num_heads]
+   float scale,
+   torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
+   torch::Tensor& context_lens,    // [num_seqs]
+   int block_size,
+   int max_context_len,
+   const c10::optional<torch::Tensor>& alibi_slopes,
+   const std::string& kv_cache_dtype) {
+   if (kv_cache_dtype == "auto") {
+     if (query.dtype() == at::ScalarType::Float) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
+     } else if (query.dtype() == at::ScalarType::Half) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
+     } else if (query.dtype() == at::ScalarType::BFloat16) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
+     } else {
+       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+     }
+   } else if (kv_cache_dtype == "fp8_e5m2") {
+     if (query.dtype() == at::ScalarType::Float) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
+     } else if (query.dtype() == at::ScalarType::Half) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
+     } else if (query.dtype() == at::ScalarType::BFloat16) {
+       CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
+     } else {
+       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+     }
+   } else {
+     TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+   }
+ }
+ 
+ #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
+   aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
+   IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>                                                       \
+   <<<grid, block, shared_mem_size, stream>>>(                                                 \
+     exp_sums_ptr,                                                                             \
+     max_logits_ptr,                                                                           \
+     tmp_out_ptr,                                                                              \
+     query_ptr,                                                                                \
+     key_cache_ptr,                                                                            \
+     value_cache_ptr,                                                                          \
+     num_kv_heads,                                                                             \
+     scale,                                                                                    \
+     block_tables_ptr,                                                                         \
+     context_lens_ptr,                                                                         \
+     max_num_blocks_per_seq,                                                                   \
+     alibi_slopes_ptr,                                                                         \
+     q_stride,                                                                                 \
+     kv_block_stride,                                                                          \
+     kv_head_stride);                                                                          \
+   aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
+   <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
+     out_ptr,                                                                                  \
+     exp_sums_ptr,                                                                             \
+     max_logits_ptr,                                                                           \
+     tmp_out_ptr,                                                                              \
+     context_lens_ptr,                                                                         \
+     max_num_partitions);
+ 
+ template<
+   typename T,
+   typename CACHE_T,
+   int BLOCK_SIZE,
+   bool IS_FP8_E5M2_KV_CACHE,
+   int NUM_THREADS = 128,
+   int PARTITION_SIZE = 512>
+ void paged_attention_v2_launcher(
+   torch::Tensor& out,
+   torch::Tensor& exp_sums,
+   torch::Tensor& max_logits,
+   torch::Tensor& tmp_out,
+   torch::Tensor& query,
+   torch::Tensor& key_cache,
+   torch::Tensor& value_cache,
+   int num_kv_heads,
+   float scale,
+   torch::Tensor& block_tables,
+   torch::Tensor& context_lens,
+   int max_context_len,
+   const c10::optional<torch::Tensor>& alibi_slopes) {
+   int num_seqs = query.size(0);
+   int num_heads = query.size(1);
+   int head_size = query.size(2);
+   int max_num_blocks_per_seq = block_tables.size(1);
+   int q_stride = query.stride(0);
+   int kv_block_stride = key_cache.stride(0);
+   int kv_head_stride = key_cache.stride(1);
+ 
+   int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+   assert(head_size % thread_group_size == 0);
+ 
+   // NOTE: alibi_slopes is optional.
+   const float* alibi_slopes_ptr = alibi_slopes ?
+     reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+     : nullptr;
+ 
+   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
+   float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
+   float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
+   T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
+   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
+   CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
+   CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
+   int* block_tables_ptr = block_tables.data_ptr<int>();
+   int* context_lens_ptr = context_lens.data_ptr<int>();
+ 
+   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+   int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
+   int logits_size = PARTITION_SIZE * sizeof(float);
+   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+ 
+   // For paged attention v2 kernel.
+   dim3 grid(num_heads, num_seqs, max_num_partitions);
+   int shared_mem_size = std::max(logits_size, outputs_size);
+   // For paged attention v2 reduce kernel.
+   dim3 reduce_grid(num_heads, num_seqs);
+   int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
+ 
+   dim3 block(NUM_THREADS);
+   const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+   switch (head_size) {
+     // NOTE: To reduce the compilation time, we only compile for the
+     // head sizes that we use in the model. However, we can easily extend this
+     // to support any head size which is a multiple of 16.
+     case 64:
+       LAUNCH_PAGED_ATTENTION_V2(64);
+       break;
+     case 80:
+       LAUNCH_PAGED_ATTENTION_V2(80);
+       break;
+     case 96:
+       LAUNCH_PAGED_ATTENTION_V2(96);
+       break;
+     case 112:
+       LAUNCH_PAGED_ATTENTION_V2(112);
+       break;
+     case 128:
+       LAUNCH_PAGED_ATTENTION_V2(128);
+       break;
+     case 256:
+       LAUNCH_PAGED_ATTENTION_V2(256);
+       break;
+     default:
+       TORCH_CHECK(false, "Unsupported head size: ", head_size);
+       break;
+   }
+ }
+ 
+ #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)           \
+   paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>(     \
+     out,                                                                         \
+     exp_sums,                                                                    \
+     max_logits,                                                                  \
+     tmp_out,                                                                     \
+     query,                                                                       \
+     key_cache,                                                                   \
+     value_cache,                                                                 \
+     num_kv_heads,                                                                \
+     scale,                                                                       \
+     block_tables,                                                                \
+     context_lens,                                                                \
+     max_context_len,                                                             \
+     alibi_slopes);
+ 
+ // NOTE: To reduce the compilation time, we omitted block sizes
+ // 1, 2, 4, 64, 128, 256.
+ #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE)       \
+   switch (block_size) {                                                     \
+     case 8:                                                                 \
+       CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);                \
+       break;                                                                \
+     case 16:                                                                \
+       CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);               \
+       break;                                                                \
+     case 32:                                                                \
+       CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);               \
+       break;                                                                \
+     default:                                                                \
+       TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
+       break;                                                                \
+   }
+ 
+ void paged_attention_v2(
+   torch::Tensor& out,             // [num_seqs, num_heads, head_size]
+   torch::Tensor& exp_sums,        // [num_seqs, num_heads, max_num_partitions]
+   torch::Tensor& max_logits,      // [num_seqs, num_heads, max_num_partitions]
+   torch::Tensor& tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
+   torch::Tensor& query,           // [num_seqs, num_heads, head_size]
+   torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
+   torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
+   int num_kv_heads,               // [num_heads]
+   float scale,
+   torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
+   torch::Tensor& context_lens,    // [num_seqs]
+   int block_size,
+   int max_context_len,
+   const c10::optional<torch::Tensor>& alibi_slopes,
+   const std::string& kv_cache_dtype) {
+   if (kv_cache_dtype == "auto") {
+     if (query.dtype() == at::ScalarType::Float) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
+     } else if (query.dtype() == at::ScalarType::Half) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
+     } else if (query.dtype() == at::ScalarType::BFloat16) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
+     } else {
+       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+     }
+   } else if (kv_cache_dtype == "fp8_e5m2") {
+     if (query.dtype() == at::ScalarType::Float) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
+     } else if (query.dtype() == at::ScalarType::Half) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
+     } else if (query.dtype() == at::ScalarType::BFloat16) {
+       CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
+     } else {
+       TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+     }
+   } else {
+     TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+   }
+ }
+ 
+ #undef WARP_SIZE
+ #undef MAX
+ #undef MIN
+ #undef DIVIDE_ROUND_UP
+ 

+ 255 - 262
kernels/attention/dtype_float32.cuh

@@ -16,265 +16,258 @@
  * See the License for the specific language governing permissions and
  * See the License for the specific language governing permissions and
  * limitations under the License.
  * limitations under the License.
  */
  */
-#pragma once
-
-#include "attention_generic.cuh"
-
-#include <stdint.h>
-
-namespace aphrodite {
-
-// Define custom FP32 vector data types.
-struct Float4_ {
-  float2 x;
-  float2 y;
-};
-
-struct Float8_ {
-  float2 x;
-  float2 y;
-  float2 z;
-  float2 w;
-};
-
-// FP32 vector types for Q, K, V.
-template<>
-struct Vec<float, 1> {
-  using Type = float;
-};
-template<>
-struct Vec<float, 2> {
-  using Type = float2;
-};
-template<>
-struct Vec<float, 4> {
-  using Type = float4;
-};
-
-// FP32 accumulator vector types corresponding to Vec.
-template<>
-struct FloatVec<float> {
-  using Type = float;
-};
-template<>
-struct FloatVec<float2> {
-  using Type = float2;
-};
-template<>
-struct FloatVec<float4> {
-  using Type = float4;
-};
-
-// Vector addition.
-inline __device__ float add(float a, float b) {
-  return a + b;
-}
-
-inline __device__ float2 add(float2 a, float2 b) {
-  float2 c;
-  c.x = add(a.x, b.x);
-  c.y = add(a.y, b.y);
-  return c;
-}
-
-inline __device__ float4 add(float4 a, float4 b) {
-  float4 c;
-  c.x = add(a.x, b.x);
-  c.y = add(a.y, b.y);
-  c.z = add(a.z, b.z);
-  c.w = add(a.w, b.w);
-  return c;
-}
-
-inline __device__ Float4_ add(Float4_ a, Float4_ b) {
-  Float4_ c;
-  c.x = add(a.x, b.x);
-  c.y = add(a.y, b.y);
-  return c;
-}
-
-// Vector multiplication.
-template<>
-inline __device__ float mul<float, float>(float a, float b) {
-  return a * b;
-}
-
-template<>
-inline __device__ float2 mul(float2 a, float2 b) {
-  float2 c;
-  c.x = a.x * b.x;
-  c.y = a.y * b.y;
-  return c;
-}
-
-template<>
-inline __device__ float2 mul(float a, float2 b) {
-  float2 c;
-  c.x = a * b.x;
-  c.y = a * b.y;
-  return c;
-}
-
-template<>
-inline __device__ float4 mul(float4 a, float4 b) {
-  float4 c;
-  c.x = a.x * b.x;
-  c.y = a.y * b.y;
-  c.z = a.z * b.z;
-  c.w = a.w * b.w;
-  return c;
-}
-
-template<>
-inline __device__ float4 mul(float a, float4 b) {
-  float4 c;
-  c.x = a * b.x;
-  c.y = a * b.y;
-  c.z = a * b.z;
-  c.w = a * b.w;
-  return c;
-}
-
-// Vector fused multiply-add.
-inline __device__ float fma(float a, float b, float c) {
-  return a * b + c;
-}
-
-inline __device__ float2 fma(float2 a, float2 b, float2 c) {
-  float2 d;
-  d.x = fma(a.x, b.x, c.x);
-  d.y = fma(a.y, b.y, c.y);
-  return d;
-}
-
-inline __device__ float2 fma(float a, float2 b, float2 c) {
-  float2 d;
-  d.x = fma(a, b.x, c.x);
-  d.y = fma(a, b.y, c.y);
-  return d;
-}
-
-inline __device__ float4 fma(float4 a, float4 b, float4 c) {
-  float4 d;
-  d.x = fma(a.x, b.x, c.x);
-  d.y = fma(a.y, b.y, c.y);
-  d.z = fma(a.z, b.z, c.z);
-  d.w = fma(a.w, b.w, c.w);
-  return d;
-}
-
-inline __device__ float4 fma(float a, float4 b, float4 c) {
-  float4 d;
-  d.x = fma(a, b.x, c.x);
-  d.y = fma(a, b.y, c.y);
-  d.z = fma(a, b.z, c.z);
-  d.w = fma(a, b.w, c.w);
-  return d;
-}
-
-inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
-  Float4_ d;
-  d.x = fma(a, b.x, c.x);
-  d.y = fma(a, b.y, c.y);
-  return d;
-}
-
-inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
-  Float8_ d;
-  d.x = fma(a, b.x, c.x);
-  d.y = fma(a, b.y, c.y);
-  d.z = fma(a, b.z, c.z);
-  d.w = fma(a, b.w, c.w);
-  return d;
-}
-
-// Vector sum.
-template<>
-inline __device__ float sum(float v) {
-  return v;
-}
-
-template<>
-inline __device__ float sum(float2 v) {
-  return v.x + v.y;
-}
-
-template<>
-inline __device__ float sum(float4 v) {
-  return v.x + v.y + v.z + v.w;
-}
-
-template<>
-inline __device__ float sum(Float4_ v) {
-  return v.x.x + v.x.y + v.y.x + v.y.y;
-}
-
-template<>
-inline __device__ float sum(Float8_ v) {
-  return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
-}
-
-// Vector dot product.
-inline __device__ float dot(float a, float b) {
-  return a * b;
-}
-
-inline __device__ float dot(float2 a, float2 b) {
-  float2 c = mul<float2, float2, float2>(a, b);
-  return c.x + c.y;
-}
-
-inline __device__ float dot(Float4_ a, Float4_ b) {
-  float2 acc = mul<float2, float2, float2>(a.x, b.x);
-  acc = fma(a.y, b.y, acc);
-  return acc.x + acc.y;
-}
-
-inline __device__ float dot(Float8_ a, Float8_ b) {
-  float2 acc = mul<float2, float2, float2>(a.x, b.x);
-  acc = fma(a.y, b.y, acc);
-  acc = fma(a.z, b.z, acc);
-  acc = fma(a.w, b.w, acc);
-  return acc.x + acc.y;
-}
-
-// From float to float.
-inline __device__ void from_float(float& dst, float src) {
-  dst = src;
-}
-
-inline __device__ void from_float(float2& dst, float2 src) {
-  dst = src;
-}
-
-inline __device__ void from_float(float4& dst, float4 src) {
-  dst = src;
-}
-
-// From float to float.
-inline __device__ float to_float(float u) {
-  return u;
-}
-
-inline __device__ float2 to_float(float2 u) {
-  return u;
-}
-
-inline __device__ float4 to_float(float4 u) {
-  return u;
-}
-
-inline __device__ Float4_ to_float(Float4_ u) {
-  return u;
-}
-
-inline __device__ Float8_ to_float(Float8_ u) {
-  return u;
-}
-
-// Zero-out a variable.
-inline __device__ void zero(float& dst) {
-  dst = 0.f;
-}
-
-} // namespace aphrodite
+ #pragma once
+
+ #include "attention_generic.cuh"
+ 
+ #include <stdint.h>
+ 
+ namespace aphrodite {
+ 
+ // Define custom FP32 vector data types.
+ struct Float4_ {
+   float2 x;
+   float2 y;
+ };
+ 
+ struct Float8_ {
+   float2 x;
+   float2 y;
+   float2 z;
+   float2 w;
+ };
+ 
+ // FP32 vector types for Q, K, V.
+ template<>
+ struct Vec<float, 1> {
+   using Type = float;
+ };
+ template<>
+ struct Vec<float, 2> {
+   using Type = float2;
+ };
+ template<>
+ struct Vec<float, 4> {
+   using Type = float4;
+ };
+ 
+ // FP32 accumulator vector types corresponding to Vec.
+ template<>
+ struct FloatVec<float> {
+   using Type = float;
+ };
+ template<>
+ struct FloatVec<float2> {
+   using Type = float2;
+ };
+ template<>
+ struct FloatVec<float4> {
+   using Type = float4;
+ };
+ 
+ // Vector addition.
+ inline __device__ float add(float a, float b) {
+   return a + b;
+ }
+ 
+ inline __device__ float2 add(float2 a, float2 b) {
+   float2 c;
+   c.x = add(a.x, b.x);
+   c.y = add(a.y, b.y);
+   return c;
+ }
+ 
+ inline __device__ float4 add(float4 a, float4 b) {
+   float4 c;
+   c.x = add(a.x, b.x);
+   c.y = add(a.y, b.y);
+   c.z = add(a.z, b.z);
+   c.w = add(a.w, b.w);
+   return c;
+ }
+ 
+ // Vector multiplication.
+ template<>
+ inline __device__ float mul<float, float>(float a, float b) {
+   return a * b;
+ }
+ 
+ template<>
+ inline __device__ float2 mul(float2 a, float2 b) {
+   float2 c;
+   c.x = a.x * b.x;
+   c.y = a.y * b.y;
+   return c;
+ }
+ 
+ template<>
+ inline __device__ float2 mul(float a, float2 b) {
+   float2 c;
+   c.x = a * b.x;
+   c.y = a * b.y;
+   return c;
+ }
+ 
+ template<>
+ inline __device__ float4 mul(float4 a, float4 b) {
+   float4 c;
+   c.x = a.x * b.x;
+   c.y = a.y * b.y;
+   c.z = a.z * b.z;
+   c.w = a.w * b.w;
+   return c;
+ }
+ 
+ template<>
+ inline __device__ float4 mul(float a, float4 b) {
+   float4 c;
+   c.x = a * b.x;
+   c.y = a * b.y;
+   c.z = a * b.z;
+   c.w = a * b.w;
+   return c;
+ }
+ 
+ // Vector fused multiply-add.
+ inline __device__ float fma(float a, float b, float c) {
+   return a * b + c;
+ }
+ 
+ inline __device__ float2 fma(float2 a, float2 b, float2 c) {
+   float2 d;
+   d.x = fma(a.x, b.x, c.x);
+   d.y = fma(a.y, b.y, c.y);
+   return d;
+ }
+ 
+ inline __device__ float2 fma(float a, float2 b, float2 c) {
+   float2 d;
+   d.x = fma(a, b.x, c.x);
+   d.y = fma(a, b.y, c.y);
+   return d;
+ }
+ 
+ inline __device__ float4 fma(float4 a, float4 b, float4 c) {
+   float4 d;
+   d.x = fma(a.x, b.x, c.x);
+   d.y = fma(a.y, b.y, c.y);
+   d.z = fma(a.z, b.z, c.z);
+   d.w = fma(a.w, b.w, c.w);
+   return d;
+ }
+ 
+ inline __device__ float4 fma(float a, float4 b, float4 c) {
+   float4 d;
+   d.x = fma(a, b.x, c.x);
+   d.y = fma(a, b.y, c.y);
+   d.z = fma(a, b.z, c.z);
+   d.w = fma(a, b.w, c.w);
+   return d;
+ }
+ 
+ inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
+   Float4_ d;
+   d.x = fma(a, b.x, c.x);
+   d.y = fma(a, b.y, c.y);
+   return d;
+ }
+ 
+ inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
+   Float8_ d;
+   d.x = fma(a, b.x, c.x);
+   d.y = fma(a, b.y, c.y);
+   d.z = fma(a, b.z, c.z);
+   d.w = fma(a, b.w, c.w);
+   return d;
+ }
+ 
+ // Vector sum.
+ template<>
+ inline __device__ float sum(float v) {
+   return v;
+ }
+ 
+ template<>
+ inline __device__ float sum(float2 v) {
+   return v.x + v.y;
+ }
+ 
+ template<>
+ inline __device__ float sum(float4 v) {
+   return v.x + v.y + v.z + v.w;
+ }
+ 
+ template<>
+ inline __device__ float sum(Float4_ v) {
+   return v.x.x + v.x.y + v.y.x + v.y.y;
+ }
+ 
+ template<>
+ inline __device__ float sum(Float8_ v) {
+   return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
+ }
+ 
+ // Vector dot product.
+ inline __device__ float dot(float a, float b) {
+   return a * b;
+ }
+ 
+ inline __device__ float dot(float2 a, float2 b) {
+   float2 c = mul<float2, float2, float2>(a, b);
+   return c.x + c.y;
+ }
+ 
+ inline __device__ float dot(Float4_ a, Float4_ b) {
+   float2 acc = mul<float2, float2, float2>(a.x, b.x);
+   acc = fma(a.y, b.y, acc);
+   return acc.x + acc.y;
+ }
+ 
+ inline __device__ float dot(Float8_ a, Float8_ b) {
+   float2 acc = mul<float2, float2, float2>(a.x, b.x);
+   acc = fma(a.y, b.y, acc);
+   acc = fma(a.z, b.z, acc);
+   acc = fma(a.w, b.w, acc);
+   return acc.x + acc.y;
+ }
+ 
+ // From float to float.
+ inline __device__ void from_float(float& dst, float src) {
+   dst = src;
+ }
+ 
+ inline __device__ void from_float(float2& dst, float2 src) {
+   dst = src;
+ }
+ 
+ inline __device__ void from_float(float4& dst, float4 src) {
+   dst = src;
+ }
+ 
+ // From float to float.
+ inline __device__ float to_float(float u) {
+   return u;
+ }
+ 
+ inline __device__ float2 to_float(float2 u) {
+   return u;
+ }
+ 
+ inline __device__ float4 to_float(float4 u) {
+   return u;
+ }
+ 
+ inline __device__ Float4_ to_float(Float4_ u) {
+   return u;
+ }
+ 
+ inline __device__ Float8_ to_float(Float8_ u) {
+   return u;
+ }
+ 
+ // Zero-out a variable.
+ inline __device__ void zero(float& dst) {
+   dst = 0.f;
+ }
+ 
+ } // namespace aphrodite

+ 1 - 0
kernels/backup/README

@@ -0,0 +1 @@
+Backup of attention and cache kernels from INT8 KV Cache. Will be restored soon.

+ 8 - 0
kernels/backup/attention_dtypes.h

@@ -0,0 +1,8 @@
+#pragma once
+
+#include "attention_generic.cuh"
+#include "dtype_float16.cuh"
+#include "dtype_float32.cuh"
+#include "dtype_bfloat16.cuh"
+#include "dtype_fp8_e5m2.cuh"
+#include "dtype_int8.cuh"

+ 1032 - 0
kernels/backup/attention_kernels.cu

@@ -0,0 +1,1032 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2023, The PygmalionAI team.
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#ifdef USE_ROCM
+#include <hip/hip_runtime.h>
+#endif
+
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+
+#include "attention_dtypes.h"
+#include "attention_utils.cuh"
+#include "../quantization/int8_kvcache/quant_utils.cuh"
+#ifdef ENABLE_FP8_E5M2
+#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+#endif
+
+#include <algorithm>
+
+#ifndef USE_ROCM
+#define WARP_SIZE 32
+#else
+#define WARP_SIZE warpSize
+#endif
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
+
+enum kv_cache_dtype {
+  AUTO,
+#ifdef ENABLE_FP8_E5M2
+  FP8_E5M2,
+#endif
+  INT8};
+
+namespace aphrodite {
+
+// Utility function for attention softmax.
+template<int NUM_WARPS>
+inline __device__ float block_sum(float* red_smem, float sum) {
+  // Decompose the thread index into warp / lane.
+  int warp = threadIdx.x / WARP_SIZE;
+  int lane = threadIdx.x % WARP_SIZE;
+
+  // Compute the sum per warp.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+    sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
+  }
+
+  // Warp leaders store the data to shared memory.
+  if (lane == 0) {
+    red_smem[warp] = sum;
+  }
+
+  // Make sure the data is in shared memory.
+  __syncthreads();
+
+  // The warps compute the final sums.
+  if (lane < NUM_WARPS) {
+    sum = red_smem[lane];
+  }
+
+  // Parallel reduction inside the warp.
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    sum += APHRODITE_SHFL_XOR_SYNC(sum, mask);
+  }
+
+  // Broadcast to other threads.
+  return APHRODITE_SHFL_SYNC(sum, 0);
+}
+
+// TODO: Merge the last two dimensions of the grid.
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template<
+  typename scalar_t,
+  typename cache_t,
+  int HEAD_SIZE,
+  int BLOCK_SIZE,
+  int NUM_THREADS,
+  kv_cache_dtype KV_CACHE_DTYPE,
+  int PARTITION_SIZE = 0> // Zero means no partitioning.
+__device__ void paged_attention_kernel(
+  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
+  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
+  scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
+  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+  const int num_kv_heads,                 // [num_heads]
+  const float scale,
+  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+  const int* __restrict__ context_lens,   // [num_seqs]
+  const int max_num_blocks_per_seq,
+  const float* __restrict__ alibi_slopes, // [num_heads]
+  const int q_stride,
+  const int kv_block_stride,
+  const int kv_head_stride,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f) {
+  const int seq_idx = blockIdx.y;
+  const int partition_idx = blockIdx.z;
+  const int max_num_partitions = gridDim.z;
+  constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
+  const int context_len = context_lens[seq_idx];
+  if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) {
+    // No work to do. Terminate the thread block.
+    return;
+  }
+
+  const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
+  const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
+
+  // [start_block_idx, end_block_idx) is the range of blocks to process.
+  const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
+  const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
+  const int num_blocks = end_block_idx - start_block_idx;
+
+  // [start_token_idx, end_token_idx) is the range of tokens to process.
+  const int start_token_idx = start_block_idx * BLOCK_SIZE;
+  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
+  const int num_tokens = end_token_idx - start_token_idx;
+
+  constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
+  assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
+  constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  const int thread_idx = threadIdx.x;
+  const int warp_idx = thread_idx / WARP_SIZE;
+  const int lane = thread_idx % WARP_SIZE;
+
+  const int head_idx = blockIdx.x;
+  const int num_heads = gridDim.x;
+  const int num_queries_per_kv = num_heads / num_kv_heads;
+  const int kv_head_idx = head_idx / num_queries_per_kv;
+  const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
+
+  // A vector type to store a part of a key or a query.
+  // The vector size is configured in such a way that the threads in a thread group
+  // fetch or compute 16 bytes at a time.
+  // For example, if the size of a thread group is 4 and the data type is half,
+  // then the vector size is 16 / (4 * sizeof(half)) == 2.
+  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
+  using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+  using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
+  using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
+
+  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
+  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
+
+  const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
+  const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
+
+  // Load the query to registers.
+  // Each thread in a thread group has a different part of the query.
+  // For example, if the the thread group size is 4, then the first thread in the group
+  // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
+  // th vectors of the query, and so on.
+  // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
+  const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+  __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
+#pragma unroll
+  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
+    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
+    q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
+  }
+  __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
+
+  // Memory planning.
+  extern __shared__ char shared_mem[];
+  // NOTE: We use FP32 for the softmax logits for better accuracy.
+  float* logits = reinterpret_cast<float*>(shared_mem);
+  // Workspace for reduction.
+  __shared__ float red_smem[2 * NUM_WARPS];
+
+  // x == THREAD_GROUP_SIZE * VEC_SIZE
+  // Each thread group fetches x elements from the key at a time.
+  constexpr int x = 16 / sizeof(cache_t);
+  float qk_max = -FLT_MAX;
+
+  // Iterate over the key blocks.
+  // Each warp fetches a block of keys for each iteration.
+  // Each thread group in a warp fetches a key from the block, and computes
+  // dot product with the query.
+  const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+    // NOTE: The block number is stored in int32. However, we cast it to int64
+    // because int32 can lead to overflow when this variable is multiplied by large numbers
+    // (e.g., kv_block_stride).
+    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+
+    // Load a key to registers.
+    // Each thread in a thread group has a different part of the key.
+    // For example, if the the thread group size is 4, then the first thread in the group
+    // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
+    // vectors of the key, and so on.
+    for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+      const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
+      const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+      K_vec k_vecs[NUM_VECS_PER_THREAD];
+
+#pragma unroll
+      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
+        const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+                                       + kv_head_idx * kv_head_stride
+                                       + physical_block_offset * x;
+        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
+        const int offset1 = (vec_idx * VEC_SIZE) / x;
+        const int offset2 = (vec_idx * VEC_SIZE) % x;
+        if constexpr (KV_CACHE_DTYPE == INT8) {
+          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          using Dequant_vec = typename FloatVec<Quant_vec>::Type;
+          Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
+          k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
+#ifdef ENABLE_FP8_E5M2
+        } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
+          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          // Vector conversion from Quant_vec to K_vec.
+          k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
+#endif
+        } else {
+          k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+        }
+      }
+
+      // Compute dot product.
+      // This includes a reduction across the threads in the same thread group.
+      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
+      // Add the ALiBi bias if slopes are given.
+      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
+
+      if (thread_group_offset == 0) {
+        // Store the partial reductions to shared memory.
+        // NOTE: It is required to zero out the masked logits.
+        const bool mask = token_idx >= context_len;
+        logits[token_idx - start_token_idx] = mask ? 0.f : qk;
+        // Update the max value.
+        qk_max = mask ? qk_max : fmaxf(qk_max, qk);
+      }
+    }
+  }
+
+  // Perform reduction across the threads in the same warp to get the
+  // max qk value for each "warp" (not across the thread block yet).
+  // The 0-th thread of each thread group already has its max qk value.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
+    qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
+  }
+  if (lane == 0) {
+    red_smem[warp_idx] = qk_max;
+  }
+  __syncthreads();
+
+  // TODO: Refactor this part.
+  // Get the max qk value for the sequence.
+  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    qk_max = fmaxf(qk_max, APHRODITE_SHFL_XOR_SYNC(qk_max, mask));
+  }
+  // Broadcast the max qk value to all threads.
+  qk_max = APHRODITE_SHFL_SYNC(qk_max, 0);
+
+  // Get the sum of the exp values.
+  float exp_sum = 0.f;
+  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+    float val = __expf(logits[i] - qk_max);
+    logits[i] = val;
+    exp_sum += val;
+  }
+  exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
+
+  // Compute softmax.
+  const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
+  for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+    logits[i] *= inv_sum;
+  }
+  __syncthreads();
+
+  // If partitioning is enabled, store the max logit and exp_sum.
+  if (USE_PARTITIONING && thread_idx == 0) {
+    float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+                                       + head_idx * max_num_partitions
+                                       + partition_idx;
+    *max_logits_ptr = qk_max;
+    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+                                   + head_idx * max_num_partitions
+                                   + partition_idx;
+    *exp_sums_ptr = exp_sum;
+  }
+
+  // Each thread will fetch 16 bytes from the value cache at a time.
+  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
+  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
+  using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
+  using Float_L_vec = typename FloatVec<L_vec>::Type;
+
+  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
+  constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
+  constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
+
+  // NOTE: We use FP32 for the accumulator for better accuracy.
+  float accs[NUM_ROWS_PER_THREAD];
+#pragma unroll
+  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+    accs[i] = 0.f;
+  }
+
+  scalar_t zero_value;
+  zero(zero_value);
+  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+    // NOTE: The block number is stored in int32. However, we cast it to int64
+    // because int32 can lead to overflow when this variable is multiplied by large numbers
+    // (e.g., kv_block_stride).
+    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+    const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
+    const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+    L_vec logits_vec;
+    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
+
+    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+                                   + kv_head_idx * kv_head_stride;
+#pragma unroll
+    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+      if (row_idx < HEAD_SIZE) {
+        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
+        V_vec v_vec;
+        if constexpr (KV_CACHE_DTYPE == INT8) {
+          // dequant and conversion
+          V_quant_vec v_vec_quant = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+          using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
+          V_dequant_vec v_vec_dequant = int8::dequant(v_vec_quant, v_scale, v_zp);
+          v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
+#ifdef ENABLE_FP8_E5M2
+        } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
+          V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+          // Vector conversion from V_quant_vec to V_vec.
+          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
+#endif
+        } else {
+          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
+        }
+        if (block_idx == num_context_blocks - 1) {
+          // NOTE: When v_vec contains the tokens that are out of the context,
+          // we should explicitly zero out the values since they may contain NaNs.
+          scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
+#pragma unroll
+          for (int j = 0; j < V_VEC_SIZE; j++) {
+            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
+          }
+        }
+        accs[i] += dot(logits_vec, v_vec);
+      }
+    }
+  }
+
+  // Perform reduction within each warp.
+#pragma unroll
+  for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+    float acc = accs[i];
+#pragma unroll
+    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
+      acc += APHRODITE_SHFL_XOR_SYNC(acc, mask);
+    }
+    accs[i] = acc;
+  }
+
+  // NOTE: A barrier is required because the shared memory space for logits
+  // is reused for the output.
+  __syncthreads();
+
+  // Perform reduction across warps.
+  float* out_smem = reinterpret_cast<float*>(shared_mem);
+#pragma unroll
+  for (int i = NUM_WARPS; i > 1; i /= 2) {
+    int mid = i / 2;
+    // Upper warps write to shared memory.
+    if (warp_idx >= mid && warp_idx < i) {
+      float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
+#pragma unroll
+      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+          dst[row_idx] = accs[i];
+        }
+      }
+    }
+    __syncthreads();
+
+    // Lower warps update the output.
+    if (warp_idx < mid) {
+      const float* src = &out_smem[warp_idx * HEAD_SIZE];
+#pragma unroll
+      for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+        const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+        if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+          accs[i] += src[row_idx];
+        }
+      }
+    }
+    __syncthreads();
+  }
+
+  // Write the final output.
+  if (warp_idx == 0) {
+    scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                            + head_idx * max_num_partitions * HEAD_SIZE
+                            + partition_idx * HEAD_SIZE;
+#pragma unroll
+    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+      if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+        from_float(*(out_ptr + row_idx), accs[i]);
+      }
+    }
+  }
+}
+
+// Grid: (num_heads, num_seqs, 1).
+template<
+  typename scalar_t,
+  typename cache_t,
+  int HEAD_SIZE,
+  int BLOCK_SIZE,
+  int NUM_THREADS,
+  kv_cache_dtype KV_CACHE_DTYPE>
+__global__ void paged_attention_v1_kernel(
+  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
+  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+  const int num_kv_heads,                 // [num_heads]
+  const float scale,
+  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+  const int* __restrict__ context_lens,   // [num_seqs]
+  const int max_num_blocks_per_seq,
+  const float* __restrict__ alibi_slopes, // [num_heads]
+  const int q_stride,
+  const int kv_block_stride,
+  const int kv_head_stride,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>(
+    /* exp_sums */ nullptr, /* max_logits */ nullptr,
+    out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
+    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
+}
+
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template<
+  typename scalar_t,
+  typename cache_t,
+  int HEAD_SIZE,
+  int BLOCK_SIZE,
+  int NUM_THREADS,
+  kv_cache_dtype KV_CACHE_DTYPE,
+  int PARTITION_SIZE>
+__global__ void paged_attention_v2_kernel(
+  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
+  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
+  scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
+  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
+  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+  const int num_kv_heads,                 // [num_heads]
+  const float scale,
+  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
+  const int* __restrict__ context_lens,   // [num_seqs]
+  const int max_num_blocks_per_seq,
+  const float* __restrict__ alibi_slopes, // [num_heads]
+  const int q_stride,
+  const int kv_block_stride,
+  const int kv_head_stride,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE, PARTITION_SIZE>(
+    exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
+    block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
+    q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
+}
+
+// Grid: (num_heads, num_seqs).
+template<
+  typename scalar_t,
+  int HEAD_SIZE,
+  int NUM_THREADS,
+  int PARTITION_SIZE>
+__global__ void paged_attention_v2_reduce_kernel(
+  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
+  const float* __restrict__ exp_sums,     // [num_seqs, num_heads, max_num_partitions]
+  const float* __restrict__ max_logits,   // [num_seqs, num_heads, max_num_partitions]
+  const scalar_t* __restrict__ tmp_out,   // [num_seqs, num_heads, max_num_partitions, head_size]
+  const int* __restrict__ context_lens,   // [num_seqs]
+  const int max_num_partitions) {
+  const int num_heads = gridDim.x;
+  const int head_idx = blockIdx.x;
+  const int seq_idx = blockIdx.y;
+  const int context_len = context_lens[seq_idx];
+  const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
+  if (num_partitions == 1) {
+    // No need to reduce. Only copy tmp_out to out.
+    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+    const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                                          + head_idx * max_num_partitions * HEAD_SIZE;
+    for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
+      out_ptr[i] = tmp_out_ptr[i];
+    }
+    // Terminate the thread block.
+    return;
+  }
+
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  const int warp_idx = threadIdx.x / WARP_SIZE;
+  const int lane = threadIdx.x % WARP_SIZE;
+
+  // Size: 2 * num_partitions.
+  extern __shared__ char shared_mem[];
+  // Workspace for reduction.
+  __shared__ float red_smem[2 * NUM_WARPS];
+
+  // Load max logits to shared memory.
+  float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
+  const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+                                           + head_idx * max_num_partitions;
+  float max_logit = -FLT_MAX;
+  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+    const float l = max_logits_ptr[i];
+    shared_max_logits[i] = l;
+    max_logit = fmaxf(max_logit, l);
+  }
+  __syncthreads();
+
+  // Get the global max logit.
+  // Reduce within the warp.
+#pragma unroll
+  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+    max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
+  }
+  if (lane == 0) {
+    red_smem[warp_idx] = max_logit;
+  }
+  __syncthreads();
+  // Reduce across warps.
+  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+    max_logit = fmaxf(max_logit, APHRODITE_SHFL_XOR_SYNC(max_logit, mask));
+  }
+  // Broadcast the max value to all threads.
+  max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
+
+  // Load rescaled exp sums to shared memory.
+  float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
+  const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+                                       + head_idx * max_num_partitions;
+  float global_exp_sum = 0.0f;
+  for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+    float l = shared_max_logits[i];
+    float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
+    global_exp_sum += rescaled_exp_sum;
+    shared_exp_sums[i] = rescaled_exp_sum;
+  }
+  __syncthreads();
+  global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
+  const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
+
+  // Aggregate tmp_out to out.
+  const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+                                        + head_idx * max_num_partitions * HEAD_SIZE;
+  scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+#pragma unroll
+  for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
+    float acc = 0.0f;
+    for (int j = 0; j < num_partitions; ++j) {
+      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
+    }
+    from_float(out_ptr[i], acc);
+  }
+}
+
+} // namespace aphrodite
+
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                        \
+  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                        \
+    ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,    \
+      KV_CACHE_DTYPE>), shared_mem_size);                                                           \
+  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,              \
+  KV_CACHE_DTYPE><<<grid, block, shared_mem_size, stream>>>(                                        \
+    out_ptr,                                                                                        \
+    query_ptr,                                                                                      \
+    key_cache_ptr,                                                                                  \
+    value_cache_ptr,                                                                                \
+    num_kv_heads,                                                                                   \
+    scale,                                                                                          \
+    block_tables_ptr,                                                                               \
+    context_lens_ptr,                                                                               \
+    max_num_blocks_per_seq,                                                                         \
+    alibi_slopes_ptr,                                                                               \
+    q_stride,                                                                                       \
+    kv_block_stride,                                                                                \
+    kv_head_stride,                                                                                 \
+    k_scale,                                                                                        \
+    k_zp,                                                                                           \
+    v_scale,                                                                                        \
+    v_zp);
+
+// TODO: Tune NUM_THREADS.
+template<
+  typename T,
+  typename CACHE_T,
+  int BLOCK_SIZE,
+  kv_cache_dtype KV_CACHE_DTYPE,
+  int NUM_THREADS = 128>
+void paged_attention_v1_launcher(
+  torch::Tensor& out,
+  torch::Tensor& query,
+  torch::Tensor& key_cache,
+  torch::Tensor& value_cache,
+  int num_kv_heads,
+  float scale,
+  torch::Tensor& block_tables,
+  torch::Tensor& context_lens,
+  int max_context_len,
+  const c10::optional<torch::Tensor>& alibi_slopes,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  int num_seqs = query.size(0);
+  int num_heads = query.size(1);
+  int head_size = query.size(2);
+  int max_num_blocks_per_seq = block_tables.size(1);
+  int q_stride = query.stride(0);
+  int kv_block_stride = key_cache.stride(0);
+  int kv_head_stride = key_cache.stride(1);
+
+  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  assert(head_size % thread_group_size == 0);
+
+  // NOTE: alibi_slopes is optional.
+  const float* alibi_slopes_ptr = alibi_slopes ?
+    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+    : nullptr;
+
+  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
+  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
+  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
+  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
+  int* block_tables_ptr = block_tables.data_ptr<int>();
+  int* context_lens_ptr = context_lens.data_ptr<int>();
+
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
+  int logits_size = padded_max_context_len * sizeof(float);
+  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+  // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
+  // Keep that in sync with the logic here!
+  int shared_mem_size = std::max(logits_size, outputs_size);
+
+  dim3 grid(num_heads, num_seqs, 1);
+  dim3 block(NUM_THREADS);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  switch (head_size) {
+    // NOTE: To reduce the compilation time, we only compile for the
+    // head sizes that we use in the model. However, we can easily extend this
+    // to support any head size which is a multiple of 16.
+    case 64:
+      LAUNCH_PAGED_ATTENTION_V1(64);
+      break;
+    case 80:
+      LAUNCH_PAGED_ATTENTION_V1(80);
+      break;
+    case 96:
+      LAUNCH_PAGED_ATTENTION_V1(96);
+      break;
+    case 112:
+      LAUNCH_PAGED_ATTENTION_V1(112);
+      break;
+    case 128:
+      LAUNCH_PAGED_ATTENTION_V1(128);
+      break;
+    case 256:
+      LAUNCH_PAGED_ATTENTION_V1(256);
+      break;
+    default:
+      TORCH_CHECK(false, "Unsupported head size: ", head_size);
+      break;
+  }
+}
+
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)             \
+  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(       \
+    out,                                                                     \
+    query,                                                                   \
+    key_cache,                                                               \
+    value_cache,                                                             \
+    num_kv_heads,                                                            \
+    scale,                                                                   \
+    block_tables,                                                            \
+    context_lens,                                                            \
+    max_context_len,                                                         \
+    alibi_slopes,                                                            \
+    k_scale,                                                                 \
+    k_zp,                                                                    \
+    v_scale,                                                                 \
+    v_zp);
+
+// NOTE: To reduce the compilation time, we omitted block sizes
+// 1, 2, 4, 64, 128, 256.
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)       \
+  switch (block_size) {                                               \
+    case 8:                                                           \
+      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                \
+      break;                                                          \
+    case 16:                                                          \
+      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);               \
+      break;                                                          \
+    case 32:                                                          \
+      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);               \
+      break;                                                          \
+    default:                                                          \
+      TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
+      break;                                                          \
+  }
+
+void paged_attention_v1(
+  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
+  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
+  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
+  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
+  int num_kv_heads,               // [num_heads]
+  float scale,
+  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
+  torch::Tensor& context_lens,    // [num_seqs]
+  int block_size,
+  int max_context_len,
+  const c10::optional<torch::Tensor>& alibi_slopes,
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f) {
+  if (kv_cache_dtype == "auto") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+#ifdef ENABLE_FP8_E5M2
+  } else if (kv_cache_dtype == "fp8_e5m2") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+#endif
+  } else if (kv_cache_dtype == "int8") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+  } else {
+    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+  }
+}
+
+#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
+  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,        \
+  KV_CACHE_DTYPE, PARTITION_SIZE>                                                             \
+  <<<grid, block, shared_mem_size, stream>>>(                                                 \
+    exp_sums_ptr,                                                                             \
+    max_logits_ptr,                                                                           \
+    tmp_out_ptr,                                                                              \
+    query_ptr,                                                                                \
+    key_cache_ptr,                                                                            \
+    value_cache_ptr,                                                                          \
+    num_kv_heads,                                                                             \
+    scale,                                                                                    \
+    block_tables_ptr,                                                                         \
+    context_lens_ptr,                                                                         \
+    max_num_blocks_per_seq,                                                                   \
+    alibi_slopes_ptr,                                                                         \
+    q_stride,                                                                                 \
+    kv_block_stride,                                                                          \
+    kv_head_stride,                                                                           \
+    k_scale,                                                                                  \
+    k_zp,                                                                                     \
+    v_scale,                                                                                  \
+    v_zp);                                                                                    \
+  aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
+  <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
+    out_ptr,                                                                                  \
+    exp_sums_ptr,                                                                             \
+    max_logits_ptr,                                                                           \
+    tmp_out_ptr,                                                                              \
+    context_lens_ptr,                                                                         \
+    max_num_partitions);
+
+template<
+  typename T,
+  typename CACHE_T,
+  int BLOCK_SIZE,
+  kv_cache_dtype KV_CACHE_DTYPE,
+  int NUM_THREADS = 128,
+  int PARTITION_SIZE = 512>
+void paged_attention_v2_launcher(
+  torch::Tensor& out,
+  torch::Tensor& exp_sums,
+  torch::Tensor& max_logits,
+  torch::Tensor& tmp_out,
+  torch::Tensor& query,
+  torch::Tensor& key_cache,
+  torch::Tensor& value_cache,
+  int num_kv_heads,
+  float scale,
+  torch::Tensor& block_tables,
+  torch::Tensor& context_lens,
+  int max_context_len,
+  const c10::optional<torch::Tensor>& alibi_slopes,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  int num_seqs = query.size(0);
+  int num_heads = query.size(1);
+  int head_size = query.size(2);
+  int max_num_blocks_per_seq = block_tables.size(1);
+  int q_stride = query.stride(0);
+  int kv_block_stride = key_cache.stride(0);
+  int kv_head_stride = key_cache.stride(1);
+
+  int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+  assert(head_size % thread_group_size == 0);
+
+  // NOTE: alibi_slopes is optional.
+  const float* alibi_slopes_ptr = alibi_slopes ?
+    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+    : nullptr;
+
+  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
+  float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
+  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
+  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
+  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
+  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
+  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
+  int* block_tables_ptr = block_tables.data_ptr<int>();
+  int* context_lens_ptr = context_lens.data_ptr<int>();
+
+  constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+  int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
+  int logits_size = PARTITION_SIZE * sizeof(float);
+  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+
+  // For paged attention v2 kernel.
+  dim3 grid(num_heads, num_seqs, max_num_partitions);
+  int shared_mem_size = std::max(logits_size, outputs_size);
+  // For paged attention v2 reduce kernel.
+  dim3 reduce_grid(num_heads, num_seqs);
+  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
+
+  dim3 block(NUM_THREADS);
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  switch (head_size) {
+    // NOTE: To reduce the compilation time, we only compile for the
+    // head sizes that we use in the model. However, we can easily extend this
+    // to support any head size which is a multiple of 16.
+    case 64:
+      LAUNCH_PAGED_ATTENTION_V2(64);
+      break;
+    case 80:
+      LAUNCH_PAGED_ATTENTION_V2(80);
+      break;
+    case 96:
+      LAUNCH_PAGED_ATTENTION_V2(96);
+      break;
+    case 112:
+      LAUNCH_PAGED_ATTENTION_V2(112);
+      break;
+    case 128:
+      LAUNCH_PAGED_ATTENTION_V2(128);
+      break;
+    case 256:
+      LAUNCH_PAGED_ATTENTION_V2(256);
+      break;
+    default:
+      TORCH_CHECK(false, "Unsupported head size: ", head_size);
+      break;
+  }
+}
+
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)                 \
+  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(           \
+    out,                                                                         \
+    exp_sums,                                                                    \
+    max_logits,                                                                  \
+    tmp_out,                                                                     \
+    query,                                                                       \
+    key_cache,                                                                   \
+    value_cache,                                                                 \
+    num_kv_heads,                                                                \
+    scale,                                                                       \
+    block_tables,                                                                \
+    context_lens,                                                                \
+    max_context_len,                                                             \
+    alibi_slopes,                                                                \
+    k_scale,                                                                     \
+    k_zp,                                                                        \
+    v_scale,                                                                     \
+    v_zp);
+
+// NOTE: To reduce the compilation time, we omitted block sizes
+// 1, 2, 4, 64, 128, 256.
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)             \
+  switch (block_size) {                                                     \
+    case 8:                                                                 \
+      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                      \
+      break;                                                                \
+    case 16:                                                                \
+      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);                     \
+      break;                                                                \
+    case 32:                                                                \
+      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);                     \
+      break;                                                                \
+    default:                                                                \
+      TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
+      break;                                                                \
+  }
+
+void paged_attention_v2(
+  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
+  torch::Tensor& exp_sums,        // [num_seqs, num_heads, max_num_partitions]
+  torch::Tensor& max_logits,      // [num_seqs, num_heads, max_num_partitions]
+  torch::Tensor& tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
+  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
+  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
+  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
+  int num_kv_heads,               // [num_heads]
+  float scale,
+  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
+  torch::Tensor& context_lens,    // [num_seqs]
+  int block_size,
+  int max_context_len,
+  const c10::optional<torch::Tensor>& alibi_slopes,
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f) {
+  if (kv_cache_dtype == "auto") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, AUTO);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, AUTO);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+#ifdef ENABLE_FP8_E5M2
+  } else if (kv_cache_dtype == "fp8_e5m2") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, FP8_E5M2);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, FP8_E5M2);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+#endif
+  } else if (kv_cache_dtype == "int8") {
+    if (query.dtype() == at::ScalarType::Float) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(float, int8_t, INT8);
+    } else if (query.dtype() == at::ScalarType::Half) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, int8_t, INT8);
+    } else if (query.dtype() == at::ScalarType::BFloat16) {
+      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, int8_t, INT8);
+    } else {
+      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+    }
+  } else {
+    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+  }
+}
+
+#undef WARP_SIZE
+#undef MAX
+#undef MIN
+#undef DIVIDE_ROUND_UP

+ 39 - 0
kernels/backup/cache.h

@@ -0,0 +1,39 @@
+#pragma once
+
+#include <torch/extension.h>
+
+#include <map>
+#include <vector>
+
+void swap_blocks(
+  torch::Tensor& src,
+  torch::Tensor& dst,
+  const std::map<int64_t, int64_t>& block_mapping);
+
+void copy_blocks(
+  std::vector<torch::Tensor>& key_caches,
+  std::vector<torch::Tensor>& value_caches,
+  const std::map<int64_t, std::vector<int64_t>>& block_mapping);
+
+void reshape_and_cache(
+  torch::Tensor& key,   
+  torch::Tensor& value, 
+  torch::Tensor& key_cache, 
+  torch::Tensor& value_cache, 
+  torch::Tensor& slot_mapping, 
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f);
+
+void gather_cached_kv(
+  torch::Tensor& key,
+  torch::Tensor& value,
+  torch::Tensor& key_cache,
+  torch::Tensor& value_cache,
+  torch::Tensor& slot_mapping);
+
+void convert_fp8_e5m2(
+  torch::Tensor& src_cache,
+  torch::Tensor& dst_cache);

+ 512 - 0
kernels/backup/cache_kernels.cu

@@ -0,0 +1,512 @@
+#include <torch/extension.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
+
+#include "cuda_compat.h"
+#include "dispatch_utils.h"
+#include "quantization/int8_kvcache/quant_utils.cuh"
+#ifdef ENABLE_FP8_E5M2
+#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+#endif
+
+#include <algorithm>
+#include <cassert>
+#include <map>
+#include <vector>
+
+enum kv_cache_dtype {
+  AUTO,
+#ifdef ENABLE_FP8_E5M2
+  FP8_E5M2,
+#endif
+  INT8};
+
+#ifdef USE_ROCM
+  #include <hip/hip_bf16.h>
+  typedef __hip_bfloat16 __nv_bfloat16;
+#endif
+
+void swap_blocks(
+  torch::Tensor& src,
+  torch::Tensor& dst,
+  const std::map<int64_t, int64_t>& block_mapping) {
+  torch::Device src_device = src.device();
+  torch::Device dst_device = dst.device();
+  cudaMemcpyKind memcpy_type;
+  if (src_device.is_cuda() && dst_device.is_cuda()) {
+    TORCH_CHECK(
+      src_device.index() == dst_device.index(),
+      "src and dst must be on the same GPU");
+    memcpy_type = cudaMemcpyDeviceToDevice;
+  } else if (src_device.is_cuda() && dst_device.is_cpu()) {
+    memcpy_type = cudaMemcpyDeviceToHost;
+  } else if (src_device.is_cpu() && dst_device.is_cuda()) {
+    memcpy_type = cudaMemcpyHostToDevice;
+  } else {
+    TORCH_CHECK(false, "Invalid device combination");
+  }
+
+  char *src_ptr = static_cast<char*>(src.data_ptr());
+  char *dst_ptr = static_cast<char*>(dst.data_ptr());
+
+  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
+  const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  // NOTE: This can be slow if the number of blocks is large.
+  for (const auto& pair : block_mapping) {
+    int64_t src_block_number = pair.first;
+    int64_t dst_block_number = pair.second;
+    int64_t src_offset = src_block_number * block_size_in_bytes;
+    int64_t dst_offset = dst_block_number * block_size_in_bytes;
+    cudaMemcpyAsync(
+      dst_ptr + dst_offset,
+      src_ptr + src_offset,
+      block_size_in_bytes,
+      memcpy_type,
+      stream);
+  }
+}
+
+namespace aphrodite {
+
+// Grid: (num_layers, num_pairs)
+template<typename scalar_t>
+__global__ void copy_blocks_kernel(
+  int64_t* key_cache_ptrs,
+  int64_t* value_cache_ptrs,
+  const int64_t* __restrict__ block_mapping,
+  const int numel_per_block) {
+  const int layer_idx = blockIdx.x;
+  const int pair_idx = blockIdx.y;
+
+  scalar_t* key_cache = reinterpret_cast<scalar_t*>(key_cache_ptrs[layer_idx]);
+  scalar_t* value_cache = reinterpret_cast<scalar_t*>(value_cache_ptrs[layer_idx]);
+  int64_t src_block_number = block_mapping[2 * pair_idx];
+  int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
+
+  const int64_t src_block_offset = src_block_number * numel_per_block;
+  const int64_t dst_block_offset = dst_block_number * numel_per_block;
+  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+    int64_t src_offset = src_block_offset + i;
+    int64_t dst_offset = dst_block_offset + i;
+    key_cache[dst_offset] = key_cache[src_offset];
+  }
+  for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+    int64_t src_offset = src_block_offset + i;
+    int64_t dst_offset = dst_block_offset + i;
+    value_cache[dst_offset] = value_cache[src_offset];
+  }
+}
+
+} // namespace aphrodite
+
+void copy_blocks(
+  std::vector<torch::Tensor>& key_caches,
+  std::vector<torch::Tensor>& value_caches,
+  const std::map<int64_t, std::vector<int64_t>>& block_mapping) {
+  int num_layers = key_caches.size();
+  TORCH_CHECK(num_layers == value_caches.size());
+  if (num_layers == 0) {
+    return;
+  }
+  torch::Device cache_device = key_caches[0].device();
+  TORCH_CHECK(cache_device.is_cuda());
+
+  // Create data structures for the kernel.
+  // Create an array of pointers to the key and value caches.
+  int64_t key_cache_ptrs[num_layers];
+  int64_t value_cache_ptrs[num_layers];
+  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
+    key_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
+    value_cache_ptrs[layer_idx] = reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
+  }
+  // Create block mapping array.
+  std::vector<int64_t> block_mapping_vec;
+  for (const auto& pair : block_mapping) {
+    int64_t src_block_number = pair.first;
+    for (int64_t dst_block_number : pair.second) {
+      block_mapping_vec.push_back(src_block_number);
+      block_mapping_vec.push_back(dst_block_number);
+    }
+  }
+  int64_t* block_mapping_array = block_mapping_vec.data();
+  int num_pairs = block_mapping_vec.size() / 2;
+
+  // Move the data structures to the GPU.
+  // NOTE: This synchronizes the CPU and GPU.
+  torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
+    key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
+  torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
+    value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
+  torch::Tensor block_mapping_tensor = torch::from_blob(
+    block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
+
+  // Launch the kernel.
+  const int numel_per_block = key_caches[0][0].numel();
+  dim3 grid(num_layers, num_pairs);
+  dim3 block(std::min(1024, numel_per_block));
+  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
+    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
+      aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
+        key_cache_ptrs_tensor.data_ptr<int64_t>(),
+        value_cache_ptrs_tensor.data_ptr<int64_t>(),
+        block_mapping_tensor.data_ptr<int64_t>(),
+        numel_per_block);
+    }));
+}
+
+namespace aphrodite {
+
+template<typename scalar_t, typename cache_t, kv_cache_dtype KV_CACHE_DTYPE>
+__global__ void reshape_and_cache_kernel(
+  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
+  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
+  cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x]
+  cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size]
+  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
+  const int key_stride,
+  const int value_stride,
+  const int num_heads,
+  const int head_size,
+  const int block_size,
+  const int x,
+  const float k_scale,
+  const float k_zp,
+  const float v_scale,
+  const float v_zp) {
+  const int64_t token_idx = blockIdx.x;
+  const int64_t slot_idx = slot_mapping[token_idx];
+  if (slot_idx < 0) {
+    // Padding token that should be ignored.
+    return;
+  }
+
+  const int64_t block_idx = slot_idx / block_size;
+  const int64_t block_offset = slot_idx % block_size;
+
+  const int n = num_heads * head_size;
+  for (int i = threadIdx.x; i < n; i += blockDim.x) {
+    const int64_t src_key_idx = token_idx * key_stride + i;
+    const int64_t src_value_idx = token_idx * value_stride + i;
+
+    const int head_idx = i / head_size;
+    const int head_offset = i % head_size;
+    const int x_idx = head_offset / x;
+    const int x_offset = head_offset % x;
+
+    const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+                                + head_idx * (head_size / x) * block_size * x
+                                + x_idx * block_size * x
+                                + block_offset * x
+                                + x_offset;
+    const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
+                                  + head_idx * head_size * block_size
+                                  + head_offset * block_size
+                                  + block_offset;
+    scalar_t tgt_key = key[src_key_idx];
+    scalar_t tgt_value = value[src_value_idx];
+    if constexpr (KV_CACHE_DTYPE == INT8) {
+      key_cache[tgt_key_idx] = int8::quant(tgt_key, k_scale, k_zp);
+      value_cache[tgt_value_idx] = int8::quant(tgt_value, v_scale, v_zp);
+#ifdef ENABLE_FP8_E5M2
+    } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
+      key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
+      value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
+#endif
+    } else {
+      key_cache[tgt_key_idx] = tgt_key;
+      value_cache[tgt_value_idx] = tgt_value;
+    }
+  }
+}
+
+} // namespace aphrodite
+
+#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_CACHE_DTYPE)                                      \
+  aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, KV_CACHE_DTYPE><<<grid, block, 0, stream>>>(  \
+    reinterpret_cast<KV_T*>(key.data_ptr()),                                                       \
+    reinterpret_cast<KV_T*>(value.data_ptr()),                                                     \
+    reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                              \
+    reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                            \
+    slot_mapping.data_ptr<int64_t>(),                                                              \
+    key_stride,                                                                                    \
+    value_stride,                                                                                  \
+    num_heads,                                                                                     \
+    head_size,                                                                                     \
+    block_size,                                                                                    \
+    x,                                                                                             \
+    k_scale,                                                                                       \
+    k_zp,                                                                                          \
+    v_scale,                                                                                       \
+    v_zp);
+
+void reshape_and_cache(
+  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
+  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
+  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
+  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
+  torch::Tensor& slot_mapping,  // [num_tokens]
+  const std::string& kv_cache_dtype,
+  const float k_scale = 1.0f,
+  const float k_zp = 0.0f,
+  const float v_scale = 1.0f,
+  const float v_zp = 0.0f)
+{
+  int num_tokens = key.size(0);
+  int num_heads = key.size(1);
+  int head_size = key.size(2);
+  int block_size = key_cache.size(3);
+  int x = key_cache.size(4);
+
+  int key_stride = key.stride(0);
+  int value_stride = value.stride(0);
+
+  dim3 grid(num_tokens);
+  dim3 block(std::min(num_heads * head_size, 512));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  if (kv_cache_dtype == "auto") {
+    if (key.dtype() == at::ScalarType::Float) {
+      CALL_RESHAPE_AND_CACHE(float, float, AUTO);
+    } else if (key.dtype() == at::ScalarType::Half) {
+      CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, AUTO);
+    } else if (key.dtype() == at::ScalarType::BFloat16) {
+      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, AUTO);
+    }
+#ifdef ENABLE_FP8_E5M2
+  } else if (kv_cache_dtype == "fp8_e5m2") {
+    if (key.dtype() == at::ScalarType::Float) {
+      CALL_RESHAPE_AND_CACHE(float, uint8_t, FP8_E5M2);
+    } else if (key.dtype() == at::ScalarType::Half) {
+      CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, FP8_E5M2);
+    } else if (key.dtype() == at::ScalarType::BFloat16) {
+      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, FP8_E5M2);
+    }
+#endif
+  } else if (kv_cache_dtype == "int8") {
+    if (key.dtype() == at::ScalarType::Float) {
+      CALL_RESHAPE_AND_CACHE(float, int8_t, INT8);
+    } else if (key.dtype() == at::ScalarType::Half) {
+      CALL_RESHAPE_AND_CACHE(uint16_t, int8_t, INT8);
+    } else if (key.dtype() == at::ScalarType::BFloat16) {
+      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, int8_t, INT8);
+    }
+  } else {
+    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+  }
+}
+
+namespace aphrodite {
+
+// Grid: (num_blocks, block_size).
+template<typename scalar_t>
+__global__ void gather_cached_kv_kernel(
+  scalar_t* __restrict__ key,             // [num_tokens, [stride], num_heads, head_size]
+  scalar_t* __restrict__ value,           // [num_tokens, [stride], num_heads, head_size]
+  const scalar_t* __restrict__ key_cache,   // [num_blocks, num_heads, head_size/x, block_size, x]
+  const scalar_t* __restrict__ value_cache,   // [num_blocks, num_heads, head_size, block_size]
+  const int* __restrict__ slot_mapping,   // [num_tokens]
+  const int key_stride,
+  const int value_stride,
+  const int num_heads,
+  const int head_size,
+  const int block_size,
+  const int x) {
+    const int token_idx = blockIdx.x;
+    const int slot_idx = slot_mapping[token_idx];
+    const int block_idx = slot_idx / block_size;
+    const int block_offset = slot_idx % block_size;
+
+    const int num_tokens = num_heads * head_size;
+    for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
+      const int tgt_key_idx = token_idx * key_stride + i;
+      const int tgt_value_idx = token_idx * value_stride + i;
+
+      const int head_idx = i / head_size;
+      const int head_offset = i % head_size;
+      const int x_idx = head_offset / x;  // the offset of the [head_size/x] dimension
+      const int x_offset = head_offset % x;
+
+      const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+                              + head_idx * (head_size / x) * block_size * x
+                              + x_idx * block_size * x
+                              + block_offset * x
+                              + x_offset;
+      const int src_value_idx = block_idx * num_heads * head_size * block_size
+                                + head_idx * head_size * block_size
+                                + head_offset * block_size
+                                + block_offset;
+
+      key[tgt_key_idx] = APHRODITE_LDG(&key_cache[src_key_idx]);
+      value[tgt_value_idx] = APHRODITE_LDG(&value_cache[src_value_idx]);
+    }
+}
+
+template <typename scalar_t>
+__global__ void gather_cached_kv_kernel_optimized(
+    scalar_t *__restrict__ key,             // [num_tokens, [stride], num_heads, head_size]
+    scalar_t *__restrict__ value,           // [num_tokens, [stride], num_heads, head_size]
+    const scalar_t *__restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+    const scalar_t *__restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
+    const int *__restrict__ slot_mapping,   // [num_tokens]
+    const int key_stride,
+    const int value_stride,
+    const int num_heads,
+    const int head_size,
+    const int block_size,
+    const int x)
+{
+    const int token_idx = blockIdx.x;
+    const int slot_idx = slot_mapping[token_idx];
+    const int block_idx = slot_idx / block_size;
+    const int block_offset = slot_idx % block_size;
+
+    const int dim = num_heads * head_size;
+    assert(dim % 4 == 0);  // this is true for known use cases
+    const int unroll_factor = 4;
+    const int unrolled_dim = dim / unroll_factor;
+
+    for (int i = threadIdx.x; i < unrolled_dim; i += blockDim.x)
+    {
+        int tgt_key_indices[unroll_factor];
+        int tgt_value_indices[unroll_factor];
+        int src_key_indices[unroll_factor];
+        int src_value_indices[unroll_factor];
+        scalar_t keys_to_store[unroll_factor];
+        scalar_t values_to_store[unroll_factor];
+
+        #pragma unroll
+        for (int j = 0; j < unroll_factor; ++j)
+        {
+            int index = i + j * unrolled_dim;
+
+            const int tgt_key_idx = token_idx * key_stride + index;
+            const int tgt_value_idx = token_idx * value_stride + index;
+
+            const int head_idx = index / head_size;
+            const int head_offset = index % head_size;
+            const int x_idx = head_offset / x;
+            const int x_offset = head_offset % x;
+
+            const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+                                    + head_idx * (head_size / x) * block_size * x
+                                    + x_idx * block_size * x
+                                    + block_offset * x
+                                    + x_offset;
+            const int src_value_idx = block_idx * num_heads * head_size * block_size
+                                      + head_idx * head_size * block_size
+                                      + head_offset * block_size
+                                      + block_offset;
+
+            tgt_key_indices[j] = tgt_key_idx;
+            tgt_value_indices[j] = tgt_value_idx;
+            src_key_indices[j] = src_key_idx;
+            src_value_indices[j] = src_value_idx;
+
+            keys_to_store[j] = APHRODITE_LDG(&key_cache[src_key_idx]);
+            values_to_store[j] = APHRODITE_LDG(&value_cache[src_value_idx]);
+        }
+
+        #pragma unroll
+        for (int j = 0; j < unroll_factor; ++j)
+        {
+            key[tgt_key_indices[j]] = keys_to_store[j];
+            value[tgt_value_indices[j]] = values_to_store[j];
+        }
+    }
+}
+
+} // namespace aphrodite
+
+void gather_cached_kv(
+  torch::Tensor& key,           // [out] [num_tokens, num_heads, head_size]
+  torch::Tensor& value,         // [out] [num_tokens, num_heads, head_size]
+  torch::Tensor& key_cache,     // [in]  [num_blocks, num_heads, head_size/x, block_size, x]
+  torch::Tensor& value_cache,   // [in]  [num_blocks, num_heads, head_size, block_size]
+  torch::Tensor& slot_mapping)  // [in]  [num_tokens]
+{
+  int num_tokens = key.size(0);
+  int num_heads = key.size(1);
+  int head_size = key.size(2);
+  int block_size = key_cache.size(3);
+  int x = key_cache.size(4);
+
+  int key_stride = key.stride(0);
+  int value_stride = value.stride(0);
+
+  dim3 grid(num_tokens);
+  dim3 block(std::min(num_heads * head_size, 512));
+  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+  APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(
+    key.scalar_type(),
+    "gather_cached_kv_kernel_optimized",
+    [&] {
+      aphrodite::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
+        key.data_ptr<scalar_t>(),
+        value.data_ptr<scalar_t>(),
+        key_cache.data_ptr<scalar_t>(),
+        value_cache.data_ptr<scalar_t>(),
+        slot_mapping.data_ptr<int>(),
+        key_stride,
+        value_stride,
+        num_heads,
+        head_size,
+        block_size,
+        x);
+    });
+}
+
+namespace aphrodite {
+
+template<typename Tout, typename Tin>
+__global__ void convert_fp8_e5m2_kernel(
+  const Tin* __restrict__ src_cache,
+  Tout* __restrict__ dst_cache,
+  const int64_t block_stride) {
+  const int64_t block_idx = blockIdx.x;
+  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
+    int64_t idx = block_idx * block_stride + i;
+#ifdef ENABLE_FP8_E5M2
+    dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
+#else
+    assert(false);
+#endif
+  }
+}
+
+} // namespace aphrodite
+
+#define CALL_CONVERT_FP8_E5M2(Tout, Tin)                                 \
+  aphrodite::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>(  \
+    reinterpret_cast<Tin*>(src_cache.data_ptr()),                        \
+    reinterpret_cast<Tout*>(dst_cache.data_ptr()),                       \
+    block_stride);
+
+void convert_fp8_e5m2(
+  torch::Tensor& src_cache,
+  torch::Tensor& dst_cache)
+{
+  int64_t num_blocks = src_cache.size(0);
+  int64_t block_stride = src_cache.stride(0);
+
+  dim3 grid(num_blocks);
+  dim3 block(std::min(block_stride, int64_t(512)));
+  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+  if (src_cache.dtype() == at::ScalarType::Float) {
+    CALL_CONVERT_FP8_E5M2(uint8_t, float);
+  } else if (src_cache.dtype() == at::ScalarType::Half) {
+    CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
+  } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
+    CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
+  } else if (dst_cache.dtype() == at::ScalarType::Float) {
+    CALL_CONVERT_FP8_E5M2(float, uint8_t);
+  } else if (dst_cache.dtype() == at::ScalarType::Half) {
+    CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
+  } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
+    CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
+  }
+}

+ 39 - 0
kernels/backup/dispatch_utils.h

@@ -0,0 +1,39 @@
+/*
+ * Adapted from
+ * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
+ */
+#pragma once
+
+#include <torch/extension.h>
+
+#define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...)              \
+  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
+  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
+
+#define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \
+  AT_DISPATCH_SWITCH(                                             \
+    TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
+
+
+#define APHRODITE_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...)     \
+  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
+  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)   \
+  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
+
+#define APHRODITE_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...)           \
+  AT_DISPATCH_SWITCH(                                                    \
+    TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
+
+#define APHRODITE_DISPATCH_CASE_INTEGRAL_TYPES(...)             \
+  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)      \
+  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)      \
+  AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)     \
+  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)       \
+  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
+
+#define APHRODITE_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...)             \
+  AT_DISPATCH_SWITCH(                                             \
+    TYPE, NAME, APHRODITE_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))

+ 280 - 0
kernels/backup/dtype_float32.cuh

@@ -0,0 +1,280 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "attention_generic.cuh"
+
+#include <stdint.h>
+
+namespace aphrodite {
+
+// Define custom FP32 vector data types.
+struct Float4_ {
+  float2 x;
+  float2 y;
+};
+
+struct Float8_ {
+  float2 x;
+  float2 y;
+  float2 z;
+  float2 w;
+};
+
+// FP32 vector types for Q, K, V.
+template<>
+struct Vec<float, 1> {
+  using Type = float;
+};
+template<>
+struct Vec<float, 2> {
+  using Type = float2;
+};
+template<>
+struct Vec<float, 4> {
+  using Type = float4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template<>
+struct FloatVec<float> {
+  using Type = float;
+};
+template<>
+struct FloatVec<float2> {
+  using Type = float2;
+};
+template<>
+struct FloatVec<float4> {
+  using Type = float4;
+};
+
+// Vector addition.
+inline __device__ float add(float a, float b) {
+  return a + b;
+}
+
+inline __device__ float2 add(float2 a, float2 b) {
+  float2 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  return c;
+}
+
+inline __device__ float4 add(float4 a, float4 b) {
+  float4 c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  c.z = add(a.z, b.z);
+  c.w = add(a.w, b.w);
+  return c;
+}
+
+inline __device__ Float4_ add(Float4_ a, Float4_ b) {
+  Float4_ c;
+  c.x = add(a.x, b.x);
+  c.y = add(a.y, b.y);
+  return c;
+}
+
+// Vector multiplication.
+template<>
+inline __device__ float mul<float, float>(float a, float b) {
+  return a * b;
+}
+
+template<>
+inline __device__ float2 mul(float2 a, float2 b) {
+  float2 c;
+  c.x = a.x * b.x;
+  c.y = a.y * b.y;
+  return c;
+}
+
+template<>
+inline __device__ float2 mul(float a, float2 b) {
+  float2 c;
+  c.x = a * b.x;
+  c.y = a * b.y;
+  return c;
+}
+
+template<>
+inline __device__ float4 mul(float4 a, float4 b) {
+  float4 c;
+  c.x = a.x * b.x;
+  c.y = a.y * b.y;
+  c.z = a.z * b.z;
+  c.w = a.w * b.w;
+  return c;
+}
+
+template<>
+inline __device__ float4 mul(float a, float4 b) {
+  float4 c;
+  c.x = a * b.x;
+  c.y = a * b.y;
+  c.z = a * b.z;
+  c.w = a * b.w;
+  return c;
+}
+
+// Vector fused multiply-add.
+inline __device__ float fma(float a, float b, float c) {
+  return a * b + c;
+}
+
+inline __device__ float2 fma(float2 a, float2 b, float2 c) {
+  float2 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  return d;
+}
+
+inline __device__ float2 fma(float a, float2 b, float2 c) {
+  float2 d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  return d;
+}
+
+inline __device__ float4 fma(float4 a, float4 b, float4 c) {
+  float4 d;
+  d.x = fma(a.x, b.x, c.x);
+  d.y = fma(a.y, b.y, c.y);
+  d.z = fma(a.z, b.z, c.z);
+  d.w = fma(a.w, b.w, c.w);
+  return d;
+}
+
+inline __device__ float4 fma(float a, float4 b, float4 c) {
+  float4 d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  d.z = fma(a, b.z, c.z);
+  d.w = fma(a, b.w, c.w);
+  return d;
+}
+
+inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
+  Float4_ d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  return d;
+}
+
+inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
+  Float8_ d;
+  d.x = fma(a, b.x, c.x);
+  d.y = fma(a, b.y, c.y);
+  d.z = fma(a, b.z, c.z);
+  d.w = fma(a, b.w, c.w);
+  return d;
+}
+
+// Vector sum.
+template<>
+inline __device__ float sum(float v) {
+  return v;
+}
+
+template<>
+inline __device__ float sum(float2 v) {
+  return v.x + v.y;
+}
+
+template<>
+inline __device__ float sum(float4 v) {
+  return v.x + v.y + v.z + v.w;
+}
+
+template<>
+inline __device__ float sum(Float4_ v) {
+  return v.x.x + v.x.y + v.y.x + v.y.y;
+}
+
+template<>
+inline __device__ float sum(Float8_ v) {
+  return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
+}
+
+// Vector dot product.
+inline __device__ float dot(float a, float b) {
+  return a * b;
+}
+
+inline __device__ float dot(float2 a, float2 b) {
+  float2 c = mul<float2, float2, float2>(a, b);
+  return c.x + c.y;
+}
+
+inline __device__ float dot(Float4_ a, Float4_ b) {
+  float2 acc = mul<float2, float2, float2>(a.x, b.x);
+  acc = fma(a.y, b.y, acc);
+  return acc.x + acc.y;
+}
+
+inline __device__ float dot(Float8_ a, Float8_ b) {
+  float2 acc = mul<float2, float2, float2>(a.x, b.x);
+  acc = fma(a.y, b.y, acc);
+  acc = fma(a.z, b.z, acc);
+  acc = fma(a.w, b.w, acc);
+  return acc.x + acc.y;
+}
+
+// From float to float.
+inline __device__ void from_float(float& dst, float src) {
+  dst = src;
+}
+
+inline __device__ void from_float(float2& dst, float2 src) {
+  dst = src;
+}
+
+inline __device__ void from_float(float4& dst, float4 src) {
+  dst = src;
+}
+
+// From float to float.
+inline __device__ float to_float(float u) {
+  return u;
+}
+
+inline __device__ float2 to_float(float2 u) {
+  return u;
+}
+
+inline __device__ float4 to_float(float4 u) {
+  return u;
+}
+
+inline __device__ Float4_ to_float(Float4_ u) {
+  return u;
+}
+
+inline __device__ Float8_ to_float(Float8_ u) {
+  return u;
+}
+
+// Zero-out a variable.
+inline __device__ void zero(float& dst) {
+  dst = 0.f;
+}
+
+} // namespace aphrodite

+ 0 - 0
kernels/attention/dtype_int8.cuh → kernels/backup/dtype_int8.cuh


Some files were not shown because too many files changed in this diff