Explorar o código

chore: rename `task_handler` to `worker` (#985)

AlpinDale hai 2 meses
pai
achega
3bb0f07461
Modificáronse 58 ficheiros con 414 adicións e 467 borrados
  1. 1 1
      aphrodite/attention/backends/abstract.py
  2. 1 1
      aphrodite/attention/backends/flash_attn.py
  3. 1 1
      aphrodite/attention/backends/flashinfer.py
  4. 1 1
      aphrodite/attention/backends/placeholder_attn.py
  5. 2 2
      aphrodite/attention/backends/utils.py
  6. 2 2
      aphrodite/executor/cpu_executor.py
  7. 3 3
      aphrodite/executor/gpu_executor.py
  8. 1 1
      aphrodite/executor/multiproc_worker_utils.py
  9. 1 1
      aphrodite/executor/neuron_executor.py
  10. 1 1
      aphrodite/executor/openvino_executor.py
  11. 1 1
      aphrodite/executor/ray_tpu_executor.py
  12. 2 2
      aphrodite/executor/ray_utils.py
  13. 1 1
      aphrodite/executor/tpu_executor.py
  14. 2 2
      aphrodite/executor/xpu_executor.py
  15. 2 2
      aphrodite/modeling/models/jamba.py
  16. 2 2
      aphrodite/modeling/models/mamba.py
  17. 1 1
      aphrodite/spec_decode/batch_expansion.py
  18. 1 1
      aphrodite/spec_decode/draft_model_runner.py
  19. 1 1
      aphrodite/spec_decode/medusa_worker.py
  20. 1 1
      aphrodite/spec_decode/multi_step_worker.py
  21. 1 1
      aphrodite/spec_decode/proposer_worker_base.py
  22. 2 3
      aphrodite/spec_decode/spec_decode_worker.py
  23. 1 1
      aphrodite/spec_decode/target_model_runner.py
  24. 0 0
      aphrodite/worker/__init__.py
  25. 0 0
      aphrodite/worker/cache_engine.py
  26. 1 1
      aphrodite/worker/cpu_model_runner.py
  27. 2 2
      aphrodite/worker/cpu_worker.py
  28. 1 1
      aphrodite/worker/embedding_model_runner.py
  29. 3 3
      aphrodite/worker/enc_dec_model_runner.py
  30. 1 1
      aphrodite/worker/model_runner.py
  31. 0 0
      aphrodite/worker/model_runner_base.py
  32. 2 2
      aphrodite/worker/multi_step_model_runner.py
  33. 3 3
      aphrodite/worker/multi_step_worker.py
  34. 1 1
      aphrodite/worker/neuron_model_runner.py
  35. 2 2
      aphrodite/worker/neuron_worker.py
  36. 0 0
      aphrodite/worker/openvino_model_runner.py
  37. 2 2
      aphrodite/worker/openvino_worker.py
  38. 1 1
      aphrodite/worker/tpu_model_runner.py
  39. 2 2
      aphrodite/worker/tpu_worker.py
  40. 1 1
      aphrodite/worker/utils.py
  41. 5 5
      aphrodite/worker/worker.py
  42. 1 1
      aphrodite/worker/worker_base.py
  43. 2 2
      aphrodite/worker/xpu_model_runner.py
  44. 4 4
      aphrodite/worker/xpu_worker.py
  45. 1 1
      docs/pages/usage/debugging.md
  46. 1 1
      kernels/attention/attention_kernels.cu
  47. 331 380
      kernels/backup/attention_kernels.cu
  48. 1 1
      tests/engine/test_multiproc_workers.py
  49. 1 1
      tests/lora/conftest.py
  50. 1 1
      tests/lora/test_long_context.py
  51. 1 1
      tests/lora/test_worker.py
  52. 1 1
      tests/models/test_jamba.py
  53. 1 1
      tests/spec_decode/test_multi_step_worker.py
  54. 3 3
      tests/spec_decode/utils.py
  55. 1 2
      tests/worker/test_encoder_decoder_model_runner.py
  56. 3 4
      tests/worker/test_model_input.py
  57. 1 2
      tests/worker/test_model_runner.py
  58. 1 1
      tests/worker/test_swap.py

+ 1 - 1
aphrodite/attention/backends/abstract.py

@@ -8,7 +8,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set,
 import torch
 
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner_base import (
+    from aphrodite.worker.model_runner_base import (
         ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase)
 
 

+ 1 - 1
aphrodite/attention/backends/flash_attn.py

@@ -18,7 +18,7 @@ from aphrodite.attention.backends.utils import (PAD_SLOT_ID,
 from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
 
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+    from aphrodite.worker.model_runner import ModelInputForGPUBuilder
 
 from aphrodite_flash_attn import (
     flash_attn_varlen_func as _flash_attn_varlen_func)

+ 1 - 1
aphrodite/attention/backends/flashinfer.py

@@ -33,7 +33,7 @@ from aphrodite.common.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
                                     make_tensor_with_pad)
 
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+    from aphrodite.worker.model_runner import ModelInputForGPUBuilder
 
 
 class FlashInferBackend(AttentionBackend):

+ 1 - 1
aphrodite/attention/backends/placeholder_attn.py

@@ -10,7 +10,7 @@ from aphrodite.attention.backends.abstract import (AttentionBackend,
 from aphrodite.attention.backends.utils import CommonAttentionState
 
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+    from aphrodite.worker.model_runner import ModelInputForGPUBuilder
 
 # Placeholder attention backend for models like Mamba and embedding models that
 # lack attention.

+ 2 - 2
aphrodite/attention/backends/utils.py

@@ -9,7 +9,7 @@ from aphrodite.attention import (AttentionMetadata, AttentionMetadataBuilder,
 from aphrodite.common.utils import async_tensor_h2d, make_tensor_with_pad
 
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner_base import ModelRunnerBase
+    from aphrodite.worker.model_runner_base import ModelRunnerBase
 
 # Error string(s) for encoder/decoder
 # unsupported attention scenarios
@@ -23,7 +23,7 @@ PAD_SLOT_ID = -1
 _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256
 
 if TYPE_CHECKING:
-    from aphrodite.task_handler.model_runner import ModelInputForGPUBuilder
+    from aphrodite.worker.model_runner import ModelInputForGPUBuilder
 
 
 def is_block_tables_empty(block_tables: Union[None, Dict]):

+ 2 - 2
aphrodite/executor/cpu_executor.py

@@ -18,7 +18,7 @@ from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
-from aphrodite.task_handler.worker_base import WorkerWrapperBase
+from aphrodite.worker.worker_base import WorkerWrapperBase
 
 
 class CPUExecutor(ExecutorBase):
@@ -121,7 +121,7 @@ class CPUExecutor(ExecutorBase):
         local_rank: int = 0,
         rank: int = 0,
     ):
-        worker_module_name = "aphrodite.task_handler.cpu_worker"
+        worker_module_name = "aphrodite.worker.cpu_worker"
         worker_class_name = "CPUWorker"
 
         wrapper = WorkerWrapperBase(

+ 3 - 3
aphrodite/executor/gpu_executor.py

@@ -9,7 +9,7 @@ from aphrodite.executor.executor_base import ExecutorAsyncBase, ExecutorBase
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
-from aphrodite.task_handler.worker_base import WorkerBase, WorkerWrapperBase
+from aphrodite.worker.worker_base import WorkerBase, WorkerWrapperBase
 
 
 def create_worker(worker_module_name: str, worker_class_name: str,
@@ -68,13 +68,13 @@ class GPUExecutor(ExecutorBase):
             self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
         worker_class_fn = None
         if self.scheduler_config.is_multi_step:
-            worker_module_name = "aphrodite.task_handler.multi_step_worker"
+            worker_module_name = "aphrodite.worker.multi_step_worker"
             worker_class_name = "MultiStepWorker"
         elif self.speculative_config:
             worker_module_name = "aphrodite.spec_decode.spec_decode_worker"
             worker_class_name = "create_spec_worker"
         else:
-            worker_module_name = "aphrodite.task_handler.worker"
+            worker_module_name = "aphrodite.worker.worker"
             worker_class_name = "Worker"
         return (worker_module_name, worker_class_name, worker_class_fn)
 

+ 1 - 1
aphrodite/executor/multiproc_worker_utils.py

@@ -142,7 +142,7 @@ class WorkerMonitor(threading.Thread):
 
 
 class ProcessWorkerWrapper:
-    """Local process wrapper for aphrodite.task_handler.Worker,
+    """Local process wrapper for aphrodite.worker.Worker,
     for handling single-node multi-GPU tensor parallel."""
 
     def __init__(self, result_handler: ResultHandler,

+ 1 - 1
aphrodite/executor/neuron_executor.py

@@ -22,7 +22,7 @@ class NeuronExecutor(ExecutorBase):
         self._init_worker()
 
     def _init_worker(self):
-        from aphrodite.task_handler.neuron_worker import NeuronWorker
+        from aphrodite.worker.neuron_worker import NeuronWorker
         distributed_init_method = get_distributed_init_method(
             get_ip(), get_open_port())
         self.driver_worker = NeuronWorker(

+ 1 - 1
aphrodite/executor/openvino_executor.py

@@ -33,7 +33,7 @@ class OpenVINOExecutor(ExecutorBase):
         self._init_worker()
 
     def _init_worker(self):
-        from aphrodite.task_handler.openvino_worker import OpenVINOWorker
+        from aphrodite.worker.openvino_worker import OpenVINOWorker
 
         assert (
             self.parallel_config.world_size == 1

+ 1 - 1
aphrodite/executor/ray_tpu_executor.py

@@ -70,7 +70,7 @@ class RayTPUExecutor(TPUExecutor):
             )
 
             assert self.speculative_config is None
-            worker_module_name = "aphrodite.task_handler.tpu_worker"
+            worker_module_name = "aphrodite.worker.tpu_worker"
             worker_class_name = "TPUWorker"
 
             # GKE does not fetch environment information from metadata server

+ 2 - 2
aphrodite/executor/ray_utils.py

@@ -11,7 +11,7 @@ from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
 from aphrodite.common.utils import get_ip, is_hip, is_xpu
 from aphrodite.executor.msgspec_utils import decode_hook, encode_hook
 from aphrodite.platforms import current_platform
-from aphrodite.task_handler.worker_base import WorkerWrapperBase
+from aphrodite.worker.worker_base import WorkerWrapperBase
 
 PG_WAIT_TIMEOUT = 1800
 
@@ -22,7 +22,7 @@ try:
     from ray.util.placement_group import PlacementGroup
 
     class RayWorkerWrapper(WorkerWrapperBase):
-        """Ray wrapper for aphrodite.task_handler.Worker, allowing Worker to be
+        """Ray wrapper for aphrodite.worker.Worker, allowing Worker to be
         lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
 
         def __init__(self, *args, **kwargs) -> None:

+ 1 - 1
aphrodite/executor/tpu_executor.py

@@ -60,7 +60,7 @@ class TPUExecutor(ExecutorBase):
         rank: int = 0,
         distributed_init_method: Optional[str] = None,
     ):
-        from aphrodite.task_handler.tpu_worker import TPUWorker
+        from aphrodite.worker.tpu_worker import TPUWorker
 
         worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
                                                      distributed_init_method))

+ 2 - 2
aphrodite/executor/xpu_executor.py

@@ -12,7 +12,7 @@ from aphrodite.common.utils import make_async
 from aphrodite.executor.executor_base import ExecutorAsyncBase
 from aphrodite.executor.gpu_executor import GPUExecutor
 from aphrodite.modeling.layers.sampler import SamplerOutput
-from aphrodite.task_handler.worker_base import WorkerBase
+from aphrodite.worker.worker_base import WorkerBase
 
 
 class XPUExecutor(GPUExecutor):
@@ -56,7 +56,7 @@ class XPUExecutor(GPUExecutor):
             raise NotImplementedError(
                 "XPU does not support speculative decoding")
         else:
-            worker_module_name = "aphrodite.task_handler.xpu_worker"
+            worker_module_name = "aphrodite.worker.xpu_worker"
             worker_class_name = "XPUWorker"
         return (worker_module_name, worker_class_name, worker_class_fn)
 

+ 2 - 2
aphrodite/modeling/models/jamba.py

@@ -37,8 +37,8 @@ from aphrodite.modeling.models.mamba_cache import MambaCacheManager
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
-from aphrodite.task_handler.model_runner import (_BATCH_SIZES_TO_CAPTURE,
-                                                 _get_graph_batch_size)
+from aphrodite.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
+                                           _get_graph_batch_size)
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 

+ 2 - 2
aphrodite/modeling/models/mamba.py

@@ -32,8 +32,8 @@ from aphrodite.modeling.models.mamba_cache import MambaCacheManager
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.modeling.utils import set_weight_attrs
 from aphrodite.quantization.base_config import QuantizationConfig
-from aphrodite.task_handler.model_runner import (_BATCH_SIZES_TO_CAPTURE,
-                                                 _get_graph_batch_size)
+from aphrodite.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
+                                           _get_graph_batch_size)
 
 KVCache = Tuple[torch.Tensor, torch.Tensor]
 

+ 1 - 1
aphrodite/spec_decode/batch_expansion.py

@@ -13,7 +13,7 @@ from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
                                               SpeculativeScorer,
                                               SpeculativeScores)
 from aphrodite.spec_decode.util import nvtx_range, split_batch_by_proposal_len
-from aphrodite.task_handler.worker_base import WorkerBase
+from aphrodite.worker.worker_base import WorkerBase
 
 SeqId = int
 TargetSeqId = int

+ 1 - 1
aphrodite/spec_decode/draft_model_runner.py

@@ -18,7 +18,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
 from aphrodite.common.sequence import ExecuteModelRequest, IntermediateTensors
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.multimodal import MultiModalInputs
-from aphrodite.task_handler.model_runner import (
+from aphrodite.worker.model_runner import (
     ModelInputForGPUWithSamplingMetadata, ModelRunner)
 
 # A flag to enable debug prints for the updated input tensors

+ 1 - 1
aphrodite/spec_decode/medusa_worker.py

@@ -10,7 +10,7 @@ from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.spec_decode.interfaces import SpeculativeProposals
 from aphrodite.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
 from aphrodite.spec_decode.top1_proposer import Top1Proposer
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.worker import Worker
 
 
 class MedusaWorker(NonLLMProposerWorkerBase, Worker):

+ 1 - 1
aphrodite/spec_decode/multi_step_worker.py

@@ -12,7 +12,7 @@ from aphrodite.spec_decode.interfaces import (SpeculativeProposals,
                                               SpeculativeProposer)
 from aphrodite.spec_decode.proposer_worker_base import ProposerWorkerBase
 from aphrodite.spec_decode.top1_proposer import Top1Proposer
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.worker import Worker
 
 
 class MultiStepWorker(Worker, ProposerWorkerBase):

+ 1 - 1
aphrodite/spec_decode/proposer_worker_base.py

@@ -4,7 +4,7 @@ from typing import List, Optional, Set, Tuple
 from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.spec_decode.interfaces import SpeculativeProposer
-from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
+from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase
 
 
 class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer):

+ 2 - 3
aphrodite/spec_decode/spec_decode_worker.py

@@ -35,9 +35,8 @@ from aphrodite.spec_decode.util import (Timer, create_sequence_group_output,
                                         get_all_num_logprobs,
                                         get_sampled_token_logprobs, nvtx_range,
                                         split_batch_by_proposal_len)
-from aphrodite.task_handler.worker import Worker
-from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
-                                                WorkerBase)
+from aphrodite.worker.worker import Worker
+from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase
 
 
 def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":

+ 1 - 1
aphrodite/spec_decode/target_model_runner.py

@@ -4,7 +4,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      LoRAConfig, ModelConfig, ParallelConfig,
                                      PromptAdapterConfig, SchedulerConfig)
 from aphrodite.common.sequence import SequenceGroupMetadata
-from aphrodite.task_handler.model_runner import (
+from aphrodite.worker.model_runner import (
     ModelInputForGPUWithSamplingMetadata, ModelRunner)
 
 

+ 0 - 0
aphrodite/task_handler/__init__.py → aphrodite/worker/__init__.py


+ 0 - 0
aphrodite/task_handler/cache_engine.py → aphrodite/worker/cache_engine.py


+ 1 - 1
aphrodite/task_handler/cpu_model_runner.py → aphrodite/worker/cpu_model_runner.py

@@ -16,7 +16,7 @@ from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                                   MultiModalInputs)
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase,
     _add_attn_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict,

+ 2 - 2
aphrodite/task_handler/cpu_worker.py → aphrodite/worker/cpu_worker.py

@@ -14,8 +14,8 @@ from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE
 from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    init_distributed_environment)
 from aphrodite.modeling import set_random_seed
-from aphrodite.task_handler.cpu_model_runner import CPUModelRunner
-from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
+from aphrodite.worker.cpu_model_runner import CPUModelRunner
+from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
                                                 LoraNotSupportedWorkerBase,
                                                 WorkerInput)
 

+ 1 - 1
aphrodite/task_handler/embedding_model_runner.py → aphrodite/worker/embedding_model_runner.py

@@ -11,7 +11,7 @@ from aphrodite.common.sequence import (IntermediateTensors, PoolerOutput,
                                        SequenceData, SequenceGroupMetadata)
 from aphrodite.modeling.pooling_metadata import PoolingMetadata
 from aphrodite.multimodal import MultiModalInputs
-from aphrodite.task_handler.model_runner import (GPUModelRunnerBase,
+from aphrodite.worker.model_runner import (GPUModelRunnerBase,
                                                  ModelInputForGPU,
                                                  ModelInputForGPUBuilder)
 

+ 3 - 3
aphrodite/task_handler/enc_dec_model_runner.py → aphrodite/worker/enc_dec_model_runner.py

@@ -24,13 +24,13 @@ from aphrodite.inputs import INPUT_REGISTRY, InputRegistry
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
-from aphrodite.task_handler.model_runner import (
+from aphrodite.worker.model_runner import (
     GPUModelRunnerBase, ModelInputForGPUBuilder,
     ModelInputForGPUWithSamplingMetadata)
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     _add_attn_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict)
-from aphrodite.task_handler.utils import assert_enc_dec_mr_supported_scenario
+from aphrodite.worker.utils import assert_enc_dec_mr_supported_scenario
 
 
 @dataclasses.dataclass(frozen=True)

+ 1 - 1
aphrodite/task_handler/model_runner.py → aphrodite/worker/model_runner.py

@@ -51,7 +51,7 @@ from aphrodite.prompt_adapter.layers import PromptAdapterMapping
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
 from aphrodite.prompt_adapter.worker_manager import (
     LRUCacheWorkerPromptAdapterManager)
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
     _add_attn_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict,

+ 0 - 0
aphrodite/task_handler/model_runner_base.py → aphrodite/worker/model_runner_base.py


+ 2 - 2
aphrodite/task_handler/multi_step_model_runner.py → aphrodite/worker/multi_step_model_runner.py

@@ -25,9 +25,9 @@ from aphrodite.modeling.layers.sampler import (PromptLogprobs, SampleLogprobs,
                                                get_logprobs,
                                                get_pythonized_sample_results)
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
-from aphrodite.task_handler.model_runner import (
+from aphrodite.worker.model_runner import (
     GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata)
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     BroadcastableModelInput, _init_attn_metadata_from_tensor_dict,
     _init_frozen_model_input_from_tensor_dict,
     _init_sampling_metadata_from_tensor_dict)

+ 3 - 3
aphrodite/task_handler/multi_step_worker.py → aphrodite/worker/multi_step_worker.py

@@ -7,10 +7,10 @@ import torch
 from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.distributed import broadcast_tensor_dict, get_pp_group
 from aphrodite.modeling.layers.sampler import SamplerOutput
-from aphrodite.task_handler.model_runner_base import BroadcastableModelInput
-from aphrodite.task_handler.multi_step_model_runner import (
+from aphrodite.worker.model_runner_base import BroadcastableModelInput
+from aphrodite.worker.multi_step_model_runner import (
     MultiStepModelRunner, StatefulModelInput)
-from aphrodite.task_handler.worker import Worker, WorkerInput
+from aphrodite.worker.worker import Worker, WorkerInput
 
 
 @dataclass

+ 1 - 1
aphrodite/task_handler/neuron_model_runner.py → aphrodite/worker/neuron_model_runner.py

@@ -16,7 +16,7 @@ from aphrodite.modeling.model_loader.neuron import get_neuron_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                                   MultiModalInputs)
-from aphrodite.task_handler.model_runner_base import (ModelRunnerBase,
+from aphrodite.worker.model_runner_base import (ModelRunnerBase,
                                                       ModelRunnerInputBase)
 
 if TYPE_CHECKING:

+ 2 - 2
aphrodite/task_handler/neuron_worker.py → aphrodite/worker/neuron_worker.py

@@ -10,8 +10,8 @@ from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    init_distributed_environment)
 from aphrodite.modeling import set_random_seed
-from aphrodite.task_handler.neuron_model_runner import NeuronModelRunner
-from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
+from aphrodite.worker.neuron_model_runner import NeuronModelRunner
+from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
                                                 LoraNotSupportedWorkerBase,
                                                 WorkerInput)
 

+ 0 - 0
aphrodite/task_handler/openvino_model_runner.py → aphrodite/worker/openvino_model_runner.py


+ 2 - 2
aphrodite/task_handler/openvino_worker.py → aphrodite/worker/openvino_worker.py

@@ -15,8 +15,8 @@ from aphrodite.distributed import (broadcast_tensor_dict,
                                    init_distributed_environment)
 from aphrodite.modeling import set_random_seed
 from aphrodite.modeling.layers.sampler import SamplerOutput
-from aphrodite.task_handler.openvino_model_runner import OpenVINOModelRunner
-from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
+from aphrodite.worker.openvino_model_runner import OpenVINOModelRunner
+from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase
 
 
 class OpenVINOCacheEngine:

+ 1 - 1
aphrodite/task_handler/tpu_model_runner.py → aphrodite/worker/tpu_model_runner.py

@@ -23,7 +23,7 @@ from aphrodite.compilation.wrapper import (
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase,
     _add_attn_metadata_broadcastable_dict,
     _init_attn_metadata_from_tensor_dict)

+ 2 - 2
aphrodite/task_handler/tpu_worker.py → aphrodite/worker/tpu_worker.py

@@ -14,8 +14,8 @@ from aphrodite.common.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
 from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    init_distributed_environment)
 from aphrodite.modeling import set_random_seed
-from aphrodite.task_handler.tpu_model_runner import TPUModelRunner
-from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
+from aphrodite.worker.tpu_model_runner import TPUModelRunner
+from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
                                                 LoraNotSupportedWorkerBase,
                                                 WorkerInput)
 

+ 1 - 1
aphrodite/task_handler/utils.py → aphrodite/worker/utils.py

@@ -3,7 +3,7 @@ Worker-related helper functions.
 '''
 
 from aphrodite.common.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS
-from aphrodite.task_handler.model_runner import GPUModelRunnerBase
+from aphrodite.worker.model_runner import GPUModelRunnerBase
 
 
 def assert_enc_dec_mr_supported_scenario(

+ 5 - 5
aphrodite/task_handler/worker.py → aphrodite/worker/worker.py

@@ -26,12 +26,12 @@ from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader.tensorizer import TensorizerConfig
 from aphrodite.platforms import current_platform
 from aphrodite.prompt_adapter.request import PromptAdapterRequest
-from aphrodite.task_handler.cache_engine import CacheEngine
-from aphrodite.task_handler.embedding_model_runner import EmbeddingModelRunner
-from aphrodite.task_handler.enc_dec_model_runner import (
+from aphrodite.worker.cache_engine import CacheEngine
+from aphrodite.worker.embedding_model_runner import EmbeddingModelRunner
+from aphrodite.worker.enc_dec_model_runner import (
     EncoderDecoderModelRunner)
-from aphrodite.task_handler.model_runner import GPUModelRunnerBase, ModelRunner
-from aphrodite.task_handler.worker_base import (LocalOrDistributedWorkerBase,
+from aphrodite.worker.model_runner import GPUModelRunnerBase, ModelRunner
+from aphrodite.worker.worker_base import (LocalOrDistributedWorkerBase,
                                                 WorkerInput)
 
 

+ 1 - 1
aphrodite/task_handler/worker_base.py → aphrodite/worker/worker_base.py

@@ -15,7 +15,7 @@ from aphrodite.distributed import (broadcast_tensor_dict, get_pp_group,
 from aphrodite.lora.request import LoRARequest
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.platforms import current_platform
-from aphrodite.task_handler.model_runner_base import (BroadcastableModelInput,
+from aphrodite.worker.model_runner_base import (BroadcastableModelInput,
                                                       ModelRunnerBase,
                                                       ModelRunnerInputBase)
 

+ 2 - 2
aphrodite/task_handler/xpu_model_runner.py → aphrodite/worker/xpu_model_runner.py

@@ -22,9 +22,9 @@ from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.model_loader import get_model
 from aphrodite.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
                                   MultiModalInputs, MultiModalRegistry)
-from aphrodite.task_handler.model_runner import (AttentionMetadata,
+from aphrodite.worker.model_runner import (AttentionMetadata,
                                                  SamplingMetadata)
-from aphrodite.task_handler.model_runner_base import (
+from aphrodite.worker.model_runner_base import (
     ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
     _add_attn_metadata_broadcastable_dict,
     _add_sampling_metadata_broadcastable_dict,

+ 4 - 4
aphrodite/task_handler/xpu_worker.py → aphrodite/worker/xpu_worker.py

@@ -17,10 +17,10 @@ from aphrodite.distributed import (ensure_model_parallel_initialized,
                                    init_distributed_environment)
 from aphrodite.distributed.parallel_state import get_pp_group
 from aphrodite.modeling import set_random_seed
-from aphrodite.task_handler.cache_engine import CacheEngine
-from aphrodite.task_handler.worker import Worker
-from aphrodite.task_handler.worker_base import LoraNotSupportedWorkerBase
-from aphrodite.task_handler.xpu_model_runner import XPUModelRunner
+from aphrodite.worker.cache_engine import CacheEngine
+from aphrodite.worker.worker import Worker
+from aphrodite.worker.worker_base import LoraNotSupportedWorkerBase
+from aphrodite.worker.xpu_model_runner import XPUModelRunner
 
 
 class XPUWorker(LoraNotSupportedWorkerBase, Worker):

+ 1 - 1
docs/pages/usage/debugging.md

@@ -28,7 +28,7 @@ Set these environment variables:
 
 ***
 
-If your instance crashes and the error trace shows somewhere around `self.graph.replay()` in `aphrodite/task_handler/model_runner.py`, then it's very likely a CUDA error inside the CUDAGraph. To know the particular operation that causes the error, you can add `--enforce-eager` in the CLI, `- enforce_eager: true` in the YAML config, or `enforce_eager=True` in the `LLM` class. This will disable CUDAGraph optimization, which might make it easier to find the root cause.
+If your instance crashes and the error trace shows somewhere around `self.graph.replay()` in `aphrodite/worker/model_runner.py`, then it's very likely a CUDA error inside the CUDAGraph. To know the particular operation that causes the error, you can add `--enforce-eager` in the CLI, `- enforce_eager: true` in the YAML config, or `enforce_eager=True` in the `LLM` class. This will disable CUDAGraph optimization, which might make it easier to find the root cause.
 
 Here's some other common issues that might cause freezes:
 

+ 1 - 1
kernels/attention/attention_kernels.cu

@@ -728,7 +728,7 @@ void paged_attention_v1_launcher(
   int logits_size = padded_max_seq_len * sizeof(float);
   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
   // Python-side check in
-  // aphrodite.task_handler.worker._check_if_can_support_max_seq_len Keep that
+  // aphrodite.worker.worker._check_if_can_support_max_seq_len Keep that
   // in sync with the logic here!
   int shared_mem_size = std::max(logits_size, outputs_size);
 

+ 331 - 380
kernels/backup/attention_kernels.cu

@@ -1,5 +1,6 @@
 /*
- * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Adapted from
+ * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
  * Copyright (c) 2023, The PygmalionAI team.
  * Copyright (c) 2023, The vLLM team.
  * Copyright (c) 2020-2023, NVIDIA CORPORATION.  All rights reserved.
@@ -17,7 +18,7 @@
  * limitations under the License.
  */
 #ifdef USE_ROCM
-#include <hip/hip_runtime.h>
+  #include <hip/hip_runtime.h>
 #endif
 
 #include <torch/extension.h>
@@ -28,15 +29,15 @@
 #include "attention_utils.cuh"
 #include "../quantization/int8_kvcache/quant_utils.cuh"
 #ifdef ENABLE_FP8_E5M2
-#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+  #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
 #endif
 
 #include <algorithm>
 
 #ifndef USE_ROCM
-#define WARP_SIZE 32
+  #define WARP_SIZE 32
 #else
-#define WARP_SIZE warpSize
+  #define WARP_SIZE warpSize
 #endif
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
@@ -47,12 +48,13 @@ enum kv_cache_dtype {
 #ifdef ENABLE_FP8_E5M2
   FP8_E5M2,
 #endif
-  INT8};
+  INT8
+};
 
 namespace aphrodite {
 
 // Utility function for attention softmax.
-template<int NUM_WARPS>
+template <int NUM_WARPS>
 inline __device__ float block_sum(float* red_smem, float sum) {
   // Decompose the thread index into warp / lane.
   int warp = threadIdx.x / WARP_SIZE;
@@ -89,34 +91,29 @@ inline __device__ float block_sum(float* red_smem, float sum) {
 
 // TODO: Merge the last two dimensions of the grid.
 // Grid: (num_heads, num_seqs, max_num_partitions).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int PARTITION_SIZE = 0> // Zero means no partitioning.
+template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
+          int NUM_THREADS, kv_cache_dtype KV_CACHE_DTYPE,
+          int PARTITION_SIZE = 0>  // Zero means no partitioning.
 __device__ void paged_attention_kernel(
-  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
-  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
+    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
+    float* __restrict__ max_logits,  // [num_seqs, num_heads,
+                                     // max_num_partitions]
+    scalar_t* __restrict__ out,  // [num_seqs, num_heads, max_num_partitions,
+                                 // head_size]
+    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
+    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size/x, block_size, x]
+    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size, block_size]
+    const int num_kv_heads,               // [num_heads]
+    const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride,
+    const float k_scale = 1.0f, const float k_zp = 0.0f,
+    const float v_scale = 1.0f, const float v_zp = 0.0f) {
   const int seq_idx = blockIdx.y;
   const int partition_idx = blockIdx.z;
   const int max_num_partitions = gridDim.z;
@@ -128,22 +125,29 @@ __device__ void paged_attention_kernel(
   }
 
   const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
-  const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
+  const int num_blocks_per_partition =
+      USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks;
 
   // [start_block_idx, end_block_idx) is the range of blocks to process.
-  const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
-  const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
+  const int start_block_idx =
+      USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
+  const int end_block_idx =
+      MIN(start_block_idx + num_blocks_per_partition, num_context_blocks);
   const int num_blocks = end_block_idx - start_block_idx;
 
   // [start_token_idx, end_token_idx) is the range of tokens to process.
   const int start_token_idx = start_block_idx * BLOCK_SIZE;
-  const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
+  const int end_token_idx =
+      MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len);
   const int num_tokens = end_token_idx - start_token_idx;
 
   constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
-  constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
+  constexpr int NUM_THREAD_GROUPS =
+      NUM_THREADS / THREAD_GROUP_SIZE;  // Note: This assumes THREAD_GROUP_SIZE
+                                        // divides NUM_THREADS
   assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
-  constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
+  constexpr int NUM_TOKENS_PER_THREAD_GROUP =
+      DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
   const int thread_idx = threadIdx.x;
   const int warp_idx = thread_idx / WARP_SIZE;
@@ -153,13 +157,14 @@ __device__ void paged_attention_kernel(
   const int num_heads = gridDim.x;
   const int num_queries_per_kv = num_heads / num_kv_heads;
   const int kv_head_idx = head_idx / num_queries_per_kv;
-  const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
+  const float alibi_slope =
+      alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
 
   // A vector type to store a part of a key or a query.
-  // The vector size is configured in such a way that the threads in a thread group
-  // fetch or compute 16 bytes at a time.
-  // For example, if the size of a thread group is 4 and the data type is half,
-  // then the vector size is 16 / (4 * sizeof(half)) == 2.
+  // The vector size is configured in such a way that the threads in a thread
+  // group fetch or compute 16 bytes at a time. For example, if the size of a
+  // thread group is 4 and the data type is half, then the vector size is 16 /
+  // (4 * sizeof(half)) == 2.
   constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
   using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
   using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
@@ -173,18 +178,21 @@ __device__ void paged_attention_kernel(
 
   // Load the query to registers.
   // Each thread in a thread group has a different part of the query.
-  // For example, if the the thread group size is 4, then the first thread in the group
-  // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
-  // th vectors of the query, and so on.
-  // NOTE: Because q is split from a qkv tensor, it may not be contiguous.
+  // For example, if the the thread group size is 4, then the first thread in
+  // the group has 0, 4, 8, ... th vectors of the query, and the second thread
+  // has 1, 5, 9, ... th vectors of the query, and so on. NOTE: Because q is
+  // split from a qkv tensor, it may not be contiguous.
   const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
   __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
 #pragma unroll
-  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
+  for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
+       i += NUM_THREAD_GROUPS) {
     const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
-    q_vecs[thread_group_offset][i] = *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
+    q_vecs[thread_group_offset][i] =
+        *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
   }
-  __syncthreads(); // TODO: possible speedup if this is replaced with a memory wall right before we use q_vecs
+  __syncthreads();  // TODO: possible speedup if this is replaced with a memory
+                    // wall right before we use q_vecs
 
   // Memory planning.
   extern __shared__ char shared_mem[];
@@ -203,51 +211,60 @@ __device__ void paged_attention_kernel(
   // Each thread group in a warp fetches a key from the block, and computes
   // dot product with the query.
   const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
-  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
+       block_idx += NUM_WARPS) {
     // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by large numbers
-    // (e.g., kv_block_stride).
-    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+    // because int32 can lead to overflow when this variable is multiplied by
+    // large numbers (e.g., kv_block_stride).
+    const int64_t physical_block_number =
+        static_cast<int64_t>(block_table[block_idx]);
 
     // Load a key to registers.
     // Each thread in a thread group has a different part of the key.
-    // For example, if the the thread group size is 4, then the first thread in the group
-    // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
-    // vectors of the key, and so on.
+    // For example, if the the thread group size is 4, then the first thread in
+    // the group has 0, 4, 8, ... th vectors of the key, and the second thread
+    // has 1, 5, 9, ... th vectors of the key, and so on.
     for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
-      const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
+      const int physical_block_offset =
+          (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
       const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
       K_vec k_vecs[NUM_VECS_PER_THREAD];
 
 #pragma unroll
       for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
-        const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
-                                       + kv_head_idx * kv_head_stride
-                                       + physical_block_offset * x;
+        const cache_t* k_ptr =
+            k_cache + physical_block_number * kv_block_stride +
+            kv_head_idx * kv_head_stride + physical_block_offset * x;
         const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
         const int offset1 = (vec_idx * VEC_SIZE) / x;
         const int offset2 = (vec_idx * VEC_SIZE) % x;
         if constexpr (KV_CACHE_DTYPE == INT8) {
-          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
+              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
           using Dequant_vec = typename FloatVec<Quant_vec>::Type;
           Dequant_vec k_vec_dequant = int8::dequant(k_vec_quant, k_scale, k_zp);
           k_vecs[j] = int8::vec_conversion<K_vec, Dequant_vec>(k_vec_dequant);
 #ifdef ENABLE_FP8_E5M2
         } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
-          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
+              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
           // Vector conversion from Quant_vec to K_vec.
-          k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
+          k_vecs[j] =
+              fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
 #endif
         } else {
-          k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+          k_vecs[j] = *reinterpret_cast<const K_vec*>(
+              k_ptr + offset1 * BLOCK_SIZE * x + offset2);
         }
       }
 
       // Compute dot product.
       // This includes a reduction across the threads in the same thread group.
-      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
+      float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
+                             q_vecs[thread_group_offset], k_vecs);
       // Add the ALiBi bias if slopes are given.
-      qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
+      qk +=
+          (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
 
       if (thread_group_offset == 0) {
         // Store the partial reductions to shared memory.
@@ -300,13 +317,12 @@ __device__ void paged_attention_kernel(
 
   // If partitioning is enabled, store the max logit and exp_sum.
   if (USE_PARTITIONING && thread_idx == 0) {
-    float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
-                                       + head_idx * max_num_partitions
-                                       + partition_idx;
+    float* max_logits_ptr = max_logits +
+                            seq_idx * num_heads * max_num_partitions +
+                            head_idx * max_num_partitions + partition_idx;
     *max_logits_ptr = qk_max;
-    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
-                                   + head_idx * max_num_partitions
-                                   + partition_idx;
+    float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
+                          head_idx * max_num_partitions + partition_idx;
     *exp_sums_ptr = exp_sum;
   }
 
@@ -319,7 +335,8 @@ __device__ void paged_attention_kernel(
 
   constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
   constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
-  constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
+  constexpr int NUM_ROWS_PER_THREAD =
+      DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
 
   // NOTE: We use FP32 for the accumulator for better accuracy.
   float accs[NUM_ROWS_PER_THREAD];
@@ -330,18 +347,21 @@ __device__ void paged_attention_kernel(
 
   scalar_t zero_value;
   zero(zero_value);
-  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+  for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
+       block_idx += NUM_WARPS) {
     // NOTE: The block number is stored in int32. However, we cast it to int64
-    // because int32 can lead to overflow when this variable is multiplied by large numbers
-    // (e.g., kv_block_stride).
-    const int64_t physical_block_number = static_cast<int64_t>(block_table[block_idx]);
+    // because int32 can lead to overflow when this variable is multiplied by
+    // large numbers (e.g., kv_block_stride).
+    const int64_t physical_block_number =
+        static_cast<int64_t>(block_table[block_idx]);
     const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
     const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
     L_vec logits_vec;
-    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
+    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
+                                                           start_token_idx));
 
-    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
-                                   + kv_head_idx * kv_head_stride;
+    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
+                           kv_head_idx * kv_head_stride;
 #pragma unroll
     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@@ -350,26 +370,32 @@ __device__ void paged_attention_kernel(
         V_vec v_vec;
         if constexpr (KV_CACHE_DTYPE == INT8) {
           // dequant and conversion
-          V_quant_vec v_vec_quant = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+          V_quant_vec v_vec_quant =
+              *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
           using V_dequant_vec = typename FloatVec<V_quant_vec>::Type;
-          V_dequant_vec v_vec_dequant = int8::dequant(v_vec_quant, v_scale, v_zp);
+          V_dequant_vec v_vec_dequant =
+              int8::dequant(v_vec_quant, v_scale, v_zp);
           v_vec = int8::vec_conversion<V_vec, V_dequant_vec>(v_vec_dequant);
 #ifdef ENABLE_FP8_E5M2
         } else if constexpr (KV_CACHE_DTYPE == FP8_E5M2) {
-          V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
+          V_quant_vec v_quant_vec =
+              *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
           // Vector conversion from V_quant_vec to V_vec.
-          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
+          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(
+              v_quant_vec);
 #endif
         } else {
           v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
         }
         if (block_idx == num_context_blocks - 1) {
           // NOTE: When v_vec contains the tokens that are out of the context,
-          // we should explicitly zero out the values since they may contain NaNs.
+          // we should explicitly zero out the values since they may contain
+          // NaNs.
           scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
 #pragma unroll
           for (int j = 0; j < V_VEC_SIZE; j++) {
-            v_vec_ptr[j] = token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
+            v_vec_ptr[j] =
+                token_idx + j < context_len ? v_vec_ptr[j] : zero_value;
           }
         }
         accs[i] += dot(logits_vec, v_vec);
@@ -426,9 +452,9 @@ __device__ void paged_attention_kernel(
 
   // Write the final output.
   if (warp_idx == 0) {
-    scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                            + head_idx * max_num_partitions * HEAD_SIZE
-                            + partition_idx * HEAD_SIZE;
+    scalar_t* out_ptr =
+        out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+        head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
 #pragma unroll
     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
@@ -440,85 +466,77 @@ __device__ void paged_attention_kernel(
 }
 
 // Grid: (num_heads, num_seqs, 1).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE>
+template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
+          int NUM_THREADS,
+          kv_cache_dtype KV_CACHE_DTYPE>
 __global__ void paged_attention_v1_kernel(
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>(
-    /* exp_sums */ nullptr, /* max_logits */ nullptr,
-    out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
-    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
+    scalar_t* __restrict__ out,           // [num_seqs, num_heads, head_size]
+    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
+    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size/x, block_size, x]
+    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size, block_size]
+    const int num_kv_heads,               // [num_heads]
+    const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride,
+    const float k_scale, const float k_zp, const float v_scale,
+    const float v_zp) {
+  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
+                         KV_CACHE_DTYPE>(
+      /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
+      v_cache, num_kv_heads, scale, block_tables, context_lens,
+      max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
+      kv_head_stride, k_scale, k_zp, v_scale, v_zp);
 }
 
 // Grid: (num_heads, num_seqs, max_num_partitions).
-template<
-  typename scalar_t,
-  typename cache_t,
-  int HEAD_SIZE,
-  int BLOCK_SIZE,
-  int NUM_THREADS,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int PARTITION_SIZE>
+template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
+          int NUM_THREADS, kv_cache_dtype KV_CACHE_DTYPE,
+          int PARTITION_SIZE>
 __global__ void paged_attention_v2_kernel(
-  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
-  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
-  scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
-  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
-  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
-  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
-  const int num_kv_heads,                 // [num_heads]
-  const float scale,
-  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_blocks_per_seq,
-  const float* __restrict__ alibi_slopes, // [num_heads]
-  const int q_stride,
-  const int kv_block_stride,
-  const int kv_head_stride,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
-  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE, PARTITION_SIZE>(
-    exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
-    block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
-    q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
+    float* __restrict__ exp_sums,  // [num_seqs, num_heads, max_num_partitions]
+    float* __restrict__ max_logits,       // [num_seqs, num_heads,
+                                          // max_num_partitions]
+    scalar_t* __restrict__ tmp_out,       // [num_seqs, num_heads,
+                                          // max_num_partitions, head_size]
+    const scalar_t* __restrict__ q,       // [num_seqs, num_heads, head_size]
+    const cache_t* __restrict__ k_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size/x, block_size, x]
+    const cache_t* __restrict__ v_cache,  // [num_blocks, num_kv_heads,
+                                          // head_size, block_size]
+    const int num_kv_heads,               // [num_heads]
+    const float scale,
+    const int* __restrict__ block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_blocks_per_seq,
+    const float* __restrict__ alibi_slopes,  // [num_heads]
+    const int q_stride, const int kv_block_stride, const int kv_head_stride,
+    const float k_scale, const float k_zp, const float v_scale,
+    const float v_zp) {
+  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
+                         KV_CACHE_DTYPE, PARTITION_SIZE>(
+      exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
+      block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
+      q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale, v_zp);
 }
 
 // Grid: (num_heads, num_seqs).
-template<
-  typename scalar_t,
-  int HEAD_SIZE,
-  int NUM_THREADS,
-  int PARTITION_SIZE>
+template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
+          int PARTITION_SIZE>
 __global__ void paged_attention_v2_reduce_kernel(
-  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
-  const float* __restrict__ exp_sums,     // [num_seqs, num_heads, max_num_partitions]
-  const float* __restrict__ max_logits,   // [num_seqs, num_heads, max_num_partitions]
-  const scalar_t* __restrict__ tmp_out,   // [num_seqs, num_heads, max_num_partitions, head_size]
-  const int* __restrict__ context_lens,   // [num_seqs]
-  const int max_num_partitions) {
+    scalar_t* __restrict__ out,            // [num_seqs, num_heads, head_size]
+    const float* __restrict__ exp_sums,    // [num_seqs, num_heads,
+                                           // max_num_partitions]
+    const float* __restrict__ max_logits,  // [num_seqs, num_heads,
+                                           // max_num_partitions]
+    const scalar_t* __restrict__ tmp_out,  // [num_seqs, num_heads,
+                                           // max_num_partitions, head_size]
+    const int* __restrict__ context_lens,  // [num_seqs]
+    const int max_num_partitions) {
   const int num_heads = gridDim.x;
   const int head_idx = blockIdx.x;
   const int seq_idx = blockIdx.y;
@@ -526,9 +544,11 @@ __global__ void paged_attention_v2_reduce_kernel(
   const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
   if (num_partitions == 1) {
     // No need to reduce. Only copy tmp_out to out.
-    scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
-    const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                                          + head_idx * max_num_partitions * HEAD_SIZE;
+    scalar_t* out_ptr =
+        out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+    const scalar_t* tmp_out_ptr =
+        tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+        head_idx * max_num_partitions * HEAD_SIZE;
     for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
       out_ptr[i] = tmp_out_ptr[i];
     }
@@ -547,8 +567,9 @@ __global__ void paged_attention_v2_reduce_kernel(
 
   // Load max logits to shared memory.
   float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
-  const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
-                                           + head_idx * max_num_partitions;
+  const float* max_logits_ptr = max_logits +
+                                seq_idx * num_heads * max_num_partitions +
+                                head_idx * max_num_partitions;
   float max_logit = -FLT_MAX;
   for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
     const float l = max_logits_ptr[i];
@@ -577,9 +598,11 @@ __global__ void paged_attention_v2_reduce_kernel(
   max_logit = APHRODITE_SHFL_SYNC(max_logit, 0);
 
   // Load rescaled exp sums to shared memory.
-  float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
-  const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
-                                       + head_idx * max_num_partitions;
+  float* shared_exp_sums =
+      reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
+  const float* exp_sums_ptr = exp_sums +
+                              seq_idx * num_heads * max_num_partitions +
+                              head_idx * max_num_partitions;
   float global_exp_sum = 0.0f;
   for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
     float l = shared_max_logits[i];
@@ -592,67 +615,47 @@ __global__ void paged_attention_v2_reduce_kernel(
   const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
 
   // Aggregate tmp_out to out.
-  const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
-                                        + head_idx * max_num_partitions * HEAD_SIZE;
-  scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+  const scalar_t* tmp_out_ptr =
+      tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
+      head_idx * max_num_partitions * HEAD_SIZE;
+  scalar_t* out_ptr =
+      out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
 #pragma unroll
   for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
     float acc = 0.0f;
     for (int j = 0; j < num_partitions; ++j) {
-      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
+      acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
+             inv_global_exp_sum;
     }
     from_float(out_ptr[i], acc);
   }
 }
 
-} // namespace aphrodite
-
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                        \
-  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                        \
-    ((void*)aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,    \
-      KV_CACHE_DTYPE>), shared_mem_size);                                                           \
-  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,              \
-  KV_CACHE_DTYPE><<<grid, block, shared_mem_size, stream>>>(                                        \
-    out_ptr,                                                                                        \
-    query_ptr,                                                                                      \
-    key_cache_ptr,                                                                                  \
-    value_cache_ptr,                                                                                \
-    num_kv_heads,                                                                                   \
-    scale,                                                                                          \
-    block_tables_ptr,                                                                               \
-    context_lens_ptr,                                                                               \
-    max_num_blocks_per_seq,                                                                         \
-    alibi_slopes_ptr,                                                                               \
-    q_stride,                                                                                       \
-    kv_block_stride,                                                                                \
-    kv_head_stride,                                                                                 \
-    k_scale,                                                                                        \
-    k_zp,                                                                                           \
-    v_scale,                                                                                        \
-    v_zp);
+}  // namespace aphrodite
+
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                 \
+  APHRODITE_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                 \
+      ((void*)aphrodite::paged_attention_v1_kernel<                          \
+          T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_CACHE_DTYPE>),  \
+      shared_mem_size);                                                      \
+  aphrodite::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,    \
+                                       NUM_THREADS, KV_CACHE_DTYPE>          \
+      <<<grid, block, shared_mem_size, stream>>>(                            \
+          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads,  \
+          scale, block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
+          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,       \
+          k_scale, k_zp, v_scale, v_zp);
 
 // TODO: Tune NUM_THREADS.
-template<
-  typename T,
-  typename CACHE_T,
-  int BLOCK_SIZE,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int NUM_THREADS = 128>
+template <typename T, typename CACHE_T, int BLOCK_SIZE,
+          kv_cache_dtype KV_CACHE_DTYPE, int NUM_THREADS = 128>
 void paged_attention_v1_launcher(
-  torch::Tensor& out,
-  torch::Tensor& query,
-  torch::Tensor& key_cache,
-  torch::Tensor& value_cache,
-  int num_kv_heads,
-  float scale,
-  torch::Tensor& block_tables,
-  torch::Tensor& context_lens,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
+    torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& context_lens,
+    int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const float k_scale, const float k_zp, const float v_scale,
+    const float v_zp) {
   int num_seqs = query.size(0);
   int num_heads = query.size(1);
   int head_size = query.size(2);
@@ -665,9 +668,10 @@ void paged_attention_v1_launcher(
   assert(head_size % thread_group_size == 0);
 
   // NOTE: alibi_slopes is optional.
-  const float* alibi_slopes_ptr = alibi_slopes ?
-    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
-    : nullptr;
+  const float* alibi_slopes_ptr =
+      alibi_slopes
+          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+          : nullptr;
 
   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
@@ -677,11 +681,13 @@ void paged_attention_v1_launcher(
   int* context_lens_ptr = context_lens.data_ptr<int>();
 
   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
+  int padded_max_context_len =
+      DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE;
   int logits_size = padded_max_context_len * sizeof(float);
   int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
-  // Python-side check in aphrodite.task_handler.worker._check_if_can_support_max_seq_len
-  // Keep that in sync with the logic here!
+  // Python-side check in
+  // aphrodite.worker.worker._check_if_can_support_max_seq_len Keep that in sync
+  // with the logic here!
   int shared_mem_size = std::max(logits_size, outputs_size);
 
   dim3 grid(num_heads, num_seqs, 1);
@@ -718,56 +724,44 @@ void paged_attention_v1_launcher(
 
 #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)             \
   paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(       \
-    out,                                                                     \
-    query,                                                                   \
-    key_cache,                                                               \
-    value_cache,                                                             \
-    num_kv_heads,                                                            \
-    scale,                                                                   \
-    block_tables,                                                            \
-    context_lens,                                                            \
-    max_context_len,                                                         \
-    alibi_slopes,                                                            \
-    k_scale,                                                                 \
-    k_zp,                                                                    \
-    v_scale,                                                                 \
-    v_zp);
+      out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
+      context_lens, max_context_len, alibi_slopes, k_scale, k_zp, v_scale,   \
+      v_zp);
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
-#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)       \
-  switch (block_size) {                                               \
-    case 8:                                                           \
-      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                \
-      break;                                                          \
-    case 16:                                                          \
-      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);               \
-      break;                                                          \
-    case 32:                                                          \
-      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);               \
-      break;                                                          \
-    default:                                                          \
-      TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
-      break;                                                          \
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)   \
+  switch (block_size) {                                           \
+    case 8:                                                       \
+      CALL_V1_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);            \
+      break;                                                      \
+    case 16:                                                      \
+      CALL_V1_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);           \
+      break;                                                      \
+    case 32:                                                      \
+      CALL_V1_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);           \
+      break;                                                      \
+    default:                                                      \
+      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+      break;                                                      \
   }
 
 void paged_attention_v1(
-  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
-  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
-  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
-  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
-  int num_kv_heads,               // [num_heads]
-  float scale,
-  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
-  torch::Tensor& context_lens,    // [num_seqs]
-  int block_size,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
+    torch::Tensor& out,    // [num_seqs, num_heads, head_size]
+    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
+    torch::Tensor&
+        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
+    torch::Tensor&
+        value_cache,   // [num_blocks, num_heads, head_size, block_size]
+    int num_kv_heads,  // [num_heads]
+    float scale,
+    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    torch::Tensor& context_lens,  // [num_seqs]
+    int block_size, int max_context_len,
+    const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, const float k_scale = 1.0f,
+    const float k_zp = 0.0f, const float v_scale = 1.0f,
+    const float v_zp = 0.0f) {
   if (kv_cache_dtype == "auto") {
     if (query.dtype() == at::ScalarType::Float) {
       CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, AUTO);
@@ -805,63 +799,33 @@ void paged_attention_v1(
   }
 }
 
-#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
-  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,        \
-  KV_CACHE_DTYPE, PARTITION_SIZE>                                                             \
-  <<<grid, block, shared_mem_size, stream>>>(                                                 \
-    exp_sums_ptr,                                                                             \
-    max_logits_ptr,                                                                           \
-    tmp_out_ptr,                                                                              \
-    query_ptr,                                                                                \
-    key_cache_ptr,                                                                            \
-    value_cache_ptr,                                                                          \
-    num_kv_heads,                                                                             \
-    scale,                                                                                    \
-    block_tables_ptr,                                                                         \
-    context_lens_ptr,                                                                         \
-    max_num_blocks_per_seq,                                                                   \
-    alibi_slopes_ptr,                                                                         \
-    q_stride,                                                                                 \
-    kv_block_stride,                                                                          \
-    kv_head_stride,                                                                           \
-    k_scale,                                                                                  \
-    k_zp,                                                                                     \
-    v_scale,                                                                                  \
-    v_zp);                                                                                    \
-  aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, PARTITION_SIZE>           \
-  <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                                   \
-    out_ptr,                                                                                  \
-    exp_sums_ptr,                                                                             \
-    max_logits_ptr,                                                                           \
-    tmp_out_ptr,                                                                              \
-    context_lens_ptr,                                                                         \
-    max_num_partitions);
-
-template<
-  typename T,
-  typename CACHE_T,
-  int BLOCK_SIZE,
-  kv_cache_dtype KV_CACHE_DTYPE,
-  int NUM_THREADS = 128,
-  int PARTITION_SIZE = 512>
+#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                   \
+  aphrodite::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,      \
+                                       NUM_THREADS, KV_CACHE_DTYPE,            \
+                                       PARTITION_SIZE>                         \
+      <<<grid, block, shared_mem_size, stream>>>(                              \
+          exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
+          value_cache_ptr, num_kv_heads, scale, block_tables_ptr,              \
+          context_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr,          \
+          q_stride, kv_block_stride, kv_head_stride, k_scale, k_zp, v_scale,   \
+          v_zp);                                                               \
+  aphrodite::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS,       \
+                                              PARTITION_SIZE>                  \
+      <<<reduce_grid, block, reduce_shared_mem_size, stream>>>(                \
+          out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,                  \
+          context_lens_ptr, max_num_partitions);
+
+template <typename T, typename CACHE_T, int BLOCK_SIZE,
+          kv_cache_dtype KV_CACHE_DTYPE, int NUM_THREADS = 128,
+          int PARTITION_SIZE = 512>
 void paged_attention_v2_launcher(
-  torch::Tensor& out,
-  torch::Tensor& exp_sums,
-  torch::Tensor& max_logits,
-  torch::Tensor& tmp_out,
-  torch::Tensor& query,
-  torch::Tensor& key_cache,
-  torch::Tensor& value_cache,
-  int num_kv_heads,
-  float scale,
-  torch::Tensor& block_tables,
-  torch::Tensor& context_lens,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const float k_scale,
-  const float k_zp,
-  const float v_scale,
-  const float v_zp) {
+    torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
+    torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
+    torch::Tensor& value_cache, int num_kv_heads, float scale,
+    torch::Tensor& block_tables, torch::Tensor& context_lens,
+    int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const float k_scale, const float k_zp, const float v_scale,
+    const float v_zp) {
   int num_seqs = query.size(0);
   int num_heads = query.size(1);
   int head_size = query.size(2);
@@ -874,9 +838,10 @@ void paged_attention_v2_launcher(
   assert(head_size % thread_group_size == 0);
 
   // NOTE: alibi_slopes is optional.
-  const float* alibi_slopes_ptr = alibi_slopes ?
-    reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
-    : nullptr;
+  const float* alibi_slopes_ptr =
+      alibi_slopes
+          ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
+          : nullptr;
 
   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
   float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
@@ -931,64 +896,50 @@ void paged_attention_v2_launcher(
   }
 }
 
-#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)                 \
-  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(           \
-    out,                                                                         \
-    exp_sums,                                                                    \
-    max_logits,                                                                  \
-    tmp_out,                                                                     \
-    query,                                                                       \
-    key_cache,                                                                   \
-    value_cache,                                                                 \
-    num_kv_heads,                                                                \
-    scale,                                                                       \
-    block_tables,                                                                \
-    context_lens,                                                                \
-    max_context_len,                                                             \
-    alibi_slopes,                                                                \
-    k_scale,                                                                     \
-    k_zp,                                                                        \
-    v_scale,                                                                     \
-    v_zp);
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE)         \
+  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, KV_CACHE_DTYPE>(   \
+      out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
+      num_kv_heads, scale, block_tables, context_lens, max_context_len,  \
+      alibi_slopes, k_scale, k_zp, v_scale, v_zp);
 
 // NOTE: To reduce the compilation time, we omitted block sizes
 // 1, 2, 4, 64, 128, 256.
-#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)             \
-  switch (block_size) {                                                     \
-    case 8:                                                                 \
-      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);                      \
-      break;                                                                \
-    case 16:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);                     \
-      break;                                                                \
-    case 32:                                                                \
-      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);                     \
-      break;                                                                \
-    default:                                                                \
-      TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
-      break;                                                                \
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_CACHE_DTYPE)   \
+  switch (block_size) {                                           \
+    case 8:                                                       \
+      CALL_V2_LAUNCHER(T, CACHE_T, 8, KV_CACHE_DTYPE);            \
+      break;                                                      \
+    case 16:                                                      \
+      CALL_V2_LAUNCHER(T, CACHE_T, 16, KV_CACHE_DTYPE);           \
+      break;                                                      \
+    case 32:                                                      \
+      CALL_V2_LAUNCHER(T, CACHE_T, 32, KV_CACHE_DTYPE);           \
+      break;                                                      \
+    default:                                                      \
+      TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+      break;                                                      \
   }
 
 void paged_attention_v2(
-  torch::Tensor& out,             // [num_seqs, num_heads, head_size]
-  torch::Tensor& exp_sums,        // [num_seqs, num_heads, max_num_partitions]
-  torch::Tensor& max_logits,      // [num_seqs, num_heads, max_num_partitions]
-  torch::Tensor& tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
-  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
-  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
-  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
-  int num_kv_heads,               // [num_heads]
-  float scale,
-  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
-  torch::Tensor& context_lens,    // [num_seqs]
-  int block_size,
-  int max_context_len,
-  const c10::optional<torch::Tensor>& alibi_slopes,
-  const std::string& kv_cache_dtype,
-  const float k_scale = 1.0f,
-  const float k_zp = 0.0f,
-  const float v_scale = 1.0f,
-  const float v_zp = 0.0f) {
+    torch::Tensor& out,         // [num_seqs, num_heads, head_size]
+    torch::Tensor& exp_sums,    // [num_seqs, num_heads, max_num_partitions]
+    torch::Tensor& max_logits,  // [num_seqs, num_heads, max_num_partitions]
+    torch::Tensor&
+        tmp_out,  // [num_seqs, num_heads, max_num_partitions, head_size]
+    torch::Tensor& query,  // [num_seqs, num_heads, head_size]
+    torch::Tensor&
+        key_cache,  // [num_blocks, num_heads, head_size/x, block_size, x]
+    torch::Tensor&
+        value_cache,   // [num_blocks, num_heads, head_size, block_size]
+    int num_kv_heads,  // [num_heads]
+    float scale,
+    torch::Tensor& block_tables,  // [num_seqs, max_num_blocks_per_seq]
+    torch::Tensor& context_lens,  // [num_seqs]
+    int block_size, int max_context_len,
+    const c10::optional<torch::Tensor>& alibi_slopes,
+    const std::string& kv_cache_dtype, const float k_scale = 1.0f,
+    const float k_zp = 0.0f, const float v_scale = 1.0f,
+    const float v_zp = 0.0f) {
   if (kv_cache_dtype == "auto") {
     if (query.dtype() == at::ScalarType::Float) {
       CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, AUTO);

+ 1 - 1
tests/engine/test_multiproc_workers.py

@@ -12,7 +12,7 @@ from aphrodite.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
 
 
 class DummyWorker:
-    """Dummy version of aphrodite.task_handler.worker.Worker"""
+    """Dummy version of aphrodite.worker.worker.Worker"""
 
     def __init__(self, rank: int):
         self.rank = rank

+ 1 - 1
tests/lora/conftest.py

@@ -256,7 +256,7 @@ def llama_2_7b_engine_extra_embeddings():
                              device_config=device_config,
                              **kwargs)
 
-    with patch("aphrodite.task_handler.model_runner.get_model",
+    with patch("aphrodite.worker.model_runner.get_model",
                get_model_patched):
         engine = aphrodite.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
     yield engine.llm_engine

+ 1 - 1
tests/lora/test_long_context.py

@@ -123,7 +123,7 @@ def lora_llm(long_context_infos):
 def test_rotary_emb_replaced(dist_init):
     """Verify rotary emb in all the layers are replaced"""
     from aphrodite.engine.args_tools import EngineArgs
-    from aphrodite.task_handler.model_runner import ModelRunner
+    from aphrodite.worker.model_runner import ModelRunner
     engine_args = EngineArgs("meta-llama/Llama-2-7b-hf",
                              long_lora_scaling_factors=(4.0, ),
                              enable_lora=True)

+ 1 - 1
tests/lora/test_worker.py

@@ -8,7 +8,7 @@ from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig,
                                      SchedulerConfig)
 from aphrodite.lora.models import LoRAMapping
 from aphrodite.lora.request import LoRARequest
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.worker import Worker
 
 
 @patch.dict(os.environ, {"RANK": "0"})

+ 1 - 1
tests/models/test_jamba.py

@@ -1,6 +1,6 @@
 import pytest
 
-from aphrodite.task_handler.model_runner import _get_graph_batch_size
+from aphrodite.worker.model_runner import _get_graph_batch_size
 from tests.models.utils import check_outputs_equal
 
 MODELS = ["ai21labs/Jamba-tiny-random"]

+ 1 - 1
tests/spec_decode/test_multi_step_worker.py

@@ -12,7 +12,7 @@ from aphrodite.modeling.utils import set_random_seed
 from aphrodite.spec_decode.draft_model_runner import TP1DraftModelRunner
 from aphrodite.spec_decode.multi_step_worker import MultiStepWorker
 from aphrodite.spec_decode.top1_proposer import Top1Proposer
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.worker import Worker
 
 from .utils import (assert_logprobs_dict_allclose, create_batch,
                     create_seq_group_metadata_from_prompts, create_worker,

+ 3 - 3
tests/spec_decode/utils.py

@@ -17,9 +17,9 @@ from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.modeling.layers.sampler import SamplerOutput
 from aphrodite.modeling.utils import set_random_seed
-from aphrodite.task_handler.cache_engine import CacheEngine
-from aphrodite.task_handler.model_runner import ModelRunner
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.cache_engine import CacheEngine
+from aphrodite.worker.model_runner import ModelRunner
+from aphrodite.worker.worker import Worker
 
 T = TypeVar("T", bound=Worker)
 

+ 1 - 2
tests/worker/test_encoder_decoder_model_runner.py

@@ -9,8 +9,7 @@ from aphrodite.common.sequence import (SamplingParams, SequenceData,
 from aphrodite.common.utils import is_cpu
 from aphrodite.constants import APHRODITE_TOKEN_ID_ARRAY_TYPE
 from aphrodite.engine.args_tools import EngineArgs
-from aphrodite.task_handler.enc_dec_model_runner import (
-    EncoderDecoderModelRunner)
+from aphrodite.worker.enc_dec_model_runner import EncoderDecoderModelRunner
 
 # CUDA graph scenarios to test
 #

+ 3 - 4
tests/worker/test_model_input.py

@@ -7,11 +7,10 @@ from aphrodite.attention import AttentionMetadata, AttentionMetadataBuilder
 from aphrodite.attention.backends.abstract import AttentionBackend
 from aphrodite.modeling.pooling_metadata import PoolingMetadata
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.task_handler.embedding_model_runner import (
+from aphrodite.worker.embedding_model_runner import (
     ModelInputForGPUWithPoolingMetadata)
-from aphrodite.task_handler.model_runner import (
-    ModelInputForGPUWithSamplingMetadata)
-from aphrodite.task_handler.multi_step_model_runner import StatefulModelInput
+from aphrodite.worker.model_runner import ModelInputForGPUWithSamplingMetadata
+from aphrodite.worker.multi_step_model_runner import StatefulModelInput
 
 
 class MockAttentionBackend(AttentionBackend):

+ 1 - 2
tests/worker/test_model_runner.py

@@ -12,8 +12,7 @@ from aphrodite.distributed.parallel_state import (
     ensure_model_parallel_initialized, init_distributed_environment)
 from aphrodite.engine.args_tools import EngineArgs
 from aphrodite.modeling.sampling_metadata import SamplingMetadata
-from aphrodite.task_handler.model_runner import (ModelRunner,
-                                                 _get_graph_batch_size)
+from aphrodite.worker.model_runner import ModelRunner, _get_graph_batch_size
 
 
 def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:

+ 1 - 1
tests/worker/test_swap.py

@@ -4,7 +4,7 @@ from aphrodite.common.sequence import ExecuteModelRequest
 from aphrodite.common.utils import (get_distributed_init_method, get_ip,
                                     get_open_port)
 from aphrodite.engine.args_tools import EngineArgs
-from aphrodite.task_handler.worker import Worker
+from aphrodite.worker.worker import Worker
 
 
 def test_swap() -> None: