Prechádzať zdrojové kódy

feat: replace Ray with NCCL for control plane comms (#221)

* nccl comms: take 4

* Revert "nccl comms: take 4"

This reverts commit 69c04be033ed62918b3465339ce1f22630682d37.

* nccl comms: take 5

* add 2/3/8bit gptq again

* add fp8 e5m2 kv cache kernels

* plumb fp8 kvcache back into engine

* fix cache kernels

* why was all that in cache header?

* attempt fix

* Revert "attempt fix"

This reverts commit 6963c0aeeb000b0cd694ae45dc35d9a3358c7cf8.

* Revert "why was all that in cache header?"

This reverts commit 5a4843e9b1fa3779d7c2f223ef2a841c979301cc.

* Revert "fix cache kernels"

This reverts commit 987f1f910ef6e9e801342d12ceda9398f2bf6f5b.

* Revert "plumb fp8 kvcache back into engine"

This reverts commit 2276aa57e8e9f944aded56aa7fb0bf09f0a68fb9.

* Revert "add fp8 e5m2 kv cache kernels"

This reverts commit 6b61496a355ca7800f45f938d5165304aee0b17a.

* add prefix cache support

* add prefix cache to llm entrypoint

* fix prefix test

* don't download both bin and safetensors

* fix crash in preprare_prompt

* simplify broadcast logic for control messages

* whoops

* refactor add group selector in broadcast

* missed this

* use the correct device when creating OptionalCUDAGuard

* add fp8 kernels

* only build fp8 kernels if cuda 11.8 and above is available

* add fp8 back

* add fp8 to throughput benchmark

* maybe?

* Revert "maybe?"

This reverts commit 667e56321c1db2be34b788cc23e1fc7fe75fc40c.

* Revert "add fp8 to throughput benchmark"

This reverts commit 7d83c677e7baa6f0f661ecb0ace59defe71adae3.

* Revert "add fp8 back"

This reverts commit 3878958bba74095fa28935718c4272759f21ccb3.

* Revert "only build fp8 kernels if cuda 11.8 and above is available"

This reverts commit 39cdca21f5d86ffe240d2d762f56da71f289db11.

* Revert "add fp8 kernels"

This reverts commit fb6c50c29a282c9355f714a9ebff5c8ddcfe0a11.

* add mirostat v2 back in

* persistent data in scheduler

* persistent data in sequence

* formatting

* potential crash if max tokens is None
AlpinDale 1 rok pred
rodič
commit
8fa608aeb7
74 zmenil súbory, kde vykonal 2363 pridanie a 2144 odobranie
  1. 2 0
      .pylintrc
  2. 4 0
      aphrodite/common/block.py
  3. 59 35
      aphrodite/common/config.py
  4. 2 0
      aphrodite/common/logits_processor.py
  5. 1 0
      aphrodite/common/outputs.py
  6. 82 0
      aphrodite/common/prefix.py
  7. 13 16
      aphrodite/common/sampling_params.py
  8. 22 6
      aphrodite/common/sequence.py
  9. 16 20
      aphrodite/common/utils.py
  10. 18 11
      aphrodite/endpoints/llm.py
  11. 4 4
      aphrodite/endpoints/openai/api_server.py
  12. 282 148
      aphrodite/engine/aphrodite_engine.py
  13. 37 26
      aphrodite/engine/args_tools.py
  14. 113 44
      aphrodite/engine/async_aphrodite.py
  15. 27 23
      aphrodite/engine/ray_tools.py
  16. 1 0
      aphrodite/modeling/__init__.py
  17. 21 33
      aphrodite/modeling/hf_downloader.py
  18. 9 8
      aphrodite/modeling/layers/activation.py
  19. 22 29
      aphrodite/modeling/layers/attention.py
  20. 10 10
      aphrodite/modeling/layers/layernorm.py
  21. 6 3
      aphrodite/modeling/layers/linear.py
  22. 0 368
      aphrodite/modeling/layers/moe.py
  23. 5 9
      aphrodite/modeling/layers/quantization/__init__.py
  24. 4 12
      aphrodite/modeling/layers/quantization/awq.py
  25. 5 41
      aphrodite/modeling/layers/quantization/base_config.py
  26. 8 8
      aphrodite/modeling/layers/quantization/gptq.py
  27. 4 7
      aphrodite/modeling/layers/quantization/squeezellm.py
  28. 12 9
      aphrodite/modeling/layers/rotary_embedding.py
  29. 47 68
      aphrodite/modeling/layers/sampler.py
  30. 0 0
      aphrodite/modeling/layers/triton_kernel/__init__.py
  31. 728 0
      aphrodite/modeling/layers/triton_kernel/prefix_prefill.py
  32. 1 1
      aphrodite/modeling/layers/vocab_parallel_embedding.py
  33. 144 3
      aphrodite/modeling/megatron/communication_op.py
  34. 4 4
      aphrodite/modeling/megatron/parallel_state.py
  35. 2 2
      aphrodite/modeling/megatron/utils.py
  36. 12 10
      aphrodite/modeling/metadata.py
  37. 1 1
      aphrodite/modeling/models/__init__.py
  38. 23 18
      aphrodite/modeling/models/decilm.py
  39. 8 14
      aphrodite/modeling/models/gpt_j.py
  40. 2 7
      aphrodite/modeling/models/gpt_neox.py
  41. 12 14
      aphrodite/modeling/models/llama.py
  42. 1 1
      aphrodite/modeling/models/mistral.py
  43. 21 48
      aphrodite/modeling/models/mixtral.py
  44. 0 269
      aphrodite/modeling/models/phi1_5.py
  45. 19 20
      aphrodite/modeling/models/yi.py
  46. 15 14
      aphrodite/modeling/sampling_metadata.py
  47. 47 13
      aphrodite/processing/block_manager.py
  48. 10 8
      aphrodite/processing/policy.py
  49. 51 33
      aphrodite/processing/scheduler.py
  50. 3 7
      aphrodite/task_handler/cache_engine.py
  51. 155 48
      aphrodite/task_handler/model_runner.py
  52. 63 26
      aphrodite/task_handler/worker.py
  53. 59 0
      examples/prefix_cache_example.py
  54. 1 1
      kernels/activation_kernels.cu
  55. 0 1
      kernels/attention/attention_dtypes.h
  56. 85 138
      kernels/attention/attention_kernels.cu
  57. 0 31
      kernels/attention/dtype_fp8.cuh
  58. 3 5
      kernels/cache.h
  59. 27 98
      kernels/cache_kernels.cu
  60. 1 1
      kernels/cuda_compat.h
  61. 3 2
      kernels/cuda_utils.h
  62. 1 2
      kernels/cuda_utils_kernels.cu
  63. 8 27
      kernels/dispatch_utils.h
  64. 2 1
      kernels/layernorm_kernels.cu
  65. 0 35
      kernels/misc_kernels.cu
  66. 4 11
      kernels/ops.h
  67. 1 1
      kernels/pos_encoding_kernels.cu
  68. 4 15
      kernels/pybind.cpp
  69. 1 1
      kernels/quantization/gptq/matrix_view.cuh
  70. 2 2
      kernels/quantization/gptq/q_gemm.cu
  71. 1 1
      kernels/quantization/gptq/qdq_4.cuh
  72. 0 270
      kernels/quantization/kvcache/quant_utils.cuh
  73. 2 1
      kernels/quantization/squeezellm/quant_cuda_kernel.cu
  74. 0 1
      setup.py

+ 2 - 0
.pylintrc

@@ -127,6 +127,7 @@ disable=abstract-method,
         signature-differs,
         standarderror-builtin,
         suppressed-message,
+        super-init-not-called,
         sys-max-int,
         too-few-public-methods,
         too-many-ancestors,
@@ -148,6 +149,7 @@ disable=abstract-method,
         useless-else-on-loop,
         useless-object-inheritance,
         useless-suppression,
+        useless-return,
         using-cmp-argument,
         wrong-import-order,
         xrange-builtin,

+ 4 - 0
aphrodite/common/block.py

@@ -65,3 +65,7 @@ class PhysicalTokenBlock:
         return (f'PhysicalTokenBlock(device={self.device}, '
                 f'block_number={self.block_number}, '
                 f'ref_count={self.ref_count})')
+
+
+# Mapping: logical block number -> physical block.
+BlockTable = List[PhysicalTokenBlock]

+ 59 - 35
aphrodite/common/config.py

@@ -1,4 +1,5 @@
 from typing import Optional, Union
+import os
 
 import torch
 from transformers import PretrainedConfig
@@ -41,16 +42,19 @@ class ModelConfig:
         revision: The specific model version to use. It can be a branch name,
             a tag name, or a commit id. If unspecified, will use the default
             version.
+        tokenizer_revision: The specific tokenizer version to use. It can be a
+            branch name, a tag name, or a commit id. If unspecified, will use
+            the default version.
         max_model_len: Maximum length of a sequence (including prompt and
             output). If None, will be derived from the model.
         quantization: Quantization method that was used to quantize the model
-            weights. If None, we assume the model weights are not quantized
+            weights. If None, we assume the model weights are not quantized.
         enforce_eager: Whether to enforce eager execution. If True, we will
             disable CUDA graph and always execute the model in eager mode.
             If False, we will use CUDA graph and eager execution in hybrid.
         max_context_len_to_capture: Maximum context len covered by CUDA graphs.
-            When a sequence has context length larger than this, we will fall
-            back to eager mode.
+            When a sequence has context length larger than this, we fall back
+            to eager mode.
     """
 
     def __init__(
@@ -64,6 +68,7 @@ class ModelConfig:
         dtype: Union[str, torch.dtype],
         seed: int,
         revision: Optional[str] = None,
+        tokenizer_revision: Optional[str] = None,
         max_model_len: Optional[int] = None,
         quantization: Optional[str] = None,
         enforce_eager: bool = False,
@@ -77,11 +82,24 @@ class ModelConfig:
         self.load_format = load_format
         self.seed = seed
         self.revision = revision
+        self.tokenizer_revision = tokenizer_revision
         self.quantization = quantization
         self.enforce_eager = enforce_eager
         self.max_context_len_to_capture = max_context_len_to_capture
 
-        self.hf_config = get_config(model, trust_remote_code, revision)
+        if os.environ.get("APHRODITE_USE_MODELSCOPE",
+                          "False").lower() == "true":
+            # download model from ModelScope hub,
+            # lazy import so that modelscope is not required for normal use.
+            from modelscope.hub.snapshot_download import snapshot_download  # pylint: disable=C
+            model_path = snapshot_download(model_id=model,
+                                           cache_dir=download_dir,
+                                           revision=revision)
+            self.model = model_path
+            self.download_dir = model_path
+            self.tokenizer = model_path
+
+        self.hf_config = get_config(self.model, trust_remote_code, revision)
         self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
         self.max_model_len = _get_and_verify_max_len(self.hf_config,
                                                      max_model_len)
@@ -116,7 +134,6 @@ class ModelConfig:
             raise ValueError(
                 "Currently, the 'pt' format is not supported for Mixtral. "
                 "Please use the 'safetensors' format instead. ")
-
         self.load_format = load_format
 
     def _verify_tokenizer_mode(self) -> None:
@@ -128,35 +145,37 @@ class ModelConfig:
         self.tokenizer_mode = tokenizer_mode
 
     def _verify_quantization(self) -> None:
-        supported_quantization = ["awq", "squeezellm", "gptq"]
+        supported_quantization = ["awq", "gptq", "squeezellm"]
         rocm_not_supported_quantization = ["awq"]
         if self.quantization is not None:
             self.quantization = self.quantization.lower()
 
-        hf_quant_config = getattr(self.hf_config, "quant_config", None)
+        # Parse quantization method from the HF model config, if available.
+        hf_quant_config = getattr(self.hf_config, "quantization_config", None)
         if hf_quant_config is not None:
             hf_quant_method = str(hf_quant_config["quant_method"]).lower()
             if self.quantization is None:
                 self.quantization = hf_quant_method
             elif self.quantization != hf_quant_method:
                 raise ValueError(
-                    f"Model quantization method is {hf_quant_method} "
-                    f"but quantization argument is {self.quantization}. "
-                    "Please use the same quantization method.")
+                    "Quantization method specified in the model config "
+                    f"({hf_quant_method}) does not match the quantization "
+                    f"method specified in the `quantization` argument "
+                    f"({self.quantization}).")
+
         if self.quantization is not None:
             if self.quantization not in supported_quantization:
                 raise ValueError(
-                    f"Unknown quantization method: {self.quantization}. "
-                    f"Must be one of {supported_quantization}.")
+                    f"Unknown quantization method: {self.quantization}. Must "
+                    f"be one of {supported_quantization}.")
             if is_hip(
             ) and self.quantization in rocm_not_supported_quantization:
                 raise ValueError(
-                    f"{self.quantization} quantization method is currently "
-                    "not supported in ROCm.")
-        if self.quantization is not None:
+                    f"{self.quantization} quantization is currently not "
+                    "supported in ROCm.")
             logger.warning(f"{self.quantization} quantization is not fully "
                            "optimized yet. The speed can be slower than "
-                           "non-quantized models (16/32bit).")
+                           "non-quantized models.")
 
     def _verify_cuda_graph(self) -> None:
         if self.max_context_len_to_capture is None:
@@ -199,6 +218,10 @@ class ModelConfig:
 
     def get_total_num_kv_heads(self) -> int:
         """Returns the total number of KV heads."""
+        # For GPTBigCode & Falcon:
+        # NOTE: for falcon, when new_decoder_architecture is True, the
+        # multi_query flag is ignored and we use n_head_kv for the number of
+        # KV heads.
         falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
         new_decoder_arch_falcon = (
             self.hf_config.model_type in falcon_model_types
@@ -210,9 +233,12 @@ class ModelConfig:
             return 1
 
         attributes = [
+            # For Falcon:
             "n_head_kv",
             "num_kv_heads",
+            # For LLaMA-2:
             "num_key_value_heads",
+            # For ChatGLM:
             "multi_query_group_num",
         ]
         for attr in attributes:
@@ -234,9 +260,6 @@ class ModelConfig:
         return max(1,
                    total_num_kv_heads // parallel_config.tensor_parallel_size)
 
-    def get_max_model_len(self) -> int:
-        return self.max_model_len
-
     def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
         total_num_hidden_layers = self.hf_config.num_hidden_layers
         return total_num_hidden_layers // parallel_config.pipeline_parallel_size
@@ -250,7 +273,6 @@ class CacheConfig:
         gpu_memory_utilization: Fraction of GPU memory to use for the
             Aphrodite execution.
         swap_space: Size of the CPU swap space per GPU (in GiB).
-        cache_dtype: Data type fro the KV cache.
     """
 
     def __init__(
@@ -258,17 +280,11 @@ class CacheConfig:
         block_size: int,
         gpu_memory_utilization: float,
         swap_space: int,
-        cache_dtype: str,
         sliding_window: Optional[int] = None,
     ) -> None:
         self.block_size = block_size
         self.gpu_memory_utilization = gpu_memory_utilization
         self.swap_space_bytes = swap_space * _GB
-        self.cache_dtype = cache_dtype
-        if cache_dtype and "fp8" in cache_dtype.lower():
-            # As FP8 is not a formal data type, we use
-            # torch.uint8 instead.
-            self.cache_dtype = torch.uint8
         self.sliding_window = sliding_window
         self._verify_args()
 
@@ -358,6 +374,8 @@ class SchedulerConfig:
         if max_num_batched_tokens is not None:
             self.max_num_batched_tokens = max_num_batched_tokens
         else:
+            # If max_model_len is too short, use 2048 as the default value for
+            # higher throughput.
             self.max_num_batched_tokens = max(max_model_len, 2048)
         self.max_num_seqs = max_num_seqs
         self.max_model_len = max_model_len
@@ -369,10 +387,10 @@ class SchedulerConfig:
             raise ValueError(
                 f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                 f"smaller than max_model_len ({self.max_model_len}). "
-                f"This effectively limits the maximum sequence length to "
-                f"max_num_batched_tokens and makes Aphrodite reject longer "
-                f"sequences. Please increase max_num_batched_tokens or "
-                f"decrease max_model_len.")
+                "This effectively limits the maximum sequence length to "
+                "max_num_batched_tokens and makes Aphrodite reject longer "
+                "sequences. Please increase max_num_batched_tokens or "
+                "decrease max_model_len.")
         if self.max_num_batched_tokens < self.max_num_seqs:
             raise ValueError(
                 f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
@@ -405,20 +423,19 @@ def _get_and_verify_dtype(
         dtype = dtype.lower()
         if dtype == "auto":
             if config_dtype == torch.float32:
+                # Following the common practice, we use float16 for float32
+                # models.
                 torch_dtype = torch.float16
             else:
                 torch_dtype = config_dtype
         else:
             if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
-                raise ValueError(f"Unknown dtype: {dtype}. Must be one of "
-                                 f"{list(_STR_DTYPE_TO_TORCH_DTYPE.keys())}.")
+                raise ValueError(f"Unknown dtype: {dtype}")
             torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
     elif isinstance(dtype, torch.dtype):
         torch_dtype = dtype
     else:
-        raise ValueError(
-            f"Unknown dtype: {dtype}. Must be either a string or a torch "
-            "dtype.")
+        raise ValueError(f"Unknown dtype: {dtype}")
 
     if is_hip() and torch_dtype == torch.float32:
         rocm_supported_dtypes = [
@@ -450,9 +467,15 @@ def _get_and_verify_max_len(
     """Get and verify the model's maximum length."""
     derived_max_model_len = float("inf")
     possible_keys = [
+        # OPT
         "max_position_embeddings",
+        # GPT-2
         "n_positions",
+        # MPT
         "max_seq_len",
+        # ChatGLM2
+        "seq_length",
+        # Others
         "max_sequence_length",
         "max_seq_length",
         "seq_len",
@@ -465,6 +488,7 @@ def _get_and_verify_max_len(
         if max_model_len is not None:
             # If max_model_len is specified, we use it.
             return max_model_len
+
         default_max_len = 2048
         logger.warning(
             "The model's config.json does not contain any of the following "

+ 2 - 0
aphrodite/common/logits_processor.py

@@ -23,6 +23,7 @@ class BiasLogitsProcessor(LogitsProcessor):
     """
 
     def __init__(self, biases: Dict[int, float]):
+        super().__init__()
         self.biases = biases
 
         if not biases:
@@ -52,6 +53,7 @@ class BanEOSUntil(LogitsProcessor):
     parameters can be handled gracefully."""
 
     def __init__(self, min_tokens: int, eos_token_id: int):
+        super().__init__()
         self._min_tokens = min_tokens
         self._eos_token_id = eos_token_id
 

+ 1 - 0
aphrodite/common/outputs.py

@@ -53,6 +53,7 @@ class RequestOutput:
         request_id: The unique ID of the request.
         prompt: The prompt string of the request.
         prompt_token_ids: The token IDs of the prompt.
+        prompt_logprobs: The log probabilities to return per prompt token.
         outputs: The output sequences of the request.
         finished: Whether the whole request is finished.
     """

+ 82 - 0
aphrodite/common/prefix.py

@@ -0,0 +1,82 @@
+from typing import Dict, List, Sequence, Tuple, Optional
+
+from aphrodite.common.block import BlockTable
+
+
+class Prefix:
+    """Data and states associated with a prefix of prompt tokens for multiple
+    sequence groups.
+    NOTE: This feature is experimental and may be replaced with automatic
+        prefix caching in the future.
+    Args:
+        prefix_id: The id of the prefix in the prefix pool.
+        token_ids: The token ids of the prefix.
+        block_size: The block size of the executed model.
+    """
+
+    def __init__(
+        self,
+        token_ids: Sequence[int],
+        block_size: int,
+    ) -> None:
+        self.token_ids = tuple(token_ids)
+        self.block_size = block_size
+        self.length = len(token_ids)
+        self.hash = hash(token_ids)
+        assert self.length % block_size == 0
+        self.block_table: Optional[BlockTable] = None
+        self.computed = False
+
+    @property
+    def allocated(self) -> bool:
+        return self.block_table is not None
+
+    def get_num_blocks(self) -> int:
+        return self.length // self.block_size
+
+    def get_block_numbers(self) -> List[int]:
+        return [block.block_number for block in self.block_table]
+
+    def get_length(self) -> int:
+        return self.length
+
+    def __hash__(self) -> int:
+        return self.hash
+
+    def set_block_table(self, block_table: BlockTable) -> None:
+        self.block_table = block_table.copy()
+
+
+class PrefixPool:
+    """Manages all the prompt prefixes.
+    NOTE: This feature is experimental and may be replaced with automatic
+        prefix caching in the future.
+    Args:
+        block_size: The block size of the executed model.
+    Attributes:
+        prefixes: A list of all the prefixes.
+        block_size: The block size of the executed model.
+    """
+
+    def __init__(
+        self,
+        block_size: int,
+    ) -> None:
+        # TODO: Add a capacity limit to the prefix pool.
+        self.prefixes: Dict[int, Prefix] = {}
+        self.block_size = block_size
+
+    def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]:
+        new_length = len(token_ids) // self.block_size * self.block_size
+        return tuple(token_ids[:new_length])
+
+    def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]:
+        token_ids = self._truncate_token_ids(token_ids)
+        if len(token_ids) == 0:
+            # Prefix is empty.
+            return None
+        prefix = Prefix(token_ids, self.block_size)
+        prefix_hash = hash(prefix)
+        if prefix_hash not in self.prefixes:
+            self.prefixes[prefix_hash] = prefix
+        return self.prefixes[prefix_hash]

+ 13 - 16
aphrodite/common/sampling_params.py

@@ -1,8 +1,9 @@
 """Sampling parameters for text generation."""
 from enum import IntEnum
 from functools import cached_property
-from typing import List, Optional, Union
-from aphrodite.common.logits_processor import LogitsProcessor
+from typing import Callable, List, Optional, Union
+
+import torch
 
 _SAMPLING_EPS = 1e-5
 
@@ -13,6 +14,12 @@ class SamplingType(IntEnum):
     BEAM = 2
 
 
+LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
+"""LogitsProcessor is a function that takes a list of previously generated
+tokens and a tensor of the logits for the next token, and returns a modified
+tensor of logits to sample from."""
+
+
 class SamplingParams:
     """Sampling parameters for text generation.
 
@@ -140,13 +147,13 @@ class SamplingParams:
         stop_token_ids: List[int] = None,
         include_stop_str_in_output: bool = False,
         ignore_eos: bool = False,
-        max_tokens: int = 16,
+        max_tokens: Optional[int] = 16,
         logprobs: Optional[int] = None,
         prompt_logprobs: Optional[int] = None,
         custom_token_bans: Optional[List[int]] = None,
         skip_special_tokens: bool = True,
         spaces_between_special_tokens: bool = True,
-        logits_processors: List[LogitsProcessor] = None,
+        logits_processors: Optional[List[LogitsProcessor]] = None,
     ) -> None:
         self.n = n
         self.best_of = best_of if best_of is not None else n
@@ -190,7 +197,7 @@ class SamplingParams:
         self.logits_processors = logits_processors or []
         self.include_stop_str_in_output = include_stop_str_in_output
 
-        self.verify()
+        self._verify_args()
         if self.use_beam_search:
             self._verify_beam_search()
         else:
@@ -203,16 +210,6 @@ class SamplingParams:
                 self.top_a = 0.0
                 self._verify_greedy_sampling()
 
-    def verify(self) -> None:
-        self._verify_args()
-        if self.use_beam_search:
-            self._verify_beam_search()
-        else:
-            self._verify_non_beam_search()
-            if self.temperature < _SAMPLING_EPS:
-                # Zero temperature means greedy sampling.
-                self._verify_greedy_sampling()
-
     def _verify_args(self) -> None:
         if self.n < 1:
             raise ValueError(f"n must be at least 1, got {self.n}.")
@@ -269,7 +266,7 @@ class SamplingParams:
             if not self.mirostat_tau >= 0:
                 raise ValueError(
                     f"mirostat_tau must be positive, got {self.mirostat_tau}")
-        if self.max_tokens < 1:
+        if self.max_tokens is not None and self.max_tokens < 1:
             raise ValueError(
                 f"max_tokens must be at least 1, got {self.max_tokens}.")
         if self.logprobs is not None and self.logprobs < 0:

+ 22 - 6
aphrodite/common/sequence.py

@@ -4,6 +4,7 @@ import enum
 from typing import Dict, List, Optional, Union
 
 from aphrodite.common.block import LogicalTokenBlock
+from aphrodite.common.prefix import Prefix
 from aphrodite.common.sampling_params import SamplingParams
 
 PromptLogprobs = List[Optional[Dict[int, float]]]
@@ -38,6 +39,9 @@ class SequenceStatus(enum.Enum):
         elif status == SequenceStatus.FINISHED_ABORTED:
             finish_reason = "abort"
         elif status == SequenceStatus.FINISHED_IGNORED:
+            # The ignored sequences are the sequences whose prompt lengths
+            # are longer than the model's length cap. Therefore, the stop
+            # reason should also be "length" as in OpenAI API.
             finish_reason = "length"
         else:
             finish_reason = None
@@ -197,7 +201,7 @@ class Sequence:
         """
         if seq_len is None:
             seq_len = self.get_len()
-            # Note: HF implementation does not count the EOS token
+            # NOTE: HF implementation does not count the EOS token
             # towards the length, we align with that here for testing.
             if (eos_token_id is not None
                     and self.get_last_token_id() == eos_token_id):
@@ -234,11 +238,13 @@ class SequenceGroup:
         seqs: List[Sequence],
         sampling_params: SamplingParams,
         arrival_time: float,
+        prefix: Optional[Prefix] = None,
     ) -> None:
         self.request_id = request_id
         self.seqs_dict = {seq.seq_id: seq for seq in seqs}
         self.sampling_params = sampling_params
         self.arrival_time = arrival_time
+        self.prefix: Optional[Prefix] = prefix
         self.prompt_logprobs: Optional[PromptLogprobs] = None
 
     @property
@@ -333,6 +339,8 @@ class SequenceGroupMetadata:
         sampling_params: The sampling parameters used to generate the outputs.
         block_tables: The block tables. (Seq id -> list of physical block
             numbers)
+        prefix: The prefix of the prompt of the sequence group.
+        persistent_data: The persistent data of the sequence group.
     """
 
     def __init__(
@@ -343,6 +351,7 @@ class SequenceGroupMetadata:
         sampling_params: SamplingParams,
         block_tables: Dict[int, List[int]],
         persistent_data: Dict[int, dict],
+        prefix: Optional[Prefix] = None,
     ) -> None:
         self.request_id = request_id
         self.is_prompt = is_prompt
@@ -350,6 +359,7 @@ class SequenceGroupMetadata:
         self.sampling_params = sampling_params
         self.block_tables = block_tables
         self.persistent_data = persistent_data
+        self.prefix = prefix
 
 
 class SequenceOutput:
@@ -361,10 +371,16 @@ class SequenceOutput:
         output_token: The output token ID.
         logprobs: The logprobs of the output token.
             (Token id -> logP(x_i+1 | x_0, ..., x_i))
+        persistent_data: The persistent data of the sequence.
     """
 
-    def __init__(self, parent_seq_id: int, output_token: int,
-                 logprobs: Dict[int, float], persistent_data: dict) -> None:
+    def __init__(
+        self,
+        parent_seq_id: int,
+        output_token: int,
+        logprobs: Dict[int, float],
+        persistent_data: dict,
+    ) -> None:
         self.parent_seq_id = parent_seq_id
         self.output_token = output_token
         self.logprobs = logprobs
@@ -372,9 +388,9 @@ class SequenceOutput:
 
     def __repr__(self) -> str:
         return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
-                f"output_token={self.output_token}), "
+                f"output_token={self.output_token}, "
                 f"logprobs={self.logprobs}, "
-                f"persistent_data={self.persistent_data}")
+                f"persistent_data={self.persistent_data})")
 
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, SequenceOutput):
@@ -386,7 +402,7 @@ class SequenceOutput:
 
 
 class SequenceGroupOutput:
-    """The model outputs associated with a sequence group."""
+    """The model output associated with a sequence group."""
 
     def __init__(
         self,

+ 16 - 20
aphrodite/common/utils.py

@@ -1,9 +1,9 @@
-"""Utils."""
-from os import path
 import enum
+import os
 import socket
-from platform import uname
 import uuid
+from platform import uname
+from typing import List
 
 import psutil
 import torch
@@ -37,7 +37,6 @@ def is_hip() -> bool:
 def get_max_shared_memory_bytes(gpu: int = 0) -> int:
     """Returns the maximum shared memory per thread block in bytes."""
     # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
-    # pylint: disable=invalid-name
     cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
     max_shared_mem = cuda_utils.get_device_attribute(
         cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
@@ -45,21 +44,8 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int:
 
 
 def get_cpu_memory() -> int:
-    """Returns the total CPU memory of the node or container in bytes."""
-
-    memory_limit = psutil.virtual_memory().total
-
-    for limit_file in [
-            "/sys/fs/cgroup/memory/memory.limit_in_bytes",  # v1
-            "/sys/fs/cgroup/memory.max"  # v2
-    ]:
-        if path.exists(limit_file):
-            with open(limit_file) as f:
-                content = f.read().strip()
-                if content.isnumeric():  # v2 can have "max" as limit
-                    memory_limit = min(memory_limit, int(content))
-
-    return memory_limit
+    """Returns the total CPU memory of the node in bytes."""
+    return psutil.virtual_memory().total
 
 
 def random_uuid() -> str:
@@ -71,7 +57,17 @@ def in_wsl() -> bool:
     return "microsoft" in " ".join(uname()).lower()
 
 
-def get_open_port():
+def get_ip() -> str:
+    s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+    s.connect(("8.8.8.8", 80))  # Doesn't need to be reachable
+    return s.getsockname()[0]
+
+
+def get_open_port() -> int:
     with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
         s.bind(("", 0))
         return s.getsockname()[1]
+
+
+def set_cuda_visible_devices(device_ids: List[int]) -> None:
+    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))

+ 18 - 11
aphrodite/endpoints/llm.py

@@ -115,6 +115,7 @@ class LLM:
         prompts: Optional[Union[str, List[str]]] = None,
         sampling_params: Optional[SamplingParams] = None,
         prompt_token_ids: Optional[List[List[int]]] = None,
+        prefix_pos: Optional[Union[int, List[int]]] = None,
         use_tqdm: bool = True,
     ) -> List[RequestOutput]:
         """Generates the completions for the input prompts.
@@ -129,6 +130,11 @@ class LLM:
                 None, we use the default sampling parameters.
             prompt_token_ids: A list of token IDs for the prompts. If None, we
                 use the tokenizer to convert the prompts to token IDs.
+            prefix_pos: If not None, we use the given position as the prefix
+                position for each prompt. We will cache the prefix's KV
+                cache and reuse it for the next request with the same prefix.
+                This is an experimental feature, and may be replaced with
+                automatic prefix caching in the future.
             use_tqdm: Whether to use tqdm to display the progress bar.
 
         Returns:
@@ -150,17 +156,14 @@ class LLM:
             sampling_params = SamplingParams()
 
         # Add requests to the engine.
-        if prompts is not None:
-            num_requests = len(prompts)
-        else:
-            num_requests = len(prompt_token_ids)
+        num_requests = len(prompts) if prompts is not None else len(
+            prompt_token_ids)
         for i in range(num_requests):
             prompt = prompts[i] if prompts is not None else None
-            if prompt_token_ids is None:
-                token_ids = None
-            else:
-                token_ids = prompt_token_ids[i]
-            self._add_request(prompt, sampling_params, token_ids)
+            prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
+            token_ids = None if prompt_token_ids is None else prompt_token_ids[
+                i]
+            self._add_request(prompt, sampling_params, token_ids, prefix_pos_i)
         return self._run_engine(use_tqdm)
 
     def _add_request(
@@ -168,10 +171,14 @@ class LLM:
         prompt: Optional[str],
         sampling_params: SamplingParams,
         prompt_token_ids: Optional[List[int]],
+        prefix_pos: Optional[int] = None,
     ) -> None:
         request_id = str(next(self.request_counter))
-        self.llm_engine.add_request(request_id, prompt, sampling_params,
-                                    prompt_token_ids)
+        self.llm_engine.add_request(request_id,
+                                    prompt,
+                                    sampling_params,
+                                    prompt_token_ids,
+                                    prefix_pos=prefix_pos)
 
     def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
         # Initialize tqdm.

+ 4 - 4
aphrodite/endpoints/openai/api_server.py

@@ -263,10 +263,10 @@ def create_logprobs(
                 for i, p in step_top_logprobs.items()
             } if step_top_logprobs else None)
 
-    logprobs.top_logprobs = [
-        {k: v if v > -1000 else -1000 for k, v in top_logprob.items()}
-        for top_logprob in logprobs.top_logprobs if top_logprob is not None
-    ]
+    logprobs.top_logprobs = [{
+        k: v if v > -1000 else -1000
+        for k, v in top_logprob.items()
+    } for top_logprob in logprobs.top_logprobs if top_logprob is not None]
 
     return logprobs
 

+ 282 - 148
aphrodite/engine/aphrodite_engine.py

@@ -1,17 +1,16 @@
 import copy
+from collections import defaultdict
 import os
 import time
-from functools import partial
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
-
-import psutil
+from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
+                    Union)
 
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig)
 from aphrodite.processing.scheduler import Scheduler, SchedulerOutputs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.metrics import record_metrics
-from aphrodite.engine.ray_tools import RayWorker, initialize_cluster, ray
+from aphrodite.engine.ray_tools import RayWorkerAphrodite, initialize_cluster, ray
 from aphrodite.common.logger import init_logger
 from aphrodite.common.outputs import RequestOutput
 from aphrodite.common.sampling_params import SamplingParams
@@ -20,10 +19,10 @@ from aphrodite.common.sequence import (SamplerOutput, Sequence, SequenceGroup,
                                        SequenceStatus)
 from aphrodite.transformers_utils.tokenizer import (detokenize_incrementally,
                                                     get_tokenizer)
-from aphrodite.common.utils import Counter
+from aphrodite.common.utils import (Counter, set_cuda_visible_devices, get_ip,
+                                    get_open_port)
 
 if ray:
-    from ray.air.util.torch_dist import init_torch_dist_process_group
     from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
 
 if TYPE_CHECKING:
@@ -56,10 +55,8 @@ class AphroditeEngine:
             management.
         parallel_config: The configuration related to distributed execution.
         scheduler_config: The configuration related to the request scheduler.
-        distributed_init_method: The initialization method for distributed
-            execution. See `torch.distributed.init_process_group` for details.
-        stage_devices: The list of devices for each stage. Each stage is a list
-            of (rank, node_resource, device) tuples.
+        placement_group: Ray placement group for distributed execution.
+            Required for distributed execution.
         log_stats: Whether to log statistics.
     """
 
@@ -69,7 +66,6 @@ class AphroditeEngine:
         cache_config: CacheConfig,
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
-        distributed_init_method: str,
         placement_group: Optional["PlacementGroup"],
         log_stats: bool,
     ) -> None:
@@ -88,7 +84,6 @@ class AphroditeEngine:
             f"Sampler Seed = {model_config.seed}\n"
             f"Context Length = {model_config.max_model_len}\n"
             f"Enforce Eager Mode = {model_config.enforce_eager}\n"
-            f"KV Cache DataType = {cache_config.cache_dtype}\n"
             f"Seed = {model_config.seed}")
         # TODO: Print more configs in debug mode.
 
@@ -103,6 +98,7 @@ class AphroditeEngine:
             model_config.tokenizer,
             tokenizer_mode=model_config.tokenizer_mode,
             trust_remote_code=model_config.trust_remote_code,
+            tokenizer_revision=model_config.tokenizer_revision,
             revision=model_config.revision)
         self.seq_counter = Counter()
 
@@ -114,7 +110,7 @@ class AphroditeEngine:
                 os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
             self._init_workers_ray(placement_group)
         else:
-            self._init_workers(distributed_init_method)
+            self._init_workers()
 
         # Profile the memory usage and initialize the cache.
         self._init_cache()
@@ -129,118 +125,169 @@ class AphroditeEngine:
         # List of (timestamp, num_tokens)
         self.num_generation_tokens: List[Tuple[float, int]] = []
 
-    def _init_workers(self, distributed_init_method: str):
+    def _init_workers(self):
         # Lazy import the Worker to avoid importing torch.cuda/xformers
         # before CUDA_VISIBLE_DEVICES is set in the Worker
-        from aphrodite.task_handler.worker import Worker  # pylint: disable=import-outside-toplevel
+        # pylint: disable=import-outside-toplevel
+        from aphrodite.task_handler.worker import Worker
 
         assert self.parallel_config.world_size == 1, (
             "Ray is required if parallel_config.world_size > 1.")
 
         self.workers: List[Worker] = []
-        worker = Worker(
+        distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
+        self.driver_worker = Worker(
             self.model_config,
             self.parallel_config,
             self.scheduler_config,
-            0,
-            distributed_init_method,
-        )
-        self.workers.append(worker)
-        self._run_workers(
-            "init_model",
-            get_all_outputs=True,
-        )
-        self._run_workers(
-            "load_model",
-            get_all_outputs=True,
-            max_concurrent_workers=self.parallel_config.
-            max_parallel_loading_workers,
+            local_rank=0,
+            rank=0,
+            distributed_init_method=distributed_init_method,
+            is_driver_worker=True,
         )
+        self._run_workers("init_model")
+        self._run_workers("load_model")
 
     def _init_workers_ray(self, placement_group: "PlacementGroup",
                           **ray_remote_kwargs):
-        # Lazy import the Worker to avoid importing torch.cuda/xformers
-        # before CUDA_VISIBLE_DEVICES is set in the Worker
-        from aphrodite.task_handler.worker import Worker  # pylint: disable=import-outside-toplevel
+        if self.parallel_config.tensor_parallel_size == 1:
+            num_gpus = self.cache_config.gpu_memory_utilization
+        else:
+            num_gpus = 1
 
-        self.workers: List[Worker] = []
-        for bundle in placement_group.bundle_specs:
+        self.driver_dummy_worker: RayWorkerAphrodite = None
+        self.workers: List[RayWorkerAphrodite] = []
+
+        driver_ip = get_ip()
+        for bundle_id, bundle in enumerate(placement_group.bundle_specs):
             if not bundle.get("GPU", 0):
                 continue
-            if self.parallel_config.tensor_parallel_size == 1:
-                num_gpus = self.cache_config.gpu_memory_utilization
-            else:
-                num_gpus = 1
+            scheduling_strategy = PlacementGroupSchedulingStrategy(
+                placement_group=placement_group,
+                placement_group_capture_child_tasks=True,
+                placement_group_bundle_index=bundle_id,
+            )
             worker = ray.remote(
                 num_cpus=0,
                 num_gpus=num_gpus,
-                scheduling_strategy=PlacementGroupSchedulingStrategy(
-                    placement_group=placement_group,
-                    placement_group_capture_child_tasks=True),
+                scheduling_strategy=scheduling_strategy,
                 **ray_remote_kwargs,
-            )(RayWorker).remote(self.model_config.trust_remote_code)
-            self.workers.append(worker)
+            )(RayWorkerAphrodite).remote(self.model_config.trust_remote_code)
+
+            worker_ip = ray.get(worker.get_node_ip.remote())
+            if worker_ip == driver_ip and self.driver_dummy_worker is None:
+                # If the worker is on the same node as the driver, we use it
+                # as the resource holder for the driver process.
+                self.driver_dummy_worker = worker
+            else:
+                self.workers.append(worker)
+
+        if self.driver_dummy_worker is None:
+            raise ValueError(
+                "Ray does not allocate any GPUs on the driver node. Consider "
+                "adjusting the Ray placement group or running the driver on a "
+                "GPU node.")
+
+        driver_node_id, driver_gpu_ids = ray.get(
+            self.driver_dummy_worker.get_node_and_gpu_ids.remote())
+        worker_node_and_gpu_ids = ray.get(
+            [worker.get_node_and_gpu_ids.remote() for worker in self.workers])
+
+        node_workers = defaultdict(list)
+        node_gpus = defaultdict(list)
+
+        node_workers[driver_node_id].append(0)
+        node_gpus[driver_node_id].extend(driver_gpu_ids)
+        for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
+                                               start=1):
+            node_workers[node_id].append(i)
+            node_gpus[node_id].extend(gpu_ids)
+        for node_id, gpu_ids in node_gpus.items():
+            node_gpus[node_id] = sorted(gpu_ids)
+
+        # Set CUDA_VISIBLE_DEVICES for the driver.
+        set_cuda_visible_devices(node_gpus[driver_node_id])
+        for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
+            worker.set_cuda_visible_devices.remote(node_gpus[node_id])
+
+        distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
+
+        # Lazy import the Worker to avoid importing torch.cuda/xformers
+        # before CUDA_VISIBLE_DEVICES is set in the Worker
+        # pylint: disable=import-outside-toplevel
+        from aphrodite.task_handler.worker import Worker
 
         # Initialize torch distributed process group for the workers.
-        init_torch_dist_process_group(self.workers, backend="nccl")
         model_config = copy.deepcopy(self.model_config)
         parallel_config = copy.deepcopy(self.parallel_config)
         scheduler_config = copy.deepcopy(self.scheduler_config)
-        self._run_workers("init_worker",
-                          get_all_outputs=True,
-                          worker_init_fn=lambda: Worker(
-                              model_config,
-                              parallel_config,
-                              scheduler_config,
-                              None,
-                              None,
-                          ))
-        self._run_workers(
-            "init_model",
-            get_all_outputs=True,
+
+        for rank, (worker, (node_id,
+                            _)) in enumerate(zip(self.workers,
+                                                 worker_node_and_gpu_ids),
+                                             start=1):
+            local_rank = node_workers[node_id].index(rank)
+            worker.init_worker.remote(
+                lambda rank=rank, local_rank=local_rank: Worker(
+                    model_config,
+                    parallel_config,
+                    scheduler_config,
+                    local_rank,
+                    rank,
+                    distributed_init_method,
+                ))
+
+        driver_rank = 0
+        driver_local_rank = node_workers[driver_node_id].index(driver_rank)
+        self.driver_worker = Worker(
+            model_config,
+            parallel_config,
+            scheduler_config,
+            driver_local_rank,
+            driver_rank,
+            distributed_init_method,
+            is_driver_worker=True,
         )
+
+        self._run_workers("init_model")
         self._run_workers(
             "load_model",
-            get_all_outputs=True,
             max_concurrent_workers=self.parallel_config.
             max_parallel_loading_workers,
         )
 
-        # HACK
-        # After running ray.init(), ray processes affinity is set to (0,1).
-        # (or whatever the CPU scheduler fancies)
-        # We however want the actual workers that are being used,
-        # so we call here since calling after ray.init() and everything else.
-        # We reassign each ray process by taking the
-        # modulus of the number of cpu_cores available.
-        # Issue: https://github.com/PygmalionAI/aphrodite-engine/issues/115
-        # The solution is similar to the taskset solution linked above.
-        current_process = psutil.Process()
-        ray_threads = 0
-        logical_cores = psutil.cpu_count(logical=True)
-        physical_cores = psutil.cpu_count(logical=False)
-        ht_scale = physical_cores / logical_cores
-        for process in current_process.children(recursive=True):
-            # process.pid
-            if "ray::" in process.name():
-                process.cpu_affinity([ray_threads])
-                ray_threads += int(1 * ht_scale) if ht_scale > 1.0 else 1
-                ray_threads = ray_threads % logical_cores
-
     def _verify_args(self) -> None:
         self.model_config.verify_with_parallel_config(self.parallel_config)
         self.cache_config.verify_with_parallel_config(self.parallel_config)
 
     def _init_cache(self) -> None:
-        """Profiles the memory usage and initializes the KV cache."""
+        """Profiles the memory usage and initializes the KV cache.
+
+        The engine will first conduct a profiling of the existing memory usage.
+        Then, it calculate the maximum possible number of GPU and CPU blocks
+        that can be allocated with the remaining free memory.
+        More details can be found in the
+        # pylint: disable=line-too-long
+        :meth:`~aphrodite.task_handler.worker.Worker.profile_num_available_blocks` method
+        from class :class:`~aphrodite.task_handler.Worker`.
+
+        Afterwards, as there may be multiple workers,
+        we take the minimum number of blocks across all workers
+        to ensure this can be applied to all of them.
+
+        Finally, the engine will initialize the KV cache
+        with the calculated number of blocks.
+
+        .. tip::
+            You may limit the usage of GPU memory
+            by adjusting the `gpu_memory_utilization` parameters.
+        """
         # Get the maximum number of blocks that can be allocated on GPU and CPU.
         num_blocks = self._run_workers(
             "profile_num_available_blocks",
-            get_all_outputs=True,
             block_size=self.cache_config.block_size,
             gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
             cpu_swap_space=self.cache_config.swap_space_bytes,
-            cache_dtype=self.cache_config.cache_dtype,
         )
 
         # Since we use a shared centralized controller, we take the minimum
@@ -256,7 +303,6 @@ class AphroditeEngine:
             raise ValueError("No available memory for the cache blocks. "
                              "Try increasing `gpu_memory_utilization` when "
                              "initializing the engine.")
-
         max_seq_len = self.cache_config.block_size * num_gpu_blocks
         if self.model_config.max_model_len > max_seq_len:
             raise ValueError(
@@ -272,7 +318,7 @@ class AphroditeEngine:
         # Initialize the cache.
         self._run_workers("init_cache_engine", cache_config=self.cache_config)
         # Warm up the model. This includes capturing the model into CUDA graph
-        # if enforce_eager is set to False.
+        # if enforce_eager is False.
         self._run_workers("warm_up_model")
 
     @classmethod
@@ -282,11 +328,9 @@ class AphroditeEngine:
         engine_configs = engine_args.create_engine_configs()
         parallel_config = engine_configs[2]
         # Initialize the cluster.
-        distributed_init_method, placement_group = initialize_cluster(
-            parallel_config)
+        placement_group = initialize_cluster(parallel_config)
         # Create the LLM engine.
         engine = cls(*engine_configs,
-                     distributed_init_method,
                      placement_group,
                      log_stats=not engine_args.disable_log_stats)
         return engine
@@ -298,6 +342,7 @@ class AphroditeEngine:
         sampling_params: SamplingParams,
         prompt_token_ids: Optional[List[int]] = None,
         arrival_time: Optional[float] = None,
+        prefix_pos: Optional[int] = None,
     ) -> None:
         """Add a request to the engine's request pool.
 
@@ -313,10 +358,39 @@ class AphroditeEngine:
             prompt_token_ids: The token IDs of the prompt. If None, we
                 use the tokenizer to convert the prompts to token IDs.
             arrival_time: The arrival time of the request. If None, we use
-                the current time.
+                the current monotonic time.
+            prefix_pos: If not None, we use the given position as the prefix
+                position for each prompt. We will cache the prefix's KV
+                cache and reuse it for the next request with the same prefix.
+                This is an experimental feature, and may be replaced with
+                automatic prefix caching in the future.
+
+        Details:
+            - Set arrival_time to the current time if it is None.
+            - Set prompt_token_ids to the encoded prompt if it is None.
+            - Create `best_of` number of :class:`~aphrodite.Sequence` objects.
+            - Create a :class:`~aphrodite.SequenceGroup` object
+              from the list of :class:`~aphrodite.Sequence`.
+            - Add the :class:`~aphrodite.SequenceGroup` object to the scheduler.
+
+        Example:
+            >>> # initialize engine
+            >>> engine = AphroditeEngine.from_engine_args(engine_args)
+            >>> # set request arguments
+            >>> example_prompt = "Who is the president of the United States?"
+            >>> sampling_params = SamplingParams(temperature=0.0)
+            >>> request_id = 0
+            >>>
+            >>> # add the request to the engine
+            >>> engine.add_request(
+            >>>    str(request_id),
+            >>>    example_prompt,
+            >>>    SamplingParams(temperature=0.0))
+            >>> # continue the request processing
+            >>> ...
         """
         if arrival_time is None:
-            arrival_time = time.time()
+            arrival_time = time.monotonic()
         if prompt_token_ids is None:
             assert prompt is not None
             prompt_token_ids = self.tokenizer.encode(prompt)
@@ -326,9 +400,13 @@ class AphroditeEngine:
         seq_id = next(self.seq_counter)
         seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
 
+        # Check whether the input specifies prefix
+        prefix = self.scheduler.prefix_pool.add_or_get_prefix(
+            prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None
+
         # Create the sequence group.
         seq_group = SequenceGroup(request_id, [seq], sampling_params,
-                                  arrival_time)
+                                  arrival_time, prefix)
 
         # Add the sequence group to the scheduler.
         self.scheduler.add_seq_group(seq_group)
@@ -338,6 +416,17 @@ class AphroditeEngine:
 
         Args:
             request_id: The ID(s) of the request to abort.
+
+        Details:
+            - Refer to the
+              :meth:`~aphrodite.processing.scheduler.Scheduler.abort_seq_group`
+              from class :class:`~aphrodite.processing.scheduler.Scheduler`.
+
+        Example:
+            >>> # initialize engine and add a request with request_id
+            >>> request_id = str(0)
+            >>> # abort the request
+            >>> engine.abort_request(request_id)
         """
         self.scheduler.abort_seq_group(request_id)
 
@@ -583,10 +672,18 @@ class AphroditeEngine:
 
         # Create the outputs.
         request_outputs: List[RequestOutput] = []
-        for seq_group in (scheduled_seq_groups +
-                          scheduler_outputs.ignored_seq_groups):
+        for seq_group in scheduled_seq_groups:
             request_output = RequestOutput.from_seq_group(seq_group)
             request_outputs.append(request_output)
+        for seq_group in scheduler_outputs.ignored_seq_groups:
+            request_output = RequestOutput.from_seq_group(seq_group)
+            request_outputs.append(request_output)
+
+        # Update prefix state, now all the uncomputed prefixes are computed.
+        for seq_group in scheduled_seq_groups:
+            if (seq_group.prefix is not None and seq_group.prefix.allocated
+                    and not seq_group.prefix.computed):
+                seq_group.prefix.computed = True
 
         if self.log_stats:
             # Log the system stats.
@@ -597,31 +694,83 @@ class AphroditeEngine:
     def step(self) -> List[RequestOutput]:
         """Performs one decoding iteration and returns newly generated results.
 
-        This function performs one decoding iteration of the engine. It first
-        schedules the sequences to be executed in the next iteration and the
-        token blocks to be swapped in/out/copy. Then, it executes the model
-        and updates the scheduler with the model outputs. Finally, it decodes
-        the sequences and returns the newly generated results.
+        .. figure:: https://i.imgur.com/sv2HssD.png
+            :alt: Overview of the step function
+            :align: center
+
+            Overview of the step function.
+
+        Details:
+            - Step 1: Schedules the sequences to be executed in the next
+              iteration and the token blocks to be swapped in/out/copy.
+
+                - Depending on the scheduling policy,
+                  sequences may be `preempted/reordered`.
+                - A Sequence Group (SG) refer to a group of sequences
+                  that are generated from the same prompt.
+
+            - Step 2: Calls the workers to execute the model.
+            - Step 3: Processes the model output. This mainly includes:
+
+                - Decodes the relevant outputs.
+                - Updates the scheduled sequence groups with model outputs
+                  based on its `sampling parameters` (`use_beam_search` or not).
+                - Frees the finished sequence groups.
+
+            - Finally, it creates and returns the newly generated results.
+
+        Example:
+            >>> # Please see the example/ folder for more detailed examples.
+            >>>
+            >>> # initialize engine and request arguments
+            >>> engine = AphroditeEngine.from_engine_args(engine_args)
+            >>> example_inputs = [(0, "What is LLM?",
+            >>>    SamplingParams(temperature=0.0))]
+            >>>
+            >>> # Start the engine with an event loop
+            >>> while True:
+            >>>     if example_inputs:
+            >>>         req_id, prompt, sampling_params = example_inputs.pop(0)
+            >>>         engine.add_request(str(req_id), prompt, sampling_params)
+            >>>
+            >>>     # continue the request processing
+            >>>     request_outputs = engine.step()
+            >>>     for request_output in request_outputs:
+            >>>         if request_output.finished:
+            >>>             # return or show the request output
+            >>>
+            >>>     if not (engine.has_unfinished_requests() or example_inputs):
+            >>>         break
         """
         seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
 
-        # Execute the model.
-        output = self._run_workers(
-            "execute_model",
-            seq_group_metadata_list=seq_group_metadata_list,
-            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
-            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
-            blocks_to_copy=scheduler_outputs.blocks_to_copy,
-        ) if not scheduler_outputs.is_empty() else []
+        if not scheduler_outputs.is_empty():
+            # Execute the model.
+            all_outputs = self._run_workers(
+                "execute_model",
+                driver_kwargs={
+                    "seq_group_metadata_list": seq_group_metadata_list,
+                    "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
+                    "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
+                    "blocks_to_copy": scheduler_outputs.blocks_to_copy,
+                })
+
+            # Only the driver worker returns the sampling results.
+            output = all_outputs[0]
+        else:
+            output = []
 
         return self._process_model_outputs(output, scheduler_outputs)
 
+    def do_log_stats(self) -> None:
+        self._log_system_stats(False, 0)
+
     def _log_system_stats(
         self,
         prompt_run: bool,
         num_batched_tokens: int,
     ) -> None:
-        now = time.time()
+        now = time.monotonic()
         # Log the number of batched input tokens.
         if prompt_run:
             self.num_prompt_tokens.append((now, num_batched_tokens))
@@ -699,7 +848,8 @@ class AphroditeEngine:
              prefix_offset=seq.prefix_offset,
              read_offset=seq.read_offset,
              skip_special_tokens=prms.skip_special_tokens,
-             spaces_between_special_tokens=prms.spaces_between_special_tokens)
+             spaces_between_special_tokens=prms.spaces_between_special_tokens,
+         )
         if seq.tokens is None:
             seq.tokens = new_tokens
         else:
@@ -715,7 +865,7 @@ class AphroditeEngine:
             if seq.output_text.endswith(stop_str):
                 if not sampling_params.include_stop_str_in_output:
                     # Truncate the output text so that the stop string is
-                    # not included in the output
+                    # not included in the output.
                     seq.output_text = seq.output_text[:-len(stop_str)]
                 seq.status = SequenceStatus.FINISHED_STOPPED
                 return
@@ -739,54 +889,38 @@ class AphroditeEngine:
             seq.status = SequenceStatus.FINISHED_STOPPED
             return
 
-    def _run_workers_in_batch(
-        self,
-        workers,
-        method: str,
-        *args,
-        **kwargs,
-    ):
-        all_outputs = []
-        for worker in workers:
-            if self.parallel_config.worker_use_ray:
-                executor = partial(worker.execute_method.remote, method)
-            else:
-                executor = getattr(worker, method)
-
-            output = executor(*args, **kwargs)
-            all_outputs.append(output)
-
-        if self.parallel_config.worker_use_ray:
-            all_outputs = ray.get(all_outputs)
-        return all_outputs
-
     def _run_workers(
         self,
         method: str,
         *args,
-        get_all_outputs: bool = False,
+        driver_args: Optional[List[Any]] = None,
+        driver_kwargs: Optional[Dict[str, Any]] = None,
         max_concurrent_workers: Optional[int] = None,
         **kwargs,
     ) -> Any:
-        """Runs a method on all workers."""
-        all_outputs = []
+        """Runs the given method on all workers."""
+
         if max_concurrent_workers:
-            work_groups = [
-                self.workers[i:i + max_concurrent_workers]
-                for i in range(0, len(self.workers), max_concurrent_workers)
-            ]
-        else:
-            work_groups = [self.workers]
+            raise NotImplementedError(
+                "max_concurrent_workers is not supported yet.")
+
+        # Start the ray workers first.
+        ray_worker_outputs = [
+            worker.execute_method.remote(method, *args, **kwargs)
+            for worker in self.workers
+        ]
+
+        if driver_args is None:
+            driver_args = args
+        if driver_kwargs is None:
+            driver_kwargs = kwargs
 
-        for workers in work_groups:
-            all_outputs.extend(
-                self._run_workers_in_batch(workers, method, *args, **kwargs))
+        # Start the driver worker after all the ray workers.
+        driver_worker_output = getattr(self.driver_worker,
+                                       method)(*driver_args, **driver_kwargs)
 
-        if get_all_outputs:
-            return all_outputs
+        # Get the results of the ray workers.
+        if self.workers:
+            ray_worker_outputs = ray.get(ray_worker_outputs)
 
-        # Make sure all workers have the same results.
-        output = all_outputs[0]
-        for other_output in all_outputs[1:]:
-            assert output == other_output
-        return output
+        return [driver_worker_output] + ray_worker_outputs

+ 37 - 26
aphrodite/engine/args_tools.py

@@ -31,10 +31,10 @@ class EngineArgs:
     max_paddings: int = 256
     disable_log_stats: bool = False
     revision: Optional[str] = None
+    tokenizer_revision: Optional[str] = None
     quantization: Optional[str] = None
     enforce_eager: bool = False
     max_context_len_to_capture: int = 8192
-    kv_cache_dtype: Optional[str] = None
 
     def __post_init__(self):
         if self.tokenizer is None:
@@ -44,6 +44,10 @@ class EngineArgs:
     def add_cli_args(
             parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
         """Shared CLI arguments for the Aphrodite engine."""
+
+        # NOTE: If you update any of the arguments below, please also
+        # make sure to update docs/source/models/engine_args.rst
+
         # Model arguments
         parser.add_argument(
             '--model',
@@ -62,6 +66,13 @@ class EngineArgs:
             help='the specific model version to use. It can be a branch '
             'name, a tag name, or a commit id. If unspecified, will use '
             'the default version.')
+        parser.add_argument(
+            '--tokenizer-revision',
+            type=str,
+            default=None,
+            help='the specific tokenizer version to use. It can be a branch '
+            'name, a tag name, or a commit id. If unspecified, will use '
+            'the default version.')
         parser.add_argument('--tokenizer-mode',
                             type=str,
                             default=EngineArgs.tokenizer_mode,
@@ -84,11 +95,11 @@ class EngineArgs:
             default=EngineArgs.load_format,
             choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
             help='The format of the model weights to load. '
-            '"auto" will try to load the weights in the safetensors '
-            'and fall back to the pytorch bin format if safetensors '
+            '"auto" will try to load the weights in the safetensors format '
+            'and fall back to the pytorch bin format if safetensors format '
             'is not available. '
             '"pt" will load the weights in the pytorch bin format. '
-            '"safetensors" will load the weights in the safetensors. '
+            '"safetensors" will load the weights in the safetensors format. '
             '"npcache" will load the weights in pytorch format and store '
             'a numpy cache to speed up the loading. '
             '"dummy" will initialize the weights with random values, '
@@ -126,11 +137,10 @@ class EngineArgs:
                             help='number of tensor parallel replicas')
         parser.add_argument(
             '--max-parallel-loading-workers',
-            '-mplw',
             type=int,
             help='load model sequentially in multiple batches, '
-            'to avoid CPU OOM when using tensor parallel '
-            'with large models.')
+            'to avoid RAM OOM when using tensor '
+            'parallel and large models')
         # KV cache arguments
         parser.add_argument('--block-size',
                             type=int,
@@ -146,14 +156,15 @@ class EngineArgs:
                             type=int,
                             default=EngineArgs.swap_space,
                             help='CPU swap space size (GiB) per GPU')
-        parser.add_argument('--gpu-memory-utilization',
-                            '-gmu',
-                            type=float,
-                            default=EngineArgs.gpu_memory_utilization,
-                            help='the percentage of GPU memory to be used for'
-                            'the model executor')
+        parser.add_argument(
+            '--gpu-memory-utilization',
+            '-gmu',
+            type=float,
+            default=EngineArgs.gpu_memory_utilization,
+            help='the fraction of GPU memory to be used for '
+            'the model executor, which can range from 0 to 1.'
+            'If unspecified, will use the default value of 0.9.')
         parser.add_argument('--max-num-batched-tokens',
-                            '-mnbt',
                             type=int,
                             default=EngineArgs.max_num_batched_tokens,
                             help='maximum number of batched tokens per '
@@ -173,25 +184,25 @@ class EngineArgs:
         parser.add_argument('--quantization',
                             '-q',
                             type=str,
-                            choices=['awq', 'squeezellm', 'gptq', None],
+                            choices=['awq', 'gptq', 'squeezellm', None],
                             default=None,
-                            help='Method used to quantize the weights')
+                            help='Method used to quantize the weights. If '
+                            'None, we first check the `quantization_config` '
+                            'attribute in the model config file. If that is '
+                            'None, we assume the model weights are not '
+                            'quantized and use `dtype` to determine the data '
+                            'type of the weights.')
         parser.add_argument('--enforce-eager',
                             action='store_true',
                             help='Always use eager-mode PyTorch. If False, '
                             'will use eager mode and CUDA graph in hybrid '
-                            'for maximum performance and flexibility.')
+                            'for maximal performance and flexibility.')
         parser.add_argument('--max-context-len-to-capture',
                             type=int,
                             default=EngineArgs.max_context_len_to_capture,
                             help='maximum context length covered by CUDA '
                             'graphs. When a sequence has context length '
                             'larger than this, we fall back to eager mode.')
-        parser.add_argument('--kv-cache-dtype',
-                            type=str,
-                            choices=['fp8', None],
-                            default=None,
-                            help='Data type for the KV cache.')
         return parser
 
     @classmethod
@@ -209,12 +220,12 @@ class EngineArgs:
                                    self.tokenizer_mode, self.trust_remote_code,
                                    self.download_dir, self.load_format,
                                    self.dtype, self.seed, self.revision,
-                                   self.max_model_len, self.quantization,
-                                   self.enforce_eager,
+                                   self.tokenizer_revision, self.max_model_len,
+                                   self.quantization, self.enforce_eager,
                                    self.max_context_len_to_capture)
         cache_config = CacheConfig(self.block_size,
                                    self.gpu_memory_utilization,
-                                   self.swap_space, self.kv_cache_dtype,
+                                   self.swap_space,
                                    model_config.get_sliding_window())
         parallel_config = ParallelConfig(self.pipeline_parallel_size,
                                          self.tensor_parallel_size,
@@ -229,7 +240,7 @@ class EngineArgs:
 
 @dataclass
 class AsyncEngineArgs(EngineArgs):
-    """Arguments for asynchronous Aohrodite engine."""
+    """Arguments for asynchronous Aphrodite engine."""
     engine_use_ray: bool = False
     disable_log_requests: bool = False
     max_log_len: Optional[int] = None

+ 113 - 44
aphrodite/engine/async_aphrodite.py

@@ -1,7 +1,8 @@
 import asyncio
 import time
 from functools import partial
-from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
+from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
+                    Union, AsyncIterator)
 
 from aphrodite.common.config import ModelConfig
 from aphrodite.engine.args_tools import AsyncEngineArgs
@@ -184,14 +185,21 @@ class _AsyncAphrodite(AphroditeEngine):
         """
         seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
 
-        # Execute the model.
-        output = (await self._run_workers_async(
-            "execute_model",
-            seq_group_metadata_list=seq_group_metadata_list,
-            blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
-            blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
-            blocks_to_copy=scheduler_outputs.blocks_to_copy,
-        )) if not scheduler_outputs.is_empty() else []
+        if not scheduler_outputs.is_empty():
+            # Execute the model.
+            all_outputs = await self._run_workers_async(
+                "execute_model",
+                driver_kwargs={
+                    "seq_group_metadata_list": seq_group_metadata_list,
+                    "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
+                    "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
+                    "blocks_to_copy": scheduler_outputs.blocks_to_copy,
+                })
+
+            # Only the driver worker returns the sampling results.
+            output = all_outputs[0]
+        else:
+            output = []
 
         return self._process_model_outputs(output, scheduler_outputs)
 
@@ -199,42 +207,40 @@ class _AsyncAphrodite(AphroditeEngine):
         self,
         method: str,
         *args,
-        get_all_outputs: bool = False,
+        driver_args: Optional[List[Any]] = None,
+        driver_kwargs: Optional[Dict[str, Any]] = None,
         **kwargs,
     ) -> Any:
         """Runs the given method on all workers."""
-        all_outputs = []
-        for worker in self.workers:
-            if self.parallel_config.worker_use_ray:
-                executor = partial(worker.execute_method.remote, method)
-            else:
-                executor = getattr(worker, method)
+        coros = []
 
-            output = executor(*args, **kwargs)
-            all_outputs.append(output)
+        if driver_args is None:
+            driver_args = args
+        if driver_kwargs is None:
+            driver_kwargs = kwargs
 
-        if self.parallel_config.worker_use_ray:
-            all_outputs = await asyncio.gather(*all_outputs)
+        # Run the driver worker asynchronously.
+        driver_executor = getattr(self.driver_worker, method)
+        coros.append(asyncio.get_event_loop().run_in_executor(
+            None, partial(driver_executor, *driver_args, **driver_kwargs)))
 
-        if get_all_outputs:
-            return all_outputs
+        # Run the ray workers asynchronously.
+        for worker in self.workers:
+            coros.append(worker.execute_method.remote(method, *args, **kwargs))
 
-        # Make sure all workers have the same results.
-        output = all_outputs[0]
-        for other_output in all_outputs[1:]:
-            assert output == other_output
-        return output
+        all_outputs = await asyncio.gather(*coros)
+        return all_outputs
 
 
 class AsyncAphrodite:
     """An asynchronous wrapper for AphroditeEngine.
 
     This class is used to wrap the AphroditeEngine class to make it
-    asynchronous. It uses asyncio to create a background loop that
-    keeps processing incoming requests. The AphroditeEngine is kicked
-    by the generate method when there are requests in the waiting queue.
-    The generate method yields the outputs from the AphroditeEngine to
-    the caller.
+    asynchronous. It uses asyncio to create a background loop that keeps
+    processing incoming requests. The AphroditeEngine is kicked by the
+    generate method when there are requests in the waiting queue.
+    The generate method yields the outputs from the AphroditeEngine
+    to the caller.
 
     NOTE: For the comprehensive list of arguments, see `AphroditeEngine`.
 
@@ -248,7 +254,8 @@ class AsyncAphrodite:
         log_requests: Whether to log the requests.
         start_engine_loop: If True, the background task to run the engine
             will be automatically started in the generate call.
-        *args, *kwargs: Arguments for AphroditeEngine.
+        *args: Arguments for AphroditeEngine.
+        *kwargs: Arguments for AphroditeEngine.
     """
 
     _engine_class: Type[_AsyncAphrodite] = _AsyncAphrodite
@@ -365,6 +372,7 @@ class AsyncAphrodite:
         sampling_params: SamplingParams,
         prompt_token_ids: Optional[List[int]] = None,
         arrival_time: Optional[float] = None,
+        prefix_pos: Optional[int] = None,
     ) -> AsyncStream:
         if self.log_requests:
             shortened_prompt = prompt
@@ -377,6 +385,7 @@ class AsyncAphrodite:
                                                               max_log_len]
             logger.info(f"Received request {request_id}: "
                         f"prompt: {shortened_prompt!r}, "
+                        f"prefix_pos: {prefix_pos}, "
                         f"sampling params: {sampling_params}, "
                         f"prompt token ids: {shortened_token_ids}.")
 
@@ -395,16 +404,20 @@ class AsyncAphrodite:
             prompt=prompt,
             sampling_params=sampling_params,
             prompt_token_ids=prompt_token_ids,
-            arrival_time=arrival_time)
+            arrival_time=arrival_time,
+            prefix_pos=prefix_pos,
+        )
 
         return stream
 
     async def generate(
-            self,
-            prompt: Optional[str],
-            sampling_params: SamplingParams,
-            request_id: str,
-            prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
+        self,
+        prompt: Optional[str],
+        sampling_params: SamplingParams,
+        request_id: str,
+        prompt_token_ids: Optional[List[int]] = None,
+        prefix_pos: Optional[int] = None,
+    ) -> AsyncIterator[RequestOutput]:
         """Generate outputs for a request.
 
         Generate outputs for a request. This method is a coroutine. It adds the
@@ -418,20 +431,71 @@ class AsyncAphrodite:
             request_id: The unique id of the request.
             prompt_token_ids: The token IDs of the prompt. If None, we
                 use the tokenizer to convert the prompts to token IDs.
+            prefix_pos: If not None, we use the given position as the prefix
+                position for each prompt. We will cache the prefix's KV
+                cache and reuse it for the next request with the same prefix.
+                This is an experimental feature, and may be replaced with
+                automatic prefix caching in the future.
 
         Yields:
             The output `RequestOutput` objects from the AphroditeEngine for the
             request.
+
+        Details:
+            - If the engine is not running, start the background loop,
+              which iteratively invokes
+              # pylint: disable=line-too-long
+              :meth:`~aphrodite.engine.async_llm_engine.AsyncAphrodite.engine_step`
+              to process the waiting requests.
+            - Add the request to the engine's `RequestTracker`.
+              On the next background loop, this request will be sent to
+              the underlying engine.
+              Also, a corresponding `AsyncStream` will be created.
+            - Wait for the request outputs from `AsyncStream` and yield them.
+
+        Example:
+            >>> # Please refer to entrypoints/api_server.py for
+            >>> # the complete example.
+            >>>
+            >>> # initialize the engine and the example input
+            >>> engine = AsyncAphrodite.from_engine_args(engine_args)
+            >>> example_input = {
+            >>>     "prompt": "What is LLM?",
+            >>>     "stream": False, # assume the non-streaming case
+            >>>     "temperature": 0.0,
+            >>>     "request_id": 0,
+            >>> }
+            >>>
+            >>> # start the generation
+            >>> results_generator = engine.generate(
+            >>>    example_input["prompt"],
+            >>>    SamplingParams(temperature=example_input["temperature"]),
+            >>>    example_input["request_id"])
+            >>>
+            >>> # get the results
+            >>> final_output = None
+            >>> async for request_output in results_generator:
+            >>>     if await request.is_disconnected():
+            >>>         # Abort the request if the client disconnects.
+            >>>         await engine.abort(request_id)
+            >>>         # Return or raise an error
+            >>>         ...
+            >>>     final_output = request_output
+            >>>
+            >>> # Process and return the final output
+            >>> ...
         """
         # Preprocess the request.
-        arrival_time = time.time()
+        # This should not be used for logging, as it is monotonic time.
+        arrival_time = time.monotonic()
 
         try:
             stream = await self.add_request(request_id,
                                             prompt,
                                             sampling_params,
                                             prompt_token_ids=prompt_token_ids,
-                                            arrival_time=arrival_time)
+                                            arrival_time=arrival_time,
+                                            prefix_pos=prefix_pos)
 
             async for request_output in stream:
                 yield request_output
@@ -487,16 +551,21 @@ class AsyncAphrodite:
         engine_configs = engine_args.create_engine_configs()
         parallel_config = engine_configs[2]
         # Initialize the cluster.
-        distributed_init_method, placement_group = initialize_cluster(
-            parallel_config, engine_args.engine_use_ray)
+        placement_group = initialize_cluster(parallel_config,
+                                             engine_args.engine_use_ray)
         # Create the async LLM engine.
         engine = cls(parallel_config.worker_use_ray,
                      engine_args.engine_use_ray,
                      *engine_configs,
-                     distributed_init_method,
                      placement_group,
                      log_requests=not engine_args.disable_log_requests,
                      log_stats=not engine_args.disable_log_stats,
                      max_log_len=engine_args.max_log_len,
                      start_engine_loop=start_engine_loop)
         return engine
+
+    async def do_log_stats(self) -> None:
+        if self.engine_use_ray:
+            await self.engine.do_log_stats.remote()
+        else:
+            self.engine.do_log_stats()

+ 27 - 23
aphrodite/engine/ray_tools.py

@@ -1,18 +1,15 @@
-"""Ray for distributed multi-node inference:
-https://github.com/ray-project/ray"""
-from typing import Optional, Tuple, TYPE_CHECKING
+from typing import Optional, List, Tuple, TYPE_CHECKING
 
 from aphrodite.common.config import ParallelConfig
 from aphrodite.common.logger import init_logger
-from aphrodite.common.utils import get_open_port, is_hip
+from aphrodite.common.utils import is_hip, set_cuda_visible_devices, get_ip
 
 logger = init_logger(__name__)
 
 try:
     import ray
-    from ray.air.util.torch_dist import TorchDistributedWorker
 
-    class RayWorker(TorchDistributedWorker):
+    class RayWorkerAphrodite:
         """Ray wrapper for aphrodite.task_handler.Worker, allowing Worker to be
         lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
 
@@ -33,13 +30,23 @@ try:
             executor = getattr(self, method)
             return executor(*args, **kwargs)
 
+        def get_node_ip(self) -> str:
+            return get_ip()
+
+        def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
+            node_id = ray.get_runtime_context().get_node_id()
+            gpu_ids = ray.get_gpu_ids()
+            return node_id, gpu_ids
+
+        def set_cuda_visible_devices(self, device_ids) -> None:
+            set_cuda_visible_devices(device_ids)
+
 except ImportError as e:
     logger.warning(f"Failed to import Ray with {e!r}. "
                    "For distributed inference, please install Ray with "
                    "`pip install ray pandas pyarrow`.")
     ray = None
-    TorchDistributedWorker = None
-    RayWorker = None  # pylint: disable=invalid-name
+    RayWorkerAphrodite = None
 
 if TYPE_CHECKING:
     from ray.util.placement_group import PlacementGroup
@@ -49,7 +56,7 @@ def initialize_cluster(
     parallel_config: ParallelConfig,
     engine_use_ray: bool = False,
     ray_address: Optional[str] = None,
-) -> Tuple[str, Optional["PlacementGroup"]]:
+) -> Optional["PlacementGroup"]:
     """Initialize the distributed cluster probably with Ray.
 
     Args:
@@ -59,11 +66,10 @@ def initialize_cluster(
             the default Ray cluster address.
 
     Returns:
-        A tuple of (`distributed_init_method`, `all_stage_devices`). The
+        A tuple of (`distributed_init_method`, `placement_group`). The
         `distributed_init_method` is the address for initializing the
-        distributed backend. `all_stage_devices` includes device IDs for
-        each worker in each pipeline stage. Each device ID is a tuple of
-        (rank, node resource, device id).
+        distributed backend. `placement_group` includes the specification
+        of the resources for each distributed worker.
     """
     if parallel_config.worker_use_ray or engine_use_ray:
         if ray is None:
@@ -79,13 +85,11 @@ def initialize_cluster(
             ray.init(address=ray_address, ignore_reinit_error=True)
 
     if not parallel_config.worker_use_ray:
-        # Initialize cluster locally.
-        port = get_open_port()
-        # We need to setup the distributed init method to make sure
-        # the distributed megatron code (e.g., get world size) works correctly.
-        distributed_init_method = f"tcp://localhost:{port}"
-        return distributed_init_method, None
+        assert parallel_config.world_size == 1, (
+            "Ray is required if parallel_config.world_size > 1.")
+        return None
 
+    # Create placement group for worker processes
     current_placement_group = ray.util.get_current_placement_group()
     if current_placement_group:
         # We are in a placement group
@@ -110,12 +114,12 @@ def initialize_cluster(
                 "The number of required GPUs exceeds the total number of "
                 "available GPUs in the cluster.")
         # Create a new placement group
-        current_placement_group = ray.util.placement_group([{
-            "GPU": 1
-        }] * parallel_config.world_size)
+        placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
+        current_placement_group = ray.util.placement_group(
+            placement_group_specs)
         # Wait until PG is ready - this will block until all
         # requested resources are available, and will timeout
         # if they cannot be provisioned.
         ray.get(current_placement_group.ready(), timeout=1800)
 
-    return None, current_placement_group
+    return current_placement_group

+ 1 - 0
aphrodite/modeling/__init__.py

@@ -6,5 +6,6 @@ from aphrodite.modeling.utils import set_random_seed
 __all__ = [
     "InputMetadata",
     "get_model",
+    "SamplingMetadata",
     "set_random_seed",
 ]

+ 21 - 33
aphrodite/modeling/hf_downloader.py

@@ -1,17 +1,18 @@
 """Utilities for downloading and initializing model weights."""
 import filelock
 import glob
+import fnmatch
 import json
 import os
 from collections import defaultdict
 from typing import Any, Iterator, List, Optional, Tuple
 
-from huggingface_hub import snapshot_download
-from safetensors.torch import load_file, save_file, safe_open
+from huggingface_hub import snapshot_download, HfFileSystem
 import numpy as np
+from safetensors.torch import load_file, save_file, safe_open
 import torch
-from tqdm.auto import tqdm
 from transformers import PretrainedConfig
+from tqdm.auto import tqdm
 
 from aphrodite.common.logger import init_logger
 from aphrodite.modeling.layers.quantization import (get_quantization_config,
@@ -89,6 +90,7 @@ def get_quant_config(
     cache_dir: Optional[str] = None,
 ) -> QuantizationConfig:
     quant_cls = get_quantization_config(quantization)
+    # Read the quantization config from the HF model config, if available.
     hf_quant_config = getattr(hf_config, "quantization_config", None)
     if hf_quant_config is not None:
         return quant_cls.from_config(hf_quant_config)
@@ -131,7 +133,7 @@ def prepare_hf_model_weights(
     # Download model weights from huggingface.
     is_local = os.path.isdir(model_name_or_path)
     use_safetensors = False
-    # Some quantized models use .pt files for storing the weights
+    # Some quantized models use .pt files for storing the weights.
     if load_format == "auto":
         allow_patterns = ["*.safetensors", "*.bin"]
     elif load_format == "safetensors":
@@ -145,9 +147,21 @@ def prepare_hf_model_weights(
         raise ValueError(f"Unknown load_format: {load_format}")
 
     if fall_back_to_pt:
-        allow_patterns += [".pt"]
+        allow_patterns += ["*.pt"]
 
     if not is_local:
+        # Before we download we look at that is available:
+        fs = HfFileSystem()
+        file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
+
+        # depending on what is available we download different things
+        for pattern in allow_patterns:
+            matching = fnmatch.filter(file_list, pattern)
+            if len(matching) > 0:
+                allow_patterns = [pattern]
+                break
+
+        logger.info(f"Downloading model weights {allow_patterns}")
         # Use file lock to prevent multiple processes from
         # downloading the same model weights at the same time.
         with get_lock(model_name_or_path, cache_dir):
@@ -195,7 +209,6 @@ def hf_model_weights_iterator(
     revision: Optional[str] = None,
     fall_back_to_pt: Optional[bool] = True,
 ) -> Iterator[Tuple[str, torch.Tensor]]:
-
     hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights(
         model_name_or_path,
         cache_dir=cache_dir,
@@ -238,7 +251,7 @@ def hf_model_weights_iterator(
     elif use_safetensors:
         for st_file in hf_weights_files:
             with safe_open(st_file, framework="pt") as f:
-                for name in f.keys():
+                for name in f.keys():  # noqa: SIM118
                     param = f.get_tensor(name)
                     yield name, param
     else:
@@ -252,6 +265,7 @@ def hf_model_weights_iterator(
 
 def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
     """convert PySafeSlice object from safetensors to torch.Tensor
+
     PySafeSlice object supports indexing, which is done before loading the
     actual tensor and can reduce the amount of memory being read into the
     memory. However, it does not support more advanced functionalities
@@ -286,29 +300,3 @@ def initialize_dummy_weights(
     for param in model.state_dict().values():
         if torch.is_floating_point(param):
             param.data.uniform_(low, high)
-
-
-def get_parallel_weight(model: torch.nn.Module):
-    if model.quant_config is None:
-        column_weight_suffixes = ["weight", "bias"]
-        row_weight_suffixes = ["weight"]
-    else:
-        column_weight_suffixes = (
-            model.quant_config.get_col_parallel_tensor_names())
-        row_weight_suffixes = (
-            model.quant_config.get_row_parallel_tensor_names())
-
-    column_parallel_weights: List[str] = []
-    for layer in model.column_parallel_layers:
-        for suffix in column_weight_suffixes:
-            column_parallel_weights.append(f"{layer}.{suffix}")
-    row_parallel_weights: List[str] = []
-    for layer in model.row_parallel_layers:
-        for suffix in row_weight_suffixes:
-            row_parallel_weights.append(f"{layer}.{suffix}")
-
-    if hasattr(model, "parallel_vocab_layers"):
-        for layer in model.parallel_vocab_layers:
-            for suffix in ["weight", "bias"]:
-                column_parallel_weights.append(f"{layer}.{suffix}")
-    return column_parallel_weights, row_parallel_weights

+ 9 - 8
aphrodite/modeling/layers/activation.py

@@ -1,3 +1,4 @@
+"""Custom activation functions."""
 import math
 from typing import Optional
 
@@ -5,7 +6,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from aphrodite._C import ops as activation_ops
+from aphrodite._C import ops
 from aphrodite.modeling.layers.quantization import QuantizationConfig
 from aphrodite.modeling.megatron.parallel_state import (
     get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
@@ -24,7 +25,7 @@ class SiluAndMul(nn.Module):
     """
 
     def _forward(self, x: torch.Tensor) -> torch.Tensor:
-        """PyTorch-native implementation. Equivalent to forward()."""
+        """PyTorch-native implementation equivalent to forward()."""
         d = x.shape[-1] // 2
         return F.silu(x[..., :d]) * x[..., d:]
 
@@ -32,34 +33,34 @@ class SiluAndMul(nn.Module):
         d = x.shape[-1] // 2
         output_shape = (x.shape[:-1] + (d, ))
         out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
-        activation_ops.silu_and_mul(out, x)
+        ops.silu_and_mul(out, x)
         return out
 
 
 class NewGELU(nn.Module):
 
     def _forward(self, x: torch.Tensor) -> torch.Tensor:
-        """Pytorch-native implemenation. Equivalent to forward()."""
+        """PyTorch-native implementation equivalent to forward()."""
         c = math.sqrt(2.0 / math.pi)
         return 0.5 * x * (1.0 + torch.tanh(c *
                                            (x + 0.044715 * torch.pow(x, 3.0))))
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         out = torch.empty_like(x)
-        activation_ops.gelu_new(out, x)
+        ops.gelu_new(out, x)
         return out
 
 
 class FastGELU(nn.Module):
 
     def _forward(self, x: torch.Tensor) -> torch.Tensor:
-        """Pytorch-native implemenation. Equivalent to forward()."""
-        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608028654 *
+        """PyTorch-native implementation equivalent to forward()."""
+        return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
                                            (1.0 + 0.044715 * x * x)))
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         out = torch.empty_like(x)
-        activation_ops.gelu_fast(out, x)
+        ops.gelu_fast(out, x)
         return out
 
 

+ 22 - 29
aphrodite/modeling/layers/attention.py

@@ -1,7 +1,4 @@
-"""
-Multi-head Paged Attention by Woosuk et al. (vLLM) Copyright (c) 2023.
-https://vllm.ai/
-"""
+"""Multi-head attention."""
 from typing import List, Optional
 
 import torch
@@ -10,7 +7,7 @@ from xformers import ops as xops
 from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
                                          LowerTriangularMaskWithTensorBias)
 
-from aphrodite._C import ops as attention_ops
+from aphrodite._C import ops
 from aphrodite._C import cache_ops
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.common.utils import is_hip
@@ -132,7 +129,8 @@ class PagedAttention(nn.Module):
                     input_metadata.attn_bias = attn_bias
                 else:
                     input_metadata.attn_bias = _make_alibi_bias(
-                        self.alibi_slopes, batch_size, seq_len, query.dtype)
+                        self.alibi_slopes, self.num_kv_heads, batch_size,
+                        seq_len, query.dtype)
 
             # TODO: Too many view operations. Let's try to reduce them
             # in the future for code readability.
@@ -158,20 +156,15 @@ class PagedAttention(nn.Module):
             output = out.view_as(query)
         else:
             # Decoding run.
-            if key_cache is not None and value_cache is not None:
-                output = _paged_attention(
-                    query,
-                    key_cache,
-                    value_cache,
-                    input_metadata,
-                    self.num_kv_heads,
-                    self.scale,
-                    self.alibi_slopes,
-                )
-            else:
-                # This happens during the initial memory profiling run
-                # for CUDA graphs.
-                output = torch.zeros_like(query)
+            output = _paged_attention(
+                query,
+                key_cache,
+                value_cache,
+                input_metadata,
+                self.num_kv_heads,
+                self.scale,
+                self.alibi_slopes,
+            )
 
         # Reshape the output tensor.
         return output.view(batch_size, seq_len, hidden_size)
@@ -179,31 +172,34 @@ class PagedAttention(nn.Module):
 
 def _make_alibi_bias(
     alibi_slopes: torch.Tensor,
+    num_kv_heads: int,
     batch_size: int,
     seq_len: int,
     dtype: torch.dtype,
 ) -> LowerTriangularMaskWithTensorBias:
-    bias = torch.arange(seq_len, dtype=dtype)
+    bias = torch.arange(seq_len, dtype=dtype, device="cuda")
     # NOTE: HF uses
     #     `bias = bias[None, :].repeat(prompt_len, 1)`
-    # here. It that both biases give the same results, but
+    # here. We find that both biases give the same results, but
     # the bias below more accurately follows the original ALiBi
     # paper.
     bias = bias[None, :] - bias[:, None]
-    bias = bias.to(alibi_slopes.device)
 
     # When using custom attention bias, xformers requires the bias to
     # be sliced from a tensor whose length is a multiple of 8.
     padded_len = (seq_len + 7) // 8 * 8
+    num_heads = alibi_slopes.shape[0]
     bias = torch.empty(
         batch_size,
-        alibi_slopes.shape[0],
+        num_heads,
         seq_len,
         padded_len,
         device=alibi_slopes.device,
         dtype=dtype,
     )[:, :, :, :seq_len].copy_(bias)
     bias.mul_(alibi_slopes[:, None, None])
+    if num_heads != num_kv_heads:
+        bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
     attn_bias = LowerTriangularMaskWithTensorBias(bias)
     return attn_bias
 
@@ -219,7 +215,6 @@ def _paged_attention(
 ) -> torch.Tensor:
     output = torch.empty_like(query)
 
-    enable_fp8_kv_cache = key_cache.dtype == torch.uint8
     block_size = value_cache.shape[3]
     num_seqs, num_heads, head_size = query.shape
     max_num_partitions = (
@@ -236,7 +231,7 @@ def _paged_attention(
         max_num_partitions == 1 or num_seqs * num_heads > 512)
     if use_v1:
         # Run PagedAttention V1.
-        attention_ops.paged_attention_v1(
+        ops.paged_attention_v1(
             output,
             query,
             key_cache,
@@ -248,7 +243,6 @@ def _paged_attention(
             block_size,
             input_metadata.max_context_len,
             alibi_slopes,
-            enable_fp8_kv_cache,
         )
     else:
         # Run PagedAttention V2.
@@ -264,7 +258,7 @@ def _paged_attention(
             device=output.device,
         )
         max_logits = torch.empty_like(exp_sums)
-        attention_ops.paged_attention_v2(
+        ops.paged_attention_v2(
             output,
             exp_sums,
             max_logits,
@@ -279,6 +273,5 @@ def _paged_attention(
             block_size,
             input_metadata.max_context_len,
             alibi_slopes,
-            enable_fp8_kv_cache,
         )
     return output

+ 10 - 10
aphrodite/modeling/layers/layernorm.py

@@ -1,23 +1,23 @@
-"""Custom normalization layers"""
+"""Custom normalization layers."""
 from typing import Optional, Tuple, Union
+
 import torch
 import torch.nn as nn
 
-from aphrodite._C import ops as layernorm_ops
+from aphrodite._C import ops
 
 
 class RMSNorm(nn.Module):
     """Root mean square normalization.
 
     Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
-    Refer to the Root Mean Square Layer Normalization paper
-    https://arxiv.org/abs/1910.07467
+    Refer to https://arxiv.org/abs/1910.07467
     """
 
     def __init__(
-            self,
-            hidden_size: int,
-            eps: float = 1e-6,  # the epsilon value used by llama models
+        self,
+        hidden_size: int,
+        eps: float = 1e-6,
     ) -> None:
         super().__init__()
         self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -28,7 +28,7 @@ class RMSNorm(nn.Module):
         x: torch.Tensor,
         residual: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
-        """PyTorch-native implementation. Equivalent to forward()."""
+        """PyTorch-native implementation equivalent to forward()."""
         orig_dtype = x.dtype
         x = x.to(torch.float32)
         if residual is not None:
@@ -49,7 +49,7 @@ class RMSNorm(nn.Module):
         residual: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
         if residual is not None:
-            layernorm_ops.fused_add_rms_norm(
+            ops.fused_add_rms_norm(
                 x,
                 residual,
                 self.weight.data,
@@ -57,7 +57,7 @@ class RMSNorm(nn.Module):
             )
             return x, residual
         out = torch.empty_like(x)
-        layernorm_ops.rms_norm(
+        ops.rms_norm(
             out,
             x,
             self.weight.data,

+ 6 - 3
aphrodite/modeling/layers/linear.py

@@ -275,8 +275,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
                 current_shard_offset += output_size
             packed_dim = getattr(param, "packed_dim", None)
             for shard_id, shard_offset, shard_size in shard_offsets:
-                # If quantized, we need to adjust the offset and size to
-                # account for the packing.
+                # If quantized, we need to adjust the offset and size to account
+                # for the packing.
                 if packed_dim == output_dim:
                     shard_size = shard_size // param.pack_factor
                     shard_offset = shard_offset // param.pack_factor
@@ -423,7 +423,10 @@ class QKVParallelLinear(ColumnParallelLinear):
                 shard_offset = shard_offset // param.pack_factor
             param_data = param_data.narrow(output_dim, shard_offset,
                                            shard_size)
-            shard_id = tp_rank // self.num_kv_head_replicas
+            if loaded_shard_id == "q":
+                shard_id = tp_rank
+            else:
+                shard_id = tp_rank // self.num_kv_head_replicas
             start_idx = shard_id * shard_size
             loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                  shard_size)

+ 0 - 368
aphrodite/modeling/layers/moe.py

@@ -1,368 +0,0 @@
-from typing import Tuple
-
-import torch
-from torch import nn
-import torch.nn.functional as F
-import triton
-import triton.language as tl
-
-from aphrodite._C import ops
-from aphrodite.modeling.layers.linear import ReplicatedLinear
-from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_all_reduce)
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
-from aphrodite.modeling.utils import set_weight_attrs
-
-
-class MoE(nn.Module):
-    """A tensor parallel MOE that shards each expert across all ranks.
-    Each expert's weights are sharded across all ranks. The forward pass
-    will first expand and group the hidden states by experts, then compute
-    the per-rank MLP output of each expert using grouped gemm, and finally
-    reduce the output across ranks.
-    """
-
-    def __init__(
-        self,
-        num_experts: int,
-        top_k: int,
-        hidden_size: int,
-        intermediate_size: int,
-    ):
-        super().__init__()
-        tp_size = get_tensor_model_parallel_world_size()
-        self.num_total_experts = num_experts
-        self.top_k = top_k
-        self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size // tp_size
-
-        self.gate = ReplicatedLinear(self.hidden_size,
-                                     self.num_total_experts,
-                                     bias=False,
-                                     linear_method=None)
-
-        self.w1s = nn.Parameter(
-            torch.empty(self.num_total_experts,
-                        self.hidden_size,
-                        self.intermediate_size,
-                        device="cuda"))
-        self.w2s = nn.Parameter(
-            torch.empty(self.num_total_experts,
-                        self.intermediate_size,
-                        self.hidden_size,
-                        device="cuda"))
-        self.w3s = nn.Parameter(
-            torch.empty(self.num_total_experts,
-                        self.hidden_size,
-                        self.intermediate_size,
-                        device="cuda"))
-
-        set_weight_attrs(self.w1s, {
-            "weight_loader": self.weight_loader,
-            "tp_type": "column"
-        })
-        set_weight_attrs(self.w2s, {
-            "weight_loader": self.weight_loader,
-            "tp_type": "row"
-        })
-        set_weight_attrs(self.w3s, {
-            "weight_loader": self.weight_loader,
-            "tp_type": "column"
-        })
-
-    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
-                      expert_id: int):
-        tp_rank = get_tensor_model_parallel_rank()
-        loaded_weight = loaded_weight.t()
-        # The parallel dimension is 1 for column-parallel, and 0 for
-        # row-parallel.
-        parallel_dim = 1 if getattr(param, "tp_type", None) == "column" else 0
-        param_data = param.data
-        shard_size = param_data.shape[parallel_dim + 1]
-        start_idx = tp_rank * shard_size
-        loaded_weight = loaded_weight.narrow(parallel_dim, start_idx,
-                                             shard_size)
-        assert param_data[expert_id].shape == loaded_weight.shape
-        param_data[expert_id].copy_(loaded_weight)
-
-    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
-        batch_size, sequence_length, hidden_size = hidden_states.shape
-        hidden_states = hidden_states.view(-1, self.hidden_size)
-        # router_logits: (batch * sequence_length, n_experts)
-        router_logits, _ = self.gate(hidden_states)
-
-        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
-        routing_weights, selected_experts = torch.topk(routing_weights,
-                                                       self.top_k,
-                                                       dim=-1)
-        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
-
-        # Step 1: expand and permute hidden states and routing weights to group
-        #        hidden states by experts.
-        (expanded_hidden_states, experts_range, expanded_weights,
-         reverse_indices) = self.expand_and_permutate_hidden_states(
-             hidden_states, selected_experts, routing_weights)
-
-        # Step 2: compute the output of each expert.
-        expanded_hidden_states = self.apply_experts_ffn(
-            expanded_hidden_states, experts_range, self.w1s.data,
-            self.w2s.data, self.w3s.data)
-
-        # Step 3: apply weights to the output of each expert, and reduce
-        # across ranks.
-        expanded_hidden_states.mul_(expanded_weights.unsqueeze(-1))
-        tensor_model_parallel_all_reduce(expanded_hidden_states)
-
-        # Step 4: merge the output of each expert, according to the indices.
-        return self.merge_expert_outputs(expanded_hidden_states,
-                                         reverse_indices).view(
-                                             batch_size, sequence_length,
-                                             hidden_size)
-
-    def expand_and_permutate_hidden_states(
-        self,
-        hidden_states: torch.Tensor,
-        selected_experts: torch.Tensor,
-        routing_weights: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
-        """Expand and group hidden states and routing weights according
-        to the selected experts.
-
-        Args:
-            hidden_states (torch.Tensor): [batch_size, hidden_size]
-                hidden states.
-            selected_experts (torch.Tensor): [batch_size, top_k_experts]
-                the indices of the selected experts.
-            routing_weights (torch.Tensor): [batch_size, top_k_experts]
-                the routing weights of the selected experts.
-
-        Returns:
-            expanded_hidden_states: [batch_size * top_k_experts, hidden_size]
-                expanded hidden states that rows are grouped by experts.
-            cum_experts_range: [num_experts + 1] the cumulative range of the
-                experts in expanded_hidden_states, in the first dimension.
-            expanded_weights: [batch_size * top_k_experts] the expanded
-                expert weights for each row in expanded_hidden_states.
-            reverse_indices: [batch_size * top_k_experts] the indices of each
-                row in expanded_hidden_states which maps back to the original
-                hidden states.
-        """
-        reverse_indices = torch.argsort(selected_experts.view(-1), dim=-1)
-        cum_experts_range = torch.zeros(self.num_total_experts + 1,
-                                        dtype=torch.int32,
-                                        device=hidden_states.device)
-        num_rows_per_expert = torch.zeros(self.num_total_experts,
-                                          dtype=torch.int32,
-                                          device=hidden_states.device)
-        ops.bincount(selected_experts.view(-1), num_rows_per_expert)
-        torch.cumsum(num_rows_per_expert, dim=0, out=cum_experts_range[1:])
-        expanded_weights = routing_weights.view(-1)[reverse_indices]
-        reverse_indices.div_(self.top_k, rounding_mode="floor")
-        return (hidden_states[reverse_indices], cum_experts_range,
-                expanded_weights, reverse_indices)
-
-    def apply_experts_ffn(
-        self,
-        expanded_hidden_states: torch.
-        Tensor,  # [batch_size * top_k_experts, hidden_size]
-        cum_experts_range: torch.Tensor,  # [num_experts + 1]
-        w1s: torch.Tensor,  # [num_experts, hidden_size, ffn_dim]
-        w2s: torch.Tensor,  # [num_experts, ffn_dim, hidden_size]
-        w3s: torch.Tensor,  # [num_experts, hidden_size, ffn_dim]
-    ) -> torch.Tensor:  # [batch_size * top_k_experts, hidden_size]
-        grouped_w1_out = grouped_matmul(expanded_hidden_states,
-                                        cum_experts_range, w1s, "silu")
-        grouped_w3_out = grouped_matmul(expanded_hidden_states,
-                                        cum_experts_range, w3s)
-        grouped_w1_out.mul_(grouped_w3_out)
-        return grouped_matmul(grouped_w1_out, cum_experts_range, w2s)
-
-    def merge_expert_outputs(
-            self,
-            expanded_hidden_states: torch.
-        Tensor,  # [batch_size * top_k_experts, hidden_size]
-            reverse_indices,  # [batch_size * top_k_experts]
-    ) -> torch.Tensor:
-        out = torch.zeros(expanded_hidden_states.shape[0] // self.top_k,
-                          self.hidden_size,
-                          device=expanded_hidden_states.device,
-                          dtype=expanded_hidden_states.dtype)
-        out.index_add_(0, reverse_indices, expanded_hidden_states)
-        return out
-
-
-# The following code is adapted from
-# https://github.com/openai/triton/blob/main/python/tutorials/11-grouped-gemm.py
-@triton.jit
-def grouped_matmul_kernel(
-    # [batch_size, k], where each group are stored compactly in the batch
-    # dimension. The range of each group is specified in cumulative_m_range.
-    group_a_ptr,
-    # [num_groups, k, n]
-    group_b_ptr,
-    # [batch_size, n], where each group are stored compactly in the batch
-    # dimension. The range of each group is specified in cumulative_m_range.
-    group_c_ptr,
-    # num of gemm problems
-    group_size,
-    # for each gemm problem with size <m, n, k>, m is stored in
-    # cumulative_m_range[i + i] - cumulative_m_range[i].
-    # n and k are the same for all problems.
-    cumulative_m_range,
-    n,
-    k,
-    # group_a_ptr.stride(0)
-    stride_a0,
-    # group_b_ptr.stride(1)
-    stride_b1,
-    # group_c_ptr.stride(0)
-    stride_c0,
-    # number of virtual SM
-    NUM_SM: tl.constexpr,
-    # tile sizes
-    BLOCK_SIZE_M: tl.constexpr,
-    BLOCK_SIZE_N: tl.constexpr,
-    BLOCK_SIZE_K: tl.constexpr,
-    ACTIVATION: tl.constexpr,
-):
-    tile_idx = tl.program_id(0)
-    last_problem_end = 0
-    for g in range(group_size):
-        # get the gemm size of the current problem
-        a_offset = tl.load(cumulative_m_range + g)
-        gm = tl.load(cumulative_m_range + g + 1) - a_offset
-        gn = n
-        gk = k
-        num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
-        num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
-        num_tiles = num_m_tiles * num_n_tiles
-        # iterate through the tiles in the current gemm problem
-        # pylint: disable=chained-comparison
-        while (tile_idx >= last_problem_end
-               and tile_idx < last_problem_end + num_tiles):
-
-            # pick up a tile from the current gemm problem
-            k = gk
-            a_ptr = group_a_ptr + a_offset * stride_a0
-            b_ptr = group_b_ptr + g * k * n
-            c_ptr = group_c_ptr + a_offset * stride_c0
-            # figure out tile coordinates
-            tile_idx_in_gemm = tile_idx - last_problem_end
-            tile_m_idx = tile_idx_in_gemm // num_n_tiles
-            tile_n_idx = tile_idx_in_gemm % num_n_tiles
-
-            # do regular gemm here
-            offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
-            offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
-            offs_k = tl.arange(0, BLOCK_SIZE_K)
-            a_ptrs = a_ptr + offs_am[:, None] * stride_a0 + offs_k[None, :]
-            b_ptrs = b_ptr + offs_k[:, None] * stride_b1 + offs_bn[None, :]
-            accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
-                                   dtype=tl.float32)
-            for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
-                # hint to Triton compiler to do proper loop pipelining
-                tl.multiple_of(a_ptrs, [16, 16])
-                tl.multiple_of(b_ptrs, [16, 16])
-
-                a = tl.load(a_ptrs,
-                            mask=(offs_k[None, :] < k - kk * BLOCK_SIZE_K) &
-                            (offs_am[:, None] < gm),
-                            other=0.0)
-                b = tl.load(b_ptrs,
-                            mask=(offs_k[:, None] < k - kk * BLOCK_SIZE_K) &
-                            (offs_bn[None, :] < gn),
-                            other=0.0)
-                accumulator += tl.dot(a, b)
-                a_ptrs += BLOCK_SIZE_K
-                b_ptrs += BLOCK_SIZE_K * stride_b1
-
-            if ACTIVATION == "silu":
-                accumulator = silu(accumulator)
-            c = accumulator.to(group_c_ptr.dtype.element_ty)
-
-            offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
-            offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
-            c_ptrs = c_ptr + stride_c0 * offs_cm[:, None] + offs_cn[None, :]
-            c_mask = (offs_cm[:, None] < gm) & (offs_cn[None, :] < gn)
-
-            tl.store(c_ptrs, c, mask=c_mask)
-
-            # go to the next tile by advancing NUM_SM
-            tile_idx += NUM_SM
-
-        # get ready to go to the next gemm problem
-        last_problem_end = last_problem_end + num_tiles
-
-
-@triton.jit
-def silu(x):
-    return x * tl.sigmoid(x)
-
-
-def grouped_matmul(
-        input: torch.Tensor,  # pylint: disable=redefined-builtin
-        cumulative_group_range: torch.Tensor,
-        group_b_ptr: torch.Tensor,
-        activation: str = ""):
-    """Performs a grouped matrix-matrix product of matrices stored in input
-    and group_b_ptr.
-
-    input is a tensor of shape [batch_size, k] where each group are stored
-    compactly in the batch dimension. The range of each group is specified
-    in cumulative_group_range. This allows the input to have fixed shape
-    regardless of the group sizes.
-
-    Args:
-        input (torch.Tensor): [batch_size, k] compact input.
-        cumulative_group_range (torch.Tensor): [num_groups + 1] the cumulative
-            range of the groups in input.
-        group_b_ptr (torch.Tensor): [num_groups, k, n] the second matrix.
-        activation (str, optional): "" or "silu". Defaults to "".
-
-    Returns:
-        torch.Tensor: [batch_size, n] compact output where groups
-            are stored compactly in the batch dimension.
-    """
-    device = torch.device("cuda")
-    assert cumulative_group_range.shape[0] == group_b_ptr.shape[0] + 1
-    group_size = cumulative_group_range.shape[0] - 1
-    output = torch.zeros(input.shape[0],
-                         group_b_ptr.shape[2],
-                         device=device,
-                         dtype=input.dtype)
-    BLOCK_SIZE_M = 16
-    BLOCK_SIZE_N = 64
-    BLOCK_SIZE_K = 32
-    num_warps = 2
-    NUM_SM = 128
-    num_stages = 5
-    # hand tune the block size for different problem sizes.
-    if input.shape[0] >= 8:
-        num_warps = 4
-        BLOCK_SIZE_N = 128
-    if input.shape[0] >= 32:
-        num_warps = 4
-        BLOCK_SIZE_M = 32
-        BLOCK_SIZE_N = 128
-    # we use a fixed number of CTA, and it's auto-tunable
-    grid = lambda META: (META["NUM_SM"], )
-    grouped_matmul_kernel[grid](group_a_ptr=input,
-                                group_b_ptr=group_b_ptr,
-                                group_c_ptr=output,
-                                group_size=group_size,
-                                cumulative_m_range=cumulative_group_range,
-                                n=group_b_ptr.shape[2],
-                                k=group_b_ptr.shape[1],
-                                stride_a0=input.stride(0),
-                                stride_b1=group_b_ptr.stride(1),
-                                stride_c0=output.stride(0),
-                                ACTIVATION=activation,
-                                BLOCK_SIZE_M=BLOCK_SIZE_M,
-                                BLOCK_SIZE_N=BLOCK_SIZE_N,
-                                BLOCK_SIZE_K=BLOCK_SIZE_K,
-                                NUM_SM=NUM_SM,
-                                num_warps=num_warps,
-                                num_stages=num_stages)
-
-    return output

+ 5 - 9
aphrodite/modeling/layers/quantization/__init__.py

@@ -1,20 +1,16 @@
 from typing import Type
 
-from aphrodite.modeling.layers.quantization.squeezellm import SqueezeLLMConfig
+from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
+from aphrodite.modeling.layers.quantization.awq import AWQConfig
 from aphrodite.modeling.layers.quantization.gptq import GPTQConfig
-from aphrodite.modeling.layers.quantization.base_config import (
-    QuantizationConfig)
-from aphrodite.common.utils import is_hip
+from aphrodite.modeling.layers.quantization.squeezellm import SqueezeLLMConfig
 
 _QUANTIZATION_CONFIG_REGISTRY = {
-    "squeezellm": SqueezeLLMConfig,
+    "awq": AWQConfig,
     "gptq": GPTQConfig,
+    "squeezellm": SqueezeLLMConfig,
 }
 
-if not is_hip():
-    from aphrodite.modeling.layers.quantization.awq import AWQConfig
-    _QUANTIZATION_CONFIG_REGISTRY["awq"] = AWQConfig
-
 
 def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
     if quantization not in _QUANTIZATION_CONFIG_REGISTRY:

+ 4 - 12
aphrodite/modeling/layers/quantization/awq.py

@@ -2,18 +2,11 @@ from typing import Any, Dict, List, Optional
 
 import torch
 from torch.nn.parameter import Parameter
+
+from aphrodite._C import ops
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               set_weight_attrs)
 from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
-from aphrodite.common.logger import init_logger
-from aphrodite.common.utils import is_hip
-
-logger = init_logger(__name__)
-
-if is_hip():
-    logger.warning("AWQ is not supported on ROCm.")
-else:
-    from aphrodite._C import ops as quantization_ops
 
 
 class AWQConfig(QuantizationConfig):
@@ -87,7 +80,7 @@ class AWQLinearMethod(LinearMethodBase):
     def create_weights(self, input_size_per_partition: int,
                        output_size_per_partition: int, input_size: int,
                        output_size: int,
-                       params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
+                       params_dtype: torch.dtype) -> Dict[str, Any]:
         if input_size_per_partition % self.quant_config.group_size != 0:
             raise ValueError(
                 "The input size is not aligned with the quantized "
@@ -160,8 +153,7 @@ class AWQLinearMethod(LinearMethodBase):
         pack_factor = self.quant_config.pack_factor
         out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
         reshaped_x = x.reshape(-1, x.shape[-1])
-        out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
-                                        pack_factor)
+        out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor)
         if bias is not None:
             out = out + bias
         return out.reshape(out_shape)

+ 5 - 41
aphrodite/modeling/layers/quantization/base_config.py

@@ -1,5 +1,5 @@
 from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
 
 import torch
 
@@ -50,44 +50,6 @@ class QuantizationConfig(ABC):
         raise ValueError(f"Cannot find any of {keys} in the model's "
                          "quantization config.")
 
-    @classmethod
-    def get_packed_tensors(cls) -> Dict[str, int]:
-        """Returns a dictionary of packed tensor names and their pack dims."""
-        raise NotImplementedError
-
-    @classmethod
-    def get_packed_dim(cls, tensor_name: str) -> Optional[int]:
-        """Returns the pack dim of a tensor if it is packed.
-
-        A tensor is considered packed if each element in the tensor is a
-        packed representation of multiple elements in the original tensor.
-        For example, an INT32 element in the tensor may represent 8 INT4
-        elements in the original tensor.
-        If the tensor is not packed, returns None.
-        """
-        packed_tensors = cls.get_packed_tensors()
-        for packed_tensor_name, pack_dim in packed_tensors.items():
-            if packed_tensor_name in tensor_name:
-                return pack_dim
-        return None
-
-    @classmethod
-    def get_transposed_tensor_names(cls) -> List[str]:
-        raise NotImplementedError
-
-    @classmethod
-    def is_transposed(cls, tensor_name: str) -> bool:
-        """Returns True if a tensor is transposed relative to nn.Linear.weight.
-        """
-        return any(tag in tensor_name
-                   for tag in cls.get_transposed_tensor_names())
-
-    def get_col_parallel_tensor_names(self) -> List[str]:
-        raise NotImplementedError
-
-    def get_row_parallel_tensor_names(self) -> List[str]:
-        raise NotImplementedError
-
     @abstractmethod
     def get_linear_method(self) -> LinearMethodBase:
         """Get the linear method to use for the quantized linear layer."""
@@ -95,6 +57,8 @@ class QuantizationConfig(ABC):
 
     @abstractmethod
     def get_scaled_act_names(self) -> List[str]:
-        """"Returns the activation function names that should be post-scaled.
-            Currently AWQ only."""
+        """Returns the activation function names that should be post-scaled.
+
+        For now, this is only used by AWQ.
+        """
         raise NotImplementedError

+ 8 - 8
aphrodite/modeling/layers/quantization/gptq.py

@@ -6,7 +6,7 @@ from fractions import Fraction
 import torch
 from torch.nn.parameter import Parameter
 
-from aphrodite._C import ops as quantization_ops
+from aphrodite._C import ops
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               set_weight_attrs)
 from aphrodite.modeling.layers.quantization.base_config import (
@@ -207,13 +207,13 @@ class GPTQLinearMethod(LinearMethodBase):
             else:
                 weights["g_idx"] = torch.empty((1, 1), device="meta")
             weights["exllama_state"] = ExllamaState.READY
-            quantization_ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
-                                          self.quant_config.weight_bits)
-        output = quantization_ops.gptq_gemm(
-            reshaped_x, weights["qweight"], weights["qzeros"],
-            weights["scales"], weights["g_idx"],
-            weights["exllama_state"] == ExllamaState.READY,
-            self.quant_config.weight_bits)
+            ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
+                             self.quant_config.weight_bits)
+        output = ops.gptq_gemm(reshaped_x, weights["qweight"],
+                               weights["qzeros"], weights["scales"],
+                               weights["g_idx"],
+                               weights["exllama_state"] == ExllamaState.READY,
+                               self.quant_config.weight_bits)
         if bias is not None:
             output = output + bias
         return output.reshape(out_shape)

+ 4 - 7
aphrodite/modeling/layers/quantization/squeezellm.py

@@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional
 import torch
 from torch.nn.parameter import Parameter
 
-from aphrodite._C import ops as quantization_ops
+from aphrodite._C import ops
 from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               set_weight_attrs)
 from aphrodite.modeling.layers.quantization.base_config import QuantizationConfig
@@ -70,7 +70,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
     def create_weights(self, input_size_per_partition: int,
                        output_size_per_partition: int, input_size: int,
                        output_size: int,
-                       params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
+                       params_dtype: torch.dtype) -> Dict[str, Any]:
         if input_size_per_partition % self.quant_config.pack_factor != 0:
             raise ValueError(
                 "The input size is not aligned with the quantized "
@@ -119,15 +119,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
         reshaped_x = x.reshape(-1, x.shape[-1])
         if is_hip():
             out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
-            quantization_ops.squeezellm_gemm(reshaped_x, qweight, out_f,
-                                             lookup_table)
+            ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
             out = out_f.to(dtype=torch.float16)
-            # do something specific for HIP
         else:
             # NOTE: The output tensor should be zero-initialized.
             out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
-            quantization_ops.squeezellm_gemm(reshaped_x, qweight, out,
-                                             lookup_table)
+            ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
 
         if bias is not None:
             out = out + bias

+ 12 - 9
aphrodite/modeling/layers/rotary_embedding.py

@@ -22,13 +22,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Rotary Positional Embeddings."""
-from typing import Any, Dict, Optional, Tuple, Union
 import math
+from typing import Any, Dict, Optional, Tuple, Union
 
 import torch
 import torch.nn as nn
 
-from aphrodite._C import ops as pos_encoding_ops
+from aphrodite._C import ops
 
 
 def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -103,7 +103,7 @@ class RotaryEmbedding(nn.Module):
         query: torch.Tensor,
         key: torch.Tensor,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """PyTorch-native implementation. Equivalent to forward()."""
+        """PyTorch-native implementation equivalent to forward()."""
         query = query.view(*query.shape[:-1], -1, self.head_size)
         key = key.view(*key.shape[:-1], -1, self.head_size)
 
@@ -144,11 +144,10 @@ class RotaryEmbedding(nn.Module):
         query: torch.Tensor,
         key: torch.Tensor,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
-        # pos_encoding_ops.rotary_embedding() is an in-place operation that
+        # ops.rotary_embedding() is an in-place operation that
         # updates the query and key tensors.
-        pos_encoding_ops.rotary_embedding(positions, query, key,
-                                          self.head_size, self.cos_sin_cache,
-                                          self.is_neox_style)
+        ops.rotary_embedding(positions, query, key, self.head_size,
+                             self.cos_sin_cache, self.is_neox_style)
         return query, key
 
 
@@ -248,7 +247,7 @@ def _yarn_find_correction_range(low_rot: int,
     high = math.ceil(
         _yarn_find_correction_dim(high_rot, dim, base,
                                   max_position_embeddings))
-    return max(low, 0), min(high, dim - 1)  # clamp values just in case
+    return max(low, 0), min(high, dim - 1)  # Clamp values just in case
 
 
 def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
@@ -270,7 +269,10 @@ def _yarn_get_mscale(scale: float = 1) -> float:
 
 
 class YaRNScalingRotaryEmbedding(RotaryEmbedding):
-    """Rotary embedding extended with YaRN method (Peng et al.)"""
+    """RotaryEmbedding extended with YaRN method.
+
+    Credits to Peng et al. github.com/jquesnelle/yarn
+    """
 
     def __init__(
         self,
@@ -342,6 +344,7 @@ def get_rope(
            tuple(rope_scaling.items()) if rope_scaling is not None else None)
     if key in _ROPE_DICT:
         return _ROPE_DICT[key]
+
     if rope_scaling is None:
         rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
                                      is_neox_style)

+ 47 - 68
aphrodite/modeling/layers/sampler.py

@@ -8,7 +8,7 @@ from aphrodite.modeling.sampling_metadata import (SamplingMetadata,
                                                   OutputMetadata,
                                                   SamplingTensors)
 from aphrodite.modeling.megatron.communication_op import (
-    tensor_model_parallel_all_gather)
+    tensor_model_parallel_gather)
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import (PromptLogprobs, SampleLogprobs,
                                        SamplerOutput, SequenceData,
@@ -40,30 +40,35 @@ class Sampler(nn.Module):
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
         embedding_bias: Optional[torch.Tensor] = None,
-    ) -> SamplerOutput:
+    ) -> Optional[SamplerOutput]:
         # Get the hidden states that we use for sampling.
         hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
 
         # Get the logits for the next tokens.
         logits = _get_logits(hidden_states, embedding, embedding_bias,
                              self.vocab_size)
+
+        # Only perform sampling in the driver worker.
+        # Note: `_get_logits` is still distributed across TP workers because
+        # the `embedding` weight is distributed across TP workers.
+        # TODO: Change the get_logits part to a separate stage.
+        if not sampling_metadata.perform_sampling:
+            return None
+
+        assert logits is not None
         _, vocab_size = logits.shape
 
         output_metadata = OutputMetadata()
 
+        # Apply logits processors (if any)
+        logits = _apply_logits_processors(logits, sampling_metadata)
+
         # Prepare sampling tensors with pinned memory to avoid blocking.
         (sampling_tensors, do_temperatures, do_penalties, do_topks, do_topps,
          do_topas, do_minps, do_tfss, do_eta_cutoffs, do_epsilon_cutoffs,
          do_typical_ps, do_mirostat) = (SamplingTensors.from_sampling_metadata(
              sampling_metadata, vocab_size, logits.device, logits.dtype))
 
-        if do_temperatures:
-            # Apply temperature scaling.
-            # Use in-place division to avoid creating a new tensor.
-            logits = _apply_temperature(logits, sampling_tensors.temperatures,
-                                        sampling_tensors.dynatemp_ranges,
-                                        sampling_tensors.dynatemp_exps)
-
         if do_penalties:
             logits = _apply_penalties(logits, sampling_tensors.prompt_tokens,
                                       sampling_tensors.output_tokens,
@@ -71,16 +76,16 @@ class Sampler(nn.Module):
                                       sampling_tensors.frequency_penalties,
                                       sampling_tensors.repetition_penalties)
 
+        if do_temperatures:
+            logits = _apply_temperature(logits, sampling_tensors.temperatures,
+                                        sampling_tensors.dynatemp_ranges,
+                                        sampling_tensors.dynatemp_exps)
+
         if do_topks or do_topps or do_topas or do_minps:
             logits = _apply_alphabet_soup(logits, sampling_tensors.top_ps,
                                           sampling_tensors.top_ks,
                                           sampling_tensors.top_as,
                                           sampling_tensors.min_ps)
-        # Apply Eta/epsilon cutoff, typical_p, and tail-free sampling,
-        # as described in:
-        # https://arxiv.org/abs/2210.15191
-        # https://arxiv.org/abs/2202.00666
-        # https://www.trentonbricken.com/Tail-Free-Sampling/
         if do_tfss:
             logits = _apply_tfs(logits, sampling_tensors.tfss)
         if do_eta_cutoffs:
@@ -95,10 +100,6 @@ class Sampler(nn.Module):
         banned_tokens = _get_custom_token_bans(sampling_metadata)
         assert len(banned_tokens) == logits.shape[0]
         logits = _apply_token_bans(logits, banned_tokens)
-
-        logits = _apply_logits_processors(sampling_metadata, logits,
-                                          sampling_tensors.output_tokens)
-
         if do_mirostat:
             logits = _mirostat(logits, sampling_tensors, output_metadata)
 
@@ -121,14 +122,15 @@ class Sampler(nn.Module):
 
 def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
                 embedding_bias: Optional[torch.Tensor],
-                vocab_size: int) -> torch.Tensor:
+                vocab_size: int) -> Optional[torch.Tensor]:
     # Get the logits for the next tokens.
     logits = torch.matmul(hidden_states, embedding.t())
     if embedding_bias is not None:
         logits += embedding_bias
-    logits = tensor_model_parallel_all_gather(logits)
+    logits = tensor_model_parallel_gather(logits)
     # Remove paddings in vocab (if any).
-    logits = logits[:, :vocab_size]
+    if logits is not None:
+        logits = logits[:, :vocab_size]
     return logits
 
 
@@ -136,31 +138,9 @@ def _prune_hidden_states(
     hidden_states: torch.Tensor,
     sampling_metadata: SamplingMetadata,
 ) -> torch.Tensor:
-    selected_token_indices: List[int] = []
-    start_idx = 0
-    max_prompt_len = max(
-        sampling_metadata.prompt_lens) if sampling_metadata.prompt_lens else 1
-    for i, seq_group in enumerate(sampling_metadata.seq_groups):
-        seq_ids, sampling_params = seq_group
-        if i < sampling_metadata.num_prompts:
-            assert len(seq_ids) == 1, "Prompt input should have only one seq."
-            prompt_len = sampling_metadata.prompt_lens[i]
-            if sampling_params.prompt_logprobs is not None:
-                selected_token_indices.extend(
-                    range(start_idx, start_idx + prompt_len - 1))
-            selected_token_indices.append(start_idx + prompt_len - 1)
-            start_idx += max_prompt_len
-        else:
-            num_seqs = len(seq_ids)
-            selected_token_indices.extend(
-                range(start_idx, start_idx + num_seqs))
-            start_idx += num_seqs
-
-    selected_token_indices = torch.tensor(selected_token_indices,
-                                          dtype=torch.long,
-                                          device=hidden_states.device)
     hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
-    return hidden_states.index_select(0, selected_token_indices)
+    return hidden_states.index_select(0,
+                                      sampling_metadata.selected_token_indices)
 
 
 def _get_bin_counts_and_mask(
@@ -194,19 +174,27 @@ def _get_custom_token_bans(
     return banned_tokens
 
 
-def _apply_logits_processors(sampling_metadata: SamplingMetadata,
-                             logits: torch.Tensor,
-                             output_tokens: List[List[int]]) -> torch.Tensor:
-    seq_offset = 0
-
+def _apply_logits_processors(
+    logits: torch.Tensor,
+    sampling_metadata: SamplingMetadata,
+) -> torch.Tensor:
+    logits_row_idx = 0
+    found_logits_processors = False
     for seq_ids, sampling_params in sampling_metadata.seq_groups:
-        seq_end = seq_offset + len(seq_ids)
-
-        for proc in sampling_params.logits_processors:
-            proc(logits[seq_offset:seq_end], output_tokens[seq_offset:seq_end])
-
-        seq_offset = seq_end
-
+        logits_processors = sampling_params.logits_processors
+        if logits_processors:
+            found_logits_processors = True
+            for seq_id in seq_ids:
+                logits_row = logits[logits_row_idx]
+                token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
+                for logits_processor in logits_processors:
+                    logits_row = logits_processor(token_ids, logits_row)
+                logits[logits_row_idx] = logits_row
+                logits_row_idx += 1
+        else:
+            logits_row_idx += len(seq_ids)
+    if found_logits_processors:
+        assert logits_row_idx == logits.shape[0]
     return logits
 
 
@@ -537,20 +525,11 @@ def _sample(
     sampling_metadata: SamplingMetadata,
 ) -> List[Tuple[List[int], List[int]]]:
     categorized_seq_group_ids = {t: [] for t in SamplingType}
-    categorized_sample_indices = {t: [] for t in SamplingType}
-    start_idx = 0
+    categorized_sample_indices = sampling_metadata.categorized_sample_indices
     for i, seq_group in enumerate(sampling_metadata.seq_groups):
-        seq_ids, sampling_params = seq_group
+        _, sampling_params = seq_group
         sampling_type = sampling_params.sampling_type
-        if (i < sampling_metadata.num_prompts
-                and sampling_params.prompt_logprobs is not None):
-            prompt_len = sampling_metadata.prompt_lens[i]
-            start_idx += prompt_len - 1
         categorized_seq_group_ids[sampling_type].append(i)
-        num_seqs = len(seq_ids)
-        categorized_sample_indices[sampling_type].extend(
-            range(start_idx, start_idx + num_seqs))
-        start_idx += num_seqs
 
     sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
     sample_metadata = {}

+ 0 - 0
aphrodite/modeling/layers/triton_kernel/__init__.py


+ 728 - 0
aphrodite/modeling/layers/triton_kernel/prefix_prefill.py

@@ -0,0 +1,728 @@
+# The kernels in this file are adapted from LightLLM's context_attention_fwd:
+# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
+
+import torch
+import triton
+import triton.language as tl
+
+if triton.__version__ >= "2.1.0":
+
+    @triton.jit
+    def _fwd_kernel(
+        Q,
+        K,
+        V,
+        K_cache,
+        V_cache,
+        B_Loc,
+        sm_scale,
+        B_Start_Loc,
+        B_Seqlen,
+        B_Ctxlen,
+        block_size,
+        x,
+        Out,
+        stride_b_loc_b,
+        stride_b_loc_s,
+        stride_qbs,
+        stride_qh,
+        stride_qd,
+        stride_kbs,
+        stride_kh,
+        stride_kd,
+        stride_vbs,
+        stride_vh,
+        stride_vd,
+        stride_obs,
+        stride_oh,
+        stride_od,
+        stride_k_cache_bs,
+        stride_k_cache_h,
+        stride_k_cache_d,
+        stride_k_cache_bl,
+        stride_k_cache_x,
+        stride_v_cache_bs,
+        stride_v_cache_h,
+        stride_v_cache_d,
+        stride_v_cache_bl,
+        BLOCK_M: tl.constexpr,
+        BLOCK_DMODEL: tl.constexpr,
+        BLOCK_N: tl.constexpr,
+    ):
+        cur_batch = tl.program_id(0)
+        cur_head = tl.program_id(1)
+        start_m = tl.program_id(2)
+
+        cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
+        cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
+        cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
+
+        block_start_loc = BLOCK_M * start_m
+
+        # initialize offsets
+        offs_n = tl.arange(0, BLOCK_N)
+        offs_d = tl.arange(0, BLOCK_DMODEL)
+        offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+        off_q = (
+            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
+            cur_head * stride_qh + offs_d[None, :] * stride_qd)
+
+        q = tl.load(
+            Q + off_q,
+            mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
+            other=0.0)
+
+        # # initialize pointer to m and l
+        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+        l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+
+        for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
+            start_n = tl.multiple_of(start_n, BLOCK_N)
+            # -- compute qk ----
+            bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
+                         ((start_n + offs_n) // block_size) * stride_b_loc_s,
+                         mask=(start_n + offs_n) < cur_batch_ctx_len,
+                         other=0)
+            off_k = (bn[None, :] * stride_k_cache_bs +
+                     cur_head * stride_k_cache_h +
+                     (offs_d[:, None] // x) * stride_k_cache_d +
+                     ((start_n + offs_n[None, :]) % block_size) *
+                     stride_k_cache_bl +
+                     (offs_d[:, None] % x) * stride_k_cache_x)
+            off_v = (
+                bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
+                offs_d[None, :] * stride_v_cache_d +
+                (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
+            k = tl.load(K_cache + off_k,
+                        mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
+                        other=0.0)
+
+            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+            qk += tl.dot(q, k)
+            qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
+                          float("-inf"))
+            qk *= sm_scale
+
+            # -- compute m_ij, p, l_ij
+            m_ij = tl.max(qk, 1)
+            p = tl.exp(qk - m_ij[:, None])
+            l_ij = tl.sum(p, 1)
+            # -- update m_i and l_i
+            m_i_new = tl.maximum(m_i, m_ij)
+            alpha = tl.exp(m_i - m_i_new)
+            beta = tl.exp(m_ij - m_i_new)
+            l_i_new = alpha * l_i + beta * l_ij
+            # -- update output accumulator --
+            # scale p
+            p_scale = beta / l_i_new
+            p = p * p_scale[:, None]
+            # scale acc
+            acc_scale = l_i / l_i_new * alpha
+            acc = acc * acc_scale[:, None]
+            # update acc
+            v = tl.load(V_cache + off_v,
+                        mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
+                        other=0.0)
+
+            p = p.to(v.dtype)
+            acc += tl.dot(p, v)
+            # # update m_i and l_i
+            l_i = l_i_new
+            m_i = m_i_new
+
+        off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
+                 offs_d[:, None] * stride_kd)
+        off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
+                 offs_d[None, :] * stride_vd)
+        k_ptrs = K + off_k
+        v_ptrs = V + off_v
+
+        block_mask = tl.where(
+            block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
+
+        for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
+            start_n = tl.multiple_of(start_n, BLOCK_N)
+            # -- compute qk ----
+            k = tl.load(k_ptrs +
+                        (cur_batch_in_all_start_index + start_n) * stride_kbs,
+                        mask=(start_n + offs_n[None, :]) <
+                        cur_batch_seq_len - cur_batch_ctx_len,
+                        other=0.0)
+
+            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+            qk += tl.dot(q, k)
+            qk *= sm_scale
+            qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
+                          float("-inf"))
+
+            # -- compute m_ij, p, l_ij
+            m_ij = tl.max(qk, 1)
+            p = tl.exp(qk - m_ij[:, None])
+            l_ij = tl.sum(p, 1)
+            # -- update m_i and l_i
+            m_i_new = tl.maximum(m_i, m_ij)
+            alpha = tl.exp(m_i - m_i_new)
+            beta = tl.exp(m_ij - m_i_new)
+            l_i_new = alpha * l_i + beta * l_ij
+            # -- update output accumulator --
+            # scale p
+            p_scale = beta / l_i_new
+            p = p * p_scale[:, None]
+            # scale acc
+            acc_scale = l_i / l_i_new * alpha
+            acc = acc * acc_scale[:, None]
+            # update acc
+            v = tl.load(v_ptrs +
+                        (cur_batch_in_all_start_index + start_n) * stride_vbs,
+                        mask=(start_n + offs_n[:, None]) <
+                        cur_batch_seq_len - cur_batch_ctx_len,
+                        other=0.0)
+
+            p = p.to(v.dtype)
+            acc += tl.dot(p, v)
+            # update m_i and l_i
+            l_i = l_i_new
+            m_i = m_i_new
+        # initialize pointers to output
+        off_o = (
+            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
+            cur_head * stride_oh + offs_d[None, :] * stride_od)
+        out_ptrs = Out + off_o
+        tl.store(out_ptrs,
+                 acc,
+                 mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
+        return
+
+    @triton.jit
+    def _fwd_kernel_flash_attn_v2(
+        Q,
+        K,
+        V,
+        K_cache,
+        V_cache,
+        B_Loc,
+        sm_scale,
+        B_Start_Loc,
+        B_Seqlen,
+        B_Ctxlen,
+        block_size,
+        x,
+        Out,
+        stride_b_loc_b,
+        stride_b_loc_s,
+        stride_qbs,
+        stride_qh,
+        stride_qd,
+        stride_kbs,
+        stride_kh,
+        stride_kd,
+        stride_vbs,
+        stride_vh,
+        stride_vd,
+        stride_obs,
+        stride_oh,
+        stride_od,
+        stride_k_cache_bs,
+        stride_k_cache_h,
+        stride_k_cache_d,
+        stride_k_cache_bl,
+        stride_k_cache_x,
+        stride_v_cache_bs,
+        stride_v_cache_h,
+        stride_v_cache_d,
+        stride_v_cache_bl,
+        BLOCK_M: tl.constexpr,
+        BLOCK_DMODEL: tl.constexpr,
+        BLOCK_N: tl.constexpr,
+    ):
+        cur_batch = tl.program_id(0)
+        cur_head = tl.program_id(1)
+        start_m = tl.program_id(2)
+
+        cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
+        cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
+        cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
+
+        block_start_loc = BLOCK_M * start_m
+
+        # initialize offsets
+        offs_n = tl.arange(0, BLOCK_N)
+        offs_d = tl.arange(0, BLOCK_DMODEL)
+        offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+        off_q = (
+            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
+            cur_head * stride_qh + offs_d[None, :] * stride_qd)
+
+        q = tl.load(
+            Q + off_q,
+            mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
+            other=0.0)
+
+        # # initialize pointer to m and l
+        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+        l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+
+        for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
+            start_n = tl.multiple_of(start_n, BLOCK_N)
+            # -- compute qk ----
+            bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
+                         ((start_n + offs_n) // block_size) * stride_b_loc_s,
+                         mask=(start_n + offs_n) < cur_batch_ctx_len,
+                         other=0)
+            off_k = (bn[None, :] * stride_k_cache_bs +
+                     cur_head * stride_k_cache_h +
+                     (offs_d[:, None] // x) * stride_k_cache_d +
+                     ((start_n + offs_n[None, :]) % block_size) *
+                     stride_k_cache_bl +
+                     (offs_d[:, None] % x) * stride_k_cache_x)
+            off_v = (
+                bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
+                offs_d[None, :] * stride_v_cache_d +
+                (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
+            k = tl.load(K_cache + off_k,
+                        mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
+                        other=0.0)
+
+            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+            qk += tl.dot(q, k)
+            qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
+                          float("-inf"))
+            qk *= sm_scale
+
+            # -- compute m_ij, p, l_ij
+            m_ij = tl.max(qk, 1)
+            m_i_new = tl.maximum(m_i, m_ij)
+            p = tl.math.exp(qk - m_i_new[:, None])
+            l_ij = tl.sum(p, 1)
+            # -- update m_i and l_i
+
+            alpha = tl.math.exp(m_i - m_i_new)
+            l_i_new = alpha * l_i + l_ij
+            # -- update output accumulator --
+            # scale p
+            # scale acc
+            acc_scale = alpha
+            # acc_scale = l_i / l_i_new * alpha
+            acc = acc * acc_scale[:, None]
+            # update acc
+            v = tl.load(V_cache + off_v,
+                        mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
+                        other=0.0)
+
+            p = p.to(v.dtype)
+            acc += tl.dot(p, v)
+            # update m_i and l_i
+            l_i = l_i_new
+            m_i = m_i_new
+
+        off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
+                 offs_d[:, None] * stride_kd)
+        off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
+                 offs_d[None, :] * stride_vd)
+        k_ptrs = K + off_k
+        v_ptrs = V + off_v
+
+        block_mask = tl.where(
+            block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
+
+        for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
+            start_n = tl.multiple_of(start_n, BLOCK_N)
+            # -- compute qk ----
+            k = tl.load(k_ptrs +
+                        (cur_batch_in_all_start_index + start_n) * stride_kbs,
+                        mask=(start_n + offs_n[None, :]) <
+                        cur_batch_seq_len - cur_batch_ctx_len,
+                        other=0.0)
+
+            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+            qk += tl.dot(q, k)
+            qk *= sm_scale
+            qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
+                          float("-inf"))
+
+            # -- compute m_ij, p, l_ij
+            m_ij = tl.max(qk, 1)
+            m_i_new = tl.maximum(m_i, m_ij)
+            p = tl.math.exp(qk - m_i_new[:, None])
+            l_ij = tl.sum(p, 1)
+            # -- update m_i and l_i
+
+            alpha = tl.math.exp(m_i - m_i_new)
+            l_i_new = alpha * l_i + l_ij
+            # -- update output accumulator --
+            # scale p
+            # scale acc
+            acc_scale = alpha
+            # acc_scale = l_i / l_i_new * alpha
+            acc = acc * acc_scale[:, None]
+            # update acc
+            v = tl.load(v_ptrs +
+                        (cur_batch_in_all_start_index + start_n) * stride_vbs,
+                        mask=(start_n + offs_n[:, None]) <
+                        cur_batch_seq_len - cur_batch_ctx_len,
+                        other=0.0)
+
+            p = p.to(v.dtype)
+            acc += tl.dot(p, v)
+            # update m_i and l_i
+            l_i = l_i_new
+            m_i = m_i_new
+
+        # acc /= l_i[:, None]
+        # initialize pointers to output
+        off_o = (
+            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
+            cur_head * stride_oh + offs_d[None, :] * stride_od)
+        out_ptrs = Out + off_o
+        tl.store(out_ptrs,
+                 acc,
+                 mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
+        return
+
+    @triton.jit
+    def _fwd_kernel_alibi(
+        Q,
+        K,
+        V,
+        K_cache,
+        V_cache,
+        B_Loc,
+        sm_scale,
+        B_Start_Loc,
+        B_Seqlen,
+        B_Ctxlen,
+        Alibi_slopes,
+        block_size,
+        x,
+        Out,
+        stride_b_loc_b,
+        stride_b_loc_s,
+        stride_qbs,
+        stride_qh,
+        stride_qd,
+        stride_kbs,
+        stride_kh,
+        stride_kd,
+        stride_vbs,
+        stride_vh,
+        stride_vd,
+        stride_obs,
+        stride_oh,
+        stride_od,
+        stride_k_cache_bs,
+        stride_k_cache_h,
+        stride_k_cache_d,
+        stride_k_cache_bl,
+        stride_k_cache_x,
+        stride_v_cache_bs,
+        stride_v_cache_h,
+        stride_v_cache_d,
+        stride_v_cache_bl,
+        BLOCK_M: tl.constexpr,
+        BLOCK_DMODEL: tl.constexpr,
+        BLOCK_N: tl.constexpr,
+    ):
+        # attn_bias[]
+        cur_batch = tl.program_id(0)
+        cur_head = tl.program_id(1)
+        start_m = tl.program_id(2)
+
+        # cur_batch_seq_len: the length of prompts
+        # cur_batch_ctx_len: the length of prefix
+        # cur_batch_in_all_start_index: the start id of the dim=0
+        cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
+        cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
+        cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
+
+        block_start_loc = BLOCK_M * start_m
+
+        # initialize offsets
+        offs_n = tl.arange(0, BLOCK_N)
+        offs_d = tl.arange(0, BLOCK_DMODEL)
+        offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
+        off_q = (
+            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
+            cur_head * stride_qh + offs_d[None, :] * stride_qd)
+
+        q = tl.load(
+            Q + off_q,
+            mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
+            other=0.0)
+
+        # # initialize pointer to m and l
+        m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
+        l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
+        acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
+
+        alibi_slope = tl.load(Alibi_slopes + cur_head)
+        alibi_start_q = tl.arange(
+            0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
+        alibi_start_k = 0
+        for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
+            start_n = tl.multiple_of(start_n, BLOCK_N)
+            # -- compute qk ----
+            bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
+                         ((start_n + offs_n) // block_size) * stride_b_loc_s,
+                         mask=(start_n + offs_n) < cur_batch_ctx_len,
+                         other=0)
+            off_k = (bn[None, :] * stride_k_cache_bs +
+                     cur_head * stride_k_cache_h +
+                     (offs_d[:, None] // x) * stride_k_cache_d +
+                     ((start_n + offs_n[None, :]) % block_size) *
+                     stride_k_cache_bl +
+                     (offs_d[:, None] % x) * stride_k_cache_x)
+            off_v = (
+                bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
+                offs_d[None, :] * stride_v_cache_d +
+                (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
+            k = tl.load(K_cache + off_k,
+                        mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
+                        other=0.0)
+
+            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+            qk += tl.dot(q, k)
+            qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
+                          float("-inf"))
+            qk *= sm_scale
+
+            # load alibi
+            alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
+                     alibi_start_q[:, None]) * alibi_slope
+            alibi = tl.where(
+                (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
+                alibi, float("-inf"))
+            qk += alibi
+            alibi_start_k += BLOCK_N
+
+            # -- compute m_ij, p, l_ij
+            m_ij = tl.max(qk, 1)
+            m_i_new = tl.maximum(m_i, m_ij)
+            p = tl.math.exp(qk - m_i_new[:, None])
+            l_ij = tl.sum(p, 1)
+            # -- update m_i and l_i
+
+            alpha = tl.math.exp(m_i - m_i_new)
+            l_i_new = alpha * l_i + l_ij
+            # -- update output accumulator --
+            # scale p
+            # scale acc
+            acc_scale = alpha
+            # acc_scale = l_i / l_i_new * alpha
+            acc = acc * acc_scale[:, None]
+            # update acc
+            v = tl.load(V_cache + off_v,
+                        mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
+                        other=0.0)
+
+            p = p.to(v.dtype)
+            acc += tl.dot(p, v, allow_tf32=False)
+            # update m_i and l_i
+            l_i = l_i_new
+            m_i = m_i_new
+
+        off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
+                 offs_d[:, None] * stride_kd)
+        off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
+                 offs_d[None, :] * stride_vd)
+        k_ptrs = K + off_k
+        v_ptrs = V + off_v
+
+        block_mask = tl.where(
+            block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
+
+        # init alibi
+        alibi_slope = tl.load(Alibi_slopes + cur_head)
+        alibi_start_q = tl.arange(
+            0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
+        alibi_start_k = cur_batch_ctx_len
+        # # init debuger
+        # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
+        # offset_db_k = tl.arange(0, BLOCK_N)
+        # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
+        for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
+            start_n = tl.multiple_of(start_n, BLOCK_N)
+            # -- compute qk ----
+            k = tl.load(k_ptrs +
+                        (cur_batch_in_all_start_index + start_n) * stride_kbs,
+                        mask=(start_n + offs_n[None, :]) <
+                        cur_batch_seq_len - cur_batch_ctx_len,
+                        other=0.0)
+
+            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
+            qk += tl.dot(q, k, allow_tf32=False)
+            qk *= sm_scale
+            qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
+                          float("-inf"))
+
+            # load alibi
+            alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
+                     alibi_start_q[:, None]) * alibi_slope
+            alibi = tl.where(
+                (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
+                alibi, float("-inf"))
+            qk += alibi
+            alibi_start_k += BLOCK_N
+
+            # -- compute m_ij, p, l_ij
+            m_ij = tl.max(qk, 1)
+            m_i_new = tl.maximum(m_i, m_ij)
+            p = tl.math.exp(qk - m_i_new[:, None])
+            l_ij = tl.sum(p, 1)
+            # -- update m_i and l_i
+
+            alpha = tl.math.exp(m_i - m_i_new)
+            l_i_new = alpha * l_i + l_ij
+            # -- update output accumulator --
+            # scale p
+            # scale acc
+            acc_scale = alpha
+            # acc_scale = l_i / l_i_new * alpha
+            acc = acc * acc_scale[:, None]
+            # update acc
+            v = tl.load(v_ptrs +
+                        (cur_batch_in_all_start_index + start_n) * stride_vbs,
+                        mask=(start_n + offs_n[:, None]) <
+                        cur_batch_seq_len - cur_batch_ctx_len,
+                        other=0.0)
+
+            p = p.to(v.dtype)
+            acc += tl.dot(p, v, allow_tf32=False)
+            # update m_i and l_i
+            l_i = l_i_new
+            m_i = m_i_new
+
+        acc = acc / l_i[:, None]
+
+        # initialize pointers to output
+        off_o = (
+            (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
+            cur_head * stride_oh + offs_d[None, :] * stride_od)
+        out_ptrs = Out + off_o
+        tl.store(out_ptrs,
+                 acc,
+                 mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
+        return
+
+    @torch.inference_mode()
+    def context_attention_fwd(q,
+                              k,
+                              v,
+                              o,
+                              k_cache,
+                              v_cache,
+                              b_loc,
+                              b_start_loc,
+                              b_seq_len,
+                              b_ctx_len,
+                              max_input_len,
+                              alibi_slopes=None):
+        BLOCK = 128
+        # shape constraints
+        Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
+        assert Lq == Lk and Lk == Lv
+        assert Lk in {16, 32, 64, 128}
+
+        sm_scale = 1.0 / (Lq**0.5)
+        batch, head = b_seq_len.shape[0], q.shape[1]
+
+        grid = (batch, head, triton.cdiv(max_input_len, BLOCK))  # batch, head,
+
+        num_warps = 8 if Lk <= 64 else 8
+        if alibi_slopes is not None:
+            _fwd_kernel_alibi[grid](
+                q,
+                k,
+                v,
+                k_cache,
+                v_cache,
+                b_loc,
+                sm_scale,
+                b_start_loc,
+                b_seq_len,
+                b_ctx_len,
+                alibi_slopes,
+                v_cache.shape[3],
+                8,
+                o,
+                b_loc.stride(0),
+                b_loc.stride(1),
+                q.stride(0),
+                q.stride(1),
+                q.stride(2),
+                k.stride(0),
+                k.stride(1),
+                k.stride(2),
+                v.stride(0),
+                v.stride(1),
+                v.stride(2),
+                o.stride(0),
+                o.stride(1),
+                o.stride(2),
+                k_cache.stride(0),
+                k_cache.stride(1),
+                k_cache.stride(2),
+                k_cache.stride(3),
+                k_cache.stride(
+                    4
+                ),  #[num_blocks, num_kv_heads, head_size/x, block_size, x]
+                v_cache.stride(0),
+                v_cache.stride(1),
+                v_cache.stride(2),
+                v_cache.stride(
+                    3),  #[num_blocks, num_kv_heads, head_size, block_size]
+                BLOCK_M=BLOCK,
+                BLOCK_DMODEL=Lk,
+                BLOCK_N=BLOCK,
+                num_warps=num_warps,
+                num_stages=1,
+            )
+            return
+
+        _fwd_kernel[grid](
+            q,
+            k,
+            v,
+            k_cache,
+            v_cache,
+            b_loc,
+            sm_scale,
+            b_start_loc,
+            b_seq_len,
+            b_ctx_len,
+            v_cache.shape[3],
+            8,
+            o,
+            b_loc.stride(0),
+            b_loc.stride(1),
+            q.stride(0),
+            q.stride(1),
+            q.stride(2),
+            k.stride(0),
+            k.stride(1),
+            k.stride(2),
+            v.stride(0),
+            v.stride(1),
+            v.stride(2),
+            o.stride(0),
+            o.stride(1),
+            o.stride(2),
+            k_cache.stride(0),
+            k_cache.stride(1),
+            k_cache.stride(2),
+            k_cache.stride(3),
+            k_cache.stride(
+                4),  #[num_blocks, num_kv_heads, head_size/x, block_size, x]
+            v_cache.stride(0),
+            v_cache.stride(1),
+            v_cache.stride(2),
+            v_cache.stride(
+                3),  #[num_blocks, num_kv_heads, head_size, block_size]
+            BLOCK_M=BLOCK,
+            BLOCK_DMODEL=Lk,
+            BLOCK_N=BLOCK,
+            num_warps=num_warps,
+            num_stages=1,
+        )
+        return

+ 1 - 1
aphrodite/modeling/layers/vocab_parallel_embedding.py

@@ -58,7 +58,7 @@ class VocabParallelEmbedding(torch.nn.Module):
         if params_dtype is None:
             params_dtype = torch.get_default_dtype()
         self.tp_size = get_tensor_model_parallel_world_size()
-        # Divide the weight matrix along the vocabulary dimension.
+        # Divide the weight matrix along the vocaburaly dimension.
         self.vocab_start_index, self.vocab_end_index = (
             vocab_range_from_global_vocab_size(
                 self.num_embeddings_padded, get_tensor_model_parallel_rank(),

+ 144 - 3
aphrodite/modeling/megatron/communication_op.py

@@ -1,14 +1,20 @@
+from collections import namedtuple
+from typing import Any, Dict, List, Optional, Union
+
+from torch.distributed import ProcessGroup
 import torch
 
 from aphrodite.modeling.megatron.parallel_state import (
+    get_tensor_model_parallel_rank,
     get_tensor_model_parallel_world_size,
     get_tensor_model_parallel_group,
 )
 
 
-def tensor_model_parallel_all_reduce(input_):
+def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
     """All-reduce the input tensor across model parallel group.
-    Note: This operation is applied in-place on the input tensor.
+
+    NOTE: This operation is applied in-place on the input tensor.
     """
     # Bypass the function if we are using only 1 GPU.
     if get_tensor_model_parallel_world_size() == 1:
@@ -19,7 +25,8 @@ def tensor_model_parallel_all_reduce(input_):
     return input_
 
 
-def tensor_model_parallel_all_gather(input_, dim=-1):
+def tensor_model_parallel_all_gather(input_: torch.Tensor,
+                                     dim: int = -1) -> torch.Tensor:
     """All-gather the input tensor across model parallel group."""
     world_size = get_tensor_model_parallel_world_size()
     # Bypass the function if we are using only 1 GPU.
@@ -44,3 +51,137 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
                                           (world_size * input_size[dim], ) +
                                           input_size[dim + 1:])
     return output_tensor
+
+
+def tensor_model_parallel_gather(input_: torch.Tensor,
+                                 dst: int = 0,
+                                 dim: int = -1) -> torch.Tensor:
+    """Gather the input tensor across model parallel group.
+
+    NOTE: We assume that the input tensor is on the same device across
+    all the ranks.
+    """
+    world_size = get_tensor_model_parallel_world_size()
+    # Bypass the function if we are using only 1 GPU.
+    if world_size == 1:
+        return input_
+    assert -input_.dim() <= dim < input_.dim(), (
+        f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
+    if dim < 0:
+        # Convert negative dim to positive.
+        dim += input_.dim()
+    # Allocate output tensor.
+    if get_tensor_model_parallel_rank() == dst:
+        gather_list = [torch.empty_like(input_) for _ in range(world_size)]
+    else:
+        gather_list = None
+    # Gather.
+    torch.distributed.gather(input_,
+                             gather_list,
+                             dst=dst,
+                             group=get_tensor_model_parallel_group())
+    if get_tensor_model_parallel_rank() == dst:
+        output_tensor = torch.cat(gather_list, dim=dim)
+    else:
+        output_tensor = None
+    return output_tensor
+
+
+def broadcast(input_: torch.Tensor,
+              src: int = 0,
+              group: Optional[ProcessGroup] = None):
+    """Broadcast the input tensor."""
+    group = group or torch.distributed.group.WORLD
+    ranks = torch.distributed.get_process_group_ranks(group)
+    assert src in ranks, f"Invalid src rank ({src})"
+
+    # Bypass the function if we are using only 1 GPU.
+    world_size = torch.distributed.get_world_size(group=group)
+    if world_size == 1:
+        return input_
+    # Broadcast.
+    torch.distributed.broadcast(input_, src=src, group=group)
+    return input_
+
+
+def broadcast_object_list(obj_list: List[Any],
+                          src: int = 0,
+                          group: Optional[ProcessGroup] = None):
+    """Broadcast the input object list."""
+    group = group or torch.distributed.group.WORLD
+    ranks = torch.distributed.get_process_group_ranks(group)
+    assert src in ranks, f"Invalid src rank ({src})"
+
+    # Bypass the function if we are using only 1 GPU.
+    world_size = torch.distributed.get_world_size(group=group)
+    if world_size == 1:
+        return obj_list
+    # Broadcast.
+    torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
+    return obj_list
+
+
+TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
+
+
+def broadcast_tensor_dict(
+    tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
+    src: int = 0,
+    group: Optional[ProcessGroup] = None,
+) -> Dict[Any, Union[torch.Tensor, Any]]:
+    """Broadcast the input tensor dictionary."""
+    group = group or torch.distributed.group.WORLD
+    ranks = torch.distributed.get_process_group_ranks(group)
+    assert src in ranks, f"Invalid src rank ({src})"
+
+    # Bypass the function if we are using only 1 GPU.
+    world_size = torch.distributed.get_world_size(group=group)
+    if world_size == 1:
+        return tensor_dict
+
+    rank = torch.distributed.get_rank()
+    if rank == src:
+        assert isinstance(
+            tensor_dict,
+            dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
+        metadata_list = []
+        for key, value in tensor_dict.items():
+            if isinstance(value, torch.Tensor):
+                assert value.is_cuda, (
+                    f"Tensor {key}: {value} is not on cuda. Currently we only "
+                    f"support broadcasting tensors on cuda.")
+                metadata_list.append(
+                    (key, TensorMetadata(value.dtype, value.size())))
+            else:
+                metadata_list.append((key, value))
+        torch.distributed.broadcast_object_list([metadata_list],
+                                                src=src,
+                                                group=group)
+        for key, value in metadata_list:
+            if isinstance(value, TensorMetadata):
+                tensor = tensor_dict[key]
+                torch.distributed.broadcast(tensor, src=src)
+    else:
+        recv_metadata_list = [None]
+        torch.distributed.broadcast_object_list(recv_metadata_list,
+                                                src=src,
+                                                group=group)
+        metadata_list = recv_metadata_list[0]
+        tensor_dict = {}
+        async_handles = []
+        for key, value in metadata_list:  # pylint: disable=not-an-iterable
+            if isinstance(value, TensorMetadata):
+                tensor = torch.empty(value.size,
+                                     dtype=value.dtype,
+                                     device="cuda")
+                async_handle = torch.distributed.broadcast(tensor,
+                                                           src=src,
+                                                           async_op=True,
+                                                           group=group)
+                async_handles.append(async_handle)
+                tensor_dict[key] = tensor
+            else:
+                tensor_dict[key] = value
+        for async_handle in async_handles:
+            async_handle.wait()
+    return tensor_dict

+ 4 - 4
aphrodite/modeling/megatron/parallel_state.py

@@ -1,9 +1,9 @@
 # Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
-# pylint: disable=line-too-long
-# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
+# Adapted from
+# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
-"""Model and data parallel groups."""
+"""Tensor and pipeline parallel groups."""
 
 import torch
 
@@ -85,7 +85,7 @@ def initialize_model_parallel(
 
 
 def model_parallel_is_initialized():
-    """Check if model and data parallel groups are initialized."""
+    """Check if tensor and pipeline parallel groups are initialized."""
     return (_TENSOR_MODEL_PARALLEL_GROUP is not None
             and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
 

+ 2 - 2
aphrodite/modeling/megatron/utils.py

@@ -1,7 +1,7 @@
 # Copyright 2023 The PygmalionAI team.
 # Copyright 2023 The vLLM team.
-# pylint: disable=line-too-long
-# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
+# Adapted from
+# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
 # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
 from typing import Sequence
 

+ 12 - 10
aphrodite/modeling/metadata.py

@@ -1,46 +1,48 @@
-from typing import List, Optional
+from typing import Optional
+
 import torch
 
 
 class InputMetadata:
-    """Metadata for input sequences. Used for PagedAttention.
+    """Metadata for input sequences. Used in PagedAttention.
 
     Args:
-        seq_groups: List of (seq_ids, sampling_params).
-        seq_data: Seq_id -> SequenceData.
         prompt_lens: Lengths of prompts.
         slot_mapping: The address to write the new KV to of each token.
-        context_lens: the length of attention context for each generation token.
         max_context_len: The maximum context length.
+        context_lens: the length of attention context for each sequence.
         block_tables: The block tables. (Seq id -> list of physical block)
     """
 
     def __init__(
         self,
-        prompt_lens: List[int],
+        is_prompt: bool,
         slot_mapping: torch.Tensor,
+        prompt_lens: Optional[torch.Tensor],
+        max_seq_len: Optional[int],
+        start_loc: Optional[torch.Tensor],
         max_context_len: Optional[int],
         context_lens: Optional[torch.Tensor],
         block_tables: Optional[torch.Tensor],
         use_cuda_graph: bool,
     ) -> None:
+        self.is_prompt = is_prompt
         self.prompt_lens = prompt_lens
+        self.max_seq_len = max_seq_len
+        self.start_loc = start_loc
         self.max_context_len = max_context_len
         self.slot_mapping = slot_mapping
         self.context_lens = context_lens
         self.block_tables = block_tables
         self.use_cuda_graph = use_cuda_graph
 
-        self.is_prompt = len(prompt_lens) > 0
-
         # Set during the execution of the first attention op.
         # FIXME: This is a hack.
         self.attn_bias = None
 
     def __repr__(self) -> str:
-        # Print only useful metadata.
         return ("InputMetadata("
-                f"prompt_lens={self.prompt_lens}, "
+                f"is_prompt={self.is_prompt}, "
                 f"max_context_len={self.max_context_len}, "
                 f"slot_mapping={self.slot_mapping}, "
                 f"context_lens={self.context_lens}, "

+ 1 - 1
aphrodite/modeling/models/__init__.py

@@ -17,7 +17,7 @@ _MODELS = {
     "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
     "MistralForCausalLM": ("mistral", "MistralForCausalLM"),
     "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
-    "PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"),
+    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
     "YiForCausalLM": ("yi", "YiForCausalLM"),
 }
 

+ 23 - 18
aphrodite/modeling/models/decilm.py

@@ -31,24 +31,26 @@ from transformers import PretrainedConfig
 
 from aphrodite.modeling.layers.linear import LinearMethodBase
 from aphrodite.modeling.models.llama import LlamaForCausalLM
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-
-DECI_LM_MODEL_NAME = "Deci/DeciLM-7b-instruct"
+from aphrodite.modeling.weight_utils import (default_weight_loader,
+                                             hf_model_weights_iterator)
 
 
 class DeciLMForCausalLM(LlamaForCausalLM):
     """
     Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
-    Based on the llama modeling code..
+    Based on the llama executor.
+
     The main difference is that DeciLM uses Variable Grouped Query Attention.
     The constant number of GQA heads in the decoder is overriden with a value
-    per layer. Usually, in the HuggingFace implementation, Instead of
-    `config.num_key_value_heads`, we use
-    `config.num_key_value_heads_per_layer[i]` which varies.
-    Currently, PagedAttention does not work well with variable GQA,
-    so we normalize the weights upon loading, and use uniform GQA with the max
-    value instead.
+    per layer.
+
+    Usually, in the HuggingFace implementation, instead of
+    "config.num_key_value_heads", we use
+    "config.num_key_value_heads_per_layer[i]" which varies.
+
+    Currently, PagedAttention does not work well with variable GQA, so we
+    normalize the weights upon loading, and use uniform GQA with the max value
+    instead.
     """
 
     def __init__(
@@ -78,28 +80,31 @@ class DeciLMForCausalLM(LlamaForCausalLM):
                 model_name_or_path, cache_dir, load_format, revision):
             if "rotary_emb.inv_freq" in name:
                 continue
-            if "rotary_emb.cos_cached" in name:
-                continue
-            if "rotary_emb.sin_cached" in name:
-                continue
 
             if "k_proj" in name or "v_proj" in name:
-                loaded_weight = self.degroup_weight(loaded_weight)
+                loaded_weight = self._degroup_weight(loaded_weight)
 
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
-                param = params_dict[name.replace(weight_name, param_name)]
+                name = name.replace(weight_name, param_name)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
 
-    def degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
+    def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
         hidden_size = self.config.hidden_size
         head_size = self.config.hidden_size // self.config.num_attention_heads
         target_num_kv_heads = self.config.num_key_value_heads

+ 8 - 14
aphrodite/modeling/models/gpt_j.py

@@ -16,11 +16,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inference-only GPT-J model compatible with HuggingFace weights.
-
-The input of the model is flattened to a 1D tensor of tokens. The model uses
-InputMetadata to extract the original 2D shape of the input.
-"""
+"""Inference-only GPT-J model compatible with HuggingFace weights."""
 from typing import List, Optional, Tuple
 
 import torch
@@ -148,10 +144,8 @@ class GPTJBlock(nn.Module):
         linear_method: Optional[LinearMethodBase] = None,
     ):
         super().__init__()
-        if config.n_inner is None:
-            inner_dim = 4 * config.n_embd
-        else:
-            inner_dim = config.n_inner
+        inner_dim = (4 * config.n_embd
+                     if config.n_inner is None else config.n_inner)
         self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
         self.attn = GPTJAttention(config, linear_method)
         self.mlp = GPTJMLP(inner_dim, config, linear_method)
@@ -248,7 +242,7 @@ class GPTJForCausalLM(nn.Module):
         self,
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
-    ) -> SamplerOutput:
+    ) -> Optional[SamplerOutput]:
         next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                    sampling_metadata, self.lm_head.bias)
         return next_tokens
@@ -275,16 +269,16 @@ class GPTJForCausalLM(nn.Module):
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
-                # skip loading extra bias for GPTQ models
-                if name.endswith("bias") and name not in params_dict:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                # skip loading extra bias for GPTQ models
-                if name.endswith("bias") and name not in params_dict:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",

+ 2 - 7
aphrodite/modeling/models/gpt_neox.py

@@ -16,11 +16,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inference-only GPT-NeoX model compatible with HuggingFace weights.
-
-The input of the model is flattened to a 1D tensor of tokens. The model uses
-InputMetadata to extract the original 2D shape of the input.
-"""
+"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
 from typing import List, Optional, Tuple
 
 import torch
@@ -80,7 +76,6 @@ class GPTNeoXAttention(nn.Module):
             bias=self.bias,
             linear_method=linear_method,
         )
-
         scaling = self.head_size**-0.5
         rotary_dim = int(self.head_size * config.rotary_pct)
         assert rotary_dim % 2 == 0
@@ -261,7 +256,7 @@ class GPTNeoXForCausalLM(nn.Module):
         self,
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
-    ) -> SamplerOutput:
+    ) -> Optional[SamplerOutput]:
         next_tokens = self.sampler(self.embed_out.weight, hidden_states,
                                    sampling_metadata)
         return next_tokens

+ 12 - 14
aphrodite/modeling/models/llama.py

@@ -21,11 +21,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inference-only LLaMA model compatible with HuggingFace weights.
-
-The input of the model is flattened to a 1D tensor of tokens. The model uses
-InputMetadata to extract the original 2D shape of the input.
-"""
+"""Inference-only LLaMA model compatible with HuggingFace weights."""
 from typing import Any, Dict, List, Optional, Tuple
 
 import torch
@@ -133,6 +129,7 @@ class LlamaAttention(nn.Module):
             bias=False,
             linear_method=linear_method,
         )
+
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.head_dim,
@@ -286,7 +283,7 @@ class LlamaForCausalLM(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[KVCache],
         input_metadata: InputMetadata,
-    ) -> SamplerOutput:
+    ) -> torch.Tensor:
         hidden_states = self.model(input_ids, positions, kv_caches,
                                    input_metadata)
         return hidden_states
@@ -295,7 +292,7 @@ class LlamaForCausalLM(nn.Module):
         self,
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
-    ) -> SamplerOutput:
+    ) -> Optional[SamplerOutput]:
         next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                    sampling_metadata)
         return next_tokens
@@ -318,24 +315,25 @@ class LlamaForCausalLM(nn.Module):
                 model_name_or_path, cache_dir, load_format, revision):
             if "rotary_emb.inv_freq" in name:
                 continue
-            if "rotary_emb.cos_cached" in name:
-                continue
-            if "rotary_emb.sin_cached" in name:
+            if ("rotary_emb.cos_cached" in name
+                    or "rotary_emb.sin_cached" in name):
+                # Models trained using ColossalAI may include these tensors in
+                # the checkpoint. Skip them.
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
-                # skip loading extra bias for GPTQ models
-                if name.endswith("bias") and name not in params_dict:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                # skip loading extra bias for GPTQ models
-                if name.endswith("bias") and name not in params_dict:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",

+ 1 - 1
aphrodite/modeling/models/mistral.py

@@ -288,7 +288,7 @@ class MistralForCausalLM(nn.Module):
         self,
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
-    ) -> SamplerOutput:
+    ) -> Optional[SamplerOutput]:
         next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                    sampling_metadata)
         return next_tokens

+ 21 - 48
aphrodite/modeling/models/mixtral.py

@@ -22,13 +22,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Mixtral model."""
-from typing import List, Optional, Tuple, Dict, Any
+from typing import List, Optional, Tuple
+
+import numpy as np
 
 import torch
-from torch import nn
 import torch.nn.functional as F
+
+from torch import nn
 from transformers import MixtralConfig
-import numpy as np
 
 from aphrodite.modeling.metadata import InputMetadata
 from aphrodite.modeling.layers.attention import PagedAttention
@@ -37,7 +39,6 @@ from aphrodite.modeling.layers.linear import (LinearMethodBase,
                                               ReplicatedLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
-from aphrodite.modeling.layers.moe import MoE
 from aphrodite.modeling.layers.rotary_embedding import get_rope
 from aphrodite.modeling.layers.sampler import Sampler
 from aphrodite.modeling.layers.vocab_parallel_embedding import (
@@ -81,7 +82,7 @@ class MixtralMLP(nn.Module):
                                    bias=False,
                                    linear_method=linear_method)
 
-        # TODO: Use aphrodite's SiluAndMul
+        # TODO: Use vllm's SiluAndMul
         self.act_fn = nn.SiLU()
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -168,7 +169,6 @@ class MixtralAttention(nn.Module):
                  num_kv_heads: int,
                  max_position: int = 4096 * 32,
                  rope_theta: float = 10000,
-                 rope_scaling: Optional[Dict[str, Any]] = None,
                  linear_method: Optional[LinearMethodBase] = None,
                  sliding_window: Optional[int] = None) -> None:
         super().__init__()
@@ -191,7 +191,6 @@ class MixtralAttention(nn.Module):
         self.q_size = self.num_heads * self.head_dim
         self.kv_size = self.num_kv_heads * self.head_dim
         self.scaling = self.head_dim**-0.5
-        self.rope_scaling = rope_scaling
         self.rope_theta = rope_theta
         self.sliding_window = sliding_window
 
@@ -215,7 +214,6 @@ class MixtralAttention(nn.Module):
             max_position=max_position,
             base=int(self.rope_theta),
             is_neox_style=True,
-            rope_scaling=rope_scaling,
         )
         self.attn = PagedAttention(
             self.num_heads,
@@ -252,25 +250,16 @@ class MixtralDecoderLayer(nn.Module):
         self.hidden_size = config.hidden_size
         # Requires transformers > 4.32.0
         rope_theta = getattr(config, "rope_theta", 10000)
-        rope_scaling = getattr(config, "rope_scaling", None)
         self.self_attn = MixtralAttention(
             hidden_size=self.hidden_size,
             num_heads=config.num_attention_heads,
             max_position=config.max_position_embeddings,
             num_kv_heads=config.num_key_value_heads,
             rope_theta=rope_theta,
-            rope_scaling=rope_scaling,
             sliding_window=config.sliding_window,
             linear_method=linear_method)
-        if linear_method is None:
-            self.block_sparse_moe = MoE(
-                num_experts=config.num_local_experts,
-                top_k=config.num_experts_per_tok,
-                hidden_size=config.hidden_size,
-                intermediate_size=config.intermediate_size)
-        else:
-            self.block_sparse_moe = MixtralMoE(config,
-                                               linear_method=linear_method)
+        self.block_sparse_moe = MixtralMoE(config=config,
+                                           linear_method=linear_method)
         self.input_layernorm = RMSNorm(config.hidden_size,
                                        eps=config.rms_norm_eps)
         self.post_attention_layernorm = RMSNorm(config.hidden_size,
@@ -332,7 +321,7 @@ class MixtralModel(nn.Module):
         positions: torch.Tensor,
         kv_caches: List[KVCache],
         input_metadata: InputMetadata,
-    ) -> SamplerOutput:
+    ) -> torch.Tensor:
         hidden_states = self.embed_tokens(input_ids)
         residual = None
         for i in range(len(self.layers)):
@@ -373,7 +362,7 @@ class MixtralForCausalLM(nn.Module):
         self,
         hidden_states: Optional[torch.Tensor],
         sampling_metadata: SamplingMetadata,
-    ) -> SamplerOutput:
+    ) -> Optional[SamplerOutput]:
         next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                    sampling_metadata)
         return next_tokens
@@ -390,13 +379,6 @@ class MixtralForCausalLM(nn.Module):
             ("qkv_proj", "v_proj", "v"),
         ]
 
-        expert_params_mapping = [
-            # (param_name, weight_name, expert_id)
-            (f"{weight_name}s", f"experts.{expert_id}.{weight_name}.weight",
-             expert_id) for expert_id in range(self.config.num_local_experts)
-            for weight_name in ["w1", "w2", "w3"]
-        ]
-
         params_dict = dict(self.named_parameters())
         for name, loaded_weight in hf_model_weights_iterator(
                 model_name_or_path,
@@ -418,23 +400,14 @@ class MixtralForCausalLM(nn.Module):
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                for param_name, weight_name, expert_id in expert_params_mapping:
-                    if weight_name not in name:
-                        continue
-                    name = name.replace(weight_name, param_name)
-                    param = params_dict[name]
-                    weight_loader = param.weight_loader
-                    weight_loader(param, loaded_weight, expert_id=expert_id)
-                    break
-                else:
-                    # Skip loading extra bias for GPTQ models.
-                    if name.endswith(".bias") and name not in params_dict:
-                        continue
-                    # Skip experts that are not assigned to this worker.
-                    if ("block_sparse_moe.experts." in name
-                            and name not in params_dict):
-                        continue
-                    param = params_dict[name]
-                    weight_loader = getattr(param, "weight_loader",
-                                            default_weight_loader)
-                    weight_loader(param, loaded_weight)
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
+                    continue
+                # Skip experts that are not assigned to this worker.
+                if ("block_sparse_moe.experts." in name
+                        and name not in params_dict):
+                    continue
+                param = params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)

+ 0 - 269
aphrodite/modeling/models/phi1_5.py

@@ -1,269 +0,0 @@
-from typing import List, Optional, Tuple
-
-import torch
-from torch import nn
-from transformers import PretrainedConfig
-
-from aphrodite.modeling.metadata import InputMetadata
-from aphrodite.modeling.layers.activation import get_act_fn
-from aphrodite.modeling.layers.attention import PagedAttention
-from aphrodite.modeling.layers.linear import (ColumnParallelLinear,
-                                              RowParallelLinear,
-                                              QKVParallelLinear,
-                                              LinearMethodBase)
-from aphrodite.modeling.layers.rotary_embedding import get_rope
-from aphrodite.modeling.layers.sampler import Sampler
-from aphrodite.modeling.layers.vocab_parallel_embedding import (
-    VocabParallelEmbedding, ParallelLMHead)
-from aphrodite.modeling.megatron.parallel_state import (
-    get_tensor_model_parallel_world_size)
-from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.modeling.hf_downloader import (default_weight_loader,
-                                              hf_model_weights_iterator)
-from aphrodite.common.sequence import SamplerOutput
-
-KVCache = Tuple[torch.Tensor, torch.Tensor]
-
-
-class PhiEmbedding(nn.Module):
-
-    def __init__(self, config: PretrainedConfig):
-        super().__init__()
-
-        self.wte = VocabParallelEmbedding(
-            config.vocab_size,
-            config.hidden_size,
-        )
-
-    def forward(self, input_ids: torch.LongTensor):
-        return self.wte(input_ids)
-
-
-class PhiAttention(nn.Module):
-
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
-        super().__init__()
-        self.total_num_heads = config.num_attention_heads
-        self.hidden_size = config.hidden_size
-        self.head_size = self.hidden_size // self.total_num_heads
-
-        tensor_model_parallel_world_size = (
-            get_tensor_model_parallel_world_size())
-        assert self.total_num_heads % tensor_model_parallel_world_size == 0
-        self.num_heads = (self.total_num_heads //
-                          tensor_model_parallel_world_size)
-
-        # pylint: disable=C0103
-        self.Wqkv = QKVParallelLinear(
-            self.hidden_size,
-            self.head_size,
-            self.total_num_heads,
-            linear_method=linear_method,
-        )
-        self.qkv_proj = QKVParallelLinear(
-            config.hidden_size,
-            self.head_size,
-            self.total_num_heads,
-            bias=False,
-            linear_method=linear_method,
-        )
-        self.out_proj = RowParallelLinear(
-            self.hidden_size,
-            self.hidden_size,
-            linear_method=linear_method,
-        )
-
-        scaling = self.head_size**-0.5
-        rotary_dim = config.rotary_dim
-        assert rotary_dim % 2 == 0
-
-        # pylint: disable=C0301
-        # Refer to:
-        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
-        rope_theta = 10000
-        max_position_embeddings = getattr(config, "n_positions", 2048)
-        self.rotary_emb = get_rope(
-            self.head_size,
-            rotary_dim=rotary_dim,
-            max_position=max_position_embeddings,
-            base=rope_theta,
-        )
-        self.attn = PagedAttention(self.num_heads, self.head_size, scaling)
-
-    def forward(
-        self,
-        position_ids: torch.Tensor,
-        hidden_states: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
-    ) -> torch.Tensor:
-        qkv, _ = self.Wqkv(hidden_states)
-        q, k, v = qkv.chunk(chunks=3, dim=-1)
-        q, k = self.rotary_emb(position_ids, q, k)
-        k_cache, v_cache = kv_cache
-        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
-        output, _ = self.out_proj(attn_output)
-        return output
-
-
-class PhiMLP(nn.Module):
-
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
-        super().__init__()
-
-        n_inner = getattr(config, "n_inner", None)
-        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size
-
-        self.fc1 = ColumnParallelLinear(
-            config.hidden_size,
-            n_inner,
-            linear_method=linear_method,
-        )
-        self.fc2 = RowParallelLinear(
-            n_inner,
-            config.hidden_size,
-            linear_method=linear_method,
-        )
-        quant_config = getattr(linear_method, "quant_config", None)
-        self.act = get_act_fn(config.activation_function, quant_config,
-                              n_inner)
-
-    def forward(self, hidden_states):
-        hidden_states, _ = self.fc1(hidden_states)
-        hidden_states = self.act(hidden_states)
-        hidden_states, _ = self.fc2(hidden_states)
-        return hidden_states
-
-
-class PhiLayer(nn.Module):
-
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
-        super().__init__()
-        self.ln = nn.LayerNorm(config.hidden_size,
-                               eps=config.layer_norm_epsilon)
-        self.mixer = PhiAttention(config, linear_method)
-        self.mlp = PhiMLP(config, linear_method)
-
-    def forward(
-        self,
-        position_ids: torch.Tensor,
-        hidden_states: torch.Tensor,
-        kv_cache: KVCache,
-        input_metadata: InputMetadata,
-    ) -> torch.Tensor:
-        residual = hidden_states
-        hidden_states = self.ln(hidden_states)
-        attn_outputs = self.mixer(
-            position_ids=position_ids,
-            hidden_states=hidden_states,
-            kv_cache=kv_cache,
-            input_metadata=input_metadata,
-        )
-        feed_forward_hidden_states = self.mlp(hidden_states)
-        hidden_states = attn_outputs + feed_forward_hidden_states + residual
-        return hidden_states
-
-
-class PhiModel(nn.Module):
-
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
-        super().__init__()
-        self.config = config
-        self.linear_method = linear_method
-        self.embd = PhiEmbedding(config)
-        self.h = nn.ModuleList([
-            PhiLayer(config, linear_method)
-            for _ in range(config.num_hidden_layers)
-        ])
-
-    def forward(
-        self,
-        input_ids: torch.Tensor,
-        positions: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
-    ) -> torch.Tensor:
-        hidden_states = self.embd(input_ids)
-        for i in range(self.config.num_hidden_layers):
-            layer = self.h[i]
-            hidden_states = layer(
-                positions,
-                hidden_states,
-                kv_caches[i],
-                input_metadata,
-            )
-        return hidden_states
-
-
-class PhiCausalLMHead(nn.Module):
-
-    def __init__(self, config: PretrainedConfig):
-        super().__init__()
-        self.ln = nn.LayerNorm(config.hidden_size,
-                               eps=config.layer_norm_epsilon)
-        self.linear = ParallelLMHead(config.vocab_size,
-                                     config.hidden_size,
-                                     bias=True)
-
-
-class PhiForCausalLM(nn.Module):
-
-    def __init__(self,
-                 config: PretrainedConfig,
-                 linear_method: Optional[LinearMethodBase] = None):
-        super().__init__()
-        self.config = config
-        self.linear_method = linear_method
-
-        self.transformer = PhiModel(config, linear_method)
-        self.lm_head = PhiCausalLMHead(config)
-        self.sampler = Sampler(config.vocab_size)
-
-    def forward(
-        self,
-        input_ids: torch.Tensor,
-        positions: torch.Tensor,
-        kv_caches: List[KVCache],
-        input_metadata: InputMetadata,
-    ) -> torch.Tensor:
-        hidden_states = self.transformer(input_ids, positions, kv_caches,
-                                         input_metadata)
-        hidden_states = self.lm_head.ln(hidden_states)
-        return hidden_states
-
-    def sample(
-        self,
-        hidden_states: torch.Tensor,
-        sampling_metadata: SamplingMetadata,
-    ) -> SamplerOutput:
-        head = self.lm_head.linear
-        next_tokens = self.sampler(head.weight, hidden_states,
-                                   sampling_metadata, head.bias)
-        return next_tokens
-
-    def load_weights(self,
-                     model_name_or_path: str,
-                     cache_dir: Optional[str] = None,
-                     load_format: str = "auto",
-                     revision: Optional[str] = None):
-        params_dict = dict(self.named_parameters())
-        for name, loaded_weight in hf_model_weights_iterator(
-                model_name_or_path, cache_dir, load_format, revision):
-            if "rotary_emb.inv_freq" in name:
-                continue
-
-            # skip loading extra bias for GPTQ models
-            if name.endswith("bias") and name not in params_dict:
-                continue
-            param = params_dict[name]
-            weight_loader = getattr(param, "weight_loader",
-                                    default_weight_loader)
-            weight_loader(param, loaded_weight)

+ 19 - 20
aphrodite/modeling/models/yi.py

@@ -21,11 +21,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inference-only Yi model (https://01.ai) compatible with HuggingFace weights.
-
-The input of the model is flattened to a 1D tensor of tokens. The model uses
-InputMetadata to extract the original 2D shape of the input.
-"""
+"""Inference-only Yi model compatible with HuggingFace weights."""
 from typing import Any, Dict, List, Optional, Tuple
 
 import torch
@@ -198,24 +194,25 @@ class YiDecoderLayer(nn.Module):
         hidden_states: torch.Tensor,
         kv_cache: KVCache,
         input_metadata: InputMetadata,
-    ) -> torch.Tensor:
+        residual: Optional[torch.Tensor],
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
         # Self Attention
-        residual = hidden_states
-        hidden_states = self.ln1(hidden_states)
+        if residual is None:
+            residual = hidden_states
+            hidden_states = self.ln1(hidden_states)
+        else:
+            hidden_states, residual = self.ln1(hidden_states, residual)
         hidden_states = self.self_attn(
             positions=positions,
             hidden_states=hidden_states,
             kv_cache=kv_cache,
             input_metadata=input_metadata,
         )
-        hidden_states = residual + hidden_states
 
         # Fully Connected
-        residual = hidden_states
-        hidden_states = self.ln2(hidden_states)
+        hidden_states, residual = self.ln2(hidden_states, residual)
         hidden_states = self.mlp(hidden_states)
-        hidden_states = residual + hidden_states
-        return hidden_states
+        return hidden_states, residual
 
 
 class YiModel(nn.Module):
@@ -247,15 +244,17 @@ class YiModel(nn.Module):
         input_metadata: InputMetadata,
     ) -> torch.Tensor:
         hidden_states = self.embed_tokens(input_ids)
+        residual = None
         for i in range(len(self.layers)):
             layer = self.layers[i]
-            hidden_states = layer(
+            hidden_states, residual = layer(
                 positions,
                 hidden_states,
                 kv_caches[i],
                 input_metadata,
+                residual,
             )
-        hidden_states = self.norm(hidden_states)
+        hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
 
 
@@ -288,7 +287,7 @@ class YiForCausalLM(nn.Module):
         self,
         hidden_states: torch.Tensor,
         sampling_metadata: SamplingMetadata,
-    ) -> SamplerOutput:
+    ) -> Optional[SamplerOutput]:
         next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                    sampling_metadata)
         return next_tokens
@@ -315,16 +314,16 @@ class YiForCausalLM(nn.Module):
                 if weight_name not in name:
                     continue
                 name = name.replace(weight_name, param_name)
-                # skip loading extra bias for GPTQ models
-                if name.endswith("bias") and name not in params_dict:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
             else:
-                # skip loading extra bias for GPTQ models
-                if name.endswith("bias") and name not in params_dict:
+                # Skip loading extra bias for GPTQ models.
+                if name.endswith(".bias") and name not in params_dict:
                     continue
                 param = params_dict[name]
                 weight_loader = getattr(param, "weight_loader",

+ 15 - 14
aphrodite/modeling/sampling_metadata.py

@@ -29,23 +29,28 @@ class OutputMetadata(PersistentMetadata):
 
 class SamplingMetadata:
     """Metadata for input sequences. Used in sampler.
+
     Args:
         seq_groups: List of (seq_ids, sampling_params).
         seq_data: Seq_id -> SequenceData.
         prompt_lens: Lengths of prompts.
         selected_token_indices: Token indices selected for sampling.
-        categorized_sample_indices: SamplingType -> token indicies to sample.
+        categorized_sample_indices: SamplingType -> token indices to sample.
+        perform_sampling: Whether to perform sampling. This option is used to
+            make the sampling only happens in the driver worker, and disable
+            sampling in other worker processes.
         persistent_metadata: Metadata that persists across iterations.
-        output_metadata: the metadata of the output.
+        output_metadata: the output metadata.
     """
 
     def __init__(
         self,
-        seq_groups: List[Tuple[List[int], SamplingParams]],
-        seq_data: Dict[int, SequenceData],
-        prompt_lens: List[int],
+        seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
+        seq_data: Optional[Dict[int, SequenceData]],
+        prompt_lens: Optional[List[int]],
         selected_token_indices: torch.Tensor,
-        categorized_sample_indices: Dict[SamplingType, torch.Tensor],
+        categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
+        perform_sampling: bool = True,
         persistent_metadata: Optional[PersistentMetadata] = None,
         output_metadata: Optional[OutputMetadata] = None,
     ) -> None:
@@ -54,10 +59,11 @@ class SamplingMetadata:
         self.prompt_lens = prompt_lens
         self.selected_token_indices = selected_token_indices
         self.categorized_sample_indices = categorized_sample_indices
+        self.perform_sampling = perform_sampling
         self.persistent_metadata = persistent_metadata or PersistentMetadata()
         self.output_metadata = output_metadata or OutputMetadata()
 
-        self.num_prompts = len(prompt_lens)
+        self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
 
     def __repr__(self) -> str:
         return (
@@ -67,8 +73,9 @@ class SamplingMetadata:
             f"prompt_lens={self.prompt_lens}, "
             f"selected_token_indices={self.selected_token_indices}, "
             f"categorized_sample_indices={self.categorized_sample_indices}, "
+            f"perform_sampling={self.perform_sampling}, "
             f"persistent_metadata={self.persistent_metadata}, "
-            f"output_metadata={self.output_metadata})")
+            f"output_metadata={self.output_metadata}) ")
 
 
 @dataclass
@@ -181,12 +188,6 @@ class SamplingTensors:
             if do_mirostat is False and sampling_params.mirostat_mode == 2:
                 do_mirostat = True
 
-            # if not do_alphabet_soup and (top_p < 1.0 - _SAMPLING_EPS
-            #                              or top_k != vocab_size
-            #                              or top_a > 0.0
-            #                              or min_p > _SAMPLING_EPS):
-            #     do_alphabet_soup = True
-
             if (i < sampling_metadata.num_prompts
                     and sampling_params.prompt_logprobs is not None):
                 # For tokens in the prompt that we only need to get their

+ 47 - 13
aphrodite/processing/block_manager.py

@@ -2,13 +2,10 @@
 import enum
 from typing import Dict, List, Optional, Set, Tuple
 
-from aphrodite.common.block import PhysicalTokenBlock
+from aphrodite.common.block import BlockTable, PhysicalTokenBlock
 from aphrodite.common.sequence import Sequence, SequenceGroup, SequenceStatus
 from aphrodite.common.utils import Device
 
-# Mapping: logical block number -> physical block.
-BlockTable = List[PhysicalTokenBlock]
-
 
 class BlockAllocator:
     """Manages free physical token blocks for a device.
@@ -55,14 +52,14 @@ class BlockAllocator:
 
 
 class AllocStatus(enum.Enum):
-    """Result for BlockSpaceManager.can_allocate().
-    1. OK: seq_grouop can be allocated.
+    """Result for BlockSpaceManager.can_allocate
+
+    1. OK: seq_group can be allocated now.
     2. LATER: seq_group cannot be allocated.
-        The capacity of the allocator is larger
-        than seq_group required.
+      The capacity of allocator is larger than seq_group required.
     3. NEVER: seq_group can never be allocated.
-        The seq_group is too large to allocate in
-        GPU."""
+      The seq_group is too large to allocated in GPU.
+    """
     OK = enum.auto()
     LATER = enum.auto()
     NEVER = enum.auto()
@@ -105,10 +102,13 @@ class BlockSpaceManager:
         # the same prompt. This may not be true for preempted sequences.
         seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
         num_required_blocks = len(seq.logical_token_blocks)
+        if seq_group.prefix is not None and seq_group.prefix.allocated:
+            num_required_blocks -= seq_group.prefix.get_num_blocks()
         if self.block_sliding_window is not None:
             num_required_blocks = min(num_required_blocks,
                                       self.block_sliding_window)
         num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
+
         # Use watermark to avoid frequent cache eviction.
         if (self.num_total_gpu_blocks - num_required_blocks <
                 self.watermark_blocks):
@@ -124,8 +124,20 @@ class BlockSpaceManager:
         seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
 
         # Allocate new physical token blocks that will store the prompt tokens.
+        num_prompt_blocks = len(seq.logical_token_blocks)
         block_table: BlockTable = []
-        for logical_idx in range(len(seq.logical_token_blocks)):
+        prefix_block_table: BlockTable = []
+        num_prefix_blocks = 0
+
+        prefix = seq_group.prefix
+        if prefix is not None and prefix.allocated:
+            # Prefix has already been allocated. Use the existing block table.
+            num_prompt_blocks -= prefix.get_num_blocks()
+            for block in prefix.block_table:
+                block.ref_count += seq_group.num_seqs()
+                block_table.append(block)
+
+        for logical_idx in range(num_prompt_blocks):
             if (self.block_sliding_window is not None
                     and logical_idx >= self.block_sliding_window):
                 block = block_table[logical_idx % self.block_sliding_window]
@@ -135,6 +147,15 @@ class BlockSpaceManager:
             block.ref_count = seq_group.num_seqs()
             block_table.append(block)
 
+        if prefix is not None and not prefix.allocated:
+            # Allocate blocks for the prefix, we will compute the prefix's
+            # KV cache in this run.
+            num_prefix_blocks = prefix.get_num_blocks()
+            prefix_block_table = block_table[:num_prefix_blocks]
+            for block in prefix_block_table:
+                block.ref_count += 1
+            prefix.set_block_table(prefix_block_table)
+
         # Assign the block table for each sequence.
         for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
             self.block_tables[seq.seq_id] = block_table.copy()
@@ -152,13 +173,14 @@ class BlockSpaceManager:
         block_table = self.block_tables[seq.seq_id]
 
         if len(block_table) < len(logical_blocks):
-            # The sequence has a new logical block.
-            # Allocate a new physical block.
             if (self.block_sliding_window
                     and len(block_table) >= self.block_sliding_window):
+                # re-use a block
                 block_table.append(block_table[len(block_table) %
                                                self.block_sliding_window])
             else:
+                # The sequence has a new logical block.
+                # Allocate a new physical block.
                 block = self.gpu_allocator.allocate()
                 block_table.append(block)
                 return None
@@ -208,10 +230,17 @@ class BlockSpaceManager:
 
     def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
         # CPU block -> GPU block.
+        if seq_group.prefix is not None:
+            # make sure to swap in the prefix first
+            assert seq_group.prefix.allocated and seq_group.prefix.computed
         mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
         for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
             new_block_table: BlockTable = []
             block_table = self.block_tables[seq.seq_id]
+            if seq_group.prefix is not None:
+                for block in seq_group.prefix.block_table:
+                    new_block_table.append(block)
+                    block.ref_count += 1
 
             for cpu_block in block_table:
                 if cpu_block in mapping:
@@ -243,6 +272,11 @@ class BlockSpaceManager:
             block_table = self.block_tables[seq.seq_id]
 
             for gpu_block in block_table:
+                if (seq_group.prefix is not None
+                        and gpu_block in seq_group.prefix.block_table):
+                    # NOTE: We do not swap out the prefix blocks for now.
+                    self.gpu_allocator.free(gpu_block)
+                    continue
                 if gpu_block in mapping:
                     cpu_block = mapping[gpu_block]
                     cpu_block.ref_count += 1

+ 10 - 8
aphrodite/processing/policy.py

@@ -1,4 +1,5 @@
-from typing import List
+from collections import deque
+from typing import Deque
 
 from aphrodite.common.sequence import SequenceGroup
 
@@ -15,13 +16,14 @@ class Policy:
     def sort_by_priority(
         self,
         now: float,
-        seq_groups: List[SequenceGroup],
-    ) -> List[SequenceGroup]:
-        return sorted(
-            seq_groups,
-            key=lambda seq_group: self.get_priority(now, seq_group),
-            reverse=True,
-        )
+        seq_groups: Deque[SequenceGroup],
+    ) -> Deque[SequenceGroup]:
+        return deque(
+            sorted(
+                seq_groups,
+                key=lambda seq_group: self.get_priority(now, seq_group),
+                reverse=True,
+            ))
 
 
 class FCFS(Policy):

+ 51 - 33
aphrodite/processing/scheduler.py

@@ -1,6 +1,7 @@
+from collections import deque
 import enum
 import time
-from typing import Dict, Iterable, List, Optional, Tuple, Union
+from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union
 
 from aphrodite.common.config import CacheConfig, SchedulerConfig
 from aphrodite.processing.block_manager import AllocStatus, BlockSpaceManager
@@ -8,6 +9,7 @@ from aphrodite.processing.policy import PolicyFactory
 from aphrodite.common.logger import init_logger
 from aphrodite.common.sequence import (Sequence, SequenceData, SequenceGroup,
                                        SequenceGroupMetadata, SequenceStatus)
+from aphrodite.common.prefix import PrefixPool
 
 logger = init_logger(__name__)
 
@@ -29,7 +31,7 @@ class SchedulerOutputs:
 
     def __init__(
         self,
-        scheduled_seq_groups: List[SequenceGroup],
+        scheduled_seq_groups: Iterable[SequenceGroup],
         prompt_run: bool,
         num_batched_tokens: int,
         blocks_to_swap_in: Dict[int, int],
@@ -75,38 +77,55 @@ class Scheduler:
             num_cpu_blocks=self.cache_config.num_cpu_blocks,
             sliding_window=self.cache_config.sliding_window)
 
-        # TODO: Use deque instead of list for better performance.
+        # Create the prefix pool to cache the prefixes.
+        self.prefix_pool = PrefixPool(self.cache_config.block_size)
+
         # Sequence groups in the WAITING state.
-        self.waiting: List[SequenceGroup] = []
+        self.waiting: Deque[SequenceGroup] = deque()
         # Sequence groups in the RUNNING state.
-        self.running: List[SequenceGroup] = []
+        self.running: Deque[SequenceGroup] = deque()
         # Sequence groups in the SWAPPED state.
-        self.swapped: List[SequenceGroup] = []
+        self.swapped: Deque[SequenceGroup] = deque()
 
     def add_seq_group(self, seq_group: SequenceGroup) -> None:
         # Add sequence groups to the waiting queue.
         self.waiting.append(seq_group)
 
     def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
+        """Aborts a sequence group with the given ID.
+
+        Check if the sequence group with the given ID
+            is present in any of the state queue.
+        If present, remove the sequence group from the state queue.
+            Also, if any of the sequences in the sequence group is not finished,
+                free the sequence with status `FINISHED_ABORTED`.
+        Otherwise, do nothing.
+
+        Args:
+            request_id: The ID(s) of the sequence group to abort.
+        """
         if isinstance(request_id, str):
             request_id = (request_id, )
         request_ids = set(request_id)
         for state_queue in [self.waiting, self.running, self.swapped]:
-            # We need to reverse the list as we are removing elements
-            # from it as we iterate over it. If we don't do it,
-            # indices will get messed up and we will skip over elements.
-            for seq_group in reversed(state_queue):
+            aborted_groups: List[SequenceGroup] = []
+            for seq_group in state_queue:
+                if not request_ids:
+                    # Using 'break' here may add two extra iterations,
+                    # but is acceptable to reduce complexity .
+                    break
                 if seq_group.request_id in request_ids:
-                    # Remove the sequence group from the state queue.
-                    state_queue.remove(seq_group)
-                    for seq in seq_group.get_seqs():
-                        if seq.is_finished():
-                            continue
-                        seq.status = SequenceStatus.FINISHED_ABORTED
-                        self.free_seq(seq)
+                    # Appending aborted group into pending list.
+                    aborted_groups.append(seq_group)
                     request_ids.remove(seq_group.request_id)
-                    if not request_ids:
-                        return
+            for aborted_group in aborted_groups:
+                # Remove the sequence group from the state queue.
+                state_queue.remove(aborted_group)
+                for seq in aborted_group.get_seqs():
+                    if seq.is_finished():
+                        continue
+                    seq.status = SequenceStatus.FINISHED_ABORTED
+                    self.free_seq(seq)
 
     def has_unfinished_seqs(self) -> bool:
         return self.waiting or self.running or self.swapped
@@ -152,7 +171,7 @@ class Scheduler:
                     for seq in waiting_seqs:
                         seq.status = SequenceStatus.FINISHED_IGNORED
                     ignored_seq_groups.append(seq_group)
-                    self.waiting.pop(0)
+                    self.waiting.popleft()
                     continue
 
                 # If the sequence group cannot be allocated, stop.
@@ -162,11 +181,11 @@ class Scheduler:
                 elif can_allocate == AllocStatus.NEVER:
                     logger.warning(
                         f"Input prompt ({num_prompt_tokens} tokens) is too long"
-                        f" and exceeds the capacity of the block manager.")
+                        f" and exceeds the capacity of block_manager")
                     for seq in waiting_seqs:
                         seq.status = SequenceStatus.FINISHED_IGNORED
                     ignored_seq_groups.append(seq_group)
-                    self.waiting.pop(0)
+                    self.waiting.popleft()
                     continue
 
                 # If the number of batched tokens exceeds the limit, stop.
@@ -188,7 +207,7 @@ class Scheduler:
                     break
                 seq_lens = new_seq_lens
 
-                seq_group = self.waiting.pop(0)
+                seq_group = self.waiting.popleft()
                 self._allocate(seq_group)
                 self.running.append(seq_group)
                 num_curr_seqs += num_new_seqs
@@ -214,14 +233,14 @@ class Scheduler:
         self.running = self.policy.sort_by_priority(now, self.running)
 
         # Reserve new token slots for the running sequence groups.
-        running: List[SequenceGroup] = []
+        running: Deque[SequenceGroup] = deque()
         preempted: List[SequenceGroup] = []
         while self.running:
-            seq_group = self.running.pop(0)
+            seq_group = self.running.popleft()
             while not self.block_manager.can_append_slot(seq_group):
                 if self.running:
                     # Preempt the lowest-priority sequence groups.
-                    victim_seq_group = self.running.pop(-1)
+                    victim_seq_group = self.running.pop()
                     self._preempt(victim_seq_group, blocks_to_swap_out)
                     preempted.append(victim_seq_group)
                 else:
@@ -255,7 +274,7 @@ class Scheduler:
                         self.scheduler_config.max_num_seqs):
                     break
 
-                seq_group = self.swapped.pop(0)
+                seq_group = self.swapped.popleft()
                 self._swap_in(seq_group, blocks_to_swap_in)
                 self._append_slot(seq_group, blocks_to_copy)
                 num_curr_seqs += num_new_seqs
@@ -304,6 +323,7 @@ class Scheduler:
                 sampling_params=seq_group.sampling_params,
                 block_tables=block_tables,
                 persistent_data=persistent_data,
+                prefix=seq_group.prefix,
             )
             seq_group_metadata_list.append(seq_group_metadata)
         return seq_group_metadata_list, scheduler_outputs
@@ -315,10 +335,8 @@ class Scheduler:
         self.block_manager.free(seq)
 
     def free_finished_seq_groups(self) -> None:
-        self.running = [
-            seq_group for seq_group in self.running
-            if not seq_group.is_finished()
-        ]
+        self.running = deque(seq_group for seq_group in self.running
+                             if not seq_group.is_finished())
 
     def _allocate(self, seq_group: SequenceGroup) -> None:
         self.block_manager.allocate(seq_group)
@@ -366,7 +384,7 @@ class Scheduler:
         elif preemption_mode == PreemptionMode.SWAP:
             self._preempt_by_swap(seq_group, blocks_to_swap_out)
         else:
-            assert False, "Invalid preemption mode."
+            raise AssertionError("Invalid preemption mode.")
 
     def _preempt_by_recompute(
         self,
@@ -379,7 +397,7 @@ class Scheduler:
             self.block_manager.free(seq)
         # NOTE: For FCFS, we insert the preempted sequence group to the front
         # of the waiting queue.
-        self.waiting.insert(0, seq_group)
+        self.waiting.appendleft(seq_group)
 
     def _preempt_by_swap(
         self,

+ 3 - 7
aphrodite/task_handler/cache_engine.py

@@ -1,4 +1,4 @@
-"""CacheEngine for managing the KV cache"""
+"""CacheEngine class for managing the KV cache."""
 from typing import Dict, List, Tuple
 
 import torch
@@ -34,8 +34,7 @@ class CacheEngine:
         self.head_size = model_config.get_head_size()
         self.num_layers = model_config.get_num_layers(parallel_config)
         self.num_heads = model_config.get_num_kv_heads(parallel_config)
-        self.dtype = (cache_config.cache_dtype
-                      if cache_config.cache_dtype else model_config.dtype)
+        self.dtype = model_config.dtype
 
         self.block_size = cache_config.block_size
         self.num_gpu_blocks = cache_config.num_gpu_blocks
@@ -143,7 +142,6 @@ class CacheEngine:
     @staticmethod
     def get_cache_block_size(
         block_size: int,
-        cache_dtype: torch.dtype,
         model_config: ModelConfig,
         parallel_config: ParallelConfig,
     ) -> int:
@@ -154,9 +152,7 @@ class CacheEngine:
         key_cache_block = block_size * num_heads * head_size
         value_cache_block = key_cache_block
         total = num_layers * (key_cache_block + value_cache_block)
-        if cache_dtype is None:
-            cache_dtype = model_config.dtype
-        dtype_size = _get_dtype_size(cache_dtype)
+        dtype_size = _get_dtype_size(model_config.dtype)
         return dtype_size * total
 
 

+ 155 - 48
aphrodite/task_handler/model_runner.py

@@ -1,5 +1,5 @@
 import time
-from typing import Dict, List, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
 
 import numpy as np
 import torch
@@ -9,6 +9,8 @@ from aphrodite.common.config import (ModelConfig, ParallelConfig,
                                      SchedulerConfig)
 from aphrodite.common.logger import init_logger
 from aphrodite.modeling import get_model, InputMetadata, SamplingMetadata
+from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
+                                                          )
 from aphrodite.common.sampling_params import SamplingParams, SamplingType
 from aphrodite.common.sequence import (SamplerOutput, SequenceData,
                                        SequenceGroupMetadata)
@@ -31,10 +33,12 @@ class ModelRunner:
         model_config: ModelConfig,
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
+        is_driver_worker: bool = False,
     ):
         self.model_config = model_config
         self.parallel_config = parallel_config
         self.scheduler_config = scheduler_config
+        self.is_driver_worker = is_driver_worker
 
         # model_config can be None in tests/samplers/test_sampler.py.
         # FIXME: This is a hack to make the tests work. Refactor this.
@@ -73,13 +77,17 @@ class ModelRunner:
     def _prepare_prompt(
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
-    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
+    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int],
+               List[int]]:
         assert len(seq_group_metadata_list) > 0
         input_tokens: List[List[int]] = []
         input_positions: List[List[int]] = []
         slot_mapping: List[List[int]] = []
 
         prompt_lens: List[int] = []
+        context_lens: List[int] = []
+        subquery_lens: List[int] = []
+        prefix_block_tables: List[List[int]] = []
         for seq_group_metadata in seq_group_metadata_list:
             assert seq_group_metadata.is_prompt
             seq_ids = list(seq_group_metadata.seq_data.keys())
@@ -90,11 +98,23 @@ class ModelRunner:
             prompt_tokens = seq_data.get_token_ids()
             prompt_len = len(prompt_tokens)
             prompt_lens.append(prompt_len)
+            prefix_len = 0
+            prefix = seq_group_metadata.prefix
+            if prefix is not None and prefix.computed:
+                prefix_len = prefix.get_length()
+                prompt_tokens = prompt_tokens[prefix_len:]
+                prefix_block_tables.append(prefix.get_block_numbers())
+            else:
+                prefix_block_tables.append([])
+            # actual prompt lens
+            context_lens.append(prefix_len)
+            subquery_lens.append(prompt_len - prefix_len)
 
             input_tokens.append(prompt_tokens)
             # NOTE: Here we assume that the first token in the prompt
             # is always the first token in the sequence.
-            input_positions.append(list(range(prompt_len)))
+            input_positions.append(
+                list(range(prefix_len, prefix_len + len(prompt_tokens))))
 
             if seq_group_metadata.block_tables is None:
                 # During memory profiling, the block tables are not initialized
@@ -112,8 +132,11 @@ class ModelRunner:
             # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
             start_idx = 0
             if self.sliding_window is not None:
+                assert prefix_len == 0, (
+                    "Prefix caching is currently not supported with "
+                    "sliding window attention")
                 start_idx = max(0, prompt_len - self.sliding_window)
-            for i in range(prompt_len):
+            for i in range(prefix_len, prompt_len):
                 if i < start_idx:
                     slot_mapping[-1].append(_PAD_SLOT_ID)
                     continue
@@ -123,7 +146,7 @@ class ModelRunner:
                 slot = block_number * self.block_size + block_offset
                 slot_mapping[-1].append(slot)
 
-        max_prompt_len = max(prompt_lens)
+        max_prompt_len = max(subquery_lens)
         input_tokens = _make_tensor_with_pad(input_tokens,
                                              max_prompt_len,
                                              pad=0,
@@ -136,16 +159,39 @@ class ModelRunner:
                                              max_prompt_len,
                                              pad=_PAD_SLOT_ID,
                                              dtype=torch.long)
+        context_lens_tensor = torch.tensor(context_lens,
+                                           dtype=torch.int,
+                                           device="cuda")
+        # Prepare prefix block tables
+        max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
+        block_tables = _make_tensor_with_pad(
+            prefix_block_tables,
+            max_len=max_prompt_block_table_len,
+            pad=0,
+            dtype=torch.int,
+        )
+        start_loc_tensor = torch.arange(0,
+                                        len(prompt_lens) * max_prompt_len,
+                                        max_prompt_len,
+                                        dtype=torch.long,
+                                        device="cuda")
+        prompt_lens_tensor = torch.tensor(prompt_lens,
+                                          dtype=torch.long,
+                                          device="cuda")
 
         input_metadata = InputMetadata(
-            prompt_lens=prompt_lens,
+            is_prompt=True,
             slot_mapping=slot_mapping,
+            prompt_lens=prompt_lens_tensor,
+            max_seq_len=max_prompt_len,
+            start_loc=start_loc_tensor,
             max_context_len=None,
-            context_lens=None,
-            block_tables=None,
+            context_lens=context_lens_tensor,
+            block_tables=block_tables,
             use_cuda_graph=False,
         )
-        return input_tokens, input_positions, input_metadata
+        return (input_tokens, input_positions, input_metadata, prompt_lens,
+                subquery_lens)
 
     def _prepare_decode(
         self,
@@ -206,32 +252,24 @@ class ModelRunner:
                 block_tables.append([])
             batch_size = graph_batch_size
 
-        # When using CUDA graph, we don't need to make the tensors on the GPU
-        # because they will be eventually copied to the designated GPU buffer.
-        device = "cpu" if use_captured_graph else "cuda"
-        pin_memory = use_captured_graph and not self.in_wsl
         input_tokens = _make_tensor_with_pad(input_tokens,
                                              max_len=1,
                                              pad=0,
                                              dtype=torch.long,
-                                             device=device,
-                                             pin_memory=pin_memory)
+                                             device="cuda")
         input_positions = _make_tensor_with_pad(input_positions,
                                                 max_len=1,
                                                 pad=0,
                                                 dtype=torch.long,
-                                                device=device,
-                                                pin_memory=pin_memory)
+                                                device="cuda")
         slot_mapping = _make_tensor_with_pad(slot_mapping,
                                              max_len=1,
                                              pad=_PAD_SLOT_ID,
                                              dtype=torch.long,
-                                             device=device,
-                                             pin_memory=pin_memory)
+                                             device="cuda")
         context_lens = torch.tensor(context_lens,
                                     dtype=torch.int,
-                                    device=device,
-                                    pin_memory=pin_memory)
+                                    device="cuda")
 
         if use_captured_graph:
             # The shape of graph_block_tables is
@@ -240,18 +278,24 @@ class ModelRunner:
             for i, block_table in enumerate(block_tables):
                 if block_table:
                     input_block_tables[i, :len(block_table)] = block_table
-            block_tables = torch.tensor(input_block_tables, device=device)
+            block_tables = torch.tensor(input_block_tables, device="cuda")
         else:
+            max_block_table_len = (max_context_len + self.block_size -
+                                   1) // self.block_size
             block_tables = _make_tensor_with_pad(
                 block_tables,
-                max_len=max_context_len,
+                max_len=max_block_table_len,
                 pad=0,
                 dtype=torch.int,
+                device="cuda",
             )
 
         input_metadata = InputMetadata(
-            prompt_lens=[],
+            is_prompt=False,
             slot_mapping=slot_mapping,
+            prompt_lens=None,
+            max_seq_len=None,
+            start_loc=None,
             max_context_len=max_context_len,
             context_lens=context_lens,
             block_tables=block_tables,
@@ -263,6 +307,7 @@ class ModelRunner:
         self,
         seq_group_metadata_list: List[SequenceGroupMetadata],
         prompt_lens: List[int],
+        subquery_lens: Optional[List[int]],
     ) -> SamplingMetadata:
         seq_groups: List[Tuple[List[int], SamplingParams]] = []
         selected_token_indices: List[int] = []
@@ -270,7 +315,7 @@ class ModelRunner:
         categorized_sample_indices = {t: [] for t in SamplingType}
         categorized_sample_indices_start_idx = 0
 
-        max_prompt_len = max(prompt_lens) if prompt_lens else 1
+        max_subquery_len = max(subquery_lens) if subquery_lens else 1
         for i, seq_group_metadata in enumerate(seq_group_metadata_list):
             seq_ids = list(seq_group_metadata.seq_data.keys())
             sampling_params = seq_group_metadata.sampling_params
@@ -278,10 +323,11 @@ class ModelRunner:
 
             if seq_group_metadata.is_prompt:
                 assert len(seq_ids) == 1
-                prompt_len = prompt_lens[i]
+                assert subquery_lens is not None
+                subquery_len = subquery_lens[i]
                 if sampling_params.prompt_logprobs is not None:
                     # NOTE: prompt token positions do not need sample, skip
-                    categorized_sample_indices_start_idx += prompt_len - 1
+                    categorized_sample_indices_start_idx += subquery_len - 1
 
                 categorized_sample_indices[
                     sampling_params.sampling_type].append(
@@ -291,10 +337,10 @@ class ModelRunner:
                 if sampling_params.prompt_logprobs is not None:
                     selected_token_indices.extend(
                         range(selected_token_start_idx,
-                              selected_token_start_idx + prompt_len - 1))
+                              selected_token_start_idx + subquery_len - 1))
                 selected_token_indices.append(selected_token_start_idx +
-                                              prompt_len - 1)
-                selected_token_start_idx += max_prompt_len
+                                              subquery_len - 1)
+                selected_token_start_idx += max_subquery_len
             else:
                 num_seqs = len(seq_ids)
                 selected_token_indices.extend(
@@ -334,23 +380,78 @@ class ModelRunner:
         )
         return sampling_metadata
 
+    def prepare_input_tensors(
+        self,
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
+    ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]:
+        if self.is_driver_worker:
+            # NOTE: We assume that all sequences in the group are all prompts or
+            # all decodes.
+            is_prompt = seq_group_metadata_list[0].is_prompt
+            # Prepare input tensors.
+            if is_prompt:
+                (input_tokens, input_positions, input_metadata, prompt_lens,
+                 subquery_lens) = self._prepare_prompt(seq_group_metadata_list)
+            else:
+                (input_tokens, input_positions, input_metadata
+                 ) = self._prepare_decode(seq_group_metadata_list)
+                subquery_lens = None
+                prompt_lens = []
+            sampling_metadata = self._prepare_sample(seq_group_metadata_list,
+                                                     prompt_lens,
+                                                     subquery_lens)
+
+            # Broadcast the metadata.
+            metadata_dict = {
+                "input_tokens": input_tokens,
+                "input_positions": input_positions,
+                "is_prompt": input_metadata.is_prompt,
+                "slot_mapping": input_metadata.slot_mapping,
+                "prompt_lens": input_metadata.prompt_lens,
+                "max_seq_len": input_metadata.max_seq_len,
+                "start_loc": input_metadata.start_loc,
+                "max_context_len": input_metadata.max_context_len,
+                "context_lens": input_metadata.context_lens,
+                "block_tables": input_metadata.block_tables,
+                "use_cuda_graph": input_metadata.use_cuda_graph,
+                "selected_token_indices":
+                sampling_metadata.selected_token_indices,
+            }
+            broadcast_tensor_dict(metadata_dict, src=0)
+        else:
+            metadata_dict = broadcast_tensor_dict(src=0)
+            input_tokens = metadata_dict["input_tokens"]
+            input_positions = metadata_dict["input_positions"]
+            input_metadata = InputMetadata(
+                is_prompt=metadata_dict["is_prompt"],
+                slot_mapping=metadata_dict["slot_mapping"],
+                prompt_lens=metadata_dict["prompt_lens"],
+                max_seq_len=metadata_dict["max_seq_len"],
+                start_loc=metadata_dict["start_loc"],
+                max_context_len=metadata_dict["max_context_len"],
+                context_lens=metadata_dict["context_lens"],
+                block_tables=metadata_dict["block_tables"],
+                use_cuda_graph=metadata_dict["use_cuda_graph"],
+            )
+            sampling_metadata = SamplingMetadata(
+                seq_groups=None,
+                seq_data=None,
+                prompt_lens=None,
+                selected_token_indices=metadata_dict["selected_token_indices"],
+                categorized_sample_indices=None,
+                perform_sampling=False,
+            )
+
+        return input_tokens, input_positions, input_metadata, sampling_metadata
+
     @torch.inference_mode()
     def execute_model(
         self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
         kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
-    ) -> SamplerOutput:
-        # NOTE: We assume that all sequences in the group are all prompts or
-        # all decodes.
-        is_prompt = seq_group_metadata_list[0].is_prompt
-        # Prepare input tensors.
-        if is_prompt:
-            inputs = self._prepare_prompt(seq_group_metadata_list)
-            input_tokens, input_positions, input_metadata = inputs
-        else:
-            inputs = self._prepare_decode(seq_group_metadata_list)
-            input_tokens, input_positions, input_metadata = inputs
-
+    ) -> Optional[SamplerOutput]:
+        input_tokens, input_positions, input_metadata, sampling_metadata = (
+            self.prepare_input_tensors(seq_group_metadata_list))
         # Execute the model.
         if input_metadata.use_cuda_graph:
             graph_batch_size = input_tokens.shape[0]
@@ -364,9 +465,6 @@ class ModelRunner:
             input_metadata=input_metadata,
         )
 
-        sampling_metadata = self._prepare_sample(seq_group_metadata_list,
-                                                 input_metadata.prompt_lens)
-
         # Sample the next token.
         output = self.model.sample(
             hidden_states=hidden_states,
@@ -429,13 +527,22 @@ class ModelRunner:
         context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
         block_tables = torch.from_numpy(self.graph_block_tables).cuda()
 
+        graph_batch_size = _get_graph_batch_size(
+            self.scheduler_config.max_num_seqs)
+        batch_size_capture_list = [
+            bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
+        ]
+
         # NOTE: Capturing the largest batch size first may help reduce the
         # memory usage of CUDA graph.
-        for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
+        for batch_size in reversed(batch_size_capture_list):
             # Create dummy input_metadata.
             input_metadata = InputMetadata(
-                prompt_lens=[],
+                is_prompt=False,
                 slot_mapping=slot_mapping[:batch_size],
+                prompt_lens=None,
+                max_seq_len=None,
+                start_loc=None,
                 max_context_len=self.max_context_len_to_capture,
                 context_lens=context_lens[:batch_size],
                 block_tables=block_tables[:batch_size],

+ 63 - 26
aphrodite/task_handler/worker.py

@@ -8,6 +8,8 @@ import torch.distributed
 from aphrodite.common.config import (CacheConfig, ModelConfig, ParallelConfig,
                                      SchedulerConfig)
 from aphrodite.modeling import set_random_seed
+from aphrodite.modeling.megatron.communication_op import (broadcast_tensor_dict
+                                                          )
 from aphrodite.modeling.megatron.parallel_state import (
     initialize_model_parallel)
 from aphrodite.common.sequence import SamplerOutput, SequenceGroupMetadata
@@ -28,17 +30,23 @@ class Worker:
         model_config: ModelConfig,
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
-        rank: Optional[int] = None,
-        distributed_init_method: Optional[str] = None,
+        local_rank: int,
+        rank: int,
+        distributed_init_method: str,
+        is_driver_worker: bool = False,
     ) -> None:
         self.model_config = model_config
         self.parallel_config = parallel_config
         self.scheduler_config = scheduler_config
+        self.local_rank = local_rank
         self.rank = rank
         self.distributed_init_method = distributed_init_method
+        self.is_driver_worker = is_driver_worker
+        if self.is_driver_worker:
+            assert self.rank == 0, "The driver worker must have rank 0."
 
         self.model_runner = ModelRunner(model_config, parallel_config,
-                                        scheduler_config)
+                                        scheduler_config, is_driver_worker)
         # Uninitialized cache engine. Will be initialized by
         # self.init_cache_engine().
         self.cache_config = None
@@ -50,19 +58,14 @@ class Worker:
         # torch.distributed.all_reduce does not free the input tensor until
         # the synchronization point. This causes the memory usage to grow
         # as the number of all_reduce calls increases. This env var disables
-        # this behaviour.
-
+        # this behavior.
+        # Related issue:
+        # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
         os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
 
         # This env var set by Ray causes exceptions with graph building.
         os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
-        # Env vars will be set by Ray.
-        self.rank = self.rank if self.rank is not None else int(
-            os.getenv("RANK", "-1"))
-        local_rank = int(os.getenv("LOCAL_RANK", "0"))
-        self.device = torch.device(f"cuda:{local_rank}")
-        if self.rank < 0:
-            raise ValueError("Invalid or unspecified rank.")
+        self.device = torch.device(f"cuda:{self.local_rank}")
         torch.cuda.set_device(self.device)
 
         _check_if_gpu_supports_dtype(self.model_config.dtype)
@@ -83,8 +86,15 @@ class Worker:
         block_size: int,
         gpu_memory_utilization: float,
         cpu_swap_space: int,
-        cache_dtype: torch.dtype,
     ) -> Tuple[int, int]:
+        """Profiles the peak memory usage of the model and returns the maximum
+        number of GPU and CPU cache blocks that can be allocated.
+
+        Args:
+            block_size: The size of the cache block.
+            gpu_memory_utilization: The fraction of the total GPU memory to use.
+            cpu_swap_space: The size of the CPU swap space in bytes.
+        """
         # Profile the memory usage of the model and get the maximum number of
         # cache blocks that can be allocated with the remaining free memory.
         torch.cuda.empty_cache()
@@ -100,7 +110,7 @@ class Worker:
         peak_memory = total_gpu_memory - free_gpu_memory
 
         cache_block_size = CacheEngine.get_cache_block_size(
-            block_size, cache_dtype, self.model_config, self.parallel_config)
+            block_size, self.model_config, self.parallel_config)
         num_gpu_blocks = int(
             (total_gpu_memory * gpu_memory_utilization - peak_memory) //
             cache_block_size)
@@ -108,7 +118,6 @@ class Worker:
         num_gpu_blocks = max(num_gpu_blocks, 0)
         num_cpu_blocks = max(num_cpu_blocks, 0)
         torch.cuda.empty_cache()
-
         return num_gpu_blocks, num_cpu_blocks
 
     def init_cache_engine(self, cache_config: CacheConfig) -> None:
@@ -126,14 +135,12 @@ class Worker:
         # the model initialization and profiling.
         set_random_seed(self.model_config.seed)
 
-    @torch.inference_mode()
-    def execute_model(
+    def cache_swap(
         self,
-        seq_group_metadata_list: List[SequenceGroupMetadata],
         blocks_to_swap_in: Dict[int, int],
         blocks_to_swap_out: Dict[int, int],
         blocks_to_copy: Dict[int, List[int]],
-    ) -> SamplerOutput:
+    ) -> None:
         # Issue cache operations.
         issued_cache_op = False
         if blocks_to_swap_in:
@@ -148,14 +155,44 @@ class Worker:
 
         cache_events = self.cache_events if issued_cache_op else None
 
-        # Wati for cache operations to finish.
+        # Wait for cache operations to finish.
         # TODO: Profile swapping overhead and optimize if needed.
         if cache_events is not None:
             for event in cache_events:  # pylint: disable=not-an-iterable
                 event.wait()
 
+    @torch.inference_mode()
+    def execute_model(
+        self,
+        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
+        blocks_to_swap_in: Optional[Dict[int, int]] = None,
+        blocks_to_swap_out: Optional[Dict[int, int]] = None,
+        blocks_to_copy: Optional[Dict[int, List[int]]] = None,
+    ) -> Optional[SamplerOutput]:
+        if self.is_driver_worker:
+            assert seq_group_metadata_list is not None
+            num_seq_groups = len(seq_group_metadata_list)
+            assert blocks_to_swap_in is not None
+            assert blocks_to_swap_out is not None
+            assert blocks_to_copy is not None
+            data = {
+                "num_seq_groups": num_seq_groups,
+                "blocks_to_swap_in": blocks_to_swap_in,
+                "blocks_to_swap_out": blocks_to_swap_out,
+                "blocks_to_copy": blocks_to_copy,
+            }
+            broadcast_tensor_dict(data, src=0)
+        else:
+            data = broadcast_tensor_dict(src=0)
+            num_seq_groups = data["num_seq_groups"]
+            blocks_to_swap_in = data["blocks_to_swap_in"]
+            blocks_to_swap_out = data["blocks_to_swap_out"]
+            blocks_to_copy = data["blocks_to_copy"]
+
+        self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
+
         # If there is no input, we don't need to execute the model.
-        if not seq_group_metadata_list:
+        if num_seq_groups == 0:
             return {}
 
         output = self.model_runner.execute_model(seq_group_metadata_list,
@@ -190,19 +227,19 @@ def _init_distributed_environment(
 
     # A small all_reduce for warmup.
     torch.distributed.all_reduce(torch.zeros(1).cuda())
-
     initialize_model_parallel(parallel_config.tensor_parallel_size,
                               parallel_config.pipeline_parallel_size)
 
 
 def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
+    # Check if the GPU supports the dtype.
     if torch_dtype == torch.bfloat16:
         compute_capability = torch.cuda.get_device_capability()
         if compute_capability[0] < 8:
             gpu_name = torch.cuda.get_device_name()
             raise ValueError(
                 "Bfloat16 is only supported on GPUs with compute capability "
-                f"of at least 8.0. You {gpu_name} GPU has compute capability "
-                f"{compute_capability[0]}.{compute_capability[1]}. Please "
-                "use the `--dtype float16` argument when launching the engine."
-            )
+                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
+                f"{compute_capability[0]}.{compute_capability[1]}. "
+                "You can use float16 instead by explicitly setting the"
+                "`dtype` flag in CLI, for example: --dtype=half.")

+ 59 - 0
examples/prefix_cache_example.py

@@ -0,0 +1,59 @@
+from aphrodite import LLM, SamplingParams
+
+prefix = (
+    "You are an expert school principal, skilled in effectively managing "
+    "faculty and staff. Draft 10-15 questions for a potential first grade "
+    "Head Teacher for my K-12, all-girls', independent school that emphasizes "
+    "community, joyful discovery, and life-long learning. The candidate is "
+    "coming in for a first-round panel interview for a 8th grade Math "
+    "teaching role. They have 5 years of previous teaching experience "
+    "as an assistant teacher at a co-ed, public school with experience "
+    "in middle school math teaching. Based on these information, fulfill "
+    "the following paragraph: ")
+
+# Sample prompts.
+prompts = [
+    "Hello, my name is",
+    "The president of the United States is",
+    "The capital of France is",
+    "The future of AI is",
+]
+# Create a sampling params object.
+sampling_params = SamplingParams(temperature=0.0)
+
+# Create an LLM.
+llm = LLM(model="EleutherAI/pythia-70m-deduped")
+
+generating_prompts = [prefix + prompt for prompt in prompts]
+
+# Generate texts from the prompts. The output is a list of RequestOutput objects
+# that contain the prompt, generated text, and other information.
+outputs = llm.generate(generating_prompts, sampling_params)
+# Print the outputs.
+for output in outputs:
+    prompt = output.prompt
+    generated_text = output.outputs[0].text
+    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+
+print("-" * 80)
+
+# -1 since the last token can change when concatenating prompts.
+prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
+
+# The llm.generate call will batch all prompts and send the batch at once if resources allow.
+# The prefix will only be cached after the first batch is processed, so we need to call generate once
+# to calculate the prefix and cache it.
+outputs = llm.generate(generating_prompts[0],
+                       sampling_params,
+                       prefix_pos=[prefix_pos])
+
+# Subsequent batches can leverage the cached prefix
+outputs = llm.generate(generating_prompts,
+                       sampling_params,
+                       prefix_pos=[prefix_pos] * len(generating_prompts))
+
+# Print the outputs. You should see the same outputs as before
+for output in outputs:
+    prompt = output.prompt
+    generated_text = output.outputs[0].text
+    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

+ 1 - 1
kernels/activation_kernels.cu

@@ -1,6 +1,6 @@
+#include <ATen/cuda/CUDAContext.h>
 #include <torch/extension.h>
 #include <c10/cuda/CUDAGuard.h>
-#include <ATen/cuda/CUDAContext.h>
 
 #include "cuda_compat.h"
 #include "dispatch_utils.h"

+ 0 - 1
kernels/attention/attention_dtypes.h

@@ -4,4 +4,3 @@
 #include "dtype_float16.cuh"
 #include "dtype_float32.cuh"
 #include "dtype_bfloat16.cuh"
-#include "dtype_fp8.cuh"

+ 85 - 138
kernels/attention/attention_kernels.cu

@@ -26,7 +26,6 @@
 
 #include "attention_dtypes.h"
 #include "attention_utils.cuh"
-#include "../quantization/kvcache/quant_utils.cuh"
 
 #include <algorithm>
 
@@ -81,19 +80,17 @@ inline __device__ float block_sum(float* red_smem, float sum) {
 // Grid: (num_heads, num_seqs, max_num_partitions).
 template<
   typename scalar_t,
-  typename cache_t,
   int HEAD_SIZE,
   int BLOCK_SIZE,
   int NUM_THREADS,
-  bool ENABLE_FP8_KV_CACHE,
   int PARTITION_SIZE = 0> // Zero means no partitioning.
 __device__ void paged_attention_kernel(
   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
   float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
   scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
   const int num_kv_heads,                 // [num_heads]
   const float scale,
   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
@@ -149,7 +146,6 @@ __device__ void paged_attention_kernel(
   constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
   using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
   using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
-  using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
 
   constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
   constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
@@ -181,7 +177,7 @@ __device__ void paged_attention_kernel(
 
   // x == THREAD_GROUP_SIZE * VEC_SIZE
   // Each thread group fetches x elements from the key at a time.
-  constexpr int x = 16 / sizeof(cache_t);
+  constexpr int x = 16 / sizeof(scalar_t);
   float qk_max = -FLT_MAX;
 
   // Iterate over the key blocks.
@@ -207,19 +203,13 @@ __device__ void paged_attention_kernel(
 
 #pragma unroll
       for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
-        const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
-                                       + kv_head_idx * kv_head_stride
-                                       + physical_block_offset * x;
+        const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+                                        + kv_head_idx * kv_head_stride
+                                        + physical_block_offset * x;
         const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
         const int offset1 = (vec_idx * VEC_SIZE) / x;
         const int offset2 = (vec_idx * VEC_SIZE) % x;
-        if constexpr (ENABLE_FP8_KV_CACHE) {
-          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
-          // Vector conversion from Quant_vec to K_vec.
-          k_vecs[j] = vec_conversion<K_vec, Quant_vec>(k_vec_quant);
-        } else {
-          k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
-        }
+        k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
       }
 
       // Compute dot product.
@@ -293,7 +283,6 @@ __device__ void paged_attention_kernel(
   constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
   using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
   using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
-  using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
   using Float_L_vec = typename FloatVec<L_vec>::Type;
 
   constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
@@ -319,21 +308,14 @@ __device__ void paged_attention_kernel(
     L_vec logits_vec;
     from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
 
-    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
-                                   + kv_head_idx * kv_head_stride;
+    const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+                                    + kv_head_idx * kv_head_stride;
 #pragma unroll
     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
       if (row_idx < HEAD_SIZE) {
         const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
-        V_vec v_vec;
-        if constexpr (ENABLE_FP8_KV_CACHE) {
-          V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
-          // Vector conversion from V_quant_vec to V_vec.
-          v_vec = vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
-        } else {
-          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
-        }
+        V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
         if (block_idx == num_context_blocks - 1) {
           // NOTE: When v_vec contains the tokens that are out of the context,
           // we should explicitly zero out the values since they may contain NaNs.
@@ -413,16 +395,14 @@ __device__ void paged_attention_kernel(
 // Grid: (num_heads, num_seqs, 1).
 template<
   typename scalar_t,
-  typename cache_t,
   int HEAD_SIZE,
   int BLOCK_SIZE,
-  int NUM_THREADS,
-  bool ENABLE_FP8_KV_CACHE>
+  int NUM_THREADS>
 __global__ void paged_attention_v1_kernel(
   scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
   const int num_kv_heads,                 // [num_heads]
   const float scale,
   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
@@ -432,7 +412,7 @@ __global__ void paged_attention_v1_kernel(
   const int q_stride,
   const int kv_block_stride,
   const int kv_head_stride) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, ENABLE_FP8_KV_CACHE>(
+  paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
     /* exp_sums */ nullptr, /* max_logits */ nullptr,
     out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
     max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
@@ -441,19 +421,17 @@ __global__ void paged_attention_v1_kernel(
 // Grid: (num_heads, num_seqs, max_num_partitions).
 template<
   typename scalar_t,
-  typename cache_t,
   int HEAD_SIZE,
   int BLOCK_SIZE,
   int NUM_THREADS,
-  bool ENABLE_FP8_KV_CACHE,
   int PARTITION_SIZE>
 __global__ void paged_attention_v2_kernel(
   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
   float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
   scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
+  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
   const int num_kv_heads,                 // [num_heads]
   const float scale,
   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
@@ -463,7 +441,7 @@ __global__ void paged_attention_v2_kernel(
   const int q_stride,
   const int kv_block_stride,
   const int kv_head_stride) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, ENABLE_FP8_KV_CACHE, PARTITION_SIZE>(
+  paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
     exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
     block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
     q_stride, kv_block_stride, kv_head_stride);
@@ -572,10 +550,10 @@ __global__ void paged_attention_v2_reduce_kernel(
 
 #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                  \
   APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                       \
-    ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,   \
-      ENABLE_FP8_KV_CACHE>), shared_mem_size);                                                \
-  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
-  ENABLE_FP8_KV_CACHE><<<grid, block, shared_mem_size, stream>>>(                             \
+    ((void*)aphrodite::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>),          \
+    shared_mem_size);                                                                         \
+  aphrodite::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>                      \
+  <<<grid, block, shared_mem_size, stream>>>(                                                 \
     out_ptr,                                                                                  \
     query_ptr,                                                                                \
     key_cache_ptr,                                                                            \
@@ -593,9 +571,7 @@ __global__ void paged_attention_v2_reduce_kernel(
 // TODO: Tune NUM_THREADS.
 template<
   typename T,
-  typename CACHE_T,
   int BLOCK_SIZE,
-  bool ENABLE_FP8_KV_CACHE,
   int NUM_THREADS = 128>
 void paged_attention_v1_launcher(
   torch::Tensor& out,
@@ -626,8 +602,8 @@ void paged_attention_v1_launcher(
 
   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
-  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
-  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
+  T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
+  T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
   int* block_tables_ptr = block_tables.data_ptr<int>();
   int* context_lens_ptr = context_lens.data_ptr<int>();
 
@@ -671,31 +647,31 @@ void paged_attention_v1_launcher(
   }
 }
 
-#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, ENABLE_FP8_KV_CACHE)       \
-  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, ENABLE_FP8_KV_CACHE>( \
-    out,                                                                    \
-    query,                                                                  \
-    key_cache,                                                              \
-    value_cache,                                                            \
-    num_kv_heads,                                                           \
-    scale,                                                                  \
-    block_tables,                                                           \
-    context_lens,                                                           \
-    max_context_len,                                                        \
+#define CALL_V1_LAUNCHER(T, BLOCK_SIZE)                             \
+  paged_attention_v1_launcher<T, BLOCK_SIZE>(                       \
+    out,                                                            \
+    query,                                                          \
+    key_cache,                                                      \
+    value_cache,                                                    \
+    num_kv_heads,                                                   \
+    scale,                                                          \
+    block_tables,                                                   \
+    context_lens,                                                   \
+    max_context_len,                                                \
     alibi_slopes);
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
-#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, ENABLE_FP8_KV_CACHE)\
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T)                              \
   switch (block_size) {                                             \
     case 8:                                                         \
-      CALL_V1_LAUNCHER(T, CACHE_T, 8, ENABLE_FP8_KV_CACHE);         \
+      CALL_V1_LAUNCHER(T, 8);                                       \
       break;                                                        \
     case 16:                                                        \
-      CALL_V1_LAUNCHER(T, CACHE_T, 16, ENABLE_FP8_KV_CACHE);        \
+      CALL_V1_LAUNCHER(T, 16);                                      \
       break;                                                        \
     case 32:                                                        \
-      CALL_V1_LAUNCHER(T, CACHE_T, 32, ENABLE_FP8_KV_CACHE);        \
+      CALL_V1_LAUNCHER(T, 32);                                      \
       break;                                                        \
     default:                                                        \
       TORCH_CHECK(false, "Unsupported block size: ", block_size);   \
@@ -713,34 +689,20 @@ void paged_attention_v1(
   torch::Tensor& context_lens,    // [num_seqs]
   int block_size,
   int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const bool enable_fp8_kv_cache) {
-  if (enable_fp8_kv_cache) {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
+  const c10::optional<torch::Tensor>& alibi_slopes) {
+  if (query.dtype() == at::ScalarType::Float) {
+    CALL_V1_LAUNCHER_BLOCK_SIZE(float);
+  } else if (query.dtype() == at::ScalarType::Half) {
+    CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
+  } else if (query.dtype() == at::ScalarType::BFloat16) {
+    CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
   } else {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
+    TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
   }
 }
 
 #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
-  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
-  ENABLE_FP8_KV_CACHE, PARTITION_SIZE>                                                        \
+  aphrodite::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>      \
   <<<grid, block, shared_mem_size, stream>>>(                                                 \
     exp_sums_ptr,                                                                             \
     max_logits_ptr,                                                                           \
@@ -768,9 +730,7 @@ void paged_attention_v1(
 
 template<
   typename T,
-  typename CACHE_T,
   int BLOCK_SIZE,
-  bool ENABLE_FP8_KV_CACHE,
   int NUM_THREADS = 128,
   int PARTITION_SIZE = 512>
 void paged_attention_v2_launcher(
@@ -808,8 +768,8 @@ void paged_attention_v2_launcher(
   float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
   T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
-  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
-  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
+  T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
+  T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
   int* block_tables_ptr = block_tables.data_ptr<int>();
   int* context_lens_ptr = context_lens.data_ptr<int>();
 
@@ -856,38 +816,38 @@ void paged_attention_v2_launcher(
   }
 }
 
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, ENABLE_FP8_KV_CACHE)       \
-  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, ENABLE_FP8_KV_CACHE>( \
-    out,                                                                    \
-    exp_sums,                                                               \
-    max_logits,                                                             \
-    tmp_out,                                                                \
-    query,                                                                  \
-    key_cache,                                                              \
-    value_cache,                                                            \
-    num_kv_heads,                                                           \
-    scale,                                                                  \
-    block_tables,                                                           \
-    context_lens,                                                           \
-    max_context_len,                                                        \
+#define CALL_V2_LAUNCHER(T, BLOCK_SIZE)                             \
+  paged_attention_v2_launcher<T, BLOCK_SIZE>(                       \
+    out,                                                            \
+    exp_sums,                                                       \
+    max_logits,                                                     \
+    tmp_out,                                                        \
+    query,                                                          \
+    key_cache,                                                      \
+    value_cache,                                                    \
+    num_kv_heads,                                                   \
+    scale,                                                          \
+    block_tables,                                                   \
+    context_lens,                                                   \
+    max_context_len,                                                \
     alibi_slopes);
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
-#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, ENABLE_FP8_KV_CACHE)        \
-  switch (block_size) {                                                     \
-    case 8:                                                                 \
-      CALL_V2_LAUNCHER(T, CACHE_T, 8, ENABLE_FP8_KV_CACHE);                 \
-      break;                                                                \
-    case 16:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 16, ENABLE_FP8_KV_CACHE);                \
-      break;                                                                \
-    case 32:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 32, ENABLE_FP8_KV_CACHE);                \
-      break;                                                                \
-    default:                                                                \
-      TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
-      break;                                                                \
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T)                              \
+  switch (block_size) {                                             \
+    case 8:                                                         \
+      CALL_V2_LAUNCHER(T, 8);                                       \
+      break;                                                        \
+    case 16:                                                        \
+      CALL_V2_LAUNCHER(T, 16);                                      \
+      break;                                                        \
+    case 32:                                                        \
+      CALL_V2_LAUNCHER(T, 32);                                      \
+      break;                                                        \
+    default:                                                        \
+      TORCH_CHECK(false, "Unsupported block size: ", block_size);   \
+      break;                                                        \
   }
 
 void paged_attention_v2(
@@ -904,28 +864,15 @@ void paged_attention_v2(
   torch::Tensor& context_lens,    // [num_seqs]
   int block_size,
   int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const bool enable_fp8_kv_cache) {
-  if (enable_fp8_kv_cache) {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
+  const c10::optional<torch::Tensor>& alibi_slopes) {
+  if (query.dtype() == at::ScalarType::Float) {
+    CALL_V2_LAUNCHER_BLOCK_SIZE(float);
+  } else if (query.dtype() == at::ScalarType::Half) {
+    CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
+  } else if (query.dtype() == at::ScalarType::BFloat16) {
+    CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
   } else {
-    if (query.dtype() == at::ScalarType::Float) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
-    } else if (query.dtype() == at::ScalarType::Half) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
-    } else if (query.dtype() == at::ScalarType::BFloat16) {
-      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
-    } else {
-      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
-    }
+    TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
   }
 }
 

+ 0 - 31
kernels/attention/dtype_fp8.cuh

@@ -1,31 +0,0 @@
-#pragma once
-
-#include "attention_generic.cuh"
-
-#include <stdint.h>
-#include <cuda_fp8.h>
-
-namespace aphrodite {
-// FP8 vector types for quantization of KV Cache
-
-template<>
-struct Vec<uint8_t, 1> {
-    using Type = uint8_t;
-};
-
-template<>
-struct Vec<uint8_t, 2> {
-    using Type = uint16_t;
-};
-
-template<>
-struct Vec<uint8_t, 4> {
-    using Type = uint32_t;
-};
-
-template<>
-struct Vec<uint8_t, 8> {
-    using Type = uint2;
-};
-
-} // namespace aphrodite

+ 3 - 5
kernels/cache.h

@@ -1,3 +1,5 @@
+#pragma once
+
 #include <torch/extension.h>
 
 #include <map>
@@ -25,8 +27,4 @@ void gather_cached_kv(
   torch::Tensor& value,
   torch::Tensor& key_cache,
   torch::Tensor& value_cache,
-  torch::Tensor& slot_mapping);
-
-void convert_fp8(
-  torch::Tensor& src_cache,
-  torch::Tensor& dst_cache);
+  torch::Tensor& slot_mapping);

+ 27 - 98
kernels/cache_kernels.cu

@@ -4,7 +4,6 @@
 
 #include "cuda_compat.h"
 #include "dispatch_utils.h"
-#include "quantization/kvcache/quant_utils.cuh"
 
 #include <algorithm>
 #include <cassert>
@@ -35,7 +34,7 @@ void swap_blocks(
   char *dst_ptr = static_cast<char*>(dst.data_ptr());
 
   const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
-  const at::cuda::OptionalCUDAGuard device_guard(src_device);
+  const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   // NOTE: This can be slow if the number of blocks is large.
   for (const auto& pair : block_mapping) {
@@ -132,7 +131,7 @@ void copy_blocks(
   dim3 block(std::min(1024, numel_per_block));
   const at::cuda::OptionalCUDAGuard device_guard(cache_device);
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  APHRODITE_DISPATCH_FLOATING_BYTE_TYPES(
+  APHRODITE_DISPATCH_FLOATING_TYPES(
     key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
       aphrodite::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
         key_cache_ptrs_tensor.data_ptr<int64_t>(),
@@ -144,12 +143,12 @@ void copy_blocks(
 
 namespace aphrodite {
 
-template<typename scalar_t, typename cache_t, bool enable_fp8_kv_cache>
+template<typename scalar_t>
 __global__ void reshape_and_cache_kernel(
   const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
   const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
-  cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x]
-  cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size]
+  scalar_t* __restrict__ key_cache,           // [num_blocks, num_heads, head_size/x, block_size, x]
+  scalar_t* __restrict__ value_cache,         // [num_blocks, num_heads, head_size, block_size]
   const int64_t* __restrict__ slot_mapping,   // [num_tokens]
   const int key_stride,
   const int value_stride,
@@ -186,34 +185,13 @@ __global__ void reshape_and_cache_kernel(
                                   + head_idx * head_size * block_size
                                   + head_offset * block_size
                                   + block_offset;
-    scalar_t tgt_key = APHRODITE_LDG(&key[src_key_idx]);
-    scalar_t tgt_value = APHRODITE_LDG(&value[src_value_idx]);
-    if constexpr (enable_fp8_kv_cache) {
-      key_cache[tgt_key_idx] = vec_conversion<uint8_t, scalar_t>(tgt_key);
-      value_cache[tgt_value_idx] = vec_conversion<uint8_t, scalar_t>(tgt_value);
-    } else {
-      key_cache[tgt_key_idx] = tgt_key;
-      value_cache[tgt_value_idx] = tgt_value;
-    }
+    key_cache[tgt_key_idx] = key[src_key_idx];
+    value_cache[tgt_value_idx] = value[src_value_idx];
   }
 }
 
 } // namespace aphrodite
 
-#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, ENABLE_FP8_KV_CACHE)                                \
-  aphrodite::reshape_and_cache_kernel<KV_T, CACHE_T, ENABLE_FP8_KV_CACHE><<<grid, block, 0, stream>>>( \
-    reinterpret_cast<KV_T*>(key.data_ptr()),                                                      \
-    reinterpret_cast<KV_T*>(value.data_ptr()),                                                    \
-    reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                             \
-    reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                           \
-    slot_mapping.data_ptr<int64_t>(),                                                             \
-    key_stride,                                                                                   \
-    value_stride,                                                                                 \
-    num_heads,                                                                                    \
-    head_size,                                                                                    \
-    block_size,                                                                                   \
-    x);
-
 void reshape_and_cache(
   torch::Tensor& key,           // [num_tokens, num_heads, head_size]
   torch::Tensor& value,         // [num_tokens, num_heads, head_size]
@@ -234,24 +212,23 @@ void reshape_and_cache(
   dim3 block(std::min(num_heads * head_size, 512));
   const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  bool enable_fp8_kv_cache = (key_cache.scalar_type() == at::ScalarType::Byte);
-  if (enable_fp8_kv_cache) {
-    if (key.dtype() == at::ScalarType::Float) {
-      CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
-    } else if (key.dtype() == at::ScalarType::Half) {
-      CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
-    } else if (key.dtype() == at::ScalarType::BFloat16) {
-      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
-    }
-  } else {
-    if (key.dtype() == at::ScalarType::Float) {
-      CALL_RESHAPE_AND_CACHE(float, float, false);
-    } else if (key.dtype() == at::ScalarType::Half) {
-      CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
-    } else if (key.dtype() == at::ScalarType::BFloat16) {
-      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
-    }
-  }
+  APHRODITE_DISPATCH_FLOATING_TYPES(
+    key.scalar_type(),
+    "reshape_and_cache_kernel",
+    [&] {
+      aphrodite::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
+        key.data_ptr<scalar_t>(),
+        value.data_ptr<scalar_t>(),
+        key_cache.data_ptr<scalar_t>(),
+        value_cache.data_ptr<scalar_t>(),
+        slot_mapping.data_ptr<int64_t>(),
+        key_stride,
+        value_stride,
+        num_heads,
+        head_size,
+        block_size,
+        x);
+    });
 }
 
 namespace aphrodite {
@@ -279,12 +256,12 @@ __global__ void gather_cached_kv_kernel(
     for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) {
       const int tgt_key_idx = token_idx * key_stride + i;
       const int tgt_value_idx = token_idx * value_stride + i;
-
+  
       const int head_idx = i / head_size;
       const int head_offset = i % head_size;
       const int x_idx = head_offset / x;  // the offset of the [head_size/x] dimension
       const int x_offset = head_offset % x;
-
+  
       const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
                               + head_idx * (head_size / x) * block_size * x
                               + x_idx * block_size * x
@@ -396,7 +373,7 @@ void gather_cached_kv(
   dim3 block(std::min(num_heads * head_size, 512));
   const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  APHRODITE_DISPATCH_FLOATING_BYTE_TYPES(
+  APHRODITE_DISPATCH_FLOATING_TYPES(
     key.scalar_type(),
     "gather_cached_kv_kernel_optimized",
     [&] {
@@ -414,51 +391,3 @@ void gather_cached_kv(
         x);
     });
 }
-
-namespace aphrodite {
-
-template<typename Tout, typename Tin>
-__global__ void convert_fp8_kernel(
-  const Tin* __restrict__ src_cache,
-  Tout* __restrict__ dst_cache,
-  const int64_t block_stride) {
-  const int64_t block_idx = blockIdx.x;
-  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
-    int64_t idx = block_idx * block_stride + i;
-    dst_cache[idx] = vec_conversion<Tout, Tin>(APHRODITE_LDG(&src_cache[idx]));
-  }
-}
-
-} // namespace aphrodite
-
-#define CALL_CONVERT_FP8(Tout, Tin)                                 \
-  aphrodite::convert_fp8_kernel<Tout, Tin><<<grid, block, 0, stream>>>(  \
-    reinterpret_cast<Tin*>(src_cache.data_ptr()),                   \
-    reinterpret_cast<Tout*>(dst_cache.data_ptr()),                  \
-    block_stride);
-
-void convert_fp8(
-  torch::Tensor& src_cache,
-  torch::Tensor& dst_cache)
-{
-  int64_t num_blocks = src_cache.size(0);
-  int64_t block_stride = src_cache.stride(0);
-
-  dim3 grid(num_blocks);
-  dim3 block(std::min(block_stride, int64_t(512)));
-  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-
-  if (src_cache.dtype() == at::ScalarType::Float) {
-    CALL_CONVERT_FP8(uint8_t, float);
-  } else if (src_cache.dtype() == at::ScalarType::Half) {
-    CALL_CONVERT_FP8(uint8_t, uint16_t);
-  } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
-    CALL_CONVERT_FP8(uint8_t, __nv_bfloat16);
-  } else if (dst_cache.dtype() == at::ScalarType::Float) {
-    CALL_CONVERT_FP8(float, uint8_t);
-  } else if (dst_cache.dtype() == at::ScalarType::Half) {
-    CALL_CONVERT_FP8(uint16_t, uint8_t);
-  } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
-    CALL_CONVERT_FP8(__nv_bfloat16, uint8_t);
-  }
-}

+ 1 - 1
kernels/cuda_compat.h

@@ -24,4 +24,4 @@
 #else
   #define APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
     hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
-#endif
+#endif

+ 3 - 2
kernels/cuda_utils.h

@@ -1,6 +1,7 @@
+#pragma once
+
 #include <torch/extension.h>
 
 int get_device_attribute(
     int attribute,
-    int device_id);
-    
+    int device_id);

+ 1 - 2
kernels/cuda_utils_kernels.cu

@@ -1,7 +1,6 @@
 #ifdef USE_ROCM
-    #include <hip/hip_runtime.h>
+  #include <hip/hip_runtime.h>
 #endif
-
 int get_device_attribute(
     int attribute,
     int device_id)

+ 8 - 27
kernels/dispatch_utils.h

@@ -2,34 +2,15 @@
  * Adapted from
  * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
  */
+#pragma once
+
 #include <torch/extension.h>
 
-#define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...)                         \
-  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)                      \
-  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)                       \
+#define APHRODITE_DISPATCH_CASE_FLOATING_TYPES(...)              \
+  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
+  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
   AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
 
-#define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)                  \
-  AT_DISPATCH_SWITCH(                                                       \
-    TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
-
-#define APHRODITE_DISPATCH_CASE_FLOATING_BYTE_TYPES(...)                    \
-  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)                      \
-  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)                       \
-  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)                   \
-  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
-
-#define APHRODITE_DISPATCH_FLOATING_BYTE_TYPES(TYPE, NAME, ...)             \
-  AT_DISPATCH_SWITCH(                                                       \
-    TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_BYTE_TYPES(__VA_ARGS__))
-
-#define APHRODITE_DISPATCH_CASE_INTEGRAL_TYPES(...)                         \
-  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)                       \
-  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)                       \
-  AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)                      \
-  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)                        \
-  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
-
-#define APHRODITE_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...)                  \
-  AT_DISPATCH_SWITCH(                                                       \
-    TYPE, NAME, APHRODITE_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
+#define APHRODITE_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \
+  AT_DISPATCH_SWITCH(                                             \
+    TYPE, NAME, APHRODITE_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

+ 2 - 1
kernels/layernorm_kernels.cu

@@ -1,6 +1,6 @@
 #include <torch/extension.h>
-#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
 
 #include "dispatch_utils.h"
 #include "reduction.cuh"
@@ -35,6 +35,7 @@ __global__ void rms_norm_kernel(
   }
 }
 
+// TODO: Further optimize this kernel.
 template<typename scalar_t>
 __global__ void fused_add_rms_norm_kernel(
   scalar_t* __restrict__ input,           // [..., hidden_size]

+ 0 - 35
kernels/misc_kernels.cu

@@ -1,35 +0,0 @@
-// adapted from https://github.com/rusty1s/pytorch_bincount
-#include <torch/extension.h>
-#include <ATen/cuda/CUDAContext.h>
-
-#include <ATen/ATen.h>
-#include <THC/THCAtomics.cuh>
-
-#include "cuda_compat.h"
-#include "dispatch_utils.h"
-
-#define THREADS 1024
-#define BLOCKS(N) (N + THREADS - 1) / THREADS
-
-namespace aphrodite {
-template <typename scalar_t>
-__global__ void bincount_kernel(scalar_t *__restrict__ src, int32_t *out,
-                                size_t numel) {
-  const size_t index = blockIdx.x * blockDim.x + threadIdx.x;
-  const size_t stride = blockDim.x * gridDim.x;
-  for (ptrdiff_t i = index; i < numel; i += stride) {
-    atomicAdd(out + (ptrdiff_t)src[i], 1);
-  }
-}
-}
-
-// create a custom bincount since pytorch's bincount is
-// not cudagraph capturable.
-void aphrodite_bincount(torch::Tensor src, torch::Tensor out) {
-   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
-  APHRODITE_DISPATCH_INTEGRAL_TYPES(
-    src.scalar_type(), "bincount_kernel", [&] {
-    aphrodite::bincount_kernel<scalar_t><<<BLOCKS(src.numel()), THREADS, 0, stream>>>(
-        src.data<scalar_t>(), out.data<int32_t>(), src.numel());
-  });
-}

+ 4 - 11
kernels/ops.h

@@ -1,4 +1,5 @@
-#include <cstdint>
+#pragma once
+
 #include <torch/extension.h>
 
 void paged_attention_v1(
@@ -12,8 +13,7 @@ void paged_attention_v1(
   torch::Tensor& context_lens,
   int block_size,
   int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const bool enable_fp8_kv_cache);
+  const c10::optional<torch::Tensor>& alibi_slopes);
 
 void paged_attention_v2(
   torch::Tensor& out,
@@ -29,8 +29,7 @@ void paged_attention_v2(
   torch::Tensor& context_lens,
   int block_size,
   int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const bool enable_fp8_kv_cache);
+  const c10::optional<torch::Tensor>& alibi_slopes);
 
 void rms_norm(
   torch::Tensor& out,
@@ -64,7 +63,6 @@ void gelu_fast(
   torch::Tensor& out,
   torch::Tensor& input);
 
-// The AWQ kernels are only available on CUDA
 #ifndef USE_ROCM
 torch::Tensor awq_gemm(
   torch::Tensor _in_feats,
@@ -93,8 +91,3 @@ void gptq_shuffle(
   torch::Tensor q_weight,
   torch::Tensor q_perm,
   int bit);
-
-void aphrodite_bincount(
-  torch::Tensor src,
-  torch::Tensor out);
-  

+ 1 - 1
kernels/pos_encoding_kernels.cu

@@ -1,6 +1,6 @@
 #include <torch/extension.h>
-#include <c10/cuda/CUDAGuard.h>
 #include <ATen/cuda/CUDAContext.h>
+#include <c10/cuda/CUDAGuard.h>
 
 #include "cuda_compat.h"
 #include "dispatch_utils.h"

+ 4 - 15
kernels/pybind.cpp

@@ -48,20 +48,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
     &rotary_embedding,
     "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
 
+#ifndef USE_ROCM
   // Quantization ops
-  #ifndef USE_ROCM
+  ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
+#endif
   ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
   ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
-  ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
-  #endif
   ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
 
-  // misc
-  ops.def(
-    "bincount",
-    &aphrodite_bincount,
-    "CUDA Graph compatible bincount implementation.");
-
   // Cache ops
   pybind11::module cache_ops = m.def_submodule("cache_ops", "Aphrodite Engine cache ops");
   cache_ops.def(
@@ -80,11 +74,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
     "gather_cached_kv",
     &gather_cached_kv,
     "Gather key and value from the cache into contiguous QKV tensors");
-  cache_ops.def(
-    "convert_fp8",
-    &convert_fp8,
-    "Convert the KV cache to FP8 datatype");
-    
 
   // Cuda utils
   pybind11::module cuda_utils = m.def_submodule("cuda_utils", "Aphrodite Engine cuda utils");
@@ -92,4 +81,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
     "get_device_attribute",
     &get_device_attribute,
     "Gets the specified device attribute.");
-}
+}

+ 1 - 1
kernels/quantization/gptq/matrix_view.cuh

@@ -271,4 +271,4 @@ public:
 
 }  // namespace gptq
 }  // namespace aphrodite
-#endif
+#endif

+ 2 - 2
kernels/quantization/gptq/q_gemm.cu

@@ -1610,7 +1610,7 @@ void reconstruct_gptq
     if (bit == 2) {
         kernel = reconstruct_gptq_kernel<MatrixView_q2_row, 2>;
     } else if (bit == 8) {
-       	kernel = reconstruct_gptq_kernel<MatrixView_q8_row, 8>;
+        kernel = reconstruct_gptq_kernel<MatrixView_q8_row, 8>;
     } else if (bit == 3) {
         kernel = reconstruct_gptq_3bit_kernel;
         gridDim.y = DIVIDE(height, 32);
@@ -2072,4 +2072,4 @@ void gptq_shuffle
         q_weight.size(1),
         bit
     );
-}
+}

+ 1 - 1
kernels/quantization/gptq/qdq_4.cuh

@@ -144,4 +144,4 @@ __forceinline__ __device__ void dequant_4bit_8_gptq
 }  // namespace gptq
 }  // namespace aphrodite
 
-#endif
+#endif

+ 0 - 270
kernels/quantization/kvcache/quant_utils.cuh

@@ -1,270 +0,0 @@
-#pragma once
-
-#include <assert.h>
-#include <stdint.h>
-#include <float.h>
-#include <type_traits>
-#include "../../attention/attention_dtypes.h"
-#include "../../attention/dtype_float32.cuh"
-#include "../../attention/dtype_float16.cuh"
-#include "../../attention/dtype_bfloat16.cuh"
-
-using namespace aphrodite;
-
-
-template<typename Tout, typename Tin>
-__inline__ __device__ Tout vec_conversion(const Tin& x)
-{
-    return x;
-}
-
-
-// fp8 -> half
-template<>
-__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
-{
-    __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
-    return res.x;
-}
-
-// fp8x2 -> half2
-template<>
-__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
-{
-    union {
-        uint16_t u16[2];
-        uint32_t u32;
-    } tmp;
-    __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
-    tmp.u16[0] = res.x;
-    tmp.u16[1] = res.y;
-    return tmp.u32;
-}
-
-// fp8x4 -> half2x2
-template<>
-__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
-{
-    union {
-        uint2    u32x2;
-        uint32_t u32[2];
-    } tmp;
-    tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
-    tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
-    return tmp.u32x2;
-}
-
-// fp8x8 -> half2x4
-template<>
-__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
-{
-    union {
-        uint4 u64x2;
-        uint2 u64[2];
-    } tmp;
-    tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
-    tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
-    return tmp.u64x2;
-}
-
-// fp8 -> __nv_bfloat16
-template<>
-__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
-{
-    // Note there is no direct convert function from fp8 to bf16.
-    // fp8 -> half
-    __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
-    // half -> float -> bf16
-    float tmp = half_to_float(res.x);
-    return __float2bfloat16(tmp);
-}
-
-// fp8x2 -> __nv_bfloat162
-template<>
-__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
-{
-    __nv_bfloat162 res;
-    res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
-    res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
-    return res;
-}
-
-// fp8x4 -> bf16_4_t
-template<>
-__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
-{
-    bf16_4_t res;
-    res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
-    res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
-    return res;
-}
-
-// fp8x8 -> bf16_8_t
-template<>
-__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
-{
-    bf16_4_t tmp1, tmp2;
-    tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
-    tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
-    bf16_8_t res;
-    res.x = tmp1.x;
-    res.y = tmp1.y;
-    res.z = tmp2.x;
-    res.w = tmp2.y;
-    return res;
-}
-
-// fp8 -> float
-template<>
-__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
-{
-    // fp8 -> half
-    uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
-    // half -> float
-    return half_to_float(tmp);
-}
-
-// fp8x2 -> float2
-template<>
-__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
-{
-    // fp8x2 -> half2
-    uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
-    // half2 -> float2
-    return half2_to_float2(tmp);
-}
-
-// fp8x4 -> float4
-template<>
-__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
-{
-    Float4_ res;
-    res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
-    res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
-    return res;
-}
-
-// fp8x8 -> float8
-template<>
-__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
-{
-    Float4_ tmp1, tmp2;
-    tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
-    tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
-    Float8_ res;
-    res.x = tmp1.x;
-    res.y = tmp1.y;
-    res.z = tmp2.x;
-    res.w = tmp2.y;
-    return res;
-}
-
-
-// half -> fp8
-template<>
-__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
-{
-    __half_raw tmp;
-    tmp.x = a;
-    __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
-    return (uint8_t)res;
-}
-
-// bf16 -> fp8
-template<>
-__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
-{
-    __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
-    return (uint8_t)res;
-}
-
-// float -> fp8
-template<>
-__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
-{
-    __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
-    return (uint8_t)res;
-}
-
-// fp8x4 -> float4
-template<>
-__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
-{
-    Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
-    float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
-    return res;
-}
-
-
-template<>
-__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
-{
-    union {
-        half2    float16;
-        uint32_t uint32;
-    };
-
-    float16 = __float22half2_rn(a);
-    return uint32;
-}
-
-template<>
-__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
-{
-    uint2  b;
-    float2 val;
-    val.x = a.x.x;
-    val.y = a.x.y;
-    b.x   = vec_conversion<uint32_t, float2>(val);
-
-    val.x = a.y.x;
-    val.y = a.y.y;
-    b.y   = vec_conversion<uint32_t, float2>(val);
-
-    return b;
-}
-
-template<>
-__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
-{
-    float4 b;
-    b.x = a.x.x;
-    b.y = a.x.y;
-    b.z = a.y.x;
-    b.w = a.y.y;
-    return b;
-}
-
-template<>
-__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
-{
-    uint4 b;
-    b.x = vec_conversion<uint32_t, float2>(a.x);
-    b.y = vec_conversion<uint32_t, float2>(a.y);
-    b.z = vec_conversion<uint32_t, float2>(a.z);
-    b.w = vec_conversion<uint32_t, float2>(a.w);
-    return b;
-}
-
-template<>
-__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
-    return __float22bfloat162_rn(a);
-}
-
-template<>
-__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
-    bf16_4_t b;
-    b.x = vec_conversion<__nv_bfloat162, float2>(a.x);
-    b.y = vec_conversion<__nv_bfloat162, float2>(a.y);
-    return b;
-}
-
-template<>
-__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
-    bf16_8_t b;
-    b.x = vec_conversion<__nv_bfloat162, float2>(a.x);
-    b.y = vec_conversion<__nv_bfloat162, float2>(a.y);
-    b.z = vec_conversion<__nv_bfloat162, float2>(a.z);
-    b.w = vec_conversion<__nv_bfloat162, float2>(a.w);
-    return b;
-}

+ 2 - 1
kernels/quantization/squeezellm/quant_cuda_kernel.cu

@@ -200,6 +200,7 @@ void squeezellm_gemm(
     (width + BLOCKWIDTH - 1) / BLOCKWIDTH
   );
   dim3 threads(BLOCKWIDTH);
+
   const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
   aphrodite::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
@@ -221,4 +222,4 @@ void squeezellm_gemm(
 }
 
 #undef BLOCKWIDTH
-#undef BLOCKHEIGHT4
+#undef BLOCKHEIGHT4

+ 0 - 1
setup.py

@@ -213,7 +213,6 @@ elif _is_hip():
 ext_modules = []
 
 aphrodite_extension_sources = [
-    "kernels/misc_kernels.cu",
     "kernels/cache_kernels.cu",
     "kernels/attention/attention_kernels.cu",
     "kernels/pos_encoding_kernels.cu",