Browse Source

chore: rename `task_handler` to `worker` (#985)

AlpinDale 2 months ago
parent
commit
3bb0f07461
58 changed files with 414 additions and 467 deletions
  1. 1 1
      aphrodite/attention/backends/abstract.py
  2. 1 1
      aphrodite/attention/backends/flash_attn.py
  3. 1 1
      aphrodite/attention/backends/flashinfer.py
  4. 1 1
      aphrodite/attention/backends/placeholder_attn.py
  5. 2 2
      aphrodite/attention/backends/utils.py
  6. 2 2
      aphrodite/executor/cpu_executor.py
  7. 3 3
      aphrodite/executor/gpu_executor.py
  8. 1 1
      aphrodite/executor/multiproc_worker_utils.py
  9. 1 1
      aphrodite/executor/neuron_executor.py
  10. 1 1
      aphrodite/executor/openvino_executor.py
  11. 1 1
      aphrodite/executor/ray_tpu_executor.py
  12. 2 2
      aphrodite/executor/ray_utils.py
  13. 1 1
      aphrodite/executor/tpu_executor.py
  14. 2 2
      aphrodite/executor/xpu_executor.py
  15. 2 2
      aphrodite/modeling/models/jamba.py
  16. 2 2
      aphrodite/modeling/models/mamba.py
  17. 1 1
      aphrodite/spec_decode/batch_expansion.py
  18. 1 1
      aphrodite/spec_decode/draft_model_runner.py
  19. 1 1
      aphrodite/spec_decode/medusa_worker.py
  20. 1 1
      aphrodite/spec_decode/multi_step_worker.py
  21. 1 1
      aphrodite/spec_decode/proposer_worker_base.py
  22. 2 3
      aphrodite/spec_decode/spec_decode_worker.py
  23. 1 1
      aphrodite/spec_decode/target_model_runner.py
  24. 0 0
      aphrodite/worker/__init__.py
  25. 0 0
      aphrodite/worker/cache_engine.py
  26. 1 1
      aphrodite/worker/cpu_model_runner.py
  27. 2 2
      aphrodite/worker/cpu_worker.py
  28. 1 1
      aphrodite/worker/embedding_model_runner.py
  29. 3 3
      aphrodite/worker/enc_dec_model_runner.py
  30. 1 1
      aphrodite/worker/model_runner.py
  31. 0 0
      aphrodite/worker/model_runner_base.py
  32. 2 2
      aphrodite/worker/multi_step_model_runner.py
  33. 3 3
      aphrodite/worker/multi_step_worker.py
  34. 1 1
      aphrodite/worker/neuron_model_runner.py
  35. 2 2
      aphrodite/worker/neuron_worker.py
  36. 0 0
      aphrodite/worker/openvino_model_runner.py
  37. 2 2
      aphrodite/worker/openvino_worker.py
  38. 1 1
      aphrodite/worker/tpu_model_runner.py
  39. 2 2
      aphrodite/worker/tpu_worker.py
  40. 1 1
      aphrodite/worker/utils.py
  41. 5 5
      aphrodite/worker/worker.py
  42. 1 1
      aphrodite/worker/worker_base.py
  43. 2 2
      aphrodite/worker/xpu_model_runner.py
  44. 4 4
      aphrodite/worker/xpu_worker.py
  45. 1 1
      docs/pages/usage/debugging.md
  46. 1 1
      kernels/attention/attention_kernels.cu
  47. 331 380
      kernels/backup/attention_kernels.cu
  48. 1 1
      tests/engine/test_multiproc_workers.py
  49. 1 1
      tests/lora/conftest.py
  50. 1 1
      tests/lora/test_long_context.py
  51. 1 1
      tests/lora/test_worker.py
  52. 1 1
      tests/models/test_jamba.py
  53. 1 1
      tests/spec_decode/test_multi_step_worker.py
  54. 3 3
      tests/spec_decode/utils.py
  55. 1 2
      tests/worker/test_encoder_decoder_model_runner.py
  56. 3 4
      tests/worker/test_model_input.py
  57. 1 2
      tests/worker/test_model_runner.py
  58. 1 1
      tests/worker/test_swap.py

+ 1 - 1
aphrodite/attention/backends/abstract.py

@@ -8,7 +8,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
 import torch
 import torch
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner_base import (
+    from aphrodite.worker.model_runner_base import (
         ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase)
         ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase)
 
 
 
 

+ 1 - 1
aphrodite/attention/backends/flash_attn.py

@@ -18,7 +18,7 @@ from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
 from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
 from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+    from aphrodite.worker.model_runner import ModelInputForGPUBuilder
 
 
 from aphrodite_flash_attn import (
 from aphrodite_flash_attn import (
     flash_attn_varlen_func as _flash_attn_varlen_func)
     flash_attn_varlen_func as _flash_attn_varlen_func)

+ 1 - 1
aphrodite/attention/backends/flashinfer.py

@@ -33,7 +33,7 @@ from aphrodite.common.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
                                     make_tensor_with_pad)
                                     make_tensor_with_pad)
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+    from aphrodite.worker.model_runner import ModelInputForGPUBuilder
 
 
 
 
 class FlashInferBackend(AttentionBackend):
 class FlashInferBackend(AttentionBackend):

+ 1 - 1
aphrodite/attention/backends/placeholder_attn.py

@@ -10,7 +10,7 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
 from aphrodite.attention.backends.utils import CommonAttentionState
 from aphrodite.attention.backends.utils import CommonAttentionState
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+    from aphrodite.worker.model_runner import ModelInputForGPUBuilder
 
 
 # Placeholder attention backend for models like Mamba and embedding models that
 # Placeholder attention backend for models like Mamba and embedding models that
 # lack attention.
 # lack attention.

+ 2 - 2
aphrodite/attention/backends/utils.py

@@ -9,7 +9,7 @@ from aphrodite.attention import (AttentionMetadata, AttentionMetadataBuilder,
 from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
 from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner_base import ModelRunnerBase
+    from aphrodite.worker.model_runner_base import ModelRunnerBase
 
 
 # Error string(s) for encoder/decoder
 # Error string(s) for encoder/decoder
 # unsupported attention scenarios
 # unsupported attention scenarios
@@ -23,7 +23,7 @@ PAD_SLOT_ID = -1
 _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
 _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+    from aphrodite.worker.model_runner import ModelInputForGPUBuilder
 
 
 
 
 def is_block_tables_empty(block_tables: Union[None, Dict]):
 def is_block_tables_empty(block_tables: Union[None, Dict]):

+ 2 - 2
aphrodite/executor/cpu_executor.py

@@ -18,7 +18,7 @@ from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
-from aphrodite.task_handler.worker_base import WorkerWrapperBase
+from aphrodite.worker.worker_base import WorkerWrapperBase
 
 
 
 
 class CPUExecutor(ExecutorBase):
 class CPUExecutor(ExecutorBase):
@@ -121,7 +121,7 @@ class CPUExecutor(ExecutorBase):
         local_rank: int = 0,
         local_rank: int = 0,
         rank: int = 0,
         rank: int = 0,
     ):
     ):
-        worker_module_name = "aphrodite.task_handler.cpu_worker"
+        worker_module_name = "aphrodite.worker.cpu_worker"
         worker_class_name = "CPUWorker"
         worker_class_name = "CPUWorker"
 
 
         wrapper = WorkerWrapperBase(
         wrapper = WorkerWrapperBase(

+ 3 - 3
aphrodite/executor/gpu_executor.py

@@ -9,7 +9,7 @@ from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
-from aphrodite.task_handler.worker_base import WorkerBase, WorkerWrapperBase
+from aphrodite.worker.worker_base import WorkerBase, WorkerWrapperBase
 
 
 
 
 def create_worker(worker_module_name: str, worker_class_name: str,
 def create_worker(worker_module_name: str, worker_class_name: str,
@@ -68,13 +68,13 @@ class GPUExecutor(ExecutorBase):
             self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
             self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
         worker_class_fn = None
         worker_class_fn = None
         if self.scheduler_config.is_multi_step:
         if self.scheduler_config.is_multi_step:
-            worker_module_name = "aphrodite.task_handler.multi_step_worker"
+            worker_module_name = "aphrodite.worker.multi_step_worker"
             worker_class_name = "MultiStepWorker"
             worker_class_name = "MultiStepWorker"
         elif self.speculative_config:
         elif self.speculative_config:
             worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
             worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
             worker_class_name = "create_spec_worker"
             worker_class_name = "create_spec_worker"
         else:
         else:
-            worker_module_name = "aphrodite.task_handler.worker"
+            worker_module_name = "aphrodite.worker.worker"
             worker_class_name = "Worker"
             worker_class_name = "Worker"
         return (worker_module_name, worker_class_name, worker_class_fn)
         return (worker_module_name, worker_class_name, worker_class_fn)
 
 

+ 1 - 1
aphrodite/executor/multiproc_worker_utils.py

@@ -142,7 +142,7 @@ class WorkerMonitor(threading.Thread):
 
 
 
 
 class ProcessWorkerWrapper:
 class ProcessWorkerWrapper:
-    """Local process wrapper for aphrodite.task_handler.Worker,
+    """Local process wrapper for aphrodite.worker.Worker,
     for handling single-node multi-GPU tensor parallel."""
     for handling single-node multi-GPU tensor parallel."""
 
 
     def __init__(self, result_handler: ResultHandler,
     def __init__(self, result_handler: ResultHandler,

+ 1 - 1
aphrodite/executor/neuron_executor.py

@@ -22,7 +22,7 @@ class NeuronExecutor(ExecutorBase):
         self._init_worker()
         self._init_worker()
 
 
     def _init_worker(self):
     def _init_worker(self):
-        from aphrodite.task_handler.neuron_worker import NeuronWorker
+        from aphrodite.worker.neuron_worker import NeuronWorker
         distributed_init_method = get_distributed_init_method(
         distributed_init_method = get_distributed_init_method(
             get_ip(), get_open_port())
             get_ip(), get_open_port())
         self.driver_worker = NeuronWorker(
         self.driver_worker = NeuronWorker(

+ 1 - 1
aphrodite/executor/openvino_executor.py

@@ -33,7 +33,7 @@ class OpenVINOExecutor(ExecutorBase):
         self._init_worker()
         self._init_worker()
 
 
     def _init_worker(self):
     def _init_worker(self):
-        from aphrodite.task_handler.openvino_worker import OpenVINOWorker
+        from aphrodite.worker.openvino_worker import OpenVINOWorker
 
 
         assert (
         assert (
             self.parallel_config.world_size == 1
             self.parallel_config.world_size == 1

+ 1 - 1
aphrodite/executor/ray_tpu_executor.py

@@ -70,7 +70,7 @@ class RayTPUExecutor(TPUExecutor):
             )
             )
 
 
             assert self.speculative_config is None
             assert self.speculative_config is None
-            worker_module_name = "aphrodite.task_handler.tpu_worker"
+            worker_module_name = "aphrodite.worker.tpu_worker"
             worker_class_name = "TPUWorker"
             worker_class_name = "TPUWorker"
 
 
             # GKE does not fetch environment information from metadata server
             # GKE does not fetch environment information from metadata server

+ 2 - 2
aphrodite/executor/ray_utils.py

@@ -11,7 +11,7 @@ from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
 from aphrodite.common.utils import get_ip, is_hip, is_xpu
 from aphrodite.common.utils import get_ip, is_hip, is_xpu
 from aphrodite.executor.msgspec_utils import decode_hook, encode_hook
 from aphrodite.executor.msgspec_utils import decode_hook, encode_hook
 from aphrodite.platforms import current_platform
 from aphrodite.platforms import current_platform
-from aphrodite.task_handler.worker_base import WorkerWrapperBase
+from aphrodite.worker.worker_base import WorkerWrapperBase
 
 
 PG_WAIT_TIMEOUT = 1800
 PG_WAIT_TIMEOUT = 1800
 
 
@@ -22,7 +22,7 @@ try:
     from ray.util.placement_group import PlacementGroup
     from ray.util.placement_group import PlacementGroup
 
 
     class RayWorkerWrapper(WorkerWrapperBase):
     class RayWorkerWrapper(WorkerWrapperBase):
-        """Ray wrapper for aphrodite.task_handler.Worker, allowing Worker to be
+        """Ray wrapper for aphrodite.worker.Worker, allowing Worker to be
         lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
         lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
 
 
         def __init__(self, *args, **kwargs) -> None:
         def __init__(self, *args, **kwargs) -> None:

+ 1 - 1
aphrodite/executor/tpu_executor.py

@@ -60,7 +60,7 @@ class TPUExecutor(ExecutorBase):
         rank: int = 0,
         rank: int = 0,
         distributed_init_method: Optional[str] = None,
         distributed_init_method: Optional[str] = None,
     ):
     ):
-        from aphrodite.task_handler.tpu_worker import TPUWorker
+        from aphrodite.worker.tpu_worker import TPUWorker
 
 
         worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
         worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
                                                      distributed_init_method))
                                                      distributed_init_method))

+ 2 - 2
aphrodite/executor/xpu_executor.py

@@ -12,7 +12,7 @@ from aphrodite.common.utils import make_async
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.gpu_executor import GPUExecutor
 from aphrodite.executor.gpu_executor import GPUExecutor
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
-from aphrodite.task_handler.worker_base import WorkerBase
+from aphrodite.worker.worker_base import WorkerBase
 
 
 
 
 class XPUExecutor(GPUExecutor):
 class XPUExecutor(GPUExecutor):
@@ -56,7 +56,7 @@ class XPUExecutor(GPUExecutor):
             raise NotImplementedError(
             raise NotImplementedError(
                 "XPU does not support speculative decoding")
                 "XPU does not support speculative decoding")
         else:
         else:
-            worker_module_name = "aphrodite.task_handler.xpu_worker"
+            worker_module_name = "aphrodite.worker.xpu_worker"
             worker_class_name = "XPUWorker"
             worker_class_name = "XPUWorker"
         return (worker_module_name, worker_class_name, worker_class_fn)
         return (worker_module_name, worker_class_name, worker_class_fn)
 
 

+ 2 - 2
aphrodite/modeling/models/jamba.py

@@ -37,8 +37,8 @@ from aphrodite.modeling.models.mamba_cache import MambaCacheManager
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.base_config import QuantizationConfig
-from aphrodite.task_handler.model_runner import (_BATCH_SIZES_TO_CAPTURE,
-                                                 _get_graph_batch_size)
+from aphrodite.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
+                                           _get_graph_batch_size)
 
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 

+ 2 - 2
aphrodite/modeling/models/mamba.py

@@ -32,8 +32,8 @@ from aphrodite.modeling.models.mamba_cache import MambaCacheManager
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
 from aphrodite.quantization.base_config import QuantizationConfig
-from aphrodite.task_handler.model_runner import (_BATCH_SIZES_TO_CAPTURE,
-                                                 _get_graph_batch_size)
+from aphrodite.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
+                                           _get_graph_batch_size)
 
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 
 

+ 1 - 1
aphrodite/spec_decode/batch_expansion.py

@@ -13,7 +13,7 @@ from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
                                               SpeculativeScorer,
                                               SpeculativeScorer,
                                               SpeculativeScores)
                                               SpeculativeScores)
 from aphrodite.spec_decode.util import nvtx_range, split_batch_by_proposal_len
 from aphrodite.spec_decode.util import nvtx_range, split_batch_by_proposal_len
-from aphrodite.task_handler.worker_base import WorkerBase
+from aphrodite.worker.worker_base import WorkerBase
 
 
 SeqId = int
 SeqId = int
 TargetSeqId = int
 TargetSeqId = int

+ 1 - 1
aphrodite/spec_decode/draft_model_runner.py

@@ -18,7 +18,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
 from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
 from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.multimodal import MultiModalInputs
 from aphrodite.multimodal import MultiModalInputs
-from aphrodite.task_handler.model_runner import (
+from aphrodite.worker.model_runner import (
     ModelInputForGPUWithSamplingMetadata, ModelRunner)
     ModelInputForGPUWithSamplingMetadata, ModelRunner)
 
 
 # A flag to enable debug prints for the updated input tensors
 # A flag to enable debug prints for the updated input tensors

+ 1 - 1
aphrodite/spec_decode/medusa_worker.py

@@ -10,7 +10,7 @@ from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
 from aphrodite.spec_decode.top1_proposer import Top1Proposer
 from aphrodite.spec_decode.top1_proposer import Top1Proposer
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.worker import Worker
 
 
 
 
 class MedusaWorker(NonLLMProposerWorkerBase, Worker):
 class MedusaWorker(NonLLMProposerWorkerBase, Worker):

+ 1 - 1
aphrodite/spec_decode/multi_step_worker.py

@@ -12,7 +12,7 @@ from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
                                               SpeculativeProposer)
                                               SpeculativeProposer)
 from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
 from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
 from aphrodite.spec_decode.top1_proposer import Top1Proposer
 from aphrodite.spec_decode.top1_proposer import Top1Proposer
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.worker import Worker
 
 
 
 
 class MultiStepWorker(Worker, ProposerWorkerBase):
 class MultiStepWorker(Worker, ProposerWorkerBase):

+ 1 - 1
aphrodite/spec_decode/proposer_worker_base.py

@@ -4,7 +4,7 @@ from typing import List, Optional, Set, Tuple
 from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import SpeculativeProposer
 from aphrodite.spec_decode.interfaces import SpeculativeProposer
-from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
+from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase
 
 
 
 
 class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):
 class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):

+ 2 - 3
aphrodite/spec_decode/spec_decode_worker.py

@@ -35,9 +35,8 @@ from aphrodite.spec_decode.util import (Timer, create_sequence_group_output,
                                         get_all_num_logprobs,
                                         get_all_num_logprobs,
                                         get_sampled_token_logprobs, nvtx_range,
                                         get_sampled_token_logprobs, nvtx_range,
                                         split_batch_by_proposal_len)
                                         split_batch_by_proposal_len)
-from aphrodite.task_handler.worker import Worker
-from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
-                                                WorkerBase)
+from aphrodite.worker.worker import Worker
+from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
 
 
 
 
 def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
 def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":

+ 1 - 1
aphrodite/spec_decode/target_model_runner.py

@@ -4,7 +4,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
                                      PromptAdapterConfig, SchedulerConfig)
 from aphrodite.common.sequence import SequenceGroupMetadata
 from aphrodite.common.sequence import SequenceGroupMetadata
-from aphrodite.task_handler.model_runner import (
+from aphrodite.worker.model_runner import (
     ModelInputForGPUWithSamplingMetadata, ModelRunner)
     ModelInputForGPUWithSamplingMetadata, ModelRunner)
 
 
 
 

+ 0 - 0
aphrodite/task_handler/__init__.py → aphrodite/worker/__init__.py


+ 0 - 0
aphrodite/task_handler/cache_engine.py → aphrodite/worker/cache_engine.py


+ 1 - 1
aphrodite/task_handler/cpu_model_runner.py → aphrodite/worker/cpu_model_runner.py

@@ -16,7 +16,7 @@ from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                                   MultiModalInputs)
                                   MultiModalInputs)
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase,
     ModelRunnerBase, ModelRunnerInputBase,
     _add_attn_metadata_broadcastable_dict,
     _add_attn_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict,

+ 2 - 2
aphrodite/task_handler/cpu_worker.py → aphrodite/worker/cpu_worker.py

@@ -14,8 +14,8 @@ from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
 from aphrodite.distributed import (ensure_model_parallel_initialized,
 from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    init_distributed_environment)
                                    init_distributed_environment)
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling import set_random_seed
-from aphrodite.task_handler.cpu_model_runner import CPUModelRunner
-from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
+from aphrodite.worker.cpu_model_runner import CPUModelRunner
+from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
                                                 LoraNotSupportedWorkerBase,
                                                 LoraNotSupportedWorkerBase,
                                                 WorkerInput)
                                                 WorkerInput)
 
 

+ 1 - 1
aphrodite/task_handler/embedding_model_runner.py → aphrodite/worker/embedding_model_runner.py

@@ -11,7 +11,7 @@ from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
                                        SequenceData, SequenceGroupMetadata)
                                        SequenceData, SequenceGroupMetadata)
 from aphrodite.modeling.pooling_metadata import PoolingMetadata
 from aphrodite.modeling.pooling_metadata import PoolingMetadata
 from aphrodite.multimodal import MultiModalInputs
 from aphrodite.multimodal import MultiModalInputs
-from aphrodite.task_handler.model_runner import (GPUModelRunnerBase,
+from aphrodite.worker.model_runner import (GPUModelRunnerBase,
                                                  ModelInputForGPU,
                                                  ModelInputForGPU,
                                                  ModelInputForGPUBuilder)
                                                  ModelInputForGPUBuilder)
 
 

+ 3 - 3
aphrodite/task_handler/enc_dec_model_runner.py → aphrodite/worker/enc_dec_model_runner.py

@@ -24,13 +24,13 @@ from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
 from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
-from aphrodite.task_handler.model_runner import (
+from aphrodite.worker.model_runner import (
     GPUModelRunnerBase, ModelInputForGPUBuilder,
     GPUModelRunnerBase, ModelInputForGPUBuilder,
     ModelInputForGPUWithSamplingMetadata)
     ModelInputForGPUWithSamplingMetadata)
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     _add_attn_metadata_broadcastable_dict,
     _add_attn_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict)
     _add_sampling_metadata_broadcastable_dict)
-from aphrodite.task_handler.utils import assert_enc_dec_mr_supported_scenario
+from aphrodite.worker.utils import assert_enc_dec_mr_supported_scenario
 
 
 
 
 @dataclasses.dataclass(frozen=True)
 @dataclasses.dataclass(frozen=True)

+ 1 - 1
aphrodite/task_handler/model_runner.py → aphrodite/worker/model_runner.py

@@ -51,7 +51,7 @@ from aphrodite.prompt_adapter.layers import PromptAdapterMapping
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.worker_manager import (
 from aphrodite.prompt_adapter.worker_manager import (
     LRUCacheWorkerPromptAdapterManager)
     LRUCacheWorkerPromptAdapterManager)
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
     ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
     _add_attn_metadata_broadcastable_dict,
     _add_attn_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict,

+ 0 - 0
aphrodite/task_handler/model_runner_base.py → aphrodite/worker/model_runner_base.py


+ 2 - 2
aphrodite/task_handler/multi_step_model_runner.py → aphrodite/worker/multi_step_model_runner.py

@@ -25,9 +25,9 @@ from aphrodite.modeling.layers.sampler import (PromptLogprobs, SampleLogprobs,
                                                get_logprobs,
                                                get_logprobs,
                                                get_pythonized_sample_results)
                                                get_pythonized_sample_results)
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
-from aphrodite.task_handler.model_runner import (
+from aphrodite.worker.model_runner import (
     GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata)
     GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata)
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
     BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
     _init_frozen_model_input_from_tensor_dict,
     _init_frozen_model_input_from_tensor_dict,
     _init_sampling_metadata_from_tensor_dict)
     _init_sampling_metadata_from_tensor_dict)

+ 3 - 3
aphrodite/task_handler/multi_step_worker.py → aphrodite/worker/multi_step_worker.py

@@ -7,10 +7,10 @@ import torch
 from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
 from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
-from aphrodite.task_handler.model_runner_base import BroadcastableModelInput
-from aphrodite.task_handler.multi_step_model_runner import (
+from aphrodite.worker.model_runner_base import BroadcastableModelInput
+from aphrodite.worker.multi_step_model_runner import (
     MultiStepModelRunner, StatefulModelInput)
     MultiStepModelRunner, StatefulModelInput)
-from aphrodite.task_handler.worker import Worker, WorkerInput
+from aphrodite.worker.worker import Worker, WorkerInput
 
 
 
 
 @dataclass
 @dataclass

+ 1 - 1
aphrodite/task_handler/neuron_model_runner.py → aphrodite/worker/neuron_model_runner.py

@@ -16,7 +16,7 @@ from aphrodite.modeling.model_loader.neuron import get_neuron_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                                   MultiModalInputs)
                                   MultiModalInputs)
-from aphrodite.task_handler.model_runner_base import (ModelRunnerBase,
+from aphrodite.worker.model_runner_base import (ModelRunnerBase,
                                                       ModelRunnerInputBase)
                                                       ModelRunnerInputBase)
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:

+ 2 - 2
aphrodite/task_handler/neuron_worker.py → aphrodite/worker/neuron_worker.py

@@ -10,8 +10,8 @@ from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.distributed import (ensure_model_parallel_initialized,
 from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    init_distributed_environment)
                                    init_distributed_environment)
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling import set_random_seed
-from aphrodite.task_handler.neuron_model_runner import NeuronModelRunner
-from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
+from aphrodite.worker.neuron_model_runner import NeuronModelRunner
+from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
                                                 LoraNotSupportedWorkerBase,
                                                 LoraNotSupportedWorkerBase,
                                                 WorkerInput)
                                                 WorkerInput)
 
 

+ 0 - 0
aphrodite/task_handler/openvino_model_runner.py → aphrodite/worker/openvino_model_runner.py


+ 2 - 2
aphrodite/task_handler/openvino_worker.py → aphrodite/worker/openvino_worker.py

@@ -15,8 +15,8 @@ from aphrodite.distributed import (broadcast_tensor_dict,
                                    init_distributed_environment)
                                    init_distributed_environment)
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
-from aphrodite.task_handler.openvino_model_runner import OpenVINOModelRunner
-from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
+from aphrodite.worker.openvino_model_runner import OpenVINOModelRunner
+from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase
 
 
 
 
 class OpenVINOCacheEngine:
 class OpenVINOCacheEngine:

+ 1 - 1
aphrodite/task_handler/tpu_model_runner.py → aphrodite/worker/tpu_model_runner.py

@@ -23,7 +23,7 @@ from aphrodite.compilation.wrapper import (
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase,
     ModelRunnerBase, ModelRunnerInputBase,
     _add_attn_metadata_broadcastable_dict,
     _add_attn_metadata_broadcastable_dict,
     _init_attn_metadata_from_tensor_dict)
     _init_attn_metadata_from_tensor_dict)

+ 2 - 2
aphrodite/task_handler/tpu_worker.py → aphrodite/worker/tpu_worker.py

@@ -14,8 +14,8 @@ from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
 from aphrodite.distributed import (ensure_model_parallel_initialized,
 from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    init_distributed_environment)
                                    init_distributed_environment)
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling import set_random_seed
-from aphrodite.task_handler.tpu_model_runner import TPUModelRunner
-from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
+from aphrodite.worker.tpu_model_runner import TPUModelRunner
+from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
                                                 LoraNotSupportedWorkerBase,
                                                 LoraNotSupportedWorkerBase,
                                                 WorkerInput)
                                                 WorkerInput)
 
 

+ 1 - 1
aphrodite/task_handler/utils.py → aphrodite/worker/utils.py

@@ -3,7 +3,7 @@ Worker-related helper functions.
 '''
 '''
 
 
 from aphrodite.common.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
 from aphrodite.common.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
-from aphrodite.task_handler.model_runner import GPUModelRunnerBase
+from aphrodite.worker.model_runner import GPUModelRunnerBase
 
 
 
 
 def assert_enc_dec_mr_supported_scenario(
 def assert_enc_dec_mr_supported_scenario(

+ 5 - 5
aphrodite/task_handler/worker.py → aphrodite/worker/worker.py

@@ -26,12 +26,12 @@ from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.platforms import current_platform
 from aphrodite.platforms import current_platform
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
-from aphrodite.task_handler.cache_engine import CacheEngine
-from aphrodite.task_handler.embedding_model_runner import EmbeddingModelRunner
-from aphrodite.task_handler.enc_dec_model_runner import (
+from aphrodite.worker.cache_engine import CacheEngine
+from aphrodite.worker.embedding_model_runner import EmbeddingModelRunner
+from aphrodite.worker.enc_dec_model_runner import (
     EncoderDecoderModelRunner)
     EncoderDecoderModelRunner)
-from aphrodite.task_handler.model_runner import GPUModelRunnerBase, ModelRunner
-from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
+from aphrodite.worker.model_runner import GPUModelRunnerBase, ModelRunner
+from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
                                                 WorkerInput)
                                                 WorkerInput)
 
 
 
 

+ 1 - 1
aphrodite/task_handler/worker_base.py → aphrodite/worker/worker_base.py

@@ -15,7 +15,7 @@ from aphrodite.distributed import (broadcast_tensor_dict, get_pp_group,
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.platforms import current_platform
 from aphrodite.platforms import current_platform
-from aphrodite.task_handler.model_runner_base import (BroadcastableModelInput,
+from aphrodite.worker.model_runner_base import (BroadcastableModelInput,
                                                       ModelRunnerBase,
                                                       ModelRunnerBase,
                                                       ModelRunnerInputBase)
                                                       ModelRunnerInputBase)
 
 

+ 2 - 2
aphrodite/task_handler/xpu_model_runner.py → aphrodite/worker/xpu_model_runner.py

@@ -22,9 +22,9 @@ from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                                   MultiModalInputs, MultiModalRegistry)
                                   MultiModalInputs, MultiModalRegistry)
-from aphrodite.task_handler.model_runner import (AttentionMetadata,
+from aphrodite.worker.model_runner import (AttentionMetadata,
                                                  SamplingMetadata)
                                                  SamplingMetadata)
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
     ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
     _add_attn_metadata_broadcastable_dict,
     _add_attn_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict,

+ 4 - 4
aphrodite/task_handler/xpu_worker.py → aphrodite/worker/xpu_worker.py

@@ -17,10 +17,10 @@ from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    init_distributed_environment)
                                    init_distributed_environment)
 from aphrodite.distributed.parallel_state import get_pp_group
 from aphrodite.distributed.parallel_state import get_pp_group
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling import set_random_seed
-from aphrodite.task_handler.cache_engine import CacheEngine
-from aphrodite.task_handler.worker import Worker
-from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
-from aphrodite.task_handler.xpu_model_runner import XPUModelRunner
+from aphrodite.worker.cache_engine import CacheEngine
+from aphrodite.worker.worker import Worker
+from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase
+from aphrodite.worker.xpu_model_runner import XPUModelRunner
 
 
 
 
 class XPUWorker(LoraNotSupportedWorkerBase, Worker):
 class XPUWorker(LoraNotSupportedWorkerBase, Worker):

+ 1 - 1
docs/pages/usage/debugging.md

@@ -28,7 +28,7 @@ Set these environment variables:
 
 
 ***
 ***
 
 
-If your instance crashes and the error trace shows somewhere around `self.graph.replay()` in `aphrodite/task_handler/model_runner.py`, then it's very likely a CUDA error inside the CUDAGraph. To know the particular operation that causes the error, you can add `--enforce-eager` in the CLI, `- enforce_eager: true` in the YAML config, or `enforce_eager=True` in the `LLM` class. This will disable CUDAGraph optimization, which might make it easier to find the root cause.
+If your instance crashes and the error trace shows somewhere around `self.graph.replay()` in `aphrodite/worker/model_runner.py`, then it's very likely a CUDA error inside the CUDAGraph. To know the particular operation that causes the error, you can add `--enforce-eager` in the CLI, `- enforce_eager: true` in the YAML config, or `enforce_eager=True` in the `LLM` class. This will disable CUDAGraph optimization, which might make it easier to find the root cause.
 
 
 Here's some other common issues that might cause freezes:
 Here's some other common issues that might cause freezes:
 
 

+ 1 - 1
kernels/attention/attention_kernels.cu

@@ -728,7 +728,7 @@ void paged_attention_v1_launcher(
   int logits_size = padded_max_seq_len * sizeof(float);
   int logits_size = padded_max_seq_len * sizeof(float);
   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
   // Python-side check in
   // Python-side check in
-  // aphrodite.task_handler.worker._check_if_can_support_max_seq_len Keep that
+  // aphrodite.worker.worker._check_if_can_support_max_seq_len Keep that
   // in sync with the logic here!
   // in sync with the logic here!
   int shared_mem_size = std::max(logits_size, outputs_size);
   int shared_mem_size = std::max(logits_size, outputs_size);
 
 

+ 331 - 380
kernels/backup/attention_kernels.cu

@@ -1,5 +1,6 @@
 /*
 /*
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Adapted from
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  * Copyright (c) 2023, The PygmalionAI team.
  * Copyright (c) 2023, The PygmalionAI team.
  * Copyright (c) 2023, The vLLM team.
  * Copyright (c) 2023, The vLLM team.
  * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
  * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
@@ -17,7 +18,7 @@
  * limitations under the License.
  * limitations under the License.
  */
  */
 #ifdef USE_ROCM
 #ifdef USE_ROCM
-#include <hip/hip_runtime.h>
+  #include <hip/hip_runtime.h>
 #endif
 #endif
 
 
 #include <torch/extension.h>
 #include <torch/extension.h>
@@ -28,15 +29,15 @@
 #include "attention_utils.cuh"
 #include "attention_utils.cuh"
 #include "../quantization/int8_kvcache/quant_utils.cuh"
 #include "../quantization/int8_kvcache/quant_utils.cuh"
 #ifdef ENABLE_FP8_E5M2
 #ifdef ENABLE_FP8_E5M2
-#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+  #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
 #endif
 #endif
 
 
 #include <algorithm>
 #include <algorithm>
 
 
 #ifndef USE_ROCM
 #ifndef USE_ROCM
-#define WARP_SIZE 32
+  #define WARP_SIZE 32
 #else
 #else
-#define WARP_SIZE warpSize
+  #define WARP_SIZE warpSize
 #endif
 #endif
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -47,12 +48,13 @@ enum kv_cache_dtype {
 #ifdef ENABLE_FP8_E5M2
 #ifdef ENABLE_FP8_E5M2
   FP8_E5M2,
   FP8_E5M2,
 #endif
 #endif
-  INT8};
+  INT8
+};
 
 
 namespace aphrodite {
 namespace aphrodite {
 
 
 // Utility function for attention softmax.
 // Utility function for attention softmax.
-template<int NUM_WARPS>
+template <int NUM_WARPS>
 inline __device__ float block_sum(float* red_smem, float sum) {
 inline __device__ float block_sum(float* red_smem, float sum) {
   // Decompose the thread index into warp / lane.
   // Decompose the thread index into warp / lane.
   int warp = threadIdx.x / WARP_SIZE;
   int warp = threadIdx.x / WARP_SIZE;
@@ -89,34 +91,29 @@ inline __device__ float block_sum(float* red_smem, float sum) {
 
 
 // TODO: Merge the last two dimensions of the grid.
 // TODO: Merge the last two dimensions of the grid.
 // Grid: (num_heads, num_seqs, max_num_partitions).
 // Grid: (num_heads, num_seqs, max_num_partitions).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int PARTITION_SIZE = 0> // Zero means no partitioning.
+template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
+          int NUM_THREADS, kv_cache_dtype KV_CACHE_DTYPE,
+          int PARTITION_SIZE = 0>  // Zero means no partitioning.
 __device__ void paged_attention_kernel(
 __device__ void paged_attention_kernel(
-  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
-  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
+    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
+    float* __restrict__ max_logits,  // [num_seqs, num_heads,
+                                     // max_num_partitions]
+    scalar_t* __restrict__ out,  // [num_seqs, num_heads, max_num_partitions,
+                                 // head_size]
+    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
+    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size/x, block_size, x]
+    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size, block_size]
+    const int num_kv_heads,               // [num_heads]
+    const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride,
+    const float k_scale = 1.0f, const float k_zp = 0.0f,
+    const float v_scale = 1.0f, const float v_zp = 0.0f) {
   const int seq_idx = blockIdx.y;
   const int seq_idx = blockIdx.y;
   const int partition_idx = blockIdx.z;
   const int partition_idx = blockIdx.z;
   const int max_num_partitions = gridDim.z;
   const int max_num_partitions = gridDim.z;
@@ -128,22 +125,29 @@ __device__ void paged_attention_kernel(
   }
   }
 
 
   const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
   const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
-  const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
+  const int num_blocks_per_partition =
+      USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
 
 
   // [start_block_idx, end_block_idx) is the range of blocks to process.
   // [start_block_idx, end_block_idx) is the range of blocks to process.
-  const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
-  const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
+  const int start_block_idx =
+      USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
+  const int end_block_idx =
+      MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
   const int num_blocks = end_block_idx - start_block_idx;
   const int num_blocks = end_block_idx - start_block_idx;
 
 
   // [start_token_idx, end_token_idx) is the range of tokens to process.
   // [start_token_idx, end_token_idx) is the range of tokens to process.
   const int start_token_idx = start_block_idx * BLOCK_SIZE;
   const int start_token_idx = start_block_idx * BLOCK_SIZE;
-  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
+  const int end_token_idx =
+      MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
   const int num_tokens = end_token_idx - start_token_idx;
   const int num_tokens = end_token_idx - start_token_idx;
 
 
   constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
   constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
-  constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
+  constexpr int NUM_THREAD_GROUPS =
+      NUM_THREADS / THREAD_GROUP_SIZE;  // Note: This assumes THREAD_GROUP_SIZE
+                                        // divides NUM_THREADS
   assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
   assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
-  constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
+  constexpr int NUM_TOKENS_PER_THREAD_GROUP =
+      DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
   const int thread_idx = threadIdx.x;
   const int thread_idx = threadIdx.x;
   const int warp_idx = thread_idx / WARP_SIZE;
   const int warp_idx = thread_idx / WARP_SIZE;
@@ -153,13 +157,14 @@ __device__ void paged_attention_kernel(
   const int num_heads = gridDim.x;
   const int num_heads = gridDim.x;
   const int num_queries_per_kv = num_heads / num_kv_heads;
   const int num_queries_per_kv = num_heads / num_kv_heads;
   const int kv_head_idx = head_idx / num_queries_per_kv;
   const int kv_head_idx = head_idx / num_queries_per_kv;
-  const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
+  const float alibi_slope =
+      alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
 
 
   // A vector type to store a part of a key or a query.
   // A vector type to store a part of a key or a query.
-  // The vector size is configured in such a way that the threads in a thread group
-  // fetch or compute 16 bytes at a time.
-  // For example, if the size of a thread group is 4 and the data type is half,
-  // then the vector size is 16 / (4 * sizeof(half)) == 2.
+  // The vector size is configured in such a way that the threads in a thread
+  // group fetch or compute 16 bytes at a time. For example, if the size of a
+  // thread group is 4 and the data type is half, then the vector size is 16 /
+  // (4 * sizeof(half)) == 2.
   constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
   constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
   using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
   using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
   using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
   using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
@@ -173,18 +178,21 @@ __device__ void paged_attention_kernel(
 
 
   // Load the query to registers.
   // Load the query to registers.
   // Each thread in a thread group has a different part of the query.
   // Each thread in a thread group has a different part of the query.
-  // For example, if the the thread group size is 4, then the first thread in the group
-  // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
-  // th vectors of the query, and so on.
-  // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
+  // For example, if the the thread group size is 4, then the first thread in
+  // the group has 0, 4, 8, ... th vectors of the query, and the second thread
+  // has 1, 5, 9, ... th vectors of the query, and so on. NOTE: Because q is
+  // split from a qkv tensor, it may not be contiguous.
   const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
   const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
   __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
   __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
 #pragma unroll
 #pragma unroll
-  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
+  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
+       i += NUM_THREAD_GROUPS) {
     const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
     const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
-    q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
+    q_vecs[thread_group_offset][i] =
+        *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
   }
   }
-  __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
+  __syncthreads();  // TODO: possible speedup if this is replaced with a memory
+                    // wall right before we use q_vecs
 
 
   // Memory planning.
   // Memory planning.
   extern __shared__ char shared_mem[];
   extern __shared__ char shared_mem[];
@@ -203,51 +211,60 @@ __device__ void paged_attention_kernel(
   // Each thread group in a warp fetches a key from the block, and computes
   // Each thread group in a warp fetches a key from the block, and computes
   // dot product with the query.
   // dot product with the query.
   const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
   const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
-  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
+       block_idx += NUM_WARPS) {
     // NOTE: The block number is stored in int32. However, we cast it to int64
     // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by large numbers
-    // (e.g., kv_block_stride).
-    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+    // because int32 can lead to overflow when this variable is multiplied by
+    // large numbers (e.g., kv_block_stride).
+    const int64_t physical_block_number =
+        static_cast<int64_t>(block_table[block_idx]);
 
 
     // Load a key to registers.
     // Load a key to registers.
     // Each thread in a thread group has a different part of the key.
     // Each thread in a thread group has a different part of the key.
-    // For example, if the the thread group size is 4, then the first thread in the group
-    // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
-    // vectors of the key, and so on.
+    // For example, if the the thread group size is 4, then the first thread in
+    // the group has 0, 4, 8, ... th vectors of the key, and the second thread
+    // has 1, 5, 9, ... th vectors of the key, and so on.
     for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
     for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
-      const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
+      const int physical_block_offset =
+          (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
       const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
       const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
       K_vec k_vecs[NUM_VECS_PER_THREAD];
       K_vec k_vecs[NUM_VECS_PER_THREAD];
 
 
 #pragma unroll
 #pragma unroll
       for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
       for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
-        const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
-                                       + kv_head_idx * kv_head_stride
-                                       + physical_block_offset * x;
+        const cache_t* k_ptr =
+            k_cache + physical_block_number * kv_block_stride +
+            kv_head_idx * kv_head_stride + physical_block_offset * x;
         const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
         const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
         const int offset1 = (vec_idx * VEC_SIZE) / x;
         const int offset1 = (vec_idx * VEC_SIZE) / x;
         const int offset2 = (vec_idx * VEC_SIZE) % x;
         const int offset2 = (vec_idx * VEC_SIZE) % x;
         if constexpr (KV_CACHE_DTYPE == INT8) {
         if constexpr (KV_CACHE_DTYPE == INT8) {
-          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
+              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
           using Dequant_vec = typename FloatVec<Quant_vec>::Type;
           using Dequant_vec = typename FloatVec<Quant_vec>::Type;
           Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
           Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
           k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
           k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
 #ifdef ENABLE_FP8_E5M2
 #ifdef ENABLE_FP8_E5M2
         } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
         } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
-          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
+              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
           // Vector conversion from Quant_vec to K_vec.
           // Vector conversion from Quant_vec to K_vec.
-          k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
+          k_vecs[j] =
+              fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
 #endif
 #endif
         } else {
         } else {
-          k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          k_vecs[j] = *reinterpret_cast<const K_vec*>(
+              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
         }
         }
       }
       }
 
 
       // Compute dot product.
       // Compute dot product.
       // This includes a reduction across the threads in the same thread group.
       // This includes a reduction across the threads in the same thread group.
-      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
+      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
+                             q_vecs[thread_group_offset], k_vecs);
       // Add the ALiBi bias if slopes are given.
       // Add the ALiBi bias if slopes are given.
-      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
+      qk +=
+          (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
 
 
       if (thread_group_offset == 0) {
       if (thread_group_offset == 0) {
         // Store the partial reductions to shared memory.
         // Store the partial reductions to shared memory.
@@ -300,13 +317,12 @@ __device__ void paged_attention_kernel(
 
 
   // If partitioning is enabled, store the max logit and exp_sum.
   // If partitioning is enabled, store the max logit and exp_sum.
   if (USE_PARTITIONING && thread_idx == 0) {
   if (USE_PARTITIONING && thread_idx == 0) {
-    float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
-                                       + head_idx * max_num_partitions
-                                       + partition_idx;
+    float* max_logits_ptr = max_logits +
+                            seq_idx * num_heads * max_num_partitions +
+                            head_idx * max_num_partitions + partition_idx;
     *max_logits_ptr = qk_max;
     *max_logits_ptr = qk_max;
-    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
-                                   + head_idx * max_num_partitions
-                                   + partition_idx;
+    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
+                          head_idx * max_num_partitions + partition_idx;
     *exp_sums_ptr = exp_sum;
     *exp_sums_ptr = exp_sum;
   }
   }
 
 
@@ -319,7 +335,8 @@ __device__ void paged_attention_kernel(
 
 
   constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
   constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
   constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
   constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
-  constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
+  constexpr int NUM_ROWS_PER_THREAD =
+      DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
 
 
   // NOTE: We use FP32 for the accumulator for better accuracy.
   // NOTE: We use FP32 for the accumulator for better accuracy.
   float accs[NUM_ROWS_PER_THREAD];
   float accs[NUM_ROWS_PER_THREAD];
@@ -330,18 +347,21 @@ __device__ void paged_attention_kernel(
 
 
   scalar_t zero_value;
   scalar_t zero_value;
   zero(zero_value);
   zero(zero_value);
-  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
+       block_idx += NUM_WARPS) {
     // NOTE: The block number is stored in int32. However, we cast it to int64
     // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by large numbers
-    // (e.g., kv_block_stride).
-    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+    // because int32 can lead to overflow when this variable is multiplied by
+    // large numbers (e.g., kv_block_stride).
+    const int64_t physical_block_number =
+        static_cast<int64_t>(block_table[block_idx]);
     const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
     const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
     const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
     const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
     L_vec logits_vec;
     L_vec logits_vec;
-    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
+    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
+                                                           start_token_idx));
 
 
-    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
-                                   + kv_head_idx * kv_head_stride;
+    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
+                           kv_head_idx * kv_head_stride;
 #pragma unroll
 #pragma unroll
     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@@ -350,26 +370,32 @@ __device__ void paged_attention_kernel(
         V_vec v_vec;
         V_vec v_vec;
         if constexpr (KV_CACHE_DTYPE == INT8) {
         if constexpr (KV_CACHE_DTYPE == INT8) {
           // dequant and conversion
           // dequant and conversion
-          V_quant_vec v_vec_quant = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+          V_quant_vec v_vec_quant =
+              *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
           using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
           using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
-          V_dequant_vec v_vec_dequant = int8::dequant(v_vec_quant, v_scale, v_zp);
+          V_dequant_vec v_vec_dequant =
+              int8::dequant(v_vec_quant, v_scale, v_zp);
           v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
           v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
 #ifdef ENABLE_FP8_E5M2
 #ifdef ENABLE_FP8_E5M2
         } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
         } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
-          V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+          V_quant_vec v_quant_vec =
+              *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
           // Vector conversion from V_quant_vec to V_vec.
           // Vector conversion from V_quant_vec to V_vec.
-          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
+          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(
+              v_quant_vec);
 #endif
 #endif
         } else {
         } else {
           v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
           v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
         }
         }
         if (block_idx == num_context_blocks - 1) {
         if (block_idx == num_context_blocks - 1) {
           // NOTE: When v_vec contains the tokens that are out of the context,
           // NOTE: When v_vec contains the tokens that are out of the context,
-          // we should explicitly zero out the values since they may contain NaNs.
+          // we should explicitly zero out the values since they may contain
+          // NaNs.
           scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
           scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
 #pragma unroll
 #pragma unroll
           for (int j = 0; j < V_VEC_SIZE; j++) {
           for (int j = 0; j < V_VEC_SIZE; j++) {
-            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
+            v_vec_ptr[j] =
+                token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
           }
           }
         }
         }
         accs[i] += dot(logits_vec, v_vec);
         accs[i] += dot(logits_vec, v_vec);
@@ -426,9 +452,9 @@ __device__ void paged_attention_kernel(
 
 
   // Write the final output.
   // Write the final output.
   if (warp_idx == 0) {
   if (warp_idx == 0) {
-    scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                            + head_idx * max_num_partitions * HEAD_SIZE
-                            + partition_idx * HEAD_SIZE;
+    scalar_t* out_ptr =
+        out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+        head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
 #pragma unroll
 #pragma unroll
     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@@ -440,85 +466,77 @@ __device__ void paged_attention_kernel(
 }
 }
 
 
 // Grid: (num_heads, num_seqs, 1).
 // Grid: (num_heads, num_seqs, 1).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE>
+template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
+          int NUM_THREADS,
+          kv_cache_dtype KV_CACHE_DTYPE>
 __global__ void paged_attention_v1_kernel(
 __global__ void paged_attention_v1_kernel(
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>(
-    /* exp_sums */ nullptr, /* max_logits */ nullptr,
-    out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
-    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
+    scalar_t* __restrict__ out,           // [num_seqs, num_heads, head_size]
+    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
+    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size/x, block_size, x]
+    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size, block_size]
+    const int num_kv_heads,               // [num_heads]
+    const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride,
+    const float k_scale, const float k_zp, const float v_scale,
+    const float v_zp) {
+  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
+                         KV_CACHE_DTYPE>(
+      /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
+      v_cache, num_kv_heads, scale, block_tables, context_lens,
+      max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
+      kv_head_stride, k_scale, k_zp, v_scale, v_zp);
 }
 }
 
 
 // Grid: (num_heads, num_seqs, max_num_partitions).
 // Grid: (num_heads, num_seqs, max_num_partitions).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int PARTITION_SIZE>
+template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
+          int NUM_THREADS, kv_cache_dtype KV_CACHE_DTYPE,
+          int PARTITION_SIZE>
 __global__ void paged_attention_v2_kernel(
 __global__ void paged_attention_v2_kernel(
-  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
-  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
-  scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE, PARTITION_SIZE>(
-    exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
-    block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
-    q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
+    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
+    float* __restrict__ max_logits,       // [num_seqs, num_heads,
+                                          // max_num_partitions]
+    scalar_t* __restrict__ tmp_out,       // [num_seqs, num_heads,
+                                          // max_num_partitions, head_size]
+    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
+    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size/x, block_size, x]
+    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size, block_size]
+    const int num_kv_heads,               // [num_heads]
+    const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride,
+    const float k_scale, const float k_zp, const float v_scale,
+    const float v_zp) {
+  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
+                         KV_CACHE_DTYPE, PARTITION_SIZE>(
+      exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
+      block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
+      q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
 }
 }
 
 
 // Grid: (num_heads, num_seqs).
 // Grid: (num_heads, num_seqs).
-template<
-  typename scalar_t,
-  int HEAD_SIZE,
-  int NUM_THREADS,
-  int PARTITION_SIZE>
+template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
+          int PARTITION_SIZE>
 __global__ void paged_attention_v2_reduce_kernel(
 __global__ void paged_attention_v2_reduce_kernel(
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
-  const float* __restrict__ exp_sums,     // [num_seqs, num_heads, max_num_partitions]
-  const float* __restrict__ max_logits,   // [num_seqs, num_heads, max_num_partitions]
-  const scalar_t* __restrict__ tmp_out,   // [num_seqs, num_heads, max_num_partitions, head_size]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_partitions) {
+    scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
+    const float* __restrict__ exp_sums,    // [num_seqs, num_heads,
+                                           // max_num_partitions]
+    const float* __restrict__ max_logits,  // [num_seqs, num_heads,
+                                           // max_num_partitions]
+    const scalar_t* __restrict__ tmp_out,  // [num_seqs, num_heads,
+                                           // max_num_partitions, head_size]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_partitions) {
   const int num_heads = gridDim.x;
   const int num_heads = gridDim.x;
   const int head_idx = blockIdx.x;
   const int head_idx = blockIdx.x;
   const int seq_idx = blockIdx.y;
   const int seq_idx = blockIdx.y;
@@ -526,9 +544,11 @@ __global__ void paged_attention_v2_reduce_kernel(
   const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
   const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
   if (num_partitions == 1) {
   if (num_partitions == 1) {
     // No need to reduce. Only copy tmp_out to out.
     // No need to reduce. Only copy tmp_out to out.
-    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
-    const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                                          + head_idx * max_num_partitions * HEAD_SIZE;
+    scalar_t* out_ptr =
+        out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+    const scalar_t* tmp_out_ptr =
+        tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+        head_idx * max_num_partitions * HEAD_SIZE;
     for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
     for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
       out_ptr[i] = tmp_out_ptr[i];
       out_ptr[i] = tmp_out_ptr[i];
     }
     }
@@ -547,8 +567,9 @@ __global__ void paged_attention_v2_reduce_kernel(
 
 
   // Load max logits to shared memory.
   // Load max logits to shared memory.
   float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
   float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
-  const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
-                                           + head_idx * max_num_partitions;
+  const float* max_logits_ptr = max_logits +
+                                seq_idx * num_heads * max_num_partitions +
+                                head_idx * max_num_partitions;
   float max_logit = -FLT_MAX;
   float max_logit = -FLT_MAX;
   for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
   for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
     const float l = max_logits_ptr[i];
     const float l = max_logits_ptr[i];
@@ -577,9 +598,11 @@ __global__ void paged_attention_v2_reduce_kernel(
   max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
   max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
 
 
   // Load rescaled exp sums to shared memory.
   // Load rescaled exp sums to shared memory.
-  float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
-  const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
-                                       + head_idx * max_num_partitions;
+  float* shared_exp_sums =
+      reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
+  const float* exp_sums_ptr = exp_sums +
+                              seq_idx * num_heads * max_num_partitions +
+                              head_idx * max_num_partitions;
   float global_exp_sum = 0.0f;
   float global_exp_sum = 0.0f;
   for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
   for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
     float l = shared_max_logits[i];
     float l = shared_max_logits[i];
@@ -592,67 +615,47 @@ __global__ void paged_attention_v2_reduce_kernel(
   const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
   const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
 
 
   // Aggregate tmp_out to out.
   // Aggregate tmp_out to out.
-  const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                                        + head_idx * max_num_partitions * HEAD_SIZE;
-  scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+  const scalar_t* tmp_out_ptr =
+      tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+      head_idx * max_num_partitions * HEAD_SIZE;
+  scalar_t* out_ptr =
+      out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
 #pragma unroll
 #pragma unroll
   for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
   for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
     float acc = 0.0f;
     float acc = 0.0f;
     for (int j = 0; j < num_partitions; ++j) {
     for (int j = 0; j < num_partitions; ++j) {
-      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
+      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
+             inv_global_exp_sum;
     }
     }
     from_float(out_ptr[i], acc);
     from_float(out_ptr[i], acc);
   }
   }
 }
 }
 
 
-} // namespace aphrodite
-
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                        \
-  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                        \
-    ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,    \
-      KV_CACHE_DTYPE>), shared_mem_size);                                                           \
-  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,              \
-  KV_CACHE_DTYPE><<<grid, block, shared_mem_size, stream>>>(                                        \
-    out_ptr,                                                                                        \
-    query_ptr,                                                                                      \
-    key_cache_ptr,                                                                                  \
-    value_cache_ptr,                                                                                \
-    num_kv_heads,                                                                                   \
-    scale,                                                                                          \
-    block_tables_ptr,                                                                               \
-    context_lens_ptr,                                                                               \
-    max_num_blocks_per_seq,                                                                         \
-    alibi_slopes_ptr,                                                                               \
-    q_stride,                                                                                       \
-    kv_block_stride,                                                                                \
-    kv_head_stride,                                                                                 \
-    k_scale,                                                                                        \
-    k_zp,                                                                                           \
-    v_scale,                                                                                        \
-    v_zp);
+}  // namespace aphrodite
+
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                 \
+  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                 \
+      ((void*)aphrodite::paged_attention_v1_kernel<                          \
+          T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>),  \
+      shared_mem_size);                                                      \
+  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,    \
+                                       NUM_THREADS, KV_CACHE_DTYPE>          \
+      <<<grid, block, shared_mem_size, stream>>>(                            \
+          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads,  \
+          scale, block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
+          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,       \
+          k_scale, k_zp, v_scale, v_zp);
 
 
 // TODO: Tune NUM_THREADS.
 // TODO: Tune NUM_THREADS.
-template<
-  typename T,
-  typename CACHE_T,
-  int BLOCK_SIZE,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int NUM_THREADS = 128>
+template <typename T, typename CACHE_T, int BLOCK_SIZE,
+          kv_cache_dtype KV_CACHE_DTYPE, int NUM_THREADS = 128>
 void paged_attention_v1_launcher(
 void paged_attention_v1_launcher(
-  torch::Tensor& out,
-  torch::Tensor& query,
-  torch::Tensor& key_cache,
-  torch::Tensor& value_cache,
-  int num_kv_heads,
-  float scale,
-  torch::Tensor& block_tables,
-  torch::Tensor& context_lens,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
+    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& context_lens,
+    int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const float k_scale, const float k_zp, const float v_scale,
+    const float v_zp) {
   int num_seqs = query.size(0);
   int num_seqs = query.size(0);
   int num_heads = query.size(1);
   int num_heads = query.size(1);
   int head_size = query.size(2);
   int head_size = query.size(2);
@@ -665,9 +668,10 @@ void paged_attention_v1_launcher(
   assert(head_size % thread_group_size == 0);
   assert(head_size % thread_group_size == 0);
 
 
   // NOTE: alibi_slopes is optional.
   // NOTE: alibi_slopes is optional.
-  const float* alibi_slopes_ptr = alibi_slopes ?
-    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
-    : nullptr;
+  const float* alibi_slopes_ptr =
+      alibi_slopes
+          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+          : nullptr;
 
 
   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
@@ -677,11 +681,13 @@ void paged_attention_v1_launcher(
   int* context_lens_ptr = context_lens.data_ptr<int>();
   int* context_lens_ptr = context_lens.data_ptr<int>();
 
 
   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
+  int padded_max_context_len =
+      DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
   int logits_size = padded_max_context_len * sizeof(float);
   int logits_size = padded_max_context_len * sizeof(float);
   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
-  // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
-  // Keep that in sync with the logic here!
+  // Python-side check in
+  // aphrodite.worker.worker._check_if_can_support_max_seq_len Keep that in sync
+  // with the logic here!
   int shared_mem_size = std::max(logits_size, outputs_size);
   int shared_mem_size = std::max(logits_size, outputs_size);
 
 
   dim3 grid(num_heads, num_seqs, 1);
   dim3 grid(num_heads, num_seqs, 1);
@@ -718,56 +724,44 @@ void paged_attention_v1_launcher(
 
 
 #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)             \
 #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)             \
   paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(       \
   paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(       \
-    out,                                                                     \
-    query,                                                                   \
-    key_cache,                                                               \
-    value_cache,                                                             \
-    num_kv_heads,                                                            \
-    scale,                                                                   \
-    block_tables,                                                            \
-    context_lens,                                                            \
-    max_context_len,                                                         \
-    alibi_slopes,                                                            \
-    k_scale,                                                                 \
-    k_zp,                                                                    \
-    v_scale,                                                                 \
-    v_zp);
+      out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
+      context_lens, max_context_len, alibi_slopes, k_scale, k_zp, v_scale,   \
+      v_zp);
 
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
 // 1, 2, 4, 64, 128, 256.
-#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)       \
-  switch (block_size) {                                               \
-    case 8:                                                           \
-      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                \
-      break;                                                          \
-    case 16:                                                          \
-      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);               \
-      break;                                                          \
-    case 32:                                                          \
-      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);               \
-      break;                                                          \
-    default:                                                          \
-      TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
-      break;                                                          \
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)   \
+  switch (block_size) {                                           \
+    case 8:                                                       \
+      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);            \
+      break;                                                      \
+    case 16:                                                      \
+      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);           \
+      break;                                                      \
+    case 32:                                                      \
+      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);           \
+      break;                                                      \
+    default:                                                      \
+      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+      break;                                                      \
   }
   }
 
 
 void paged_attention_v1(
 void paged_attention_v1(
-  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
-  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
-  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
-  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
-  int num_kv_heads,               // [num_heads]
-  float scale,
-  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
-  torch::Tensor& context_lens,    // [num_seqs]
-  int block_size,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
+    torch::Tensor& out,    // [num_seqs, num_heads, head_size]
+    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
+    torch::Tensor&
+        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
+    torch::Tensor&
+        value_cache,   // [num_blocks, num_heads, head_size, block_size]
+    int num_kv_heads,  // [num_heads]
+    float scale,
+    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    torch::Tensor& context_lens,  // [num_seqs]
+    int block_size, int max_context_len,
+    const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, const float k_scale = 1.0f,
+    const float k_zp = 0.0f, const float v_scale = 1.0f,
+    const float v_zp = 0.0f) {
   if (kv_cache_dtype == "auto") {
   if (kv_cache_dtype == "auto") {
     if (query.dtype() == at::ScalarType::Float) {
     if (query.dtype() == at::ScalarType::Float) {
       CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
       CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
@@ -805,63 +799,33 @@ void paged_attention_v1(
   }
   }
 }
 }
 
 
-#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
-  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,        \
-  KV_CACHE_DTYPE, PARTITION_SIZE>                                                             \
-  <<<grid, block, shared_mem_size, stream>>>(                                                 \
-    exp_sums_ptr,                                                                             \
-    max_logits_ptr,                                                                           \
-    tmp_out_ptr,                                                                              \
-    query_ptr,                                                                                \
-    key_cache_ptr,                                                                            \
-    value_cache_ptr,                                                                          \
-    num_kv_heads,                                                                             \
-    scale,                                                                                    \
-    block_tables_ptr,                                                                         \
-    context_lens_ptr,                                                                         \
-    max_num_blocks_per_seq,                                                                   \
-    alibi_slopes_ptr,                                                                         \
-    q_stride,                                                                                 \
-    kv_block_stride,                                                                          \
-    kv_head_stride,                                                                           \
-    k_scale,                                                                                  \
-    k_zp,                                                                                     \
-    v_scale,                                                                                  \
-    v_zp);                                                                                    \
-  aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
-  <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
-    out_ptr,                                                                                  \
-    exp_sums_ptr,                                                                             \
-    max_logits_ptr,                                                                           \
-    tmp_out_ptr,                                                                              \
-    context_lens_ptr,                                                                         \
-    max_num_partitions);
-
-template<
-  typename T,
-  typename CACHE_T,
-  int BLOCK_SIZE,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int NUM_THREADS = 128,
-  int PARTITION_SIZE = 512>
+#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                   \
+  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,      \
+                                       NUM_THREADS, KV_CACHE_DTYPE,            \
+                                       PARTITION_SIZE>                         \
+      <<<grid, block, shared_mem_size, stream>>>(                              \
+          exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
+          value_cache_ptr, num_kv_heads, scale, block_tables_ptr,              \
+          context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr,          \
+          q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale,   \
+          v_zp);                                                               \
+  aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS,       \
+                                              PARTITION_SIZE>                  \
+      <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                \
+          out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,                  \
+          context_lens_ptr, max_num_partitions);
+
+template <typename T, typename CACHE_T, int BLOCK_SIZE,
+          kv_cache_dtype KV_CACHE_DTYPE, int NUM_THREADS = 128,
+          int PARTITION_SIZE = 512>
 void paged_attention_v2_launcher(
 void paged_attention_v2_launcher(
-  torch::Tensor& out,
-  torch::Tensor& exp_sums,
-  torch::Tensor& max_logits,
-  torch::Tensor& tmp_out,
-  torch::Tensor& query,
-  torch::Tensor& key_cache,
-  torch::Tensor& value_cache,
-  int num_kv_heads,
-  float scale,
-  torch::Tensor& block_tables,
-  torch::Tensor& context_lens,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
+    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
+    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& context_lens,
+    int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const float k_scale, const float k_zp, const float v_scale,
+    const float v_zp) {
   int num_seqs = query.size(0);
   int num_seqs = query.size(0);
   int num_heads = query.size(1);
   int num_heads = query.size(1);
   int head_size = query.size(2);
   int head_size = query.size(2);
@@ -874,9 +838,10 @@ void paged_attention_v2_launcher(
   assert(head_size % thread_group_size == 0);
   assert(head_size % thread_group_size == 0);
 
 
   // NOTE: alibi_slopes is optional.
   // NOTE: alibi_slopes is optional.
-  const float* alibi_slopes_ptr = alibi_slopes ?
-    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
-    : nullptr;
+  const float* alibi_slopes_ptr =
+      alibi_slopes
+          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+          : nullptr;
 
 
   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
   float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
   float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
@@ -931,64 +896,50 @@ void paged_attention_v2_launcher(
   }
   }
 }
 }
 
 
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)                 \
-  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(           \
-    out,                                                                         \
-    exp_sums,                                                                    \
-    max_logits,                                                                  \
-    tmp_out,                                                                     \
-    query,                                                                       \
-    key_cache,                                                                   \
-    value_cache,                                                                 \
-    num_kv_heads,                                                                \
-    scale,                                                                       \
-    block_tables,                                                                \
-    context_lens,                                                                \
-    max_context_len,                                                             \
-    alibi_slopes,                                                                \
-    k_scale,                                                                     \
-    k_zp,                                                                        \
-    v_scale,                                                                     \
-    v_zp);
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)         \
+  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(   \
+      out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
+      num_kv_heads, scale, block_tables, context_lens, max_context_len,  \
+      alibi_slopes, k_scale, k_zp, v_scale, v_zp);
 
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
 // 1, 2, 4, 64, 128, 256.
-#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)             \
-  switch (block_size) {                                                     \
-    case 8:                                                                 \
-      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                      \
-      break;                                                                \
-    case 16:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);                     \
-      break;                                                                \
-    case 32:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);                     \
-      break;                                                                \
-    default:                                                                \
-      TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
-      break;                                                                \
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)   \
+  switch (block_size) {                                           \
+    case 8:                                                       \
+      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);            \
+      break;                                                      \
+    case 16:                                                      \
+      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);           \
+      break;                                                      \
+    case 32:                                                      \
+      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);           \
+      break;                                                      \
+    default:                                                      \
+      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+      break;                                                      \
   }
   }
 
 
 void paged_attention_v2(
 void paged_attention_v2(
-  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
-  torch::Tensor& exp_sums,        // [num_seqs, num_heads, max_num_partitions]
-  torch::Tensor& max_logits,      // [num_seqs, num_heads, max_num_partitions]
-  torch::Tensor& tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
-  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
-  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
-  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
-  int num_kv_heads,               // [num_heads]
-  float scale,
-  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
-  torch::Tensor& context_lens,    // [num_seqs]
-  int block_size,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
+    torch::Tensor& out,         // [num_seqs, num_heads, head_size]
+    torch::Tensor& exp_sums,    // [num_seqs, num_heads, max_num_partitions]
+    torch::Tensor& max_logits,  // [num_seqs, num_heads, max_num_partitions]
+    torch::Tensor&
+        tmp_out,  // [num_seqs, num_heads, max_num_partitions, head_size]
+    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
+    torch::Tensor&
+        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
+    torch::Tensor&
+        value_cache,   // [num_blocks, num_heads, head_size, block_size]
+    int num_kv_heads,  // [num_heads]
+    float scale,
+    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    torch::Tensor& context_lens,  // [num_seqs]
+    int block_size, int max_context_len,
+    const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, const float k_scale = 1.0f,
+    const float k_zp = 0.0f, const float v_scale = 1.0f,
+    const float v_zp = 0.0f) {
   if (kv_cache_dtype == "auto") {
   if (kv_cache_dtype == "auto") {
     if (query.dtype() == at::ScalarType::Float) {
     if (query.dtype() == at::ScalarType::Float) {
       CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
       CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);

+ 1 - 1
tests/engine/test_multiproc_workers.py

@@ -12,7 +12,7 @@ from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
 
 
 
 
 class DummyWorker:
 class DummyWorker:
-    """Dummy version of aphrodite.task_handler.worker.Worker"""
+    """Dummy version of aphrodite.worker.worker.Worker"""
 
 
     def __init__(self, rank: int):
     def __init__(self, rank: int):
         self.rank = rank
         self.rank = rank

+ 1 - 1
tests/lora/conftest.py

@@ -256,7 +256,7 @@ def llama_2_7b_engine_extra_embeddings():
                              device_config=device_config,
                              device_config=device_config,
                              **kwargs)
                              **kwargs)
 
 
-    with patch("aphrodite.task_handler.model_runner.get_model",
+    with patch("aphrodite.worker.model_runner.get_model",
                get_model_patched):
                get_model_patched):
         engine = aphrodite.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
         engine = aphrodite.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
     yield engine.llm_engine
     yield engine.llm_engine

+ 1 - 1
tests/lora/test_long_context.py

@@ -123,7 +123,7 @@ def lora_llm(long_context_infos):
 def test_rotary_emb_replaced(dist_init):
 def test_rotary_emb_replaced(dist_init):
     """Verify rotary emb in all the layers are replaced"""
     """Verify rotary emb in all the layers are replaced"""
     from aphrodite.engine.args_tools import EngineArgs
     from aphrodite.engine.args_tools import EngineArgs
-    from aphrodite.task_handler.model_runner import ModelRunner
+    from aphrodite.worker.model_runner import ModelRunner
     engine_args = EngineArgs("meta-llama/Llama-2-7b-hf",
     engine_args = EngineArgs("meta-llama/Llama-2-7b-hf",
                              long_lora_scaling_factors=(4.0, ),
                              long_lora_scaling_factors=(4.0, ),
                              enable_lora=True)
                              enable_lora=True)

+ 1 - 1
tests/lora/test_worker.py

@@ -8,7 +8,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      SchedulerConfig)
                                      SchedulerConfig)
 from aphrodite.lora.models import LoRAMapping
 from aphrodite.lora.models import LoRAMapping
 from aphrodite.lora.request import LoRARequest
 from aphrodite.lora.request import LoRARequest
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.worker import Worker
 
 
 
 
 @patch.dict(os.environ, {"RANK": "0"})
 @patch.dict(os.environ, {"RANK": "0"})

+ 1 - 1
tests/models/test_jamba.py

@@ -1,6 +1,6 @@
 import pytest
 import pytest
 
 
-from aphrodite.task_handler.model_runner import _get_graph_batch_size
+from aphrodite.worker.model_runner import _get_graph_batch_size
 from tests.models.utils import check_outputs_equal
 from tests.models.utils import check_outputs_equal
 
 
 MODELS = ["ai21labs/Jamba-tiny-random"]
 MODELS = ["ai21labs/Jamba-tiny-random"]

+ 1 - 1
tests/spec_decode/test_multi_step_worker.py

@@ -12,7 +12,7 @@ from aphrodite.modeling.utils import set_random_seed
 from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
 from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
 from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
 from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
 from aphrodite.spec_decode.top1_proposer import Top1Proposer
 from aphrodite.spec_decode.top1_proposer import Top1Proposer
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.worker import Worker
 
 
 from .utils import (assert_logprobs_dict_allclose, create_batch,
 from .utils import (assert_logprobs_dict_allclose, create_batch,
                     create_seq_group_metadata_from_prompts, create_worker,
                     create_seq_group_metadata_from_prompts, create_worker,

+ 3 - 3
tests/spec_decode/utils.py

@@ -17,9 +17,9 @@ from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.utils import set_random_seed
 from aphrodite.modeling.utils import set_random_seed
-from aphrodite.task_handler.cache_engine import CacheEngine
-from aphrodite.task_handler.model_runner import ModelRunner
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.cache_engine import CacheEngine
+from aphrodite.worker.model_runner import ModelRunner
+from aphrodite.worker.worker import Worker
 
 
 T = TypeVar("T", bound=Worker)
 T = TypeVar("T", bound=Worker)
 
 

+ 1 - 2
tests/worker/test_encoder_decoder_model_runner.py

@@ -9,8 +9,7 @@ from aphrodite.common.sequence import (SamplingParams, SequenceData,
 from aphrodite.common.utils import is_cpu
 from aphrodite.common.utils import is_cpu
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
-from aphrodite.task_handler.enc_dec_model_runner import (
-    EncoderDecoderModelRunner)
+from aphrodite.worker.enc_dec_model_runner import EncoderDecoderModelRunner
 
 
 # CUDA graph scenarios to test
 # CUDA graph scenarios to test
 #
 #

+ 3 - 4
tests/worker/test_model_input.py

@@ -7,11 +7,10 @@ from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder
 from aphrodite.attention.backends.abstract import AttentionBackend
 from aphrodite.attention.backends.abstract import AttentionBackend
 from aphrodite.modeling.pooling_metadata import PoolingMetadata
 from aphrodite.modeling.pooling_metadata import PoolingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.task_handler.embedding_model_runner import (
+from aphrodite.worker.embedding_model_runner import (
     ModelInputForGPUWithPoolingMetadata)
     ModelInputForGPUWithPoolingMetadata)
-from aphrodite.task_handler.model_runner import (
-    ModelInputForGPUWithSamplingMetadata)
-from aphrodite.task_handler.multi_step_model_runner import StatefulModelInput
+from aphrodite.worker.model_runner import ModelInputForGPUWithSamplingMetadata
+from aphrodite.worker.multi_step_model_runner import StatefulModelInput
 
 
 
 
 class MockAttentionBackend(AttentionBackend):
 class MockAttentionBackend(AttentionBackend):

+ 1 - 2
tests/worker/test_model_runner.py

@@ -12,8 +12,7 @@ from aphrodite.distributed.parallel_state import (
     ensure_model_parallel_initialized, init_distributed_environment)
     ensure_model_parallel_initialized, init_distributed_environment)
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.task_handler.model_runner import (ModelRunner,
-                                                 _get_graph_batch_size)
+from aphrodite.worker.model_runner import ModelRunner, _get_graph_batch_size
 
 
 
 
 def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
 def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:

+ 1 - 1
tests/worker/test_swap.py

@@ -4,7 +4,7 @@ from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port)
                                     get_open_port)
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.engine.args_tools import EngineArgs
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.worker import Worker
 
 
 
 
 def test_swap() -> None:
 def test_swap() -> None: